From 1d3a0e3065ca3085871db316d1cfdbd8078c6213 Mon Sep 17 00:00:00 2001 From: Xin Wen Date: Wed, 12 Feb 2025 20:15:55 +0000 Subject: [PATCH] =?UTF-8?q?[bugfix]=20=E6=9B=BF=E6=8D=A2CaptionEmbedder?= =?UTF-8?q?=E4=B8=ADcuda=E8=B0=83=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Xin Wen --- mindspeed_mm/models/common/embeddings/common_embeddings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindspeed_mm/models/common/embeddings/common_embeddings.py b/mindspeed_mm/models/common/embeddings/common_embeddings.py index 4c6b6d73..2b3b8568 100644 --- a/mindspeed_mm/models/common/embeddings/common_embeddings.py +++ b/mindspeed_mm/models/common/embeddings/common_embeddings.py @@ -120,7 +120,7 @@ class CaptionEmbedder(nn.Module): Drops labels to enable classifier-free guidance. """ if force_drop_ids is None: - drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob + drop_ids = torch.rand(caption.shape[0], device=caption.device) < self.uncond_prob else: drop_ids = force_drop_ids == 1 caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption) -- Gitee