diff --git a/configs/dapo_qwen3_30b_a3b_A3.yaml b/configs/dapo_qwen3_30b_a3b_A3.yaml index f4dfca88dec2f4a375cf86092385d06464677982..c6cf3819cb276afa55f510b331e1fcaedddbe3bf 100644 --- a/configs/dapo_qwen3_30b_a3b_A3.yaml +++ b/configs/dapo_qwen3_30b_a3b_A3.yaml @@ -52,8 +52,8 @@ actor_config: adam_beta1: 0.9 adam_beta2: 0.999 finetune: true - load: ./Qwen3-30B-A3B-tp4-pp1-ep2/ - save: ./save + load: ./ckpt + save: ./ckpt no_load_optim: true no_load_rng: true no_save_optim: false @@ -76,7 +76,7 @@ rl_config: epochs: 1 clip_ratio: 0.2 entropy_coeff: 0.0 - # shuffle_minibatch: false + shuffle_minibatch: false n_samples_per_prompt: 8 rule_reward: true verifier_function: ["acc_for_dapo"] @@ -119,6 +119,7 @@ generate_config: infer_tensor_parallel_size: 1 infer_pipeline_parallel_size: 1 infer_expert_parallel_size: 8 # 同步修改runtime_env.yaml中VLLM_DP_SIZE的值,建议保持一致 + infer_expert_tensor_parallel_size: 2 # vllm 模型相关设置 max_num_seqs: 1024 diff --git a/configs/ppo_qwen25_7b_A2.yaml b/configs/ppo_qwen25_7b_A2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e8c4e6888d306913e0c3ec9f3e3736a5d12d9373 --- /dev/null +++ b/configs/ppo_qwen25_7b_A2.yaml @@ -0,0 +1,132 @@ +defaults: + - model: + - qwen25_7b + +megatron_training: + model: qwen25_7b + use_fused_rmsnorm: true + use_mcore_models: true + sequence_parallel: true + use_flash_attn: true + no_masked_softmax_fusion: true + attention_softmax_in_fp32: true + no_gradient_accumulation_fusion: true + use_fused_swiglu: true + use_fused_rotary_pos_emb: true + bf16: true + use_distributed_optimizer: true + tokenizer_type: PretrainedFromHF + tokenizer_name_or_path: ./Qwen2.5-7B-Instruct + global_batch_size: 8 + seq_length: 2048 + save_interval: 200 + train_iters: 20000 + stage: ray_ppo + attention_dropout: 0.0 + init_method_std: 0.01 + hidden_dropout: 0.0 + distributed_backend: nccl + no_shared_storage: true + variable_seq_lengths: true + dataset_additional_keys: ['labels',] + data_path: ./data + split: 100,0,0 + no_shuffle: true + full_shuffle_instruction_dataset: false + +actor_config: + model: qwen25_7b + micro_batch_size: 1 + tensor_model_parallel_size: 4 + pipeline_model_parallel_size: 1 + lr: 1e-6 + lr_decay_style: constant + min_lr: 0 + weight_decay: 0.01 + lr_warmup_fraction: 0.0 + clip_grad: 1.0 + adam_beta1: 0.9 + adam_beta2: 0.95 + finetune: true + load: ./ckpt + save: ./ckpt + no_load_optim: true + no_load_rng: true + +critic_config: + model: qwen25_7b + tensor_model_parallel_size: 4 + pipeline_model_parallel_size: 1 + micro_batch_size: 1 + lr: 5e-6 + lr_decay_style: constant + min_lr: 0 + weight_decay: 0.01 + lr_warmup_fraction: 0.0 + clip_grad: 1.0 + adam_beta1: 0.9 + adam_beta2: 0.95 + finetune: true + no_load_optim: True + no_load_rng: True + load: ./ckpt + save: ./ckpt + +rl_config: + guarantee_order: true + use_integrated_worker: true + blocking: true + actor_forward_micro_batch_size: 1 + ref_forward_micro_batch_size: 1 + gamma: 1.0 + lam: 1.0 + adv_estimator: gae + kl_penalty: kl + kl_ctrl_type: fixed + init_kl_coef: 0.0 + mini_batch_size: 256 + max_prompt_length: 2048 + epochs: 1 + clip_ratio: 0.2 + cliprange_value: 0.5 + entropy_coeff: 0 + shuffle_mini_batch: false + n_samples_per_prompt: 1 + rule_reward: true + verifier_function: ["acc_for_ppo"] + verifier_weight: [1.0] + num_cpus_for_local_task: 1.0 + use_tensorboard: true + actor_resource: + num_npus: 8 + critic_resource: + num_npus: 8 + +generate_config: + enforce_eager: True + trust_remote_code: true + offload_train_optimizer: true + offload_train_grad: true + offload_train_param: true + + # 推理时的并行配置 + infer_tensor_parallel_size: 4 + infer_pipeline_parallel_size: 1 + infer_expert_parallel_size: 1 + + # vllm 模型相关设置 + max_num_seqs: 1024 + max_model_len: 4096 + max_num_batched_tokens: 8192 + dtype: "bfloat16" + gpu_memory_utilization: 0.8 + + # 采样配置 + sampling_config: + logprobs: 1 + max_tokens: 2048 + top_p: 1 + top_k: -1 + min_p: 0.0 + temperature: 1.0 + detokenize: false \ No newline at end of file diff --git a/examples/dapo/dapo_trainer_qwen3_30b_a3b.sh b/examples/dapo/dapo_trainer_qwen3_30b_a3b.sh index 69799f75a701418cf2af015160ab793d1a510afe..bff2bb1f43b72d9ddb4805e6d61ef7c99ad3296c 100644 --- a/examples/dapo/dapo_trainer_qwen3_30b_a3b.sh +++ b/examples/dapo/dapo_trainer_qwen3_30b_a3b.sh @@ -19,7 +19,7 @@ export LCAL_COMM_ID=127.0.0.1:27001 NNODES=1 NPUS_PER_NODE=16 #修改为对应主节点IP -MASTER_ADDR="localhost" +MASTER_ADDR="IP FOR MASTER NODE" #获取当前机器IP CURRENT_IP=$(ip -4 addr show $(ip -o -4 route show to default | awk '{print $5}') | grep -oP '(?<=inet\s)\d+(\.\d+){3}') @@ -62,4 +62,4 @@ else done fi -sleep 999999 \ No newline at end of file +sleep 999999 diff --git a/examples/ppo/ppo_trainer_qwen25_7b.sh b/examples/ppo/ppo_trainer_qwen25_7b.sh new file mode 100644 index 0000000000000000000000000000000000000000..23d739a8293ef42949fbf043bef22e5d1516762a --- /dev/null +++ b/examples/ppo/ppo_trainer_qwen25_7b.sh @@ -0,0 +1,62 @@ +pkill -9 python +ray stop --force +export RAY_DEDUP_LOGS=0 +export HYDRA_FULL_ERROR=1 + +DEFAULT_YAML="ppo_qwen25_7b_A2" +YAML=${1:-$DEFAULT_YAML} +echo "Use $YAML" + +ulimit -n 32768 +mkdir logs + +export TASK_QUEUE_ENABLE=2 +export HCCL_IF_BASE_PORT=24703 + +NNODES=1 +NPUS_PER_NODE=16 +#修改为对应主节点IP +MASTER_ADDR="localhost" +#获取当前机器IP +CURRENT_IP=$(ip -4 addr show $(ip -o -4 route show to default | awk '{print $5}') | grep -oP '(?<=inet\s)\d+(\.\d+){3}') + +if [ "$MASTER_ADDR" = "$CURRENT_IP" ]; then + # 主节点启动 + ray start --head --port 6866 --dashboard-host=0.0.0.0 --node-ip-address=$CURRENT_IP --dashboard-port=8260 --resources='{"NPU": '$NPUS_PER_NODE'}' + + while true; do + ray_status_output=$(ray status) + npu_count=$(echo "$ray_status_output" | grep -oP '(?<=/)\d+\.\d+(?=\s*NPU)' | head -n 1) + npu_count_int=$(echo "$npu_count" | awk '{print int($1)}') + device_count=$((npu_count_int / $NPUS_PER_NODE)) + + # 判断 device_count 是否与 NNODES 相等 + if [ "$device_count" -eq "$NNODES" ]; then + echo "Ray cluster is ready with $device_count devices (from $npu_count NPU resources), starting Python script." + ray status + python cli/train_ppo.py --config-name $YAML 2>&1 | tee logs/training.log + break + else + echo "Waiting for Ray to allocate $NNODES devices. Current device count: $device_count" + sleep 5 + fi + done +else + # 子节点尝试往主节点注册ray直到成功 + while true; do + # 尝试连接 Ray 集群 + ray start --address="$MASTER_ADDR:6866" --resources='{"NPU": '$NPUS_PER_NODE'}' --node-ip-address=$CURRENT_IP + + # 检查连接是否成功 + ray status + if [ $? -eq 0 ]; then + echo "Successfully connected to the Ray cluster!" + break + else + echo "Failed to connect to the Ray cluster. Retrying in 5 seconds..." + sleep 5 + fi + done +fi + +sleep 999999 \ No newline at end of file diff --git a/mindspeed_rl/config_cls/generate_config.py b/mindspeed_rl/config_cls/generate_config.py index e3f46dbc0d536bccbab6279170d552a0e6655d33..3c8bebea33a33c648366460704715af2316a7726 100644 --- a/mindspeed_rl/config_cls/generate_config.py +++ b/mindspeed_rl/config_cls/generate_config.py @@ -63,6 +63,9 @@ class GenerateConfig(BaseConfig): # 推理时的专家并行大小,默认为 1 self.infer_expert_parallel_size = 1 + # 推理时的ETP并行大小,默认为 1 + self.infer_expert_tensor_parallel_size = 1 + # 最大可处理的序列数量,默认为 1 self.max_num_seqs = 1 diff --git a/mindspeed_rl/models/base/base_inference_engine.py b/mindspeed_rl/models/base/base_inference_engine.py index 966c8c6775e590758a35388fb859aa73d5d6b58a..53066881a12e379e42935afa1944f18013e79b02 100644 --- a/mindspeed_rl/models/base/base_inference_engine.py +++ b/mindspeed_rl/models/base/base_inference_engine.py @@ -21,6 +21,7 @@ class BaseInferEngine(ABC): infer_tensor_parallel_size: int = 8, infer_pipeline_parallel_size: int = 1, infer_expert_parallel_size: int = 1, + infer_expert_tensor_parallel_size: int = 1, max_num_seqs: int = 1, # Default value set to 1 max_model_len: int = 2048, # Default value set to 2048 dtype: str = "bfloat16", # Default value set to "bfloat16" @@ -40,6 +41,7 @@ class BaseInferEngine(ABC): infer_tensor_parallel_size (int): Tensor parallel size during inference. infer_pipeline_parallel_size (int): Pipeline parallel size during inference. infer_expert_parallel_size (int): Expert parallel size during inference. + infer_expert_tensor_parallel_size (int): Expert tensor parallel size during inference. max_num_seqs (int): Maximum number of sequences to process simultaneously. Default is 1. max_model_len (int): Maximum model length (in tokens). Default is 2048. dtype (str): Data type for model weights. Default is "bfloat16". @@ -57,6 +59,7 @@ class BaseInferEngine(ABC): self.infer_tensor_parallel_size = infer_tensor_parallel_size self.infer_pipeline_parallel_size = infer_pipeline_parallel_size self.infer_expert_parallel_size = infer_expert_parallel_size + self.infer_expert_tensor_parallel_size = infer_expert_tensor_parallel_size self.max_num_seqs = max_num_seqs self.max_model_len = max_model_len self.dtype = dtype diff --git a/mindspeed_rl/models/rollout/vllm_adapter/megatron_weight_loaders.py b/mindspeed_rl/models/rollout/vllm_adapter/megatron_weight_loaders.py index 45fac868790b26bfbd7c2459f69de69457edc9ca..7105d50c91bba46b0982d06a0e2487b6022b7899 100644 --- a/mindspeed_rl/models/rollout/vllm_adapter/megatron_weight_loaders.py +++ b/mindspeed_rl/models/rollout/vllm_adapter/megatron_weight_loaders.py @@ -9,10 +9,11 @@ from transformers.configuration_utils import PretrainedConfig class InferParallelConfig: - def __init__(self, infer_tensor_parallel_size: int, infer_pipeline_parallel_size: int, infer_expert_parallel_size: int): + def __init__(self, infer_tensor_parallel_size: int, infer_pipeline_parallel_size: int, infer_expert_parallel_size: int, infer_expert_tensor_parallel_size:int): self.infer_tensor_parallel_size = infer_tensor_parallel_size self.infer_pipeline_parallel_size = infer_pipeline_parallel_size self.infer_expert_parallel_size = infer_expert_parallel_size + self.infer_expert_tensor_parallel_size = infer_expert_tensor_parallel_size def load_megatron_weights(actor_weights: Dict, vllm_model: nn.Module, diff --git a/mindspeed_rl/models/rollout/vllm_engine.py b/mindspeed_rl/models/rollout/vllm_engine.py index 73fe1af46c13aa360b532e1cb9e2b05fbab2d053..bfc6b62bd6ccf984456e004bdabf65057c4d5571 100644 --- a/mindspeed_rl/models/rollout/vllm_engine.py +++ b/mindspeed_rl/models/rollout/vllm_engine.py @@ -41,6 +41,7 @@ class VLLMInferEngine(BaseInferEngine): infer_tensor_parallel_size: int, infer_pipeline_parallel_size: int, infer_expert_parallel_size: int, + infer_expert_tensor_parallel_size: int, sampling_config: dict, prompt_type: str = None, prompt_type_path: str = None, @@ -72,6 +73,7 @@ class VLLMInferEngine(BaseInferEngine): infer_tensor_parallel_size (int): Tensor parallel size during inference. infer_pipeline_parallel_size (int): Pipeline parallel size during inference. infer_expert_parallel_size (int): Expert parallel size during inference. + infer_expert_tensor_parallel_size (int): Expert tensor parallel size during inference. sampling_config (dict): Configuration for text generation sampling. enable_prefix_caching (bool): Whether to enable prefix caching. num_scheduler_steps (int): Num scheduler steps. Default is 1. @@ -94,6 +96,7 @@ class VLLMInferEngine(BaseInferEngine): infer_tensor_parallel_size=infer_tensor_parallel_size, infer_pipeline_parallel_size=infer_pipeline_parallel_size, infer_expert_parallel_size=infer_expert_parallel_size, + infer_expert_tensor_parallel_size=infer_expert_tensor_parallel_size, max_num_seqs=max_num_seqs, max_model_len=max_model_len, dtype=dtype, @@ -150,7 +153,8 @@ class VLLMInferEngine(BaseInferEngine): train_pipeline_model_parallel_size=train_pipeline_parallel_size, train_expert_model_parallel_size=train_expert_parallel_size, infer_expert_model_parallel_size=infer_expert_parallel_size, - train_context_model_parallel_size=train_context_parallel_size + train_context_model_parallel_size=train_context_parallel_size, + infer_expert_tensor_parallel_size=infer_expert_tensor_parallel_size ) if load_format == "megatron": diff --git a/mindspeed_rl/workers/actor_hybrid_worker.py b/mindspeed_rl/workers/actor_hybrid_worker.py index f317908b20cba2f2ac8942a8ab8df05bb8785dc1..3d768a766f3ea516ae4cd4dd51d5614504db2a61 100644 --- a/mindspeed_rl/workers/actor_hybrid_worker.py +++ b/mindspeed_rl/workers/actor_hybrid_worker.py @@ -526,6 +526,7 @@ class ActorHybridWorkerBase(BaseWorker): infer_tensor_parallel_size=self.generate_config.infer_tensor_parallel_size, infer_pipeline_parallel_size=self.generate_config.infer_pipeline_parallel_size, infer_expert_parallel_size=self.generate_config.infer_expert_parallel_size, + infer_expert_tensor_parallel_size=self.generate_config.infer_expert_tensor_parallel_size, megatron_config=self.megatron_config, sampling_config=sampling_config, enable_prefix_caching=self.generate_config.enable_prefix_caching, @@ -552,6 +553,7 @@ class ActorHybridWorkerBase(BaseWorker): infer_tensor_parallel_size=self.generate_config.infer_tensor_parallel_size, infer_pipeline_parallel_size=self.generate_config.infer_pipeline_parallel_size, infer_expert_parallel_size=self.generate_config.infer_expert_parallel_size, + infer_expert_tensor_parallel_size=self.generate_config.infer_expert_tensor_parallel_size, num_layer_list=self.megatron_config.num_layer_list, moe_tp_extend_ep=self.megatron_config.moe_tp_extend_ep, parallel_state=self.parallel_state, diff --git a/mindspeed_rl/workers/resharding/megatron_sharding_manager.py b/mindspeed_rl/workers/resharding/megatron_sharding_manager.py index 0bf33117825b4547bd330f18ba1131ac148f34c2..3197e60b3b575a54eaa8bb85349487a222aa26d4 100644 --- a/mindspeed_rl/workers/resharding/megatron_sharding_manager.py +++ b/mindspeed_rl/workers/resharding/megatron_sharding_manager.py @@ -148,6 +148,7 @@ class MegatronShardingManager: infer_tensor_parallel_size=None, infer_pipeline_parallel_size=None, infer_expert_parallel_size=None, + infer_expert_tensor_parallel_size=None, num_layer_list=None, moe_tp_extend_ep=None, parallel_state=None, @@ -167,6 +168,7 @@ class MegatronShardingManager: infer_tensor_parallel_size (int): Tensor parallel size during inference. infer_pipeline_parallel_size (int): Pipeline parallel size during inference. infer_expert_parallel_size (int): Expert parallel size during inference. + infer_expert_tensor_parallel_size (int): Expert tensor parallel size during inference. num_layer_list (str): a list of number of layers, seperated by comma; e.g., 4,4,4,4. moe_tp_extend_ep (bool): Controls whether expert model parameters are split across multiple GPUs. parallel_state (ModuleType): Megatron parallel state of the model. @@ -184,6 +186,7 @@ class MegatronShardingManager: infer_tensor_parallel_size=infer_tensor_parallel_size, infer_pipeline_parallel_size=infer_pipeline_parallel_size, infer_expert_parallel_size=infer_expert_parallel_size, + infer_expert_tensor_parallel_size=infer_expert_tensor_parallel_size, num_layer_list=num_layer_list, moe_tp_extend_ep=moe_tp_extend_ep, parallel_state=parallel_state, diff --git a/mindspeed_rl/workers/resharding/vllm_weight_container.py b/mindspeed_rl/workers/resharding/vllm_weight_container.py index 729a602e51ec34d34eb7658d1f56edb2506f697b..c7cd07d26876de4863f1da028adc42d925374d9b 100644 --- a/mindspeed_rl/workers/resharding/vllm_weight_container.py +++ b/mindspeed_rl/workers/resharding/vllm_weight_container.py @@ -44,6 +44,7 @@ class MegatronStyleVllmWeightContainer: def __init__(self, megatron_model, vllm_model, model_config, infer_tensor_parallel_size, infer_pipeline_parallel_size, infer_expert_parallel_size, + infer_expert_tensor_parallel_size, num_layer_list, moe_tp_extend_ep=False, parallel_state=None, @@ -59,6 +60,7 @@ class MegatronStyleVllmWeightContainer: infer_tensor_parallel_size (int): Inference tensor parallel size infer_pipeline_parallel_size (int): Inference pipeline parallel size infer_expert_parallel_size (int): Inference expert parallel size + infer_expert_tensor_parallel_size (int): Inference expert tensor parallel size num_layer_list (str): a list of number of layers, seperated by comma; e.g., 4,4,4,4. moe_tp_extend_ep (bool): Controls whether expert model parameters are split across multiple GPUs. parallel_state (ModuleType): Megatron parallel state of the model. @@ -118,10 +120,10 @@ class MegatronStyleVllmWeightContainer: self.moe_tp_extend_ep = moe_tp_extend_ep # TODO: infer_expert_tensor_parallel_size and num_process is fixed. - self.infer_expert_tensor_parallel_size = 1 + self._infer_etp_size = infer_expert_tensor_parallel_size self.num_process = 1 self._infer_ep_size = self._infer_ep_size * self._infer_tp_size - self.experts_memory_expand_N = self._infer_ep_size // self._ep_size + self.experts_memory_expand_N = self._infer_ep_size // self._ep_size * self._infer_etp_size # validate parallel configs self._validate_parallel_config() @@ -163,6 +165,10 @@ class MegatronStyleVllmWeightContainer: raise ValueError( f"Do not support split train tp size {self._tp_size} to infer tp size {self._infer_tp_size} " f"with train dp size {(self._world_size // (self._tp_size * self._pp_size))}.") + if self._world_size % self._infer_ep_size != 0: + raise ValueError( + f"Do not support infer etp size {self._infer_etp_size} with infer ep size {self._infer_ep_size} and infer tp size {self._infer_tp_size}" + ) def get_infer_params(self): """ @@ -334,6 +340,64 @@ class MegatronStyleVllmWeightContainer: return res.permute(1, 0, 2).contiguous() + def expert_tensor_parallel(self, megatron_param, name): + # the true ep size, equal to ep_size * etp_size + if self.infer_etp_size == 1: + return megatron_param + + etp_size = self._infer_etp_size + + num_experts = self.num_local_experts + + # 1. build ep_etp matrix buffer + # For megatron param [e0, e1], we make it [a0, a1, b0, b1], in which e0 == [a0, b0] + + # weight1: column cut, be like [g0, u0, g1, u1, ...] + if 'weight1' in name: + + hidden_size = megatron_param.shape[0] + megatron_param = torch.cat(megatron_param.view(num_experts, hidden_size, -1).unbind(0), dim=1) + + # We can treat both the gate and the up weight as 2 independent experts. + num_experts *= 2 + + # weight2: row cut, be like [ d0, d1, d2, ...]^T + elif 'weight2' in name: + + hidden_size = megatron_param.shape[1] + megatron_param = torch.cat(megatron_param.view(num_experts, -1, hidden_size).unbind(0), dim=0) + + # transpose params to handle uniformly with column cut + megatron_param = megatron_param.t() + + else: + return megatron_param + + # chunk to etp * ep parts + chunks = torch.chunk(megatron_param, etp_size * num_experts, dim=1) + + # re-select by etp-ep order + # e.g. ETP=2 num_experts=4, old order [1,2,3,4,5,6,7,8], new order [1,3,5,7,2,4,6,8] + new_order = [] + for i in range(self._ep_size): + for j in range(num_experts // self._ep_size): + for k in range(self._infer_etp_size): + new_order.append(chunks[i * num_experts // self._ep_size * etp_size + j * etp_size + k]) + + reordered_x = torch.cat(new_order, dim=1) + new_size = etp_size * num_experts // self._world_size + final_chunks = torch.chunk(reordered_x, new_size, dim=1) + + etp_rank = torch.distributed.get_rank(group=self._ep_group) + res = final_chunks[etp_rank] + total_experts = self.num_local_experts * etp_size + res = res.reshape(hidden_size, total_experts, -1) + + if 'weight2' in name: + return res.permute(1, 2, 0).contiguous() + return res.permute(1, 0, 2).contiguous() + + def _update_weight_buffers_intra_pp(self): """ Here, we only update the current training pp_rank's buffer. @@ -346,6 +410,7 @@ class MegatronStyleVllmWeightContainer: infer_param = self.allgather_tp_param(megatron_param, name) infer_param = self.split_tp_params(infer_param, name) infer_param = self.trans_ep_params_to_tp(infer_param, name) + infer_param = self.expert_tensor_parallel(infer_param, name) return infer_param pp_rank = self._pp_rank diff --git a/tests/st/resharding/test_resharding.py b/tests/st/resharding/test_resharding.py index 20a1d42ec8de6d4ac65c419d4b40f0e296d7e27d..e229c7feb14c30cb77b201f0f32d1884c82fffc7 100644 --- a/tests/st/resharding/test_resharding.py +++ b/tests/st/resharding/test_resharding.py @@ -315,6 +315,7 @@ class TestActor(): infer_tensor_parallel_size=args.infer_tp, infer_pipeline_parallel_size=args.infer_pp, infer_expert_parallel_size=args.infer_ep, + infer_expert_tensor_parallel_size=args.infer_etp, sampling_config=sampling_config, max_num_seqs=16, max_model_len=4096, @@ -331,6 +332,7 @@ class TestActor(): infer_tensor_parallel_size=args.infer_tp, infer_pipeline_parallel_size=args.infer_pp, infer_expert_parallel_size=args.infer_ep, + infer_expert_tensor_parallel_size=args.infer_etp, num_layer_list=None, moe_tp_extend_ep=False, parallel_state=mpu, @@ -375,6 +377,7 @@ def parse_args(): parser.add_argument("--infer-tp", type=int, default=4) parser.add_argument("--infer-pp", type=int, default=1) parser.add_argument("--infer-ep", type=int, default=1) + parser.add_argument("--infer-etp", type=int, default=1) return parser.parse_args()