335 Star 1.5K Fork 858

MindSpore / docs

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
enable_auto_augmentation.md 9.95 KB
一键复制 编辑 原始数据 按行查看 历史
俞涵 提交于 2023-11-16 14:36 . update obs links

应用自动数据增强

Linux Ascend GPU CPU 数据准备 中级 高级

查看源文件    查看notebook    在线运行

概述

自动数据增强(AutoAugment)[1]是在一系列图像增强子策略的搜索空间中,通过搜索算法找到适合特定数据集的图像增强方案。MindSpore的c_transforms模块提供了丰富的C++算子来实现AutoAugment,用户也可以自定义函数或者算子来实现。更多MindSpore算子的详细说明参见API文档

MindSpore算子和AutoAugment中的算子的对应关系如下:

AutoAugment算子 MindSpore算子 描述
shearX RandomAffine 横向剪切
shearY RandomAffine 纵向剪切
translateX RandomAffine 水平平移
translateY RandomAffine 垂直平移
rotate RandomRotation 旋转变换
color RandomColor 颜色变换
posterize RandomPosterize 减少颜色通道位数
solarize RandomSolarize 指定的阈值范围内,反转所有的像素点
contrast RandomColorAdjust 调整对比度
sharpness RandomSharpness 调整锐度
brightness RandomColorAdjust 调整亮度
autocontrast AutoContrast 最大化图像对比度
equalize Equalize 均衡图像直方图
invert Invert 反转图像

ImageNet自动数据增强

本教程以在ImageNet数据集上实现AutoAugment作为示例。

针对ImageNet数据集的数据增强策略包含25条子策略,每条子策略中包含两种变换,针对一个batch中的每张图像随机挑选一个子策略的组合,以预定的概率来决定是否执行子策略中的每种变换。

用户可以使用MindSpore中c_transforms模块的RandomSelectSubpolicy接口来实现AutoAugment,在ImageNet分类训练中标准的数据增强方式分以下几个步骤:

  • RandomCropDecodeResize:随机裁剪后进行解码。

  • RandomHorizontalFlip:水平方向上随机翻转。

  • Normalize:归一化。

  • HWC2CHW:图片通道变化。

