Ai
42 Star 397 Fork 242

Ascend/MindSpeed-MM
暂停

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
infer_hidream_text2img_lora_bf16.py 1.35 KB
一键复制 编辑 原始数据 按行查看 历史
J石页 提交于 2025-06-27 10:52 +08:00 . !1157[Feature] Support for HiDream Image
# Copyright 2025 Huawei Technologies Co., Ltd
import os
import torch
from diffusers import HiDreamImagePipeline
from prompt_utils import run_inference
from transformer_patches import apply_patches
from transformers import LlamaForCausalLM, PreTrainedTokenizerFast
apply_patches()
MODEL_PATH = "HiDream-ai/HiDream-I1-Full" # Model path for HiDream
FORTH_PATH = "meta-llama/Meta-Llama-3.1-8B-Instruct" # pretrained model path for tokenizer & text encoder
OUTPUT_PATH = "./infer_result" # Output path
LORA_WEIGHTS = "./logs/pytorch_lora_weights.safetensors" # Path for saved LoRA
DEVICE = "npu"
os.makedirs(OUTPUT_PATH, exist_ok=True) # Create the output folder
tokenizer = PreTrainedTokenizerFast.from_pretrained(FORTH_PATH)
text_encoder = LlamaForCausalLM.from_pretrained(
FORTH_PATH,
output_hidden_states=True,
output_attentions=True,
torch_dtype=torch.bfloat16,
)
pipe = HiDreamImagePipeline.from_pretrained(
MODEL_PATH,
tokenizer_4=tokenizer,
text_encoder_4=text_encoder,
torch_dtype=torch.bfloat16,
local_files_only=True,
)
if os.path.exists(LORA_WEIGHTS): # Load Lora weights
print(f"Loading LoRA weights from {LORA_WEIGHTS}")
pipe.load_lora_weights(LORA_WEIGHTS)
else:
print("LoRA weights not found. Using the base model")
pipe = pipe.to(DEVICE)
pipe.enable_model_cpu_offload()
run_inference(pipe, OUTPUT_PATH)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/ascend/MindSpeed-MM.git
git@gitee.com:ascend/MindSpeed-MM.git
ascend
MindSpeed-MM
MindSpeed-MM
master

搜索帮助