diff --git a/MindChem/applications/orb/README.md b/MindChem/applications/orb/README.md index a6598369f7c1b093ca1d0e01031946f2c5a2988c..f0c953d2b4461440797684f0acf21a5a39e23a8e 100644 --- a/MindChem/applications/orb/README.md +++ b/MindChem/applications/orb/README.md @@ -10,7 +10,7 @@ ## Environment Requirements -> 1. Install `mindspore (2.5.0)` +> 1. Install `mindspore (2.7.0)` > 2. Install dependencies: `pip install -r requirement.txt` ## Quick Start @@ -28,32 +28,53 @@ ```text The main code modules are in the src folder, with the dataset folder containing the datasets, the orb_ckpts folder containing pre-trained models and trained model weight files, and the configs folder containing parameter configuration files for each code. -orb_models # Model name +orb_models # ORB pre-training / fine-tuning project ├── dataset - ├── train_mptrj_ase.db # Training dataset for fine-tuning stage - └── val_mptrj_ase.db # Test dataset for fine-tuning stage -├── orb_ckpts - └── orb-mptraj-only-v2.ckpt # Pre-trained model checkpoint -├── configs - ├── config.yaml # Single-card training parameter configuration file - ├── config_parallel.yaml # Multi-card parallel training parameter configuration file - └── config_eval.yaml # Inference parameter configuration file -├── src - ├── __init__.py - ├── ase_dataset.py # Process and load datasets - ├── atomic_system.py # Define data structure for atomic systems - ├── base.py # Base class definitions - ├── featurization_utilities.py # Provide tools to convert atomic systems into feature vectors - ├── pretrained.py # Pre-trained model related functions - ├── property_definitions.py # Define calculation methods and naming rules for various physical properties in atomic systems - ├── trainer.py # Model loss class definitions - ├── segment_ops.py # Provide tools for segmenting data - └── utils.py # Utility module -├── finetune.py # Model fine-tuning code -├── evaluate.py # Model inference code -├── run.sh # Single-card training startup script -├── run_parallel.sh # Multi-card parallel training startup script -└── requirement.txt # Environment +│ ├── train_mptrj_ase.db # Training dataset for fine-tuning (ASE trajectories, SQLite) +│ └── val_mptrj_ase.db # Validation / test dataset for fine-tuning +│ +├── orb_ckpts # Directory for pre-trained & fine-tuned checkpoints +│ └── orb-mptraj-only-v2.ckpt # Pre-trained ORB checkpoint (mptraj-only task) +│ +├── configs # Config files for training / inference +│ ├── config.yaml # Single-card training configuration (lr, batch_size, etc.) +│ ├── config_parallel.yaml # Multi-card data-parallel training configuration +│ └── config_eval.yaml # Inference / evaluation configuration +│ +├── src # Core code for data processing and training +│ ├── __init__.py # Package initializer for src +│ ├── ase_dataset.py # Load and wrap ASE datasets (read SQLite, build atomic graphs) +│ ├── atomic_system.py # Data structures for atomic systems (positions, species, cell, etc.) +│ ├── base.py # Common base classes and utilities (e.g., batch_graphs) +│ ├── featurization_utilities.py # Tools to convert atomic systems into model input features +│ ├── pretrained.py # Interfaces for building and loading pre-trained ORB models +│ ├── property_definitions.py # Config and naming rules for energy / forces / stress, etc. +│ ├── trainer.py # Training loop and loss wrappers (e.g., OrbLoss) +│ ├── segment_ops.py # Segment-wise reduction ops (segment_sum / mean / max) +│ └── utils.py # Utility functions (seeding, logging, optimizer & LR scheduler) +│ +├── models # Model definitions (GNN / ORB networks) +│ ├── __init__.py # Package initializer for models; export model entry points +│ └── cell # Basic building blocks for NN / GNN +│ ├── __init__.py # Package initializer for cell +│ ├── activation.py # Activation functions and non-linear modules +│ ├── basic_block.py # Basic network blocks (e.g., MLP block, ResBlock) +│ ├── convolution.py # Graph / spatial convolution layers +│ ├── embedding.py # Embedding layers for atom types, edge features, etc. +│ ├── message_passing.py # Message passing layers and GNN backbone +│ └── orb # ORB-specific submodules +│ ├── __init__.py # Package initializer for orb +│ ├── gns.py # GNS (Graph Network Simulator) related structures / APIs +│ ├── orb.py # Main ORB architecture (encoder + heads) +│ └── utils.py # Internal utilities and helper modules for ORB +│ +├── finetune.py # Entry script for model fine-tuning +├── evaluate.py # Entry script for model inference / evaluation +│ +├── run.sh # Single-card training launcher (wraps finetune.py + config.yaml) +├── run_parallel.sh # Multi-card training launcher (msrun + config_parallel.yaml) +└── requirement.txt # Python dependency list for environment setup + ``` ## Download Dataset diff --git a/MindChem/applications/orb/README_CN.md b/MindChem/applications/orb/README_CN.md index 6812131cf695b3ee4b8822122270dc13bc3d9c5e..b2ec65f1bd162010ccb8ecaaf6efdaa26537dd51 100644 --- a/MindChem/applications/orb/README_CN.md +++ b/MindChem/applications/orb/README_CN.md @@ -10,7 +10,7 @@ ## 环境要求 -> 1. 安装`mindspore(2.5.0)` +> 1. 安装`mindspore(2.7.0)` > 2. 安装依赖包:`pip install -r requirement.txt` ## 快速入门 @@ -28,32 +28,53 @@ ```text 代码主要模块在src文件夹下,其中dataset文件夹下是数据集,orb_ckpts文件夹下是预训练模型和训练好的模型权重文件,configs文件夹下是各代码的参数配置文件。 -orb_models # 模型名 +orb_models # ORB 预训练 / 微调工程 ├── dataset - ├── train_mptrj_ase.db # 微调阶段训练数据集 - └── val_mptrj_ase.db # 微调阶段测试数据集 -├── orb_ckpts - └── orb-mptraj-only-v2.ckpt # 预训练模型checkpoint -├── configs - ├── config.yaml # 单卡训练参数配置文件 - ├── config_parallel.yaml # 多卡并行训练参数配置文件 - └── config_eval.yaml # 推理参数配置文件 -├── src - ├── __init__.py - ├── ase_dataset.py # 处理和加载数据集 - ├── atomic_system.py # 定义原子系统的数据结构 - ├── base.py # 基础类定义 - ├── featurization_utilities.py # 提供将原子系统转换为特征向量的工具 - ├── pretrained.py # 预训练模型相关函数 - ├── property_definitions.py # 定义原子系统中各种物理性质的计算方式和命名规则 - ├── trainer.py # 模型loss类定义 - ├── segment_ops.py # 提供对数据进行分段处理的工具 - └── utils.py # 工具模块 -├── finetune.py # 模型微调代码 -├── evaluate.py # 模型推理代码 -├── run.sh # 单卡训练启动脚本 -├── run_parallel.sh # 多卡并行训练启动脚本 -└── requirement.txt # 环境 +│ ├── train_mptrj_ase.db # 微调训练集(ASE 轨迹,SQLite 格式) +│ └── val_mptrj_ase.db # 微调验证 / 测试集 +│ +├── orb_ckpts # 预训练 & 微调模型 ckpt 存放目录 +│ └── orb-mptraj-only-v2.ckpt # 仅 mptraj 任务的预训练 ORB 模型 +│ +├── configs # 训练 / 推理配置 +│ ├── config.yaml # 单卡训练配置(学习率、batch_size 等) +│ ├── config_parallel.yaml # 多卡数据并行训练配置 +│ └── config_eval.yaml # 推理 / 评估配置 +│ +├── src # 数据处理与训练核心源码 +│ ├── __init__.py # src 包初始化 +│ ├── ase_dataset.py # ASE 数据集读取与封装(读 SQLite、组装原子图) +│ ├── atomic_system.py # 原子系统数据结构定义(坐标、原子种类、晶胞信息等) +│ ├── base.py # 通用基类与工具(batch_graphs 等图数据打包) +│ ├── featurization_utilities.py # 原子系统 → 模型输入特征张量的特征化工具 +│ ├── pretrained.py # 预训练 ORB 模型构造与加载接口 +│ ├── property_definitions.py # 能量 / 力 / 应力等物理量配置与命名 +│ ├── trainer.py # 训练循环与 OrbLoss 等损失封装 +│ ├── segment_ops.py # segment_sum / mean / max 等分段归约算子 +│ └── utils.py # 通用工具函数(随机种子、日志、优化器、LR scheduler 等) +│ +├── models # 模型结构定义(GNN / ORB 等) +│ ├── __init__.py # models 包初始化,统一暴露模型接口 +│ └── cell # 神经网络 / 图网络基础模块 +│ ├── __init__.py # cell 子包初始化 +│ ├── activation.py # 激活函数与非线性模块 +│ ├── basic_block.py # 基础网络模块(MLP Block、ResBlock 等) +│ ├── convolution.py # 图卷积 / 空间卷积相关层 +│ ├── embedding.py # 原子类型、边特征等嵌入层 +│ ├── message_passing.py # 消息传递层与 GNN 主体结构 +│ └── orb # ORB 专用子模块 +│ ├── __init__.py # orb 子包初始化 +│ ├── gns.py # GNS(Graph Network Simulator) 相关结构 / 接口 +│ ├── orb.py # ORB 主体网络(encoder + heads) +│ └── utils.py # ORB 内部工具与辅助模块 +│ +├── finetune.py # 模型微调入口脚本 +├── evaluate.py # 推理 / 评估入口脚本 +│ +├── run.sh # 单卡训练启动脚本(调用 finetune.py + config.yaml) +├── run_parallel.sh # 多卡并行训练启动脚本(msrun + config_parallel.yaml) +└── requirement.txt # Python 依赖列表(环境搭建用) + ``` ## 下载数据集 diff --git a/MindChem/applications/orb/configs/config_eval.yaml b/MindChem/applications/orb/configs/config_eval.yaml index 1e98c5f0b0d98f1c2dd87bcb095b0c19036ef2b4..e0ffea036b7ba6bdc6a44db41f1fd6f87b4115f8 100644 --- a/MindChem/applications/orb/configs/config_eval.yaml +++ b/MindChem/applications/orb/configs/config_eval.yaml @@ -6,7 +6,7 @@ device_id: 0 val_data_path: dataset/val_mptrj_ase.db num_workers: 8 batch_size: 64 -checkpoint_path: orb_ckpts/orb-ft-checkpoint_epoch99.ckpt +checkpoint_path: orb_ckpts/orb-mptraj-only-v2.ckpt random_seed: 1234 output_dir: results/ diff --git a/MindChem/applications/orb/configs/config_parallel.yaml b/MindChem/applications/orb/configs/config_parallel.yaml index c6a5e0857b52d3cf543e0e0388160e6764184c1c..51436b345e9d77a412c8a6402d2e77c2fe84b312 100644 --- a/MindChem/applications/orb/configs/config_parallel.yaml +++ b/MindChem/applications/orb/configs/config_parallel.yaml @@ -2,9 +2,9 @@ train_data_path: dataset/train_mptrj_ase.db val_data_path: dataset/val_mptrj_ase.db num_workers: 8 -batch_size: 256 +batch_size: 64 gradient_clip_val: 0.5 -max_epochs: 100 +max_epochs: 2 checkpoint_path: orb_ckpts/ lr: 3.0e-4 random_seed: 666 diff --git a/MindChem/applications/orb/mindchemistry/graph/__init__.py b/MindChem/applications/orb/mindchemistry/graph/__init__.py deleted file mode 100644 index 1ae7d9a34cfe890b31a59b76be9d3d1a0c8b3c8a..0000000000000000000000000000000000000000 --- a/MindChem/applications/orb/mindchemistry/graph/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2024 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""graph""" diff --git a/MindChem/applications/orb/mindchemistry/graph/dataloader.py b/MindChem/applications/orb/mindchemistry/graph/dataloader.py deleted file mode 100644 index 6af294afb6a4f85e69db0526a5a864c58dceebbc..0000000000000000000000000000000000000000 --- a/MindChem/applications/orb/mindchemistry/graph/dataloader.py +++ /dev/null @@ -1,408 +0,0 @@ -# Copyright 2024 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""dataloader -""" -import random -import numpy as np -from mindspore import Tensor -import mindspore as ms - - -class DataLoaderBase: - r""" - DataLoader that stacks a batch of graph data to fixed-size Tensors - - For specific dataset, usually the following functions should be customized to include different fields: - __init__, shuffle_action, __iter__ - - """ - - def __init__(self, - batch_size, - edge_index, - label=None, - node_attr=None, - edge_attr=None, - padding_std_ratio=3.5, - dynamic_batch_size=True, - shuffle_dataset=True, - max_node=None, - max_edge=None): - self.batch_size = batch_size - self.edge_index = edge_index - self.index = 0 - self.step = 0 - self.padding_std_ratio = padding_std_ratio - self.batch_change_num = 0 - self.batch_exceeding_num = 0 - self.dynamic_batch_size = dynamic_batch_size - self.shuffle_dataset = shuffle_dataset - - ### can be customized to specific dataset - self.label = label - self.node_attr = node_attr - self.edge_attr = edge_attr - self.sample_num = len(self.node_attr) - batch_size_div = self.batch_size - if batch_size_div != 0: - self.step_num = int(self.sample_num / batch_size_div) - else: - raise ValueError - - if dynamic_batch_size: - self.max_start_sample = self.sample_num - else: - self.max_start_sample = self.sample_num - self.batch_size + 1 - - self.set_global_max_node_edge_num(self.node_attr, self.edge_attr, max_node, max_edge, shuffle_dataset, - dynamic_batch_size) - ####### - - def __len__(self): - return self.sample_num - - ### example of generating data of each step, can be customized to specific dataset - def __iter__(self): - if self.shuffle_dataset: - self.shuffle() - else: - self.restart() - - while self.index < self.max_start_sample: - # pylint: disable=W0612 - edge_index_step, node_batch_step, node_mask, edge_mask, batch_size_mask, node_num, edge_num, batch_size \ - = self.gen_common_data(self.node_attr, self.edge_attr) - - ### can be customized to generate different attributes or labels according to specific dataset - node_attr_step = self.gen_node_attr(self.node_attr, batch_size, node_num) - edge_attr_step = self.gen_edge_attr(self.edge_attr, batch_size, edge_num) - label_step = self.gen_global_attr(self.label, batch_size) - - self.add_step_index(batch_size) - - ### make number to Tensor, if it is used as a Tensor in the network - node_num = Tensor(node_num) - batch_size = Tensor(batch_size) - - yield node_attr_step, edge_attr_step, label_step, edge_index_step, node_batch_step, \ - node_mask, edge_mask, node_num, batch_size - - @staticmethod - def pad_zero_to_end(src, axis, zeros_len): - """pad_zero_to_end""" - pad_shape = [] - for i in range(src.ndim): - if i == axis: - pad_shape.append((0, zeros_len)) - else: - pad_shape.append((0, 0)) - return np.pad(src, pad_shape) - - @staticmethod - def gen_mask(total_len, real_len): - """gen_mask""" - mask = np.concatenate((np.full((real_len,), np.float32(1)), np.full((total_len - real_len,), np.float32(0)))) - return mask - - ### example of computing global max length of node_attr and edge_attr, can be customized to specific dataset - def set_global_max_node_edge_num(self, - node_attr, - edge_attr, - max_node=None, - max_edge=None, - shuffle_dataset=True, - dynamic_batch_size=True): - """set_global_max_node_edge_num - - Args: - node_attr: node_attr - edge_attr: edge_attr - max_node: max_node. Defaults to None. - max_edge: max_edge. Defaults to None. - shuffle_dataset: shuffle_dataset. Defaults to True. - dynamic_batch_size: dynamic_batch_size. Defaults to True. - - Raises: - ValueError: ValueError - """ - if not shuffle_dataset: - max_node_num, max_edge_num = self.get_max_node_edge_num(node_attr, edge_attr, dynamic_batch_size) - self.max_node_num_global = max_node_num if max_node is None else max(max_node, max_node_num) - self.max_edge_num_global = max_edge_num if max_edge is None else max(max_edge, max_edge_num) - return - - sum_node = 0 - sum_edge = 0 - count = 0 - max_node_single = 0 - max_edge_single = 0 - for step in range(self.sample_num): - node_len = len(node_attr[step]) - edge_len = len(edge_attr[step]) - sum_node += node_len - sum_edge += edge_len - max_node_single = max(max_node_single, node_len) - max_edge_single = max(max_edge_single, edge_len) - count += 1 - if count != 0: - mean_node = sum_node / count - mean_edge = sum_edge / count - else: - raise ValueError - - if max_node is not None and max_edge is not None: - if max_node < max_node_single: - raise ValueError( - f"the max_node {max_node} is less than the max length of a single sample {max_node_single}") - if max_edge < max_edge_single: - raise ValueError( - f"the max_edge {max_edge} is less than the max length of a single sample {max_edge_single}") - - self.max_node_num_global = max_node - self.max_edge_num_global = max_edge - elif max_node is None and max_edge is None: - sum_node = 0 - sum_edge = 0 - for step in range(self.sample_num): - sum_node += (len(node_attr[step]) - mean_node) ** 2 - sum_edge += (len(edge_attr[step]) - mean_edge) ** 2 - - if count != 0: - std_node = np.sqrt(sum_node / count) - std_edge = np.sqrt(sum_edge / count) - else: - raise ValueError - - self.max_node_num_global = int(self.batch_size * mean_node + - self.padding_std_ratio * np.sqrt(self.batch_size) * std_node) - self.max_edge_num_global = int(self.batch_size * mean_edge + - self.padding_std_ratio * np.sqrt(self.batch_size) * std_edge) - self.max_node_num_global = max(self.max_node_num_global, max_node_single) - self.max_edge_num_global = max(self.max_edge_num_global, max_edge_single) - elif max_node is None: - if max_edge < max_edge_single: - raise ValueError( - f"the max_edge {max_edge} is less than the max length of a single sample {max_edge_single}") - - if mean_edge != 0: - self.max_node_num_global = int(max_edge * mean_node / mean_edge) - else: - raise ValueError - self.max_node_num_global = max(self.max_node_num_global, max_node_single) - self.max_edge_num_global = max_edge - else: - if max_node < max_node_single: - raise ValueError( - f"the max_node {max_node} is less than the max length of a single sample {max_node_single}") - - self.max_node_num_global = max_node - if mean_node != 0: - self.max_edge_num_global = int(max_node * mean_edge / mean_node) - else: - raise ValueError - self.max_edge_num_global = max(self.max_edge_num_global, max_edge_single) - - def get_max_node_edge_num(self, node_attr, edge_attr, remainder=True): - """get_max_node_edge_num - - Args: - node_attr: node_attr - edge_attr: edge_attr - remainder (bool, optional): remainder. Defaults to True. - - Returns: - max_node_num, max_edge_num - """ - max_node_num = 0 - max_edge_num = 0 - index = 0 - for _ in range(self.step_num): - node_num = 0 - edge_num = 0 - for _ in range(self.batch_size): - node_num += len(node_attr[index]) - edge_num += len(edge_attr[index]) - index += 1 - max_node_num = max(max_node_num, node_num) - max_edge_num = max(max_edge_num, edge_num) - - if remainder: - remain_num = self.sample_num - index - 1 - node_num = 0 - edge_num = 0 - for _ in range(remain_num): - node_num += len(node_attr[index]) - edge_num += len(edge_attr[index]) - index += 1 - max_node_num = max(max_node_num, node_num) - max_edge_num = max(max_edge_num, edge_num) - - return max_node_num, max_edge_num - - def shuffle_index(self): - """shuffle_index""" - indices = list(range(self.sample_num)) - random.shuffle(indices) - return indices - - ### example of shuffling the input dataset, can be customized to specific dataset - def shuffle_action(self): - """shuffle_action""" - indices = self.shuffle_index() - self.edge_index = [self.edge_index[i] for i in indices] - self.label = [self.label[i] for i in indices] - self.node_attr = [self.node_attr[i] for i in indices] - self.edge_attr = [self.edge_attr[i] for i in indices] - - ### example of generating the final shuffled dataset, can be customized to specific dataset - def shuffle(self): - """shuffle""" - self.shuffle_action() - if not self.dynamic_batch_size: - max_node_num, max_edge_num = self.get_max_node_edge_num(self.node_attr, self.edge_attr, remainder=False) - while max_node_num > self.max_node_num_global or max_edge_num > self.max_edge_num_global: - self.shuffle_action() - max_node_num, max_edge_num = self.get_max_node_edge_num(self.node_attr, self.edge_attr, remainder=False) - - self.step = 0 - self.index = 0 - - def restart(self): - """restart""" - self.step = 0 - self.index = 0 - - ### example of calculating dynamic batch size to avoid exceeding the max length of node and edge, can be customized to specific dataset - def get_batch_size(self, node_attr, edge_attr, start_batch_size): - """get_batch_size - - Args: - node_attr: node_attr - edge_attr: edge_attr - start_batch_size: start_batch_size - - Returns: - batch_size - """ - node_num = 0 - edge_num = 0 - for i in range(start_batch_size): - index = self.index + i - node_num += len(node_attr[index]) - edge_num += len(edge_attr[index]) - - exceeding = False - while node_num > self.max_node_num_global or edge_num > self.max_edge_num_global: - node_num -= len(node_attr[index]) - edge_num -= len(edge_attr[index]) - index -= 1 - exceeding = True - self.batch_exceeding_num += 1 - if exceeding: - self.batch_change_num += 1 - - return index - self.index + 1 - - def gen_common_data(self, node_attr, edge_attr): - """gen_common_data - - Args: - node_attr: node_attr - edge_attr: edge_attr - - Returns: - common_data - """ - if self.dynamic_batch_size: - if self.step >= self.step_num: - batch_size = self.get_batch_size(node_attr, edge_attr, - min((self.sample_num - self.index), self.batch_size)) - else: - batch_size = self.get_batch_size(node_attr, edge_attr, self.batch_size) - else: - batch_size = self.batch_size - - ######################## node_batch - node_batch_step = [] - sample_num = 0 - for i in range(self.index, self.index + batch_size): - node_batch_step.extend([sample_num] * node_attr[i].shape[0]) - sample_num += 1 - node_batch_step = np.array(node_batch_step) - node_num = node_batch_step.shape[0] - - ######################## edge_index - edge_index_step = np.array([[], []], dtype=np.int64) - max_edge_index = 0 - for i in range(self.index, self.index + batch_size): - edge_index_step = np.concatenate((edge_index_step, self.edge_index[i] + max_edge_index), 1) - max_edge_index = np.max(edge_index_step) + 1 - edge_num = edge_index_step.shape[1] - - ######################### padding - edge_index_step = self.pad_zero_to_end(edge_index_step, 1, self.max_edge_num_global - edge_num) - node_batch_step = self.pad_zero_to_end(node_batch_step, 0, self.max_node_num_global - node_num) - - ######################### mask - node_mask = self.gen_mask(self.max_node_num_global, node_num) - edge_mask = self.gen_mask(self.max_edge_num_global, edge_num) - batch_size_mask = self.gen_mask(self.batch_size, batch_size) - - ######################### make Tensor - edge_index_step = Tensor(edge_index_step, ms.int32) - node_batch_step = Tensor(node_batch_step, ms.int32) - node_mask = Tensor(node_mask) - edge_mask = Tensor(edge_mask) - batch_size_mask = Tensor(batch_size_mask) - - return CommonData(edge_index_step, node_batch_step, node_mask, edge_mask, batch_size_mask, node_num, edge_num, - batch_size).get_tuple_data() - - def gen_node_attr(self, node_attr, batch_size, node_num): - """gen_node_attr""" - node_attr_step = np.concatenate(node_attr[self.index:self.index + batch_size], 0) - node_attr_step = self.pad_zero_to_end(node_attr_step, 0, self.max_node_num_global - node_num) - node_attr_step = Tensor(node_attr_step) - return node_attr_step - - def gen_edge_attr(self, edge_attr, batch_size, edge_num): - """gen_edge_attr""" - edge_attr_step = np.concatenate(edge_attr[self.index:self.index + batch_size], 0) - edge_attr_step = self.pad_zero_to_end(edge_attr_step, 0, self.max_edge_num_global - edge_num) - edge_attr_step = Tensor(edge_attr_step) - return edge_attr_step - - def gen_global_attr(self, global_attr, batch_size): - """gen_global_attr""" - global_attr_step = np.stack(global_attr[self.index:self.index + batch_size], 0) - global_attr_step = self.pad_zero_to_end(global_attr_step, 0, self.batch_size - batch_size) - global_attr_step = Tensor(global_attr_step) - return global_attr_step - - def add_step_index(self, batch_size): - """add_step_index""" - self.index = self.index + batch_size - self.step += 1 - -class CommonData: - """CommonData""" - def __init__(self, edge_index_step, node_batch_step, node_mask, edge_mask, batch_size_mask, node_num, edge_num, - batch_size): - self.tuple_data = (edge_index_step, node_batch_step, node_mask, edge_mask, batch_size_mask, node_num, edge_num, - batch_size) - - def get_tuple_data(self): - """get_tuple_data""" - return self.tuple_data diff --git a/MindChem/applications/orb/mindchemistry/graph/graph.py b/MindChem/applications/orb/mindchemistry/graph/graph.py deleted file mode 100644 index 1f13b8869cf684124bddc0fee911d683aea1f902..0000000000000000000000000000000000000000 --- a/MindChem/applications/orb/mindchemistry/graph/graph.py +++ /dev/null @@ -1,294 +0,0 @@ -# Copyright 2024 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""graph""" -import mindspore as ms -from mindspore import ops, nn - - -def degree(index, dim_size, mask=None): - r""" - Computes the degree of a one-dimensional index tensor. - """ - if index.ndim != 1: - raise ValueError(f"the dimension of index {index.ndim} is not equal to 1") - - if mask is not None: - if mask.shape[0] != index.shape[0]: - raise ValueError(f"mask.shape[0] {mask.shape[0]} is not equal to index.shape[0] {index.shape[0]}") - if mask.ndim != 1: - st = [0] * mask.ndim - slice_size = [1] * mask.ndim - slice_size[0] = mask.shape[0] - mask = ops.slice(mask, st, slice_size).squeeze() - src = mask.astype(ms.int32) - else: - src = ops.ones(index.shape, ms.int32) - - index = index.unsqueeze(-1) - out = ops.zeros((dim_size,), ms.int32) - - return ops.tensor_scatter_add(out, index, src) - - -class Aggregate(nn.Cell): - r""" - Easy-use version of scatter. - - Args: - mode (str): {'add', 'sum', 'mean', 'avg'}, scatter mode. - - Raises: - ValueError: If `mode` is not legal. - - Supported Platforms: - ``CPU`` ``GPU`` ``Ascend`` - - """ - - def __init__(self, mode='add'): - super().__init__() - self.mode = mode - if mode in ('add', 'sum'): - self.scatter = self.scatter_sum - elif mode in ('mean', 'avg'): - self.scatter = self.scatter_mean - else: - raise ValueError(f"Unexpected scatter mode {mode}") - - @staticmethod - def scatter_sum(src, index, out=None, dim_size=None, mask=None): - r""" - Computes the scatter sum of a source tensor. The index should be one-dimensional - """ - if index.ndim != 1: - raise ValueError(f"the dimension of index {index.ndim} is not equal to 1") - if index.shape[0] != src.shape[0]: - raise ValueError(f"index.shape[0] {index.shape[0]} is not equal to src.shape[0] {src.shape[0]}") - if out is None and dim_size is None: - raise ValueError("the out Tensor and out dim_size cannot be both None") - - index = index.unsqueeze(-1) - - if out is None: - out = ops.zeros((dim_size,) + src.shape[1:], dtype=src.dtype) - elif dim_size is not None and out.shape[0] != dim_size: - raise ValueError(f"the out.shape[0] {out.shape[0]} is not equal to dim_size {dim_size}") - - if mask is not None: - if mask.shape[0] != src.shape[0]: - raise ValueError(f"mask.shape[0] {mask.shape[0]} is not equal to src.shape[0] {src.shape[0]}") - if src.ndim != mask.ndim: - if mask.size != mask.shape[0]: - raise ValueError("mask.ndim dose not match src.ndim, and cannot be broadcasted to the same") - shape = [1] * src.ndim - shape[0] = -1 - mask = ops.reshape(mask, shape) - src = ops.mul(src, mask.astype(src.dtype)) - - return ops.tensor_scatter_add(out, index, src) - - @staticmethod - def scatter_mean(src, index, out=None, dim_size=None, mask=None): - r""" - Computes the scatter mean of a source tensor. The index should be one-dimensional - """ - if out is None and dim_size is None: - raise ValueError("the out Tensor and out dim_size cannot be both None") - - if dim_size is None: - dim_size = out.shape[0] - elif out is not None and out.shape[0] != dim_size: - raise ValueError(f"the out.shape[0] {out.shape[0]} is not equal to dim_size {dim_size}") - - count = degree(index, dim_size, mask=mask) - eps = 1e-5 - count = ops.maximum(count, eps) - - scatter_sum = Aggregate.scatter_sum(src, index, dim_size=dim_size, mask=mask) - - shape = [1] * scatter_sum.ndim - shape[0] = -1 - count = ops.reshape(count, shape).astype(scatter_sum.dtype) - res = ops.true_divide(scatter_sum, count) - - if out is not None: - res = res + out - - return res - - -class AggregateNodeToGlobal(Aggregate): - """AggregateNodeToGlobal""" - - def __init__(self, mode='add'): - super().__init__(mode=mode) - - def construct(self, node_attr, batch, out=None, dim_size=None, mask=None): - r""" - Args: - node_attr (Tensor): The source tensor of node attributes. - batch (Tensor): The indices of sample to scatter to. - out (Tensor): The destination tensor. Default: None. - dim_size (int): If `out` is not given, automatically create output with size `dim_size`. Default: None. - out and dim_size cannot be both None. - mask (Tensor): The mask of the node_attr tensor - Returns: - Tensor. - """ - return self.scatter(node_attr, batch, out=out, dim_size=dim_size, mask=mask) - - -class AggregateEdgeToGlobal(Aggregate): - """AggregateEdgeToGlobal""" - - def __init__(self, mode='add'): - super().__init__(mode=mode) - - def construct(self, edge_attr, batch_edge, out=None, dim_size=None, mask=None): - r""" - Args: - edge_attr (Tensor): The source tensor of edge attributes. - batch_edge (Tensor): The indices of sample to scatter to. - out (Tensor): The destination tensor. Default: None. - dim_size (int): If `out` is not given, automatically create output with size `dim_size`. Default: None. - out and dim_size cannot be both None. - mask (Tensor): The mask of the node_attr tensor - Returns: - Tensor. - """ - return self.scatter(edge_attr, batch_edge, out=out, dim_size=dim_size, mask=mask) - - -class AggregateEdgeToNode(Aggregate): - """AggregateEdgeToNode""" - - def __init__(self, mode='add', dim=0): - super().__init__(mode=mode) - self.dim = dim - - def construct(self, edge_attr, edge_index, out=None, dim_size=None, mask=None): - r""" - Args: - edge_attr (Tensor): The source tensor of edge attributes. - edge_index (Tensor): The indices of nodes in each edge. - out (Tensor): The destination tensor. Default: None. - dim_size (int): If `out` is not given, automatically create output with size `dim_size`. Default: None. - out and dim_size cannot be both None. - mask (Tensor): The mask of the node_attr tensor - Returns: - Tensor. - """ - return self.scatter(edge_attr, edge_index[self.dim], out=out, dim_size=dim_size, mask=mask) - - -class Lift(nn.Cell): - """Lift""" - - def __init__(self, mode="multi_graph"): - super().__init__() - self.mode = mode - if mode not in ["multi_graph", "single_graph"]: - raise ValueError(f"Unexpected lift mode {mode}") - - @staticmethod - def lift(src, index, axis=0, mask=None): - """lift""" - res = ops.index_select(src, axis, index) - - if mask is not None: - if mask.shape[0] != res.shape[0]: - raise ValueError(f"mask.shape[0] {mask.shape[0]} is not equal to res.shape[0] {res.shape[0]}") - if res.ndim != mask.ndim: - if mask.size != mask.shape[0]: - raise ValueError("mask.ndim dose not match src.ndim, and cannot be broadcasted to the same") - shape = [1] * res.ndim - shape[0] = -1 - mask = ops.reshape(mask, shape) - res = ops.mul(res, mask.astype(res.dtype)) - - return res - - @staticmethod - def repeat(src, num, axis=0, max_len=None): - res = ops.repeat_elements(src, num, axis) - - if (max_len is not None) and (max_len > num): - padding = ops.zeros((max_len - num,) + res.shape[1:], dtype=res.dtype) - res = ops.cat((res, padding), axis=0) - - return res - - -class LiftGlobalToNode(Lift): - """LiftGlobalToNode""" - - def __init__(self, mode="multi_graph"): - super().__init__(mode=mode) - - def construct(self, global_attr, batch=None, num_node=None, mask=None, max_len=None): - r""" - Args: - global_attr (Tensor): The source tensor of global attributes. - batch (Tensor): The indices of samples to get. - num_node (Int): The number of node in the graph, when there is only 1 graph. - mask (Tensor): The mask of the output tensor. - max_len (Int): The output length. - Returns: - Tensor. - """ - if global_attr.shape[0] > 1 or self.mode == "multi_graph": - return self.lift(global_attr, batch, mask=mask) - return self.repeat(global_attr, num_node, max_len=max_len) - - -class LiftGlobalToEdge(Lift): - """LiftGlobalToEdge""" - - def __init__(self, mode="multi_graph"): - super().__init__(mode=mode) - - def construct(self, global_attr, batch_edge=None, num_edge=None, mask=None, max_len=None): - r""" - Args: - global_attr (Tensor): The source tensor of global attributes. - batch_edge (Tensor): The indices of samples to get. - num_edge (Int): The number of edge in the graph, when there is only 1 graph. - mask (Tensor): The mask of the output tensor. - max_len (Int): The output length. - Returns: - Tensor. - """ - if global_attr.shape[0] > 1 or self.mode == "multi_graph": - return self.lift(global_attr, batch_edge, mask=mask) - return self.repeat(global_attr, num_edge, max_len=max_len) - - -class LiftNodeToEdge(Lift): - """LiftNodeToEdge""" - - def __init__(self, dim=0): - super().__init__(mode="multi_graph") - self.dim = dim - - def construct(self, global_attr, edge_index, mask=None): - r""" - Args: - global_attr (Tensor): The source tensor of global attributes. - edge_index (Tensor): The indices of nodes for each edge. - mask (Tensor): The mask of the output tensor. - Returns: - Tensor. - """ - return self.lift(global_attr, edge_index[self.dim], mask=mask) diff --git a/MindChem/applications/orb/mindchemistry/graph/loss.py b/MindChem/applications/orb/mindchemistry/graph/loss.py deleted file mode 100644 index 28e241e6fbcc72a8e7b8f2045a113e63fd8293f1..0000000000000000000000000000000000000000 --- a/MindChem/applications/orb/mindchemistry/graph/loss.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright 2024 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""loss""" -import mindspore as ms -from mindspore import ops, nn - - -class LossMaskBase(nn.Cell): - """LossMaskBase""" - - def __init__(self, reduction='mean', dtype=None): - super().__init__() - self.reduction = reduction - if reduction not in ["mean", "sum"]: - raise ValueError(f"Unexpected reduction mode {reduction}") - - self.dtype = dtype if dtype is not None else ms.float32 - - def construct(self, logits, labels, mask=None, num=None): - """construct""" - if logits.shape != labels.shape: - raise ValueError(f"logits.shape {logits.shape} is not equal to labels.shape {labels.shape}") - - x = self.loss(logits.astype(self.dtype), labels.astype(self.dtype)) - - if mask is not None: - if mask.shape[0] != x.shape[0]: - raise ValueError(f"mask.shape[0] {mask.shape[0]} is not equal to input.shape[0] {x.shape[0]}") - if x.ndim != mask.ndim: - if mask.size != mask.shape[0]: - raise ValueError("mask.ndim dose not match src.ndim, and cannot be broadcasted to the same") - shape = [1] * x.ndim - shape[0] = -1 - mask = ops.reshape(mask, shape) - x = ops.mul(x, mask.astype(x.dtype)) - - # pylint: disable=W0622 - sum = ops.sum(x) - if self.reduction == "sum": - return sum - if num is None: - num = x.size - else: - num_div = x.shape[0] - if num_div != 0: - num = x.size / num_div * num - else: - raise ValueError - return ops.true_divide(sum, num) - -class L1LossMask(LossMaskBase): - - def __init__(self, reduction='mean'): - super().__init__(reduction) - - def loss(self, logits, labels): - return ops.abs(logits - labels) - - -class L2LossMask(LossMaskBase): - - def __init__(self, reduction='mean'): - super().__init__(reduction) - - def loss(self, logits, labels): - return ops.square(logits - labels) diff --git a/MindChem/applications/orb/mindchemistry/graph/normlization.py b/MindChem/applications/orb/mindchemistry/graph/normlization.py deleted file mode 100644 index caccfdb4246b7bd6b670ef6d5e3931a0a2342b0b..0000000000000000000000000000000000000000 --- a/MindChem/applications/orb/mindchemistry/graph/normlization.py +++ /dev/null @@ -1,278 +0,0 @@ -# Copyright 2024 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""norm""" -import mindspore as ms -from mindspore import ops, Parameter, nn -from .graph import AggregateNodeToGlobal, LiftGlobalToNode - - -class BatchNormMask(nn.Cell): - """BatchNormMask""" - - def __init__(self, num_features, eps=1e-5, momentum=0.9, affine=True): - super().__init__() - self.num_features = num_features - self.eps = eps - self.momentum = momentum - self.affine = affine - self.moving_mean = Parameter(ops.zeros((num_features,), ms.float32), name="moving_mean", requires_grad=False) - self.moving_variance = Parameter(ops.ones((num_features,), ms.float32), - name="moving_variance", - requires_grad=False) - if affine: - self.gamma = Parameter(ops.ones((num_features,), ms.float32), name="gamma", requires_grad=True) - self.beta = Parameter(ops.zeros((num_features,), ms.float32), name="beta", requires_grad=True) - - def construct(self, x, mask, num): - """construct""" - if x.shape[1] != self.num_features: - raise ValueError(f"x.shape[1] {x.shape[1]} is not equal to num_features {self.num_features}") - if x.shape[0] != mask.shape[0]: - raise ValueError(f"x.shape[0] {x.shape[0]} is not equal to mask.shape[0] {mask.shape[0]}") - - if x.ndim != mask.ndim: - if mask.size != mask.shape[0]: - raise ValueError("mask.ndim dose not match src.ndim, and cannot be broadcasted to the same") - shape = [1] * x.ndim - shape[0] = -1 - mask = ops.reshape(mask, shape).astype(x.dtype) - x = ops.mul(x, mask) - - # pylint: disable=R1705 - if x.ndim > 2: - norm_axis = [] - shape = [-1] - for i in range(2, x.ndim): - norm_axis.append(i) - shape.append(1) - - if self.training: - mean = ops.div(ops.sum(x, 0), num) - mean = ops.mean(mean, norm_axis) - self.moving_mean = self.momentum * self.moving_mean + (1 - self.momentum) * mean - mean = ops.reshape(mean, shape) - mean = ops.mul(mean, mask) - x = x - mean - - var = ops.div(ops.sum(ops.pow(x, 2), 0), num) - var = ops.mean(var, norm_axis) - self.moving_variance = self.momentum * self.moving_variance + (1 - self.momentum) * var - std = ops.sqrt(ops.add(var, self.eps)) - std = ops.reshape(std, shape) - y = ops.true_divide(x, std) - else: - mean = ops.reshape(self.moving_mean.astype(x.dtype), shape) - mean = ops.mul(mean, mask) - std = ops.sqrt(ops.add(self.moving_variance.astype(x.dtype), self.eps)) - std = ops.reshape(std, shape) - y = ops.true_divide(ops.sub(x, mean), std) - - if self.affine: - gamma = ops.reshape(self.gamma.astype(x.dtype), shape) - beta = ops.reshape(self.beta.astype(x.dtype), shape) * mask - y = y * gamma + beta - - return y - else: - if self.training: - mean = ops.div(ops.sum(x, 0), num) - self.moving_mean = self.momentum * self.moving_mean + (1 - self.momentum) * mean - mean = ops.mul(mean, mask) - x = x - mean - - var = ops.div(ops.sum(ops.pow(x, 2), 0), num) - self.moving_variance = self.momentum * self.moving_variance + (1 - self.momentum) * var - std = ops.sqrt(ops.add(var, self.eps)) - y = ops.true_divide(x, std) - else: - mean = ops.mul(self.moving_mean.astype(x.dtype), mask) - std = ops.sqrt(ops.add(self.moving_variance.astype(x.dtype), self.eps)) - y = ops.true_divide(ops.sub(x, mean), std) - - if self.affine: - beta = self.beta.astype(x.dtype) * mask - y = y * self.gamma.astype(x.dtype) + beta - - return y - - -class GraphLayerNormMask(nn.Cell): - """GraphLayerNormMask""" - - def __init__(self, - normalized_shape, - begin_norm_axis=-1, - eps=1e-5, - sub_mean=True, - divide_std=True, - affine_weight=True, - affine_bias=True, - aggr_mode="mean"): - super().__init__() - self.normalized_shape = normalized_shape - self.begin_norm_axis = begin_norm_axis - self.eps = eps - self.sub_mean = sub_mean - self.divide_std = divide_std - self.affine_weight = affine_weight - self.affine_bias = affine_bias - self.mean = ops.ReduceMean(keep_dims=True) - self.aggregate = AggregateNodeToGlobal(mode=aggr_mode) - self.lift = LiftGlobalToNode(mode="multi_graph") - - if affine_weight: - self.gamma = Parameter(ops.ones(normalized_shape, ms.float32), name="gamma", requires_grad=True) - if affine_bias: - self.beta = Parameter(ops.zeros(normalized_shape, ms.float32), name="beta", requires_grad=True) - - def construct(self, x, batch, mask, dim_size, scale=None): - """construct""" - begin_norm_axis = self.begin_norm_axis if self.begin_norm_axis >= 0 else self.begin_norm_axis + x.ndim - if begin_norm_axis not in range(1, x.ndim): - raise ValueError(f"begin_norm_axis {begin_norm_axis} is not in range 1 to {x.ndim}") - - norm_axis = [] - for i in range(begin_norm_axis, x.ndim): - norm_axis.append(i) - if self.normalized_shape[i - begin_norm_axis] != x.shape[i]: - raise ValueError(f"x.shape[{i}] {x.shape[i]} is not equal to normalized_shape[{i - begin_norm_axis}] " - f"{self.normalized_shape[i - begin_norm_axis]}") - - if x.shape[0] != mask.shape[0]: - raise ValueError(f"x.shape[0] {x.shape[0]} is not equal to mask.shape[0] {mask.shape[0]}") - if x.shape[0] != batch.shape[0]: - raise ValueError(f"x.shape[0] {x.shape[0]} is not equal to batch.shape[0] {batch.shape[0]}") - - if x.ndim != mask.ndim: - if mask.size != mask.shape[0]: - raise ValueError("mask.ndim dose not match src.ndim, and cannot be broadcasted to the same") - shape = [1] * x.ndim - shape[0] = -1 - mask = ops.reshape(mask, shape).astype(x.dtype) - x = ops.mul(x, mask) - - if self.sub_mean: - mean = self.aggregate(x, batch, dim_size=dim_size, mask=mask) - mean = self.mean(mean, norm_axis) - mean = self.lift(mean, batch) - mean = ops.mul(mean, mask) - x = x - mean - - if self.divide_std: - var = self.aggregate(ops.square(x), batch, dim_size=dim_size, mask=mask) - var = self.mean(var, norm_axis) - if scale is not None: - var = var * scale - std = ops.sqrt(var + self.eps) - std = self.lift(std, batch) - x = ops.true_divide(x, std) - - if self.affine_weight: - x = x * self.gamma.astype(x.dtype) - - if self.affine_bias: - beta = ops.mul(self.beta.astype(x.dtype), mask) - x = x + beta - - return x - - -class GraphInstanceNormMask(nn.Cell): - """GraphInstanceNormMask""" - - def __init__(self, - num_features, - eps=1e-5, - sub_mean=True, - divide_std=True, - affine_weight=True, - affine_bias=True, - aggr_mode="mean"): - super().__init__() - self.num_features = num_features - self.eps = eps - self.sub_mean = sub_mean - self.divide_std = divide_std - self.affine_weight = affine_weight - self.affine_bias = affine_bias - self.mean = ops.ReduceMean(keep_dims=True) - self.aggregate = AggregateNodeToGlobal(mode=aggr_mode) - self.lift = LiftGlobalToNode(mode="multi_graph") - - if affine_weight: - self.gamma = Parameter(ops.ones((self.num_features,), ms.float32), name="gamma", requires_grad=True) - if affine_bias: - self.beta = Parameter(ops.zeros((self.num_features,), ms.float32), name="beta", requires_grad=True) - - def construct(self, x, batch, mask, dim_size, scale=None): - """construct""" - if x.shape[1] != self.num_features: - raise ValueError(f"x.shape[1] {x.shape[1]} is not equal to num_features {self.num_features}") - if x.shape[0] != mask.shape[0]: - raise ValueError(f"x.shape[0] {x.shape[0]} is not equal to mask.shape[0] {mask.shape[0]}") - if x.shape[0] != batch.shape[0]: - raise ValueError(f"x.shape[0] {x.shape[0]} is not equal to batch.shape[0] {batch.shape[0]}") - - if x.ndim != mask.ndim: - if mask.size != mask.shape[0]: - raise ValueError("mask.ndim dose not match src.ndim, and cannot be broadcasted to the same") - shape = [1] * x.ndim - shape[0] = -1 - mask = ops.reshape(mask, shape).astype(x.dtype) - x = ops.mul(x, mask) - gamma = None # 后来添加,防止未定义报错 - if x.ndim > 2: - norm_axis = [] - shape = [-1] - for i in range(2, x.ndim): - norm_axis.append(i) - shape.append(1) - - if self.affine_weight: - gamma = ops.reshape(self.gamma.astype(x.dtype), shape) - if self.affine_bias: - beta = ops.reshape(self.beta.astype(x.dtype), shape) - else: - if self.affine_weight: - gamma = self.gamma.astype(x.dtype) - if self.affine_bias: - beta = self.beta.astype(x.dtype) - - if self.sub_mean: - mean = self.aggregate(x, batch, dim_size=dim_size, mask=mask) - if x.ndim > 2: - mean = self.mean(mean, norm_axis) - mean = self.lift(mean, batch) - mean = ops.mul(mean, mask) - x = x - mean - - if self.divide_std: - var = self.aggregate(ops.square(x), batch, dim_size=dim_size, mask=mask) - if x.ndim > 2: - var = self.mean(var, norm_axis) - if scale is not None: - var = var * scale - std = ops.sqrt(var + self.eps) - std = self.lift(std, batch) - x = ops.true_divide(x, std) - - if self.affine_weight: - x = x * gamma - - if self.affine_bias: - beta = ops.mul(beta, mask) - x = x + beta - - return x diff --git a/MindChem/applications/orb/mindchemistry/so2_conv/__init__.py b/MindChem/applications/orb/mindchemistry/so2_conv/__init__.py deleted file mode 100644 index e5542a477db973beb0c65aeea02ca720de7c8ddd..0000000000000000000000000000000000000000 --- a/MindChem/applications/orb/mindchemistry/so2_conv/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright 2024 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -init file -""" -from .so3 import SO3Rotation -from .so2 import SO2Convolution diff --git a/MindChem/applications/orb/mindchemistry/so2_conv/init_edge_rot_mat.py b/MindChem/applications/orb/mindchemistry/so2_conv/init_edge_rot_mat.py deleted file mode 100644 index a05a2264b37bc3174f058b87c3ba7d6dfb965034..0000000000000000000000000000000000000000 --- a/MindChem/applications/orb/mindchemistry/so2_conv/init_edge_rot_mat.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright 2024 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -file to get rotating matrix from edge distance vector -""" -from mindspore import ops -import mindspore.numpy as ms_np - - -def init_edge_rot_mat(edge_distance_vec): - """ - get rotating matrix from edge distance vector - """ - epsilon = 0.00000001 - edge_vec_0 = edge_distance_vec - edge_vec_0_distance = ops.sqrt(ops.maximum(ops.sum(edge_vec_0 ** 2, dim=1), epsilon)) - # Make sure the atoms are far enough apart - norm_x = ops.div(edge_vec_0, edge_vec_0_distance.view(-1, 1)) - edge_vec_2 = ops.rand_like(edge_vec_0) - 0.5 - - edge_vec_2 = ops.div(edge_vec_2, ops.sqrt(ops.maximum(ops.sum(edge_vec_2 ** 2, dim=1), epsilon)).view(-1, 1)) - # Create two rotated copies of the random vectors in case the random vector is aligned with norm_x - # With two 90 degree rotated vectors, at least one should not be aligned with norm_x - edge_vec_2b = edge_vec_2.copy() - edge_vec_2b[:, 0] = -edge_vec_2[:, 1] - edge_vec_2b[:, 1] = edge_vec_2[:, 0] - edge_vec_2c = edge_vec_2.copy() - edge_vec_2c[:, 1] = -edge_vec_2[:, 2] - edge_vec_2c[:, 2] = edge_vec_2[:, 1] - vec_dot_b = ops.abs(ops.sum(edge_vec_2b * norm_x, dim=1)).view(-1, 1) - vec_dot_c = ops.abs(ops.sum(edge_vec_2c * norm_x, dim=1)).view(-1, 1) - vec_dot = ops.abs(ops.sum(edge_vec_2 * norm_x, dim=1)).view(-1, 1) - edge_vec_2 = ops.where(ops.broadcast_to(ops.gt(vec_dot, vec_dot_b), edge_vec_2b.shape), edge_vec_2b, edge_vec_2) - vec_dot = ops.abs(ops.sum(edge_vec_2 * norm_x, dim=1)).view(-1, 1) - edge_vec_2 = ops.where(ops.broadcast_to(ops.gt(vec_dot, vec_dot_c), edge_vec_2c.shape), edge_vec_2c, edge_vec_2) - vec_dot = ops.abs(ops.sum(edge_vec_2 * norm_x, dim=1)) - # Check the vectors aren't aligned - - norm_z = ms_np.cross(norm_x, edge_vec_2, axis=1) - norm_z = ops.div(norm_z, ops.sqrt(ops.maximum(ops.sum(norm_z ** 2, dim=1, keepdim=True), epsilon))) - norm_z = ops.div(norm_z, ops.sqrt(ops.maximum(ops.sum(norm_z ** 2, dim=1), epsilon)).view(-1, 1)) - - norm_y = ms_np.cross(norm_x, norm_z, axis=1) - norm_y = ops.div(norm_y, ops.sqrt(ops.maximum(ops.sum(norm_y ** 2, dim=1, keepdim=True), epsilon))) - # Construct the 3D rotation matrix - norm_x = norm_x.view(-1, 3, 1) - norm_y = -norm_y.view(-1, 3, 1) - norm_z = norm_z.view(-1, 3, 1) - edge_rot_mat_inv = ops.cat([norm_z, norm_x, norm_y], axis=2) - - edge_rot_mat = ops.swapaxes(edge_rot_mat_inv, 1, 2) - return edge_rot_mat diff --git a/MindChem/applications/orb/mindchemistry/so2_conv/jd.pkl b/MindChem/applications/orb/mindchemistry/so2_conv/jd.pkl deleted file mode 100644 index 1b762ad4369564e2f21b808850cc8013e6e41385..0000000000000000000000000000000000000000 Binary files a/MindChem/applications/orb/mindchemistry/so2_conv/jd.pkl and /dev/null differ diff --git a/MindChem/applications/orb/mindchemistry/so2_conv/so2.py b/MindChem/applications/orb/mindchemistry/so2_conv/so2.py deleted file mode 100644 index 0c23b319e8170344c6b003e8edfa9207f27965d9..0000000000000000000000000000000000000000 --- a/MindChem/applications/orb/mindchemistry/so2_conv/so2.py +++ /dev/null @@ -1,260 +0,0 @@ -# Copyright 2024 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -so2 file -""" -import mindspore as ms -from mindspore import ops, nn -from mindscience.e3nn.o3 import Irreps - - -class Silu(nn.Cell): - """ - silu activation class - """ - - def __init__(self): - super().__init__() - self.sigmoid = nn.Sigmoid() - - def construct(self, x): - """ - silu activation class construct process - """ - return ops.mul(x, self.sigmoid(x)) - - -class SO2MConvolution(nn.Cell): - """ - SO2 Convolution subnetwork - """ - - def __init__(self, in_channels, out_channels): - super().__init__() - self.fc = nn.Dense(in_channels // 2, out_channels, - has_bias=False).to_float(ms.float16) - self.out_channels = out_channels - - def construct(self, x_m): - """ - SO2 Convolution sub network construct process - """ - x_m = self.fc(x_m).astype(ms.float32) - x_i = ops.narrow(x_m, 2, 0, self.out_channels // 2) - x_r = ops.narrow(x_m, 2, self.out_channels // 2, - self.out_channels // 2) - - x_m_r = ops.narrow(x_r, 1, 1, 1) - ops.narrow( - x_i, 1, 0, 1) # x_r[:, 1] - x_i[:, 0] - x_m_i = ops.narrow(x_i, 1, 1, 1) + ops.narrow( - x_r, 1, 0, 1) # x_i[:, 1] + x_r[:, 0] - - x_out = ops.cat((x_m_i, x_m_r), axis=1) - return x_out - - -class SO2Convolution(nn.Cell): - """ - SO2 Convolution network - """ - - def __init__(self, irreps_in, irreps_out): - super().__init__() - - self.irreps_in1 = Irreps(irreps_in) - self.irreps_out = Irreps(irreps_out) - - self.max_order_in = -1 - for mulir in self.irreps_in1: - self.max_order_in = max(self.max_order_in, mulir.ir.l) - self.max_order_out = -1 - for mulir in self.irreps_out: - self.max_order_out = max(self.max_order_out, mulir.ir.l) - - self.m_shape_dict_in, self.irreps_in1_length = self.get_m_info( - self.irreps_in1, self.max_order_in) - self.m_shape_dict_out, self.irreps_out_length = self.get_m_info( - self.irreps_out, self.max_order_out) - - self.fc_m0 = nn.Dense(self.m_shape_dict_in.get(0, None), - self.m_shape_dict_out.get(0, None)).to_float(ms.float16) - - self.global_max_order = min(self.max_order_in + 1, - self.max_order_out + 1) - - self.so2_m_conv = nn.CellList([]) - for i in range(self.global_max_order): - if i == 0: - continue - so2_m_convolution = SO2MConvolution(self.m_shape_dict_in.get(i, None), - self.m_shape_dict_out.get(i, None)) - self.so2_m_conv.append(so2_m_convolution) - - self.max_m_in = 2 * self.max_order_in + 1 - - self.irreps_out_data = [] - for mulir in self.irreps_out: - key = mulir.ir.l - value = mulir.mul - self.irreps_out_data.append((key, value)) - - def get_m_info(self, irreps, max_order): - """ - helper function to get m_info - """ - m_shape_dict = {} - m_mul_ir_l_dict = {} - for i in range(max_order + 1): - m_shape_dict[i] = 0 - - for mulir in irreps: - mul = mulir.mul - ir_l = mulir.ir.l - if ir_l not in m_mul_ir_l_dict: - m_mul_ir_l_dict[ir_l] = mul - else: - m_mul_ir_l_dict[ir_l] = m_mul_ir_l_dict[ir_l] + mul - for j in range(mulir.ir.l + 1): - if j == 0: - m_shape_dict[j] = m_shape_dict[j] + mul - else: - m_shape_dict[j] = m_shape_dict[j] + 2 * mul - - return m_shape_dict, len(irreps) - - def get_m_list_merge(self, x): - """ - helper function to get m_list_merge - """ - m_list = [] - for _ in range(self.max_m_in): - m_list.append([]) - - index_shifting = int((self.max_m_in - 1) / 2) - for tmp in x: - m_length = tmp.shape[-1] - m_shift = int((m_length - 1) / 2) - for j in range(m_length): - m_list[j - m_shift + index_shifting].append(tmp[:, :, j]) - - m_list_merge = [] - for i in range(index_shifting + 1): - if i == 0: - m_list_merge.append(ops.cat(m_list[index_shifting - i], -1)) - else: - m_list_merge.append( - ops.cat((ops.cat(m_list[index_shifting - i], -1), - ops.cat(m_list[index_shifting + i], -1)), -1)) - return m_list_merge - - def construct(self, x, x_edge): - """ - SO2 Convolution network construct process - """ - ##################### _m_primary ######################### - num_edges = ops.shape(x_edge)[0] - m_list_merge = self.get_m_list_merge(x) - # ##################### finish _m_primary ######################### - # radial function - out = [] - - ### Compute m=0 coefficients separately since they only have real values - x_0 = m_list_merge[0] - - x_0 = self.fc_m0(x_0).astype(ms.float32) - out.append(x_0) - - #### Compute the values for the m > 0 coefficients - for m in range(self.global_max_order): - if m == 0: - continue - x_m = m_list_merge[m] - x_m = x_m.reshape(num_edges, 2, -1) - x_m = self.so2_m_conv[m - 1](x_m) - out.append(x_m) - - ###################### start fill 0 ###################### - if self.max_order_out + 1 > len(m_list_merge): - for m in range(len(m_list_merge), self.max_order_out + 1): - extra_zero = ops.zeros( - (num_edges, 2, int(self.m_shape_dict_out.get(m, None) / 2))) - out.append(extra_zero) - ###################### finish fill 0 ###################### - - ###################### start _l_primary ######################### - l_primary_list_0 = [] - l_primary_list_left = [] - l_primary_list_right = [] - - for _ in range(self.irreps_out_length): - l_primary_list_0.append([]) - l_primary_list_left.append([]) - l_primary_list_right.append([]) - - m_0 = out[0] - offset = 0 - index = 0 - - for key_val in self.irreps_out_data: - key = key_val[0] - value = key_val[1] - if key >= 0: - l_primary_list_0[index].append( - ops.unsqueeze(m_0[:, offset:offset + value], -1)) - offset = offset + value - index = index + 1 - - for m in range(1, len(out)): - right = out[m][:, 1] - offset = 0 - index = 0 - - for key_val in self.irreps_out_data: - key = key_val[0] - value = key_val[1] - if key >= m: - l_primary_list_right[index].append( - ops.unsqueeze(right[:, offset:offset + value], -1)) - offset = offset + value - index = index + 1 - - for m in range(len(out) - 1, 0, -1): - left = out[m][:, 0] - offset = 0 - index = 0 - - for key_val in self.irreps_out_data: - key = key_val[0] - value = key_val[1] - if key >= m: - l_primary_list_left[index].append( - ops.unsqueeze(left[:, offset:offset + value], -1)) - offset = offset + value - index = index + 1 - - l_primary_list = [] - for i in range(self.irreps_out_length): - if i == 0: - tmp = ops.cat(l_primary_list_0[i], -1) - l_primary_list.append(tmp) - else: - tmp = ops.cat( - (ops.cat((ops.cat(l_primary_list_left[i], - -1), ops.cat(l_primary_list_0[i], -1)), - -1), ops.cat(l_primary_list_right[i], -1)), -1) - l_primary_list.append(tmp) - - ##################### finish _l_primary ######################### - return tuple(l_primary_list) diff --git a/MindChem/applications/orb/mindchemistry/so2_conv/so3.py b/MindChem/applications/orb/mindchemistry/so2_conv/so3.py deleted file mode 100644 index 0ffae5a5f89042b85d5bbe9d780b79a656ed912a..0000000000000000000000000000000000000000 --- a/MindChem/applications/orb/mindchemistry/so2_conv/so3.py +++ /dev/null @@ -1,156 +0,0 @@ -# Copyright 2024 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -so3 file -""" -import mindspore as ms -from mindspore import nn, ops, vmap, jit_class -from mindspore.numpy import tensordot -from mindscience.e3nn import o3 -from mindscience.e3nn.o3 import Irreps - -from .wigner import wigner_D - - -class SO3Embedding(nn.Cell): - """ - SO3Embedding class - """ - - def __init__(self): - self.embedding = None - - def _rotate(self, so3rotation, lmax_list, max_list): - """ - SO3Embedding rotate - """ - embedding_rotate = so3rotation[0].rotate(self.embedding, lmax_list[0], - max_list[0]) - self.embedding = embedding_rotate - - def _rotate_inv(self, so3rotation): - """ - SO3Embedding rotate inverse - """ - embedding_rotate = so3rotation[0].rotate_inv(self.embedding, - self.lmax_list[0], - self.mmax_list[0]) - self.embedding = embedding_rotate - - -@jit_class -class SO3Rotation: - """ - SO3_Rotation class - """ - - def __init__(self, lmax, irreps_in, irreps_out): - self.lmax = lmax - self.irreps_in1 = Irreps(irreps_in) - self.irreps_out = Irreps(irreps_out) - self.tensordot_vmap = vmap(tensordot, (0, 0, None), 0) - - @staticmethod - def narrow(inputs, axis, start, length): - """ - SO3_Rotation narrow class - """ - begins = [0] * inputs.ndim - begins[axis] = start - - sizes = list(inputs.shape) - - sizes[axis] = length - res = ops.slice(inputs, begins, sizes) - return res - - @staticmethod - def rotation_to_wigner_d_matrix(edge_rot_mat, start_lmax, end_lmax): - """ - SO3_Rotation rotation_to_wigner_d_matrix - """ - x = edge_rot_mat @ ms.Tensor([0.0, 1.0, 0.0]) - alpha, beta = o3.xyz_to_angles(x) - rvalue = (ops.swapaxes( - o3.angles_to_matrix(alpha, beta, ops.zeros_like(alpha)), -1, -2) - @ edge_rot_mat) - gamma = ops.atan2(rvalue[..., 0, 2], rvalue[..., 0, 0]) - - block_list = [] - for lmax in range(start_lmax, end_lmax + 1): - block = wigner_D(lmax, alpha, beta, gamma).astype(ms.float32) - block_list.append(block) - return block_list - - def set_wigner(self, rot_mat3x3): - """ - SO3_Rotation set_wigner - """ - wigner = self.rotation_to_wigner_d_matrix(rot_mat3x3, 0, self.lmax) - wigner_inv = [] - length = len(wigner) - for i in range(length): - wigner_inv.append(ops.swapaxes(wigner[i], 1, 2)) - return tuple(wigner), tuple(wigner_inv) - - def rotate(self, embedding, wigner): - """ - SO3_Rotation rotate - """ - res = [] - batch_shape = embedding.shape[:-1] - for (s, l), mir in zip(self.irreps_in1.slice_tuples, - self.irreps_in1.data): - v_slice = self.narrow(embedding, -1, s, l) - if embedding.ndim == 1: - res.append((v_slice.reshape((1,) + batch_shape + - (mir.mul, mir.ir.dim)), mir.ir)) - else: - res.append( - (v_slice.reshape(batch_shape + (mir.mul, mir.ir.dim)), - mir.ir)) - rotate_data_list = [] - for data, ir in res: - self.tensordot_vmap(data.astype(ms.float16), - wigner[ir.l].astype(ms.float16), ([1], [1])) - rotate_data = self.tensordot_vmap(data.astype(ms.float16), - wigner[ir.l].astype(ms.float16), - ((1), (1))).astype(ms.float32) - rotate_data_list.append(rotate_data) - return tuple(rotate_data_list) - - def rotate_inv(self, embedding, wigner_inv): - """ - SO3_Rotation rotate_inv - """ - res = [] - batch_shape = embedding[0].shape[0:1] - index = 0 - for (_, _), mir in zip(self.irreps_out.slice_tuples, - self.irreps_out.data): - v_slice = embedding[index] - if embedding[0].ndim == 1: - res.append((v_slice, mir.ir)) - else: - res.append((v_slice, mir.ir)) - index = index + 1 - rotate_back_data_list = [] - for data, ir in res: - rotate_back_data = self.tensordot_vmap( - data.astype(ms.float16), wigner_inv[ir.l].astype(ms.float16), - ((1), (1))).astype(ms.float32) - rotate_back_data_list.append( - rotate_back_data.view(batch_shape + (-1,))) - return ops.cat(rotate_back_data_list, -1) diff --git a/MindChem/applications/orb/mindchemistry/so2_conv/wigner.py b/MindChem/applications/orb/mindchemistry/so2_conv/wigner.py deleted file mode 100644 index c3e08615c7a3da1ab757a86d28ccebdf3aa25ea0..0000000000000000000000000000000000000000 --- a/MindChem/applications/orb/mindchemistry/so2_conv/wigner.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright 2024 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -wigner file -""" - -# pylint: disable=C0103 -import pickle -from mindspore import ops -import mindspore as ms -from mindscience.e3nn.utils.func import broadcast_args - - -def wigner_D(lv, alpha, beta, gamma): - """ - # Borrowed from e3nn @ 0.4.0: - # https://github.com/e3nn/e3nn/blob/0.4.0/e3nn/o3/_wigner.py#L10 - # jd is a list of tensors of shape (2l+1, 2l+1) - - # Borrowed from e3nn @ 0.4.0: - # https://github.com/e3nn/e3nn/blob/0.4.0/e3nn/o3/_wigner.py#L37 - # - # In 0.5.0, e3nn shifted to torch.matrix_exp which is significantly slower: - # https://github.com/e3nn/e3nn/blob/0.5.0/e3nn/o3/_wigner.py#L92 - """ - jd = None - with open("jd.pkl", "rb") as f: - jd = pickle.load(f) - if not lv < len(jd): - raise NotImplementedError( - f"wigner D maximum l implemented is {len(jd) - 1}, send us an email to ask for more" - ) - alpha, beta, gamma = broadcast_args(alpha, beta, gamma) - j = jd[lv] - xa = _z_rot_mat(alpha, lv) - xb = _z_rot_mat(beta, lv) - xc = _z_rot_mat(gamma, lv) - return xa @ j.astype(ms.float16) @ xb @ j.astype(ms.float16) @ xc - - -def _z_rot_mat(angle, lv): - shape = angle.shape - m = ops.zeros((shape[0], 2 * lv + 1, 2 * lv + 1)) - inds = ops.arange(0, 2 * lv + 1, 1) - reversed_inds = ops.arange(2 * lv, -1, -1) - frequencies = ops.arange(lv, -lv - 1, -1) - m[..., inds, reversed_inds] = ops.sin(frequencies * angle[..., None]) - m[..., inds, inds] = ops.cos(frequencies * angle[..., None]) - return m.astype(ms.float16) diff --git a/MindChem/applications/orb/mindchemistry/utils/__init__.py b/MindChem/applications/orb/mindchemistry/utils/__init__.py deleted file mode 100644 index 3f15063d7fe04bc33a33343e0a478392b3dbe093..0000000000000000000000000000000000000000 --- a/MindChem/applications/orb/mindchemistry/utils/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this filepio[] except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""init""" -from .load_config import load_yaml_config - -__all__ = ['load_yaml_config'] diff --git a/MindChem/applications/orb/mindchemistry/utils/check_func.py b/MindChem/applications/orb/mindchemistry/utils/check_func.py deleted file mode 100644 index 711a441fe34c8f70eaa5a43bc8f021ccb7e2a495..0000000000000000000000000000000000000000 --- a/MindChem/applications/orb/mindchemistry/utils/check_func.py +++ /dev/null @@ -1,128 +0,0 @@ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""functions""" -from __future__ import absolute_import - -from mindspore import context - -_SPACE = " " - - -def _convert_to_tuple(params): - if params is None: - return params - if not isinstance(params, (list, tuple)): - params = (params,) - if isinstance(params, list): - params_out = tuple(params) - else: - params_out = params # ✅ 防止未定义 - return params_out - - -def check_param_type(param, param_name, data_type=None, exclude_type=None): - """Check parameter's data type""" - data_type = _convert_to_tuple(data_type) - exclude_type = _convert_to_tuple(exclude_type) - - if data_type and not isinstance(param, data_type): - raise TypeError( - f"The type of {param_name} should be instance of {data_type}, but got {param} with type {type(param)}" - ) - if exclude_type and type(param) in exclude_type: - raise TypeError( - f"The type of {param_name} should not be instance of {exclude_type},but got {param} with type {type(param)}" - ) - - -def check_param_value(param, param_name, valid_value): - """check parameter's value""" - valid_value = _convert_to_tuple(valid_value) - if param not in valid_value: - raise ValueError(f"The value of {param_name} should be in {valid_value}, but got {param}") - - -def check_param_type_value(param, param_name, valid_value, data_type=None, exclude_type=None): - """check both data type and value""" - check_param_type(param, param_name, data_type=data_type, exclude_type=exclude_type) - check_param_value(param, param_name, valid_value) - - -def check_dict_type(param_dict, param_name, key_type=None, value_type=None): - """check data type for key and value of the specified dict""" - check_param_type(param_dict, param_name, data_type=dict) - - for key in param_dict.keys(): - if key_type: - check_param_type(key, _SPACE.join(("key of", param_name)), data_type=key_type) - if value_type: - values = _convert_to_tuple(param_dict[key]) - for value in values: - check_param_type(value, _SPACE.join(("value of", param_name)), data_type=value_type) - - -def check_dict_value(param_dict, param_name, key_value=None, value_value=None): - """check values for key and value of specified dict""" - check_param_type(param_dict, param_name, data_type=dict) - - for key in param_dict.keys(): - if key_value: - check_param_value(key, _SPACE.join(("key of", param_name)), key_value) - if value_value: - values = _convert_to_tuple(param_dict[key]) - for value in values: - check_param_value(value, _SPACE.join(("value of", param_name)), value_value) - - -def check_dict_type_value(param_dict, param_name, key_type=None, value_type=None, key_value=None, value_value=None): - """check values for key and value of specified dict""" - check_dict_type(param_dict, param_name, key_type=key_type, value_type=value_type) - check_dict_value(param_dict, param_name, key_value=key_value, value_value=value_value) - - -def check_mode(api_name): - """check running mode""" - if context.get_context("mode") == context.PYNATIVE_MODE: - raise RuntimeError(f"{api_name} is only supported GRAPH_MODE now but got PYNATIVE_MODE") - - -def check_param_no_greater(param, param_name, compared_value): - """ Check whether the param less than the given compared_value""" - if param > compared_value: - raise ValueError(f"The value of {param_name} should be no greater than {compared_value}, but got {param}") - - -def check_param_odd(param, param_name): - """ Check whether the param is an odd number""" - if param % 2 == 0: - raise ValueError(f"The value of {param_name} should be an odd number, but got {param}") - - -def check_param_even(param, param_name): - """ Check whether the param is an even number""" - for value in param: - if value % 2 != 0: - raise ValueError(f"The value of {param_name} should be an even number, but got {param}") - - -def check_lr_param_type_value(param, param_name, param_type, thresh_hold=0, restrict=False, exclude=None): - if (exclude and isinstance(param, exclude)) or not isinstance(param, param_type): - raise TypeError(f"the type of {param_name} should be {param_type}, but got {type(param)}") - if restrict: - if param <= thresh_hold: - raise ValueError(f"the value of {param_name} should be > {thresh_hold}, but got: {param}") - else: - if param < thresh_hold: - raise ValueError(f"the value of {param_name} should be >= {thresh_hold}, but got: {param}") diff --git a/MindChem/applications/orb/mindchemistry/utils/load_config.py b/MindChem/applications/orb/mindchemistry/utils/load_config.py deleted file mode 100644 index 3ddc76e429db1b851f912c1538e923eec87f26ef..0000000000000000000000000000000000000000 --- a/MindChem/applications/orb/mindchemistry/utils/load_config.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -utility functions -""" -import os -import yaml - - -def _make_paths_absolute(dir_, config): - """ - Make all values for keys ending with `_path` absolute to dir_. - - Args: - dir_ (str): The path of yaml configuration file. - config (dict): The yaml for configuration file. - - Returns: - Dict. The configuration information in dict format. - """ - for key in config.keys(): - if key.endswith("_path"): - config[key] = os.path.join(dir_, config[key]) - config[key] = os.path.abspath(config[key]) - if isinstance(config[key], dict): - config[key] = _make_paths_absolute(dir_, config[key]) - return config - - -def load_yaml_config(file_path): - """ - Load a YAML configuration file. - - Args: - file_path (str): The path of yaml configuration file. - - Returns: - Dict. The configuration information in dict format. - - Supported Platforms: - ``Ascend`` ``CPU`` ``GPU`` - - Examples: - >>> from mindchemistry.utils import load_yaml_config - >>> config_file_path = 'xxx' # 'xxx' is the file_path - >>> configs = load_yaml_config(config_file_path) - """ - # Read YAML experiment definition file - with open(file_path, 'r', encoding='utf-8') as stream: - config = yaml.safe_load(stream) - config = _make_paths_absolute(os.path.join( - os.path.dirname(file_path), ".."), config) - return config - - -def load_yaml_config_from_path(file_path): - """ - Load a YAML configuration file. - - Args: - file_path (str): The path of yaml configuration file. - - Returns: - Dict. The configuration information in dict format. - - Supported Platforms: - ``Ascend`` ``CPU`` ``GPU`` - """ - # Read YAML experiment definition file - with open(file_path, 'r', encoding='utf-8') as stream: - config = yaml.safe_load(stream) - - return config diff --git a/MindChem/applications/orb/mindchemistry/__init__.py b/MindChem/applications/orb/models/__init__.py similarity index 95% rename from MindChem/applications/orb/mindchemistry/__init__.py rename to MindChem/applications/orb/models/__init__.py index 29f929195cd630338b6a4ceddbcaed459d20aecd..e0521e5da263f2b7f01eb3263b5a2367d7d493d9 100644 --- a/MindChem/applications/orb/mindchemistry/__init__.py +++ b/MindChem/applications/orb/models/__init__.py @@ -19,13 +19,11 @@ import mindspore as ms from mindspore import log as logger from mindscience.e3nn import * from .cell import * -from .utils import * -from .graph import * -from .so2_conv import * + __all__ = [] __all__.extend(cell.__all__) -__all__.extend(utils.__all__) + diff --git a/MindChem/applications/orb/mindchemistry/cell/__init__.py b/MindChem/applications/orb/models/cell/__init__.py similarity index 100% rename from MindChem/applications/orb/mindchemistry/cell/__init__.py rename to MindChem/applications/orb/models/cell/__init__.py diff --git a/MindChem/applications/orb/mindchemistry/cell/activation.py b/MindChem/applications/orb/models/cell/activation.py similarity index 100% rename from MindChem/applications/orb/mindchemistry/cell/activation.py rename to MindChem/applications/orb/models/cell/activation.py diff --git a/MindChem/applications/orb/mindchemistry/cell/basic_block.py b/MindChem/applications/orb/models/cell/basic_block.py similarity index 100% rename from MindChem/applications/orb/mindchemistry/cell/basic_block.py rename to MindChem/applications/orb/models/cell/basic_block.py diff --git a/MindChem/applications/orb/mindchemistry/cell/convolution.py b/MindChem/applications/orb/models/cell/convolution.py similarity index 100% rename from MindChem/applications/orb/mindchemistry/cell/convolution.py rename to MindChem/applications/orb/models/cell/convolution.py diff --git a/MindChem/applications/orb/mindchemistry/cell/embedding.py b/MindChem/applications/orb/models/cell/embedding.py similarity index 100% rename from MindChem/applications/orb/mindchemistry/cell/embedding.py rename to MindChem/applications/orb/models/cell/embedding.py diff --git a/MindChem/applications/orb/mindchemistry/cell/message_passing.py b/MindChem/applications/orb/models/cell/message_passing.py similarity index 100% rename from MindChem/applications/orb/mindchemistry/cell/message_passing.py rename to MindChem/applications/orb/models/cell/message_passing.py diff --git a/MindChem/applications/orb/mindchemistry/cell/orb/__init__.py b/MindChem/applications/orb/models/cell/orb/__init__.py similarity index 100% rename from MindChem/applications/orb/mindchemistry/cell/orb/__init__.py rename to MindChem/applications/orb/models/cell/orb/__init__.py diff --git a/MindChem/applications/orb/mindchemistry/cell/orb/gns.py b/MindChem/applications/orb/models/cell/orb/gns.py similarity index 99% rename from MindChem/applications/orb/mindchemistry/cell/orb/gns.py rename to MindChem/applications/orb/models/cell/orb/gns.py index ab53083f8605dbe0463a5d9a0554d5e5990428f9..940fd37e33b3ff9dc0a81b895db4a36c7239c9c0 100644 --- a/MindChem/applications/orb/mindchemistry/cell/orb/gns.py +++ b/MindChem/applications/orb/models/cell/orb/gns.py @@ -24,7 +24,7 @@ from mindspore import nn, ops, Tensor, mint from mindspore.common.initializer import Uniform import mindspore.ops.operations as P -from mindchemistry.cell.orb.utils import build_mlp +from models.cell.orb.utils import build_mlp _KEY = "feat" diff --git a/MindChem/applications/orb/mindchemistry/cell/orb/orb.py b/MindChem/applications/orb/models/cell/orb/orb.py similarity index 99% rename from MindChem/applications/orb/mindchemistry/cell/orb/orb.py rename to MindChem/applications/orb/models/cell/orb/orb.py index 8afc55dbf5e931505ad6747b53ff9fcd4e1fbe23..c0e856ca96d91f8dfae6b741b7847922e3e4dbd7 100644 --- a/MindChem/applications/orb/mindchemistry/cell/orb/orb.py +++ b/MindChem/applications/orb/models/cell/orb/orb.py @@ -21,8 +21,8 @@ import numpy import mindspore as ms from mindspore import Parameter, ops, Tensor, mint -from mindchemistry.cell.orb.gns import _KEY, MoleculeGNS -from mindchemistry.cell.orb.utils import ( +from models.cell.orb.gns import _KEY, MoleculeGNS +from models.cell.orb.utils import ( aggregate_nodes, build_mlp, REFERENCE_ENERGIES, @@ -123,7 +123,7 @@ class ScalarNormalizer(ms.nn.Cell): self.bn = mint.nn.BatchNorm1d(1, affine=False, momentum=None) self.bn.running_mean = Parameter(Tensor([0], ms.float32)) self.bn.running_var = Parameter(Tensor([1], ms.float32)) - self.bn.num_batches_tracked = Parameter(Tensor([1000], ms.float32)) + self.bn.num_batches_tracked = Parameter(Tensor([1000], ms.float32), requires_grad=False) self.stastics = { "running_mean": init_mean if init_mean is not None else 0.0, "running_var": init_std**2 if init_std is not None else 1.0, @@ -138,7 +138,7 @@ class ScalarNormalizer(ms.nn.Cell): if hasattr(self, "running_mean"): return (x - self.running_mean) / mint.sqrt(self.running_var) return (x - self.bn.running_mean) / mint.sqrt(self.bn.running_var) - + def inverse(self, x: Tensor): """Reverse the construct normalization. diff --git a/MindChem/applications/orb/mindchemistry/cell/orb/utils.py b/MindChem/applications/orb/models/cell/orb/utils.py similarity index 100% rename from MindChem/applications/orb/mindchemistry/cell/orb/utils.py rename to MindChem/applications/orb/models/cell/orb/utils.py diff --git a/MindChem/applications/orb/src/pretrained.py b/MindChem/applications/orb/src/pretrained.py index cb66db19fe55e2694a492e836c78d0a175bca67d..8f58ddde61894eb71f7275ac01ee3556adf3b171 100644 --- a/MindChem/applications/orb/src/pretrained.py +++ b/MindChem/applications/orb/src/pretrained.py @@ -20,7 +20,7 @@ from typing import Optional from mindspore import nn, load_checkpoint, load_param_into_net -from mindchemistry.cell import ( +from models.cell import ( EnergyHead, GraphHead, Orb,