diff --git a/omniadvisor/compile.py b/omniadvisor/compile.py index 1ad55968f1605119f39d04726dcdcca5925a36c4..ad63aee52bb9b08a85f8571bab948973efed8398 100644 --- a/omniadvisor/compile.py +++ b/omniadvisor/compile.py @@ -73,6 +73,31 @@ def delete_files_recursively(root_dir, patterns): shutil.rmtree(TEMP_BUILD_DIR) +def set_permissions(root_dir: dir): + """ + 为项目目录下的文件和目录设置权限: + + :param root_dir: 要搜索的根目录 + """ + config_dir_list = [ + f'{root_dir}/config', + ] + + for root, _, files in os.walk(root_dir): + # 判断是否配置目录 + if root in config_dir_list: + os.chmod(root, 0o750) + for file in files: + file_path = os.path.join(root, file) + os.chmod(file_path, 0o640) + # 非配置目录,则为程序目录 + os.chmod(root, 0o550) + for file in files: + file_path = os.path.join(root, file) + os.chmod(file_path, 0o550) + + + if __name__ == '__main__': if os.path.exists(OUTPUT_ROOT_DIR): shutil.rmtree(OUTPUT_ROOT_DIR) @@ -98,3 +123,6 @@ if __name__ == '__main__': # 删除build_output 目录下所有无用文件 delete_files_recursively(OUTPUT_SRC_DIR, RETAIN_FILES) + + # 设置输出文件的权限 + set_permissions(OUTPUT_ROOT_DIR) diff --git a/omniadvisor/src/omniadvisor/utils/logger.py b/omniadvisor/src/omniadvisor/utils/logger.py index 9ecacd4624ce5dbf02914b4bcdd3771f11bd2af9..532e5503ae1b45cd388d2fe41df8b2fc96842af4 100755 --- a/omniadvisor/src/omniadvisor/utils/logger.py +++ b/omniadvisor/src/omniadvisor/utils/logger.py @@ -2,11 +2,53 @@ import os import logging import warnings from logging.config import dictConfig +from logging.handlers import RotatingFileHandler + from common.constant import OA_CONF CONSOLE_FORMAT = '%(asctime)s [%(levelname)s] %(message)s' FILE_FORMAT = '%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d in %(funcName)s - %(message)s' + +class SecureRotatingFileHandler(RotatingFileHandler): + """ + 自定义RotatingFileHandler,支持设置文件权限 + """ + def _open(self): + """ + 打开文件并设置权限 + """ + stream = super()._open() + # 正在记录的日志文件权限设置为640 + os.chmod(self.baseFilename, 0o640) + return stream + + def doRollover(self): + """ + 执行日志滚动并设置归档文件权限 + """ + # 调用父类方法执行滚动 + super().doRollover() + + # 设置所有归档文件的权限 + self._set_archive_permissions() + + def _set_archive_permissions(self): + """ + 设置所有归档日志文件的权限为440 + """ + # 获取日志目录和基本文件名 + dir_name, base_name = os.path.split(self.baseFilename) + + # 获取所有可能的归档文件 + file_names = os.listdir(dir_name) + for file_name in file_names: + if file_name.startswith(base_name) and file_name != os.path.basename(self.baseFilename): + archive_path = os.path.join(dir_name, file_name) + # 归档的日志文件权限设置为440 + os.chmod(archive_path, 0o440) + + LOGGING_CONFIG = { 'version': 1, 'disable_existing_loggers': True, @@ -36,7 +78,7 @@ LOGGING_CONFIG = { 'stream': 'ext://sys.stdout', }, 'fileHandler': { - 'class': 'logging.handlers.RotatingFileHandler', + '()': SecureRotatingFileHandler, 'level': 'DEBUG', 'formatter': 'fileFormatter', 'filename': OA_CONF.log_file_path, @@ -56,6 +98,8 @@ LOGGING_CONFIG = { # 检查并创建log文件夹 if not os.path.exists(OA_CONF.log_dir): os.makedirs(OA_CONF.log_dir) + # 日志文件目录权限设置为750 + os.chmod(OA_CONF.log_dir, 0o750) # 使用dictConfig加载配置 dictConfig(LOGGING_CONFIG) diff --git a/omniadvisor/src/omniadvisor/utils/utils.py b/omniadvisor/src/omniadvisor/utils/utils.py index 7bc693626217302148d7c591e2c641853486dc75..e822a09dab1a7ff44e1dda8d3624f7eb89f0d014 100644 --- a/omniadvisor/src/omniadvisor/utils/utils.py +++ b/omniadvisor/src/omniadvisor/utils/utils.py @@ -38,10 +38,14 @@ def save_trace_data(data: List[Dict[str, str]], data_dir): """ if not os.path.exists(data_dir): os.makedirs(data_dir) + # 数据目录权限设置为750 + os.chmod(data_dir, 0o750) file_path = "".join([data_dir, "/", str(uuid.uuid4())]) try: with open(file_path, 'w', encoding='utf-8') as f: json.dump(data, f, ensure_ascii=False, indent=4) + # 数据文件权限设置为640 + os.chmod(file_path, 0o640) global_logger.debug(f"数据已成功保存到 {file_path}") except IOError as e: raise IOError(f"出现IO错误: {e}") from e diff --git a/omniadvisor/tests/omniadvisor/interface/test_hijack_recommend.py b/omniadvisor/tests/omniadvisor/interface/test_hijack_recommend.py index 1926f6cc09816c8a6411558cbe0d9d77eef56e5c..a078b85f552304a1bd975b2498701aaa791b7938 100644 --- a/omniadvisor/tests/omniadvisor/interface/test_hijack_recommend.py +++ b/omniadvisor/tests/omniadvisor/interface/test_hijack_recommend.py @@ -11,6 +11,10 @@ from omniadvisor.interface.hijack_recommend import hijack_recommend, _process_lo # 测试代码 class TestHijackRecommend: + def setup_class(self): + # 为了避免抓取trace子进程的等待 + OA_CONF.spark_fetch_trace_timeout = 0 + # 场景 1: 正常流程,Spark 执行成功 @patch("omniadvisor.interface.hijack_recommend._process_load_config") @patch("omniadvisor.interface.hijack_recommend.spark_run") diff --git a/omniadvisor/tests/omniadvisor/utils/test_utils.py b/omniadvisor/tests/omniadvisor/utils/test_utils.py index cd6293ac89e1d2ace444c91af07d934e25714185..7a580175398d030c0d8e067d08490be17bd0e97b 100644 --- a/omniadvisor/tests/omniadvisor/utils/test_utils.py +++ b/omniadvisor/tests/omniadvisor/utils/test_utils.py @@ -32,9 +32,10 @@ class TestRunCmd: class TestTraceDataSaver: @patch('builtins.open', new_callable=mock_open, create=True) + @patch('os.chmod') @patch('os.makedirs') @patch('uuid.uuid4', return_value="test-uuid") - def test_save_trace_data_success(self, mock_uuid, mock_makedirs, mock_file): + def test_save_trace_data_success(self, mock_uuid, mock_makedirs, mock_chmod, mock_file): data = [{"key": "value"}] data_dir = "/tmp" @@ -46,8 +47,9 @@ class TestTraceDataSaver: @patch('builtins.open', new_callable=mock_open, create=True) @patch('uuid.uuid4', return_value="test-uuid") + @patch('os.chmod') @patch('os.makedirs') - def test_save_trace_data_ioerror(self, mock_makedirs, mock_uuid, mock_open_file): + def test_save_trace_data_ioerror(self, mock_makedirs, mock_chmod, mock_uuid, mock_open_file): data = [{"key": "value"}] data_dir = "/tmp" @@ -60,10 +62,11 @@ class TestTraceDataSaver: except IOError as e: assert str(e) == "出现IO错误: IO Error", f"Unexpected error message: {str(e)}" + @patch('os.chmod') @patch('os.makedirs') @patch('uuid.uuid4', return_value="test-uuid") @patch('builtins.open', new_callable=mock_open) - def test_save_trace_data_exception(self, mock_open, mock_uuid, mock_makedirs): + def test_save_trace_data_exception(self, mock_open, mock_uuid, mock_makedirs, mock_chmod): data = [{"key": "value"}] data_dir = "/tmp" mock_open.side_effect = Exception("Unexpected error") # 设置 mock_open 抛出异常