RandomCropDecodeResize后插入AutoAugment变换,如下所示:

  1. 引入MindSpore数据增强模块。

    import matplotlib.pyplot as plt
    
    import mindspore.dataset as ds
    import mindspore.dataset.transforms.c_transforms as c_transforms
    import mindspore.dataset.vision.c_transforms as c_vision
    from mindspore import dtype as mstype
  2. 定义MindSpore算子到AutoAugment算子的映射:

    # define Auto Augmentation operators
    PARAMETER_MAX = 10
    
    def float_parameter(level, maxval):
        return float(level) * maxval /  PARAMETER_MAX
    
    def int_parameter(level, maxval):
        return int(level * maxval / PARAMETER_MAX)
    
    def shear_x(level):
        v = float_parameter(level, 0.3)
        return c_transforms.RandomChoice([c_vision.RandomAffine(degrees=0, shear=(-v,-v)), c_vision.RandomAffine(degrees=0, shear=(v, v))])
    
    def shear_y(level):
        v = float_parameter(level, 0.3)
        return c_transforms.RandomChoice([c_vision.RandomAffine(degrees=0, shear=(0, 0, -v,-v)), c_vision.RandomAffine(degrees=0, shear=(0, 0, v, v))])
    
    def translate_x(level):
        v = float_parameter(level, 150 / 331)
        return c_transforms.RandomChoice([c_vision.RandomAffine(degrees=0, translate=(-v,-v)), c_vision.RandomAffine(degrees=0, translate=(v, v))])
    
    def translate_y(level):
        v = float_parameter(level, 150 / 331)
        return c_transforms.RandomChoice([c_vision.RandomAffine(degrees=0, translate=(0, 0, -v,-v)), c_vision.RandomAffine(degrees=0, translate=(0, 0, v, v))])
    
    def color_impl(level):
        v = float_parameter(level, 1.8) + 0.1
        return c_vision.RandomColor(degrees=(v, v))
    
    def rotate_impl(level):
        v = int_parameter(level, 30)
        return c_transforms.RandomChoice([c_vision.RandomRotation(degrees=(-v, -v)), c_vision.RandomRotation(degrees=(v, v))])
    
    def solarize_impl(level):
        level = int_parameter(level, 256)
        v = 256 - level
        return c_vision.RandomSolarize(threshold=(0, v))
    
    def posterize_impl(level):
        level = int_parameter(level, 4)
        v = 4 - level
        return c_vision.RandomPosterize(bits=(v, v))
    
    def contrast_impl(level):
        v = float_parameter(level, 1.8) + 0.1
        return c_vision.RandomColorAdjust(contrast=(v, v))
    
    def autocontrast_impl(level):
        return c_vision.AutoContrast()
    
    def sharpness_impl(level):
        v = float_parameter(level, 1.8) + 0.1
        return c_vision.RandomSharpness(degrees=(v, v))
    
    def brightness_impl(level):
        v = float_parameter(level, 1.8) + 0.1
        return c_vision.RandomColorAdjust(brightness=(v, v))
  3. 定义ImageNet数据集的AutoAugment策略:

    # define the Auto Augmentation policy
    imagenet_policy = [
        [(posterize_impl(8), 0.4), (rotate_impl(9), 0.6)],
        [(solarize_impl(5), 0.6), (autocontrast_impl(5), 0.6)],
        [(c_vision.Equalize(), 0.8), (c_vision.Equalize(), 0.6)],
        [(posterize_impl(7), 0.6), (posterize_impl(6), 0.6)],
        [(c_vision.Equalize(), 0.4), (solarize_impl(4), 0.2)],
    
        [(c_vision.Equalize(), 0.4), (rotate_impl(8), 0.8)],
        [(solarize_impl(3), 0.6), (c_vision.Equalize(), 0.6)],
        [(posterize_impl(5), 0.8), (c_vision.Equalize(), 1.0)],
        [(rotate_impl(3), 0.2), (solarize_impl(8), 0.6)],
        [(c_vision.Equalize(), 0.6), (posterize_impl(6), 0.4)],
    
        [(rotate_impl(8), 0.8), (color_impl(0), 0.4)],
        [(rotate_impl(9), 0.4), (c_vision.Equalize(), 0.6)],
        [(c_vision.Equalize(), 0.0), (c_vision.Equalize(), 0.8)],
        [(c_vision.Invert(), 0.6), (c_vision.Equalize(), 1.0)],
        [(color_impl(4), 0.6), (contrast_impl(8), 1.0)],
    
        [(rotate_impl(8), 0.8), (color_impl(2), 1.0)],
        [(color_impl(8), 0.8), (solarize_impl(7), 0.8)],
        [(sharpness_impl(7), 0.4), (c_vision.Invert(), 0.6)],
        [(shear_x(5), 0.6), (c_vision.Equalize(), 1.0)],
        [(color_impl(0), 0.4), (c_vision.Equalize(), 0.6)],
    
        [(c_vision.Equalize(), 0.4), (solarize_impl(4), 0.2)],
        [(solarize_impl(5), 0.6), (autocontrast_impl(5), 0.6)],
        [(c_vision.Invert(), 0.6), (c_vision.Equalize(), 1.0)],
        [(color_impl(4), 0.6), (contrast_impl(8), 1.0)],
        [(c_vision.Equalize(), 0.8), (c_vision.Equalize(), 0.6)],
    ]
  4. RandomCropDecodeResize操作后插入AutoAugment变换。

    def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, shuffle=True, num_samples=5, target="Ascend"):
        # create a train or eval imagenet2012 dataset for ResNet-50
        dataset = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8,
                                   shuffle=shuffle, num_samples=num_samples)
    
        image_size = 224
        mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
        std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
    
        # define map operations
        if do_train:
            trans = [
                c_vision.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
            ]
    
            post_trans = [
                c_vision.RandomHorizontalFlip(prob=0.5),
            ]
        else:
            trans = [
                c_vision.Decode(),
                c_vision.Resize(256),
                c_vision.CenterCrop(image_size),
                c_vision.Normalize(mean=mean, std=std),
                c_vision.HWC2CHW()
            ]
        dataset = dataset.map(operations=trans, input_columns="image")
        if do_train:
            dataset = dataset.map(operations=c_vision.RandomSelectSubpolicy(imagenet_policy), input_columns=["image"])
            dataset = dataset.map(operations=post_trans, input_columns="image")
        type_cast_op = c_transforms.TypeCast(mstype.int32)
        dataset = dataset.map(operations=type_cast_op, input_columns="label")
        # apply the batch operation
        dataset = dataset.batch(batch_size, drop_remainder=True)
        # apply the repeat operation
        dataset = dataset.repeat(repeat_num)
    
        return dataset
  5. 验证自动数据增强效果。

    # Define the path to image folder directory. This directory needs to contain sub-directories which contain the images
    DATA_DIR = "/path/to/image_folder_directory"
    dataset = create_dataset(dataset_path=DATA_DIR, do_train=True, batch_size=5, shuffle=False, num_samples=5)
    
    epochs = 5
    itr = dataset.create_dict_iterator()
    fig=plt.figure(figsize=(8, 8))
    columns = 5
    rows = 5
    
    step_num = 0
    for ep_num in range(epochs):
        for data in itr:
            step_num += 1
            for index in range(rows):
                fig.add_subplot(rows, columns, ep_num * rows + index + 1)
                plt.imshow(data['image'].asnumpy()[index])
    plt.show()

    为了更好地演示效果,此处只加载5张图片,且读取时不进行shuffle操作,自动数据增强时也不进行NormalizeHWC2CHW操作。

    augment

    运行结果可以看到,batch中每张图像的增强效果,水平方向表示1个batch的5张图像,垂直方向表示5个batch。

参考文献

[1] AutoAugment: Learning Augmentation Policies from Data.

1
https://gitee.com/mindspore/docs.git
git@gitee.com:mindspore/docs.git
mindspore
docs
docs
r1.2

搜索帮助