代码拉取完成,页面将自动刷新
from sklearn.svm import SVC
from sklearn.datasets import load_digits
from sklearn import metrics
from matplotlib import pyplot as plt
import random
digits = load_digits()
# 将数据分开保存
images = digits.images
labels = digits.target
print(images)
print(images.shape)
# 此时 images 是三维的(1797 * 8 * 8),即: 1797个8 * 8的矩阵,
print(images.ndim)
print(labels)
# 来个图片看一下结构
fig = plt.figure(figsize=(10, 10))
fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0.05, wspace=0.05)
# 显示前100个看看
for i in range(100):
sub_image = fig.add_subplot(10, 10, i + 1, xticks=[], yticks=[])
sub_image.imshow(digits.images[i], cmap=plt.cm.binary, interpolation='nearest')
sub_image.text(0, 9, str(digits.target[i]))
plt.show()
# 将8*8的图片转换为一维向量
cnt = len(images)
images_vector = images.reshape((cnt, -1))
print(images_vector)
print(images_vector.shape)
# 随机选择训练集和测试集
sample = list(range(cnt))
test_size = int(cnt * 0.3)
random.shuffle(sample)
train, test = sample[test_size:], sample[:test_size]
X_train, Y_train = images_vector[train], labels[train]
X_test, Y_test = images_vector[test], labels[test]
# 使用rbf核函数
classifier = SVC(kernel='rbf', C=1.0, gamma=0.001)
classifier.fit(X_train, Y_train)
print(classifier)
prediction = classifier.predict(X_test)
print(prediction)
print(metrics.classification_report(Y_test, prediction))
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。