# lego minifigures pytorch lightning tutorial **Repository Path**: yichaoyyds/lego-minifigures-pytorch-lightning-tutorial ## Basic Information - **Project Name**: lego minifigures pytorch lightning tutorial - **Description**: No description available - **Primary Language**: Python - **License**: Not specified - **Default Branch**: master - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 1 - **Forks**: 0 - **Created**: 2022-05-08 - **Last Updated**: 2023-11-23 ## Categories & Tags **Categories**: Uncategorized **Tags**: None ## README # MobileNetV2 pyTorch Lightning LEGO Minifigures 图像分类案例 此案例中,我们将通过 pyTorch Lightning 对 MobileNetV2 预训练模型进行迁移学习,对象是 LEGO Minifigures 数据集。 运行环境: - 平台:Win10。 - IDE:Visual Studio Code - **建议**预装:Anaconda3 - [代码](https://gitee.com/yichaoyyds/lego-minifigures-pytorch-lightning-tutorial) 目录: - [MobileNetV2 pyTorch Lightning LEGO Minifigures 图像分类案例](#mobilenetv2-pytorch-lightning-lego-minifigures-图像分类案例) - [1 关于 LEGO Minifigures 数据集](#1-关于-lego-minifigures-数据集) - [2 代码运行](#2-代码运行) - [3 数据读取与处理](#3-数据读取与处理) - [3.1 pl.LightningDataModule](#31-pllightningdatamodule) - [3.2 数据预处理](#32-数据预处理) - [3.3 数据增强(仅对于训练集)](#33-数据增强仅对于训练集) - [4 模型训练/验证/测试/优化类](#4-模型训练验证测试优化类) - [4.1 pl.LightningModule](#41-pllightningmodule) - [4.2 模型导入](#42-模型导入) - [4.3 模型训练](#43-模型训练) - [4.4 模型验证](#44-模型验证) - [4.5 模型测试](#45-模型测试) - [4.5 模型优化](#45-模型优化) - [5 迁移学习](#5-迁移学习) - [5.1 实例化数据读取处理模块以及训练/验证/测试/优化类](#51-实例化数据读取处理模块以及训练验证测试优化类) - [5.2 迁移学习的训练,验证,以及测试](#52-迁移学习的训练验证以及测试) - [6 结果](#6-结果) - [6.1 训练/验证/测试 Loss 以及 Accuracy 值](#61-训练验证测试-loss-以及-accuracy-值) - [6.2 通过 ModelCheckpoint 文件导入最优模型并进行模型推理](#62-通过-modelcheckpoint-文件导入最优模型并进行模型推理) - [6.3 找出误识别的图片](#63-找出误识别的图片) ## 1 关于 LEGO Minifigures 数据集 ![lego](./img/1.png) 该数据集包含各种乐高人仔的图片。 数据集中的每个人仔都有几张不同姿势和不同环境的图像。 目前,它包含来自乐高套装的 28 个人物(总共 300 多张图像):Yoda's Hut, Spider Mech vs. Venom, General Grievous' Combat Speeder, Kylo Ren's Shuttle™ Microfighter, AT-ST™ Raider from The Mandalorian, Molten Man Battle, Aragog's Lair, Black Widow's Helicopter Chase, Captain America: Outriders Attack, Pteranodon Chase, Iron Man Hall of Armor。 该数据集的[GITHUB链接](https://github.com/yisaienkov/tinysets#pypi)。Kaggle中也有其数据集可以直接下载,[Kaggle链接](https://www.kaggle.com/datasets/ihelon/lego-minifigures-classification)。 这里我们已经下载了这个数据集,在`archive`文件夹下。需要注意: - `index.csv`:包含了训练集/验证集(一共361个数据)图片的位置,以及对应的标签(所以这个csv文件有2列); - `test.csv`:包含了测试集(一共76个数据)图片位置,以及对应的标签; - `metadata.csv`:包含了这个数据集的37个标签。 ## 2 代码运行 我们建议在运行代码前,我们先在本地建一个虚拟环境。如果我们在本地安装了Anaconda,那么可以使用`conda create -n your_env_name python=X.X(2.7、3.6等)`命令创建python版本为X.X、名字为your_env_name的虚拟环境。 这里我们输入`conda create -n mlFlowEx python=3.8.2`。 安装完默认的依赖后,我们进入虚拟环境:`conda activate mlFlowEx`。注意,如果需要退出,则输入`conda deactivate`。另外,如果Terminal没有成功切换到虚拟环境,可以尝试`conda init powershell`,然后重启terminal。 然后,我们在虚拟环境中下载好相关依赖:`pip3 install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple`。 如果我们需要在jupyter notebook中运行代码,这里需要添加名为`mlFlowEx_env`的kernal: ```command python -m ipykernel install --user --name mlFlowEx_env ``` 这里有两种运行代码的方式: - 我们可以直接在主路径下运行:`python LegoCharacterRecognition.py`; - 我们也可以在`lego-minifigures-pytorch-lightning-tutorial.ipynb`逐条运行代码; ## 3 数据读取与处理 首先,我们需要读取数据,拆分成训练集、验证集、测试集,并且进行一些数据预处理,必须数据增强等。 ### 3.1 pl.LightningDataModule `pl.LightningDataModule` 标准化了训练、验证、测试集的拆分、数据准备和转换。主要优点是一致的数据拆分、数据准备和跨模型转换,一个例子如下: ```python class MyDataModule(LightningDataModule): def __init__(self): super().__init__() def prepare_data(self): # download, split, etc... # only called on 1 GPU/TPU in distributed def setup(self, stage): # make assignments here (val/train/test split) # called on every process in DDP def train_dataloader(self): train_split = Dataset(...) return DataLoader(train_split) def val_dataloader(self): val_split = Dataset(...) return DataLoader(val_split) def test_dataloader(self): test_split = Dataset(...) return DataLoader(test_split) def teardown(self): # clean up after fit or test # called on every process in DDP ``` 此项目代码: ```python class LEGOMinifiguresDataModule(pl.LightningDataModule): def __init__( self, train_batch_size, valid_batch_size, test_batch_size, image_size, base_dir, train_augmentations=None ): """ Initialization of inherited lightning data module """ super().__init__() self.train_batch_size = train_batch_size self.valid_batch_size = valid_batch_size self.test_batch_size = test_batch_size self.image_size = image_size self.base_dir = base_dir self.train_augmentations=train_augmentations def setup(self, stage): """ Load the data, parse it and split the data into train, test, validation data """ # Load train dataset, test dataset self.df_train = pd.read_csv(os.path.join(self.base_dir, 'index.csv')) self.df_test = pd.read_csv(os.path.join(self.base_dir, 'test.csv')) # Split train dataset into train dataset and validation dataset X, y = df_train.path, df_train.class_id train_paths, valid_paths, y_train, y_valid = train_test_split(X, y, random_state=0) # Store train/validation/test dataset path and label self.train_targets = y_train - 1 self.train_paths = list(map(lambda x: os.path.join(self.base_dir, x), train_paths)) self.valid_targets = y_valid - 1 self.valid_paths = list(map(lambda x: os.path.join(self.base_dir, x), valid_paths)) tmp_test = self.df_test test_paths = tmp_test.path self.test_targets = tmp_test.class_id - 1 self.test_paths = list(map(lambda x: os.path.join(self.base_dir, x), test_paths)) def train_dataloader(self): """ :return: output - Train data loader for the given input """ train_data_retriever = DataRetriever( self.train_paths, self.train_targets, image_size=self.image_size, transforms=self.train_augmentations ) train_loader = torch_data.DataLoader( train_data_retriever, batch_size=self.train_batch_size, shuffle=True, ) return train_loader def val_dataloader(self): """ :return: output - Validation data loader for the given input """ valid_data_retriever = DataRetriever( self.valid_paths, self.valid_targets, image_size=self.image_size, ) valid_loader = torch_data.DataLoader( valid_data_retriever, batch_size=self.valid_batch_size, shuffle=True, ) return valid_loader def test_dataloader(self): """ :return: output - Test data loader for the given input """ test_data_retriever = DataRetriever( self.test_paths, self.test_targets, image_size=self.image_size, ) test_loader = torch_data.DataLoader( test_data_retriever, batch_size=self.test_batch_size, shuffle=False, ) return test_loader ``` `LEGOMinifiguresDataModule`类继承于`pl.LightningDataModule`,大体上可以说由两部分组成。第一部分是读取数据集,并且拆分成训练数据集/验证数据集/测试数据集(`setup`函数)。第二部分是对于这三个数据集的dataloader(`train_dataloader`,`val_dataloader`,`test_dataloader`函数)。 ### 3.2 数据预处理 对于`train_dataloader`,`val_dataloader`,`test_dataloader`函数,我们需要先对各自输入的数据集进行一些预处理,见`DataRetriever`类: ```python class DataRetriever(torch_data.Dataset): def __init__( self, paths, targets, image_size=(224, 224), transforms=None ): self.paths = list(paths) self.targets = list(targets) self.image_size = image_size self.transforms = transforms self.preprocess = torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), ]) def __len__(self): return len(self.targets) def __getitem__(self, index): img = cv2.imread(self.paths[index]) # 我们需要调整图像大小到与MoblieNetV2的输入尺寸匹配 img = cv2.resize(img, self.image_size) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) if self.transforms: img = self.transforms(image=img)['image'] img = self.preprocess(img) y = torch.tensor(self.targets[index], dtype=torch.long) return {'X': img, 'y': y} ``` 我们先读取图像(`cv2.imread(self.paths[index])`),然后调整图像大小到与MoblieNetV2的输入尺寸匹配(`cv2.resize(img, self.image_size)`),将图像格式从BGR转换到RGB(`cv2.cvtColor(img, cv2.COLOR_BGR2RGB)`)。对于训练集,我们会对其进行数据增强(`self.transforms(image=img)['image']`)。然后,我们对于所有数据集需要进行预处理(`self.preprocess(img)`): ```python self.preprocess = torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), ]) ``` 最后,`__getitem__`返回图像和标签:`{'X': img, 'y': y}`。 ### 3.3 数据增强(仅对于训练集) 和验证集/测试集略有不同,训练集的`train_dataloader`需要对于数据进行数据增强。 ```python def get_train_transforms(): return A.Compose( [ A.Rotate(limit=30, border_mode=cv2.BORDER_REPLICATE, p=0.5), A.Cutout(num_holes=8, max_h_size=25, max_w_size=25, fill_value=0, p=0.25), A.Cutout(num_holes=8, max_h_size=25, max_w_size=25, fill_value=255, p=0.25), A.HorizontalFlip(p=0.5), A.RandomContrast(limit=(-0.3, 0.3), p=0.5), A.RandomBrightness(limit=(-0.4, 0.4), p=0.5), A.Blur(p=0.25), ], p=1.0 ) ``` ## 4 模型训练/验证/测试/优化类 pyTorch Lightning的一大优点就是将机器学习/深度学习流程化,标准化。与模型训练/验证/测试相关的代码都写在`LitModel`类中,继承于`pl.LightningModule`。 ### 4.1 pl.LightningModule 对于 PyTorch Lightning,有两个函数是至关重要,一个是`pl.LightningModule`,一个是`pl.LightningDataModule`。前者的包含了训练/验证/预测/优化的所有模块,后者则是数据集读取模块。我们通过PyTorch Lightning进行模型训练的时候,通常会继承这两个类。目前我对 PyTorch Lightning 不是很了解,所以这里我作为一个初学者的角度,针对这个案例进行一些相关的解读。 关于`pl.LightningModule`,和我们这个案例相关的函数包括: - `forward`,作用和`torch.nn.Module.forward()`一样,这里我们不再赘述; - `training_step`,我们计算并返回训练损失和一些额外的metrics。 - `validation_step`,我们计算并返回验证损失和一些额外的metrics。 - `test_step`,我们计算并返回测试损失和一些额外的metrics。 - `validation_epoch_end`,在验证epoch结束后,计算这个epoch的平均验证accuracy。 - `test_epoch_end`,在测试epoch结束后,计算计算这个epoch的平均测试accuracy。 - `configure_optimizers`,选择要在优化中使用的优化器和学习率调度器。 [此网页](https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.core.LightningModule.html?highlight=LightningModule#pytorch_lightning.core.LightningModule.training_step)有详细的描述,这里不再赘述。 ### 4.2 模型导入 这里我们基于 MobileNetV2 的预训练模型,对 LEGO Minifigures 数据集进行迁移训练。所以,首先我们需要导入预训练模型,并且加上最后一层全连接层,组成一个完整的神经网络。代码位于`LitModel`类的`__init__`函数中: ```python def __init__(self, n_classes): super().__init__() self.net = torch.hub.load( 'pytorch/vision:v0.6.0', 'mobilenet_v2', pretrained=True ) self.net.classifier = torch_nn.Linear( in_features=1280, out_features=n_classes, bias=True ) self.save_hyperparameters() def forward(self, x): x = self.net(x) return x ``` ### 4.3 模型训练 模型训练相关的代码位于`LitModel`类的`training_step`函数中。模型训练的输入`batch`可以理解为`LEGOMinifiguresDataModule`的`train_dataloader`的输出(我们可以这么理解)。相关代码如下: ```python def training_step(self, batch, batch_idx): """ Training the data as batches and returns training loss on each batch :param train_batch: Batch data :param batch_idx: Batch indices :return: output - Training loss """ X, y = batch['X'], batch['y'] y_hat = self(X) train_loss = torch_F.cross_entropy(y_hat, y) train_acc = accuracy( y_hat, y, num_classes=self.hparams.n_classes ) #result = pl.TrainResult(train_loss) self.log('train_loss', train_loss, prog_bar=True, on_epoch=True, on_step=False) self.log('train_acc', train_acc, prog_bar=True, on_epoch=True, on_step=False) return {"loss": train_loss} ``` 我们通过`training_step`可以计算得到当前循环下 train loss 以及 train accuracy。函数返回 train loss 值。 ### 4.4 模型验证 模型验证相关的代码位于`LitModel`类的`validation_step`与`validation_epoch_end`函数中。逻辑与模型训练类似: ```python def validation_step(self, batch, batch_idx): """ Performs validation of data in batches :param batch: Batch data :param batch_idx: Batch indices :return: output - valid step loss """ X, y = batch['X'], batch['y'] y_hat = self(X) valid_loss = torch_F.cross_entropy(y_hat, y) valid_acc = accuracy( y_hat, y, num_classes=self.hparams.n_classes ) #result = pl.EvalResult(checkpoint_on=valid_loss, early_stop_on=valid_loss) self.log('valid_loss', valid_loss, prog_bar=True, on_epoch=True, on_step=False) self.log('valid_acc', valid_acc, prog_bar=True, on_epoch=True, on_step=False) return {"val_step_loss": valid_loss} def validation_epoch_end(self, outputs): """ Computes average validation accuracy :param outputs: outputs after every epoch end :return: output - average valid loss """ avg_loss = torch.stack([x["val_step_loss"] for x in outputs]).mean() self.log("val_loss", avg_loss, sync_dist=True) ``` ### 4.5 模型测试 模型验证相关的代码位于`LitModel`类的`test_step`与`test_epoch_end`函数中。逻辑与模型训练类似: ```python def test_step(self, batch, batch_idx): """ Performs test and computes the accuracy of the model :param batch: Batch data :param batch_idx: Batch indices :return: output - Testing accuracy """ X, y = batch['X'], batch['y'] y_hat = self(X) test_loss = torch_F.cross_entropy(y_hat, y) test_acc = accuracy( y_hat, y, num_classes=self.hparams.n_classes ) #result = pl.EvalResult(checkpoint_on=valid_loss, early_stop_on=valid_loss) self.log('test_loss', test_loss, prog_bar=True, on_epoch=True, on_step=False) self.log('test_acc', test_acc, prog_bar=True, on_epoch=True, on_step=False) return {"test_acc": test_acc} def test_epoch_end(self, outputs): """ Computes average test accuracy score :param outputs: outputs after every epoch end :return: output - average test loss """ avg_test_acc = torch.stack([x["test_acc"] for x in outputs]).mean() self.log("avg_test_acc", avg_test_acc) ``` ### 4.5 模型优化 最后一块就是模型优化,我们可以自行选择优化器。相关代码如下: ```python def configure_optimizers(self): """ Initializes the optimizer and learning rate scheduler :return: output - Initialized optimizer and scheduler """ self.optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) self.scheduler = { "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau( self.optimizer, mode="min", factor=0.2, patience=2, min_lr=1e-6, verbose=True, ), "monitor": "val_loss", } return [self.optimizer], [self.scheduler] ``` ## 5 迁移学习 这里,我们就开始迁移学习的训练,验证,以及最后的测试了。 ### 5.1 实例化数据读取处理模块以及训练/验证/测试/优化类 上面章节介绍了 pytorch lightning 的 `LightningDataModule` 以及 `LightningModule` 类如何在我们这个案例中被继承与使用。这里,我们首先需要实例化 `LitModel` 以及 `LEGOMinifiguresDataModule`。 ```python # 实例化模型训练/验证/测试/优化类 model = LitModel(n_classes=N_CLASSES) # 实例化数据读取处理模块 data_module = LEGOMinifiguresDataModule( train_batch_size=4, valid_batch_size=1, test_batch_size=1, image_size=(512, 512), base_dir=BASE_DIR, train_augmentations=get_train_transforms() ) ``` ### 5.2 迁移学习的训练,验证,以及测试 我们定义 EarlyStopping 条件,ModelCheckpoint 的保存机制,以及对 learning rate 的保存跟踪,然后就可以开始训练,验证,以及测试了。代码如下: ```python # 定义 EarlyStopping Criteria early_stopping = EarlyStopping( monitor='valid_loss', mode='min', verbose=True, patience=3, ) # 定义 ModelCheckpoint,我们可以通过导入.ckpt文件模型进行模型继续训练,或者模型推理。 callback_model_checkpoint = ModelCheckpoint( dirpath=os.getcwd(), filename='sample-{epoch}-{valid_loss:.3f}', save_top_k=1, verbose=True, monitor='valid_loss', mode='min', ) lr_logger = LearningRateMonitor() # 实例化Trainer,这里epoch我只设定了12次,由于我这边的运行环境是CPU。 trainer = pl.Trainer( gpus=0, callbacks=[lr_logger, early_stopping, callback_model_checkpoint], checkpoint_callback=True, max_epochs=12 ) # 模型训练/验证 trainer.fit( model, data_module, ) # 模型测试 trainer.test(datamodule=data_module) ``` ## 6 结果 ### 6.1 训练/验证/测试 Loss 以及 Accuracy 值 由于我这边用CPU进行的训练,所以epoch的数量设置的比较小,如果是GPU的同学,建议设置50。运行的Terminal记录如下: ```terminal Using cache found in C:\Users\XXX/.cache\torch\hub\pytorch_vision_v0.6.0 GPU available: False, used: False TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs | Name | Type | Params ------------------------------------- 0 | net | MobileNetV2 | 2.3 M ------------------------------------- 2.3 M Trainable params 0 Non-trainable params 2.3 M Total params 9.085 Total estimated model params size (MB) Epoch 0: 8%|▊ | 13/159 [00:43<08:05, 3.32s/it, loss=3.67, v_num=13]Metric valid_loss improved. New best score: 3.316 Epoch 0: 100%|██████████| 159/159 [03:52<00:00, 1.46s/it, loss=3.51, v_num=13, valid_loss=3.320, valid_acc=0.143, train_loss=3.600, train_acc=0.0741]Epoch 0, global step 68: 'valid_loss' reached 3.31555 (best 3.31555), saving model to 'XXX\\sample-epoch=0-valid_loss=3.316.ckpt' as top 1 Epoch 1: 100%|██████████| 159/159 [07:42<00:00, 2.91s/it, loss=3, v_num=13, valid_loss=2.820, valid_acc=0.308, train_loss=3.600, train_acc=0.0741] Metric valid_loss improved by 0.494 >= min_delta = 0.0. New best score: 2.822 Epoch 1: 100%|██████████| 159/159 [07:42<00:00, 2.91s/it, loss=3, v_num=13, valid_loss=2.820, valid_acc=0.308, train_loss=3.100, train_acc=0.215] Epoch 1, global step 136: 'valid_loss' reached 2.82201 (best 2.82201), saving model to 'XXX\\sample-epoch=1-valid_loss=2.822.ckpt' as top 1 ... Epoch 11: 100%|██████████| 159/159 [48:15<00:00, 18.21s/it, loss=0.569, v_num=13, valid_loss=0.391, valid_acc=0.945, train_loss=0.670, train_acc=0.944] Metric valid_loss improved by 0.081 >= min_delta = 0.0. New best score: 0.391 Epoch 11: 100%|██████████| 159/159 [48:15<00:00, 18.21s/it, loss=0.569, v_num=13, valid_loss=0.391, valid_acc=0.945, train_loss=0.548, train_acc=0.978]Epoch 11, global step 816: 'valid_loss' reached 0.39098 (best 0.39098), saving model to 'XXX\\sample-epoch=11-valid_loss=0.391.ckpt' as top 1 Epoch 11: 100%|██████████| 159/159 [48:16<00:00, 18.21s/it, loss=0.569, v_num=13, valid_loss=0.391, valid_acc=0.945, train_loss=0.548, train_acc=0.978] Restoring states from the checkpoint path at D:\yichao\learning\courses\MLOps\mlflow-ex\ex\8_LegoCharacterRecognition\sample-epoch=11-valid_loss=0.391.ckpt Loaded model weights from checkpoint at D:\yichao\learning\courses\MLOps\mlflow-ex\ex\8_LegoCharacterRecognition\sample-epoch=11-valid_loss=0.391.ckpt Testing DataLoader 0: 100%|██████████| 76/76 [00:20<00:00, 3.66it/s] ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── Test metric DataLoader 0 ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── avg_test_acc 0.8684210777282715 test_acc 0.8684210777282715 test_loss 0.5233738422393799 ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── [{'test_loss': 0.5233738422393799, 'test_acc': 0.8684210777282715, 'avg_test_acc': 0.8684210777282715}] ``` 我们可以看到,当我们最终得到的结果: - loss=0.569 - valid_loss=0.391 - valid_acc=0.945 - train_loss=0.548 - train_acc=0.978 - test_loss=0.5233738422393799 - test_acc=0.8684210777282715 - avg_test_acc=0.8684210777282715 ### 6.2 通过 ModelCheckpoint 文件导入最优模型并进行模型推理 这里我们尝试读取训练过程中保存的 checkpoint 文件:`sample-epoch=11-valid_loss=0.391.ckpt`,然后用其进行模型推理: ```python device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 读取 checkpoint 文件。注意,这个案例中这个 checkpoint 保存在当前路径下。 best_model_path = callback_model_checkpoint.best_model_path model = LitModel.load_from_checkpoint( checkpoint_path=best_model_path ) model = model.to(device) model.freeze() # 模型推理 y_pred = [] y_gt = [] for ind, batch in enumerate(data_module.test_dataloader()): pred_probs = model(batch['X']) y_pred.extend(pred_probs.argmax(axis=-1).cpu().numpy()) y_gt.extend(batch['y']) # Calculate needed metrics print(f'Accuracy score on test data:\t{sk_metrics.accuracy_score(y_gt, y_pred)}') print(f'Macro F1 score on test data:\t{sk_metrics.f1_score(y_gt, y_pred, average="macro")}') ``` 我们得到的结果如下: ```terminal Accuracy score on test data: 0.868421052631579 Macro F1 score on test data: 0.8436293436293436 ``` 需要注意,由于这里我设置epoch的值只有12,这么模型完全可以精度再高一些,但由于我这边只有CPU,所以就不花时间等结果了。 ### 6.3 找出误识别的图片 最后,我们可以更进一步,找出哪些误识别的图片(由于我们这个数据集比较小)。 首先,让我们来看看整个数据集的 confusion matrix 出来的效果: ```python # Load metadata to get classes people-friendly names labels = df_metadata['minifigure_name'].tolist() # Calculate confusion matrix confusion_matrix = sk_metrics.confusion_matrix(y_gt, y_pred) df_confusion_matrix = pd.DataFrame(confusion_matrix, index=labels, columns=labels) # Show confusion matrix plt.figure(figsize=(12, 12)) sn.heatmap(df_confusion_matrix, annot=True, cbar=False, cmap='Oranges', linewidths=1, linecolor='black') plt.xlabel('Predicted labels', fontsize=15) plt.xticks(fontsize=12) plt.ylabel('True labels', fontsize=15) plt.yticks(fontsize=12); ``` ![lego](./img/2.png) 最后,我们将那些错误识别的图片都显示出来: ```python error_images = [] error_label = [] error_pred = [] error_prob = [] for batch in data_module.test_dataloader(): _X_test, _y_test = batch['X'], batch['y'] pred = torch.softmax(model(_X_test), axis=-1).cpu().numpy() pred_class = pred.argmax(axis=-1) if pred_class != _y_test.cpu().numpy(): error_images.extend(_X_test) error_label.extend(_y_test) error_pred.extend(pred_class) error_prob.extend(pred.max(axis=-1)) def denormalize_image(image): return image * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406] plt.figure(figsize=(16, 16)) w_size = int(len(error_images) ** 0.5) h_size = math.ceil(len(error_images) / w_size) for ind, image in enumerate(error_images): plt.subplot(h_size, w_size, ind + 1) plt.imshow(denormalize_image(image.permute(1, 2, 0).numpy())) pred_label = labels[error_pred[ind]] pred_prob = error_prob[ind] true_label = labels[error_label[ind]] plt.title(f'predict: {pred_label} ({pred_prob:.2f}) true: {true_label}') plt.axis('off') ``` 结果如下: ![lego](./img/3.png)