diff --git a/inferrt/python/mrt/torch/fx_backend.py b/inferrt/python/mrt/torch/fx_backend.py index 13d76931689a0157e4abf4fba066e7cc41b10085..dd65658cca8df62209814607284d1dd6485d5ed9 100644 --- a/inferrt/python/mrt/torch/fx_backend.py +++ b/inferrt/python/mrt/torch/fx_backend.py @@ -185,13 +185,83 @@ def _get_op(target): return Op.custom_call return None +def _argument_to_real_value(value_type, value, arg_len): + """ + Convert a torch fx value to its real value. + + Args: + value_type (torch.dtype): The type of the value. + value (Any): The value of the argument. + + Returns: + Any: The real value of the argument. + """ + if isinstance(value_type, torch.OptionalType): + return _argument_to_real_value(value_type.getElementType(), value, arg_len) + if isinstance(value_type, torch.ListType): + if isinstance(value, list): + return value + if value is None: + return value + if not arg_len: + return [value] + return [value for _ in range(arg_len)] + return value + + +def _create_args(schema: torch.FunctionSchema, node: Node) -> List[Argument]: + """ + Create a list of Argument objects from a torch fx node. -def _get_op_schemas( - target: OpOverload | OpOverloadPacket, -) -> Optional[List[torch._C.FunctionSchema]]: + Args: + schema (torch.FunctionSchema): The schema of the node. + node (torch.fx.Node): The FX node whose arguments should be created. + + Returns: + List[Argument]: A list of Argument objects in the node's arguments, preserving order. + Bool: Whether the arguments are valid. + """ + flat_args = [] + args = node.args + kwargs = node.kwargs + arg_idx = 0 + if len(args) + len(kwargs) > len(schema.arguments): + return flat_args, False + + for arg in args: + if schema.arguments[arg_idx].kwarg_only: + return flat_args, False + real_arg = _argument_to_real_value(schema.arguments[arg_idx].real_type, arg, schema.arguments[arg_idx].N) + flat_args.append(real_arg) + arg_idx += 1 + + consumed_kwargs = 0 + for argument in schema.arguments[arg_idx:]: + if argument.name in kwargs: + real_arg = _argument_to_real_value(argument.real_type, kwargs[argument.name], argument.N) + flat_args.append(real_arg) + consumed_kwargs += 1 + elif hasattr(argument, "default_value"): + flat_args.append(argument.default_value) + else: + return flat_args, False + + if consumed_kwargs != len(kwargs): + return flat_args, False + return flat_args, True + + +def _get_op_schemas(target) -> Optional[List[torch._C.FunctionSchema]]: """ Retrieve torch schema(s) for a given op target. Returns None if unavailable. """ + if isinstance(target, str): + for ns in iter(torch.ops): + ops_ns = getattr(torch.ops, ns) + if hasattr(ops_ns, target): + op_target = getattr(ops_ns, target) + return [getattr(op_target, overload)._schema for overload in op_target.overloads()] + return None if isinstance(target, OpOverload): return [target._schema] @@ -217,30 +287,18 @@ def _flatten_args(op: Op, node: Node) -> List[Argument]: Returns: List[Argument]: A flat list of all Argument objects in the node's arguments, preserving order. """ - flat_args = list(node.args) - # for custom op - if op == Op.custom_call: - op_name = node.target.__name__ - flat_args = [op_name] + flat_args - return flat_args - kwargs = node.kwargs - if not kwargs: - return flat_args - # if kwargs has only one element, add the value to flat_args and return - if len(kwargs) == 1: - flat_args.append(list(kwargs.values())[0]) - return flat_args + flat_args = [] schemas = _get_op_schemas(node.target) if not schemas: - raise RuntimeError(f"Cannot resolve schemas for op: {node.target}") - if len(schemas) != 1: - raise RuntimeError("Currently, do not support op overload") - schema = next(iter(schemas)) - for argument in schema.arguments: - if not argument.kwarg_only: - continue - if argument.name in kwargs: - flat_args.append(kwargs[argument.name]) + return list(node.args) + list(node.kwargs.values()) + found = False + for schema in schemas: + flat_args, found = _create_args(schema, node) + if found: + break + if not found: + err_msg = f"Failed to find a valid schema for {node.target} with arguments {node.args} and kwargs {node.kwargs}" + raise ValueError(err_msg) return flat_args @@ -311,6 +369,9 @@ def backend(gm: GraphModule, example_inputs: List[torch.Tensor]): raise NotImplementedError(f"Unsupported op: {node.target}") flat_node_args = _flatten_args(op, node) + if op == Op.custom_call: + op_name = node.target.__name__ + flat_node_args = [op_name] + flat_node_args input_nodes = _map_args(flat_node_args, env, executor) hook_func = get_arg_mapping_hook(op) if hook_func is not None: