diff --git a/mindspeed_mm/models/diffusion/cogvideo_diffusion.py b/mindspeed_mm/models/diffusion/cogvideo_diffusion.py index 335979bbb94b897c119df31c76176205358d7141..e5d539df10f3413f5210ac249d867dc06251f54e 100644 --- a/mindspeed_mm/models/diffusion/cogvideo_diffusion.py +++ b/mindspeed_mm/models/diffusion/cogvideo_diffusion.py @@ -391,7 +391,7 @@ class CogVideoDiffusion(nn.Module): latent_model_input = torch.cat([latents] * 2) latent_model_input = self.diffusion.scale_model_input(latent_model_input, t) current_timestep = t.expand(latent_model_input.shape[0]) - model_kwargs["latents"] = latent_model_input.permute(0, 2, 1, 3, 4) + model_kwargs["latent"] = latent_model_input.permute(0, 2, 1, 3, 4) with torch.no_grad(): noise_pred = model(timestep=current_timestep, **model_kwargs) diff --git a/mindspeed_mm/models/diffusion/diffusers_scheduler.py b/mindspeed_mm/models/diffusion/diffusers_scheduler.py index cb19c1547d5e205c7d0612450a1e99301f838d4b..5307b2a27bf64c05252b978f12a100c8e67b1d8f 100644 --- a/mindspeed_mm/models/diffusion/diffusers_scheduler.py +++ b/mindspeed_mm/models/diffusion/diffusers_scheduler.py @@ -241,9 +241,9 @@ class DiffusersScheduler: current_timestep = t.expand(latent_model_input.shape[0]) if use_dynamic_cfg: # b t c h w -> b c t h w - model_kwargs["latents"] = latent_model_input.permute(0, 2, 1, 3, 4) + model_kwargs["latent"] = latent_model_input.permute(0, 2, 1, 3, 4) else: - model_kwargs["latents"] = latent_model_input + model_kwargs["latent"] = latent_model_input video_mask = torch.ones_like(latent_model_input)[:, 0] world_size = model_kwargs.get("world_size", 1) video_mask = video_mask.repeat(1, world_size, 1, 1) diff --git a/mindspeed_mm/models/predictor/dits/hunyuan_video_dit.py b/mindspeed_mm/models/predictor/dits/hunyuan_video_dit.py index 9828881d562b7ac902e16c38afa44e208faf076a..cdc46bfd29912872d3cb5ee0ca4b4bcf5c1ca0b2 100644 --- a/mindspeed_mm/models/predictor/dits/hunyuan_video_dit.py +++ b/mindspeed_mm/models/predictor/dits/hunyuan_video_dit.py @@ -400,27 +400,27 @@ class HunyuanVideoDiT(MultiModalModule): def forward( self, - x: torch.Tensor, + latent: torch.Tensor, timestep: torch.Tensor, - prompt: List[torch.Tensor], - prompt_mask: Union[torch.Tensor, List[torch.Tensor]] = None, + text_embed: List[torch.Tensor], + text_mask: Union[torch.Tensor, List[torch.Tensor]] = None, guidance: torch.Tensor = None, **kwargs ): - bs, _, ot, oh, ow = x.shape + bs, _, ot, oh, ow = latent.shape tt, th, tw = ( ot // self.patch_size[0], oh // self.patch_size[1], ow // self.patch_size[2], ) - x = x.to(self.dtype) + latent = latent.to(self.dtype) - if isinstance(prompt_mask, list): - prompt_mask = prompt_mask[0] - prompt_mask = prompt_mask.to(x.device) - prompt[0] = prompt[0].view(-1, prompt[0].shape[-2], prompt[0].shape[-1]).to(self.dtype) # B*N, seq_len, Dim - prompt[1] = prompt[1].to(self.dtype) # B, N, Dim - prompt_mask = prompt_mask.view(-1, prompt_mask.shape[-1]) # B*N, seqlen + if isinstance(text_mask, list): + text_mask = text_mask[0] + text_mask = text_mask.to(latent.device) + text_embed[0] = text_embed[0].view(-1, text_embed[0].shape[-2], text_embed[0].shape[-1]).to(self.dtype) # B*N, seq_len, Dim + text_embed[1] = text_embed[1].to(self.dtype) # B, N, Dim + text_mask = text_mask.view(-1, text_mask.shape[-1]) # B*N, seqlen # Prepare modulation vectors vec = self.time_in(timestep) @@ -435,7 +435,7 @@ class HunyuanVideoDiT(MultiModalModule): frist_frame_token_num = None # text modulation - vec_2 = self.vector_in(prompt[1]) + vec_2 = self.vector_in(text_embed[1]) vec = vec + vec_2 if self.i2v_condition_type == "token_replace": token_replace_vec = token_replace_vec + vec_2 @@ -444,17 +444,17 @@ class HunyuanVideoDiT(MultiModalModule): if self.guidance_embed: if guidance is None: guidance = torch.tensor( - [self.embeded_guidance_scale] * x.shape[0], + [self.embeded_guidance_scale] * latent.shape[0], dtype=torch.float32, ).to(vec.device).to(vec.dtype) * 1000.0 vec = vec + self.guidance_in(guidance) - img = self.img_in(x) + img = self.img_in(latent) if self.text_projection == "linear": - txt = self.txt_in(prompt[0]) + txt = self.txt_in(text_embed[0]) elif self.text_projection == "single_refiner": - txt = self.txt_in(prompt[0], timestep, prompt_mask if self.use_attention_mask else None) + txt = self.txt_in(text_embed[0], timestep, text_mask if self.use_attention_mask else None) else: raise NotImplementedError( f"Unsupported text_projection: {self.text_projection}" @@ -464,12 +464,12 @@ class HunyuanVideoDiT(MultiModalModule): img_seq_len = img.shape[1] # compute cu_squlens and max_seqlen for flash attention - cu_seqlens_q = _get_cu_seqlens(prompt_mask, img_seq_len) + cu_seqlens_q = _get_cu_seqlens(text_mask, img_seq_len) cu_seqlens_kv = cu_seqlens_q max_seqlen_q = img_seq_len + txt_seq_len max_seqlen_kv = max_seqlen_q - rope_sizes = list(x.shape)[-3:] + rope_sizes = list(latent.shape)[-3:] rope_sizes = [rope_sizes[i] // self.patch_size[i] for i in range(3)] freqs_cos, freqs_sin = get_nd_rotary_pos_embed( self.rope_dim_list, @@ -481,7 +481,7 @@ class HunyuanVideoDiT(MultiModalModule): freqs_sin = freqs_sin.unsqueeze(0).unsqueeze(2).to(device=vec.device, dtype=vec.dtype) # b s n d if bs == 1: - txt = txt[:, :prompt_mask.sum()] + txt = txt[:, :text_mask.sum()] txt_seq_len = txt.shape[1] # RNG context @@ -542,14 +542,14 @@ class HunyuanVideoDiT(MultiModalModule): if self.sequence_parallel: txt = txt.transpose(0, 1).contiguous() - x = torch.cat([img, txt], dim=0) + latent = torch.cat([img, txt], dim=0) else: - x = torch.cat([img, txt], dim=1) + latent = torch.cat([img, txt], dim=1) # single_stream - x = self._checkpointed_forward( + latent = self._checkpointed_forward( "single_stream", - (x, ), + (latent,), vec, img_seq_len, cu_seqlens_q, @@ -579,13 +579,13 @@ class HunyuanVideoDiT(MultiModalModule): ) if self.sequence_parallel: txt = txt.transpose(0, 1).contiguous() - x = torch.cat([img, txt], dim=0) + latent = torch.cat([img, txt], dim=0) else: - x = torch.cat([img, txt], dim=1) + latent = torch.cat([img, txt], dim=1) for _, block in enumerate(self.single_blocks): - x = block( - x=x, + latent = block( + x=latent, vec=vec, img_len=img_seq_len, cu_seqlens_q=cu_seqlens_q, @@ -599,9 +599,9 @@ class HunyuanVideoDiT(MultiModalModule): )[0] if self.sequence_parallel: - img = x[:img_seq_len // mpu.get_tensor_model_parallel_world_size()] + img = latent[:img_seq_len // mpu.get_tensor_model_parallel_world_size()] else: - img = x[:, :img_seq_len] + img = latent[:, :img_seq_len] # --------------------- Final layer ------------ if self.enable_tensor_parallel: @@ -618,28 +618,28 @@ class HunyuanVideoDiT(MultiModalModule): else: shift, scale = self.adaLN_modulation(vec).chunk(2, dim=-1) - x = self.norm_final(img) * (1 + scale) + shift + latent = self.norm_final(img) * (1 + scale) + shift if self.enable_tensor_parallel: - x = self.proj_out(x)[0] + latent = self.proj_out(latent)[0] if self.sequence_parallel: - x = tensor_parallel.mappings.all_gather_last_dim_from_tensor_parallel_region(x) + latent = tensor_parallel.mappings.all_gather_last_dim_from_tensor_parallel_region(latent) else: - x = tensor_parallel.mappings.gather_from_tensor_model_parallel_region(x) + latent = tensor_parallel.mappings.gather_from_tensor_model_parallel_region(latent) else: - x = self.proj_out(x) + latent = self.proj_out(latent) if self.sequence_parallel: - x = x.transpose(0, 1).contiguous() # s b h -> b s h + latent = latent.transpose(0, 1).contiguous() # s b h -> b s h if self.context_parallel_algo is not None: - x = gather_forward_split_backward( - x, + latent = gather_forward_split_backward( + latent, mpu.get_context_parallel_group(), dim=1, grad_scale="up" ) - output = self.unpatchify(x, tt, th, tw) + output = self.unpatchify(latent, tt, th, tw) return output diff --git a/mindspeed_mm/models/predictor/dits/sat_dit.py b/mindspeed_mm/models/predictor/dits/sat_dit.py index 98c3aab82fb5c92ef170206e340cd2052ee0993b..f2fc96d29b6d37e780a3ae245e3c7dc1f836545e 100644 --- a/mindspeed_mm/models/predictor/dits/sat_dit.py +++ b/mindspeed_mm/models/predictor/dits/sat_dit.py @@ -231,26 +231,20 @@ class SatDiT(MultiModalModule): def forward( self, - latents: torch.Tensor, + latent: torch.Tensor, timestep: Optional[torch.Tensor] = None, - prompt: Optional[torch.Tensor] = None, + text_embed: Optional[torch.Tensor] = None, + text_mask: Optional[torch.Tensor] = None, video_mask: Optional[torch.Tensor] = None, - prompt_mask: Optional[torch.Tensor] = None, - added_cond_kwargs: Dict[str, torch.Tensor] = None, - class_labels: Optional[torch.Tensor] = None, - use_image_num: Optional[int] = 0, **kwargs ) -> torch.Tensor: """ Args: - latents: Shape (batch size, num latent pixels) if discrete, shape (batch size, channel, height, width) if continuous. + latent: Shape (batch size, num latent pixels) if discrete, shape (batch size, channel, height, width) if continuous. timestep: Used to indicate denoising step. Optional timestep to be applied as an embedding in AdaLayerNorm. - prompt: Conditional embeddings for cross attention layer. + text_embed: Conditional embeddings for cross attention layer. video_mask: An attention mask of shape (batch, key_tokens) is applied to latents. - prompt_mask: Cross-attention mask applied to prompt. - added_cond_kwargs: resolution or aspect_ratio. - class_labels: Used to indicate class labels conditioning. - use_image_num: The number of images use for trainning. + text_mask: Cross-attention mask applied to prompt. """ # RNG context @@ -260,8 +254,8 @@ class SatDiT(MultiModalModule): rng_context = nullcontext() if self.pre_process: - _, _, t, h, w = latents.shape - frames = t - use_image_num + _, _, t, h, w = latent.shape + frames = t vid_mask, img_mask = None, None prompt_vid_mask, prompt_img_mask = None, None @@ -270,25 +264,23 @@ class SatDiT(MultiModalModule): height, width = h // self.patch_size_h, w // self.patch_size_w if "masked_video" in kwargs.keys() and kwargs["masked_video"] is not None: - latents = torch.cat([latents, kwargs["masked_video"]], dim=1) + latent = torch.cat([latent, kwargs["masked_video"]], dim=1) - added_cond_kwargs = {"resolution": None, "aspect_ratio": None} - latents_vid, latents_img, prompt_vid, prompt_img, timestep_vid, timestep_img, \ - embedded_timestep_vid, embedded_timestep_img = self._operate_on_patched_inputs( - latents, prompt, timestep, frames) + latent_vid, latent_img, prompt_vid, prompt_img, timestep_vid = self._operate_on_patched_inputs( + latent, text_embed, timestep, frames) if self.concat_text_embed: - latents_vid = torch.cat((prompt_vid, latents_vid), dim=1) + latent_vid = torch.cat((prompt_vid, latent_vid), dim=1) if self.enable_sequence_parallelism or self.sequence_parallel: - latents_vid = latents_vid.transpose(0, 1).contiguous() + latent_vid = latent_vid.transpose(0, 1).contiguous() if self.enable_sequence_parallelism: - latents_vid = split_forward_gather_backward(latents_vid, mpu.get_context_parallel_group(), dim=0, + latent_vid = split_forward_gather_backward(latent_vid, mpu.get_context_parallel_group(), dim=0, grad_scale='down') if self.sequence_parallel: - latents_vid = tensor_parallel.scatter_to_sequence_parallel_region(latents_vid) + latent_vid = tensor_parallel.scatter_to_sequence_parallel_region(latent_vid) else: t, h, w = self.input_size # PP currently does not support dynamic resolution. - frames = t - use_image_num + frames = t vid_mask, img_mask = None, None prompt_vid_mask, prompt_img_mask = None, None @@ -296,8 +288,8 @@ class SatDiT(MultiModalModule): frames = ((frames - 1) // self.patch_size_t + 1) if frames % 2 == 1 else frames // self.patch_size_t # patchfy height, width = h // self.patch_size_h, w // self.patch_size_w - latents_vid = latents - prompt_vid = prompt + latent_vid = latent + prompt_vid = text_embed timestep_vid = timestep frames = torch.tensor(frames) @@ -316,14 +308,13 @@ class SatDiT(MultiModalModule): with rng_context: if self.recompute_granularity == "full": - if latents_vid is not None: - latents_vid = self._checkpointed_forward( - latents_vid, + if latent_vid is not None: + latent_vid = self._checkpointed_forward( + latent_vid, video_mask=vid_mask, prompt=prompt_vid, prompt_mask=prompt_vid_mask, timestep=timestep_vid, - class_labels=class_labels, frames=frames, height=height, width=width, @@ -332,14 +323,13 @@ class SatDiT(MultiModalModule): ) else: for block in self.videodit_blocks: - if latents_vid is not None: - latents_vid = block( - latents_vid, + if latent_vid is not None: + latent_vid = block( + latent_vid, video_mask=vid_mask, prompt=prompt_vid, prompt_mask=prompt_vid_mask, timestep=timestep_vid, - class_labels=class_labels, frames=frames, height=height, width=width, @@ -352,9 +342,9 @@ class SatDiT(MultiModalModule): # 3. Output if self.post_process: output_vid, output_img = None, None - if latents_vid is not None: + if latent_vid is not None: output_vid = self._get_output_for_patched_inputs( - latents=latents_vid, + latent=latent_vid, timestep=timestep_vid, height=height, width=width, @@ -368,7 +358,7 @@ class SatDiT(MultiModalModule): output = output_img return output, prompt_vid, timestep_vid else: - return latents_vid, prompt_vid, timestep_vid + return latent_vid, prompt_vid, timestep_vid def _get_block(self, layer_number): return self.videodit_blocks[layer_number] @@ -380,7 +370,6 @@ class SatDiT(MultiModalModule): prompt, prompt_mask, timestep, - class_labels, frames, height, width, @@ -413,7 +402,6 @@ class SatDiT(MultiModalModule): video_mask, prompt_mask, timestep, - class_labels, frames, height, width, @@ -432,7 +420,6 @@ class SatDiT(MultiModalModule): video_mask, prompt_mask, timestep, - class_labels, frames, height, width, @@ -447,7 +434,6 @@ class SatDiT(MultiModalModule): prompt=prompt, prompt_mask=prompt_mask, timestep=timestep, - class_labels=class_labels, frames=frames, height=height, width=width, @@ -469,47 +455,46 @@ class SatDiT(MultiModalModule): buffers = tuple(self.buffers()) return buffers[0].dtype - def _operate_on_patched_inputs(self, latents, prompt, timestep, frames): - b, _, t, h, w = latents.shape + def _operate_on_patched_inputs(self, latent, text_embed, timestep, frames): + b, _, t, h, w = latent.shape if self.rope is not None: - latents_vid, latents_img = self.patch_embed(latents.to(self.dtype), prompt, + latents_vid, latents_img = self.patch_embed(latent.to(self.dtype), text_embed, rope_T=t // self.patch_size[0], rope_H=h // self.patch_size[1], rope_W=w // self.patch_size[2]) _, seq_len, _ = latents_vid.shape - pos_emb = self.rope.position_embedding_forward(latents.to(self.dtype), + pos_emb = self.rope.position_embedding_forward(latent.to(self.dtype), seq_length=seq_len - self.rope.text_length) if pos_emb is not None: latents_vid = latents_vid + pos_emb else: - latents_vid, latents_img = self.patch_embed(latents.to(self.dtype), frames, + latents_vid, latents_img = self.patch_embed(latent.to(self.dtype), frames, rope_T=t // self.patch_size[0], rope_H=h // self.patch_size[1], rope_W=w // self.patch_size[2]) - timestep_vid, timestep_img = None, None - embedded_timestep_vid, embedded_timestep_img = None, None - prompt_vid, prompt_img = None, None + timestep_vid = None + text_embed_vid, text_embed_img = None, None if self.time_embed is not None: timestep_vid = self.time_embed(timestep) if self.ofs_embed_dim is not None: - ofs_emb = timestep_embedding(latents.new_full((1,), fill_value=2.0), self.ofs_embed_dim, dtype=self.dtype) + ofs_emb = timestep_embedding(latent.new_full((1,), fill_value=2.0), self.ofs_embed_dim, dtype=self.dtype) ofs_emb = self.ofs_embed(ofs_emb) timestep_vid = timestep_vid + ofs_emb if self.caption_projection is not None: - prompt = self.caption_projection(prompt) + text_embed = self.caption_projection(text_embed) if latents_vid is None: - prompt_img = rearrange(prompt, 'b 1 l d -> (b 1) l d') + text_embed_img = rearrange(text_embed, 'b 1 l d -> (b 1) l d') else: - prompt_vid = rearrange(prompt[:, :1], 'b 1 l d -> (b 1) l d') + text_embed_vid = rearrange(text_embed[:, :1], 'b 1 l d -> (b 1) l d') if latents_img is not None: - prompt_img = rearrange(prompt[:, 1:], 'b i l d -> (b i) l d') + text_embed_img = rearrange(text_embed[:, 1:], 'b i l d -> (b i) l d') - return latents_vid, latents_img, prompt_vid, prompt_img, timestep_vid, timestep_img, embedded_timestep_vid, embedded_timestep_img + return latents_vid, latents_img, text_embed_vid, text_embed_img, timestep_vid - def _get_output_for_patched_inputs(self, latents, timestep, height=None, width=None): - x = self.norm_final(latents) + def _get_output_for_patched_inputs(self, latent, timestep, height=None, width=None): + x = self.norm_final(latent) _scale_shift_table = self.adaLN_modulation(timestep)[0] if self.sequence_parallel: _scale_shift_table = tensor_parallel.mappings.all_gather_last_dim_from_tensor_parallel_region( @@ -532,11 +517,11 @@ class SatDiT(MultiModalModule): x = gather_forward_split_backward(x, mpu.get_context_parallel_group(), dim=1, grad_scale="up") x = x[:, self.rope.text_length:, :] x = self.proj_out_linear(x) - latents = x + latent = x # unpatchify - output = rearrange(latents, "b (t h w) (c o p q) -> b (t o) c (h p) (w q)", - b=latents.shape[0], h=height, w=width, + output = rearrange(latent, "b (t h w) (c o p q) -> b (t o) c (h p) (w q)", + b=latent.shape[0], h=height, w=width, o=self.patch_size_t, p=self.patch_size_h, q=self.patch_size_w, c=self.out_channels).transpose(1, 2) return output @@ -757,14 +742,12 @@ class VideoDiTBlock(nn.Module): video_mask: Optional[torch.Tensor] = None, prompt_mask: Optional[torch.Tensor] = None, timestep: Dict[str, torch.Tensor] = None, - class_labels: Optional[torch.Tensor] = None, frames: torch.int64 = None, height: torch.int64 = None, width: torch.int64 = None, rotary_pos_emb=None, text_length: torch.int64 = None, checkpoint_skip_core_attention: bool = False, - added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, ) -> torch.FloatTensor: # before_self_attention if checkpoint_skip_core_attention: diff --git a/mindspeed_mm/models/predictor/dits/stdit.py b/mindspeed_mm/models/predictor/dits/stdit.py index d6cb03e595fb6847831e493910a25db0f0e6f9a6..867aca853fc4ea5635534513dbee62e7f14ec4ec 100644 --- a/mindspeed_mm/models/predictor/dits/stdit.py +++ b/mindspeed_mm/models/predictor/dits/stdit.py @@ -341,23 +341,23 @@ class STDiT(MultiModalModule): d_s=self.num_spatial, ) - def forward(self, video, timestep, prompt, prompt_mask=None, **kwargs): + def forward(self, latent, timestep, text_embed, text_mask=None, **kwargs): """ Forward pass of STDiT. Args: - video (torch.Tensor): latent representation of video; of shape [B, C, T, H, W] + latent (torch.Tensor): latent representation of video; of shape [B, C, T, H, W] timestep (torch.Tensor): diffusion time steps; of shape [B] - prompt (torch.Tensor): representation of prompts; of shape [B, 1, N_token, C] - prompt_mask (torch.Tensor): mask for selecting prompt tokens; of shape [B, N_token] + text_embed (torch.Tensor): representation of prompts; of shape [B, 1, N_token, C] + text_mask (torch.Tensor): mask for selecting prompt tokens; of shape [B, N_token] Returns: x (torch.Tensor): output latent representation; of shape [B, C, T, H, W] """ if mpu.is_pipeline_first_stage(): - x = video.to(self.dtype) + x = latent.to(self.dtype) timestep = timestep.to(self.dtype) - y = prompt.to(self.dtype) - mask = prompt_mask + y = text_embed.to(self.dtype) + mask = text_mask # embedding x = self.x_embedder(x) # [B, N, C] diff --git a/mindspeed_mm/models/predictor/dits/stdit3.py b/mindspeed_mm/models/predictor/dits/stdit3.py index f14dee03f3d852e09caa900c8233ad14b808746d..38586e54a8bee0f1419369c7c05f040889b92871 100644 --- a/mindspeed_mm/models/predictor/dits/stdit3.py +++ b/mindspeed_mm/models/predictor/dits/stdit3.py @@ -344,122 +344,122 @@ class STDiT3(MultiModalModule): y = y.squeeze(1).view(1, -1, self.hidden_size) return y, y_lens - def forward(self, video, timestep, prompt, prompt_mask=None, video_mask=None, fps=None, height=None, width=None, **kwargs): + def forward(self, latent, timestep, text_embed, text_mask=None, video_mask=None, **kwargs): dtype = self.x_embedder.proj.weight.dtype - B = video.size(0) - video = video.to(dtype) + B = latent.size(0) + latent = latent.to(dtype) timestep = timestep.to(dtype) - prompt = prompt.to(dtype) + text_embed = text_embed.to(dtype) # === get pos embed === - _, _, Tx, Hx, Wx = video.size() - T, H, W = self.get_dynamic_size(video) + _, _, Tx, Hx, Wx = latent.size() + T, H, W = self.get_dynamic_size(latent) S = H * W base_size = round(S ** 0.5) - resolution_sq = (height[0].item() * width[0].item()) ** 0.5 + resolution_sq = (kwargs["height"][0].item() * kwargs["width"][0].item()) ** 0.5 scale = resolution_sq / self.input_sq_size - pos_emb = self.pos_embed(video, H, W, scale=scale, base_size=base_size) + pos_emb = self.pos_embed(latent, H, W, scale=scale, base_size=base_size) # === get timestep embed === - t = self.t_embedder(timestep, dtype=video.dtype) # [B, C] - fps = self.fps_embedder(fps.unsqueeze(1), B) + t = self.t_embedder(timestep, dtype=latent.dtype) # [B, C] + fps = self.fps_embedder(kwargs["fps"].unsqueeze(1), B) t = t + fps t_mlp = self.t_block(t) t0 = t0_mlp = None if video_mask is not None: t0_timestep = torch.zeros_like(timestep) - t0 = self.t_embedder(t0_timestep, dtype=video.dtype) + t0 = self.t_embedder(t0_timestep, dtype=latent.dtype) t0 = t0 + fps t0_mlp = self.t_block(t0) # === get y embed === if self.skip_y_embedder: - y_lens = prompt_mask + y_lens = text_mask if isinstance(y_lens, torch.Tensor): y_lens = y_lens.long().tolist() else: - prompt, y_lens = self.encode_text(prompt, prompt_mask) + text_embed, y_lens = self.encode_text(text_embed, text_mask) # === get x embed === - video = self.x_embedder(video) # [B, N, C] - video = rearrange(video, "B (T S) C -> B T S C", T=T, S=S) - video = video + pos_emb + latent = self.x_embedder(latent) # [B, N, C] + latent = rearrange(latent, "B (T S) C -> B T S C", T=T, S=S) + latent = latent + pos_emb # === process video mask === if video_mask is not None: - video_mask = video_mask[:, :, None, None].expand(B, T, S, video.shape[-1]).contiguous() + video_mask = video_mask[:, :, None, None].expand(B, T, S, latent.shape[-1]).contiguous() # shard over the sequence dim if sp is enabled if self.enable_sequence_parallelism: - s_split_sizes = cal_split_sizes(dim_size=video.size(2), world_size=self.sp_size) - t_split_sizes = cal_split_sizes(dim_size=video.size(1), world_size=self.sp_size) - video = split_forward_gather_backward(video, mpu.get_context_parallel_group(), - dim=1, grad_scale="down", split_sizes=t_split_sizes) + s_split_sizes = cal_split_sizes(dim_size=latent.size(2), world_size=self.sp_size) + t_split_sizes = cal_split_sizes(dim_size=latent.size(1), world_size=self.sp_size) + latent = split_forward_gather_backward(latent, mpu.get_context_parallel_group(), + dim=1, grad_scale="down", split_sizes=t_split_sizes) sp_rank = mpu.get_context_parallel_rank() if video_mask is not None: video_mask_split_s = video_mask[:, :, sum(s_split_sizes[:sp_rank]): sum(s_split_sizes[:sp_rank + 1]), :] video_mask_split_t = video_mask[:, sum(t_split_sizes[:sp_rank]): sum(t_split_sizes[:sp_rank + 1]), :, :] - video_mask_split_s = video_mask_split_s.view(B, -1, video.shape[-1]).to(video.dtype) - video_mask_split_t = video_mask_split_t.view(B, -1, video.shape[-1]).to(video.dtype) + video_mask_split_s = video_mask_split_s.view(B, -1, latent.shape[-1]).to(latent.dtype) + video_mask_split_t = video_mask_split_t.view(B, -1, latent.shape[-1]).to(latent.dtype) else: video_mask_split_s, video_mask_split_t = None, None - T, S = video.size(1), video.size(2) + T, S = latent.size(1), latent.size(2) if video_mask is not None: - video_mask = video_mask.view(B, -1, video.shape[-1]).to(video.dtype) - video = rearrange(video, "B T S C -> B (T S) C", T=T, S=S) + video_mask = video_mask.view(B, -1, latent.shape[-1]).to(latent.dtype) + latent = rearrange(latent, "B T S C -> B (T S) C", T=T, S=S) # === blocks === for i, (spatial_block, temporal_block) in enumerate(zip(self.spatial_blocks, self.temporal_blocks)): if self.enable_sequence_parallelism: # === spatial block === - video = auto_grad_checkpoint(spatial_block, video, prompt, t_mlp, y_lens, video_mask_split_t, t0_mlp, T, S) + latent = auto_grad_checkpoint(spatial_block, latent, text_embed, t_mlp, y_lens, video_mask_split_t, t0_mlp, T, S) # split T, gather S - video = rearrange(video, "B (T S) C -> B T S C", T=T, S=S) - video = all_to_all(video, mpu.get_context_parallel_group(), + latent = rearrange(latent, "B (T S) C -> B T S C", T=T, S=S) + latent = all_to_all(latent, mpu.get_context_parallel_group(), scatter_dim=2, scatter_sizes=s_split_sizes, gather_dim=1, gather_sizes=t_split_sizes) - T, S = video.size(1), video.size(2) - video = rearrange(video, "B T S C -> B (T S) C", T=T, S=S) + T, S = latent.size(1), latent.size(2) + latent = rearrange(latent, "B T S C -> B (T S) C", T=T, S=S) # === temporal block === - video = auto_grad_checkpoint(temporal_block, video, prompt, t_mlp, y_lens, video_mask_split_s, t0_mlp, T, S) + latent = auto_grad_checkpoint(temporal_block, latent, text_embed, t_mlp, y_lens, video_mask_split_s, t0_mlp, T, S) if i == self.depth - 1: #final block break else: # split s, gather t - video = rearrange(video, "B (T S) C -> B T S C", T=T, S=S) - video = all_to_all(video, mpu.get_context_parallel_group(), - scatter_dim=1, scatter_sizes=t_split_sizes, - gather_dim=2, gather_sizes=s_split_sizes) - T, S = video.size(1), video.size(2) - video = rearrange(video, "B T S C -> B (T S) C", T=T, S=S) + latent = rearrange(latent, "B (T S) C -> B T S C", T=T, S=S) + latent = all_to_all(latent, mpu.get_context_parallel_group(), + scatter_dim=1, scatter_sizes=t_split_sizes, + gather_dim=2, gather_sizes=s_split_sizes) + T, S = latent.size(1), latent.size(2) + latent = rearrange(latent, "B T S C -> B (T S) C", T=T, S=S) else: - video = auto_grad_checkpoint(spatial_block, video, prompt, t_mlp, y_lens, video_mask, t0_mlp, T, S) - video = auto_grad_checkpoint(temporal_block, video, prompt, t_mlp, y_lens, video_mask, t0_mlp, T, S) + latent = auto_grad_checkpoint(spatial_block, latent, text_embed, t_mlp, y_lens, video_mask, t0_mlp, T, S) + latent = auto_grad_checkpoint(temporal_block, latent, text_embed, t_mlp, y_lens, video_mask, t0_mlp, T, S) if self.enable_sequence_parallelism: # === final layer === - video = self.final_layer(video, t, video_mask_split_s, t0, T, S) - video = rearrange(video, "B (T S) C -> B T S C", T=T, S=S) - video = gather_forward_split_backward(video, mpu.get_context_parallel_group(), - dim=2, grad_scale="up", gather_sizes=s_split_sizes) - S = video.size(2) - video = rearrange(video, "B T S C -> B (T S) C", T=T, S=S) + latent = self.final_layer(latent, t, video_mask_split_s, t0, T, S) + latent = rearrange(latent, "B (T S) C -> B T S C", T=T, S=S) + latent = gather_forward_split_backward(latent, mpu.get_context_parallel_group(), + dim=2, grad_scale="up", gather_sizes=s_split_sizes) + S = latent.size(2) + latent = rearrange(latent, "B T S C -> B (T S) C", T=T, S=S) else: # === final layer === - video = self.final_layer(video, t, video_mask, t0, T, S) - video = self.unpatchify(video, T, H, W, Tx, Hx, Wx) + latent = self.final_layer(latent, t, video_mask, t0, T, S) + latent = self.unpatchify(latent, T, H, W, Tx, Hx, Wx) # cast to float32 for better accuracy - video = video.to(torch.float32) - return video + latent = latent.to(torch.float32) + return latent def unpatchify(self, x, N_t, N_h, N_w, R_t, R_h, R_w): """ diff --git a/mindspeed_mm/models/predictor/dits/step_video_dit.py b/mindspeed_mm/models/predictor/dits/step_video_dit.py index a705e4020277d9ede4b8349c588bcc7e47234164..4fd9d4239f585cee95b90311726ecb5b55cbcae4 100644 --- a/mindspeed_mm/models/predictor/dits/step_video_dit.py +++ b/mindspeed_mm/models/predictor/dits/step_video_dit.py @@ -160,11 +160,11 @@ class StepVideoDiT(MultiModalModule): def forward( self, - hidden_states: torch.Tensor, + latent: torch.Tensor, timestep: Optional[torch.LongTensor] = None, - prompt: Optional[list] = None, + text_embed: Optional[list] = None, + text_mask: Optional[torch.Tensor] = None, added_cond_kwargs: Dict[str, torch.Tensor] = None, - prompt_mask: Optional[torch.Tensor] = None, fps: torch.Tensor = None, **kwargs ): @@ -175,45 +175,45 @@ class StepVideoDiT(MultiModalModule): rng_context = nullcontext() if self.pre_process: - if hidden_states.ndim != 5: + if latent.ndim != 5: raise ValueError("hidden_states's shape should be (bsz, f, ch, h ,w)") - encoder_hidden_states = prompt[0]# b 1 s h - encoder_hidden_states_2 = prompt[1]# b 1 s h + encoder_hidden_states = text_embed[0]# b 1 s h + encoder_hidden_states_2 = text_embed[1]# b 1 s h motion_score = kwargs.get("motion_score", 5.0) condition_hidden_states = kwargs.get("image_latents") # Only retain stepllm's mask - if isinstance(prompt_mask, list): - encoder_attention_mask = prompt_mask[0] + if isinstance(text_mask, list): + encoder_attention_mask = text_mask[0] # Padding 1 on the mask of the stepllm len_clip = encoder_hidden_states_2.shape[2] encoder_attention_mask = encoder_attention_mask.squeeze(1).to( - hidden_states.device) # stepchat_tokenizer_mask: b 1 s => b s + latent.device) # stepchat_tokenizer_mask: b 1 s => b s encoder_attention_mask = torch.nn.functional.pad(encoder_attention_mask, (len_clip, 0), value=1) # pad attention_mask with clip's length - bsz, frame, _, height, width = hidden_states.shape + bsz, frame, _, height, width = latent.shape if mpu.get_context_parallel_world_size() > 1: frame //= mpu.get_context_parallel_world_size() - hidden_states = split_forward_gather_backward(hidden_states, mpu.get_context_parallel_group(), dim=1, - grad_scale='down') + latent = split_forward_gather_backward(latent, mpu.get_context_parallel_group(), dim=1, + grad_scale='down') height, width = height // self.patch_size, width // self.patch_size - hidden_states = self.patchfy(hidden_states, condition_hidden_states) - len_frame = hidden_states.shape[1] + latent = self.patchfy(latent, condition_hidden_states) + len_frame = latent.shape[1] if self.use_additional_conditions: if condition_hidden_states is not None: added_cond_kwargs = { - "motion_score": torch.tensor([motion_score], device=hidden_states.device, - dtype=hidden_states.dtype).repeat(bsz) + "motion_score": torch.tensor([motion_score], device=latent.device, + dtype=latent.dtype).repeat(bsz) } else: added_cond_kwargs = { - "resolution": torch.tensor([(height, width)] * bsz, device=hidden_states.device, - dtype=hidden_states.dtype), - "nframe": torch.tensor([frame] * bsz, device=hidden_states.device, dtype=hidden_states.dtype), + "resolution": torch.tensor([(height, width)] * bsz, device=latent.device, + dtype=latent.dtype), + "nframe": torch.tensor([frame] * bsz, device=latent.device, dtype=latent.dtype), "fps": fps } else: @@ -228,26 +228,26 @@ class StepVideoDiT(MultiModalModule): clip_embedding = self.clip_projection(encoder_hidden_states_2) encoder_hidden_states = torch.cat([clip_embedding, encoder_hidden_states], dim=2) - hidden_states = rearrange(hidden_states, '(b f) l d-> b (f l) d', b=bsz, f=frame, l=len_frame).contiguous() + latent = rearrange(latent, '(b f) l d-> b (f l) d', b=bsz, f=frame, l=len_frame).contiguous() encoder_hidden_states, attn_mask = self.prepare_attn_mask(encoder_attention_mask, encoder_hidden_states, q_seqlen=frame * len_frame) # Rotary positional embeddings - rotary_pos_emb = self.rope(bsz, frame * mpu.get_context_parallel_world_size(), height, width, hidden_states.device)# s b 1 d + rotary_pos_emb = self.rope(bsz, frame * mpu.get_context_parallel_world_size(), height, width, latent.device)# s b 1 d if mpu.get_context_parallel_world_size() > 1: rotary_pos_emb = rotary_pos_emb.chunk(mpu.get_context_parallel_world_size(), dim=0)[mpu.get_context_parallel_rank()] else: - encoder_hidden_states = prompt - attn_mask = prompt_mask.to(torch.bool) + encoder_hidden_states = text_embed + attn_mask = text_mask.to(torch.bool) embedded_timestep = kwargs["embedded_timestep"] rotary_pos_emb = kwargs["rotary_pos_emb"] bsz, frame, height, width, len_frame = kwargs["batch_size"], kwargs["frames"], kwargs["h"], kwargs["w"], kwargs["len_frame"] with rng_context: if self.recompute_granularity == "full": - hidden_states = self._checkpointed_forward( - hidden_states, + latent = self._checkpointed_forward( + latent, encoder_hidden_states, timestep, attn_mask, @@ -255,33 +255,33 @@ class StepVideoDiT(MultiModalModule): ) else: for _, block in zip(self.global_layer_idx, self.transformer_blocks): - hidden_states = block( - hidden_states, + latent = block( + latent, encoder_hidden_states, timestep, attn_mask, rotary_pos_emb ) - output = hidden_states + output = latent if self.post_process: - hidden_states = rearrange(hidden_states, 'b (f l) d -> (b f) l d', b=bsz, f=frame, l=len_frame) + latent = rearrange(latent, 'b (f l) d -> (b f) l d', b=bsz, f=frame, l=len_frame) embedded_timestep = repeat(embedded_timestep, 'b d -> (b f) d', f=frame).contiguous() shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) - hidden_states = self.norm_out(hidden_states) + latent = self.norm_out(latent) # Modulation - hidden_states = hidden_states * (1 + scale) + shift - hidden_states = self.proj_out(hidden_states) + latent = latent * (1 + scale) + shift + latent = self.proj_out(latent) # unpatchify - hidden_states = hidden_states.reshape( + latent = latent.reshape( shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) ) - hidden_states = rearrange(hidden_states, 'n h w p q c -> n c h p w q') - output = hidden_states.reshape( + latent = rearrange(latent, 'n h w p q c -> n c h p w q') + output = latent.reshape( shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) ) diff --git a/mindspeed_mm/models/predictor/dits/video_dit.py b/mindspeed_mm/models/predictor/dits/video_dit.py index dba37679ebb38c9a6d2f4f1dba832f2fbdebb220..e5ca8e323c2d60767b9e63234803c62f986ea2c3 100644 --- a/mindspeed_mm/models/predictor/dits/video_dit.py +++ b/mindspeed_mm/models/predictor/dits/video_dit.py @@ -181,11 +181,11 @@ class VideoDiT(MultiModalModule): def forward( self, - latents: torch.Tensor, + latent: torch.Tensor, timestep: Optional[torch.Tensor] = None, - prompt: Optional[torch.Tensor] = None, + text_embed: Optional[torch.Tensor] = None, + text_mask: Optional[torch.Tensor] = None, video_mask: Optional[torch.Tensor] = None, - prompt_mask: Optional[torch.Tensor] = None, added_cond_kwargs: Dict[str, torch.Tensor] = None, class_labels: Optional[torch.Tensor] = None, use_image_num: Optional[int] = 0, @@ -193,19 +193,19 @@ class VideoDiT(MultiModalModule): ) -> torch.Tensor: """ Args: - latents: Shape (batch size, num latent pixels) if discrete, shape (batch size, channel, height, width) if continuous. + latent: Shape (batch size, num latent pixels) if discrete, shape (batch size, channel, height, width) if continuous. timestep: Used to indicate denoising step. Optional timestep to be applied as an embedding in AdaLayerNorm. - prompt: Conditional embeddings for cross attention layer. + text_embed: Conditional embeddings for cross attention layer. video_mask: An attention mask of shape (batch, key_tokens) is applied to latents. - prompt_mask: Cross-attention mask applied to prompt. + text_mask: Cross-attention mask applied to prompt. added_cond_kwargs: resolution or aspect_ratio. class_labels: Used to indicate class labels conditioning. use_image_num: The number of images use for trainning. """ - batch_size, _, t, _, _ = latents.shape + batch_size, _, t, _, _ = latent.shape frames = t - use_image_num vid_mask, img_mask = None, None - prompt_mask = prompt_mask.view(batch_size, -1, prompt_mask.shape[-1]) + text_mask = text_mask.view(batch_size, -1, text_mask.shape[-1]) if video_mask is not None and video_mask.ndim == 4: video_mask = video_mask.to(self.dtype) vid_mask = video_mask[:, :frames] # [b, frames, h, w] @@ -230,13 +230,13 @@ class VideoDiT(MultiModalModule): img_mask = vid_mask vid_mask = None # convert prompt_mask to a bias the same way we do for video_mask - if prompt_mask is not None and prompt_mask.ndim == 3: - prompt_mask = (1 - prompt_mask.to(self.dtype)) * -10000.0 - in_t = prompt_mask.shape[1] - prompt_vid_mask = prompt_mask[:, :in_t - use_image_num] + if text_mask is not None and text_mask.ndim == 3: + text_mask = (1 - text_mask.to(self.dtype)) * -10000.0 + in_t = text_mask.shape[1] + prompt_vid_mask = text_mask[:, :in_t - use_image_num] prompt_vid_mask = rearrange(prompt_vid_mask, 'b 1 l -> (b 1) 1 l') if prompt_vid_mask.numel() > 0 else None - prompt_img_mask = prompt_mask[:, in_t - use_image_num:] + prompt_img_mask = text_mask[:, in_t - use_image_num:] prompt_img_mask = rearrange(prompt_img_mask, 'b i l -> (b i) 1 l') if prompt_img_mask.numel() > 0 else None if frames == 1 and use_image_num == 0 and not self.enable_context_parallelism: @@ -258,15 +258,15 @@ class VideoDiT(MultiModalModule): # 1. Input frames = ((frames - 1) // self.patch_size_t + 1) if frames % 2 == 1 else frames // self.patch_size_t # patchfy - height, width = latents.shape[-2] // self.patch_size_h, latents.shape[-1] // self.patch_size_w + height, width = latent.shape[-2] // self.patch_size_h, latent.shape[-1] // self.patch_size_w added_cond_kwargs = {"resolution": None, "aspect_ratio": None} latents_vid, latents_img, prompt_vid, prompt_img, timestep_vid, timestep_img, \ embedded_timestep_vid, embedded_timestep_img = \ self._operate_on_patched_inputs( - latents=latents, - prompt=prompt, + latents=latent, + prompt=text_embed, timestep=timestep, added_cond_kwargs=added_cond_kwargs, batch_size=batch_size, @@ -292,7 +292,7 @@ class VideoDiT(MultiModalModule): width = torch.tensor(width) # Rotary positional embeddings - rotary_pos_emb = self.rope(batch_size, frames, height, width, latents.device)# s b 1 d + rotary_pos_emb = self.rope(batch_size, frames, height, width, latent.device)# s b 1 d if mpu.get_context_parallel_world_size() > 1: rotary_pos_emb = rotary_pos_emb.chunk(mpu.get_context_parallel_world_size(), dim=0)[mpu.get_context_parallel_rank()] diff --git a/mindspeed_mm/models/predictor/dits/video_dit_sparse.py b/mindspeed_mm/models/predictor/dits/video_dit_sparse.py index 085835d5be675fe2cfcef792539299883788016d..1d1d26916de62bdacae37997fe43542b6b2af2b8 100644 --- a/mindspeed_mm/models/predictor/dits/video_dit_sparse.py +++ b/mindspeed_mm/models/predictor/dits/video_dit_sparse.py @@ -229,11 +229,11 @@ class VideoDitSparse(MultiModalModule): def forward( self, - latents: torch.Tensor, + latent: torch.Tensor, timestep: Optional[torch.Tensor] = None, - prompt: Optional[torch.Tensor] = None, + text_embed: Optional[torch.Tensor] = None, + text_mask: Optional[torch.Tensor] = None, video_mask: Optional[torch.Tensor] = None, - prompt_mask: Optional[torch.Tensor] = None, **kwargs ): # RNG context. @@ -244,31 +244,31 @@ class VideoDitSparse(MultiModalModule): if self.pre_process: # pre_process latents - batch_size, c, frames, h, w = latents.shape + batch_size, c, frames, h, w = latent.shape if mpu.get_context_parallel_world_size() > 1: frames //= mpu.get_context_parallel_world_size() - latents = split_forward_gather_backward(latents, mpu.get_context_parallel_group(), dim=2, - grad_scale='down') - prompt = split_forward_gather_backward(prompt, mpu.get_context_parallel_group(), - dim=2, grad_scale='down') + latent = split_forward_gather_backward(latent, mpu.get_context_parallel_group(), dim=2, + grad_scale='down') + text_embed = split_forward_gather_backward(text_embed, mpu.get_context_parallel_group(), + dim=2, grad_scale='down') - latents, prompt, timestep, embedded_timestep = self._operate_on_patched_inputs( - latents, prompt, timestep, batch_size, **kwargs + latent, text_embed, timestep, embedded_timestep = self._operate_on_patched_inputs( + latent, text_embed, timestep, batch_size, **kwargs ) - latents = rearrange(latents, 'b s h -> s b h', b=batch_size).contiguous() - prompt = rearrange(prompt, 'b s h -> s b h', b=batch_size).contiguous() + latent = rearrange(latent, 'b s h -> s b h', b=batch_size).contiguous() + text_embed = rearrange(text_embed, 'b s h -> s b h', b=batch_size).contiguous() timestep = timestep.view(batch_size, 6, -1).transpose(0, 1).contiguous() if self.sequence_parallel: - latents = tensor_parallel.scatter_to_sequence_parallel_region(latents) - prompt = tensor_parallel.scatter_to_sequence_parallel_region(prompt) + latent = tensor_parallel.scatter_to_sequence_parallel_region(latent) + text_embed = tensor_parallel.scatter_to_sequence_parallel_region(text_embed) - prompt_mask = prompt_mask.view(batch_size, -1, prompt_mask.shape[-1]) + text_mask = text_mask.view(batch_size, -1, text_mask.shape[-1]) # convert encoder_attention_mask to a bias the same way we do for attention_mask - if prompt_mask is not None and prompt_mask.ndim == 3: + if text_mask is not None and text_mask.ndim == 3: # b, 1, l - prompt_mask = (1 - prompt_mask.to(self.dtype)) * -10000.0 + text_mask = (1 - text_mask.to(self.dtype)) * -10000.0 else: embedded_timestep = kwargs['embedded_timestep'] batch_size, c, frames, h, w = kwargs['batch_size'], kwargs['c'], kwargs['frames'], kwargs['h'], kwargs['w'] @@ -279,12 +279,12 @@ class VideoDitSparse(MultiModalModule): frames, height, width = torch.tensor(frames), torch.tensor(height), torch.tensor(width) # Rotary positional embeddings - rotary_pos_emb = self.rope(batch_size, frames * mpu.get_context_parallel_world_size(), height, width, latents.device)# s b 1 d + rotary_pos_emb = self.rope(batch_size, frames * mpu.get_context_parallel_world_size(), height, width, latent.device)# s b 1 d if mpu.get_context_parallel_world_size() > 1: rotary_pos_emb = rotary_pos_emb.chunk(mpu.get_context_parallel_world_size(), dim=0)[mpu.get_context_parallel_rank()] origin_video_mask = video_mask.clone().detach().to(self.dtype) - origin_prompt_mask = prompt_mask.clone().detach().to(self.dtype) + origin_prompt_mask = text_mask.clone().detach().to(self.dtype) if video_mask is not None and video_mask.ndim == 4: video_mask = video_mask.to(self.dtype) @@ -299,19 +299,19 @@ class VideoDitSparse(MultiModalModule): sparse_mask = {} for sparse_n in [1, 4]: - sparse_mask[sparse_n] = self.prepare_sparse_mask(video_mask, prompt_mask, sparse_n) + sparse_mask[sparse_n] = self.prepare_sparse_mask(video_mask, text_mask, sparse_n) if (video_mask == 0).all(): video_mask = None with rng_context: if self.recompute_granularity == "full": - latents = self._checkpointed_forward( + latent = self._checkpointed_forward( sparse_mask, - latents, + latent, video_mask=video_mask, - prompt=prompt, - prompt_mask=prompt_mask, + prompt=text_embed, + prompt_mask=text_mask, timestep=timestep, frames=frames, height=height, @@ -322,19 +322,19 @@ class VideoDitSparse(MultiModalModule): for i, block in zip(self.global_layer_idx, self.videodit_sparse_blocks): if i > 1 and i < 30: try: - video_mask, prompt_mask = sparse_mask[block.self_atten.sparse_n][block.self_atten.sparse_group] + video_mask, text_mask = sparse_mask[block.self_atten.sparse_n][block.self_atten.sparse_group] except KeyError: - video_mask, prompt_mask = None, None + video_mask, text_mask = None, None else: try: - video_mask, prompt_mask = sparse_mask[1][block.self_atten.sparse_group] + video_mask, text_mask = sparse_mask[1][block.self_atten.sparse_group] except KeyError: - video_mask, prompt_mask = None, None - latents = block( - latents, + video_mask, text_mask = None, None + latent = block( + latent, video_mask=video_mask, - prompt=prompt, - prompt_mask=prompt_mask, + prompt=text_embed, + prompt_mask=text_mask, timestep=timestep, frames=frames, height=height, @@ -342,12 +342,12 @@ class VideoDitSparse(MultiModalModule): rotary_pos_emb=rotary_pos_emb, ) - output = latents + output = latent if self.post_process: # 3. Output output = self._get_output_for_patched_inputs( - latents=latents, + latents=latent, timestep=timestep, embedded_timestep=embedded_timestep, num_frames=frames, @@ -358,7 +358,7 @@ class VideoDitSparse(MultiModalModule): if mpu.get_context_parallel_world_size() > 1: output = gather_forward_split_backward(output, mpu.get_context_parallel_group(), dim=2, grad_scale='up') - rtn = (output, prompt, timestep, embedded_timestep, origin_video_mask, origin_prompt_mask) + rtn = (output, text_embed, timestep, embedded_timestep, origin_video_mask, origin_prompt_mask) return rtn def pipeline_set_prev_stage_tensor(self, input_tensor_list, extra_kwargs): diff --git a/mindspeed_mm/models/predictor/dits/wan_dit.py b/mindspeed_mm/models/predictor/dits/wan_dit.py index e24a3511591239cead9cf49c4d7e0ac340c63b37..63b2dd7791169d4aed45a02b0786d2e12d143d04 100644 --- a/mindspeed_mm/models/predictor/dits/wan_dit.py +++ b/mindspeed_mm/models/predictor/dits/wan_dit.py @@ -105,7 +105,7 @@ class WanDiT(MultiModalModule): self.recompute_num_layers_skip_core_attention = args.recompute_num_layers_skip_core_attention self.attention_async_offload = attention_async_offload self.fp32_calculate = fp32_calculate - + self.h2d_stream = torch_npu.npu.Stream() if attention_async_offload else None self.d2h_stream = torch_npu.npu.Stream() if attention_async_offload else None @@ -312,16 +312,16 @@ class WanDiT(MultiModalModule): def forward( self, - x: torch.Tensor, + latent: torch.Tensor, timestep: torch.Tensor, - prompt: torch.Tensor, - prompt_mask: torch.Tensor = None, + text_embed: torch.Tensor, + text_mask: torch.Tensor = None, i2v_clip_feature: torch.Tensor = None, i2v_vae_feature: torch.Tensor = None, **kwargs, ): if self.pre_process: - timestep = timestep.to(x[0].device) + timestep = timestep.to(latent[0].device) # time embeddings times = self.time_embedding( self.sinusoidal_embedding_1d(self.freq_dim, timestep) @@ -329,24 +329,24 @@ class WanDiT(MultiModalModule): time_emb = self.time_projection(times).unflatten(1, (6, self.hidden_size)) # prompt embeddings - bs = prompt.size(0) - prompt = prompt.view(bs, -1, prompt.size(-1)) - if prompt_mask is not None: - seq_lens = prompt_mask.view(bs, -1).sum(dim=-1) + bs = text_embed.size(0) + text_embed = text_embed.view(bs, -1, text_embed.size(-1)) + if text_mask is not None: + seq_lens = text_mask.view(bs, -1).sum(dim=-1) for i, seq_len in enumerate(seq_lens): - prompt[i, seq_len:] = 0 - prompt_emb = self.text_embedding(prompt) + text_embed[i, seq_len:] = 0 + prompt_emb = self.text_embedding(text_embed) # cat i2v & flf2v if self.model_type in ["i2v", "flf2v"]: - i2v_clip_feature = i2v_clip_feature.to(x) - i2v_vae_feature = i2v_vae_feature.to(x) - x = torch.cat([x, i2v_vae_feature], dim=1) # (b, c[x+y], f, h, w) + i2v_clip_feature = i2v_clip_feature.to(latent) + i2v_vae_feature = i2v_vae_feature.to(latent) + latent = torch.cat([latent, i2v_vae_feature], dim=1) # (b, c[x+y], f, h, w) clip_embedding = self.img_emb(i2v_clip_feature.float() if self.fp32_calculate else i2v_clip_feature.to(time_emb.dtype)) prompt_emb = torch.cat([clip_embedding, prompt_emb], dim=1) # patch embedding - patch_emb = self.patch_embedding(x.to(time_emb.dtype)) + patch_emb = self.patch_embedding(latent.to(time_emb.dtype)) embs, grid_sizes = self.patchify(patch_emb) @@ -363,7 +363,7 @@ class WanDiT(MultiModalModule): prompt_emb = kwargs['prompt_emb'] time_emb = kwargs['time_emb'] times = kwargs['times'] - embs = x + embs = latent rotary_pos_emb = self.rope(batch_size, frames, height, width) @@ -408,7 +408,7 @@ class WanDiT(MultiModalModule): embs_out = self.head(embs, times) out = self.unpatchify(embs_out, frames, height, width) - rtn = (out, prompt, prompt_emb, time_emb, times, prompt_mask) + rtn = (out, text_embed, prompt_emb, time_emb, times, text_mask) return rtn diff --git a/mindspeed_mm/models/sora_model.py b/mindspeed_mm/models/sora_model.py index 7292f4fc285eccd13a6973934dd43dbfee3cd266..ffab5a325b3397715ec70defb3184e09127fa455 100644 --- a/mindspeed_mm/models/sora_model.py +++ b/mindspeed_mm/models/sora_model.py @@ -186,10 +186,10 @@ class SoRAModel(nn.Module): output = self.predictor( predictor_input_latent, - timestep=predictor_timesteps, - prompt=predictor_prompt, + predictor_timesteps, + predictor_prompt, + predictor_prompt_mask, video_mask=predictor_video_mask, - prompt_mask=predictor_prompt_mask, **kwargs, ) diff --git a/mindspeed_mm/tasks/inference/pipeline/cogvideox_pipeline.py b/mindspeed_mm/tasks/inference/pipeline/cogvideox_pipeline.py index bc1342c020cd54e10851c66ae8e0852c80ea07ad..9af6343720094ea264f217171ffca43715d95d65 100644 --- a/mindspeed_mm/tasks/inference/pipeline/cogvideox_pipeline.py +++ b/mindspeed_mm/tasks/inference/pipeline/cogvideox_pipeline.py @@ -196,8 +196,8 @@ class CogVideoXPipeline(MMPipeline, InputsCheckMixin, MMEncoderMixin): # 6 prepare extra step kwargs extra_step_kwargs = self.prepare_extra_step_kwargs(self.generator, eta) - model_kwargs = {"prompt": prompt_embeds.unsqueeze(1), - "prompt_mask": prompt_embeds_attention_mask, + model_kwargs = {"text_embed": prompt_embeds.unsqueeze(1), + "text_mask": prompt_embeds_attention_mask, "masked_video": image_latents} self.scheduler.guidance_scale = self.guidance_scale diff --git a/mindspeed_mm/tasks/inference/pipeline/hunyuanvideo_pipeline.py b/mindspeed_mm/tasks/inference/pipeline/hunyuanvideo_pipeline.py index b2ea6b78e9580cfbf77c3e41147ca47133fd654d..daddf752f0e965273a1db985047d33e829736675 100644 --- a/mindspeed_mm/tasks/inference/pipeline/hunyuanvideo_pipeline.py +++ b/mindspeed_mm/tasks/inference/pipeline/hunyuanvideo_pipeline.py @@ -204,7 +204,7 @@ class HunyuanVideoPipeline(MMPipeline, InputsCheckMixin, MMEncoderMixin): guidance_scale=self.guidance_scale, guidance_rescale=self.guidance_rescale, embedded_guidance_scale=self.embedded_guidance_scale, - model_kwargs={"prompt": [prompt_embeds, prompt_embeds_2], "prompt_mask": prompt_mask}, + model_kwargs={"text_embed": [prompt_embeds, prompt_embeds_2], "text_mask": prompt_mask}, extra_step_kwargs=extra_step_kwargs, **i2v_kwargs ) diff --git a/mindspeed_mm/tasks/inference/pipeline/opensora_pipeline.py b/mindspeed_mm/tasks/inference/pipeline/opensora_pipeline.py index 76fa3156b0e94b68fe0e3dc158981f25ae555e70..125c2cd70c79ee5692b392b3e71ad6ef180a7d08 100644 --- a/mindspeed_mm/tasks/inference/pipeline/opensora_pipeline.py +++ b/mindspeed_mm/tasks/inference/pipeline/opensora_pipeline.py @@ -60,13 +60,13 @@ class OpenSoraPipeline(MMPipeline, InputsCheckMixin, MMEncoderMixin): else: model_args = dict(prompt=prompt_embeds, prompt_mask=prompt_embeds_attention_mask) y_null = self.null(batch_size) - model_args["prompt"] = torch.cat([model_args["prompt"], y_null], 0) + model_args["text_embed"] = torch.cat([model_args["prompt"], y_null], 0) model_args["fps"] = torch.tensor([fps], device=device, dtype=dtype).repeat(batch_size) model_args["height"] = torch.tensor([self.height], device=device, dtype=dtype).repeat(batch_size) model_args["width"] = torch.tensor([self.width], device=device, dtype=dtype).repeat(batch_size) model_args["num_frames"] = torch.tensor([self.num_frames], device=device, dtype=dtype).repeat(batch_size) model_args["ar"] = torch.tensor([self.height / self.width], device=device, dtype=dtype).repeat(batch_size) - model_args["mask"] = prompt_embeds_attention_mask + model_args["text_mask"] = prompt_embeds_attention_mask # 5. Prepare latents image_size = (self.height, self.width) diff --git a/mindspeed_mm/tasks/inference/pipeline/opensoraplan_pipeline.py b/mindspeed_mm/tasks/inference/pipeline/opensoraplan_pipeline.py index a0079d7089c08ab9cb191c6ca3cbca6682d2fd9b..75884e538847ef00ebd4f2a705554622ce678cb5 100644 --- a/mindspeed_mm/tasks/inference/pipeline/opensoraplan_pipeline.py +++ b/mindspeed_mm/tasks/inference/pipeline/opensoraplan_pipeline.py @@ -3,7 +3,6 @@ import math import inspect import torch -import torch.nn as nn from mindspeed_mm.tasks.inference.pipeline.pipeline_base import MMPipeline from mindspeed_mm.tasks.inference.pipeline.pipeline_mixin.encode_mixin import MMEncoderMixin @@ -116,9 +115,9 @@ class OpenSoraPlanPipeline(MMPipeline, InputsCheckMixin, MMEncoderMixin): clean_caption=clean_caption, use_prompt_preprocess=use_prompt_preprocess ) - + # multi text encoder - else: + else: prompt_embeds, prompt_embeds_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = [], [], [], [] for tokenizer, text_encoder in zip(self.tokenizer, self.text_encoder): prompt_embed, attention_mask, negative_prompt_embed, negative_attention_mask = self.encode_texts( @@ -140,7 +139,7 @@ class OpenSoraPlanPipeline(MMPipeline, InputsCheckMixin, MMEncoderMixin): if do_classifier_free_guidance: if isinstance(prompt_embeds, list): prompt_embeds = [ - torch.cat([negative_prompt_embed, prompt_embed], dim=0) + torch.cat([negative_prompt_embed, prompt_embed], dim=0) for negative_prompt_embed, prompt_embed in zip(negative_prompt_embeds, prompt_embeds) ] prompt_embeds_attention_mask = [ @@ -180,10 +179,10 @@ class OpenSoraPlanPipeline(MMPipeline, InputsCheckMixin, MMEncoderMixin): math.ceil(int(self.width) / self.vae.vae_scale_factor[2]), ) latents = self.prepare_latents( - shape, - generator=generator, - device=device, - dtype=prompt_embeds[0].dtype + shape, + generator=generator, + device=device, + dtype=prompt_embeds[0].dtype if isinstance(prompt_embeds, list) else prompt_embeds.dtype, latents=latents ) @@ -199,10 +198,10 @@ class OpenSoraPlanPipeline(MMPipeline, InputsCheckMixin, MMEncoderMixin): prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d if prompt_embeds_attention_mask.ndim == 2: prompt_embeds_attention_mask = prompt_embeds_attention_mask.unsqueeze(1) # b l -> b 1 l - model_kwargs = {"prompt": prompt_embeds, + model_kwargs = {"text_embed": prompt_embeds, "added_cond_kwargs": added_cond_kwargs, "enable_temporal_attentions": enable_temporal_attentions, - "prompt_mask": prompt_embeds_attention_mask, + "text_mask": prompt_embeds_attention_mask, "return_dict": False} if self.model_type == "i2v": model_kwargs.update(i2v_kwargs) diff --git a/mindspeed_mm/tasks/inference/pipeline/stepvideo_pipeline.py b/mindspeed_mm/tasks/inference/pipeline/stepvideo_pipeline.py index f9b634c3f2d8be171587b56b2076985ab2c06f3f..6220634ed032926b524e1ce569c21ff394ef80ff 100644 --- a/mindspeed_mm/tasks/inference/pipeline/stepvideo_pipeline.py +++ b/mindspeed_mm/tasks/inference/pipeline/stepvideo_pipeline.py @@ -271,7 +271,7 @@ class StepVideoPipeline(MMPipeline, InputsCheckMixin, MMEncoderMixin): device=device, do_classifier_free_guidance=do_classifier_free_guidance, guidance_scale=self.guidance_scale, - model_kwargs={"prompt": [prompt_embeds, clip_embedding], "prompt_mask": [prompt_mask, clip_mask], + model_kwargs={"text_embed": [prompt_embeds, clip_embedding], "text_mask": [prompt_mask, clip_mask], "motion_score": self.motion_score, "image_latents": image_latents} ) # predict model offload to 'cpu'