代码拉取完成,页面将自动刷新
import os
import json
from argparse import ArgumentParser
from glob import glob
from tqdm import tqdm
import torch
from safetensors.torch import load_file, save_file
#from kernel import weight_dequant
def weight_dequant(weight: torch.Tensor, scale: torch.Tensor, block_size: int = 128) -> torch.Tensor:
"""
Dequantizes the given weight tensor using the provided scale tensor, efficiently handling cases where
`weight` is not a multiple of `block_size` by broadcasting `scale`.
Args:
weight (torch.Tensor): The quantized weight tensor of shape(M, N).
scale (torch.Tensor): The scale tensor of shape (M // block_size, N // block_size).
block_size (int, optional): The block size to use for dequantization. Defaults to 128.
Returns:
torch.Tensor: The dequantized weight tensor of the same shape as `weight`, converted to the default dtype.
Raises:
AssertionError: If `scale` dimensions do not align with `weight` shape after scaling.
"""
#Get the original dimensions of weight
M, N = weight.shape
# Compute the effective block dimensions for scale
scale_m, scale_n = scale.shape
assert scale_m == (M + block_size - 1) // block_size, "Mismatch in scale rows and weight rows."
assert scale_n == (N + block_size - 1) // block_size, "Mismatch in scale columns and weight columns."
# Convert weight to float32 for calculations
weight = weight.to(torch.float32)
# Expand scale to match the weight tensor's shape
scale_expanded = scale.repeat_interleave(block_size, dim=0).repeat_interleave(block_size, dim=1)
# Trim scale_expanded to match weight's shape if necessary
scale_expanded = scale_expanded[:M, :N]
# Perform element-wise multiplication
dequantized_weight = weight * scale_expanded
# Convert the output to the default dtype
dequantized_weight = dequantized_weight.to(torch.get_default_dtype())
return dequantized_weight
def main(fp8_path, bf16_path):
"""
Converts FP8 weights to BF16 and saves the converted weights.
This function reads FP8 weights from the specified directory, converts them to BF16,
and saves the converted weights to another specified directory. It also updates the
model index file to reflect the changes.
Args:
fp8_path (str): The path to the directory containing the FP8 weights and model index file.
bf16_path (str): The path to the directory where the converted BF16 weights will be saved.
Raises:
KeyError: If a required scale_inv tensor is missing for a weight.
Notes:
- The function assumes that the FP8 weights are stored in safetensor files.
- The function caches loaded safetensor files to optimize memory usage.
- The function updates the model index file to remove references to scale_inv tensors.
"""
torch.set_default_dtype(torch.bfloat16)
os.makedirs(bf16_path, exist_ok=True)
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
with open(model_index_file, "r") as f:
model_index = json.load(f)
weight_map = model_index["weight_map"]
# Cache for loaded safetensor files
loaded_files = {}
fp8_weight_names = []
# Helper function to get tensor from the correct file
def get_tensor(tensor_name):
"""
Retrieves a tensor from the cached safetensor files or loads it from disk if not cached.
Args:
tensor_name (str): The name of the tensor to retrieve.
Returns:
torch.Tensor: The retrieved tensor.
Raises:
KeyError: If the tensor does not exist in the safetensor file.
"""
file_name = weight_map[tensor_name]
if file_name not in loaded_files:
file_path = os.path.join(fp8_path, file_name)
loaded_files[file_name] = load_file(file_path, device="cpu")
return loaded_files[file_name][tensor_name]
safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
safetensor_files.sort()
for safetensor_file in tqdm(safetensor_files):
file_name = os.path.basename(safetensor_file)
current_state_dict = load_file(safetensor_file, device="cpu")
loaded_files[file_name] = current_state_dict
new_state_dict = {}
for weight_name, weight in current_state_dict.items():
if weight_name.endswith("_scale_inv"):
continue
elif weight.element_size() == 1: #FP8 weight
scale_inv_name = f"{weight_name}_scale_inv"
try:
# Get scale_inv from the correct file
scale_inv = get_tensor(scale_inv_name)
fp8_weight_names.append(weight_name)
new_state_dict[weight_name] = weight_dequant(weight, scale_inv)
except KeyError:
print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion")
new_state_dict[weight_name] = weight
else:
new_state_dict[weight_name] = weight
new_safetensor_file = os.path.join(bf16_path, file_name)
save_file(new_state_dict, new_safetensor_file)
# Memory management: keep only the 2 most recently used files
if len(loaded_files) > 2:
oldest_file = next(iter(loaded_files))
del loaded_files[oldest_file]
# torch.cuda.empty_cache()
# Update model index
new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")
for weight_name in fp8_weight_names:
scale_inv_name = f"{weight_name}_scale_inv"
if scale_inv_name in weight_map:
weight_map.pop(scale_inv_name)
with open(new_model_index_file, "w") as f:
json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--input-fp8-hf-path", type=str, required=True)
parser.add_argument("--output-bf16-hf-path", type=str, required=True)
args = parser.parse_args()
main(args.input_fp8_hf_path, args.output_bf16_hf_path)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。