文章目录
- 1、前言
- 2、鸢尾花分类
- 3、结果打印
1、前言
【tensorflow框架神经网络实现鸢尾花分类】一文中使用自定义的方式,实现了鸢尾花数据集的分类工作。在这里使用tensorflow中的keras模块快速、极简实现鸢尾花分类任务。
2、鸢尾花分类
import tensorflow as tf
from sklearn import datasets
import numpy as np
# 加载数据集
np.random.seed(0)
iris = datasets.load_iris()
x_train, y_train = iris.data, iris.target
np.random.seed(0)
np.random.shuffle(x_train)
np.random.seed(0)
np.random.shuffle(y_train)
# 设置随机种子
tf.random.set_seed(0)
# 构建模型
model = tf.keras.models.Sequential([
tf.keras.layers.Dense(3, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2(0.01))
])
# 编译模型
model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=0.1),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=['sparse_categorical_accuracy'])
# 训练模型
model.fit(x_train, y_train, batch_size=32, epochs=500, validation_split=0.2, validation_freq=20)
# 打印模型摘要
model.summary()