diff --git a/README.md b/README.md index dfd3d10fdcf0dc6a8812393e22f0aa82bd368fda..02b6701134ae7f7ba1c96f46207ae21f2dc818c2 100644 --- a/README.md +++ b/README.md @@ -3,75 +3,76 @@ -- [MindSpore Community](#mindspore-community) - - [Charter](#charter) - - [Code Of Conduct](#code-of-conduct) - - [Contributor License Agreement](#contributor-license-agreement) - - [Individual contributors](#individual-contributors) - - [Corporation contributors](#corporation-contributors) - - [Useful CI Commands](#useful-ci-commands) - - [Communications](#communications) - - [CVE Report](#cve-report) - - [Slide Template](#slide-template) - - [License](#license) +- [MindSpore社区](#mindspore社区) + - [治理架构](#治理架构) + - [行为准则](#行为准则) + - [贡献者协议](#贡献者协议) + - [个人贡献者](#个人贡献者) + - [企业贡献者](#企业贡献者) + - [CI指令](#ci指令) + - [交流渠道](#交流渠道) + - [CVE上报](#cve上报) + - [材料模板](#材料模板) + - [许可证](#许可证) -English | [查看中文](./README_CN.md) +中文 | [View English](./README.md) -# MindSpore Community +# MindSpore社区 -This is the repo for all the community related materials. You can find the -following information. +该仓库托管了MindSpore社区相关的所有材料,具体信息如下。 -## Charter +## 治理架构 -Community charter is documented in [governance.md](governance.md), this is -an initial draft and will need to be approved and updated by the TSC. We -also have [sig document](sigs/README.md) and [working group document](working-groups/README.md) -charter provided. +社区治理架构的内容详见[governance.md](governance.md),该提案处于初稿阶段,后期可能会在 +技术指导委员会(TSC)的许可下进行刷新。除此之外,社区还提供了特别兴趣小组([SIG](sigs/README.md)) +和工作组([Working Group](working-groups/README.md))的资料介绍。 -## Code Of Conduct +## 行为准则 -One of the most important community document, we provided both the -[Chinese version](code-of-conduct_zh_cn.md) and [English version](code-of-conduct_en.md) -based on CNCF Code Of Conduct. +作为社区运作的核心组成部分,我们在CNCF社区行为准则的基础上,同时提供了[中文版本](code-of-conduct_zh_cn.md) +和[英文版本](code-of-conduct_en.md)的行为准则。 -## Contributor License Agreement +## 贡献者协议 -You can find both the [Individual Contributor License Agreement](ICLA.pdf) -and [Corporate Contributor License Agreement](CCLA.pdf). +MindSpore社区针对个人贡献者和企业贡献者分别提供了[Individual Contributor License Agreement](ICLA.pdf) +和[Corporate Contributor License Agreement](CCLA.pdf)。 -### Individual contributors +### 个人贡献者 -For individual contributor, please click [CLA online sign page](https://clasign.osinfra.cn/sign/Z2l0ZWUlMkZtaW5kc3BvcmU=) -and choose the `Sign Individual CLA` button to sign Contributor License Agreement. +针对想要参与社区的个人贡献者,请打开[CLA在线签署平台](https://clasign.osinfra.cn/sign/Z2l0ZWUlMkZtaW5kc3BvcmU=) +并点击`Sign Individual CLA`按钮,然后根据系统提示完成协议签署。 -### Corporation contributors +### 企业贡献者 -Corporation employee would not be permitted to sign the CLA until the corporation -has signed CCLA document, and he(she) can click [CLA online sign page](https://clasign.osinfra.cn/sign/Z2l0ZWUlMkZtaW5kc3BvcmU=) -and choose the `Sign Employee CLA` button to sign Contributor License Agreement. +企业签署流程包括`企业主体`和`企业员工`签署两个环节。 -## Useful CI Commands +企业主体签署环节需要该企业联络人打开[CLA在线签署平台](https://clasign.osinfra.cn/sign/Z2l0ZWUlMkZtaW5kc3BvcmU=) +并点击`Sign Corporation CLA`按钮,然后根据系统提示完成协议签署;线上签署完成之后请打开 +签署所用邮箱获取企业签署协议(`电子版`),打印并交给企业业务负责人签名盖章,最后将企业 +签署协议扫描并通过邮件回传给CLA签署系统。 -Please check out some of the most useful [CI command](command.md) -you could use. +企业员工需要在其所属企业签署CCLA协议之后才准许签署贡献者协议,若其所属公司已签署,请打开 +[CLA在线签署平台](https://clasign.osinfra.cn/sign/Z2l0ZWUlMkZtaW5kc3BvcmU=) +并点击`Sign Employee CLA`按钮,然后根据系统提示完成协议签署。 -## Communications +## CI指令 -Please find all the necessary information regarding how we use `IRC`, `Slack`, -and `mailing-list` for discussions in the community. +请查阅[CI指令手册](command.md)来学习CI机器人的操作命令。 -## CVE Report +## 交流渠道 -If you want to file a CVE report, please refer to information in the `security` -folder. +社区提供了多种交流渠道:包括`IRC`、`Slack`以及`邮件列表`等,详情查阅`communication`目录。 -## Slide Template +## CVE上报 -MindSpore community themed slide templates if you need to make a presentation. +如果您发现社区存在任何安全漏洞,请查阅`security`目录了解CVE上报流程。 -## License +## 材料模板 + +如果您想开展MindSpore相关的主题演讲,请查阅`slides`目录获取社区主题材料模板。 + +## 许可证 [Apache License 2.0](LICENSE) diff --git a/work/README.md b/work/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2a3efecdbffed4cce0e4a7a5d7a4ffbad8501c8b --- /dev/null +++ b/work/README.md @@ -0,0 +1,13 @@ + +# {{ community_name }} Community + +This repository hosts all community-related materials. For more information, please refer to the following sections: + +## Governance +Please see [governance.md](governance.md) for details on our community governance structure. + +## Code of Conduct +We provide both [Chinese](code-of-conduct_zh_cn.md) and [English](code-of-conduct_en.md) versions of our Code of Conduct, based on the CNCF Code of Conduct. + +## Contributor License Agreement +You can find both the [Individual](ICLA.pdf) and [Corporate](CCLA.pdf) Contributor License Agreements. \ No newline at end of file diff --git a/work/requirements.txt b/work/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..184d33901972708a5439ebc5965eda7d805da2c5 --- /dev/null +++ b/work/requirements.txt @@ -0,0 +1,5 @@ +python==3.9.25 +mindspore==2.7.0 +mindnlp==0.5.1 +gradio +tqdm \ No newline at end of file diff --git a/work/train_lora.py b/work/train_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..d1b882480ddf22d95a7960fc6938fb1159803572 --- /dev/null +++ b/work/train_lora.py @@ -0,0 +1,373 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import os + # 必须在导入其他库之前禁用 Tokenizers 并行,防止死锁 +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +MODEL_PATH = os.getenv("MODEL_PATH", "./pretrained/Qwen/Qwen2.5-7B-Instruct") +ADAPTER_DIR = os.getenv("ADAPTER_DIR", "./final_lora_output") +MERGED_DIR = os.getenv("MERGED_DIR", "./merged_model") +DATA_DIR = os.getenv("DATA_DIR", "./data") + +import argparse +import json +import math +import mindspore as ms +from mindspore import context +from mindnlp.transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, TrainerCallback, TextIteratorStreamer +from tqdm.auto import tqdm +from mindnlp.peft import LoraConfig, get_peft_model, TaskType, PeftModel +from threading import Thread + +# 设置上下文 +context.set_context(mode=ms.PYNATIVE_MODE, device_target="Ascend", device_id=0) + +def build_prompt(tokenizer, instruction, user_input): + """构造推理/训练时的提示词 + - 优先使用 `tokenizer.apply_chat_template` 以适配聊天模型格式 + - 无聊天模板时,退化为简单的指令-用户-助手三段式 + 参数: + tokenizer: 分词器对象 + instruction: 指令(必填) + user_input: 额外输入(可选) + """ + system = "你是严谨的中文法律助手。" + if hasattr(tokenizer, "apply_chat_template"): + content = instruction + ("\n" + user_input if user_input else "") + messages = [{"role": "system", "content": system}, {"role": "user", "content": content}] + return tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) + prefix = "系统:你是严谨的中文法律助手。\n用户:" + if user_input: + return f"{prefix}{instruction}\n{user_input}\n助手:" + return f"{prefix}{instruction}\n助手:" + +def load_raw_dataset(json_path): + """加载 JSON 数据集 + - 优先使用 `datasets.load_dataset` 以支持大文件与切分 + - 失败时退回到原生 `json.load` + 返回:list[dict],每条包含 instruction/input/output + """ + try: + from datasets import load_dataset + ds = load_dataset("json", data_files=json_path, split="train") + return [dict(x) for x in ds] + except Exception: + with open(json_path, "r", encoding="utf-8") as f: + return json.load(f) + +def tokenize_examples(examples, tokenizer, max_length): + """将每条样本转为 token 序列并构造 label + - 拼接:prompt + output + eos + - 对 prompt 段落的 label 置为 -100 以忽略 loss + - 支持按 `max_length` 截断 + 返回:[{input_ids, labels}] + """ + tokenized = [] + for ex in examples: + inst = ex.get("instruction", "") + inp = ex.get("input", "") + out = ex.get("output", "") + prompt = build_prompt(tokenizer, inst, inp) + full_text = prompt + out + tokenizer.eos_token + + # 分词 + full_ids = tokenizer(full_text, max_length=max_length, truncation=True, padding=False)["input_ids"] + prompt_ids = tokenizer(prompt, max_length=max_length, truncation=True, padding=False)["input_ids"] + + # 制作 Label Mask + labels = full_ids.copy() + # 将 prompt 部分的 label 设为 -100 (不计算 loss) + prompt_len = len(prompt_ids) + if prompt_len < len(labels): + for i in range(prompt_len): + labels[i] = -100 + else: + # 异常情况保护:如果截断后全是 prompt + for i in range(len(labels)): + labels[i] = -100 + + tokenized.append({"input_ids": full_ids, "labels": labels}) + return tokenized + +def make_collate_fn(tokenizer): + """构造动态 padding 的 `data_collator` + - 以 batch 内最大长度对齐 + - `input_ids`/`attention_mask`/`labels` 均返回 MindSpore Tensor + - 张量移动到 `npu:0` + """ + pad_id = tokenizer.pad_token_id + def collate_fn(batch): + max_len = max(len(x["input_ids"]) for x in batch) + input_ids = [] + labels = [] + attention_mask = [] + for x in batch: + ids = x["input_ids"] + lbs = x["labels"] + pad_len = max_len - len(ids) + input_ids.append(ids + [pad_id] * pad_len) + attention_mask.append([1] * len(ids) + [0] * pad_len) + labels.append(lbs + [-100] * pad_len) + + return { + "input_ids": ms.Tensor(input_ids, dtype=ms.int64).to("npu:0"), + "attention_mask": ms.Tensor(attention_mask, dtype=ms.int64).to("npu:0"), + "labels": ms.Tensor(labels, dtype=ms.int64).to("npu:0"), + } + return collate_fn + +class ListDataset: + """最小化的数据集封装,用于将普通 list 数据适配为 MindNLP Trainer 所需的数据集格式""" + def __init__(self, data): + self.data = data + def __len__(self): + return len(self.data) + def __getitem__(self, idx): + return self.data[idx] + +class TqdmProgress(TrainerCallback): + """使用 tqdm 展示训练进度的回调""" + def __init__(self): + self.bar = None + def on_train_begin(self, args, state, control, **kwargs): + total = args.max_steps if getattr(args, "max_steps", 0) else None + self.bar = tqdm(total=total, desc="Training") + def on_step_end(self, args, state, control, **kwargs): + if self.bar is not None: + self.bar.update(1) + def on_train_end(self, args, state, control, **kwargs): + if self.bar is not None: + self.bar.close() + +def merge_weights(args): + """独立的合并逻辑,防止显存溢出""" + print(f"\n[Merge] Starting merge process...") + print(f"[Merge] Loading base model from: {args.model_path}") + + # 重新加载纯净的底座模型 + base_model = AutoModelForCausalLM.from_pretrained( + args.model_path, + ms_dtype=ms.float16, + low_cpu_mem_usage=True + ) + + print(f"[Merge] Loading adapter from: {args.adapter_dir}") + peft_model = PeftModel.from_pretrained(base_model, args.adapter_dir) + + print("[Merge] Merging weights (merge_and_unload)...") + # 真正合并权重 + merged_model = peft_model.merge_and_unload() + merged_model.set_train(False) + + print(f"[Merge] Saving full model to: {args.merged_dir}") + merged_model.save_pretrained(args.merged_dir) + + # 同时保存 tokenizer + tokenizer = AutoTokenizer.from_pretrained(args.model_path) + tokenizer.save_pretrained(args.merged_dir) + print("[Merge] Done successfully!") + +def train(args): + """主训练流程 + - 加载 tokenizer/底座模型并应用 LoRA + - 构造数据与 `Trainer` + - 断点续训与保存 Adapter + - 可选:释放显存后进行权重合并并保存全量模型 + """ + tokenizer = AutoTokenizer.from_pretrained(args.model_path) + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + + base_model = AutoModelForCausalLM.from_pretrained(args.model_path, ms_dtype=ms.float16) + + peft_cfg = LoraConfig( + task_type=TaskType.CAUSAL_LM, + inference_mode=False, + r=args.lora_r, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + target_modules=args.target_modules.split(","), + ) + + model = get_peft_model(base_model, peft_cfg) + model = model.to("npu:0") + try: + model.print_trainable_parameters() + except Exception: + pass + + raw = load_raw_dataset(args.dataset_json) + tok = tokenize_examples(raw, tokenizer, args.max_seq_length) + collate_fn = make_collate_fn(tokenizer) + + steps = args.max_steps if args.max_steps > 0 else math.ceil(len(tok) * args.num_train_epochs / (args.per_device_train_batch_size * args.gradient_accumulation_steps)) + + training_args = TrainingArguments( + output_dir=args.output_dir, + per_device_train_batch_size=args.per_device_train_batch_size, + gradient_accumulation_steps=args.gradient_accumulation_steps, + learning_rate=args.learning_rate, + num_train_epochs=args.num_train_epochs, + logging_steps=args.logging_steps, + save_steps=args.save_steps, + fp16=False, + save_total_limit=args.save_total_limit, + max_steps=steps, + disable_tqdm=False, + ) + + trainer = Trainer( + model=model, + args=training_args, + train_dataset=ListDataset(tok), + data_collator=collate_fn, + tokenizer=tokenizer, + callbacks=[TqdmProgress()], + ) + + resume_path = None + if args.resume_from_checkpoint: + resume_path = args.resume_from_checkpoint + elif args.resume_auto: + try: + ckpts = [d for d in os.listdir(args.output_dir) if d.startswith("checkpoint-")] + if ckpts: + last = sorted(ckpts, key=lambda x: int(x.split("-")[-1]))[-1] + resume_path = os.path.join(args.output_dir, last) + except Exception: + resume_path = None + trainer.train(resume_from_checkpoint=resume_path) + + # 保存 Adapter + print(f"Saving LoRA adapter to {args.adapter_dir}") + model.save_pretrained(args.adapter_dir) + + +def infer(args): + """推理流程 + - 优先加载已合并的全量模型;否则加载底座+Adapter + - 构造提示词与采样参数 + - 采用 `TextIteratorStreamer` 流式输出生成文本 + """ + use_merged = os.path.isdir(args.merged_dir) and bool(os.listdir(args.merged_dir)) if os.path.exists(args.merged_dir) else False + if use_merged: + print(f"Loading model for inference from: {args.merged_dir}") + tokenizer = AutoTokenizer.from_pretrained(args.merged_dir) + model = AutoModelForCausalLM.from_pretrained(args.merged_dir, ms_dtype=ms.float16) + else: + print(f"Loading model for inference from: {args.model_path} with adapter {args.adapter_dir}") + tokenizer = AutoTokenizer.from_pretrained(args.model_path) + base = AutoModelForCausalLM.from_pretrained(args.model_path, ms_dtype=ms.float16) + model = PeftModel.from_pretrained(base, args.adapter_dir) + model = model.to("npu:0") + model.set_train(False) + + prompt = build_prompt(tokenizer, args.infer_instruction, args.infer_input) + inputs = tokenizer(prompt, return_tensors="ms") + inputs = {k: v.to("npu:0") for k, v in inputs.items()} + + print("-" * 20) + print(f"Question: {args.infer_instruction}") + print("Answer: ", end="", flush=True) + + # 生成配置 + generate_kwargs = dict( + input_ids=inputs["input_ids"], + max_new_tokens=args.max_new_tokens, + do_sample=args.do_sample, + top_p=args.top_p, + top_k=getattr(args, "top_k", 50), + temperature=args.temperature, + repetition_penalty=args.repetition_penalty, + no_repeat_ngram_size=args.no_repeat_ngram_size, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id + ) + + # 流式输出 + streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) + generate_kwargs["streamer"] = streamer + + thread = Thread(target=model.generate, kwargs=generate_kwargs) + thread.start() + + for new_text in streamer: + print(new_text, end="", flush=True) + print("\n" + "-" * 20) + +def parse_args(): + """命令行参数定义 + - 路径/训练/LoRA/控制/推理参数分组 + - `no_stream`/`no_prewarm` 等参数为预留,当前未使用 + """ + p = argparse.ArgumentParser() + # 路径参数 + p.add_argument("--model_path", type=str, default="MODEL_PATH", help="底座模型路径") + p.add_argument("--dataset_json", type=str, default="DATA_PDIR") + p.add_argument("--adapter_dir", type=str, default="ADAPTER_DIR", help="LoRA权重保存路径") + p.add_argument("--merged_dir", type=str, default="MERGED_DIR", help="合并后全量模型保存路径") + + # 训练参数 + train_args = p.add_argument_group("Training Arguments") + train_args.add_argument("--learning_rate", type=float, default=2e-4) + train_args.add_argument("--num_train_epochs", type=float, default=3.0) + + lora_args = p.add_argument_group("LoRA Configuration") + lora_args.add_argument("--lora_r", type=int, default=32) + lora_args.add_argument("--lora_alpha", type=int, default=64) + + # LoRA 参数 + p.add_argument("--lora_dropout", type=float, default=0.1) + p.add_argument("--target_modules", type=str, default="q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj") + + # 控制参数 + p.add_argument("--max_steps", type=int, default=0) + p.add_argument("--merge_and_save", action="store_true", help="训练完是否合并权重") + p.add_argument("--resume_from_checkpoint", type=str, default="") + p.add_argument("--resume_auto", action="store_true") + p.add_argument("--do_train", action="store_true") + p.add_argument("--do_infer", action="store_true") + + # 推理参数 + p.add_argument("--infer_instruction", type=str, default="某人在交通事故中受到了腹壁穿透创伤,该如何鉴定他的人体损伤程度?") + p.add_argument("--infer_input", type=str, default="") + p.add_argument("--max_new_tokens", type=int, default=1024) + p.add_argument("--do_sample", action="store_true") + p.add_argument("--top_p", type=float, default=0.9) + p.add_argument("--top_k", type=int, default=50) + p.add_argument("--temperature", type=float, default=0.7) + p.add_argument("--repetition_penalty", type=float, default=1.1) + p.add_argument("--no_repeat_ngram_size", type=int, default=0) + p.add_argument("--no_stream", action="store_true") + p.add_argument("--no_prewarm", action="store_true") + p.add_argument("--warmup_tokens", type=int, default=1) + + return p.parse_args() + +# 入口:根据命令行开关执行训练与推理 +# - `--do_train`:执行训练并保存 Adapter(可选合并) +# - `--do_infer`:加载模型并进行流式推理 +if __name__ == "__main__": + args = parse_args() + + if args.do_train: + os.makedirs(args.output_dir, exist_ok=True) + os.makedirs(args.adapter_dir, exist_ok=True) + if args.merge_and_save: + os.makedirs(args.merged_dir, exist_ok=True) + train(args) + + if args.do_infer: + infer(args) diff --git a/work/web_infer.py b/work/web_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..98aff3760cd5de614f0e8585f747777774a79d61 --- /dev/null +++ b/work/web_infer.py @@ -0,0 +1,117 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import os +import threading +import mindspore as ms +from mindspore import context +from mindnlp.transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer +from mindnlp.peft import PeftModel +import gradio as gr + +# =================配置区域================= +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +MODEL_PATH = os.getenv("MODEL_PATH", "./pretrained/Qwen/Qwen2.5-7B-Instruct") +ADAPTER_DIR = os.getenv("ADAPTER_DIR", "./final_lora_output") +MERGED_DIR = os.getenv("MERGED_DIR", "./merged_model") + + + +context.set_context(mode=ms.PYNATIVE_MODE, device_target="Ascend", device_id=0) + +# =================模型加载================= +print("正在加载模型...") +use_merged = os.path.isdir(MERGED_DIR) and bool(os.listdir(MERGED_DIR)) if os.path.exists(MERGED_DIR) else False + +if use_merged: + tokenizer = AutoTokenizer.from_pretrained(MERGED_DIR) + model = AutoModelForCausalLM.from_pretrained(MERGED_DIR, ms_dtype=ms.float16) +else: + tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) + base_model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, ms_dtype=ms.float16) + model = PeftModel.from_pretrained(base_model, ADAPTER_DIR) + +model = model.to("npu:0") +model.set_train(False) +if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id +print("模型加载完毕!") + +# =================推理逻辑================= + +def predict(message, history): + """ + 基于输入消息生成法律问答回复,使用流式输出。 + + 参数: + message (str): 用户当前输入的问题。 + history (list): 聊天历史记录(Gradio 传入,当前未使用)。 + + 生成: + str: 逐字生成的回答内容。 + """ + # 内部固定参数 + max_len = 1024 + temperature = 0.7 + top_p = 0.9 + + full_prompt = build_prompt(tokenizer, message) + + inputs = tokenizer(full_prompt, return_tensors="ms") + inputs = {k: v.to("npu:0") for k, v in inputs.items()} + + streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) + + generation_kwargs = dict( + input_ids=inputs["input_ids"], + max_new_tokens=max_len, + do_sample=True, + temperature=temperature, + top_p=top_p, + pad_token_id=tokenizer.pad_token_id, + streamer=streamer + ) + + thread = threading.Thread(target=model.generate, kwargs=generation_kwargs) + thread.start() + + partial_message = "" + for new_token in streamer: + partial_message += new_token + yield partial_message + +# =================搭建网页================= +with gr.Blocks(title="法律大模型") as demo: + gr.Markdown("# ⚖️ 法律大模型助手 (MindSpore版)") + + chatbot_config = gr.Chatbot( + height=600, + bubble_full_width=False, + show_copy_button=True + ) + + gr.ChatInterface( + predict, + chatbot=chatbot_config, # 传入自定义的 chatbot + examples=[ + ["某人在交通事故中受到了腹壁穿透创伤,该如何鉴定?"], + ["盗窃罪的立案标准是什么?"], + ["请说明注册商标的申请流程?"] + ], + description="基于 Qwen2.5 + LoRA 微调的法律问答助手" + ) + +if __name__ == "__main__": + demo.queue().launch(share=True, server_name="0.0.0.0") \ No newline at end of file