1 Star 0 Fork 0

O疯O / my_chatglm

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
utils.py 5.93 KB
一键复制 编辑 原始数据 按行查看 历史
wfhe 提交于 2024-01-15 15:29 . add too call for openai api
import gc
import json
import torch
from transformers import PreTrainedModel, PreTrainedTokenizer
from transformers.generation.logits_process import LogitsProcessor
from typing import Union, Tuple
class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
if torch.isnan(scores).any() or torch.isinf(scores).any():
scores.zero_()
scores[..., 5] = 5e4
return scores
def process_response(output: str, use_tool: bool = False) -> Union[str, dict]:
content = ""
for response in output.split("<|assistant|>"):
metadata, content = response.split("\n", maxsplit=1)
if not metadata.strip():
content = content.strip()
content = content.replace("[[训练时间]]", "2023年")
else:
if use_tool:
content = "\n".join(content.split("\n")[1:-1])
def tool_call(**kwargs):
return kwargs
parameters = eval(content)
content = {
"name": metadata.strip(),
"arguments": json.dumps(parameters, ensure_ascii=False)
}
else:
content = {
"name": metadata.strip(),
"content": content
}
return content
@torch.inference_mode()
def generate_stream_chatglm3(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, params: dict):
messages = params["messages"]
tools = params["tools"]
temperature = float(params.get("temperature", 1.0))
repetition_penalty = float(params.get("repetition_penalty", 1.0))
top_p = float(params.get("top_p", 1.0))
max_new_tokens = int(params.get("max_tokens", 256))
echo = params.get("echo", True)
messages = process_chatglm_messages(messages, tools=tools)
query, role = messages[-1]["content"], messages[-1]["role"]
inputs = tokenizer.build_chat_input(query, history=messages[:-1], role=role)
inputs = inputs.to(model.device)
input_echo_len = len(inputs["input_ids"][0])
if input_echo_len >= model.config.seq_length:
print(f"Input length larger than {model.config.seq_length}")
eos_token_id = [
tokenizer.eos_token_id,
tokenizer.get_command("<|user|>"),
]
gen_kwargs = {
"max_new_tokens": max_new_tokens,
"do_sample": True if temperature > 1e-5 else False,
"top_p": top_p,
"repetition_penalty": repetition_penalty,
"logits_processor": [InvalidScoreLogitsProcessor()],
}
if temperature > 1e-5:
gen_kwargs["temperature"] = temperature
total_len = 0
for total_ids in model.stream_generate(**inputs, eos_token_id=eos_token_id, **gen_kwargs):
total_ids = total_ids.tolist()[0]
total_len = len(total_ids)
if echo:
output_ids = total_ids[:-1]
else:
output_ids = total_ids[input_echo_len:-1]
response = tokenizer.decode(output_ids)
if response and response[-1] != "�":
response, stop_found = apply_stopping_strings(response, ["<|observation|>"])
yield {
"text": response,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": total_len - input_echo_len,
"total_tokens": total_len,
},
"finish_reason": "function_call" if stop_found else None,
}
if stop_found:
break
# Only last stream result contains finish_reason, we set finish_reason as stop
ret = {
"text": response,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": total_len - input_echo_len,
"total_tokens": total_len,
},
"finish_reason": "stop",
}
yield ret
gc.collect()
torch.cuda.empty_cache()
def process_chatglm_messages(messages, tools=None):
_messages = messages
messages = []
if tools:
messages.append(
{
"role": "system",
"content": "Answer the following questions as best as you can. You have access to the following tools:",
"tools": tools
}
)
for m in _messages:
role, content, func_call = m.role, m.content, m.function_call
if role == "function":
messages.append(
{
"role": "observation",
"content": content
}
)
elif role == "assistant" and func_call is not None:
for response in content.split("<|assistant|>"):
metadata, sub_content = response.split("\n", maxsplit=1)
messages.append(
{
"role": role,
"metadata": metadata,
"content": sub_content.strip()
}
)
else:
messages.append({"role": role, "content": content})
return messages
def generate_chatglm3(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, params: dict):
for response in generate_stream_chatglm3(model, tokenizer, params):
pass
return response
def apply_stopping_strings(reply, stop_strings) -> Tuple[str, bool]:
stop_found = False
for string in stop_strings:
idx = reply.find(string)
if idx != -1:
reply = reply[:idx]
stop_found = True
break
if not stop_found:
# If something like "\nYo" is generated just before "\nYou: is completed, trim it
for string in stop_strings:
for j in range(len(string) - 1, 0, -1):
if reply[-j:] == string[:j]:
reply = reply[:-j]
break
else:
continue
break
return reply, stop_found
1
https://gitee.com/wfhe/my_chatglm.git
git@gitee.com:wfhe/my_chatglm.git
wfhe
my_chatglm
my_chatglm
master

搜索帮助