From 789c6e0cbf8aff97cfe3f02a0e0c2316a0b92739 Mon Sep 17 00:00:00 2001 From: zhaoting Date: Thu, 6 Jan 2022 14:28:42 +0800 Subject: [PATCH] fix lenet perf issue --- official/cv/lenet/src/dataset.py | 8 ++------ official/nlp/gnmt_v2/src/gnmt_model/dynamic_rnn.py | 2 +- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/official/cv/lenet/src/dataset.py b/official/cv/lenet/src/dataset.py index df9eecda1..cbe6918b6 100644 --- a/official/cv/lenet/src/dataset.py +++ b/official/cv/lenet/src/dataset.py @@ -33,27 +33,23 @@ def create_dataset(data_path, batch_size=32, repeat_size=1, resize_height, resize_width = 32, 32 rescale = 1.0 / 255.0 - shift = 0.0 rescale_nml = 1 / 0.3081 shift_nml = -1 * 0.1307 / 0.3081 # define map operations resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode - rescale_nml_op = CV.Rescale(rescale_nml, shift_nml) - rescale_op = CV.Rescale(rescale, shift) + rescale_nml_op = CV.Rescale(rescale_nml * rescale, shift_nml) hwc2chw_op = CV.HWC2CHW() type_cast_op = C.TypeCast(mstype.int32) # apply map operations on images mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers) - mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers) # apply DatasetOps - buffer_size = 10000 - mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script + mnist_ds = mnist_ds.shuffle(buffer_size=1024) mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) mnist_ds = mnist_ds.repeat(repeat_size) diff --git a/official/nlp/gnmt_v2/src/gnmt_model/dynamic_rnn.py b/official/nlp/gnmt_v2/src/gnmt_model/dynamic_rnn.py index 48ca08c30..7f4c91ab3 100644 --- a/official/nlp/gnmt_v2/src/gnmt_model/dynamic_rnn.py +++ b/official/nlp/gnmt_v2/src/gnmt_model/dynamic_rnn.py @@ -68,7 +68,7 @@ class DynamicRNNCell(nn.Cell): num_layers=1, has_bias=True, batch_first=False, - dropout=0, + dropout=0.0, bidirectional=False) def construct(self, x, init_h=None, init_c=None): -- Gitee