diff --git a/MindIE/MultiModal/Flux.1-DEV/inference_flux.py b/MindIE/MultiModal/Flux.1-DEV/inference_flux.py index 87465f8dcdcad172d97e6ee39b6ed6894098d117..44db5468ac6b089275e6776bb2b3c33ef257503c 100644 --- a/MindIE/MultiModal/Flux.1-DEV/inference_flux.py +++ b/MindIE/MultiModal/Flux.1-DEV/inference_flux.py @@ -187,8 +187,9 @@ def infer(args): T5_model, tensor_parallel={"tp_size": get_world_size()}, ) + T5_model.module.to("cpu") - pipe = FluxPipeline.from_pretrained(args.path, text_encoder_2=T5_model, torch_dtype=torch.bfloat16, local_files_only=True) + pipe = FluxPipeline.from_pretrained(args.path, torch_dtype=torch.bfloat16, local_files_only=True) if args.device_type == "A2-32g-single": torch.npu.set_device(args.device_id) @@ -198,6 +199,8 @@ def infer(args): pipe.to(f"npu:{args.device_id}") else: pipe.to(f"npu:{local_rank}") + pipe.text_encoder_2.to("cpu") + pipe.text_encoder_2 = T5_model.module.to(f"npu:{local_rank}") if args.use_cache: d_stream_config = CacheConfig( @@ -227,6 +230,10 @@ def infer(args): method="dit_block_cache", blocks_count=19, steps_count=args.infer_steps, + step_start=args.infer_steps, + step_interval=2, + block_start=18, + block_end=18, ) d_stream_agent = CacheAgent(d_stream_config) pipe.transformer.d_stream_agent = d_stream_agent @@ -234,6 +241,10 @@ def infer(args): method="dit_block_cache", blocks_count=38, steps_count=args.infer_steps, + step_start=args.infer_steps, + step_interval=2, + block_start=37, + block_end=37, ) s_stream_agent = CacheAgent(s_stream_config) pipe.transformer.s_stream_agent = s_stream_agent