diff --git a/config_checking/checkers/checkpoint_checker.py b/config_checking/checkers/checkpoint_checker.py index 30e3bf91f35a05b0c0152ac362aee8af04043f6a..be5cb6fc2109a9dd1084202349df4c9b3d212d1b 100644 --- a/config_checking/checkers/checkpoint_checker.py +++ b/config_checking/checkers/checkpoint_checker.py @@ -1,19 +1,19 @@ -import os import json +import os + import torch -import hashlib -import zlib + from config_checking.checkers.base_checker import BaseChecker +from config_checking.config_checker import register_checker_item +from config_checking.utils.hash import bytes_hash from config_checking.utils.packing import create_file_in_zip -from config_checking.utils.utils import load_json, compare_dict, write_list_to_file -from config_checking.config_checker import register_checker_item from config_checking.utils.utils import config_checking_print +from config_checking.utils.utils import load_json, compare_dict, write_list_to_file def tensor_to_hash(tensor): tensor_bytes = tensor.cpu().numpy().tobytes() - hash_object = hashlib.sha256(tensor_bytes) - return hash_object.hexdigest() + return bytes_hash(tensor_bytes) def tensor_in_state_dict_to_hash(state_dict): result = {} diff --git a/config_checking/utils/hash.py b/config_checking/utils/hash.py index 153cd05818391f4dc56073c834b6c5f727566ae1..d5721aad21f669a99597a4a12cc67ca3d4ce18c2 100644 --- a/config_checking/utils/hash.py +++ b/config_checking/utils/hash.py @@ -2,7 +2,6 @@ import hashlib import os from concurrent.futures import ThreadPoolExecutor - BLOCK_SIZE = 64 << 20 # 64MB MAX_THREAD_WORKERS = 16 @@ -41,3 +40,9 @@ def calculate_hash(file_path, max_workers=MAX_THREAD_WORKERS): def string_hash(input_str): return hashlib.sha256(input_str.encode('utf-8')).hexdigest() + + +def bytes_hash(obj: bytes): + hex_dig = hashlib.sha256(obj).hexdigest() + short_hash = int(hex_dig, 16) % (2 ** 16) + return short_hash diff --git a/config_checking/utils/packing.py b/config_checking/utils/packing.py index 22690b229c720ddd10e387af6b36acc46089de24..621e713f9b9cb13316b5b8f36196a672b6e7534c 100644 --- a/config_checking/utils/packing.py +++ b/config_checking/utils/packing.py @@ -3,6 +3,8 @@ import zipfile import hashlib import multiprocessing +from config_checking.utils.hash import string_hash + proc_lock = multiprocessing.Lock() @@ -80,7 +82,7 @@ class DirPacker: hash_file_path = f"{rel_path}.hash" target_file_path = os.path.join(self.result_dirname, hash_file_path) with open(file_path, 'rb') as f: - file_hash = hashlib.sha256(f.read()).hexdigest() + file_hash = string_hash(f.read()) zip_info = zipfile.ZipInfo(target_file_path) self.zip_handler.writestr(zip_info, file_hash) diff --git a/config_checking/utils/random_patch.py b/config_checking/utils/random_patch.py index bb24919fc1b1e265c2e006139cb9fa0d8924cc7e..a6b3e4d95a277c5e6a23f57032d3ad34d638d054 100644 --- a/config_checking/utils/random_patch.py +++ b/config_checking/utils/random_patch.py @@ -14,7 +14,6 @@ DEFAULT_RANDOM_LOG_PATH = './random_patch.log' # 1、日志写入文件并发处理 # 2、支持日志文件路径设置 # 3、多个装饰器可改为责任链模式 -logging.basicConfig(filename=DEFAULT_RANDOM_LOG_PATH, level=logging.INFO) def __log_stack(func): @@ -50,6 +49,9 @@ def __track_func(func): def apply_patches(): + # init logging + logging.basicConfig(filename=DEFAULT_RANDOM_LOG_PATH, level=logging.INFO) + # Patch random module random.random = __track_func(random.random) random.randint = __track_func(random.randint) @@ -71,4 +73,7 @@ def apply_patches(): torch.randn_like = __track_func(torch.randn_like) torch.manual_seed = __track_func(torch.manual_seed) + # Patch torch.Tensor random function + torch.Tensor.exponential_ = __track_func(torch.Tensor.exponential_) + config_checking_print(f"random patches saved to file: {DEFAULT_RANDOM_LOG_PATH}") diff --git a/config_checking/utils/utils.py b/config_checking/utils/utils.py index 37705fff7a8c695d21e66b3b5db967218f902f3c..d0213c7e411338a93cf16418fa0a630f3292be0e 100644 --- a/config_checking/utils/utils.py +++ b/config_checking/utils/utils.py @@ -4,7 +4,8 @@ import os import re import torch import torch.distributed as dist -import hashlib + +from config_checking.utils.hash import bytes_hash def load_txt(file_path): @@ -61,9 +62,7 @@ def config_checking_print(msg): def tensor_to_hash(tensor): """Compute the hash value of a tensor""" tensor_bytes = tensor.clone().detach().cpu().numpy().tobytes() - hash_object = hashlib.sha256() - hash_object.update(tensor_bytes) - return hash_object.hexdigest() + return bytes_hash(tensor_bytes) features = {