From c6bd8a6792e2456e44e7df2c3fcdc75347e1a123 Mon Sep 17 00:00:00 2001 From: dingli Date: Wed, 8 Jun 2022 18:37:26 +0800 Subject: [PATCH] add conv1d to conv2d optimizer --- magiconnx/optimize/optimizers/__init__.py | 3 +- .../optimizers/conv1dtoconv2d_optimizer.py | 73 +++++++++++++++++++ 2 files changed, 75 insertions(+), 1 deletion(-) create mode 100644 magiconnx/optimize/optimizers/conv1dtoconv2d_optimizer.py diff --git a/magiconnx/optimize/optimizers/__init__.py b/magiconnx/optimize/optimizers/__init__.py index 581dfd3..010f979 100644 --- a/magiconnx/optimize/optimizers/__init__.py +++ b/magiconnx/optimize/optimizers/__init__.py @@ -1,10 +1,11 @@ from .int64toint32_optimizer import Int64ToInt32Optimizer from .continuousslice_optimizer import ContinuousSliceOptimizer - +from .conv1dtoconv2d_optimizer import Conv1dtoConv2dOptimizer def get_optimizers_info(): supported_optimizers = { "Int64ToInt32Optimizer": Int64ToInt32Optimizer, "ContinuousSliceOptimizer": ContinuousSliceOptimizer, + "Conv1dtoConv2dOptimizer": Conv1dtoConv2dOptimizer, } return supported_optimizers diff --git a/magiconnx/optimize/optimizers/conv1dtoconv2d_optimizer.py b/magiconnx/optimize/optimizers/conv1dtoconv2d_optimizer.py new file mode 100644 index 0000000..14a75d9 --- /dev/null +++ b/magiconnx/optimize/optimizers/conv1dtoconv2d_optimizer.py @@ -0,0 +1,73 @@ +import numpy as np + +from ..base_optimizer import BaseOptimizer + +class Conv1dtoConv2dOptimizer(BaseOptimizer): + visited = [] + node_count = {'sq': 0, 'us': 0} + element_wise_list = ['Mul', 'Add', 'Sub', 'Div', 'BatchNormalization', 'LeakyRelu'] + + def optimize(self, graph): + input_nodes = [graph[i] for i in graph.inputs] + for i in input_nodes: + flag = self.walk_through_graph(self, graph, i, False, False) + return graph, flag + + @staticmethod + def conv1d2conv2d(graph, conv): + attrs = ('dilations', 'kernel_shape', 'strides') + for attr in attrs: + if attr in conv.attrs.keys(): + val = conv[attr][0] + conv[attr] = [1, val] + + if 'pads' in conv.attrs.keys(): + pds = conv['pads'][0] + conv['pads'] = [0, pds, 0, pds] + conv_w = graph[conv.inputs[1]].value + conv_w = np.expand_dims(conv_w, axis=-2) + graph[conv.inputs[1]].value = conv_w + + @staticmethod + def walk_through_graph(self, graph, node, modify, flag): + if node.name in self.visited: + return flag + if node.op_type in graph.outputs: + if modify: + id = self.node_count.get('sq') + self.node_count['sq'] = id + 1 + sq = graph.add_node('Sqz_%d' % id, 'Squeeze', {'axes': [2]}) + graph.insert_node(modeify, sq, mode='after') + return flag + + if not modify: + if node.op_type == 'Conv': + id = self.node_count.get('us') + self.node_count['us'] = id + 1 + us = graph.add_node('Unsqz_%d' % id, 'Unsqueeze', {'axes': [2]}) + graph.insert_node(node.name, us, mode='before') + self.conv1d2conv2d(graph, node) + modify = node.name + self.visited.append(node.name) + else: + if node.op_type == 'Conv': + self.conv1d2conv2d(graph, node) + modify = node.name + self.visited.append(node.name) + elif node.op_type not in Conv1dtoConv2dOptimizer.element_wise_list: + id = self.node_count.get('sq') + self.node_count['sq'] = id + 1 + sq = graph.add_node('Sqz_%d' % id, 'Squeeze', {'axes': [2]}) + graph.insert_node(modify, sq, mode='after') + modify = '' + else: + modify = node.name + self.visited.append(node.name) + + nexts = graph.get_next_nodes(node.name) + for i in nexts: + flag = self.walk_through_graph(self, graph, i, modify, flag) + + if self.visited: + return True + return False \ No newline at end of file -- Gitee