diff --git a/mindspeed_mm/tasks/inference/pipeline/parallel_wrapper.py b/mindspeed_mm/tasks/inference/pipeline/parallel_wrapper.py index 936a2a08e5a71b93927b9cc76ea511115ff959ef..143e26ce2e223a09f235b053832dfdea7cb271a9 100644 --- a/mindspeed_mm/tasks/inference/pipeline/parallel_wrapper.py +++ b/mindspeed_mm/tasks/inference/pipeline/parallel_wrapper.py @@ -167,7 +167,8 @@ class ParallelWrapper: Ensure the first dimension of `model_forward_kwargs` is the batch size. """ - first_dims = [v.shape[0] for k, v in model_forward_kwargs.items() if (k != "position_ids" and k != "cache_position" and v is not None)] + # first_dims = [v.shape[0] for k, v in model_forward_kwargs.items() if (k != "position_ids" and k != "cache_position" and v is not None)] + first_dims = [v.shape[0] for k, v in model_forward_kwargs.items() if (k != "position_ids" and k != "cache_position" and k != "pixel_values" and v is not None)] if "position_ids" in model_forward_kwargs.keys(): first_dims.append(model_forward_kwargs["position_ids"].shape[1]) if not len(set(first_dims)) == 1: