From d2b3710fcdeb0e2a7bd95670dfaffdc3277e72aa Mon Sep 17 00:00:00 2001 From: yanhe13 Date: Fri, 11 Jul 2025 16:51:34 +0800 Subject: [PATCH 1/2] calculate performance of perf data upgrade --- .../default_perf_metric_calculator.py | 18 ++- .../stable_perf_metric_calculator.py | 81 ++++++------ .../benchmark/clients/base_client.py | 6 +- .../models/huggingface_above_v4_33.py | 2 +- .../benchmark/models/performance_api.py | 4 +- .../benchmark/models/vllm_custom_api_chat.py | 2 +- .../icl_inferencer/icl_gen_inferencer.py | 19 +-- .../icl_inferencer/icl_gen_perf_inferencer.py | 30 ++--- .../icl_gen_pressure_inferencer.py | 9 +- .../ais_bench/benchmark/runners/local_api.py | 2 +- .../benchmark/summarizers/default_perf.py | 26 ++-- .../benchmark/tasks/openicl_infer.py | 4 +- .../ais_bench/benchmark/utils/results.py | 15 ++- .../benchmark/utils/summarize_plot.py | 120 +++++++++--------- .../benchmark/requirements/runtime.txt | 1 + 15 files changed, 174 insertions(+), 165 deletions(-) diff --git a/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/calculators/default_perf_metric_calculator.py b/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/calculators/default_perf_metric_calculator.py index 83344578..85853b11 100644 --- a/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/calculators/default_perf_metric_calculator.py +++ b/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/calculators/default_perf_metric_calculator.py @@ -75,7 +75,8 @@ class DefaultPerfMetricCalculator(BasePerfMetricCalculator): else: result["average_decode_latencies"] = result["prefill_latency"] self.logger.info("Converting perf results of stage ...") - self.result[stage_name] = self.convert_result(copy.deepcopy(result)) + self.result[stage_name] = self.convert_result(result) + self.logger.info("Finish Converting!") def get_common_res(self): return {k: v for k, v in self.common_metrics.items() if v is not None} @@ -155,9 +156,13 @@ class DefaultPerfMetricCalculator(BasePerfMetricCalculator): return ans def calculate(self): + self.logger.info("Start calculating metrics ...") self.__calc_metrics() + self.logger.info("Start calculating common metrics ...") self.__calc_common_metrics() + self.logger.info("Start calculating add units ...") self.add_units() + self.logger.info("Finish calculating perf data!") def __calc_metrics(self): """Calculate various statistical metrics for performance analysis.""" @@ -171,17 +176,18 @@ class DefaultPerfMetricCalculator(BasePerfMetricCalculator): value = self.__statistic_prefill_or_decode_batch_size(value) # Compute statistical values + arr = np.array(value) for stat in self.stats_list: if stat == "Average": - stats[stat] = round(np.average(value), 4) + stats[stat] = round(arr.mean(), 4) elif stat == "Min": - stats[stat] = round(float(min(value)), 4) + stats[stat] = round(float(arr.min()), 4) elif stat == "Max": - stats[stat] = round(float(max(value)), 4) + stats[stat] = round(float(arr.max()), 4) elif stat == "Median": - stats[stat] = round(np.percentile(value, 50), 4) + stats[stat] = round(np.percentile(arr, 50), 4) elif is_legal_percentage_str(stat): - stats[stat] = round(np.percentile(value, int(stat[1:])), 4) + stats[stat] = round(np.percentile(arr, int(stat[1:])), 4) # Store the computed metrics if self.metrics.get(metric) is None: diff --git a/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/calculators/stable_perf_metric_calculator.py b/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/calculators/stable_perf_metric_calculator.py index 83361045..4d81cea4 100644 --- a/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/calculators/stable_perf_metric_calculator.py +++ b/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/calculators/stable_perf_metric_calculator.py @@ -1,4 +1,5 @@ import csv +import heapq from tqdm import tqdm import collections import math @@ -9,7 +10,7 @@ from ais_bench.benchmark.calculators.base_perf_metric_calculator import BasePerf from ais_bench.benchmark.registry import PERF_METRIC_CALCULATORS from ais_bench.benchmark.calculators.base_perf_metric_calculator import is_legal_percentage_str, DEFAULT_STATS, MAX_STATS_LEN -WAVE_OFFSET = 0.02 +WAVE_OFFSET = 0.05 @PERF_METRIC_CALCULATORS.register_module() class StablePerfMetricCalculator(BasePerfMetricCalculator): @@ -41,43 +42,45 @@ class StablePerfMetricCalculator(BasePerfMetricCalculator): self._process_result(perf_details.get("requests"), stage_name) def _get_requests_id(self, perf_details): - request_time_sections = [] - for id in range(len(perf_details["requests"]["id"])): - request_time_sections.append({ - "id": id, - "start_time": perf_details["requests"]["start_time"][id], - "end_time": perf_details["requests"]["end_time"][id], - }) - + request_time_sections = [ + {"id": id, "start_time": perf_details["requests"]["start_time"][id], "end_time": perf_details["requests"]["end_time"][id]} + for id in range(len(perf_details["requests"]["id"])) + ] sorted_time_sections = sorted(request_time_sections, key=lambda x: x["start_time"]) + + active_heap = [] # 最小堆存储(end_time, id) id_lists = [] - working_reqs = {} self.logger.info("Calculating stable stage ...") - for i, section in enumerate(tqdm(sorted_time_sections)): - poped_ids = [] - for k in list(working_reqs.keys()): - if working_reqs[k][1] < section["start_time"]: - poped_ids.append(k) - working_reqs.pop(k, None) - working_reqs[section["id"]] = [section["start_time"], section["end_time"]] - if len(working_reqs) == self.max_concurrency: + + for section in tqdm(sorted_time_sections): + # 1. 清理过期请求并记录最小结束时间 + poped_min_end = None + while active_heap and active_heap[0][0] < section["start_time"]: + end_time, req_id = heapq.heappop(active_heap) + poped_min_end = end_time if poped_min_end is None else min(poped_min_end, end_time) + + # 2. 添加当前请求 + heapq.heappush(active_heap, (section["end_time"], section["id"])) + current_active = len(active_heap) + + # 3. 判断稳定阶段 + if current_active == self.max_concurrency: id_lists.append(section["id"]) if len(id_lists) == 1: - self.stage_section[0] = min([perf_details["requests"]["end_time"][id] for id in list(working_reqs.keys())]) # total start time - elif len(working_reqs) >= int(self.max_concurrency * (1 - WAVE_OFFSET)) and len(id_lists) > 0: + self.stage_section[0] = active_heap[0][0] # 堆顶即最小结束时间 + elif current_active >= int(self.max_concurrency * (1 - WAVE_OFFSET)) and len(id_lists) > 0: id_lists.append(section["id"]) - else: - if len(id_lists) > 0: # start to leave stable - self.stage_section[1] = min([perf_details["requests"]["end_time"][id] for id in poped_ids]) - break - - if len(id_lists) > 0: - id_lists.pop(0) # ignore first request that reached max concurrency - if len(id_lists) == 0: + elif len(id_lists) > 0: # 退出稳定阶段 + self.stage_section[1] = poped_min_end if poped_min_end is not None else active_heap[0][0] + break + + # 4. 后处理 + if id_lists: + id_lists.pop(0) + if not id_lists: raise RuntimeError("Can not find a stable stage!") - if self.stage_section[1] == 0: - self.stage_section[1] = min([perf_details["requests"]["end_time"][id] for id in list(working_reqs.keys())]) # total end time + self.stage_section[1] = active_heap[0][0] if active_heap else sorted_time_sections[-1]["end_time"] return id_lists def _get_legal_stats_list(self, stats_list): @@ -116,7 +119,8 @@ class StablePerfMetricCalculator(BasePerfMetricCalculator): else: result["average_decode_latencies"] = result["prefill_latency"] self.logger.info("Converting perf results of stage ...") - self.result[stage_name] = self.convert_result(copy.deepcopy(result)) + self.result[stage_name] = self.convert_result(result) + self.logger.info("Finish Converting!") def get_common_res(self): return {k: v for k, v in self.common_metrics.items() if v is not None} @@ -196,9 +200,13 @@ class StablePerfMetricCalculator(BasePerfMetricCalculator): return ans def calculate(self): + self.logger.info("Start calculating metrics ...") self.__calc_metrics() + self.logger.info("Start calculating common metrics ...") self.__calc_common_metrics() + self.logger.info("Start calculating add units ...") self.add_units() + self.logger.info("Finish calculating perf data!") def __calc_metrics(self): """Calculate various statistical metrics for performance analysis.""" @@ -212,17 +220,18 @@ class StablePerfMetricCalculator(BasePerfMetricCalculator): value = self.__statistic_prefill_or_decode_batch_size(value) # Compute statistical values + arr = np.array(value) for stat in self.stats_list: if stat == "Average": - stats[stat] = round(np.average(value), 4) + stats[stat] = round(arr.mean(), 4) elif stat == "Min": - stats[stat] = round(float(min(value)), 4) + stats[stat] = round(float(arr.min()), 4) elif stat == "Max": - stats[stat] = round(float(max(value)), 4) + stats[stat] = round(float(arr.max()), 4) elif stat == "Median": - stats[stat] = round(np.percentile(value, 50), 4) + stats[stat] = round(np.percentile(arr, 50), 4) elif is_legal_percentage_str(stat): - stats[stat] = round(np.percentile(value, int(stat[1:])), 4) + stats[stat] = round(np.percentile(arr, int(stat[1:])), 4) # Store the computed metrics if self.metrics.get(metric) is None: diff --git a/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/clients/base_client.py b/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/clients/base_client.py index 2cabfca3..5229529b 100644 --- a/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/clients/base_client.py +++ b/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/clients/base_client.py @@ -179,7 +179,7 @@ class BaseClient(ABC): raise_error(f"Error processing stream response: {e}", self.lock, self.request_counter) except HTTPError as e: raise_error(f"HTTP error during stream response processing: {e}.", self.lock, self.request_counter) - + self.rev_count() self.update_request_time(inputs, start_time) return "".join(response) @@ -209,9 +209,7 @@ class BaseStreamClient(BaseClient, ABC): cur_time_point = time.perf_counter() response_dict = self.process_stream_line(json_content) if time_name not in response_dict.keys(): - response_dict[time_name] = ( - cur_time_point - last_time_point - ) * 1000 + response_dict[time_name] = round((cur_time_point - last_time_point) * 1000, 4) response_dict["chunk_time_point"] = cur_time_point * 1000 yield response_dict time_name = "decode_time" diff --git a/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/models/huggingface_above_v4_33.py b/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/models/huggingface_above_v4_33.py index 85c2c32a..2c9982e5 100644 --- a/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/models/huggingface_above_v4_33.py +++ b/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/models/huggingface_above_v4_33.py @@ -201,7 +201,7 @@ class HuggingFacewithChatTemplate(PerformanceModel): for k, v in other_kwargs.items(): if v is not None: self.logger.warning(f'Unused argument {k}={v}') - + def handle_perf_result(self, output_filepath, output_filename): e2e_latency = max(self.timestamps) - min(self.timestamps) return {"Benchmark Duration":{"total":str(round(e2e_latency, 4)) + ' ms'}} diff --git a/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/models/performance_api.py b/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/models/performance_api.py index ba94dda0..2b056031 100644 --- a/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/models/performance_api.py +++ b/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/models/performance_api.py @@ -77,7 +77,7 @@ class PerformanceAPIModel(BaseAPIModel): cache_data.num_input_chars = 0 cache_data.input_token_id = token_id cache_data.num_input_tokens = len(token_id) - + def set_result(self, data: MiddleData) -> None: """Update decoding information for a given request.""" if not data.output: @@ -146,10 +146,12 @@ class PerformanceAPIModel(BaseAPIModel): self.result_cache[key].num_generated_tokens = len(tokens) performance_data = [] try: + self.logger.info("Start converting origin perf data ...") performance_data = [ cache_data.convert_to_performance_data() for cache_data in self.result_cache.values() ] + self.logger.info("Finish converting origin perf data") except Exception as e: self.logger.error(f"Error converting performance data: {e}") finally: diff --git a/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/models/vllm_custom_api_chat.py b/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/models/vllm_custom_api_chat.py index f7764548..3743c4ad 100644 --- a/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/models/vllm_custom_api_chat.py +++ b/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/models/vllm_custom_api_chat.py @@ -176,7 +176,7 @@ class VLLMCustomAPIChat(PerformanceAPIModel): elif item['role'] == 'SYSTEM': msg['role'] = 'system' messages.append(msg) - + generation_kwargs = self.generation_kwargs.copy() generation_kwargs.update({"max_tokens": max_out_len}) generation_kwargs.update({"model": self.model}) diff --git a/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/openicl/icl_inferencer/icl_gen_inferencer.py b/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/openicl/icl_inferencer/icl_gen_inferencer.py index 75c3eba3..57c0676c 100644 --- a/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/openicl/icl_inferencer/icl_gen_inferencer.py +++ b/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/openicl/icl_inferencer/icl_gen_inferencer.py @@ -9,10 +9,10 @@ import multiprocessing from pathlib import Path from multiprocessing import RLock, freeze_support from typing import List, Optional, Tuple, Any - import torch import shutil from tqdm import tqdm +import itertools from ais_bench.benchmark.models.base import BaseModel from ais_bench.benchmark.registry import ICL_INFERENCERS @@ -24,6 +24,7 @@ from ..icl_prompt_template import PromptTemplate from ..icl_retriever import BaseRetriever from ..utils.logging import get_logger from .icl_base_inferencer import BaseInferencer, GenInferencerOutputHandler +import concurrent.futures logger = get_logger(__name__) @@ -43,7 +44,7 @@ def submit_single_model(model_cfg, mp_queue, **extra_gen_kwargs): raise AttributeError(f'{model} has no except outputs, please check model config') return model.get_performance_data() - + @ICL_INFERENCERS.register_module() class GenInferencer(BaseInferencer): """Generation Inferencer class to directly evaluate by generation. @@ -114,7 +115,7 @@ class GenInferencer(BaseInferencer): logger.warning(f"Inputs data number is {len(inputs)}, result will be empty") return results max_concurrency = extra_gen_kwargs.get("batch_size", 1) - + # Maximum MAX_CONCURRENCY_PER_PROCESS concurrency per process, number of processes less than number of cores workers_num = min( multiprocessing.cpu_count(), (max_concurrency - 1) // DEFAULT_MAX_CONCURRENCY_PER_PROCESS + 1 @@ -140,7 +141,7 @@ class GenInferencer(BaseInferencer): data_buckets = [] real_data_nums = [] bucket_index = 0 - data_index = 0 + data_index = 0 while data_index Tuple[List, List]: @@ -232,7 +233,7 @@ class GenInferencer(BaseInferencer): if hasattr(self.model, "set_performance"): extra_kwargs['do_performance'] = self.model.do_performance return extra_kwargs - + def inference(self, retriever: BaseRetriever, ice_template: Optional[PromptTemplate] = None, diff --git a/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/openicl/icl_inferencer/icl_gen_perf_inferencer.py b/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/openicl/icl_inferencer/icl_gen_perf_inferencer.py index 6e1b96d6..d054a5cf 100644 --- a/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/openicl/icl_inferencer/icl_gen_perf_inferencer.py +++ b/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/openicl/icl_inferencer/icl_gen_perf_inferencer.py @@ -21,7 +21,7 @@ from ..icl_prompt_template import PromptTemplate from ..icl_retriever import BaseRetriever from ..utils.logging import get_logger from .icl_base_inferencer import GenInferencerOutputHandler -from ais_bench.benchmark.utils.results import dump_results_dict +from ais_bench.benchmark.utils.results import dump_results_dict, fast_dump_results_dict from .icl_gen_inferencer import GenInferencer logger = get_logger(__name__) @@ -121,29 +121,15 @@ class GenPerfInferencer(GenInferencer): parsed_entries = self.model.parse_template(entry, mode='gen') results = self.inference_with_multi_process( self.model, self.model_cfg, parsed_entries, golds, **extra_gen_kwargs) - results.sort(key=lambda x: x['id']) + logger.info("Start extracting pref datas ...") preds = self.extract_preds(results) + logger.info("Finish extracting pref datas!") task_params = {"max_concurrency": self.batch_size} num_return_sequences = getattr(self.model, "generation_kwargs", {}).get( "num_return_sequences", 1 ) - for prediction in batched(results, num_return_sequences): - if num_return_sequences == 1: - prediction = prediction[0] - if not prediction.get('is_success'): - pred = "" - else: - pred = prediction.get('output') - data_id = prediction.get('id') - if data_id >= len(golds) or data_id < 0: - raise IndexError(f"No gold of output id {data_id}") - output_handler.save_results(parsed_entries[data_id], - pred, - data_id, - gold=golds[data_id]) - end_time_stamp = time.perf_counter() if self.is_main_process: @@ -153,11 +139,12 @@ class GenPerfInferencer(GenInferencer): "requests": preds, } logger.info("Dumping detail perf data ...") - dump_results_dict( + dump_start = time.perf_counter() + fast_dump_results_dict( perf_details, - osp.join(output_filepath, output_filename + "_details.json"), - False + osp.join(output_filepath, output_filename + "_details.json") ) + logger.info(f"Dump detail perf data cost: {time.perf_counter() - dump_start}(s)") if self.dump_timer and self.is_main_process: timer_filepath = osp.join(output_filepath, "timer", "time.jsonl") @@ -188,6 +175,9 @@ class GenPerfInferencer(GenInferencer): } preds["is_success"] = [pred.get("is_success", False) for pred in results] preds["is_empty"] = [pred.get("is_empty", False) for pred in results] + del preds["chunk_time_point_list"] + del preds["input_data"] + del preds["output"] return preds diff --git a/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/openicl/icl_inferencer/icl_gen_pressure_inferencer.py b/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/openicl/icl_inferencer/icl_gen_pressure_inferencer.py index efd239e5..fd32052c 100644 --- a/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/openicl/icl_inferencer/icl_gen_pressure_inferencer.py +++ b/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/openicl/icl_inferencer/icl_gen_pressure_inferencer.py @@ -20,7 +20,7 @@ from ais_bench.benchmark.utils.build import build_model_from_cfg from ais_bench.benchmark.global_consts import WORKERS_NUM from ..utils.logging import get_logger from .icl_base_inferencer import GenInferencerOutputHandler -from ais_bench.benchmark.utils.results import dump_results_dict +from ais_bench.benchmark.utils.results import dump_results_dict, fast_dump_results_dict from .icl_gen_perf_inferencer import GenPerfInferencer from .icl_gen_inferencer import DEFAULT_MAX_CONCURRENCY_PER_PROCESS @@ -167,7 +167,6 @@ class GenPressureInferencer(GenPerfInferencer): parsed_entries = self.model.parse_template(entry, mode='gen') results = self.pressure_infer_with_multiprocess( self.model, self.model_cfg, parsed_entries, golds, **extra_gen_kwargs) - results.sort(key=lambda x: x['id']) preds = self.extract_preds(results) preds['id'] = [i for i in range(len(preds['request_id']))] task_params = {"max_concurrency": self.batch_size} @@ -180,11 +179,13 @@ class GenPressureInferencer(GenPerfInferencer): "task": task_params, "requests": preds, } - dump_results_dict( + logger.info("Dumping detail perf data ...") + dump_start = time.perf_counter() + fast_dump_results_dict( perf_details, osp.join(output_filepath, output_filename + "_details.json"), - False ) + logger.info(f"Dump detail perf data cost: {time.perf_counter() - dump_start}(s)") if self.dump_timer and self.is_main_process: timer_filepath = osp.join(output_filepath, "timer", "time.jsonl") diff --git a/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/runners/local_api.py b/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/runners/local_api.py index c779ae73..903b0146 100644 --- a/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/runners/local_api.py +++ b/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/runners/local_api.py @@ -17,7 +17,7 @@ from ais_bench.benchmark.registry import RUNNERS, TASKS from ais_bench.benchmark.tasks import OpenICLInferTask, OpenICLPerfTask, OpenICLInferMergedTask from ais_bench.benchmark.tasks.base import BaseTask from ais_bench.benchmark.utils import (build_dataset_from_cfg, build_synthetic_dataset_from_cfg, - build_model_from_cfg, get_infer_output_path, + build_model_from_cfg, get_infer_output_path, get_logger, task_abbr_from_cfg) from .base import BaseRunner diff --git a/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/summarizers/default_perf.py b/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/summarizers/default_perf.py index 2305ca1e..73bc6de8 100644 --- a/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/summarizers/default_perf.py +++ b/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/summarizers/default_perf.py @@ -5,6 +5,7 @@ import getpass import math import csv import json +import orjson import os.path as osp from datetime import datetime from typing import Any, Dict, List, Optional @@ -75,19 +76,18 @@ class DefaultPerfSummarizer: perf_details_file = osp.join(self.work_dir, "performances", model, f"{dataset}_details.json") if not osp.exists(perf_details_file): continue - with open(perf_details_file, 'r', encoding='utf-8') as file: - self.logger.info(f"Loading detail perf data of {model=} {dataset=} ...") - details_data = json.load(file) - plot_file_path = osp.join(self.work_dir, "performances", model, f"{dataset}_plot.html") - has_plot = plot_sorted_request_timelines( - details_data["requests"]["start_time"], - details_data["requests"]["prefill_latency"], - details_data["requests"]["end_time"], - details_data["requests"]["decode_token_latencies"], - output_file=plot_file_path, unit="s" - ) - if has_plot: - self.logger.info(f"The {dataset}_plot has been saved in {plot_file_path}") + self.logger.info(f"Loading detail perf data of {model=} {dataset=} ...") + details_data = orjson.loads(open(perf_details_file, "rb").read()) + plot_file_path = osp.join(self.work_dir, "performances", model, f"{dataset}_plot.html") + has_plot = plot_sorted_request_timelines( + details_data["requests"]["start_time"], + details_data["requests"]["prefill_latency"], + details_data["requests"]["end_time"], + details_data["requests"]["decode_token_latencies"], + output_file=plot_file_path, unit="s" + ) + if has_plot: + self.logger.info(f"The {dataset}_plot has been saved in {plot_file_path}") calculators_per_model[dataset] = build_perf_metric_calculator_from_cfg(calculator_conf) try: calculators_per_model[dataset]._init_datas(details_data) diff --git a/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/tasks/openicl_infer.py b/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/tasks/openicl_infer.py index dd22d370..89e4abc2 100644 --- a/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/tasks/openicl_infer.py +++ b/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/tasks/openicl_infer.py @@ -106,7 +106,7 @@ class OpenICLInferTask(BaseTask): self._set_default_value(inferencer_cfg, 'batch_size', self.batch_size) inferencer_cfg['max_seq_len'] = self.model_cfg.get('max_seq_len') self.inferencer = ICL_INFERENCERS.build(inferencer_cfg) - + def _inference(self): self.logger.info( f'Start inferencing {task_abbr_from_cfg(self.sub_cfg)}') @@ -127,7 +127,7 @@ class OpenICLInferTask(BaseTask): # set inferencer's default value according to model's config' self.build_inference() - + self.inferencer.update_model_cfg(self.model_cfg) out_path = get_infer_output_path( diff --git a/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/utils/results.py b/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/utils/results.py index 9f70a987..2bc40a79 100644 --- a/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/utils/results.py +++ b/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/utils/results.py @@ -1,5 +1,6 @@ import csv import json +import orjson import collections import math from typing import Optional, Dict, Any @@ -17,6 +18,10 @@ def dump_results_dict(results_dict, filename, formatted = True): else: json.dump(results_dict, json_file, ensure_ascii=False) +def fast_dump_results_dict(results_dict, filename): + with open(filename, 'wb') as f: + f.write(orjson.dumps(results_dict)) + @dataclass class MiddleData: @@ -65,9 +70,7 @@ class MiddleData: "output": self.output, "output_token_id": self.output_token_id, "prefill_latency": self.prefill_latency, - "prefill_throughput": len(self.input_token_id) - / self.prefill_latency - * 1000 if self.prefill_latency > 0 else 0, + "prefill_throughput": round(len(self.input_token_id) / self.prefill_latency * 1000, 4) if self.prefill_latency > 0 else 0, "decode_token_latencies": self.decode_cost[:], "last_decode_latency": self.decode_cost[-1] if self.decode_cost else 0.0, "decode_max_token_latency": ( @@ -76,13 +79,11 @@ class MiddleData: "seq_latency": self.req_latency, "input_tokens_len": self.num_input_tokens, "generate_tokens_len": self.num_generated_tokens, - "generate_tokens_speed": self.num_generated_tokens - / self.req_latency - * 1000 if self.req_latency > 0 else 0, + "generate_tokens_speed": round(self.num_generated_tokens / self.req_latency * 1000, 4) if self.req_latency > 0 else 0, "input_characters_len": len(self.input_data), "generate_characters_len": self.num_generated_chars, "characters_per_token": ( - self.num_generated_chars / self.num_generated_tokens + round(self.num_generated_chars / self.num_generated_tokens, 4) if self.num_generated_tokens else 0.0 ), diff --git a/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/utils/summarize_plot.py b/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/utils/summarize_plot.py index 28f18a74..9a384111 100644 --- a/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/utils/summarize_plot.py +++ b/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/utils/summarize_plot.py @@ -29,9 +29,9 @@ TIMELINE_POINTS_PER_REQUEST = 3 # 每个请求在时间线图中占3个点( # ================== 辅助函数 ================== def validate_input_data( - start_time_list: List[float], + start_time_list: List[float], prefill_latency_list: List[float], - end_time_list: List[float], + end_time_list: List[float], decode_token_latencies_list: List[List[float]], ) -> bool: """验证输入数据是否合法""" @@ -40,15 +40,15 @@ def validate_input_data( if n_requests == 0: logger.warning("No requests to plot!") return False - - if (n_requests != len(prefill_latency_list) or + + if (n_requests != len(prefill_latency_list) or n_requests != len(end_time_list) or n_requests != len(decode_token_latencies_list)): logger.warning("Input list lengths mismatch! Details: ") logger.warning(f"start_list:{n_requests}, prefill_latency_list:{len(prefill_latency_list)}") logger.warning(f"end_list:{len(end_time_list)}, decode_token_latencies_list:{len(decode_token_latencies_list)}") return False - + return True def is_non_streaming_scenario( @@ -59,9 +59,9 @@ def is_non_streaming_scenario( return all(p == 0.0 for p in prefill_latency_list) def preprocess_data( - start_time_list: List[float], + start_time_list: List[float], prefill_latency_list: List[float], - end_time_list: List[float], + end_time_list: List[float], decode_token_latencies_list: List[List[float]], ) -> Tuple[Optional[np.ndarray], np.ndarray, np.ndarray, bool]: """ @@ -71,13 +71,13 @@ def preprocess_data( start = np.asarray(start_time_list, dtype=np.float64) prefill = np.asarray(prefill_latency_list, dtype=np.float64) / 1000 # prefill数据单位为ms,而其他数据均为s end = np.asarray(end_time_list, dtype=np.float64) - + # 检测是否是非流式场景 is_non_streaming = is_non_streaming_scenario(prefill_latency_list, decode_token_latencies_list) - + # 计算首token时间 first_token_times = (start + prefill) if not is_non_streaming else None - + # 对每条请求是否含有非首token时延判断请求索引对应的end_time是否需要更新, # 因为end_time_list因为打点位置会有误差,需用first_token_time_list的值修正 # 仅在非流式场景修正结束时间 @@ -87,15 +87,15 @@ def preprocess_data( end[no_decode_indices] = first_token_times[no_decode_indices] get_logger().debug(f"Adjusted {len(no_decode_indices)} requests with no decode tokens") del no_decode_indices - + # 计算全局最小时间 global_x_min = np.min(start) if len(start) > 0 else 0.0 - + # 计算相对时间 adjusted_starts = start - global_x_min adjusted_first_tokens = (first_token_times - global_x_min) if not is_non_streaming else None adjusted_ends = end - global_x_min - + return adjusted_first_tokens, adjusted_starts, adjusted_ends, is_non_streaming def generate_timeline_traces( @@ -108,29 +108,29 @@ def generate_timeline_traces( n_requests = len(adjusted_starts) if n_requests == 0: return [] - + # 预分配内存 red_x = np.full(TIMELINE_POINTS_PER_REQUEST * n_requests, np.nan, dtype=np.float32) red_y = np.full_like(red_x, np.nan) blue_x = np.full_like(red_x, np.nan) blue_y = np.full_like(red_x, np.nan) hover_text = np.full(TIMELINE_POINTS_PER_REQUEST * n_requests, None, dtype=object) - sorted_indices = np.argsort(adjusted_starts) - + sorted_indices = np.argsort(adjusted_starts) + for sorted_pos, orig_idx in enumerate(sorted_indices): # 获取当前请求的关键时间点 start_t = adjusted_starts[orig_idx] first_token_t = adjusted_first_tokens[orig_idx] end_t = adjusted_ends[orig_idx] - + # 计算数组中的位置 arr_idx = sorted_pos * 3 - + # 红线段(TTFT):从开始到第一个token red_x[arr_idx] = start_t red_x[arr_idx + 1] = first_token_t red_y[arr_idx:arr_idx + 2] = sorted_pos + 1 - + blue_content_data = "NaN" # 蓝线段(Decode):从第一个token到结束 @@ -140,11 +140,11 @@ def generate_timeline_traces( blue_y[arr_idx:arr_idx + 2] = sorted_pos + 1 decode_time = end_t - first_token_t blue_content_data = f"{first_token_t:.2f}→{end_t:.2f}={decode_time:.2f}" - + # 悬停文本,触发点在红线段起点 ttft = first_token_t - start_t e2e = end_t - start_t - + red_content = f"TTFT({unit}): {start_t:.2f}→{first_token_t:.2f}={ttft:.2f}
" blue_content = f"Decode({unit}): {blue_content_data}
" e2e_content = f"E2E({unit}): {start_t:.2f}→{end_t:.2f}={e2e:.2f}" @@ -155,12 +155,12 @@ def generate_timeline_traces( n_points = len(red_x) chunk_size = min(n_points, MAX_POINTS_PER_TRACE) n_chunks = (n_points + chunk_size - 1) // chunk_size - + for i in range(n_chunks): start_idx = i * chunk_size end_idx = min((i + 1) * chunk_size, n_points) chunk = slice(start_idx, end_idx) - + # 红线段 if np.any(~np.isnan(red_x[chunk])): traces.append(go.Scattergl( @@ -173,7 +173,7 @@ def generate_timeline_traces( showlegend=False, connectgaps=False )) - + # 蓝线段 if np.any(~np.isnan(blue_x[chunk])): traces.append(go.Scattergl( @@ -200,22 +200,22 @@ def generate_concurrency_traces( if not np.any(valid_mask): get_logger().warning("No valid requests for concurrency plot!") return [] - + valid_starts = adjusted_starts[valid_mask] valid_ends = adjusted_ends[valid_mask] n_events = len(valid_starts) * 2 - + # 生成事件数组 events = np.empty((n_events, 2), dtype=np.float32) events[:len(valid_starts), 0] = valid_starts events[:len(valid_starts), 1] = 1 # 开始事件 events[len(valid_starts):, 0] = valid_ends events[len(valid_starts):, 1] = -1 # 结束事件 - + # 稳定排序(时间相同则开始事件优先) sort_indices = np.lexsort((events[:, 1], events[:, 0])) events = events[sort_indices] - + # 计算并发数 unique_times, inverse_indices = np.unique(events[:, 0], return_inverse=True) delta_per_time = np.bincount(inverse_indices, weights=events[:, 1]) @@ -223,28 +223,28 @@ def generate_concurrency_traces( conc_times = unique_times conc_counts = cumulative - + # 创建悬停文本 conc_hover_text = [ - f"Time: {t:.4f}{unit}
Concurrency: {c:.0f}" + f"Time: {t:.4f}{unit}
Concurrency: {c:.0f}" for t, c in zip(conc_times, conc_counts) ] - + # 分块渲染 traces = [] n_points = len(conc_times) chunk_size = min(n_points, MAX_POINTS_PER_TRACE) n_chunks = (n_points + chunk_size - 1) // chunk_size - + for i in range(n_chunks): start_idx = i * chunk_size end_idx = min((i + 1) * chunk_size, n_points) - + if i > 0: start_idx = max(0, start_idx - 1) # 确保连续 - + chunk = slice(start_idx, end_idx) - + traces.append(go.Scattergl( x=conc_times[chunk], y=conc_counts[chunk], @@ -257,7 +257,7 @@ def generate_concurrency_traces( showlegend=False, connectgaps=True )) - + # 清理大数组释放内存 del events, sort_indices, unique_times, inverse_indices, delta_per_time, cumulative del conc_times, conc_counts, conc_hover_text @@ -280,14 +280,14 @@ def create_plot_layout( title=f"Relative Time ({unit})", range=[0, max_time], ) - + yaxis_config = dict( **AXIS_CONFIG, rangemode='nonnegative', tickmode='auto', nticks=10, ) - + if has_timeline: # 双图模式 return dict( @@ -299,7 +299,7 @@ def create_plot_layout( ), yaxis1=dict( **yaxis_config, - title="Request Index", + title="Request Index", ), xaxis2=dict( **xaxis_config, @@ -334,24 +334,24 @@ def create_plot_layout( # ================== 对文件外使用的主函数 ================== def plot_sorted_request_timelines( - start_time_list: List[float], + start_time_list: List[float], prefill_latency_list: List[float], - end_time_list: List[float], + end_time_list: List[float], decode_token_latencies_list: List[List[float]], - output_file: str = "timeline.html", + output_file: str = "timeline.html", unit: str = "s" ) -> None: """绘制请求时间线和并发图表""" logger = get_logger() start_timestamp = time.perf_counter() - + # ===== 1. 数据验证和预处理 ===== logger.info("Starting request timeline processing...") - + # 验证输入数据 if not validate_input_data(start_time_list, prefill_latency_list, end_time_list, decode_token_latencies_list): return False - + # 数据预处理 preprocess_start = time.perf_counter() adjusted_first_token_times, adjusted_starts, adjusted_ends, is_non_streaming = preprocess_data( @@ -360,13 +360,13 @@ def plot_sorted_request_timelines( if is_non_streaming: logger.warning("[Non-streaming scenario] The plot will only show the request concurrency chart!") - + n_requests = len(start_time_list) has_timeline = not is_non_streaming and adjusted_first_token_times is not None and n_requests > 0 max_time = np.max(adjusted_ends) if n_requests > 0 else 1.0 - + logger.info(f"Data preprocessing completed in {time.perf_counter() - preprocess_start:.4f}s") - + # ===== 2. 生成时间线图轨迹(仅流式场景下) ===== timeline_traces = [] if has_timeline: @@ -376,26 +376,26 @@ def plot_sorted_request_timelines( adjusted_starts, adjusted_ends, adjusted_first_token_times, unit ) logger.info(f"Generated timeline trace chunks in {time.perf_counter() - timeline_start:.4f}s") - + # ===== 3. 生成并发图轨迹 ===== logger.info("Generating concurrency traces...") concurrency_start = time.perf_counter() concurrency_traces = generate_concurrency_traces(adjusted_starts, adjusted_ends, unit) - + logger.info(f"Generated concurrency trace chunks in {time.perf_counter() - concurrency_start:.4f}s") - + # ===== 4. 创建图表 ===== logger.info("Creating figure layout...") figure_start = time.perf_counter() - + # 创建布局配置 layout = create_plot_layout(max_time, unit, has_timeline) - + # 创建图表对象 if has_timeline: fig = make_subplots( - rows=2, - cols=1, + rows=2, + cols=1, vertical_spacing=0.1, shared_xaxes=True ) @@ -407,16 +407,16 @@ def plot_sorted_request_timelines( fig = go.Figure() for trace in concurrency_traces: fig.add_trace(trace) - + # 应用布局配置 fig.update_layout(layout) - + logger.info(f"Figure layout created in {time.perf_counter() - figure_start:.4f}s") - + # ===== 5. 输出HTML ===== logger.info(f"Writing to {output_file}...") write_start = time.perf_counter() - + fig.write_html( output_file, include_plotlyjs='cdn', @@ -424,7 +424,7 @@ def plot_sorted_request_timelines( auto_open=False, full_html=True, ) - + logger.info(f"HTML written in {time.perf_counter() - write_start:.4f}s") total_time = time.perf_counter() - start_timestamp logger.info(f"Completed! Total execution time: {total_time:.4f}s") diff --git a/ais-bench_workload/experimental_tools/benchmark/requirements/runtime.txt b/ais-bench_workload/experimental_tools/benchmark/requirements/runtime.txt index 5bc15c19..3d7af85e 100644 --- a/ais-bench_workload/experimental_tools/benchmark/requirements/runtime.txt +++ b/ais-bench_workload/experimental_tools/benchmark/requirements/runtime.txt @@ -19,6 +19,7 @@ nltk>=3.7 numpy>=1.23.4,<2.0.0 openai opencv-python-headless +orjson pandas<2.0.0 plotly prettytable -- Gitee From 6eba9e02c07c18022571877607e505fe0160dfc3 Mon Sep 17 00:00:00 2001 From: yanhe13 Date: Mon, 14 Jul 2025 09:58:45 +0800 Subject: [PATCH 2/2] recover WAVE_OFFSET --- .../benchmark/calculators/stable_perf_metric_calculator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/calculators/stable_perf_metric_calculator.py b/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/calculators/stable_perf_metric_calculator.py index 4d81cea4..c3794dd3 100644 --- a/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/calculators/stable_perf_metric_calculator.py +++ b/ais-bench_workload/experimental_tools/benchmark/ais_bench/benchmark/calculators/stable_perf_metric_calculator.py @@ -10,7 +10,7 @@ from ais_bench.benchmark.calculators.base_perf_metric_calculator import BasePerf from ais_bench.benchmark.registry import PERF_METRIC_CALCULATORS from ais_bench.benchmark.calculators.base_perf_metric_calculator import is_legal_percentage_str, DEFAULT_STATS, MAX_STATS_LEN -WAVE_OFFSET = 0.05 +WAVE_OFFSET = 0.02 @PERF_METRIC_CALCULATORS.register_module() class StablePerfMetricCalculator(BasePerfMetricCalculator): -- Gitee