准确工作:需要下载music.csv
已上传
构建模型内容:
import pandas as pd
music_data = pd.read_csv('music.csv')
# music_data
X=music_data.drop(columns=['genre'])
# 删除的那一列的名字为genre
Y=music_data['genre']
# 访问指定的列
Y
预测用的是决策树,找数据关系
fit这里训练的模型:
X表示如果输入的是这样,希望的输出是Y
在本篇代码中,x是年龄和性别,y是所对应的爱好
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
music_data = pd.read_csv('music.csv')
# music_data
X=music_data.drop(columns=['genre'])
# 删除的那一列的名字为genre
Y=music_data['genre']
# 访问指定的列
# Y
model=DecisionTreeClassifier()
model.fit(X,Y)
# .fit(input,output)
# 预测21岁男性喜欢什么,20岁女性喜欢什么
predictions=model.predict([[21,1],[20,0]])
predictions
进一步优化
用20%数据进行测试,80%数据用于建立模型
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
music_data = pd.read_csv('music.csv')
X=music_data.drop(columns=['genre'])
# 删除的那一列的名字为genre
Y=music_data['genre']
X_train,X_test,Y_train,Y_test=train_test_split(X,Y,test_size=0.2)
model=DecisionTreeClassifier()
model.fit(X_train,Y_train)
predictions=model.predict(X_test)
# 返回预测的精确度0~1
score=accuracy_score(Y_test,predictions)
score
模型持久化:创建-保存-加载
保存和加载模块
import joblib
创建和保存
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
import joblib
# 保存和加载模块
music_data = pd.read_csv('music.csv')
X=music_data.drop(columns=['genre'])
# 删除的那一列的名字为genre
Y=music_data['genre']
# 训练模型
model=DecisionTreeClassifier()
model.fit(X,Y)
# 保存
joblib.dump(model,'music-recommender.joblib')
加载模型
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
import joblib
# 加载
model=joblib.load('music-recommender.joblib')
model
predictions=model.predict([[21,1]])
predictions
可视化
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree
music_data = pd.read_csv('music.csv')
X=music_data.drop(columns=['genre'])
Y=music_data['genre']
model=DecisionTreeClassifier()
model.fit(X,Y)
# Y.unique() 去重
tree.export_graphviz(model,out_file='music-rec.dot',
feature_names=['age','gender'],
class_names=sorted(Y.unique()),
# 每个节点都有可读的标签
label='all',
# 设置形状:节点框为圆角
rounded=True,
# 每个节点都有颜色
filled=True)
用vs打开保存的.dot即可