博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
机器学习之路: python 决策树分类DecisionTreeClassifier 预测泰坦尼克号乘客是否幸存...
阅读量:4594 次
发布时间:2019-06-09

本文共 2152 字,大约阅读时间需要 7 分钟。

 

 

使用python3 学习了决策树分类器的api

涉及到 特征的提取,数据类型保留,分类类型抽取出来新的类型

需要网上下载数据集,我把他们下载到了本地,

可以到我的git下载代码和数据集: https://github.com/linyi0604/MachineLearning

 

1 import pandas as pd 2 from sklearn.cross_validation import train_test_split 3 from sklearn.feature_extraction import DictVectorizer 4 from sklearn.tree import DecisionTreeClassifier 5 from sklearn.metrics import classification_report 6  7 ''' 8 决策树 9 涉及多个特征,没有明显的线性关系10 推断逻辑非常直观11 不需要对数据进行标准化12 '''13 14 '''15 1 准备数据16 '''17 # 读取泰坦尼克乘客数据,已经从互联网下载到本地18 titanic = pd.read_csv("./data/titanic/titanic.txt")19 # 观察数据发现有缺失现象20 # print(titanic.head())21 22 # 提取关键特征,sex, age, pclass都很有可能影响是否幸免23 x = titanic[['pclass', 'age', 'sex']]24 y = titanic['survived']25 # 查看当前选择的特征26 # print(x.info())27 '''28 
29 RangeIndex: 1313 entries, 0 to 131230 Data columns (total 3 columns):31 pclass 1313 non-null object32 age 633 non-null float6433 sex 1313 non-null object34 dtypes: float64(1), object(2)35 memory usage: 30.9+ KB36 None37 '''38 # age数据列 只有633个,对于空缺的 采用平均数或者中位数进行补充 希望对模型影响小39 x['age'].fillna(x['age'].mean(), inplace=True)40 41 '''42 2 数据分割43 '''44 x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.25, random_state=33)45 # 使用特征转换器进行特征抽取46 vec = DictVectorizer()47 # 类别型的数据会抽离出来 数据型的会保持不变48 x_train = vec.fit_transform(x_train.to_dict(orient="record"))49 # print(vec.feature_names_) # ['age', 'pclass=1st', 'pclass=2nd', 'pclass=3rd', 'sex=female', 'sex=male']50 x_test = vec.transform(x_test.to_dict(orient="record"))51 52 '''53 3 训练模型 进行预测54 '''55 # 初始化决策树分类器56 dtc = DecisionTreeClassifier()57 # 训练58 dtc.fit(x_train, y_train)59 # 预测 保存结果60 y_predict = dtc.predict(x_test)61 62 '''63 4 模型评估64 '''65 print("准确度:", dtc.score(x_test, y_test))66 print("其他指标:\n", classification_report(y_predict, y_test, target_names=['died', 'survived']))67 '''68 准确度: 0.781155015197568469 其他指标:70 precision recall f1-score support71 72 died 0.91 0.78 0.84 23673 survived 0.58 0.80 0.67 9374 75 avg / total 0.81 0.78 0.79 32976 '''

 

转载于:https://www.cnblogs.com/Lin-Yi/p/8970609.html

你可能感兴趣的文章
13、对象与类
查看>>
Sublime Text3 个人使用心得
查看>>
jquery 编程的最佳实践
查看>>
MeetMe
查看>>
IP报文格式及各字段意义
查看>>
(转载)rabbitmq与springboot的安装与集成
查看>>
C2. Power Transmission (Hard Edition)(线段相交)
查看>>
STM32F0使用LL库实现SHT70通讯
查看>>
Atitit. Xss 漏洞的原理and应用xss木马
查看>>
MySQL源码 数据结构array
查看>>
(文件过多时)删除目录下全部文件
查看>>
T-SQL函数总结
查看>>
python 序列:列表
查看>>
web移动端
查看>>
pythonchallenge闯关 第13题
查看>>
linux上很方便的上传下载文件工具rz和sz使用介绍
查看>>
React之特点及常见用法
查看>>
【WEB前端经验之谈】时间一年半,或沉淀、或从零开始。
查看>>
优云软件助阵GOPS·2017全球运维大会北京站
查看>>
linux 装mysql的方法和步骤
查看>>