From cb0aa237ccb899e31182ce9c037e59b71d4eb931 Mon Sep 17 00:00:00 2001 From: y30062407 Date: Wed, 11 Dec 2024 11:24:36 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90r1.3.0=E3=80=91=E3=80=90Bugfix?= =?UTF-8?q?=E3=80=91=E4=BF=AE=E5=A4=8Dbenchmark=E8=AE=AD=E7=BB=83=E5=B7=A5?= =?UTF-8?q?=E5=85=B7=E5=9C=A8=E7=BA=BF=E4=B8=8B=E8=BD=BD=E5=A4=84=E7=90=86?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E9=9B=86error?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/benchmark/run_pretrain.py | 55 ++++++++++++++++++++++++++----- 1 file changed, 46 insertions(+), 9 deletions(-) diff --git a/scripts/benchmark/run_pretrain.py b/scripts/benchmark/run_pretrain.py index 32b9e46d51..575c643113 100644 --- a/scripts/benchmark/run_pretrain.py +++ b/scripts/benchmark/run_pretrain.py @@ -18,6 +18,7 @@ import zipfile import subprocess import tempfile from enum import Enum +from typing import Optional from glob import glob from pathlib import Path import argparse @@ -44,6 +45,13 @@ def convert_path(src_path, save_path): DATASET_TYPES = ['wiki', 'alpaca'] +class EncodingFormat(Enum): + """ + Encoding Format Enumeration Class + """ + UTF_8 = "utf-8" + + class DatasetType(Enum): ZIP = "zip" JSON = "json" @@ -332,17 +340,46 @@ class ModelPretrain(BaseInitModel): raise ValueError(f"Invalid dataset type: {dataset_type}. Must be one of {DATASET_TYPES}.") - def download_file(self, url, save_path): - """download_file""" - flags_ = os.O_WRONLY | os.O_CREAT | os.O_TRUNC + def download_file(self, url: str, save_path: str, encoding: Optional[EncodingFormat] = None) -> str: + """ + Downloads a file from the specified URL and saves it to the given path. + If an encoding is specified, the file is saved using that encoding. + Otherwise, the file is saved in binary mode. + + Args: + url (str): The URL of the file to download. + save_path (str): The local path where the file will be saved. + encoding (Optional[EncodingFormat]): The encoding format to use when saving the file. Defaults to None. + + Returns: + str: The path where the file was saved. + + Raises: + Exception: If the file download fails. + """ try: - with requests.get(url, stream=True, verify=False) as r: - r.raise_for_status() - with os.fdopen(os.open(save_path, flags_, 0o750), 'wb', encoding='utf-8') as f: - for chunk in r.iter_content(chunk_size=8192): - f.write(chunk) + with requests.get(url, stream=True, verify=False) as response: + response.raise_for_status() + + if encoding is None: + # Binary mode: Write bytes to the file + flags = os.O_WRONLY | os.O_CREAT | os.O_TRUNC + with os.fdopen(os.open(save_path, flags, 0o750), 'wb') as file: + for chunk in response.iter_content(chunk_size=8192): + if chunk: # Filter out keep-alive chunks + file.write(chunk) + else: + # Text mode: Decode content using the specified encoding and write as text + text_content = response.content.decode(encoding.value) + with open(save_path, 'w', encoding=encoding.value) as file: + file.write(text_content) + except requests.exceptions.RequestException as e: - raise Exception(f"Failed to download the file,URL:{url}: {e}") from e + raise Exception(f"Failed to download the file. URL: {url}. Error: {e}") from e + except UnicodeDecodeError as e: + raise Exception(f"Failed to decode the file using encoding {encoding.name}. Error: {e}") from e + except OSError as e: + raise Exception(f"Failed to write the file to {save_path}. Error: {e}") from e return save_path -- Gitee