From 4bf33da67cb3674da71c7d07372890e3128a49c5 Mon Sep 17 00:00:00 2001 From: xukunming2025 <1652262991@qq.com> Date: Tue, 2 Dec 2025 13:51:35 +0800 Subject: [PATCH 1/7] [SPONGE] refactor orb app structure and models Signed-off-by: xukunming2025 <1652262991@qq.com> --- MindChem/applications/orb/README.md | 85 ++-- MindChem/applications/orb/README_CN.md | 85 ++-- .../applications/orb/configs/config_eval.yaml | 2 +- .../orb/configs/config_parallel.yaml | 4 +- .../orb/mindchemistry/graph/__init__.py | 15 - .../orb/mindchemistry/graph/dataloader.py | 408 ------------------ .../orb/mindchemistry/graph/graph.py | 294 ------------- .../orb/mindchemistry/graph/loss.py | 78 ---- .../orb/mindchemistry/graph/normlization.py | 278 ------------ .../orb/mindchemistry/so2_conv/__init__.py | 19 - .../so2_conv/init_edge_rot_mat.py | 64 --- .../orb/mindchemistry/so2_conv/jd.pkl | Bin 9925 -> 0 bytes .../orb/mindchemistry/so2_conv/so2.py | 260 ----------- .../orb/mindchemistry/so2_conv/so3.py | 156 ------- .../orb/mindchemistry/so2_conv/wigner.py | 61 --- .../orb/mindchemistry/utils/__init__.py | 18 - .../orb/mindchemistry/utils/check_func.py | 128 ------ .../orb/mindchemistry/utils/load_config.py | 85 ---- .../orb/{mindchemistry => models}/__init__.py | 6 +- .../cell/__init__.py | 0 .../cell/activation.py | 0 .../cell/basic_block.py | 0 .../cell/convolution.py | 0 .../cell/embedding.py | 0 .../cell/message_passing.py | 0 .../cell/orb/__init__.py | 0 .../{mindchemistry => models}/cell/orb/gns.py | 2 +- .../{mindchemistry => models}/cell/orb/orb.py | 24 +- .../cell/orb/utils.py | 0 MindChem/applications/orb/src/pretrained.py | 2 +- 30 files changed, 143 insertions(+), 1931 deletions(-) delete mode 100644 MindChem/applications/orb/mindchemistry/graph/__init__.py delete mode 100644 MindChem/applications/orb/mindchemistry/graph/dataloader.py delete mode 100644 MindChem/applications/orb/mindchemistry/graph/graph.py delete mode 100644 MindChem/applications/orb/mindchemistry/graph/loss.py delete mode 100644 MindChem/applications/orb/mindchemistry/graph/normlization.py delete mode 100644 MindChem/applications/orb/mindchemistry/so2_conv/__init__.py delete mode 100644 MindChem/applications/orb/mindchemistry/so2_conv/init_edge_rot_mat.py delete mode 100644 MindChem/applications/orb/mindchemistry/so2_conv/jd.pkl delete mode 100644 MindChem/applications/orb/mindchemistry/so2_conv/so2.py delete mode 100644 MindChem/applications/orb/mindchemistry/so2_conv/so3.py delete mode 100644 MindChem/applications/orb/mindchemistry/so2_conv/wigner.py delete mode 100644 MindChem/applications/orb/mindchemistry/utils/__init__.py delete mode 100644 MindChem/applications/orb/mindchemistry/utils/check_func.py delete mode 100644 MindChem/applications/orb/mindchemistry/utils/load_config.py rename MindChem/applications/orb/{mindchemistry => models}/__init__.py (95%) rename MindChem/applications/orb/{mindchemistry => models}/cell/__init__.py (100%) rename MindChem/applications/orb/{mindchemistry => models}/cell/activation.py (100%) rename MindChem/applications/orb/{mindchemistry => models}/cell/basic_block.py (100%) rename MindChem/applications/orb/{mindchemistry => models}/cell/convolution.py (100%) rename MindChem/applications/orb/{mindchemistry => models}/cell/embedding.py (100%) rename MindChem/applications/orb/{mindchemistry => models}/cell/message_passing.py (100%) rename MindChem/applications/orb/{mindchemistry => models}/cell/orb/__init__.py (100%) rename MindChem/applications/orb/{mindchemistry => models}/cell/orb/gns.py (99%) rename MindChem/applications/orb/{mindchemistry => models}/cell/orb/orb.py (97%) rename MindChem/applications/orb/{mindchemistry => models}/cell/orb/utils.py (100%) diff --git a/MindChem/applications/orb/README.md b/MindChem/applications/orb/README.md index a6598369f..8c85888e5 100644 --- a/MindChem/applications/orb/README.md +++ b/MindChem/applications/orb/README.md @@ -10,7 +10,7 @@ ## Environment Requirements -> 1. Install `mindspore (2.5.0)` +> 1. Install `mindspore (2.7.0)` > 2. Install dependencies: `pip install -r requirement.txt` ## Quick Start @@ -28,32 +28,59 @@ ```text The main code modules are in the src folder, with the dataset folder containing the datasets, the orb_ckpts folder containing pre-trained models and trained model weight files, and the configs folder containing parameter configuration files for each code. -orb_models # Model name +orb_models # Project root directory (ORB pretraining/finetuning project) ├── dataset - ├── train_mptrj_ase.db # Training dataset for fine-tuning stage - └── val_mptrj_ase.db # Test dataset for fine-tuning stage -├── orb_ckpts - └── orb-mptraj-only-v2.ckpt # Pre-trained model checkpoint +│ ├── train_mptrj_ase.db # Training dataset for finetuning (ASE trajectory in SQLite format) +│ └── val_mptrj_ase.db # Validation/test dataset for finetuning +│ +├── orb_ckpts # Directory to store checkpoints +│ └── orb-mptraj-only-v2.ckpt # Pretrained ORB model checkpoint (mptraj-only task) +│ # After training, additional finetuned checkpoints will be saved here +│ ├── configs - ├── config.yaml # Single-card training parameter configuration file - ├── config_parallel.yaml # Multi-card parallel training parameter configuration file - └── config_eval.yaml # Inference parameter configuration file -├── src - ├── __init__.py - ├── ase_dataset.py # Process and load datasets - ├── atomic_system.py # Define data structure for atomic systems - ├── base.py # Base class definitions - ├── featurization_utilities.py # Provide tools to convert atomic systems into feature vectors - ├── pretrained.py # Pre-trained model related functions - ├── property_definitions.py # Define calculation methods and naming rules for various physical properties in atomic systems - ├── trainer.py # Model loss class definitions - ├── segment_ops.py # Provide tools for segmenting data - └── utils.py # Utility module -├── finetune.py # Model fine-tuning code -├── evaluate.py # Model inference code -├── run.sh # Single-card training startup script -├── run_parallel.sh # Multi-card parallel training startup script -└── requirement.txt # Environment +│ ├── config.yaml # Single-device training config (lr, batch_size, etc.) +│ ├── config_parallel.yaml # Multi-device data-parallel training config +│ └── config_eval.yaml # Inference/evaluation config +│ +├── src # Core source code for training and data processing +│ ├── __init__.py # Package initializer for src (for convenient imports) +│ ├── ase_dataset.py # ASE dataset wrapper and loader (read SQLite, build atomic graphs) +│ ├── atomic_system.py # Atomic system data structure (coordinates, atom types, cell info, etc.) +│ ├── base.py # Base classes and common utilities (e.g., batch_graphs for graph batching) +│ ├── featurization_utilities.py # Tools to convert atomic systems into feature tensors +│ ├── pretrained.py # Pretrained ORB model builders and loading helpers +│ ├── property_definitions.py # Config and naming rules for physical properties (energy, forces, stress, etc.) +│ ├── trainer.py # OrbLoss and other loss / training-related wrappers +│ ├── segment_ops.py # Segmented operations (segment_sum/mean/max) for graph reductions +│ └── utils.py # General utilities (seeding, logging, optimizer & LR scheduler, etc.) +│ +├── finetune.py # Main finetuning entry script (loads data/model and starts training) +├── evaluate.py # Inference/evaluation script (run model with finetuned checkpoints) +│ +├── run.sh # Single-device training launcher (calls finetune.py + config.yaml) +├── run_parallel.sh # Multi-device training launcher (msrun + config_parallel.yaml) +├── requirement.txt # Python dependency list (for environment setup) +│ +├── models # Model definition modules (GNN/ORB network structures) +│ ├── __init__.py # Package initializer for models (expose unified model interfaces) +│ │ +│ ├── cell # Basic “cell” building blocks for NN/GNN +│ │ ├── activation.py # Activation functions and non-linear modules +│ │ ├── basic_block.py # Common basic network blocks (MLP blocks, residual blocks, etc.) +│ │ ├── convolution.py # Graph convolutions or spatial convolution layers/operators +│ │ ├── embedding.py # Embedding layers for atom types, edge features, etc. +│ │ ├── __init__.py # Package initializer for cell +│ │ ├── message_passing.py # Message passing layers and main GNN logic +│ │ │ +│ │ ├── orb # ORB-specific network architectures +│ │ │ ├── gns.py # GNS (Graph Network Simulator) related structures/interfaces +│ │ │ ├── __init__.py # Package initializer for orb submodule +│ │ │ ├── orb.py # Main ORB model definition (encoder + heads, etc.) +│ │ │ └── utils.py # Helper functions and internal utilities for ORB models +│ │ +│ └── ... # Reserved: other model architectures can be added here +│ +└── __init__.py ``` ## Download Dataset @@ -134,6 +161,14 @@ Training time: 2375.89474 seconds Training time: 2377.02413 seconds Training time: 2377.22778 seconds Training time: 2376.63176 seconds +[INFO] PS(365484,ffff13fff120,python):2025-12-02-13:43:10.997.606 [mindspore/ccsrc/ps/core/communicator/tcp_server.cc:220] Start] Event base dispatch success! +[INFO] PS(365484,ffff137ef120,python):2025-12-02-13:43:10.997.583 [mindspore/ccsrc/ps/core/communicator/tcp_client.cc:318] Start] Event base dispatch success! +[INFO] PS(365478,ffff36fdf120,python):2025-12-02-13:43:13.013.568 [mindspore/ccsrc/ps/core/communicator/tcp_client.cc:318] Start] Event base dispatch success! +[INFO] PS(365478,ffff377ef120,python):2025-12-02-13:43:13.013.575 [mindspore/ccsrc/ps/core/communicator/tcp_server.cc:220] Start] Event base dispatch success! +[INFO] PS(365488,ffff21fbf120,python):2025-12-02-13:43:15.029.782 [mindspore/ccsrc/ps/core/communicator/tcp_server.cc:220] Start] Event base dispatch success! +[INFO] PS(365488,ffff217af120,python):2025-12-02-13:43:15.029.782 [mindspore/ccsrc/ps/core/communicator/tcp_client.cc:318] Start] Event base dispatch success! +[INFO] PS(365481,ffff2d659120,python):2025-12-02-13:43:15.061.968 [mindspore/ccsrc/ps/core/communicator/tcp_client.cc:318] Start] Event base dispatch success! +[INFO] PS(365481,ffff2de69120,python):2025-12-02-13:43:15.061.956 [mindspore/ccsrc/ps/core/communicator/tcp_server.cc:220] Start] Event base dispatch success! ``` Under the same training configuration, parallel training achieved significant performance improvement compared to single-card training: diff --git a/MindChem/applications/orb/README_CN.md b/MindChem/applications/orb/README_CN.md index 6812131cf..8c1f18a0a 100644 --- a/MindChem/applications/orb/README_CN.md +++ b/MindChem/applications/orb/README_CN.md @@ -10,7 +10,7 @@ ## 环境要求 -> 1. 安装`mindspore(2.5.0)` +> 1. 安装`mindspore(2.7.0)` > 2. 安装依赖包:`pip install -r requirement.txt` ## 快速入门 @@ -28,32 +28,59 @@ ```text 代码主要模块在src文件夹下,其中dataset文件夹下是数据集,orb_ckpts文件夹下是预训练模型和训练好的模型权重文件,configs文件夹下是各代码的参数配置文件。 -orb_models # 模型名 +orb_models # 项目根目录(ORB 预训练/微调工程) ├── dataset - ├── train_mptrj_ase.db # 微调阶段训练数据集 - └── val_mptrj_ase.db # 微调阶段测试数据集 -├── orb_ckpts - └── orb-mptraj-only-v2.ckpt # 预训练模型checkpoint +│ ├── train_mptrj_ase.db # 微调阶段训练数据集(ASE 轨迹 SQLite 格式) +│ └── val_mptrj_ase.db # 微调阶段验证/测试数据集 +│ +├── orb_ckpts # checkpoint 存放目录 +│ └── orb-mptraj-only-v2.ckpt # 预训练 ORB 模型 checkpoint(仅 mptraj 任务) +│ # 训练完成后,会在此目录下额外生成微调后的 ckpt +│ ├── configs - ├── config.yaml # 单卡训练参数配置文件 - ├── config_parallel.yaml # 多卡并行训练参数配置文件 - └── config_eval.yaml # 推理参数配置文件 -├── src - ├── __init__.py - ├── ase_dataset.py # 处理和加载数据集 - ├── atomic_system.py # 定义原子系统的数据结构 - ├── base.py # 基础类定义 - ├── featurization_utilities.py # 提供将原子系统转换为特征向量的工具 - ├── pretrained.py # 预训练模型相关函数 - ├── property_definitions.py # 定义原子系统中各种物理性质的计算方式和命名规则 - ├── trainer.py # 模型loss类定义 - ├── segment_ops.py # 提供对数据进行分段处理的工具 - └── utils.py # 工具模块 -├── finetune.py # 模型微调代码 -├── evaluate.py # 模型推理代码 -├── run.sh # 单卡训练启动脚本 -├── run_parallel.sh # 多卡并行训练启动脚本 -└── requirement.txt # 环境 +│ ├── config.yaml # 单卡训练参数配置(学习率、batch_size 等) +│ ├── config_parallel.yaml # 多卡数据并行训练参数配置 +│ └── config_eval.yaml # 推理/评估阶段参数配置 +│ +├── src # 训练与数据处理的核心源码 +│ ├── __init__.py # src 包初始化,方便外部按模块导入 +│ ├── ase_dataset.py # ASE 数据集封装与加载(读 SQLite、组装原子图) +│ ├── atomic_system.py # 原子系统数据结构定义(坐标、原子种类、晶胞信息等) +│ ├── base.py # 基础类与通用函数(batch_graphs 等图数据打包工具) +│ ├── featurization_utilities.py # 原子系统 → 特征张量的特征化工具 +│ ├── pretrained.py # 预训练 ORB 模型构造与加载接口 +│ ├── property_definitions.py # 能量、力、应力等物理性质的配置与命名规则 +│ ├── trainer.py # OrbLoss 等 loss 类与训练相关封装 +│ ├── segment_ops.py # segment_sum/mean/max 等分段运算算子(图归约用) +│ └── utils.py # 通用工具函数(seed、日志、优化器与 LR scheduler 等) +│ +├── finetune.py # 模型微调入口脚本(加载数据、模型、开始训练) +├── evaluate.py # 模型推理/评估脚本(使用微调好的 ckpt 做 inference) +│ +├── run.sh # 单卡训练启动脚本(调用 finetune.py + config.yaml) +├── run_parallel.sh # 多卡并行训练启动脚本(msrun + config_parallel.yaml) +├── requirement.txt # Python 依赖列表(用于搭建运行环境) +│ +├── models # 模型结构定义模块(GNN/ORB 相关网络) +│ ├── __init__.py # models 包初始化,统一暴露各类模型接口 +│ │ +│ ├── cell # 神经网络/图网络的基础“细胞”模块 +│ │ ├── activation.py # 激活函数与非线性模块封装 +│ │ ├── basic_block.py # 通用基础网络模块(MLP Block、ResBlock 等) +│ │ ├── convolution.py # 图卷积或空间卷积相关层与算子 +│ │ ├── embedding.py # 原子类型、边特征等嵌入层定义 +│ │ ├── __init__.py # cell 子包初始化 +│ │ ├── message_passing.py # 消息传递层与 GNN 主体结构(message passing 核心逻辑) +│ │ │ +│ │ ├── orb # ORB 专用网络结构 +│ │ │ ├── gns.py # GNS(Graph Network Simulator)相关结构/接口 +│ │ │ ├── __init__.py # orb 子包初始化 +│ │ │ ├── orb.py # ORB 主模型结构定义(encoder + heads 等) +│ │ │ └── utils.py # ORB 模型内部使用的工具函数/辅助模块 +│ │ +│ └── ... # 预留:可以扩展其他模型结构 +│ +└── __init__.py ``` ## 下载数据集 @@ -134,6 +161,14 @@ Training time: 2375.89474 seconds Training time: 2377.02413 seconds Training time: 2377.22778 seconds Training time: 2376.63176 seconds +[INFO] PS(365484,ffff13fff120,python):2025-12-02-13:43:10.997.606 [mindspore/ccsrc/ps/core/communicator/tcp_server.cc:220] Start] Event base dispatch success! +[INFO] PS(365484,ffff137ef120,python):2025-12-02-13:43:10.997.583 [mindspore/ccsrc/ps/core/communicator/tcp_client.cc:318] Start] Event base dispatch success! +[INFO] PS(365478,ffff36fdf120,python):2025-12-02-13:43:13.013.568 [mindspore/ccsrc/ps/core/communicator/tcp_client.cc:318] Start] Event base dispatch success! +[INFO] PS(365478,ffff377ef120,python):2025-12-02-13:43:13.013.575 [mindspore/ccsrc/ps/core/communicator/tcp_server.cc:220] Start] Event base dispatch success! +[INFO] PS(365488,ffff21fbf120,python):2025-12-02-13:43:15.029.782 [mindspore/ccsrc/ps/core/communicator/tcp_server.cc:220] Start] Event base dispatch success! +[INFO] PS(365488,ffff217af120,python):2025-12-02-13:43:15.029.782 [mindspore/ccsrc/ps/core/communicator/tcp_client.cc:318] Start] Event base dispatch success! +[INFO] PS(365481,ffff2d659120,python):2025-12-02-13:43:15.061.968 [mindspore/ccsrc/ps/core/communicator/tcp_client.cc:318] Start] Event base dispatch success! +[INFO] PS(365481,ffff2de69120,python):2025-12-02-13:43:15.061.956 [mindspore/ccsrc/ps/core/communicator/tcp_server.cc:220] Start] Event base dispatch success! ``` 在相同的训练配置下,并行训练相比单卡训练取得了显著的性能提升: diff --git a/MindChem/applications/orb/configs/config_eval.yaml b/MindChem/applications/orb/configs/config_eval.yaml index 1e98c5f0b..e0ffea036 100644 --- a/MindChem/applications/orb/configs/config_eval.yaml +++ b/MindChem/applications/orb/configs/config_eval.yaml @@ -6,7 +6,7 @@ device_id: 0 val_data_path: dataset/val_mptrj_ase.db num_workers: 8 batch_size: 64 -checkpoint_path: orb_ckpts/orb-ft-checkpoint_epoch99.ckpt +checkpoint_path: orb_ckpts/orb-mptraj-only-v2.ckpt random_seed: 1234 output_dir: results/ diff --git a/MindChem/applications/orb/configs/config_parallel.yaml b/MindChem/applications/orb/configs/config_parallel.yaml index c6a5e0857..51436b345 100644 --- a/MindChem/applications/orb/configs/config_parallel.yaml +++ b/MindChem/applications/orb/configs/config_parallel.yaml @@ -2,9 +2,9 @@ train_data_path: dataset/train_mptrj_ase.db val_data_path: dataset/val_mptrj_ase.db num_workers: 8 -batch_size: 256 +batch_size: 64 gradient_clip_val: 0.5 -max_epochs: 100 +max_epochs: 2 checkpoint_path: orb_ckpts/ lr: 3.0e-4 random_seed: 666 diff --git a/MindChem/applications/orb/mindchemistry/graph/__init__.py b/MindChem/applications/orb/mindchemistry/graph/__init__.py deleted file mode 100644 index 1ae7d9a34..000000000 --- a/MindChem/applications/orb/mindchemistry/graph/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2024 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""graph""" diff --git a/MindChem/applications/orb/mindchemistry/graph/dataloader.py b/MindChem/applications/orb/mindchemistry/graph/dataloader.py deleted file mode 100644 index 6af294afb..000000000 --- a/MindChem/applications/orb/mindchemistry/graph/dataloader.py +++ /dev/null @@ -1,408 +0,0 @@ -# Copyright 2024 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""dataloader -""" -import random -import numpy as np -from mindspore import Tensor -import mindspore as ms - - -class DataLoaderBase: - r""" - DataLoader that stacks a batch of graph data to fixed-size Tensors - - For specific dataset, usually the following functions should be customized to include different fields: - __init__, shuffle_action, __iter__ - - """ - - def __init__(self, - batch_size, - edge_index, - label=None, - node_attr=None, - edge_attr=None, - padding_std_ratio=3.5, - dynamic_batch_size=True, - shuffle_dataset=True, - max_node=None, - max_edge=None): - self.batch_size = batch_size - self.edge_index = edge_index - self.index = 0 - self.step = 0 - self.padding_std_ratio = padding_std_ratio - self.batch_change_num = 0 - self.batch_exceeding_num = 0 - self.dynamic_batch_size = dynamic_batch_size - self.shuffle_dataset = shuffle_dataset - - ### can be customized to specific dataset - self.label = label - self.node_attr = node_attr - self.edge_attr = edge_attr - self.sample_num = len(self.node_attr) - batch_size_div = self.batch_size - if batch_size_div != 0: - self.step_num = int(self.sample_num / batch_size_div) - else: - raise ValueError - - if dynamic_batch_size: - self.max_start_sample = self.sample_num - else: - self.max_start_sample = self.sample_num - self.batch_size + 1 - - self.set_global_max_node_edge_num(self.node_attr, self.edge_attr, max_node, max_edge, shuffle_dataset, - dynamic_batch_size) - ####### - - def __len__(self): - return self.sample_num - - ### example of generating data of each step, can be customized to specific dataset - def __iter__(self): - if self.shuffle_dataset: - self.shuffle() - else: - self.restart() - - while self.index < self.max_start_sample: - # pylint: disable=W0612 - edge_index_step, node_batch_step, node_mask, edge_mask, batch_size_mask, node_num, edge_num, batch_size \ - = self.gen_common_data(self.node_attr, self.edge_attr) - - ### can be customized to generate different attributes or labels according to specific dataset - node_attr_step = self.gen_node_attr(self.node_attr, batch_size, node_num) - edge_attr_step = self.gen_edge_attr(self.edge_attr, batch_size, edge_num) - label_step = self.gen_global_attr(self.label, batch_size) - - self.add_step_index(batch_size) - - ### make number to Tensor, if it is used as a Tensor in the network - node_num = Tensor(node_num) - batch_size = Tensor(batch_size) - - yield node_attr_step, edge_attr_step, label_step, edge_index_step, node_batch_step, \ - node_mask, edge_mask, node_num, batch_size - - @staticmethod - def pad_zero_to_end(src, axis, zeros_len): - """pad_zero_to_end""" - pad_shape = [] - for i in range(src.ndim): - if i == axis: - pad_shape.append((0, zeros_len)) - else: - pad_shape.append((0, 0)) - return np.pad(src, pad_shape) - - @staticmethod - def gen_mask(total_len, real_len): - """gen_mask""" - mask = np.concatenate((np.full((real_len,), np.float32(1)), np.full((total_len - real_len,), np.float32(0)))) - return mask - - ### example of computing global max length of node_attr and edge_attr, can be customized to specific dataset - def set_global_max_node_edge_num(self, - node_attr, - edge_attr, - max_node=None, - max_edge=None, - shuffle_dataset=True, - dynamic_batch_size=True): - """set_global_max_node_edge_num - - Args: - node_attr: node_attr - edge_attr: edge_attr - max_node: max_node. Defaults to None. - max_edge: max_edge. Defaults to None. - shuffle_dataset: shuffle_dataset. Defaults to True. - dynamic_batch_size: dynamic_batch_size. Defaults to True. - - Raises: - ValueError: ValueError - """ - if not shuffle_dataset: - max_node_num, max_edge_num = self.get_max_node_edge_num(node_attr, edge_attr, dynamic_batch_size) - self.max_node_num_global = max_node_num if max_node is None else max(max_node, max_node_num) - self.max_edge_num_global = max_edge_num if max_edge is None else max(max_edge, max_edge_num) - return - - sum_node = 0 - sum_edge = 0 - count = 0 - max_node_single = 0 - max_edge_single = 0 - for step in range(self.sample_num): - node_len = len(node_attr[step]) - edge_len = len(edge_attr[step]) - sum_node += node_len - sum_edge += edge_len - max_node_single = max(max_node_single, node_len) - max_edge_single = max(max_edge_single, edge_len) - count += 1 - if count != 0: - mean_node = sum_node / count - mean_edge = sum_edge / count - else: - raise ValueError - - if max_node is not None and max_edge is not None: - if max_node < max_node_single: - raise ValueError( - f"the max_node {max_node} is less than the max length of a single sample {max_node_single}") - if max_edge < max_edge_single: - raise ValueError( - f"the max_edge {max_edge} is less than the max length of a single sample {max_edge_single}") - - self.max_node_num_global = max_node - self.max_edge_num_global = max_edge - elif max_node is None and max_edge is None: - sum_node = 0 - sum_edge = 0 - for step in range(self.sample_num): - sum_node += (len(node_attr[step]) - mean_node) ** 2 - sum_edge += (len(edge_attr[step]) - mean_edge) ** 2 - - if count != 0: - std_node = np.sqrt(sum_node / count) - std_edge = np.sqrt(sum_edge / count) - else: - raise ValueError - - self.max_node_num_global = int(self.batch_size * mean_node + - self.padding_std_ratio * np.sqrt(self.batch_size) * std_node) - self.max_edge_num_global = int(self.batch_size * mean_edge + - self.padding_std_ratio * np.sqrt(self.batch_size) * std_edge) - self.max_node_num_global = max(self.max_node_num_global, max_node_single) - self.max_edge_num_global = max(self.max_edge_num_global, max_edge_single) - elif max_node is None: - if max_edge < max_edge_single: - raise ValueError( - f"the max_edge {max_edge} is less than the max length of a single sample {max_edge_single}") - - if mean_edge != 0: - self.max_node_num_global = int(max_edge * mean_node / mean_edge) - else: - raise ValueError - self.max_node_num_global = max(self.max_node_num_global, max_node_single) - self.max_edge_num_global = max_edge - else: - if max_node < max_node_single: - raise ValueError( - f"the max_node {max_node} is less than the max length of a single sample {max_node_single}") - - self.max_node_num_global = max_node - if mean_node != 0: - self.max_edge_num_global = int(max_node * mean_edge / mean_node) - else: - raise ValueError - self.max_edge_num_global = max(self.max_edge_num_global, max_edge_single) - - def get_max_node_edge_num(self, node_attr, edge_attr, remainder=True): - """get_max_node_edge_num - - Args: - node_attr: node_attr - edge_attr: edge_attr - remainder (bool, optional): remainder. Defaults to True. - - Returns: - max_node_num, max_edge_num - """ - max_node_num = 0 - max_edge_num = 0 - index = 0 - for _ in range(self.step_num): - node_num = 0 - edge_num = 0 - for _ in range(self.batch_size): - node_num += len(node_attr[index]) - edge_num += len(edge_attr[index]) - index += 1 - max_node_num = max(max_node_num, node_num) - max_edge_num = max(max_edge_num, edge_num) - - if remainder: - remain_num = self.sample_num - index - 1 - node_num = 0 - edge_num = 0 - for _ in range(remain_num): - node_num += len(node_attr[index]) - edge_num += len(edge_attr[index]) - index += 1 - max_node_num = max(max_node_num, node_num) - max_edge_num = max(max_edge_num, edge_num) - - return max_node_num, max_edge_num - - def shuffle_index(self): - """shuffle_index""" - indices = list(range(self.sample_num)) - random.shuffle(indices) - return indices - - ### example of shuffling the input dataset, can be customized to specific dataset - def shuffle_action(self): - """shuffle_action""" - indices = self.shuffle_index() - self.edge_index = [self.edge_index[i] for i in indices] - self.label = [self.label[i] for i in indices] - self.node_attr = [self.node_attr[i] for i in indices] - self.edge_attr = [self.edge_attr[i] for i in indices] - - ### example of generating the final shuffled dataset, can be customized to specific dataset - def shuffle(self): - """shuffle""" - self.shuffle_action() - if not self.dynamic_batch_size: - max_node_num, max_edge_num = self.get_max_node_edge_num(self.node_attr, self.edge_attr, remainder=False) - while max_node_num > self.max_node_num_global or max_edge_num > self.max_edge_num_global: - self.shuffle_action() - max_node_num, max_edge_num = self.get_max_node_edge_num(self.node_attr, self.edge_attr, remainder=False) - - self.step = 0 - self.index = 0 - - def restart(self): - """restart""" - self.step = 0 - self.index = 0 - - ### example of calculating dynamic batch size to avoid exceeding the max length of node and edge, can be customized to specific dataset - def get_batch_size(self, node_attr, edge_attr, start_batch_size): - """get_batch_size - - Args: - node_attr: node_attr - edge_attr: edge_attr - start_batch_size: start_batch_size - - Returns: - batch_size - """ - node_num = 0 - edge_num = 0 - for i in range(start_batch_size): - index = self.index + i - node_num += len(node_attr[index]) - edge_num += len(edge_attr[index]) - - exceeding = False - while node_num > self.max_node_num_global or edge_num > self.max_edge_num_global: - node_num -= len(node_attr[index]) - edge_num -= len(edge_attr[index]) - index -= 1 - exceeding = True - self.batch_exceeding_num += 1 - if exceeding: - self.batch_change_num += 1 - - return index - self.index + 1 - - def gen_common_data(self, node_attr, edge_attr): - """gen_common_data - - Args: - node_attr: node_attr - edge_attr: edge_attr - - Returns: - common_data - """ - if self.dynamic_batch_size: - if self.step >= self.step_num: - batch_size = self.get_batch_size(node_attr, edge_attr, - min((self.sample_num - self.index), self.batch_size)) - else: - batch_size = self.get_batch_size(node_attr, edge_attr, self.batch_size) - else: - batch_size = self.batch_size - - ######################## node_batch - node_batch_step = [] - sample_num = 0 - for i in range(self.index, self.index + batch_size): - node_batch_step.extend([sample_num] * node_attr[i].shape[0]) - sample_num += 1 - node_batch_step = np.array(node_batch_step) - node_num = node_batch_step.shape[0] - - ######################## edge_index - edge_index_step = np.array([[], []], dtype=np.int64) - max_edge_index = 0 - for i in range(self.index, self.index + batch_size): - edge_index_step = np.concatenate((edge_index_step, self.edge_index[i] + max_edge_index), 1) - max_edge_index = np.max(edge_index_step) + 1 - edge_num = edge_index_step.shape[1] - - ######################### padding - edge_index_step = self.pad_zero_to_end(edge_index_step, 1, self.max_edge_num_global - edge_num) - node_batch_step = self.pad_zero_to_end(node_batch_step, 0, self.max_node_num_global - node_num) - - ######################### mask - node_mask = self.gen_mask(self.max_node_num_global, node_num) - edge_mask = self.gen_mask(self.max_edge_num_global, edge_num) - batch_size_mask = self.gen_mask(self.batch_size, batch_size) - - ######################### make Tensor - edge_index_step = Tensor(edge_index_step, ms.int32) - node_batch_step = Tensor(node_batch_step, ms.int32) - node_mask = Tensor(node_mask) - edge_mask = Tensor(edge_mask) - batch_size_mask = Tensor(batch_size_mask) - - return CommonData(edge_index_step, node_batch_step, node_mask, edge_mask, batch_size_mask, node_num, edge_num, - batch_size).get_tuple_data() - - def gen_node_attr(self, node_attr, batch_size, node_num): - """gen_node_attr""" - node_attr_step = np.concatenate(node_attr[self.index:self.index + batch_size], 0) - node_attr_step = self.pad_zero_to_end(node_attr_step, 0, self.max_node_num_global - node_num) - node_attr_step = Tensor(node_attr_step) - return node_attr_step - - def gen_edge_attr(self, edge_attr, batch_size, edge_num): - """gen_edge_attr""" - edge_attr_step = np.concatenate(edge_attr[self.index:self.index + batch_size], 0) - edge_attr_step = self.pad_zero_to_end(edge_attr_step, 0, self.max_edge_num_global - edge_num) - edge_attr_step = Tensor(edge_attr_step) - return edge_attr_step - - def gen_global_attr(self, global_attr, batch_size): - """gen_global_attr""" - global_attr_step = np.stack(global_attr[self.index:self.index + batch_size], 0) - global_attr_step = self.pad_zero_to_end(global_attr_step, 0, self.batch_size - batch_size) - global_attr_step = Tensor(global_attr_step) - return global_attr_step - - def add_step_index(self, batch_size): - """add_step_index""" - self.index = self.index + batch_size - self.step += 1 - -class CommonData: - """CommonData""" - def __init__(self, edge_index_step, node_batch_step, node_mask, edge_mask, batch_size_mask, node_num, edge_num, - batch_size): - self.tuple_data = (edge_index_step, node_batch_step, node_mask, edge_mask, batch_size_mask, node_num, edge_num, - batch_size) - - def get_tuple_data(self): - """get_tuple_data""" - return self.tuple_data diff --git a/MindChem/applications/orb/mindchemistry/graph/graph.py b/MindChem/applications/orb/mindchemistry/graph/graph.py deleted file mode 100644 index 1f13b8869..000000000 --- a/MindChem/applications/orb/mindchemistry/graph/graph.py +++ /dev/null @@ -1,294 +0,0 @@ -# Copyright 2024 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""graph""" -import mindspore as ms -from mindspore import ops, nn - - -def degree(index, dim_size, mask=None): - r""" - Computes the degree of a one-dimensional index tensor. - """ - if index.ndim != 1: - raise ValueError(f"the dimension of index {index.ndim} is not equal to 1") - - if mask is not None: - if mask.shape[0] != index.shape[0]: - raise ValueError(f"mask.shape[0] {mask.shape[0]} is not equal to index.shape[0] {index.shape[0]}") - if mask.ndim != 1: - st = [0] * mask.ndim - slice_size = [1] * mask.ndim - slice_size[0] = mask.shape[0] - mask = ops.slice(mask, st, slice_size).squeeze() - src = mask.astype(ms.int32) - else: - src = ops.ones(index.shape, ms.int32) - - index = index.unsqueeze(-1) - out = ops.zeros((dim_size,), ms.int32) - - return ops.tensor_scatter_add(out, index, src) - - -class Aggregate(nn.Cell): - r""" - Easy-use version of scatter. - - Args: - mode (str): {'add', 'sum', 'mean', 'avg'}, scatter mode. - - Raises: - ValueError: If `mode` is not legal. - - Supported Platforms: - ``CPU`` ``GPU`` ``Ascend`` - - """ - - def __init__(self, mode='add'): - super().__init__() - self.mode = mode - if mode in ('add', 'sum'): - self.scatter = self.scatter_sum - elif mode in ('mean', 'avg'): - self.scatter = self.scatter_mean - else: - raise ValueError(f"Unexpected scatter mode {mode}") - - @staticmethod - def scatter_sum(src, index, out=None, dim_size=None, mask=None): - r""" - Computes the scatter sum of a source tensor. The index should be one-dimensional - """ - if index.ndim != 1: - raise ValueError(f"the dimension of index {index.ndim} is not equal to 1") - if index.shape[0] != src.shape[0]: - raise ValueError(f"index.shape[0] {index.shape[0]} is not equal to src.shape[0] {src.shape[0]}") - if out is None and dim_size is None: - raise ValueError("the out Tensor and out dim_size cannot be both None") - - index = index.unsqueeze(-1) - - if out is None: - out = ops.zeros((dim_size,) + src.shape[1:], dtype=src.dtype) - elif dim_size is not None and out.shape[0] != dim_size: - raise ValueError(f"the out.shape[0] {out.shape[0]} is not equal to dim_size {dim_size}") - - if mask is not None: - if mask.shape[0] != src.shape[0]: - raise ValueError(f"mask.shape[0] {mask.shape[0]} is not equal to src.shape[0] {src.shape[0]}") - if src.ndim != mask.ndim: - if mask.size != mask.shape[0]: - raise ValueError("mask.ndim dose not match src.ndim, and cannot be broadcasted to the same") - shape = [1] * src.ndim - shape[0] = -1 - mask = ops.reshape(mask, shape) - src = ops.mul(src, mask.astype(src.dtype)) - - return ops.tensor_scatter_add(out, index, src) - - @staticmethod - def scatter_mean(src, index, out=None, dim_size=None, mask=None): - r""" - Computes the scatter mean of a source tensor. The index should be one-dimensional - """ - if out is None and dim_size is None: - raise ValueError("the out Tensor and out dim_size cannot be both None") - - if dim_size is None: - dim_size = out.shape[0] - elif out is not None and out.shape[0] != dim_size: - raise ValueError(f"the out.shape[0] {out.shape[0]} is not equal to dim_size {dim_size}") - - count = degree(index, dim_size, mask=mask) - eps = 1e-5 - count = ops.maximum(count, eps) - - scatter_sum = Aggregate.scatter_sum(src, index, dim_size=dim_size, mask=mask) - - shape = [1] * scatter_sum.ndim - shape[0] = -1 - count = ops.reshape(count, shape).astype(scatter_sum.dtype) - res = ops.true_divide(scatter_sum, count) - - if out is not None: - res = res + out - - return res - - -class AggregateNodeToGlobal(Aggregate): - """AggregateNodeToGlobal""" - - def __init__(self, mode='add'): - super().__init__(mode=mode) - - def construct(self, node_attr, batch, out=None, dim_size=None, mask=None): - r""" - Args: - node_attr (Tensor): The source tensor of node attributes. - batch (Tensor): The indices of sample to scatter to. - out (Tensor): The destination tensor. Default: None. - dim_size (int): If `out` is not given, automatically create output with size `dim_size`. Default: None. - out and dim_size cannot be both None. - mask (Tensor): The mask of the node_attr tensor - Returns: - Tensor. - """ - return self.scatter(node_attr, batch, out=out, dim_size=dim_size, mask=mask) - - -class AggregateEdgeToGlobal(Aggregate): - """AggregateEdgeToGlobal""" - - def __init__(self, mode='add'): - super().__init__(mode=mode) - - def construct(self, edge_attr, batch_edge, out=None, dim_size=None, mask=None): - r""" - Args: - edge_attr (Tensor): The source tensor of edge attributes. - batch_edge (Tensor): The indices of sample to scatter to. - out (Tensor): The destination tensor. Default: None. - dim_size (int): If `out` is not given, automatically create output with size `dim_size`. Default: None. - out and dim_size cannot be both None. - mask (Tensor): The mask of the node_attr tensor - Returns: - Tensor. - """ - return self.scatter(edge_attr, batch_edge, out=out, dim_size=dim_size, mask=mask) - - -class AggregateEdgeToNode(Aggregate): - """AggregateEdgeToNode""" - - def __init__(self, mode='add', dim=0): - super().__init__(mode=mode) - self.dim = dim - - def construct(self, edge_attr, edge_index, out=None, dim_size=None, mask=None): - r""" - Args: - edge_attr (Tensor): The source tensor of edge attributes. - edge_index (Tensor): The indices of nodes in each edge. - out (Tensor): The destination tensor. Default: None. - dim_size (int): If `out` is not given, automatically create output with size `dim_size`. Default: None. - out and dim_size cannot be both None. - mask (Tensor): The mask of the node_attr tensor - Returns: - Tensor. - """ - return self.scatter(edge_attr, edge_index[self.dim], out=out, dim_size=dim_size, mask=mask) - - -class Lift(nn.Cell): - """Lift""" - - def __init__(self, mode="multi_graph"): - super().__init__() - self.mode = mode - if mode not in ["multi_graph", "single_graph"]: - raise ValueError(f"Unexpected lift mode {mode}") - - @staticmethod - def lift(src, index, axis=0, mask=None): - """lift""" - res = ops.index_select(src, axis, index) - - if mask is not None: - if mask.shape[0] != res.shape[0]: - raise ValueError(f"mask.shape[0] {mask.shape[0]} is not equal to res.shape[0] {res.shape[0]}") - if res.ndim != mask.ndim: - if mask.size != mask.shape[0]: - raise ValueError("mask.ndim dose not match src.ndim, and cannot be broadcasted to the same") - shape = [1] * res.ndim - shape[0] = -1 - mask = ops.reshape(mask, shape) - res = ops.mul(res, mask.astype(res.dtype)) - - return res - - @staticmethod - def repeat(src, num, axis=0, max_len=None): - res = ops.repeat_elements(src, num, axis) - - if (max_len is not None) and (max_len > num): - padding = ops.zeros((max_len - num,) + res.shape[1:], dtype=res.dtype) - res = ops.cat((res, padding), axis=0) - - return res - - -class LiftGlobalToNode(Lift): - """LiftGlobalToNode""" - - def __init__(self, mode="multi_graph"): - super().__init__(mode=mode) - - def construct(self, global_attr, batch=None, num_node=None, mask=None, max_len=None): - r""" - Args: - global_attr (Tensor): The source tensor of global attributes. - batch (Tensor): The indices of samples to get. - num_node (Int): The number of node in the graph, when there is only 1 graph. - mask (Tensor): The mask of the output tensor. - max_len (Int): The output length. - Returns: - Tensor. - """ - if global_attr.shape[0] > 1 or self.mode == "multi_graph": - return self.lift(global_attr, batch, mask=mask) - return self.repeat(global_attr, num_node, max_len=max_len) - - -class LiftGlobalToEdge(Lift): - """LiftGlobalToEdge""" - - def __init__(self, mode="multi_graph"): - super().__init__(mode=mode) - - def construct(self, global_attr, batch_edge=None, num_edge=None, mask=None, max_len=None): - r""" - Args: - global_attr (Tensor): The source tensor of global attributes. - batch_edge (Tensor): The indices of samples to get. - num_edge (Int): The number of edge in the graph, when there is only 1 graph. - mask (Tensor): The mask of the output tensor. - max_len (Int): The output length. - Returns: - Tensor. - """ - if global_attr.shape[0] > 1 or self.mode == "multi_graph": - return self.lift(global_attr, batch_edge, mask=mask) - return self.repeat(global_attr, num_edge, max_len=max_len) - - -class LiftNodeToEdge(Lift): - """LiftNodeToEdge""" - - def __init__(self, dim=0): - super().__init__(mode="multi_graph") - self.dim = dim - - def construct(self, global_attr, edge_index, mask=None): - r""" - Args: - global_attr (Tensor): The source tensor of global attributes. - edge_index (Tensor): The indices of nodes for each edge. - mask (Tensor): The mask of the output tensor. - Returns: - Tensor. - """ - return self.lift(global_attr, edge_index[self.dim], mask=mask) diff --git a/MindChem/applications/orb/mindchemistry/graph/loss.py b/MindChem/applications/orb/mindchemistry/graph/loss.py deleted file mode 100644 index 28e241e6f..000000000 --- a/MindChem/applications/orb/mindchemistry/graph/loss.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright 2024 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""loss""" -import mindspore as ms -from mindspore import ops, nn - - -class LossMaskBase(nn.Cell): - """LossMaskBase""" - - def __init__(self, reduction='mean', dtype=None): - super().__init__() - self.reduction = reduction - if reduction not in ["mean", "sum"]: - raise ValueError(f"Unexpected reduction mode {reduction}") - - self.dtype = dtype if dtype is not None else ms.float32 - - def construct(self, logits, labels, mask=None, num=None): - """construct""" - if logits.shape != labels.shape: - raise ValueError(f"logits.shape {logits.shape} is not equal to labels.shape {labels.shape}") - - x = self.loss(logits.astype(self.dtype), labels.astype(self.dtype)) - - if mask is not None: - if mask.shape[0] != x.shape[0]: - raise ValueError(f"mask.shape[0] {mask.shape[0]} is not equal to input.shape[0] {x.shape[0]}") - if x.ndim != mask.ndim: - if mask.size != mask.shape[0]: - raise ValueError("mask.ndim dose not match src.ndim, and cannot be broadcasted to the same") - shape = [1] * x.ndim - shape[0] = -1 - mask = ops.reshape(mask, shape) - x = ops.mul(x, mask.astype(x.dtype)) - - # pylint: disable=W0622 - sum = ops.sum(x) - if self.reduction == "sum": - return sum - if num is None: - num = x.size - else: - num_div = x.shape[0] - if num_div != 0: - num = x.size / num_div * num - else: - raise ValueError - return ops.true_divide(sum, num) - -class L1LossMask(LossMaskBase): - - def __init__(self, reduction='mean'): - super().__init__(reduction) - - def loss(self, logits, labels): - return ops.abs(logits - labels) - - -class L2LossMask(LossMaskBase): - - def __init__(self, reduction='mean'): - super().__init__(reduction) - - def loss(self, logits, labels): - return ops.square(logits - labels) diff --git a/MindChem/applications/orb/mindchemistry/graph/normlization.py b/MindChem/applications/orb/mindchemistry/graph/normlization.py deleted file mode 100644 index caccfdb42..000000000 --- a/MindChem/applications/orb/mindchemistry/graph/normlization.py +++ /dev/null @@ -1,278 +0,0 @@ -# Copyright 2024 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""norm""" -import mindspore as ms -from mindspore import ops, Parameter, nn -from .graph import AggregateNodeToGlobal, LiftGlobalToNode - - -class BatchNormMask(nn.Cell): - """BatchNormMask""" - - def __init__(self, num_features, eps=1e-5, momentum=0.9, affine=True): - super().__init__() - self.num_features = num_features - self.eps = eps - self.momentum = momentum - self.affine = affine - self.moving_mean = Parameter(ops.zeros((num_features,), ms.float32), name="moving_mean", requires_grad=False) - self.moving_variance = Parameter(ops.ones((num_features,), ms.float32), - name="moving_variance", - requires_grad=False) - if affine: - self.gamma = Parameter(ops.ones((num_features,), ms.float32), name="gamma", requires_grad=True) - self.beta = Parameter(ops.zeros((num_features,), ms.float32), name="beta", requires_grad=True) - - def construct(self, x, mask, num): - """construct""" - if x.shape[1] != self.num_features: - raise ValueError(f"x.shape[1] {x.shape[1]} is not equal to num_features {self.num_features}") - if x.shape[0] != mask.shape[0]: - raise ValueError(f"x.shape[0] {x.shape[0]} is not equal to mask.shape[0] {mask.shape[0]}") - - if x.ndim != mask.ndim: - if mask.size != mask.shape[0]: - raise ValueError("mask.ndim dose not match src.ndim, and cannot be broadcasted to the same") - shape = [1] * x.ndim - shape[0] = -1 - mask = ops.reshape(mask, shape).astype(x.dtype) - x = ops.mul(x, mask) - - # pylint: disable=R1705 - if x.ndim > 2: - norm_axis = [] - shape = [-1] - for i in range(2, x.ndim): - norm_axis.append(i) - shape.append(1) - - if self.training: - mean = ops.div(ops.sum(x, 0), num) - mean = ops.mean(mean, norm_axis) - self.moving_mean = self.momentum * self.moving_mean + (1 - self.momentum) * mean - mean = ops.reshape(mean, shape) - mean = ops.mul(mean, mask) - x = x - mean - - var = ops.div(ops.sum(ops.pow(x, 2), 0), num) - var = ops.mean(var, norm_axis) - self.moving_variance = self.momentum * self.moving_variance + (1 - self.momentum) * var - std = ops.sqrt(ops.add(var, self.eps)) - std = ops.reshape(std, shape) - y = ops.true_divide(x, std) - else: - mean = ops.reshape(self.moving_mean.astype(x.dtype), shape) - mean = ops.mul(mean, mask) - std = ops.sqrt(ops.add(self.moving_variance.astype(x.dtype), self.eps)) - std = ops.reshape(std, shape) - y = ops.true_divide(ops.sub(x, mean), std) - - if self.affine: - gamma = ops.reshape(self.gamma.astype(x.dtype), shape) - beta = ops.reshape(self.beta.astype(x.dtype), shape) * mask - y = y * gamma + beta - - return y - else: - if self.training: - mean = ops.div(ops.sum(x, 0), num) - self.moving_mean = self.momentum * self.moving_mean + (1 - self.momentum) * mean - mean = ops.mul(mean, mask) - x = x - mean - - var = ops.div(ops.sum(ops.pow(x, 2), 0), num) - self.moving_variance = self.momentum * self.moving_variance + (1 - self.momentum) * var - std = ops.sqrt(ops.add(var, self.eps)) - y = ops.true_divide(x, std) - else: - mean = ops.mul(self.moving_mean.astype(x.dtype), mask) - std = ops.sqrt(ops.add(self.moving_variance.astype(x.dtype), self.eps)) - y = ops.true_divide(ops.sub(x, mean), std) - - if self.affine: - beta = self.beta.astype(x.dtype) * mask - y = y * self.gamma.astype(x.dtype) + beta - - return y - - -class GraphLayerNormMask(nn.Cell): - """GraphLayerNormMask""" - - def __init__(self, - normalized_shape, - begin_norm_axis=-1, - eps=1e-5, - sub_mean=True, - divide_std=True, - affine_weight=True, - affine_bias=True, - aggr_mode="mean"): - super().__init__() - self.normalized_shape = normalized_shape - self.begin_norm_axis = begin_norm_axis - self.eps = eps - self.sub_mean = sub_mean - self.divide_std = divide_std - self.affine_weight = affine_weight - self.affine_bias = affine_bias - self.mean = ops.ReduceMean(keep_dims=True) - self.aggregate = AggregateNodeToGlobal(mode=aggr_mode) - self.lift = LiftGlobalToNode(mode="multi_graph") - - if affine_weight: - self.gamma = Parameter(ops.ones(normalized_shape, ms.float32), name="gamma", requires_grad=True) - if affine_bias: - self.beta = Parameter(ops.zeros(normalized_shape, ms.float32), name="beta", requires_grad=True) - - def construct(self, x, batch, mask, dim_size, scale=None): - """construct""" - begin_norm_axis = self.begin_norm_axis if self.begin_norm_axis >= 0 else self.begin_norm_axis + x.ndim - if begin_norm_axis not in range(1, x.ndim): - raise ValueError(f"begin_norm_axis {begin_norm_axis} is not in range 1 to {x.ndim}") - - norm_axis = [] - for i in range(begin_norm_axis, x.ndim): - norm_axis.append(i) - if self.normalized_shape[i - begin_norm_axis] != x.shape[i]: - raise ValueError(f"x.shape[{i}] {x.shape[i]} is not equal to normalized_shape[{i - begin_norm_axis}] " - f"{self.normalized_shape[i - begin_norm_axis]}") - - if x.shape[0] != mask.shape[0]: - raise ValueError(f"x.shape[0] {x.shape[0]} is not equal to mask.shape[0] {mask.shape[0]}") - if x.shape[0] != batch.shape[0]: - raise ValueError(f"x.shape[0] {x.shape[0]} is not equal to batch.shape[0] {batch.shape[0]}") - - if x.ndim != mask.ndim: - if mask.size != mask.shape[0]: - raise ValueError("mask.ndim dose not match src.ndim, and cannot be broadcasted to the same") - shape = [1] * x.ndim - shape[0] = -1 - mask = ops.reshape(mask, shape).astype(x.dtype) - x = ops.mul(x, mask) - - if self.sub_mean: - mean = self.aggregate(x, batch, dim_size=dim_size, mask=mask) - mean = self.mean(mean, norm_axis) - mean = self.lift(mean, batch) - mean = ops.mul(mean, mask) - x = x - mean - - if self.divide_std: - var = self.aggregate(ops.square(x), batch, dim_size=dim_size, mask=mask) - var = self.mean(var, norm_axis) - if scale is not None: - var = var * scale - std = ops.sqrt(var + self.eps) - std = self.lift(std, batch) - x = ops.true_divide(x, std) - - if self.affine_weight: - x = x * self.gamma.astype(x.dtype) - - if self.affine_bias: - beta = ops.mul(self.beta.astype(x.dtype), mask) - x = x + beta - - return x - - -class GraphInstanceNormMask(nn.Cell): - """GraphInstanceNormMask""" - - def __init__(self, - num_features, - eps=1e-5, - sub_mean=True, - divide_std=True, - affine_weight=True, - affine_bias=True, - aggr_mode="mean"): - super().__init__() - self.num_features = num_features - self.eps = eps - self.sub_mean = sub_mean - self.divide_std = divide_std - self.affine_weight = affine_weight - self.affine_bias = affine_bias - self.mean = ops.ReduceMean(keep_dims=True) - self.aggregate = AggregateNodeToGlobal(mode=aggr_mode) - self.lift = LiftGlobalToNode(mode="multi_graph") - - if affine_weight: - self.gamma = Parameter(ops.ones((self.num_features,), ms.float32), name="gamma", requires_grad=True) - if affine_bias: - self.beta = Parameter(ops.zeros((self.num_features,), ms.float32), name="beta", requires_grad=True) - - def construct(self, x, batch, mask, dim_size, scale=None): - """construct""" - if x.shape[1] != self.num_features: - raise ValueError(f"x.shape[1] {x.shape[1]} is not equal to num_features {self.num_features}") - if x.shape[0] != mask.shape[0]: - raise ValueError(f"x.shape[0] {x.shape[0]} is not equal to mask.shape[0] {mask.shape[0]}") - if x.shape[0] != batch.shape[0]: - raise ValueError(f"x.shape[0] {x.shape[0]} is not equal to batch.shape[0] {batch.shape[0]}") - - if x.ndim != mask.ndim: - if mask.size != mask.shape[0]: - raise ValueError("mask.ndim dose not match src.ndim, and cannot be broadcasted to the same") - shape = [1] * x.ndim - shape[0] = -1 - mask = ops.reshape(mask, shape).astype(x.dtype) - x = ops.mul(x, mask) - gamma = None # 后来添加,防止未定义报错 - if x.ndim > 2: - norm_axis = [] - shape = [-1] - for i in range(2, x.ndim): - norm_axis.append(i) - shape.append(1) - - if self.affine_weight: - gamma = ops.reshape(self.gamma.astype(x.dtype), shape) - if self.affine_bias: - beta = ops.reshape(self.beta.astype(x.dtype), shape) - else: - if self.affine_weight: - gamma = self.gamma.astype(x.dtype) - if self.affine_bias: - beta = self.beta.astype(x.dtype) - - if self.sub_mean: - mean = self.aggregate(x, batch, dim_size=dim_size, mask=mask) - if x.ndim > 2: - mean = self.mean(mean, norm_axis) - mean = self.lift(mean, batch) - mean = ops.mul(mean, mask) - x = x - mean - - if self.divide_std: - var = self.aggregate(ops.square(x), batch, dim_size=dim_size, mask=mask) - if x.ndim > 2: - var = self.mean(var, norm_axis) - if scale is not None: - var = var * scale - std = ops.sqrt(var + self.eps) - std = self.lift(std, batch) - x = ops.true_divide(x, std) - - if self.affine_weight: - x = x * gamma - - if self.affine_bias: - beta = ops.mul(beta, mask) - x = x + beta - - return x diff --git a/MindChem/applications/orb/mindchemistry/so2_conv/__init__.py b/MindChem/applications/orb/mindchemistry/so2_conv/__init__.py deleted file mode 100644 index e5542a477..000000000 --- a/MindChem/applications/orb/mindchemistry/so2_conv/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright 2024 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -init file -""" -from .so3 import SO3Rotation -from .so2 import SO2Convolution diff --git a/MindChem/applications/orb/mindchemistry/so2_conv/init_edge_rot_mat.py b/MindChem/applications/orb/mindchemistry/so2_conv/init_edge_rot_mat.py deleted file mode 100644 index a05a2264b..000000000 --- a/MindChem/applications/orb/mindchemistry/so2_conv/init_edge_rot_mat.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright 2024 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -file to get rotating matrix from edge distance vector -""" -from mindspore import ops -import mindspore.numpy as ms_np - - -def init_edge_rot_mat(edge_distance_vec): - """ - get rotating matrix from edge distance vector - """ - epsilon = 0.00000001 - edge_vec_0 = edge_distance_vec - edge_vec_0_distance = ops.sqrt(ops.maximum(ops.sum(edge_vec_0 ** 2, dim=1), epsilon)) - # Make sure the atoms are far enough apart - norm_x = ops.div(edge_vec_0, edge_vec_0_distance.view(-1, 1)) - edge_vec_2 = ops.rand_like(edge_vec_0) - 0.5 - - edge_vec_2 = ops.div(edge_vec_2, ops.sqrt(ops.maximum(ops.sum(edge_vec_2 ** 2, dim=1), epsilon)).view(-1, 1)) - # Create two rotated copies of the random vectors in case the random vector is aligned with norm_x - # With two 90 degree rotated vectors, at least one should not be aligned with norm_x - edge_vec_2b = edge_vec_2.copy() - edge_vec_2b[:, 0] = -edge_vec_2[:, 1] - edge_vec_2b[:, 1] = edge_vec_2[:, 0] - edge_vec_2c = edge_vec_2.copy() - edge_vec_2c[:, 1] = -edge_vec_2[:, 2] - edge_vec_2c[:, 2] = edge_vec_2[:, 1] - vec_dot_b = ops.abs(ops.sum(edge_vec_2b * norm_x, dim=1)).view(-1, 1) - vec_dot_c = ops.abs(ops.sum(edge_vec_2c * norm_x, dim=1)).view(-1, 1) - vec_dot = ops.abs(ops.sum(edge_vec_2 * norm_x, dim=1)).view(-1, 1) - edge_vec_2 = ops.where(ops.broadcast_to(ops.gt(vec_dot, vec_dot_b), edge_vec_2b.shape), edge_vec_2b, edge_vec_2) - vec_dot = ops.abs(ops.sum(edge_vec_2 * norm_x, dim=1)).view(-1, 1) - edge_vec_2 = ops.where(ops.broadcast_to(ops.gt(vec_dot, vec_dot_c), edge_vec_2c.shape), edge_vec_2c, edge_vec_2) - vec_dot = ops.abs(ops.sum(edge_vec_2 * norm_x, dim=1)) - # Check the vectors aren't aligned - - norm_z = ms_np.cross(norm_x, edge_vec_2, axis=1) - norm_z = ops.div(norm_z, ops.sqrt(ops.maximum(ops.sum(norm_z ** 2, dim=1, keepdim=True), epsilon))) - norm_z = ops.div(norm_z, ops.sqrt(ops.maximum(ops.sum(norm_z ** 2, dim=1), epsilon)).view(-1, 1)) - - norm_y = ms_np.cross(norm_x, norm_z, axis=1) - norm_y = ops.div(norm_y, ops.sqrt(ops.maximum(ops.sum(norm_y ** 2, dim=1, keepdim=True), epsilon))) - # Construct the 3D rotation matrix - norm_x = norm_x.view(-1, 3, 1) - norm_y = -norm_y.view(-1, 3, 1) - norm_z = norm_z.view(-1, 3, 1) - edge_rot_mat_inv = ops.cat([norm_z, norm_x, norm_y], axis=2) - - edge_rot_mat = ops.swapaxes(edge_rot_mat_inv, 1, 2) - return edge_rot_mat diff --git a/MindChem/applications/orb/mindchemistry/so2_conv/jd.pkl b/MindChem/applications/orb/mindchemistry/so2_conv/jd.pkl deleted file mode 100644 index 1b762ad4369564e2f21b808850cc8013e6e41385..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 9925 zcmcIq4RBP|6}}+^l0qOM4QmT_sGvls0WH-G4ZA=R&}>7NA2T`k zy?5?+zWd#C&U<;GqVt*^_b41GRj#5L#rdVPeI*5{(|komzT(ufg5pwNiB;8Qq8Y5V z?tRvJ#;S^$mpf~2fmM}UJhy0ex%SpmissHLn~_^ml3Q+7b)Q;NFwIw7T2?Z5TA8(4 zPk^IU)wMX^xU9CkYN?eGm1ixtCRi!nDEYU{DvYa&$uBFPT_BS>O&?}et}yqbtD@4a zoSdBO8~VTxuPU?hDlbMx8t09} zf=6dY-5hoNd>__ux$|Ga$9RpzwbsY(4aBj?n`{J_Bxak&AWNK8O@v8(0swVA43mc))3(9eGk#z za1Kdqq2TCPdnyLFbLw9K-V2+&z?aD(c<^xq7;w6J`fK1(hj|-&0fXycT>5ncTl-?q zhvyxHedFY>Jk*8jVBC)25>Kgd(BS1QNBz{9x^NvG!4-buFL-8N#Rvy=81IcYIkBpNlq=+WL9;&AWlW>OT)4-sBBy-QufzaY$~D#I?bL`@<)0k_U6z!gO-|;C+{Y z`Kjwg$vaS2kJteZDY27hw+p$?{4!hBtqlEeOlejGwwV0fx5Jh-nIwV$9PKp2wG3xaEG6H z%{=42b06A9ZU2twDPFQZu@XvJ zq;@|+eQn$l@~4-*fgMg9l-!69$mVOZ|f$G(0Z<`1t?Kk{#XrR{q#zV^i0 zA3XR+@Ze18Kg?g|5BHDy(f4*N*KfyP@4lqBP7!6(6 zUv=H#`l7T3Smq)3ojP+}#IaW#j?q)dE&Xoq7PNDtt6?wg{SvnSkntYY1@n;mPMx`~ zaIhf;`6%@Tc)YSd347O&wa@VL#MZG+Sr^Ph?mKlRW;iVMAs?O3;Is1ahp@M2GpD%O zH&~~v3+5sBJsf<9!Tq8>>I38)&=vBnBlOJKZ0{f@0VE7N38G> ztKUO^>fk*z*_&*7dQ%rO%jwSn%|Aoj><7Dn$6H;?{W3SL$8zMo{`MEr+Y@l@p@PTY z!#VpK%=5T)?Aze(GEm8eSmTzN z=9}y<795Wl&pj=#f~Qj7LCjr!ew!|w?(v#$0{a_w1nw(E{m4(L7TnXCV=V)&nI0qa z!+mr)HWsly{LI8Gxpp1rDb+i6NvurQP-rJ(?l1KtKczOi<$gl;*RV1U_whvQ9`M=M z@I3fki*174I4BAFe!l$&uo9Q~$Ni;#Hoxl;;<*5v#x>T$ZXIRbk17||y|U0mWkz1MtqS=RSXiz6j><6qj<}>pp5GT)kyt~B3Jw3cr6 zQ}!$N1J)b!ITAdXf81Z{M}A7pa!cI#zd5;+{hs}l{fhk{wBEwbRo91LO8$wJ{S_SEGchyPOmh0Ovi!S@aebUL&T3Cdl&d@Y3YkP(cjUF@ag`d`TL934H$c-*nbbO=Y)5(s@Z%W)p~(L?rS{|d)!BNVLhk2Rpa{NoEhLR z(aBd2;j^yrZp%xcMJmgmdww zzBb=<+#6Wm2#l6(rdNn>q>k-VYZ0sQz^kxsD-Il%9x{J7^d0&K&SzF9g1_)!K9Cdn z>T~!kx%>xU@qQ$1O}6>QmN~{~_3In65z{pw1JBw6JKqEEMaPB(bAkEGe4xHI-;{{+ zZhWZ*F&cgUhJCr_9PD3OPs0vq{s4BvoY!Gl$E*wH8TXAk5H}K9+n6cPR<+!K&3gX{ zzua3 0 coefficients - for m in range(self.global_max_order): - if m == 0: - continue - x_m = m_list_merge[m] - x_m = x_m.reshape(num_edges, 2, -1) - x_m = self.so2_m_conv[m - 1](x_m) - out.append(x_m) - - ###################### start fill 0 ###################### - if self.max_order_out + 1 > len(m_list_merge): - for m in range(len(m_list_merge), self.max_order_out + 1): - extra_zero = ops.zeros( - (num_edges, 2, int(self.m_shape_dict_out.get(m, None) / 2))) - out.append(extra_zero) - ###################### finish fill 0 ###################### - - ###################### start _l_primary ######################### - l_primary_list_0 = [] - l_primary_list_left = [] - l_primary_list_right = [] - - for _ in range(self.irreps_out_length): - l_primary_list_0.append([]) - l_primary_list_left.append([]) - l_primary_list_right.append([]) - - m_0 = out[0] - offset = 0 - index = 0 - - for key_val in self.irreps_out_data: - key = key_val[0] - value = key_val[1] - if key >= 0: - l_primary_list_0[index].append( - ops.unsqueeze(m_0[:, offset:offset + value], -1)) - offset = offset + value - index = index + 1 - - for m in range(1, len(out)): - right = out[m][:, 1] - offset = 0 - index = 0 - - for key_val in self.irreps_out_data: - key = key_val[0] - value = key_val[1] - if key >= m: - l_primary_list_right[index].append( - ops.unsqueeze(right[:, offset:offset + value], -1)) - offset = offset + value - index = index + 1 - - for m in range(len(out) - 1, 0, -1): - left = out[m][:, 0] - offset = 0 - index = 0 - - for key_val in self.irreps_out_data: - key = key_val[0] - value = key_val[1] - if key >= m: - l_primary_list_left[index].append( - ops.unsqueeze(left[:, offset:offset + value], -1)) - offset = offset + value - index = index + 1 - - l_primary_list = [] - for i in range(self.irreps_out_length): - if i == 0: - tmp = ops.cat(l_primary_list_0[i], -1) - l_primary_list.append(tmp) - else: - tmp = ops.cat( - (ops.cat((ops.cat(l_primary_list_left[i], - -1), ops.cat(l_primary_list_0[i], -1)), - -1), ops.cat(l_primary_list_right[i], -1)), -1) - l_primary_list.append(tmp) - - ##################### finish _l_primary ######################### - return tuple(l_primary_list) diff --git a/MindChem/applications/orb/mindchemistry/so2_conv/so3.py b/MindChem/applications/orb/mindchemistry/so2_conv/so3.py deleted file mode 100644 index 0ffae5a5f..000000000 --- a/MindChem/applications/orb/mindchemistry/so2_conv/so3.py +++ /dev/null @@ -1,156 +0,0 @@ -# Copyright 2024 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -so3 file -""" -import mindspore as ms -from mindspore import nn, ops, vmap, jit_class -from mindspore.numpy import tensordot -from mindscience.e3nn import o3 -from mindscience.e3nn.o3 import Irreps - -from .wigner import wigner_D - - -class SO3Embedding(nn.Cell): - """ - SO3Embedding class - """ - - def __init__(self): - self.embedding = None - - def _rotate(self, so3rotation, lmax_list, max_list): - """ - SO3Embedding rotate - """ - embedding_rotate = so3rotation[0].rotate(self.embedding, lmax_list[0], - max_list[0]) - self.embedding = embedding_rotate - - def _rotate_inv(self, so3rotation): - """ - SO3Embedding rotate inverse - """ - embedding_rotate = so3rotation[0].rotate_inv(self.embedding, - self.lmax_list[0], - self.mmax_list[0]) - self.embedding = embedding_rotate - - -@jit_class -class SO3Rotation: - """ - SO3_Rotation class - """ - - def __init__(self, lmax, irreps_in, irreps_out): - self.lmax = lmax - self.irreps_in1 = Irreps(irreps_in) - self.irreps_out = Irreps(irreps_out) - self.tensordot_vmap = vmap(tensordot, (0, 0, None), 0) - - @staticmethod - def narrow(inputs, axis, start, length): - """ - SO3_Rotation narrow class - """ - begins = [0] * inputs.ndim - begins[axis] = start - - sizes = list(inputs.shape) - - sizes[axis] = length - res = ops.slice(inputs, begins, sizes) - return res - - @staticmethod - def rotation_to_wigner_d_matrix(edge_rot_mat, start_lmax, end_lmax): - """ - SO3_Rotation rotation_to_wigner_d_matrix - """ - x = edge_rot_mat @ ms.Tensor([0.0, 1.0, 0.0]) - alpha, beta = o3.xyz_to_angles(x) - rvalue = (ops.swapaxes( - o3.angles_to_matrix(alpha, beta, ops.zeros_like(alpha)), -1, -2) - @ edge_rot_mat) - gamma = ops.atan2(rvalue[..., 0, 2], rvalue[..., 0, 0]) - - block_list = [] - for lmax in range(start_lmax, end_lmax + 1): - block = wigner_D(lmax, alpha, beta, gamma).astype(ms.float32) - block_list.append(block) - return block_list - - def set_wigner(self, rot_mat3x3): - """ - SO3_Rotation set_wigner - """ - wigner = self.rotation_to_wigner_d_matrix(rot_mat3x3, 0, self.lmax) - wigner_inv = [] - length = len(wigner) - for i in range(length): - wigner_inv.append(ops.swapaxes(wigner[i], 1, 2)) - return tuple(wigner), tuple(wigner_inv) - - def rotate(self, embedding, wigner): - """ - SO3_Rotation rotate - """ - res = [] - batch_shape = embedding.shape[:-1] - for (s, l), mir in zip(self.irreps_in1.slice_tuples, - self.irreps_in1.data): - v_slice = self.narrow(embedding, -1, s, l) - if embedding.ndim == 1: - res.append((v_slice.reshape((1,) + batch_shape + - (mir.mul, mir.ir.dim)), mir.ir)) - else: - res.append( - (v_slice.reshape(batch_shape + (mir.mul, mir.ir.dim)), - mir.ir)) - rotate_data_list = [] - for data, ir in res: - self.tensordot_vmap(data.astype(ms.float16), - wigner[ir.l].astype(ms.float16), ([1], [1])) - rotate_data = self.tensordot_vmap(data.astype(ms.float16), - wigner[ir.l].astype(ms.float16), - ((1), (1))).astype(ms.float32) - rotate_data_list.append(rotate_data) - return tuple(rotate_data_list) - - def rotate_inv(self, embedding, wigner_inv): - """ - SO3_Rotation rotate_inv - """ - res = [] - batch_shape = embedding[0].shape[0:1] - index = 0 - for (_, _), mir in zip(self.irreps_out.slice_tuples, - self.irreps_out.data): - v_slice = embedding[index] - if embedding[0].ndim == 1: - res.append((v_slice, mir.ir)) - else: - res.append((v_slice, mir.ir)) - index = index + 1 - rotate_back_data_list = [] - for data, ir in res: - rotate_back_data = self.tensordot_vmap( - data.astype(ms.float16), wigner_inv[ir.l].astype(ms.float16), - ((1), (1))).astype(ms.float32) - rotate_back_data_list.append( - rotate_back_data.view(batch_shape + (-1,))) - return ops.cat(rotate_back_data_list, -1) diff --git a/MindChem/applications/orb/mindchemistry/so2_conv/wigner.py b/MindChem/applications/orb/mindchemistry/so2_conv/wigner.py deleted file mode 100644 index c3e08615c..000000000 --- a/MindChem/applications/orb/mindchemistry/so2_conv/wigner.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright 2024 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -wigner file -""" - -# pylint: disable=C0103 -import pickle -from mindspore import ops -import mindspore as ms -from mindscience.e3nn.utils.func import broadcast_args - - -def wigner_D(lv, alpha, beta, gamma): - """ - # Borrowed from e3nn @ 0.4.0: - # https://github.com/e3nn/e3nn/blob/0.4.0/e3nn/o3/_wigner.py#L10 - # jd is a list of tensors of shape (2l+1, 2l+1) - - # Borrowed from e3nn @ 0.4.0: - # https://github.com/e3nn/e3nn/blob/0.4.0/e3nn/o3/_wigner.py#L37 - # - # In 0.5.0, e3nn shifted to torch.matrix_exp which is significantly slower: - # https://github.com/e3nn/e3nn/blob/0.5.0/e3nn/o3/_wigner.py#L92 - """ - jd = None - with open("jd.pkl", "rb") as f: - jd = pickle.load(f) - if not lv < len(jd): - raise NotImplementedError( - f"wigner D maximum l implemented is {len(jd) - 1}, send us an email to ask for more" - ) - alpha, beta, gamma = broadcast_args(alpha, beta, gamma) - j = jd[lv] - xa = _z_rot_mat(alpha, lv) - xb = _z_rot_mat(beta, lv) - xc = _z_rot_mat(gamma, lv) - return xa @ j.astype(ms.float16) @ xb @ j.astype(ms.float16) @ xc - - -def _z_rot_mat(angle, lv): - shape = angle.shape - m = ops.zeros((shape[0], 2 * lv + 1, 2 * lv + 1)) - inds = ops.arange(0, 2 * lv + 1, 1) - reversed_inds = ops.arange(2 * lv, -1, -1) - frequencies = ops.arange(lv, -lv - 1, -1) - m[..., inds, reversed_inds] = ops.sin(frequencies * angle[..., None]) - m[..., inds, inds] = ops.cos(frequencies * angle[..., None]) - return m.astype(ms.float16) diff --git a/MindChem/applications/orb/mindchemistry/utils/__init__.py b/MindChem/applications/orb/mindchemistry/utils/__init__.py deleted file mode 100644 index 3f15063d7..000000000 --- a/MindChem/applications/orb/mindchemistry/utils/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this filepio[] except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""init""" -from .load_config import load_yaml_config - -__all__ = ['load_yaml_config'] diff --git a/MindChem/applications/orb/mindchemistry/utils/check_func.py b/MindChem/applications/orb/mindchemistry/utils/check_func.py deleted file mode 100644 index 711a441fe..000000000 --- a/MindChem/applications/orb/mindchemistry/utils/check_func.py +++ /dev/null @@ -1,128 +0,0 @@ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""functions""" -from __future__ import absolute_import - -from mindspore import context - -_SPACE = " " - - -def _convert_to_tuple(params): - if params is None: - return params - if not isinstance(params, (list, tuple)): - params = (params,) - if isinstance(params, list): - params_out = tuple(params) - else: - params_out = params # ✅ 防止未定义 - return params_out - - -def check_param_type(param, param_name, data_type=None, exclude_type=None): - """Check parameter's data type""" - data_type = _convert_to_tuple(data_type) - exclude_type = _convert_to_tuple(exclude_type) - - if data_type and not isinstance(param, data_type): - raise TypeError( - f"The type of {param_name} should be instance of {data_type}, but got {param} with type {type(param)}" - ) - if exclude_type and type(param) in exclude_type: - raise TypeError( - f"The type of {param_name} should not be instance of {exclude_type},but got {param} with type {type(param)}" - ) - - -def check_param_value(param, param_name, valid_value): - """check parameter's value""" - valid_value = _convert_to_tuple(valid_value) - if param not in valid_value: - raise ValueError(f"The value of {param_name} should be in {valid_value}, but got {param}") - - -def check_param_type_value(param, param_name, valid_value, data_type=None, exclude_type=None): - """check both data type and value""" - check_param_type(param, param_name, data_type=data_type, exclude_type=exclude_type) - check_param_value(param, param_name, valid_value) - - -def check_dict_type(param_dict, param_name, key_type=None, value_type=None): - """check data type for key and value of the specified dict""" - check_param_type(param_dict, param_name, data_type=dict) - - for key in param_dict.keys(): - if key_type: - check_param_type(key, _SPACE.join(("key of", param_name)), data_type=key_type) - if value_type: - values = _convert_to_tuple(param_dict[key]) - for value in values: - check_param_type(value, _SPACE.join(("value of", param_name)), data_type=value_type) - - -def check_dict_value(param_dict, param_name, key_value=None, value_value=None): - """check values for key and value of specified dict""" - check_param_type(param_dict, param_name, data_type=dict) - - for key in param_dict.keys(): - if key_value: - check_param_value(key, _SPACE.join(("key of", param_name)), key_value) - if value_value: - values = _convert_to_tuple(param_dict[key]) - for value in values: - check_param_value(value, _SPACE.join(("value of", param_name)), value_value) - - -def check_dict_type_value(param_dict, param_name, key_type=None, value_type=None, key_value=None, value_value=None): - """check values for key and value of specified dict""" - check_dict_type(param_dict, param_name, key_type=key_type, value_type=value_type) - check_dict_value(param_dict, param_name, key_value=key_value, value_value=value_value) - - -def check_mode(api_name): - """check running mode""" - if context.get_context("mode") == context.PYNATIVE_MODE: - raise RuntimeError(f"{api_name} is only supported GRAPH_MODE now but got PYNATIVE_MODE") - - -def check_param_no_greater(param, param_name, compared_value): - """ Check whether the param less than the given compared_value""" - if param > compared_value: - raise ValueError(f"The value of {param_name} should be no greater than {compared_value}, but got {param}") - - -def check_param_odd(param, param_name): - """ Check whether the param is an odd number""" - if param % 2 == 0: - raise ValueError(f"The value of {param_name} should be an odd number, but got {param}") - - -def check_param_even(param, param_name): - """ Check whether the param is an even number""" - for value in param: - if value % 2 != 0: - raise ValueError(f"The value of {param_name} should be an even number, but got {param}") - - -def check_lr_param_type_value(param, param_name, param_type, thresh_hold=0, restrict=False, exclude=None): - if (exclude and isinstance(param, exclude)) or not isinstance(param, param_type): - raise TypeError(f"the type of {param_name} should be {param_type}, but got {type(param)}") - if restrict: - if param <= thresh_hold: - raise ValueError(f"the value of {param_name} should be > {thresh_hold}, but got: {param}") - else: - if param < thresh_hold: - raise ValueError(f"the value of {param_name} should be >= {thresh_hold}, but got: {param}") diff --git a/MindChem/applications/orb/mindchemistry/utils/load_config.py b/MindChem/applications/orb/mindchemistry/utils/load_config.py deleted file mode 100644 index 3ddc76e42..000000000 --- a/MindChem/applications/orb/mindchemistry/utils/load_config.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -utility functions -""" -import os -import yaml - - -def _make_paths_absolute(dir_, config): - """ - Make all values for keys ending with `_path` absolute to dir_. - - Args: - dir_ (str): The path of yaml configuration file. - config (dict): The yaml for configuration file. - - Returns: - Dict. The configuration information in dict format. - """ - for key in config.keys(): - if key.endswith("_path"): - config[key] = os.path.join(dir_, config[key]) - config[key] = os.path.abspath(config[key]) - if isinstance(config[key], dict): - config[key] = _make_paths_absolute(dir_, config[key]) - return config - - -def load_yaml_config(file_path): - """ - Load a YAML configuration file. - - Args: - file_path (str): The path of yaml configuration file. - - Returns: - Dict. The configuration information in dict format. - - Supported Platforms: - ``Ascend`` ``CPU`` ``GPU`` - - Examples: - >>> from mindchemistry.utils import load_yaml_config - >>> config_file_path = 'xxx' # 'xxx' is the file_path - >>> configs = load_yaml_config(config_file_path) - """ - # Read YAML experiment definition file - with open(file_path, 'r', encoding='utf-8') as stream: - config = yaml.safe_load(stream) - config = _make_paths_absolute(os.path.join( - os.path.dirname(file_path), ".."), config) - return config - - -def load_yaml_config_from_path(file_path): - """ - Load a YAML configuration file. - - Args: - file_path (str): The path of yaml configuration file. - - Returns: - Dict. The configuration information in dict format. - - Supported Platforms: - ``Ascend`` ``CPU`` ``GPU`` - """ - # Read YAML experiment definition file - with open(file_path, 'r', encoding='utf-8') as stream: - config = yaml.safe_load(stream) - - return config diff --git a/MindChem/applications/orb/mindchemistry/__init__.py b/MindChem/applications/orb/models/__init__.py similarity index 95% rename from MindChem/applications/orb/mindchemistry/__init__.py rename to MindChem/applications/orb/models/__init__.py index 29f929195..e0521e5da 100644 --- a/MindChem/applications/orb/mindchemistry/__init__.py +++ b/MindChem/applications/orb/models/__init__.py @@ -19,13 +19,11 @@ import mindspore as ms from mindspore import log as logger from mindscience.e3nn import * from .cell import * -from .utils import * -from .graph import * -from .so2_conv import * + __all__ = [] __all__.extend(cell.__all__) -__all__.extend(utils.__all__) + diff --git a/MindChem/applications/orb/mindchemistry/cell/__init__.py b/MindChem/applications/orb/models/cell/__init__.py similarity index 100% rename from MindChem/applications/orb/mindchemistry/cell/__init__.py rename to MindChem/applications/orb/models/cell/__init__.py diff --git a/MindChem/applications/orb/mindchemistry/cell/activation.py b/MindChem/applications/orb/models/cell/activation.py similarity index 100% rename from MindChem/applications/orb/mindchemistry/cell/activation.py rename to MindChem/applications/orb/models/cell/activation.py diff --git a/MindChem/applications/orb/mindchemistry/cell/basic_block.py b/MindChem/applications/orb/models/cell/basic_block.py similarity index 100% rename from MindChem/applications/orb/mindchemistry/cell/basic_block.py rename to MindChem/applications/orb/models/cell/basic_block.py diff --git a/MindChem/applications/orb/mindchemistry/cell/convolution.py b/MindChem/applications/orb/models/cell/convolution.py similarity index 100% rename from MindChem/applications/orb/mindchemistry/cell/convolution.py rename to MindChem/applications/orb/models/cell/convolution.py diff --git a/MindChem/applications/orb/mindchemistry/cell/embedding.py b/MindChem/applications/orb/models/cell/embedding.py similarity index 100% rename from MindChem/applications/orb/mindchemistry/cell/embedding.py rename to MindChem/applications/orb/models/cell/embedding.py diff --git a/MindChem/applications/orb/mindchemistry/cell/message_passing.py b/MindChem/applications/orb/models/cell/message_passing.py similarity index 100% rename from MindChem/applications/orb/mindchemistry/cell/message_passing.py rename to MindChem/applications/orb/models/cell/message_passing.py diff --git a/MindChem/applications/orb/mindchemistry/cell/orb/__init__.py b/MindChem/applications/orb/models/cell/orb/__init__.py similarity index 100% rename from MindChem/applications/orb/mindchemistry/cell/orb/__init__.py rename to MindChem/applications/orb/models/cell/orb/__init__.py diff --git a/MindChem/applications/orb/mindchemistry/cell/orb/gns.py b/MindChem/applications/orb/models/cell/orb/gns.py similarity index 99% rename from MindChem/applications/orb/mindchemistry/cell/orb/gns.py rename to MindChem/applications/orb/models/cell/orb/gns.py index ab53083f8..940fd37e3 100644 --- a/MindChem/applications/orb/mindchemistry/cell/orb/gns.py +++ b/MindChem/applications/orb/models/cell/orb/gns.py @@ -24,7 +24,7 @@ from mindspore import nn, ops, Tensor, mint from mindspore.common.initializer import Uniform import mindspore.ops.operations as P -from mindchemistry.cell.orb.utils import build_mlp +from models.cell.orb.utils import build_mlp _KEY = "feat" diff --git a/MindChem/applications/orb/mindchemistry/cell/orb/orb.py b/MindChem/applications/orb/models/cell/orb/orb.py similarity index 97% rename from MindChem/applications/orb/mindchemistry/cell/orb/orb.py rename to MindChem/applications/orb/models/cell/orb/orb.py index 8afc55dbf..a0922bdbb 100644 --- a/MindChem/applications/orb/mindchemistry/cell/orb/orb.py +++ b/MindChem/applications/orb/models/cell/orb/orb.py @@ -21,8 +21,8 @@ import numpy import mindspore as ms from mindspore import Parameter, ops, Tensor, mint -from mindchemistry.cell.orb.gns import _KEY, MoleculeGNS -from mindchemistry.cell.orb.utils import ( +from models.cell.orb.gns import _KEY, MoleculeGNS +from models.cell.orb.utils import ( aggregate_nodes, build_mlp, REFERENCE_ENERGIES, @@ -133,12 +133,20 @@ class ScalarNormalizer(ms.nn.Cell): def construct(self, x: Tensor): """construct """ - if self.training: - self.bn(x.view(-1, 1)) - if hasattr(self, "running_mean"): - return (x - self.running_mean) / mint.sqrt(self.running_var) - return (x - self.bn.running_mean) / mint.sqrt(self.bn.running_var) - + # if self.training: + # self.bn(x.view(-1, 1)) + # if hasattr(self, "running_mean"): + # return (x - self.running_mean) / mint.sqrt(self.running_var) + # return (x - self.bn.running_mean) / mint.sqrt(self.bn.running_var) + # 修改 + if hasattr(self, "running mean"): + mean = self.running_mean + var = self.running_var + else: + mean = self.bn.running_mean + var = self.bn.running_var + return(x-mean)/mint.sqrt(var) + def inverse(self, x: Tensor): """Reverse the construct normalization. diff --git a/MindChem/applications/orb/mindchemistry/cell/orb/utils.py b/MindChem/applications/orb/models/cell/orb/utils.py similarity index 100% rename from MindChem/applications/orb/mindchemistry/cell/orb/utils.py rename to MindChem/applications/orb/models/cell/orb/utils.py diff --git a/MindChem/applications/orb/src/pretrained.py b/MindChem/applications/orb/src/pretrained.py index cb66db19f..8f58ddde6 100644 --- a/MindChem/applications/orb/src/pretrained.py +++ b/MindChem/applications/orb/src/pretrained.py @@ -20,7 +20,7 @@ from typing import Optional from mindspore import nn, load_checkpoint, load_param_into_net -from mindchemistry.cell import ( +from models.cell import ( EnergyHead, GraphHead, Orb, -- Gitee From ab0cf043e22f6cfb8272bd2b000ce40f6ed69790 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E6=95=8F=E6=B6=9B?= Date: Tue, 2 Dec 2025 14:09:37 +0800 Subject: [PATCH 2/7] [SPONGE] refactor orb app structure and models MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 刘敏涛 --- MindChem/applications/orb/README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/MindChem/applications/orb/README.md b/MindChem/applications/orb/README.md index 8c85888e5..fe577e2f7 100644 --- a/MindChem/applications/orb/README.md +++ b/MindChem/applications/orb/README.md @@ -169,6 +169,8 @@ Training time: 2376.63176 seconds [INFO] PS(365488,ffff217af120,python):2025-12-02-13:43:15.029.782 [mindspore/ccsrc/ps/core/communicator/tcp_client.cc:318] Start] Event base dispatch success! [INFO] PS(365481,ffff2d659120,python):2025-12-02-13:43:15.061.968 [mindspore/ccsrc/ps/core/communicator/tcp_client.cc:318] Start] Event base dispatch success! [INFO] PS(365481,ffff2de69120,python):2025-12-02-13:43:15.061.956 [mindspore/ccsrc/ps/core/communicator/tcp_server.cc:220] Start] Event base dispatch success! + + ``` Under the same training configuration, parallel training achieved significant performance improvement compared to single-card training: -- Gitee From 1db92041549bc1fbeff6c71fbb3208cf637d5606 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E6=95=8F=E6=B6=9B?= Date: Tue, 2 Dec 2025 14:24:00 +0800 Subject: [PATCH 3/7] [SPONGE] refactor orb app structure and models MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 刘敏涛 --- MindChem/applications/orb/README_CN.md | 38 +++++++++---------- .../applications/orb/models/cell/orb/orb.py | 2 +- 2 files changed, 19 insertions(+), 21 deletions(-) diff --git a/MindChem/applications/orb/README_CN.md b/MindChem/applications/orb/README_CN.md index 8c1f18a0a..411f5ac49 100644 --- a/MindChem/applications/orb/README_CN.md +++ b/MindChem/applications/orb/README_CN.md @@ -62,26 +62,24 @@ orb_models # 项目根目录(ORB 预 ├── requirement.txt # Python 依赖列表(用于搭建运行环境) │ ├── models # 模型结构定义模块(GNN/ORB 相关网络) -│ ├── __init__.py # models 包初始化,统一暴露各类模型接口 -│ │ -│ ├── cell # 神经网络/图网络的基础“细胞”模块 -│ │ ├── activation.py # 激活函数与非线性模块封装 -│ │ ├── basic_block.py # 通用基础网络模块(MLP Block、ResBlock 等) -│ │ ├── convolution.py # 图卷积或空间卷积相关层与算子 -│ │ ├── embedding.py # 原子类型、边特征等嵌入层定义 -│ │ ├── __init__.py # cell 子包初始化 -│ │ ├── message_passing.py # 消息传递层与 GNN 主体结构(message passing 核心逻辑) -│ │ │ -│ │ ├── orb # ORB 专用网络结构 -│ │ │ ├── gns.py # GNS(Graph Network Simulator)相关结构/接口 -│ │ │ ├── __init__.py # orb 子包初始化 -│ │ │ ├── orb.py # ORB 主模型结构定义(encoder + heads 等) -│ │ │ └── utils.py # ORB 模型内部使用的工具函数/辅助模块 -│ │ -│ └── ... # 预留:可以扩展其他模型结构 -│ -└── __init__.py -``` + ├── __init__.py # models 包初始化,统一暴露各类模型接口 + │ + ├── cell # 神经网络/图网络的基础“细胞”模块 + │ ├── activation.py # 激活函数与非线性模块封装 + │ ├── basic_block.py # 通用基础网络模块(MLP Block、ResBlock 等) + │ ├── convolution.py # 图卷积或空间卷积相关层与算子 + │ ├── embedding.py # 原子类型、边特征等嵌入层定义 + │ ├── __init__.py # cell 子包初始化 + │ ├── message_passing.py # 消息传递层与 GNN 主体结构(message passing 核心逻辑) + │ │ + │ ├── orb # ORB 专用网络结构 + │ │ ├── gns.py # GNS(Graph Network Simulator)相关结构/接口 + │ │ ├── __init__.py # orb 子包初始化 + │ │ ├── orb.py # ORB 主模型结构定义(encoder + heads 等) + │ │ └── utils.py # ORB 模型内部使用的工具函数/辅助模块 + │ + └── ... # 预留:可以扩展其他模型结构 + ## 下载数据集 diff --git a/MindChem/applications/orb/models/cell/orb/orb.py b/MindChem/applications/orb/models/cell/orb/orb.py index a0922bdbb..0a4450bf8 100644 --- a/MindChem/applications/orb/models/cell/orb/orb.py +++ b/MindChem/applications/orb/models/cell/orb/orb.py @@ -146,7 +146,7 @@ class ScalarNormalizer(ms.nn.Cell): mean = self.bn.running_mean var = self.bn.running_var return(x-mean)/mint.sqrt(var) - + def inverse(self, x: Tensor): """Reverse the construct normalization. -- Gitee From 0f85937a6051a002dc8c0e9e42cc788fb87a1dc3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E6=95=8F=E6=B6=9B?= Date: Tue, 2 Dec 2025 14:34:06 +0800 Subject: [PATCH 4/7] [SPONGE] refactor orb app structure and models MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 刘敏涛 --- MindChem/applications/orb/README_CN.md | 1 - 1 file changed, 1 deletion(-) diff --git a/MindChem/applications/orb/README_CN.md b/MindChem/applications/orb/README_CN.md index 411f5ac49..acc3619bf 100644 --- a/MindChem/applications/orb/README_CN.md +++ b/MindChem/applications/orb/README_CN.md @@ -79,7 +79,6 @@ orb_models # 项目根目录(ORB 预 │ │ └── utils.py # ORB 模型内部使用的工具函数/辅助模块 │ └── ... # 预留:可以扩展其他模型结构 - ## 下载数据集 -- Gitee From e3b02cfca5e7d3938b5335690ce32b8c686e0cbb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E6=95=8F=E6=B6=9B?= Date: Tue, 2 Dec 2025 15:22:48 +0800 Subject: [PATCH 5/7] [SPONGE] refactor orb app structure and models MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 刘敏涛 --- MindChem/applications/orb/README.md | 2 +- MindChem/applications/orb/README_CN.md | 6 +++--- MindChem/applications/orb/models/cell/orb/orb.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/MindChem/applications/orb/README.md b/MindChem/applications/orb/README.md index fe577e2f7..acfd7d6e8 100644 --- a/MindChem/applications/orb/README.md +++ b/MindChem/applications/orb/README.md @@ -34,7 +34,7 @@ orb_models # Project root directory (ORB pretrain │ └── val_mptrj_ase.db # Validation/test dataset for finetuning │ ├── orb_ckpts # Directory to store checkpoints -│ └── orb-mptraj-only-v2.ckpt # Pretrained ORB model checkpoint (mptraj-only task) +│ └── orb-mptraj-only-v2.ckpt # Pretrained ORB model checkpoint (MPTraj-only task) │ # After training, additional finetuned checkpoints will be saved here │ ├── configs diff --git a/MindChem/applications/orb/README_CN.md b/MindChem/applications/orb/README_CN.md index acc3619bf..1b3881a5e 100644 --- a/MindChem/applications/orb/README_CN.md +++ b/MindChem/applications/orb/README_CN.md @@ -28,7 +28,7 @@ ```text 代码主要模块在src文件夹下,其中dataset文件夹下是数据集,orb_ckpts文件夹下是预训练模型和训练好的模型权重文件,configs文件夹下是各代码的参数配置文件。 -orb_models # 项目根目录(ORB 预训练/微调工程) +orb_models # 项目根目录(ORB 预训练/微调工程) ├── dataset │ ├── train_mptrj_ase.db # 微调阶段训练数据集(ASE 轨迹 SQLite 格式) │ └── val_mptrj_ase.db # 微调阶段验证/测试数据集 @@ -73,12 +73,12 @@ orb_models # 项目根目录(ORB 预 │ ├── message_passing.py # 消息传递层与 GNN 主体结构(message passing 核心逻辑) │ │ │ ├── orb # ORB 专用网络结构 - │ │ ├── gns.py # GNS(Graph Network Simulator)相关结构/接口 + │ │ ├── gns.py # GNS(Graph Network Simulator)相关结构/接口 │ │ ├── __init__.py # orb 子包初始化 │ │ ├── orb.py # ORB 主模型结构定义(encoder + heads 等) │ │ └── utils.py # ORB 模型内部使用的工具函数/辅助模块 │ - └── ... # 预留:可以扩展其他模型结构 + └── ... # 预留:可以扩展其他模型结构。 ## 下载数据集 diff --git a/MindChem/applications/orb/models/cell/orb/orb.py b/MindChem/applications/orb/models/cell/orb/orb.py index 0a4450bf8..293f9d16f 100644 --- a/MindChem/applications/orb/models/cell/orb/orb.py +++ b/MindChem/applications/orb/models/cell/orb/orb.py @@ -139,7 +139,7 @@ class ScalarNormalizer(ms.nn.Cell): # return (x - self.running_mean) / mint.sqrt(self.running_var) # return (x - self.bn.running_mean) / mint.sqrt(self.bn.running_var) # 修改 - if hasattr(self, "running mean"): + if hasattr(self, "running_mean"): mean = self.running_mean var = self.running_var else: -- Gitee From fb8c2cfc9a43d22c50dd916c7ffedbc64cb9f0d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E6=95=8F=E6=B6=9B?= Date: Tue, 2 Dec 2025 15:53:11 +0800 Subject: [PATCH 6/7] [SPONGE] refactor orb app structure and models MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 刘敏涛 --- MindChem/applications/orb/README.md | 96 +++++++++++--------------- MindChem/applications/orb/README_CN.md | 87 ++++++++++------------- 2 files changed, 78 insertions(+), 105 deletions(-) diff --git a/MindChem/applications/orb/README.md b/MindChem/applications/orb/README.md index acfd7d6e8..f0c953d2b 100644 --- a/MindChem/applications/orb/README.md +++ b/MindChem/applications/orb/README.md @@ -28,59 +28,53 @@ ```text The main code modules are in the src folder, with the dataset folder containing the datasets, the orb_ckpts folder containing pre-trained models and trained model weight files, and the configs folder containing parameter configuration files for each code. -orb_models # Project root directory (ORB pretraining/finetuning project) +orb_models # ORB pre-training / fine-tuning project ├── dataset -│ ├── train_mptrj_ase.db # Training dataset for finetuning (ASE trajectory in SQLite format) -│ └── val_mptrj_ase.db # Validation/test dataset for finetuning +│ ├── train_mptrj_ase.db # Training dataset for fine-tuning (ASE trajectories, SQLite) +│ └── val_mptrj_ase.db # Validation / test dataset for fine-tuning │ -├── orb_ckpts # Directory to store checkpoints -│ └── orb-mptraj-only-v2.ckpt # Pretrained ORB model checkpoint (MPTraj-only task) -│ # After training, additional finetuned checkpoints will be saved here +├── orb_ckpts # Directory for pre-trained & fine-tuned checkpoints +│ └── orb-mptraj-only-v2.ckpt # Pre-trained ORB checkpoint (mptraj-only task) │ -├── configs -│ ├── config.yaml # Single-device training config (lr, batch_size, etc.) -│ ├── config_parallel.yaml # Multi-device data-parallel training config -│ └── config_eval.yaml # Inference/evaluation config +├── configs # Config files for training / inference +│ ├── config.yaml # Single-card training configuration (lr, batch_size, etc.) +│ ├── config_parallel.yaml # Multi-card data-parallel training configuration +│ └── config_eval.yaml # Inference / evaluation configuration │ -├── src # Core source code for training and data processing -│ ├── __init__.py # Package initializer for src (for convenient imports) -│ ├── ase_dataset.py # ASE dataset wrapper and loader (read SQLite, build atomic graphs) -│ ├── atomic_system.py # Atomic system data structure (coordinates, atom types, cell info, etc.) -│ ├── base.py # Base classes and common utilities (e.g., batch_graphs for graph batching) -│ ├── featurization_utilities.py # Tools to convert atomic systems into feature tensors -│ ├── pretrained.py # Pretrained ORB model builders and loading helpers -│ ├── property_definitions.py # Config and naming rules for physical properties (energy, forces, stress, etc.) -│ ├── trainer.py # OrbLoss and other loss / training-related wrappers -│ ├── segment_ops.py # Segmented operations (segment_sum/mean/max) for graph reductions -│ └── utils.py # General utilities (seeding, logging, optimizer & LR scheduler, etc.) +├── src # Core code for data processing and training +│ ├── __init__.py # Package initializer for src +│ ├── ase_dataset.py # Load and wrap ASE datasets (read SQLite, build atomic graphs) +│ ├── atomic_system.py # Data structures for atomic systems (positions, species, cell, etc.) +│ ├── base.py # Common base classes and utilities (e.g., batch_graphs) +│ ├── featurization_utilities.py # Tools to convert atomic systems into model input features +│ ├── pretrained.py # Interfaces for building and loading pre-trained ORB models +│ ├── property_definitions.py # Config and naming rules for energy / forces / stress, etc. +│ ├── trainer.py # Training loop and loss wrappers (e.g., OrbLoss) +│ ├── segment_ops.py # Segment-wise reduction ops (segment_sum / mean / max) +│ └── utils.py # Utility functions (seeding, logging, optimizer & LR scheduler) │ -├── finetune.py # Main finetuning entry script (loads data/model and starts training) -├── evaluate.py # Inference/evaluation script (run model with finetuned checkpoints) +├── models # Model definitions (GNN / ORB networks) +│ ├── __init__.py # Package initializer for models; export model entry points +│ └── cell # Basic building blocks for NN / GNN +│ ├── __init__.py # Package initializer for cell +│ ├── activation.py # Activation functions and non-linear modules +│ ├── basic_block.py # Basic network blocks (e.g., MLP block, ResBlock) +│ ├── convolution.py # Graph / spatial convolution layers +│ ├── embedding.py # Embedding layers for atom types, edge features, etc. +│ ├── message_passing.py # Message passing layers and GNN backbone +│ └── orb # ORB-specific submodules +│ ├── __init__.py # Package initializer for orb +│ ├── gns.py # GNS (Graph Network Simulator) related structures / APIs +│ ├── orb.py # Main ORB architecture (encoder + heads) +│ └── utils.py # Internal utilities and helper modules for ORB │ -├── run.sh # Single-device training launcher (calls finetune.py + config.yaml) -├── run_parallel.sh # Multi-device training launcher (msrun + config_parallel.yaml) -├── requirement.txt # Python dependency list (for environment setup) +├── finetune.py # Entry script for model fine-tuning +├── evaluate.py # Entry script for model inference / evaluation │ -├── models # Model definition modules (GNN/ORB network structures) -│ ├── __init__.py # Package initializer for models (expose unified model interfaces) -│ │ -│ ├── cell # Basic “cell” building blocks for NN/GNN -│ │ ├── activation.py # Activation functions and non-linear modules -│ │ ├── basic_block.py # Common basic network blocks (MLP blocks, residual blocks, etc.) -│ │ ├── convolution.py # Graph convolutions or spatial convolution layers/operators -│ │ ├── embedding.py # Embedding layers for atom types, edge features, etc. -│ │ ├── __init__.py # Package initializer for cell -│ │ ├── message_passing.py # Message passing layers and main GNN logic -│ │ │ -│ │ ├── orb # ORB-specific network architectures -│ │ │ ├── gns.py # GNS (Graph Network Simulator) related structures/interfaces -│ │ │ ├── __init__.py # Package initializer for orb submodule -│ │ │ ├── orb.py # Main ORB model definition (encoder + heads, etc.) -│ │ │ └── utils.py # Helper functions and internal utilities for ORB models -│ │ -│ └── ... # Reserved: other model architectures can be added here -│ -└── __init__.py +├── run.sh # Single-card training launcher (wraps finetune.py + config.yaml) +├── run_parallel.sh # Multi-card training launcher (msrun + config_parallel.yaml) +└── requirement.txt # Python dependency list for environment setup + ``` ## Download Dataset @@ -161,16 +155,6 @@ Training time: 2375.89474 seconds Training time: 2377.02413 seconds Training time: 2377.22778 seconds Training time: 2376.63176 seconds -[INFO] PS(365484,ffff13fff120,python):2025-12-02-13:43:10.997.606 [mindspore/ccsrc/ps/core/communicator/tcp_server.cc:220] Start] Event base dispatch success! -[INFO] PS(365484,ffff137ef120,python):2025-12-02-13:43:10.997.583 [mindspore/ccsrc/ps/core/communicator/tcp_client.cc:318] Start] Event base dispatch success! -[INFO] PS(365478,ffff36fdf120,python):2025-12-02-13:43:13.013.568 [mindspore/ccsrc/ps/core/communicator/tcp_client.cc:318] Start] Event base dispatch success! -[INFO] PS(365478,ffff377ef120,python):2025-12-02-13:43:13.013.575 [mindspore/ccsrc/ps/core/communicator/tcp_server.cc:220] Start] Event base dispatch success! -[INFO] PS(365488,ffff21fbf120,python):2025-12-02-13:43:15.029.782 [mindspore/ccsrc/ps/core/communicator/tcp_server.cc:220] Start] Event base dispatch success! -[INFO] PS(365488,ffff217af120,python):2025-12-02-13:43:15.029.782 [mindspore/ccsrc/ps/core/communicator/tcp_client.cc:318] Start] Event base dispatch success! -[INFO] PS(365481,ffff2d659120,python):2025-12-02-13:43:15.061.968 [mindspore/ccsrc/ps/core/communicator/tcp_client.cc:318] Start] Event base dispatch success! -[INFO] PS(365481,ffff2de69120,python):2025-12-02-13:43:15.061.956 [mindspore/ccsrc/ps/core/communicator/tcp_server.cc:220] Start] Event base dispatch success! - - ``` Under the same training configuration, parallel training achieved significant performance improvement compared to single-card training: diff --git a/MindChem/applications/orb/README_CN.md b/MindChem/applications/orb/README_CN.md index 1b3881a5e..b2ec65f1b 100644 --- a/MindChem/applications/orb/README_CN.md +++ b/MindChem/applications/orb/README_CN.md @@ -28,57 +28,54 @@ ```text 代码主要模块在src文件夹下,其中dataset文件夹下是数据集,orb_ckpts文件夹下是预训练模型和训练好的模型权重文件,configs文件夹下是各代码的参数配置文件。 -orb_models # 项目根目录(ORB 预训练/微调工程) +orb_models # ORB 预训练 / 微调工程 ├── dataset -│ ├── train_mptrj_ase.db # 微调阶段训练数据集(ASE 轨迹 SQLite 格式) -│ └── val_mptrj_ase.db # 微调阶段验证/测试数据集 +│ ├── train_mptrj_ase.db # 微调训练集(ASE 轨迹,SQLite 格式) +│ └── val_mptrj_ase.db # 微调验证 / 测试集 │ -├── orb_ckpts # checkpoint 存放目录 -│ └── orb-mptraj-only-v2.ckpt # 预训练 ORB 模型 checkpoint(仅 mptraj 任务) -│ # 训练完成后,会在此目录下额外生成微调后的 ckpt +├── orb_ckpts # 预训练 & 微调模型 ckpt 存放目录 +│ └── orb-mptraj-only-v2.ckpt # 仅 mptraj 任务的预训练 ORB 模型 │ -├── configs -│ ├── config.yaml # 单卡训练参数配置(学习率、batch_size 等) -│ ├── config_parallel.yaml # 多卡数据并行训练参数配置 -│ └── config_eval.yaml # 推理/评估阶段参数配置 +├── configs # 训练 / 推理配置 +│ ├── config.yaml # 单卡训练配置(学习率、batch_size 等) +│ ├── config_parallel.yaml # 多卡数据并行训练配置 +│ └── config_eval.yaml # 推理 / 评估配置 │ -├── src # 训练与数据处理的核心源码 -│ ├── __init__.py # src 包初始化,方便外部按模块导入 -│ ├── ase_dataset.py # ASE 数据集封装与加载(读 SQLite、组装原子图) +├── src # 数据处理与训练核心源码 +│ ├── __init__.py # src 包初始化 +│ ├── ase_dataset.py # ASE 数据集读取与封装(读 SQLite、组装原子图) │ ├── atomic_system.py # 原子系统数据结构定义(坐标、原子种类、晶胞信息等) -│ ├── base.py # 基础类与通用函数(batch_graphs 等图数据打包工具) -│ ├── featurization_utilities.py # 原子系统 → 特征张量的特征化工具 +│ ├── base.py # 通用基类与工具(batch_graphs 等图数据打包) +│ ├── featurization_utilities.py # 原子系统 → 模型输入特征张量的特征化工具 │ ├── pretrained.py # 预训练 ORB 模型构造与加载接口 -│ ├── property_definitions.py # 能量、力、应力等物理性质的配置与命名规则 -│ ├── trainer.py # OrbLoss 等 loss 类与训练相关封装 -│ ├── segment_ops.py # segment_sum/mean/max 等分段运算算子(图归约用) -│ └── utils.py # 通用工具函数(seed、日志、优化器与 LR scheduler 等) +│ ├── property_definitions.py # 能量 / 力 / 应力等物理量配置与命名 +│ ├── trainer.py # 训练循环与 OrbLoss 等损失封装 +│ ├── segment_ops.py # segment_sum / mean / max 等分段归约算子 +│ └── utils.py # 通用工具函数(随机种子、日志、优化器、LR scheduler 等) │ -├── finetune.py # 模型微调入口脚本(加载数据、模型、开始训练) -├── evaluate.py # 模型推理/评估脚本(使用微调好的 ckpt 做 inference) +├── models # 模型结构定义(GNN / ORB 等) +│ ├── __init__.py # models 包初始化,统一暴露模型接口 +│ └── cell # 神经网络 / 图网络基础模块 +│ ├── __init__.py # cell 子包初始化 +│ ├── activation.py # 激活函数与非线性模块 +│ ├── basic_block.py # 基础网络模块(MLP Block、ResBlock 等) +│ ├── convolution.py # 图卷积 / 空间卷积相关层 +│ ├── embedding.py # 原子类型、边特征等嵌入层 +│ ├── message_passing.py # 消息传递层与 GNN 主体结构 +│ └── orb # ORB 专用子模块 +│ ├── __init__.py # orb 子包初始化 +│ ├── gns.py # GNS(Graph Network Simulator) 相关结构 / 接口 +│ ├── orb.py # ORB 主体网络(encoder + heads) +│ └── utils.py # ORB 内部工具与辅助模块 +│ +├── finetune.py # 模型微调入口脚本 +├── evaluate.py # 推理 / 评估入口脚本 │ ├── run.sh # 单卡训练启动脚本(调用 finetune.py + config.yaml) ├── run_parallel.sh # 多卡并行训练启动脚本(msrun + config_parallel.yaml) -├── requirement.txt # Python 依赖列表(用于搭建运行环境) -│ -├── models # 模型结构定义模块(GNN/ORB 相关网络) - ├── __init__.py # models 包初始化,统一暴露各类模型接口 - │ - ├── cell # 神经网络/图网络的基础“细胞”模块 - │ ├── activation.py # 激活函数与非线性模块封装 - │ ├── basic_block.py # 通用基础网络模块(MLP Block、ResBlock 等) - │ ├── convolution.py # 图卷积或空间卷积相关层与算子 - │ ├── embedding.py # 原子类型、边特征等嵌入层定义 - │ ├── __init__.py # cell 子包初始化 - │ ├── message_passing.py # 消息传递层与 GNN 主体结构(message passing 核心逻辑) - │ │ - │ ├── orb # ORB 专用网络结构 - │ │ ├── gns.py # GNS(Graph Network Simulator)相关结构/接口 - │ │ ├── __init__.py # orb 子包初始化 - │ │ ├── orb.py # ORB 主模型结构定义(encoder + heads 等) - │ │ └── utils.py # ORB 模型内部使用的工具函数/辅助模块 - │ - └── ... # 预留:可以扩展其他模型结构。 +└── requirement.txt # Python 依赖列表(环境搭建用) + +``` ## 下载数据集 @@ -158,14 +155,6 @@ Training time: 2375.89474 seconds Training time: 2377.02413 seconds Training time: 2377.22778 seconds Training time: 2376.63176 seconds -[INFO] PS(365484,ffff13fff120,python):2025-12-02-13:43:10.997.606 [mindspore/ccsrc/ps/core/communicator/tcp_server.cc:220] Start] Event base dispatch success! -[INFO] PS(365484,ffff137ef120,python):2025-12-02-13:43:10.997.583 [mindspore/ccsrc/ps/core/communicator/tcp_client.cc:318] Start] Event base dispatch success! -[INFO] PS(365478,ffff36fdf120,python):2025-12-02-13:43:13.013.568 [mindspore/ccsrc/ps/core/communicator/tcp_client.cc:318] Start] Event base dispatch success! -[INFO] PS(365478,ffff377ef120,python):2025-12-02-13:43:13.013.575 [mindspore/ccsrc/ps/core/communicator/tcp_server.cc:220] Start] Event base dispatch success! -[INFO] PS(365488,ffff21fbf120,python):2025-12-02-13:43:15.029.782 [mindspore/ccsrc/ps/core/communicator/tcp_server.cc:220] Start] Event base dispatch success! -[INFO] PS(365488,ffff217af120,python):2025-12-02-13:43:15.029.782 [mindspore/ccsrc/ps/core/communicator/tcp_client.cc:318] Start] Event base dispatch success! -[INFO] PS(365481,ffff2d659120,python):2025-12-02-13:43:15.061.968 [mindspore/ccsrc/ps/core/communicator/tcp_client.cc:318] Start] Event base dispatch success! -[INFO] PS(365481,ffff2de69120,python):2025-12-02-13:43:15.061.956 [mindspore/ccsrc/ps/core/communicator/tcp_server.cc:220] Start] Event base dispatch success! ``` 在相同的训练配置下,并行训练相比单卡训练取得了显著的性能提升: -- Gitee From 3d3dbfd89fa96ff1714dbe33e736fe8720f4ec7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E6=95=8F=E6=B6=9B?= <1820220397@qq.com> Date: Tue, 2 Dec 2025 20:17:13 +0800 Subject: [PATCH 7/7] fix orb.py --- .../applications/orb/models/cell/orb/orb.py | 20 ++++++------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/MindChem/applications/orb/models/cell/orb/orb.py b/MindChem/applications/orb/models/cell/orb/orb.py index 293f9d16f..c0e856ca9 100644 --- a/MindChem/applications/orb/models/cell/orb/orb.py +++ b/MindChem/applications/orb/models/cell/orb/orb.py @@ -123,7 +123,7 @@ class ScalarNormalizer(ms.nn.Cell): self.bn = mint.nn.BatchNorm1d(1, affine=False, momentum=None) self.bn.running_mean = Parameter(Tensor([0], ms.float32)) self.bn.running_var = Parameter(Tensor([1], ms.float32)) - self.bn.num_batches_tracked = Parameter(Tensor([1000], ms.float32)) + self.bn.num_batches_tracked = Parameter(Tensor([1000], ms.float32), requires_grad=False) self.stastics = { "running_mean": init_mean if init_mean is not None else 0.0, "running_var": init_std**2 if init_std is not None else 1.0, @@ -133,20 +133,12 @@ class ScalarNormalizer(ms.nn.Cell): def construct(self, x: Tensor): """construct """ - # if self.training: - # self.bn(x.view(-1, 1)) - # if hasattr(self, "running_mean"): - # return (x - self.running_mean) / mint.sqrt(self.running_var) - # return (x - self.bn.running_mean) / mint.sqrt(self.bn.running_var) - # 修改 + if self.training: + self.bn(x.view(-1, 1)) if hasattr(self, "running_mean"): - mean = self.running_mean - var = self.running_var - else: - mean = self.bn.running_mean - var = self.bn.running_var - return(x-mean)/mint.sqrt(var) - + return (x - self.running_mean) / mint.sqrt(self.running_var) + return (x - self.bn.running_mean) / mint.sqrt(self.bn.running_var) + def inverse(self, x: Tensor): """Reverse the construct normalization. -- Gitee