From c6d8701975245b9242ec24912705ab5eff04a883 Mon Sep 17 00:00:00 2001 From: fineartz Date: Mon, 7 Oct 2024 14:59:12 +0000 Subject: [PATCH] add Simulator --- StreamLearn/Simulator/Policy.py | 53 ++++++ StreamLearn/Simulator/StreamEnv.py | 220 +++++++++++++++++++++++ StreamLearn/Simulator/StreamGenerator.py | 87 +++++++++ StreamLearn/Simulator/TaskRepr.py | 35 ++++ StreamLearn/Simulator/__init__.py | 0 StreamLearn/Simulator/test/algo.py | 75 ++++++++ StreamLearn/Simulator/test/dataset.py | 100 +++++++++++ StreamLearn/Simulator/test/main.py | 45 +++++ StreamLearn/Simulator/utils.py | 181 +++++++++++++++++++ 9 files changed, 796 insertions(+) create mode 100644 StreamLearn/Simulator/Policy.py create mode 100644 StreamLearn/Simulator/StreamEnv.py create mode 100644 StreamLearn/Simulator/StreamGenerator.py create mode 100644 StreamLearn/Simulator/TaskRepr.py create mode 100644 StreamLearn/Simulator/__init__.py create mode 100644 StreamLearn/Simulator/test/algo.py create mode 100644 StreamLearn/Simulator/test/dataset.py create mode 100644 StreamLearn/Simulator/test/main.py create mode 100644 StreamLearn/Simulator/utils.py diff --git a/StreamLearn/Simulator/Policy.py b/StreamLearn/Simulator/Policy.py new file mode 100644 index 0000000..01d28c0 --- /dev/null +++ b/StreamLearn/Simulator/Policy.py @@ -0,0 +1,53 @@ +from abc import ABC +from typing import List +import numpy as np + +from StreamLearn.Base.SemiEstimator import StreamEstimator + + +class BasePolicy(ABC): + + def __init__( + self, + num_node: int, + ) -> None: + self.num_node = num_node + + def act(self, obs): + """ + Choose an action based on observation. + The action should be a list of integers, each integer represents the action of a computing node. + + :param obs: observation. + """ + raise NotImplementedError("The act() method of BasePolicy must be implemented.") + + +class IdentityPolicy(BasePolicy): + ''' + An identity policy for testing + Always put the task to the first node + ''' + + def act(self, obs): + return 0 + + +class RandomPolicy(BasePolicy): + ''' + A random policy for testing + Put the task to a random node + ''' + + def act(self, obs): + return np.random.randint(self.num_node) + + +class GreedyPolicy(BasePolicy): + ''' + A greedy policy for testing + Put the task to the node with the nearest finish time + ''' + def act(self, obs): + node_status = obs['node_status'] + return np.argmin(node_status) diff --git a/StreamLearn/Simulator/StreamEnv.py b/StreamLearn/Simulator/StreamEnv.py new file mode 100644 index 0000000..b7658ee --- /dev/null +++ b/StreamLearn/Simulator/StreamEnv.py @@ -0,0 +1,220 @@ +from enum import Enum +from typing import Any, Dict, List, Tuple +from queue import PriorityQueue +import numpy as np +from torch.utils.data import DataLoader + +from StreamLearn.Simulator.StreamGenerator import StreamGenerator +from StreamLearn.Base.SemiEstimator import StreamEstimator +from StreamLearn.Simulator.TaskRepr import TaskRepr +from StreamLearn.Simulator.utils import restrict_to_throughput + + +DEFAULT_NODE_THROUGHPUT = 1024 + +class Node: + + def __init__( + self, + throughput: int = DEFAULT_NODE_THROUGHPUT, + model: StreamEstimator = None, + ) -> None: + self.throughput = throughput + self.model = model + self._init() + + def _init(self): + self.t = 0 + # A list of (task_id, start_time, end_time) + self.task_list: List[Tuple[int, int, int]] = [] + self.task_cnt = 0 + # time when all tasks are finished + self.finish_time = 0 + # task_id -> start_time, end_time, statistics + self.full_history: Dict[int, Tuple[int, int, Any]] = {} + + def reset(self): + self._init() + + def add_task(self, task_id, estimated_time): + start_time = max(self.t, self.finish_time) + self.task_list.append((task_id, start_time, start_time + estimated_time)) + self.finish_time = start_time + estimated_time + self.full_history[task_id] = (start_time, self.finish_time, None) + return start_time, self.finish_time + + def forward_time(self, t: int): + self.t += t + while self.task_cnt < len(self.task_list): + task_id, start_time, end_time = self.task_list[self.task_cnt] + if self.t >= end_time: + self.task_cnt += 1 + else: + break + + def set_task_statistics(self, task_id, statistics): + if task_id not in self.full_history: + raise ValueError(f'Task {task_id} is not executed in this node.') + start_time, end_time, _ = self.full_history[task_id] + self.full_history[task_id] = (start_time, end_time, statistics) + + @property + def is_idle(self): + return self.t >= self.finish_time + + +class Env: + + def __init__( + self, + num_nodes: int, # number of computing nodes + stream_generator: StreamGenerator, # stream generator to generate stream tasks + model: StreamEstimator = None, # model for solving the tasks + task_repr: TaskRepr = None, # task representation + ) -> None: + self.num_node = num_nodes + self.stream_generator = stream_generator + self.model = model + self.task_repr = task_repr + + self._init() + self.obs, _ = self.reset() + + def _init(self): + self.t = 0 + # current task id + self.task_cnt = 0 + # each task is represented by task_id -> (start_time, end_time) + self.tasks: Dict[int, Tuple[int, int]] = {} + self.task_data: Dict[int, Any] = {} + # active tasks are sorted by their end time + self.task_queue = PriorityQueue() + # task_id -> node_id, only active tasks are in this map + self.task_map: Dict[int, int] = {} + # task_id -> statistics + self.task_statistics: Dict[int, Any] = {} + # node list + self.nodes = [Node() for _ in range(self.num_node)] + self.num_idle = self.num_node + + def _get_next_observation(self): + terminal = False + while True: + try: + task = next(self.stream_generator) + if task is not None: + break + self._forward_time(1) + except StopIteration: + terminal = True + break + if terminal: + return None, terminal + self.task_data[self.task_cnt] = task + # time to each node being idle + node_status = [max(0, n.finish_time - self.t) for n in self.nodes] + estimated_time = self.task_repr.get_required_time(task) + self.tasks[self.task_cnt] = (self.t, self.t + estimated_time) + self.task_map[self.task_cnt] = None + self.task_cnt += 1 + self.obs = { + 'node_status': node_status, + 'estimated_time': estimated_time, + } + return self.obs, terminal + + def _forward_time(self, t: int): + self.t += t + self.stream_generator.forward_time(t) + for node in self.nodes: + node.forward_time(t) + while not self.task_queue.empty(): + end_time, task_id = self.task_queue.get() + if end_time > self.t: + self.task_queue.put((end_time, task_id)) + break + self._handle_task(task_id) + + def reset(self): + self._init() + self.stream_generator.reset() + obs, terminal = self._get_next_observation() + return obs, terminal + + def step(self, act): + ''' + act: an int to allocate the task to a node. + ''' + try: + act = int(act) + except ValueError: + raise ValueError(f'Action {act} should be an integer.') + assert 0 <= act < self.num_node, f'Action {act} is out of range [0, {self.num_node - 1}].' + if self.nodes[act].is_idle: + self.num_idle -= 1 + task_id = self.task_cnt - 1 + estimated_time = self.obs['estimated_time'] + st, ed = self.nodes[act].add_task(task_id, estimated_time) + self.task_map[task_id] = act + self.task_queue.put((ed, task_id)) + if self.num_idle == 0: + # skip to the next time when a node is idle + end_time, task_id = self.task_queue.get() + assert end_time > self.t + self._forward_time(end_time - self.t) + self._handle_task(task_id) + self.num_idle += 1 + else: + self._forward_time(1) + obs, terminal = self._get_next_observation() + return obs, terminal + + def handle_remaining_tasks(self): + ''' + Handle all remaining tasks after the stream is finished. + ''' + all_end_time = self.t + while not self.task_queue.empty(): + end_time, task_id = self.task_queue.get() + all_end_time = end_time + # self._forward_time(end_time - self.t) + self._handle_task(task_id) + self.t = all_end_time + + def _handle_task(self, task_id) -> Any: + data = self.task_data[task_id] + train_data, test_data = data.split() + self.model.fit(train_data) + + test_loader = DataLoader(test_data, batch_size=32, shuffle=False) + y_true = [] + y_pred = [] + + for batch in test_loader: + X_test, y_batch = batch + y_true.append(y_batch.numpy()) + y_pred_batch = self.model.predict(X_test) + y_pred.append(y_pred_batch.numpy()) + y_true = np.concatenate(y_true) + y_pred = np.concatenate(y_pred) + + statistics = self.model.evaluate(y_pred, y_true) + self.task_statistics[task_id] = statistics + node_id = self.task_map[task_id] + self.nodes[node_id].set_task_statistics(task_id, statistics) + if self.nodes[node_id].is_idle: + self.num_idle += 1 + + def log_statistics(self): + print(f'----- Time: {self.t} -----') + for task_id in self.tasks.keys(): + node_id = self.task_map[task_id] + if node_id is None: + print(f'Task {task_id}: To be allocated.') + continue + st, ed, stat = self.nodes[node_id].full_history[task_id] + if stat is None: + print(f'Task {task_id}: In node {node_id}, start_time={st}, running...') + else: + print(f'Task {task_id}: In node {node_id}, start_time={st}, end_time={ed}, statistics={stat}') + \ No newline at end of file diff --git a/StreamLearn/Simulator/StreamGenerator.py b/StreamLearn/Simulator/StreamGenerator.py new file mode 100644 index 0000000..431058a --- /dev/null +++ b/StreamLearn/Simulator/StreamGenerator.py @@ -0,0 +1,87 @@ +from typing import Dict +import numpy as np +from StreamLearn.Dataset.StreamDataset import StreamDataset + + +class StreamGenerator: + + def __init__( + self, + tasks: list[StreamDataset], + ) -> None: + self.tasks = tasks + self.num_tasks = len(tasks) + self._init() + + def _init(self): + self.task_cnt = 0 + self.t = 0 + self.flag = False + + def __iter__(self): + return self + + def __next__(self): + if self.task_cnt == self.num_tasks: + raise StopIteration() + task = None + self.flag = self._generate_new_task() + if self.flag: + task = self.tasks[self.task_cnt] + self.task_cnt += 1 + return task + + def reset(self): + self._init() + + def forward_time(self, t: int): + self.t += t + + def _generate_new_task(self) -> bool: + # This function determines whether to generate a new task + # Implement this in the subclass + return NotImplementedError() + + +class UniformStreamGenerator(StreamGenerator): + + def __init__( + self, + tasks: list[StreamDataset], + generate_gap: int = 10, + ) -> None: + super().__init__(tasks) + assert generate_gap > 0 + self.generate_gap = generate_gap + self.last_generation = -np.inf + + def _generate_new_task(self) -> bool: + if self.t - self.last_generation >= self.generate_gap: + self.last_generation = self.t + return True + else: + return False + + def reset(self): + super().reset() + self.last_generation = -np.inf + + +class ProbStreamGenerator(StreamGenerator): + + def __init__( + self, + tasks: list[StreamDataset], + generate_prob: float = 0.1, + seed: int = None, + ) -> None: + super().__init__(tasks) + self.generate_prob = generate_prob + if seed is not None: + np.random.seed(seed) + + def _generate_new_task(self) -> bool: + if np.random.rand() < self.generate_prob: + return True + else: + return False \ No newline at end of file diff --git a/StreamLearn/Simulator/TaskRepr.py b/StreamLearn/Simulator/TaskRepr.py new file mode 100644 index 0000000..d801093 --- /dev/null +++ b/StreamLearn/Simulator/TaskRepr.py @@ -0,0 +1,35 @@ +import numpy as np +from StreamLearn.Dataset.StreamDataset import StreamDataset + + +class TaskRepr: + ''' + Base class for task representation. + The task representation is used for downstream tasks, + such as resource allocation. + ''' + + def get_required_time(self, task: StreamDataset): + ''' + Returns the estimated time required for a task. + ''' + raise NotImplementedError() + + +class IdentityTaskRepr(TaskRepr): + + def __init__(self, t: int = 4) -> None: + self.t = t + + def get_required_time(self, task: StreamDataset): + return self.t + + +class RandomTaskRepr(TaskRepr): + + def __init__(self, t_min: int = 1, t_max: int = 5) -> None: + self.t_min = t_min + self.t_max = t_max + + def get_required_time(self, task: StreamDataset): + return np.random.randint(self.t_min, self.t_max) diff --git a/StreamLearn/Simulator/__init__.py b/StreamLearn/Simulator/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/StreamLearn/Simulator/test/algo.py b/StreamLearn/Simulator/test/algo.py new file mode 100644 index 0000000..9c91294 --- /dev/null +++ b/StreamLearn/Simulator/test/algo.py @@ -0,0 +1,75 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from sklearn.base import BaseEstimator + +from StreamLearn.Base.SemiEstimator import StreamEstimator + + +class TestEstimator(StreamEstimator): + ''' + An estimator for testing + All functions are implemented with random values + ''' + def __init__(self): + pass + + def fit(self, stream_dataset): + pass + + def predict(self, X): + return torch.rand((X.shape[0], 1)) + + def evaluate(self, y_pred, y_true): + return np.random.rand() + + +class LREstimator(StreamEstimator): + + def __init__(self): + self.coefficients = None + + def fit(self, stream_dataset): + X, y = stream_dataset[:, :-1], stream_dataset[:, -1] + X_bias = np.hstack([np.ones((X.shape[0], 1)), X]) + self.coefficients = np.linalg.inv(X_bias.T @ X_bias) @ X_bias.T @ y + + def predict(self, X): + X_bias = np.hstack([np.ones((X.shape[0], 1)), X]) + return X_bias @ self.coefficients + + def evaluate(self, y_pred, y_true): + return np.mean((y_pred - y_true) ** 2) + + +class MLPEstimator(StreamEstimator): + + def __init__(self, input_size, hidden_size, output_size, lr=0.001): + super().__init__() + self.model = nn.Sequential( + nn.Linear(input_size, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, output_size) + ) + self.criterion = nn.MSELoss() + self.optimizer = optim.Adam(self.model.parameters(), lr=lr) + + def fit(self, stream_dataset): + X, y = stream_dataset[:, :-1], stream_dataset[:, -1] + X, y = torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32).unsqueeze(1) + self.optimizer.zero_grad() + output = self.model(X) + loss = self.criterion(output, y) + loss.backward() + self.optimizer.step() + + def predict(self, X): + X = torch.tensor(X, dtype=torch.float32) + with torch.no_grad(): + return self.model(X).numpy() + + def evaluate(self, y_pred, y_true): + y_pred = torch.tensor(y_pred, dtype=torch.float32) + y_true = torch.tensor(y_true, dtype=torch.float32) + return self.criterion(y_pred, y_true).item() diff --git a/StreamLearn/Simulator/test/dataset.py b/StreamLearn/Simulator/test/dataset.py new file mode 100644 index 0000000..7b44d2a --- /dev/null +++ b/StreamLearn/Simulator/test/dataset.py @@ -0,0 +1,100 @@ +import torch +import torch.utils +from torch.utils.data import Dataset +from torchvision import datasets, transforms + +from StreamLearn.Dataset.StreamDataset import StreamDataset + + +class TestStreamDataset(StreamDataset): + ''' + A dataset for testing + All functions are implemented with random values + ''' + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self._data_shape = (10,) + + def __len__(self): + return 100 + + def __getitem__(self, index): + return torch.rand(self._data_shape), torch.randint(0, 10, (1,)) + + def split(self): + return torch.utils.data.random_split(self, [int(len(self) * 0.8), len(self) - int(len(self) * 0.8)]) + + +class MnistStreamDataset(StreamDataset): + def __init__(self, train=True, transform=None): + """ + Initialize the MnistStreamDataset. + + Args: + - train (bool): If True, loads the training dataset; if False, loads the test dataset. + - transform (callable, optional): A function/transform to apply to the data. + """ + if transform is None: + self.transform = transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) # Mean and std for MNIST + ]) + else: + self.transform = transform + + # Load MNIST dataset + self.mnist_data = datasets.MNIST( + root='./data', + train=train, + download=True, + transform=transform + ) + + def __len__(self): + return len(self.mnist_data) + + def __getitem__(self, index): + img, label = self.mnist_data[index] + + # Apply any transformations if specified + if self.transform: + img = self.transform(img) + + return img, label + + +class Cifar10StreamDataset(Dataset): + def __init__(self, train=True, transform=None): + """ + Initialize the Cifar10StreamDataset. + + Args: + - train (bool): If True, loads the training dataset; if False, loads the test dataset. + - transform (callable, optional): A function/transform to apply to the data. + """ + if transform is None: + self.transform = transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Normalize CIFAR-10 to [-1, 1] + ]) + else: + self.transform = transform + # Load CIFAR-10 dataset + self.cifar_data = datasets.CIFAR10( + root='./data', + train=train, + download=True, + transform=transform + ) + + def __len__(self): + return len(self.cifar_data) + + def __getitem__(self, index): + img, label = self.cifar_data[index] + + # Apply any transformations if specified + if self.transform: + img = self.transform(img) + + return img, label diff --git a/StreamLearn/Simulator/test/main.py b/StreamLearn/Simulator/test/main.py new file mode 100644 index 0000000..f18de50 --- /dev/null +++ b/StreamLearn/Simulator/test/main.py @@ -0,0 +1,45 @@ +import numpy as np +import torch + +from StreamLearn.Simulator.StreamEnv import Env +from StreamLearn.Simulator.Policy import IdentityPolicy, RandomPolicy, GreedyPolicy +from StreamLearn.Simulator.StreamGenerator import UniformStreamGenerator, ProbStreamGenerator +from StreamLearn.Simulator.TaskRepr import IdentityTaskRepr, RandomTaskRepr +from StreamLearn.Simulator.test.dataset import TestStreamDataset, MnistStreamDataset, Cifar10StreamDataset +from StreamLearn.Simulator.test.algo import TestEstimator, LREstimator, MLPEstimator + + +def main(): + np.random.seed(0) + torch.random.manual_seed(0) + + num_tasks = 10 + num_nodes = 10 + task_list = [TestStreamDataset() for _ in range(num_tasks)] + stream_generator = ProbStreamGenerator(task_list, generate_prob=0.8) + model = TestEstimator() + task_repr = RandomTaskRepr(t_min=9, t_max=15) + + # Create a stream environment + env = Env( + num_nodes=num_nodes, + stream_generator=stream_generator, + model=model, + task_repr=task_repr + ) + + # Create a policy + policy = GreedyPolicy(num_node=num_nodes) + + # Run the simulation + obs, terminal = env.reset() + while not terminal: + action = policy.act(obs) + obs, terminal = env.step(action) + env.log_statistics() + env.handle_remaining_tasks() + env.log_statistics() + + +if __name__ == "__main__": + main() diff --git a/StreamLearn/Simulator/utils.py b/StreamLearn/Simulator/utils.py new file mode 100644 index 0000000..e30e88e --- /dev/null +++ b/StreamLearn/Simulator/utils.py @@ -0,0 +1,181 @@ +from typing import Any +import numpy as np +import torch + +#---------------------------------------------------------------------------- +# Cached construction of constant tensors. Avoids CPU=>GPU copy when the +# same constant is used multiple times. + +_constant_cache = dict() + +def constant(value, shape=None, dtype=None, device=None, memory_format=None): + value = np.asarray(value) + if shape is not None: + shape = tuple(shape) + if dtype is None: + dtype = torch.get_default_dtype() + if device is None: + device = torch.device('cpu') + if memory_format is None: + memory_format = torch.contiguous_format + + key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) + tensor = _constant_cache.get(key, None) + if tensor is None: + tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) + if shape is not None: + tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) + tensor = tensor.contiguous(memory_format=memory_format) + _constant_cache[key] = tensor + return tensor + + +class InfiniteSampler(torch.utils.data.Sampler): + def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): + assert len(dataset) > 0 + assert num_replicas > 0 + assert 0 <= rank < num_replicas + assert 0 <= window_size <= 1 + super().__init__(dataset) + self.dataset = dataset + self.rank = rank + self.num_replicas = num_replicas + self.shuffle = shuffle + self.seed = seed + self.window_size = window_size + + def __iter__(self): + order = np.arange(len(self.dataset)) + rnd = None + window = 0 + if self.shuffle: + rnd = np.random.RandomState(self.seed) + rnd.shuffle(order) + window = int(np.rint(order.size * self.window_size)) + + idx = 0 + while True: + i = idx % order.size + if idx % self.num_replicas == self.rank: + yield order[i] + if window >= 2: + j = (i - rnd.randint(window)) % order.size + order[i], order[j] = order[j], order[i] + idx += 1 + + +def params_and_buffers(module): + assert isinstance(module, torch.nn.Module) + return list(module.parameters()) + list(module.buffers()) + +def named_params_and_buffers(module): + assert isinstance(module, torch.nn.Module) + return list(module.named_parameters()) + list(module.named_buffers()) + +@torch.inference_mode() +def copy_params_and_buffers(src_module, dst_module, require_all=False): + assert isinstance(src_module, torch.nn.Module) + assert isinstance(dst_module, torch.nn.Module) + src_tensors = dict(named_params_and_buffers(src_module)) + for name, tensor in named_params_and_buffers(dst_module): + assert (name in src_tensors) or (not require_all) + try: + if name in src_tensors: + tensor.copy_(src_tensors[name]) + except RuntimeError as err: + raise RuntimeError(f'Error copying "{name}" from {src_module} to {dst_module}') from err + + +class EasyDict(dict): + """Convenience class that behaves like a dict but allows access with the attribute syntax.""" + + def __getattr__(self, name: str) -> Any: + try: + return self[name] + except KeyError: + raise AttributeError(name) + + def __setattr__(self, name: str, value: Any) -> None: + self[name] = value + + def __delattr__(self, name: str) -> None: + del self[name] + + +def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): + assert isinstance(module, torch.nn.Module) + assert not isinstance(module, torch.jit.ScriptModule) + assert isinstance(inputs, (tuple, list)) + + # Register hooks. + entries = [] + nesting = [0] + def pre_hook(_mod, _inputs): + nesting[0] += 1 + def post_hook(mod, _inputs, outputs): + nesting[0] -= 1 + if nesting[0] <= max_nesting: + outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] + outputs = [t for t in outputs if isinstance(t, torch.Tensor)] + entries.append(EasyDict(mod=mod, outputs=outputs)) + hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] + hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] + + # Run module. + outputs = module(*inputs) + for hook in hooks: + hook.remove() + + # Identify unique outputs, parameters, and buffers. + tensors_seen = set() + for e in entries: + e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] + e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] + e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] + tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} + + # Filter out redundant entries. + if skip_redundant: + entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] + + # Construct table. + rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] + rows += [['---'] * len(rows[0])] + param_total = 0 + buffer_total = 0 + submodule_names = {mod: name for name, mod in module.named_modules()} + for e in entries: + name = '' if e.mod is module else submodule_names[e.mod] + param_size = sum(t.numel() for t in e.unique_params) + buffer_size = sum(t.numel() for t in e.unique_buffers) + output_shapes = [str(list(t.shape)) for t in e.outputs] + output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] + rows += [[ + name + (':0' if len(e.outputs) >= 2 else ''), + str(param_size) if param_size else '-', + str(buffer_size) if buffer_size else '-', + (output_shapes + ['-'])[0], + (output_dtypes + ['-'])[0], + ]] + for idx in range(1, len(e.outputs)): + rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] + param_total += param_size + buffer_total += buffer_size + rows += [['---'] * len(rows[0])] + rows += [['Total', str(param_total), str(buffer_total), '-', '-']] + + # Print table. + widths = [max(len(cell) for cell in column) for column in zip(*rows)] + print() + for row in rows: + print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) + print() + return outputs + + +def restrict_to_throughput(throughput, data): + if data is None: + return None + if data.shape[0] > throughput: + return data[:throughput] + return data -- Gitee