代码拉取完成,页面将自动刷新
# 该实现为新闻分类模型的优化,并利用训练好的模型对新数据进行预测
## 总结:
# 1. 如果要对N个类别的数据点进行分类,那么模型的最后一层应该是大小为N的Dense层
# 2. 对于单标签、多分类问题,模型的最后一层应该使用softmax激活函数,这样可以输出一个在N个输出类别上的概率分布
# 3. 对于这种问题,损失函数几乎总是应该使用分类交叉熵。它将模型输出的概率分布与目标的真实分布之间的距离最小化。
# 4. 处理多分类问题的标签有两种方法:
# a. 通过分类编码(也叫one-hot编码)对标签进行编码,然后使用categorical_crossentropy损失函数;
# b. 将标签编码为整数,然后使用sparse_categorical_crossentropy损失函数。
# 5. 如果你需要将数据划分到多个类别中,那么应避免使用太小的中间层,以免在模型中造成信息瓶颈
import os
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
from keras.src.datasets import reuters
import numpy as np
from keras.src import Sequential
from keras.src.layers import Dense
# 1. 导入Reuters新闻数据集
(X_train, y_train), (X_test, y_test) = reuters.load_data(num_words=10000)
# 2. 数据预处理
def vectorize_sequences(sequences, dimension=10000):
result = np.zeros((len(sequences), dimension))
for i, sequence in enumerate(sequences):
for j in sequence:
result[i, j] = 1
return result
# def to_one_hot(labes, dimension=46):
# result = np.zeros((len(labes), dimension))
# for i, label in enumerate(labes):
# result[i, label] = 1
# return result
X_train = vectorize_sequences(X_train)
X_test = vectorize_sequences(X_test)
# 处理多分类问题的标签有两种方法:
# 1. 通过分类编码(也叫one-hot编码)对标签进行编码,然后使用categorical_crossentropy损失函数
# 2. 将标签编码为整数,然后使用sparse_categorical_crossentropy(稀疏分类交叉熵)损失函数
# y_train = to_one_hot(y_train)
# y_test = to_one_hot(y_test)
y_train = np.array(y_train)
y_test = np.array(y_test)
# 3. 模型处理
# 第一步:导入模型
model = Sequential()
# 第二步:添加网络层
model.add(Dense(64, activation='relu'))
# 如果需要将数据划分到多个类别中,那么应避免使用太小的中间层,以免在模型中造成信息瓶颈。本例有46个分类
model.add(Dense(64, activation='relu'))
model.add(Dense(46, activation='softmax'))
# 如果要对N个类别的数据点进行分类,那么模型的最后一层应该是大小为N的Dense层
# 对于单标签、多分类问题,模型的最后一层应该使用softmax激活函数,这样可以输出一个在N个输出类别上的概率分布
# 第三步:编译模型
# 对于这种单标签、多分类问题,损失函数几乎总是应该使用分类交叉熵(categorical_crossentropy)。
# 它将模型输出的概率分布与目标的真实分布之间的距离最小化
model.compile(optimizer='rmsprop', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
# 第四步:训练模型
history = model.fit(X_train, y_train, epochs=9, batch_size=512)
# 第五步:评估模型
result = model.evaluate(X_test, y_test)
print("evaluate result: ", result)
# 第六步:预测模型
y_pred = model.predict(X_test)
print("y_pred: ", y_pred)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。