diff --git a/scripts/benchmark/run_pretrain.py b/scripts/benchmark/run_pretrain.py index 32b9e46d515b7b28c84f3d9c01fd1186e26fdace..575c64311317ac7539012a973bdd34e4494c6be9 100644 --- a/scripts/benchmark/run_pretrain.py +++ b/scripts/benchmark/run_pretrain.py @@ -18,6 +18,7 @@ import zipfile import subprocess import tempfile from enum import Enum +from typing import Optional from glob import glob from pathlib import Path import argparse @@ -44,6 +45,13 @@ def convert_path(src_path, save_path): DATASET_TYPES = ['wiki', 'alpaca'] +class EncodingFormat(Enum): + """ + Encoding Format Enumeration Class + """ + UTF_8 = "utf-8" + + class DatasetType(Enum): ZIP = "zip" JSON = "json" @@ -332,17 +340,46 @@ class ModelPretrain(BaseInitModel): raise ValueError(f"Invalid dataset type: {dataset_type}. Must be one of {DATASET_TYPES}.") - def download_file(self, url, save_path): - """download_file""" - flags_ = os.O_WRONLY | os.O_CREAT | os.O_TRUNC + def download_file(self, url: str, save_path: str, encoding: Optional[EncodingFormat] = None) -> str: + """ + Downloads a file from the specified URL and saves it to the given path. + If an encoding is specified, the file is saved using that encoding. + Otherwise, the file is saved in binary mode. + + Args: + url (str): The URL of the file to download. + save_path (str): The local path where the file will be saved. + encoding (Optional[EncodingFormat]): The encoding format to use when saving the file. Defaults to None. + + Returns: + str: The path where the file was saved. + + Raises: + Exception: If the file download fails. + """ try: - with requests.get(url, stream=True, verify=False) as r: - r.raise_for_status() - with os.fdopen(os.open(save_path, flags_, 0o750), 'wb', encoding='utf-8') as f: - for chunk in r.iter_content(chunk_size=8192): - f.write(chunk) + with requests.get(url, stream=True, verify=False) as response: + response.raise_for_status() + + if encoding is None: + # Binary mode: Write bytes to the file + flags = os.O_WRONLY | os.O_CREAT | os.O_TRUNC + with os.fdopen(os.open(save_path, flags, 0o750), 'wb') as file: + for chunk in response.iter_content(chunk_size=8192): + if chunk: # Filter out keep-alive chunks + file.write(chunk) + else: + # Text mode: Decode content using the specified encoding and write as text + text_content = response.content.decode(encoding.value) + with open(save_path, 'w', encoding=encoding.value) as file: + file.write(text_content) + except requests.exceptions.RequestException as e: - raise Exception(f"Failed to download the file,URL:{url}: {e}") from e + raise Exception(f"Failed to download the file. URL: {url}. Error: {e}") from e + except UnicodeDecodeError as e: + raise Exception(f"Failed to decode the file using encoding {encoding.name}. Error: {e}") from e + except OSError as e: + raise Exception(f"Failed to write the file to {save_path}. Error: {e}") from e return save_path