From b1b659d559df350702e89f60a15a6be1f6a1bf76 Mon Sep 17 00:00:00 2001 From: birfied Date: Thu, 18 Sep 2025 10:59:31 +0800 Subject: [PATCH 01/21] move from MindChemistry to MindChem --- .../applications/crystalflow/.gitignore | 0 MindChem/applications/diffcsp/README.md | 127 +++++++ .../applications/diffcsp/compute_metric.py | 327 ++++++++++++++++++ MindChem/applications/diffcsp/config.yaml | 41 +++ .../applications/diffcsp/data/crysloader.py | 0 .../applications/diffcsp/data/data_utils.py | 0 MindChem/applications/diffcsp/data/dataset.py | 135 ++++++++ .../applications/diffcsp/evaluate.py | 0 .../applications/diffcsp/models/cspnet.py | 0 .../applications/diffcsp/models/diff_utils.py | 0 .../applications/diffcsp/models/diffusion.py | 0 .../diffcsp/models/infer_utils.py | 0 .../diffcsp/models/train_utils.py | 0 .../applications/diffcsp/requirement.txt | 0 MindChem/applications/diffcsp/train.py | 208 +++++++++++ 15 files changed, 838 insertions(+) rename {MindChemistry => MindChem}/applications/crystalflow/.gitignore (100%) create mode 100644 MindChem/applications/diffcsp/README.md create mode 100644 MindChem/applications/diffcsp/compute_metric.py create mode 100644 MindChem/applications/diffcsp/config.yaml rename {MindChemistry => MindChem}/applications/diffcsp/data/crysloader.py (100%) rename {MindChemistry => MindChem}/applications/diffcsp/data/data_utils.py (100%) create mode 100644 MindChem/applications/diffcsp/data/dataset.py rename {MindChemistry => MindChem}/applications/diffcsp/evaluate.py (100%) rename {MindChemistry => MindChem}/applications/diffcsp/models/cspnet.py (100%) rename {MindChemistry => MindChem}/applications/diffcsp/models/diff_utils.py (100%) rename {MindChemistry => MindChem}/applications/diffcsp/models/diffusion.py (100%) rename {MindChemistry => MindChem}/applications/diffcsp/models/infer_utils.py (100%) rename {MindChemistry => MindChem}/applications/diffcsp/models/train_utils.py (100%) rename {MindChemistry => MindChem}/applications/diffcsp/requirement.txt (100%) create mode 100644 MindChem/applications/diffcsp/train.py diff --git a/MindChemistry/applications/crystalflow/.gitignore b/MindChem/applications/crystalflow/.gitignore similarity index 100% rename from MindChemistry/applications/crystalflow/.gitignore rename to MindChem/applications/crystalflow/.gitignore diff --git a/MindChem/applications/diffcsp/README.md b/MindChem/applications/diffcsp/README.md new file mode 100644 index 000000000..858aa417c --- /dev/null +++ b/MindChem/applications/diffcsp/README.md @@ -0,0 +1,127 @@ + +# 模型名称 + +> DiffCSP + +## 介绍 + +> DiffCSP是一种基于扩散模型的深度学习框架,用于解决晶体结构预测这一基础科学难题。其核心思想是将寻找稳定晶体结构的过程转化为一个生成问题:模型通过学习海量已知晶体数据中的分布规律,能够仅根据材料的化学成分(原子种类与比例),直接、快速地生成合理的三维原子结构(包括晶格和原子坐标)。与传统依赖大量量子力学计算的方法相比,DiffCSP的关键创新在于采用了​​SE(3)-等变图神经网络​​并融入了​​周期性边界条件​​,确保了生成的结构严格遵守物理对称性,从而能以极高的效率探索材料的多态性,为新材料的加速发现与设计提供了强大工具。 + +## 环境要求 + +> 1. 安装`mindspore(2.3.0)` +> 2. 安装依赖包:`pip install -r requirement.txt` + +## 快速入门 + +> 1. 将Mindchemistry/mindchemistry文件包下载到当前目录 +> 2. 在[数据集链接](https://download-mindspore.osinfra.cn/mindscience/mindchemistry/diffcsp/)下载相应的数据集 +> 3. 安装依赖包:`pip install -r requirement.txt` +> 4. 训练命令: `python train.py` +> 5. 预测命令: `python evaluate.py` +> 6. 评估命令: `python compute_metric.py` +> 7. 评估结果放在`config.yaml`中指定的`metric_dir`路径的json文件中 + +### 代码目录结构 + +```txt +diffcsp + │ README.md README文件 + │ config.yaml 配置文件 + │ train.py 训练启动脚本 + │ evaluate.py 推理启动脚本 + │ compute_metric.py 评估启动脚本 + │ requirement.txt 环境依赖 + │ + └─data + data_utils.py 数据集处理工具 + dataset.py 读取数据集 + crysloader.py 数据集载入器 + └─models + cspnet.py 基于图神经网络的去噪器模块 + diffusion.py 扩散模型模块 + diff_utils.py 工具模块 + infer_utils.py 推理工具模块 + train_utils.py 训练工具模块 + +``` + +## 下载数据集 + +在[数据集链接](https://download-mindspore.osinfra.cn/mindscience/mindchemistry/diffcsp/)中下载相应的数据集文件夹和dataset_prop.txt数据集属性文件放置于当前路径的dataset文件夹下(如果没有需要自己手动创建),文件路径参考: + +```txt +diffcsp + ... + └─dataset + perov_5 钙钛矿数据集 + carbon_24 碳晶体数据集 + mp_20 晶胞内原子数最多为20的MP数据集 + mpts_52 晶胞内原子数最多为52的MP数据集 + dataset_prop.txt 数据集属性文件 + ... +``` + +## 训练过程 + +### 训练 + +将Mindchemistry/mindchemistry文件包下载到当前目录; + +更改config文件,设置训练参数: +> 1. 设置训练的dataset,见dataset字段 +> 2. 设置去噪器模型的配置,见model字段 +> 3. 设置训练保存的权重文件,更改train.ckpt_dir文件夹名称和checkpoint.last_path权重文件名称 +> 4. 其它训练设置见train字段 + +```bash +pip install -r requirement.txt +python train.py +``` + +### 推理 + +将权重的path写入config文件的checkpoint.last_path中。预训练模型可以从[预训练模型链接](https://download-mindspore.osinfra.cn/mindscience/mindchemistry/diffcsp/pre-train)中获取。 + +更改config文件中的test字段来更改推理参数,特别是test.num_eval,它**决定了对于每个组分生成多少个样本**,对于后续的评估阶段很重要。 + +```bash +python evaluate.py +``` + +推理得到的晶体将保存在test.eval_save_path指定的文件中 + +文件中存储的内容为python字典,格式为: + +```python +{ + 'pred': [ + [晶体A sample 1, 晶体A sample 2, 晶体A sample 3, ... 晶体A sample num_eval], + [晶体B sample 1, 晶体B sample 2, 晶体B sample 3, ... 晶体B sample num_eval] + ... + ] + 'gt': [ + 晶体A ground truth, + 晶体B ground truth, + ... + ] +} +``` + +### 评估 + +将推理得到的晶体文件的path写入config文件的test.eval_save_path中; + +确保num_evals与进行推理时设置的对于每个组分生成样本的数量一致或更小。比如进行推理时,num_evals设置为1,那么评估时,num_evals只能设置为1;推理时,num_evals设置为20,那么评估时,num_evals可以设置为1-20的数字来进行评估。 + +更改config文件中的test.metric_dir字段来设置评估结果的保存路径 + +```bash +python compute_metric.py +``` + +得到的评估结果文件示例: + +```json +{"match_rate": 0.985997357992074, "rms_dist": 0.013073775170360118} +``` diff --git a/MindChem/applications/diffcsp/compute_metric.py b/MindChem/applications/diffcsp/compute_metric.py new file mode 100644 index 000000000..df5539aad --- /dev/null +++ b/MindChem/applications/diffcsp/compute_metric.py @@ -0,0 +1,327 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""compute metric file""" +import itertools +import json +import os +import pickle +from collections import Counter +from pathlib import Path +import argparse +import yaml + +import numpy as np +from matminer.featurizers.composition.composite import ElementProperty +from matminer.featurizers.site.fingerprint import CrystalNNFingerprint +from p_tqdm import p_map +from pymatgen.analysis.structure_matcher import StructureMatcher +from pymatgen.core.composition import Composition +from pymatgen.core.lattice import Lattice +from pymatgen.core.structure import Structure +import smact +from smact.screening import pauling_test +from tqdm import trange + +from models.infer_utils import chemical_symbols + +matcher = StructureMatcher(stol=0.5, angle_tol=10, ltol=0.3) +crystalnn_fp = CrystalNNFingerprint.from_preset("ops") +comp_fp = ElementProperty.from_preset('magpie') + + +def smact_validity(comp, count, use_pauling_test=True, include_alloys=True): + """Smact validity. See details in the paper Crystal Diffution Variational Autoencoder and + its codebase. + """ + elem_symbols = tuple([chemical_symbols[elem] for elem in comp]) + space = smact.element_dictionary(elem_symbols) + smact_elems = [e[1] for e in space.items()] + electronegs = [e.pauling_eneg for e in smact_elems] + ox_combos = [e.oxidation_states for e in smact_elems] + if len(set(elem_symbols)) == 1: + return True + if include_alloys: + is_metal_list = [elem_s in smact.metals for elem_s in elem_symbols] + if all(is_metal_list): + return True + + threshold = np.max(count) + oxn = 1 + for oxc in ox_combos: + oxn *= len(oxc) + if oxn > 1e7: + return False + for ox_states in itertools.product(*ox_combos): + stoichs = [(c,) for c in count] + # Test for charge balance + cn_e, _ = smact.neutral_ratios(ox_states, + stoichs=stoichs, + threshold=threshold) + # Electronegativity test + if cn_e: + if use_pauling_test: + try: + electroneg_ok = pauling_test(ox_states, electronegs) + except TypeError: + # if no electronegativity data, assume it is okay + electroneg_ok = True + else: + electroneg_ok = True + if electroneg_ok: + return True + return False + + +def structure_validity(crystal, cutoff=0.5): + """Structure validity. See details in the paper Crystal Diffution Variational Autoencoder and + its codebase. + """ + dist_mat = crystal.distance_matrix + # Pad diagonal with a large number + dist_mat = dist_mat + np.diag(np.ones(dist_mat.shape[0]) * (cutoff + 10.)) + if dist_mat.min() < cutoff or crystal.volume < 0.1 or max( + crystal.lattice.abc) > 40: + return False + + return True + +class Crystal: + """Strict crystal validity. See details in the paper CDVAE `Crystal + Diffution Variational Autoencoder` and + its codebase. We adopt the same evaluation metric criteria as CDVAE. + """ + + def __init__(self, crys_array_dict): + self.frac_coords = crys_array_dict['frac_coords'] + self.atom_types = crys_array_dict['atom_types'] + self.lengths = crys_array_dict['lengths'] + self.angles = crys_array_dict['angles'] + self.dict = crys_array_dict + if len(self.atom_types.shape) > 1: + self.dict['atom_types'] = np.argmax(self.atom_types, axis=-1) + 1 + self.atom_types = np.argmax(self.atom_types, axis=-1) + 1 + + self.get_structure() + self.get_composition() + self.get_validity() + self.get_fingerprints() + + def get_structure(self): + """get_structure + """ + if min(self.lengths.tolist()) < 0: + self.constructed = False + self.invalid_reason = 'non_positive_lattice' + if np.isnan(self.lengths).any() or np.isnan( + self.angles).any() or np.isnan(self.frac_coords).any(): + self.constructed = False + self.invalid_reason = 'nan_value' + else: + try: + self.structure = Structure(lattice=Lattice.from_parameters( + *(self.lengths.tolist() + self.angles.tolist())), + species=self.atom_types, + coords=self.frac_coords, + coords_are_cartesian=False) + self.constructed = True + # pylint: disable=W0703 + except Exception: + self.constructed = False + self.invalid_reason = 'construction_raises_exception' + if self.structure.volume < 0.1: + self.constructed = False + self.invalid_reason = 'unrealistically_small_lattice' + + def get_composition(self): + """get_composition + """ + elem_counter = Counter(self.atom_types) # pylint: disable=E1121 + composition = [(elem, elem_counter[elem]) + for elem in sorted(elem_counter.keys())] + elems, counts = list(zip(*composition)) + counts = np.array(counts) + counts = counts / np.gcd.reduce(counts) + self.elems = elems + self.comps = tuple(counts.astype('int').tolist()) + + def get_validity(self): + """get_validity + """ + self.comp_valid = smact_validity(self.elems, self.comps) + if self.constructed: + self.struct_valid = structure_validity(self.structure) + else: + self.struct_valid = False + self.valid = self.comp_valid and self.struct_valid + + def get_fingerprints(self): + """get_fingerprints + """ + elem_counter = Counter(self.atom_types) # pylint: disable=E1121 + comp = Composition(elem_counter) + self.comp_fp = comp_fp.featurize(comp) + try: + site_fps = [ + crystalnn_fp.featurize(self.structure, i) + for i in range(len(self.structure)) + ] + # pylint: disable=W0703 + except Exception: + # counts crystal as invalid if fingerprint cannot be constructed. + self.valid = False + self.comp_fp = None + self.struct_fp = None + return + self.struct_fp = np.array(site_fps).mean(axis=0) + + +def get_rms(pred_struc_list, gt_struc: Structure, num_eval, np_list): + """Calculate the rms distance between the ground truth and predicted crystal structures. + + Args: + pred_struc_list (List[Structure]): The crystals generated by diffution model + in the form of Structure. + gt_struc (Structure): The ground truth crystal. + num_eval (int): Specify that the first N items in the predicted List of crystal structures + participate in the evaluationo. + np_list (List[Dict]): The crystals generated by diffution model in the form of Dict. + """ + + def process_one(pred_struc: Structure): + try: + if not pred_struc.is_valid(): + return None + rms_dist = matcher.get_rms_dist(pred_struc, gt_struc) + rms_dist = None if rms_dist is None else rms_dist[0] + tune_rms = rms_dist + # pylint: disable=W0703 + except Exception: + tune_rms = None + return tune_rms + + min_rms = None + min_struc = None + for i, struct in enumerate(pred_struc_list): + if i == num_eval: + break + rms = process_one(struct) + if rms is not None and (min_rms is None or min_rms > rms): + min_rms = rms + min_struc = np_list[i] + return min_rms, min_struc + + +def get_struc_from_np_list(np_list): + """convert the crystal in the form of Dict to pymatgen.Structure + """ + result = [] + for cry_array in np_list: + try: + struct = Structure(lattice=Lattice.from_parameters( + *(cry_array['lengths'].tolist() + + cry_array['angles'].tolist())), + species=cry_array['atom_types'], + coords=cry_array['frac_coords'], + coords_are_cartesian=False) + # pylint: disable=W0703 + except Exception: + print('Warning: One anomalous crystal structure has captured and removed. ') + struct = None + + result.append(struct) + return result + +def main(args): + """main + """ + with open(args.config, 'r') as stream: + config = yaml.safe_load(stream) + + eval_file = config['test']['eval_save_path'] + num_eval = config['test']['num_eval'] + output_path = config['test']['metric_dir'] + + with open(eval_file, 'rb') as f: + eval_dict = pickle.load(f) + + pred_list = eval_dict['pred'] + gt_list = eval_dict['gt'] + gt_list = get_struc_from_np_list(gt_list) + rms = [] + + # calculate rmsd + for i in trange(len(gt_list)): + pred_struc = get_struc_from_np_list(pred_list[i]) + gt_struc = gt_list[i] + rms_single, struc_single = get_rms(pred_struc, gt_struc, num_eval, + pred_list[i]) + rms.append((rms_single, struc_single)) + + rms, struc_list = zip(*rms) + + # Remove the ones with RMSD as None, and store the valid structures in the list valid_crys. + rms_np = [] + valid_crys = [] + for i, rms_per in enumerate(rms): + if rms_per is not None: + rms_np.append(rms_per) + valid_crys.append(struc_list[i]) + + # Conduct rigorous structural verification, specifically through verification using the Crystal class. + print('Using the Crystal class for validity checks') + valid_list = p_map(lambda x: Crystal(x).valid, valid_crys) + rms_np_strict = [] + for i, is_valid in enumerate(valid_list): + if is_valid: + rms_np_strict.append(rms_np[i]) + + rms_np = np.array(rms_np_strict) + rms_valid_index = np.array([x is not None for x in rms_np_strict]) + + match_rate = rms_valid_index.sum() / len(gt_list) + rms = rms_np[rms_valid_index].mean() + + print('match_rate: ', match_rate) + print('rms: ', rms) + + all_metrics = {'match_rate': match_rate, 'rms_dist': rms} + + if Path(output_path).exists(): + metrics_out_file = f'eval_metrics_{num_eval}.json' + metrics_out_file = os.path.join(output_path, metrics_out_file) + + # only overwrite metrics computed in the new run. + if Path(metrics_out_file).exists(): + with open(metrics_out_file, 'r') as f: + written_metrics = json.load(f) + if isinstance(written_metrics, dict): + written_metrics.update(all_metrics) + else: + with open(metrics_out_file, 'w') as f: + json.dump(all_metrics, f) + if isinstance(written_metrics, dict): + with open(metrics_out_file, 'w') as f: + json.dump(written_metrics, f) + else: + with open(metrics_out_file, 'w') as f: + json.dump(all_metrics, f) + else: + print('Warning: The metric result file path is not specified') + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--config', default='config.yaml') + main_args = parser.parse_args() + main(main_args) diff --git a/MindChem/applications/diffcsp/config.yaml b/MindChem/applications/diffcsp/config.yaml new file mode 100644 index 000000000..6c2899a15 --- /dev/null +++ b/MindChem/applications/diffcsp/config.yaml @@ -0,0 +1,41 @@ +dataset: + data_name: 'mp_20' + train: + path: './dataset/mp_20/train.csv' + save_path: './dataset/mp_20/train.npy' + val: + path: './dataset/mp_20/val.csv' + save_path: './dataset/mp_20/val.npy' + test: + path: './dataset/mp_20/test.csv' + save_path: './dataset/mp_20/test.npy' + +model: + # For dataset carbon, mp, mpts + hidden_dim: 512 + num_layers: 6 + num_freqs: 128 + # # For dataset perov + # hidden_dim: 256 + # num_layers: 4 + # num_freqs: 10 + +train: + ckpt_dir: "./ckpt/mp_20" + # 3500, 4000, 1000, 1000 epochs for Perov-5, Carbon-24, MP-20 and MPTS-52 respectively. + epoch_size: 1000 + # 512, 512, 128, 128 for Perov-5, Carbon-24, MP-20 and MPTS-52 respectively. + batch_size: 256 + seed: 1234 + +checkpoint: + last_path: "./ckpt/mp_20/last_test.ckpt" + +test: + # 1024 for perov, 512 for carbon and mp, 256 for mpts + batch_size: 512 + num_eval: 1 + # 1e-5 for mp and mpts, 5e-7 for perov, 5e-6 for carbon num_eval=1 and 5e-7 for carbon num_eval=20 + step_lr: 1e-5 + eval_save_path: './ckpt/mp_20/predict_crys.pkl' + metric_dir: './ckpt/mp_20/' diff --git a/MindChemistry/applications/diffcsp/data/crysloader.py b/MindChem/applications/diffcsp/data/crysloader.py similarity index 100% rename from MindChemistry/applications/diffcsp/data/crysloader.py rename to MindChem/applications/diffcsp/data/crysloader.py diff --git a/MindChemistry/applications/diffcsp/data/data_utils.py b/MindChem/applications/diffcsp/data/data_utils.py similarity index 100% rename from MindChemistry/applications/diffcsp/data/data_utils.py rename to MindChem/applications/diffcsp/data/data_utils.py diff --git a/MindChem/applications/diffcsp/data/dataset.py b/MindChem/applications/diffcsp/data/dataset.py new file mode 100644 index 000000000..964a9a957 --- /dev/null +++ b/MindChem/applications/diffcsp/data/dataset.py @@ -0,0 +1,135 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""dataset file""" +import os +from pathlib import Path + +import numpy as np + +from data.data_utils import StandardScaler, preprocess + +def fullconnect_dataset(name, + path, + niggli=True, + primitive=False, + graph_method='none', + preprocess_workers=30, + save_path='', + nrows=-1): + """ + Read crystal data from a CSV file and convert each into a fully connected graph, + where the nodes represent atoms within the unit cell and the edges connect every pair of nodes. + + Args: + name (str): The name of dataset, mainly used to read the dataset + property in './dataset/dataset_prop.txt'. + It doesn't matter for crystal structure prediction task. + Choices: [perov_5, carbon_24, mp_20, mpts_52]. + Users can also create custom datasets, by modify the + './dataset/dataset_prop.txt'. + path (str): The path of csv file of dataset. + niggli (bool): Whether to use niggli algorithom to + preprocess the choice of lattice. Default: + ``True``. + primitive (bool): Whether to represent the crystal in primitive cell. Default: + ``False``. + graph_method (str): If 'crystalnn', construct the graph by crystalnn + algorithm, mainly effect the construct of edges. + If 'none', don't construct any edge. Default: ``none``. + preprocess_workers (int): The numbers of cpus used for + preprocessing the crystals. Default: ``None``. + save_path (str): The path for saving the preprocessed data, + aiming to load the dataset more quickly next time. + nrows (int): If nrows > 0, read the first 'nrows' lines of csv file. + If nrows = -1, read the whole csv file. + This arg is mainly for debugging to quickly load a few crystals. + Returns: + x (list): List of Atom types. Shape of each element i.e. numpy array: (num_atoms, 1) + frac_coord_list (list): List of Fractional Coordinates of atoms. + Shape of each element i.e. numpy array: (num_atoms, 3) + edge_attr (list): List of numpy arrays filled with ones, + just used to better construct the dataloader, + without numerical significance. Shape of each element + i.e. numpy array: (num_edges, 3) + edge_index (list): List of index of the beginning and end + of edges. Each element is composed as [src, dst], where + src and dst is numpy arrays with Shape (num_edges,). + lengths_list (list): List of lengths of lattice. Shape of + each element i.e. numpy array: (3,) + angles_list (list): List of angles of lattice. Shape of + each element i.e. numpy array: (3,) + labels (list): List of property of crystal. Shape of + each element i.e. numpy array: (1,) + """ + x = [] + frac_coord_list = [] + edge_index = [] + edge_attr = [] + labels = [] + lengths_list = [] + angles_list = [] + + if Path('./dataset/dataset_prop.txt').exists(): + with open('./dataset/dataset_prop.txt', 'r') as file: + data = file.read() + # pylint: disable=W0123 + scalar_dict = eval(data) + else: + scalar_dict = {} + + if name in scalar_dict.keys(): + prop = scalar_dict[name]['prop'] + scaler = StandardScaler(scalar_dict[name]['scaler.means'], + scalar_dict[name]['scaler.stds']) + else: + print('No dataset property is specified, so no property reading is performed') + prop = "None" + scaler = None + + if os.path.exists(save_path): + cached_data = np.load(save_path, allow_pickle=True) + else: + cached_data = preprocess(path, + preprocess_workers, + niggli=niggli, + primitive=primitive, + graph_method=graph_method, + prop_list=[prop], + nrows=nrows) + + np.save(save_path, cached_data) + + for idx in range(len(cached_data)): + data_dict = cached_data[idx] + (frac_coords, atom_types, lengths, angles, _, _, + num_atoms) = data_dict['graph_arrays'] + + indices = np.arange(num_atoms) + dst, src = np.meshgrid(indices, indices) + src = src.reshape(-1) + dst = dst.reshape(-1) + + x.append(atom_types.reshape(-1, 1)) + frac_coord_list.append(frac_coords) + edge_index.append(np.array([src, dst])) + edge_attr.append(np.ones((num_atoms * num_atoms, 3))) + lengths_list.append(lengths) + angles_list.append(angles) + if scaler is not None: + labels.append(scaler.transform(data_dict[prop])) + else: + labels.append(0.0) + + return x, frac_coord_list, edge_attr, edge_index, lengths_list, angles_list, labels diff --git a/MindChemistry/applications/diffcsp/evaluate.py b/MindChem/applications/diffcsp/evaluate.py similarity index 100% rename from MindChemistry/applications/diffcsp/evaluate.py rename to MindChem/applications/diffcsp/evaluate.py diff --git a/MindChemistry/applications/diffcsp/models/cspnet.py b/MindChem/applications/diffcsp/models/cspnet.py similarity index 100% rename from MindChemistry/applications/diffcsp/models/cspnet.py rename to MindChem/applications/diffcsp/models/cspnet.py diff --git a/MindChemistry/applications/diffcsp/models/diff_utils.py b/MindChem/applications/diffcsp/models/diff_utils.py similarity index 100% rename from MindChemistry/applications/diffcsp/models/diff_utils.py rename to MindChem/applications/diffcsp/models/diff_utils.py diff --git a/MindChemistry/applications/diffcsp/models/diffusion.py b/MindChem/applications/diffcsp/models/diffusion.py similarity index 100% rename from MindChemistry/applications/diffcsp/models/diffusion.py rename to MindChem/applications/diffcsp/models/diffusion.py diff --git a/MindChemistry/applications/diffcsp/models/infer_utils.py b/MindChem/applications/diffcsp/models/infer_utils.py similarity index 100% rename from MindChemistry/applications/diffcsp/models/infer_utils.py rename to MindChem/applications/diffcsp/models/infer_utils.py diff --git a/MindChemistry/applications/diffcsp/models/train_utils.py b/MindChem/applications/diffcsp/models/train_utils.py similarity index 100% rename from MindChemistry/applications/diffcsp/models/train_utils.py rename to MindChem/applications/diffcsp/models/train_utils.py diff --git a/MindChemistry/applications/diffcsp/requirement.txt b/MindChem/applications/diffcsp/requirement.txt similarity index 100% rename from MindChemistry/applications/diffcsp/requirement.txt rename to MindChem/applications/diffcsp/requirement.txt diff --git a/MindChem/applications/diffcsp/train.py b/MindChem/applications/diffcsp/train.py new file mode 100644 index 000000000..fd21606a6 --- /dev/null +++ b/MindChem/applications/diffcsp/train.py @@ -0,0 +1,208 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""train file""" +import os +import time +import logging +import argparse +import yaml +import numpy as np +import mindspore as ms +from mindspore import nn, set_seed +from mindspore.amp import all_finite +from mindchemistry.graph.loss import L2LossMask +from models.cspnet import CSPNet +from models.diffusion import CSPDiffusion +from models.train_utils import LossRecord +from data.dataset import fullconnect_dataset +from data.crysloader import Crysloader as DataLoader + +logging.basicConfig(level=logging.INFO) + +def parse_args(): + '''Parse input args''' + parser = argparse.ArgumentParser() + parser.add_argument('--config', default='config.yaml', help="The config file path") + parser.add_argument('--device_id', type=int, default=0, + help="ID of the target device") + parser.add_argument('--device_target', type=str, default='Ascend', choices=["GPU", "Ascend"], + help="The target device to run, support 'Ascend', 'GPU'") + input_args = parser.parse_args() + return input_args + +def main(): + args = parse_args() + ms.set_context(device_target=args.device_target, device_id=args.device_id) + + with open(args.config, 'r') as stream: + config = yaml.safe_load(stream) + + ckpt_dir = config['train']["ckpt_dir"] + + if not os.path.exists(ckpt_dir): + os.makedirs(ckpt_dir) + + set_seed(config['train']["seed"]) + + batch_size_max = config['train']['batch_size'] + + cspnet = CSPNet(num_layers=config['model']['num_layers'], hidden_dim=config['model']['hidden_dim'], + num_freqs=config['model']['num_freqs']) + + if os.path.exists(config['checkpoint']['last_path']): + logging.info("load from existing check point................") + param_dict = ms.load_checkpoint(config['checkpoint']['last_path']) + ms.load_param_into_net(cspnet, param_dict) + logging.info("finish load from existing checkpoint") + else: + logging.info("Starting new training process") + + diffcsp = CSPDiffusion(cspnet) + + model_parameters = filter(lambda p: p.requires_grad, diffcsp.get_parameters()) + params = sum(np.prod(p.shape) for p in model_parameters) + logging.info("The model you built has %s parameters.", params) + + optimizer = nn.Adam(params=diffcsp.trainable_params()) + loss_func_mse = L2LossMask(reduction='mean') + + def forward(atom_types_step, frac_coords_step, _, lengths_step, angles_step, edge_index_step, batch_node2graph, \ + node_mask_step, edge_mask_step, batch_mask, node_num_valid, batch_size_valid): + """forward""" + pred_l, rand_l, pred_x, tar_x = diffcsp(batch_size_valid, atom_types_step, lengths_step, + angles_step, frac_coords_step, batch_node2graph, edge_index_step, + node_mask_step, edge_mask_step, batch_mask) + mseloss_l = loss_func_mse(pred_l, rand_l, mask=batch_mask, num=batch_size_valid) + mseloss_x = loss_func_mse(pred_x, tar_x, mask=node_mask_step, num=node_num_valid) + mseloss = mseloss_l + mseloss_x + + return mseloss, mseloss_l, mseloss_x + + backward = ms.value_and_grad(forward, None, weights=diffcsp.trainable_params(), has_aux=True) + + @ms.jit + def train_step(atom_types_step, frac_coords_step, property_step, lengths_step, angles_step, + edge_index_step, batch_node2graph, node_mask_step, edge_mask_step, batch_mask, + node_num_valid, batch_size_valid): + """train step""" + (mseloss, mseloss_l, mseloss_x), grads = backward(atom_types_step, frac_coords_step, property_step, + lengths_step, angles_step, edge_index_step, batch_node2graph, + node_mask_step, edge_mask_step, batch_mask, node_num_valid, + batch_size_valid) + + is_finite = all_finite(grads) + if is_finite: + optimizer(grads) + + return mseloss, is_finite, mseloss_l, mseloss_x + + @ms.jit + def eval_step(atom_types_step, frac_coords_step, property_step, lengths_step, angles_step, + edge_index_step, batch_node2graph, + node_mask_step, edge_mask_step, batch_mask, node_num_valid, batch_size_valid): + """eval step""" + mseloss, mseloss_l, mseloss_x = forward(atom_types_step, frac_coords_step, property_step, lengths_step, + angles_step, edge_index_step, batch_node2graph, + node_mask_step, edge_mask_step, batch_mask, node_num_valid, + batch_size_valid) + return mseloss, mseloss_l, mseloss_x + + epoch = 0 + epoch_size = config['train']["epoch_size"] + + logging.info("Start to initialise train_loader") + train_datatset = fullconnect_dataset(name=config['dataset']["data_name"], path=config['dataset']["train"]["path"], + save_path=config['dataset']["train"]["save_path"]) + train_loader = DataLoader(batch_size_max, *train_datatset, shuffle_dataset=True) + logging.info("Start to initialise eval_loader") + val_datatset = fullconnect_dataset(name=config['dataset']["data_name"], path=config['dataset']["val"]["path"], + save_path=config['dataset']["val"]["save_path"]) + eval_loader = DataLoader(batch_size_max, *val_datatset, + dynamic_batch_size=False, shuffle_dataset=True) + + while epoch < epoch_size: + epoch_starttime = time.time() + + train_mseloss_record = LossRecord() + eval_mseloss_record = LossRecord() + + #################################################### train ################################################### + logging.info("+++++++++++++++ start traning +++++++++++++++++++++") + diffcsp.set_train(True) + + starttime = time.time() + record_iter = 0 + for atom_types_batch, frac_coords_batch, property_batch, lengths_batch, angles_batch,\ + edge_index_batch, batch_node2graph_, node_mask_batch, edge_mask_batch, batch_mask_batch,\ + node_num_valid_, batch_size_valid_ in train_loader: + + result = train_step(atom_types_batch, frac_coords_batch, property_batch, + lengths_batch, angles_batch, edge_index_batch, batch_node2graph_, + node_mask_batch, edge_mask_batch, batch_mask_batch, node_num_valid_, + batch_size_valid_) + + mseloss_step, _, mseloss_l_, mseloss_x_ = result + + if record_iter % 50 == 0: + logging.info("==============================step: %s ,epoch: %s", train_loader.step - 1, epoch) + logging.info("learning rate: %s", optimizer.learning_rate.value()) + logging.info("train mse loss: %s", mseloss_step) + logging.info("train mse_lattice loss: %s", mseloss_l_) + logging.info("train mse_coords loss: %s", mseloss_x_) + starttime0 = starttime + starttime = time.time() + logging.info("traning time: %s", starttime - starttime0) + + record_iter += 1 + + train_mseloss_record.update(mseloss_step) + + #################################################### finish train ######################################## + epoch_endtime = time.time() + logging.info("epoch %s running time: %s", epoch, epoch_endtime - epoch_starttime) + logging.info("epoch %s average train mse loss: %s", epoch, train_mseloss_record.avg) + + ms.save_checkpoint(diffcsp.decoder, config['checkpoint']['last_path']) + + if epoch % 5 == 0: + #################################################### validation ########################################## + logging.info("+++++++++++++++ start validation +++++++++++++++++++++") + diffcsp.set_train(False) + + starttime = time.time() + for atom_types_batch, frac_coords_batch, property_batch, lengths_batch, angles_batch,\ + edge_index_batch, batch_node2graph_, node_mask_batch, edge_mask_batch, batch_mask_batch,\ + node_num_valid_, batch_size_valid_ in eval_loader: + + result_e = eval_step(atom_types_batch, frac_coords_batch, property_batch, + lengths_batch, angles_batch, edge_index_batch, batch_node2graph_, + node_mask_batch, edge_mask_batch, batch_mask_batch, node_num_valid_, + batch_size_valid_) + + mseloss_step, mseloss_l_, mseloss_x_ = result_e + + eval_mseloss_record.update(mseloss_step) + + #################################################### finish validation ################################# + + starttime0 = starttime + starttime = time.time() + logging.info("validation time: %s", starttime - starttime0) + logging.info("epoch %s average validation mse loss: %s", epoch, eval_mseloss_record.avg) + + epoch = epoch + 1 + +if __name__ == '__main__': + main() -- Gitee From 36f291b9b0bdedf1bd0ca9b326f7a6532c638783 Mon Sep 17 00:00:00 2001 From: birfied Date: Thu, 18 Sep 2025 11:01:03 +0800 Subject: [PATCH 02/21] move files to MindChem --- MindChemistry/applications/diffcsp/README.md | 127 ------- .../applications/diffcsp/compute_metric.py | 327 ------------------ .../applications/diffcsp/config.yaml | 41 --- .../applications/diffcsp/data/dataset.py | 135 -------- MindChemistry/applications/diffcsp/train.py | 208 ----------- 5 files changed, 838 deletions(-) delete mode 100644 MindChemistry/applications/diffcsp/README.md delete mode 100644 MindChemistry/applications/diffcsp/compute_metric.py delete mode 100644 MindChemistry/applications/diffcsp/config.yaml delete mode 100644 MindChemistry/applications/diffcsp/data/dataset.py delete mode 100644 MindChemistry/applications/diffcsp/train.py diff --git a/MindChemistry/applications/diffcsp/README.md b/MindChemistry/applications/diffcsp/README.md deleted file mode 100644 index 858aa417c..000000000 --- a/MindChemistry/applications/diffcsp/README.md +++ /dev/null @@ -1,127 +0,0 @@ - -# 模型名称 - -> DiffCSP - -## 介绍 - -> DiffCSP是一种基于扩散模型的深度学习框架,用于解决晶体结构预测这一基础科学难题。其核心思想是将寻找稳定晶体结构的过程转化为一个生成问题:模型通过学习海量已知晶体数据中的分布规律,能够仅根据材料的化学成分(原子种类与比例),直接、快速地生成合理的三维原子结构(包括晶格和原子坐标)。与传统依赖大量量子力学计算的方法相比,DiffCSP的关键创新在于采用了​​SE(3)-等变图神经网络​​并融入了​​周期性边界条件​​,确保了生成的结构严格遵守物理对称性,从而能以极高的效率探索材料的多态性,为新材料的加速发现与设计提供了强大工具。 - -## 环境要求 - -> 1. 安装`mindspore(2.3.0)` -> 2. 安装依赖包:`pip install -r requirement.txt` - -## 快速入门 - -> 1. 将Mindchemistry/mindchemistry文件包下载到当前目录 -> 2. 在[数据集链接](https://download-mindspore.osinfra.cn/mindscience/mindchemistry/diffcsp/)下载相应的数据集 -> 3. 安装依赖包:`pip install -r requirement.txt` -> 4. 训练命令: `python train.py` -> 5. 预测命令: `python evaluate.py` -> 6. 评估命令: `python compute_metric.py` -> 7. 评估结果放在`config.yaml`中指定的`metric_dir`路径的json文件中 - -### 代码目录结构 - -```txt -diffcsp - │ README.md README文件 - │ config.yaml 配置文件 - │ train.py 训练启动脚本 - │ evaluate.py 推理启动脚本 - │ compute_metric.py 评估启动脚本 - │ requirement.txt 环境依赖 - │ - └─data - data_utils.py 数据集处理工具 - dataset.py 读取数据集 - crysloader.py 数据集载入器 - └─models - cspnet.py 基于图神经网络的去噪器模块 - diffusion.py 扩散模型模块 - diff_utils.py 工具模块 - infer_utils.py 推理工具模块 - train_utils.py 训练工具模块 - -``` - -## 下载数据集 - -在[数据集链接](https://download-mindspore.osinfra.cn/mindscience/mindchemistry/diffcsp/)中下载相应的数据集文件夹和dataset_prop.txt数据集属性文件放置于当前路径的dataset文件夹下(如果没有需要自己手动创建),文件路径参考: - -```txt -diffcsp - ... - └─dataset - perov_5 钙钛矿数据集 - carbon_24 碳晶体数据集 - mp_20 晶胞内原子数最多为20的MP数据集 - mpts_52 晶胞内原子数最多为52的MP数据集 - dataset_prop.txt 数据集属性文件 - ... -``` - -## 训练过程 - -### 训练 - -将Mindchemistry/mindchemistry文件包下载到当前目录; - -更改config文件,设置训练参数: -> 1. 设置训练的dataset,见dataset字段 -> 2. 设置去噪器模型的配置,见model字段 -> 3. 设置训练保存的权重文件,更改train.ckpt_dir文件夹名称和checkpoint.last_path权重文件名称 -> 4. 其它训练设置见train字段 - -```bash -pip install -r requirement.txt -python train.py -``` - -### 推理 - -将权重的path写入config文件的checkpoint.last_path中。预训练模型可以从[预训练模型链接](https://download-mindspore.osinfra.cn/mindscience/mindchemistry/diffcsp/pre-train)中获取。 - -更改config文件中的test字段来更改推理参数,特别是test.num_eval,它**决定了对于每个组分生成多少个样本**,对于后续的评估阶段很重要。 - -```bash -python evaluate.py -``` - -推理得到的晶体将保存在test.eval_save_path指定的文件中 - -文件中存储的内容为python字典,格式为: - -```python -{ - 'pred': [ - [晶体A sample 1, 晶体A sample 2, 晶体A sample 3, ... 晶体A sample num_eval], - [晶体B sample 1, 晶体B sample 2, 晶体B sample 3, ... 晶体B sample num_eval] - ... - ] - 'gt': [ - 晶体A ground truth, - 晶体B ground truth, - ... - ] -} -``` - -### 评估 - -将推理得到的晶体文件的path写入config文件的test.eval_save_path中; - -确保num_evals与进行推理时设置的对于每个组分生成样本的数量一致或更小。比如进行推理时,num_evals设置为1,那么评估时,num_evals只能设置为1;推理时,num_evals设置为20,那么评估时,num_evals可以设置为1-20的数字来进行评估。 - -更改config文件中的test.metric_dir字段来设置评估结果的保存路径 - -```bash -python compute_metric.py -``` - -得到的评估结果文件示例: - -```json -{"match_rate": 0.985997357992074, "rms_dist": 0.013073775170360118} -``` diff --git a/MindChemistry/applications/diffcsp/compute_metric.py b/MindChemistry/applications/diffcsp/compute_metric.py deleted file mode 100644 index df5539aad..000000000 --- a/MindChemistry/applications/diffcsp/compute_metric.py +++ /dev/null @@ -1,327 +0,0 @@ -# Copyright 2024 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""compute metric file""" -import itertools -import json -import os -import pickle -from collections import Counter -from pathlib import Path -import argparse -import yaml - -import numpy as np -from matminer.featurizers.composition.composite import ElementProperty -from matminer.featurizers.site.fingerprint import CrystalNNFingerprint -from p_tqdm import p_map -from pymatgen.analysis.structure_matcher import StructureMatcher -from pymatgen.core.composition import Composition -from pymatgen.core.lattice import Lattice -from pymatgen.core.structure import Structure -import smact -from smact.screening import pauling_test -from tqdm import trange - -from models.infer_utils import chemical_symbols - -matcher = StructureMatcher(stol=0.5, angle_tol=10, ltol=0.3) -crystalnn_fp = CrystalNNFingerprint.from_preset("ops") -comp_fp = ElementProperty.from_preset('magpie') - - -def smact_validity(comp, count, use_pauling_test=True, include_alloys=True): - """Smact validity. See details in the paper Crystal Diffution Variational Autoencoder and - its codebase. - """ - elem_symbols = tuple([chemical_symbols[elem] for elem in comp]) - space = smact.element_dictionary(elem_symbols) - smact_elems = [e[1] for e in space.items()] - electronegs = [e.pauling_eneg for e in smact_elems] - ox_combos = [e.oxidation_states for e in smact_elems] - if len(set(elem_symbols)) == 1: - return True - if include_alloys: - is_metal_list = [elem_s in smact.metals for elem_s in elem_symbols] - if all(is_metal_list): - return True - - threshold = np.max(count) - oxn = 1 - for oxc in ox_combos: - oxn *= len(oxc) - if oxn > 1e7: - return False - for ox_states in itertools.product(*ox_combos): - stoichs = [(c,) for c in count] - # Test for charge balance - cn_e, _ = smact.neutral_ratios(ox_states, - stoichs=stoichs, - threshold=threshold) - # Electronegativity test - if cn_e: - if use_pauling_test: - try: - electroneg_ok = pauling_test(ox_states, electronegs) - except TypeError: - # if no electronegativity data, assume it is okay - electroneg_ok = True - else: - electroneg_ok = True - if electroneg_ok: - return True - return False - - -def structure_validity(crystal, cutoff=0.5): - """Structure validity. See details in the paper Crystal Diffution Variational Autoencoder and - its codebase. - """ - dist_mat = crystal.distance_matrix - # Pad diagonal with a large number - dist_mat = dist_mat + np.diag(np.ones(dist_mat.shape[0]) * (cutoff + 10.)) - if dist_mat.min() < cutoff or crystal.volume < 0.1 or max( - crystal.lattice.abc) > 40: - return False - - return True - -class Crystal: - """Strict crystal validity. See details in the paper CDVAE `Crystal - Diffution Variational Autoencoder` and - its codebase. We adopt the same evaluation metric criteria as CDVAE. - """ - - def __init__(self, crys_array_dict): - self.frac_coords = crys_array_dict['frac_coords'] - self.atom_types = crys_array_dict['atom_types'] - self.lengths = crys_array_dict['lengths'] - self.angles = crys_array_dict['angles'] - self.dict = crys_array_dict - if len(self.atom_types.shape) > 1: - self.dict['atom_types'] = np.argmax(self.atom_types, axis=-1) + 1 - self.atom_types = np.argmax(self.atom_types, axis=-1) + 1 - - self.get_structure() - self.get_composition() - self.get_validity() - self.get_fingerprints() - - def get_structure(self): - """get_structure - """ - if min(self.lengths.tolist()) < 0: - self.constructed = False - self.invalid_reason = 'non_positive_lattice' - if np.isnan(self.lengths).any() or np.isnan( - self.angles).any() or np.isnan(self.frac_coords).any(): - self.constructed = False - self.invalid_reason = 'nan_value' - else: - try: - self.structure = Structure(lattice=Lattice.from_parameters( - *(self.lengths.tolist() + self.angles.tolist())), - species=self.atom_types, - coords=self.frac_coords, - coords_are_cartesian=False) - self.constructed = True - # pylint: disable=W0703 - except Exception: - self.constructed = False - self.invalid_reason = 'construction_raises_exception' - if self.structure.volume < 0.1: - self.constructed = False - self.invalid_reason = 'unrealistically_small_lattice' - - def get_composition(self): - """get_composition - """ - elem_counter = Counter(self.atom_types) # pylint: disable=E1121 - composition = [(elem, elem_counter[elem]) - for elem in sorted(elem_counter.keys())] - elems, counts = list(zip(*composition)) - counts = np.array(counts) - counts = counts / np.gcd.reduce(counts) - self.elems = elems - self.comps = tuple(counts.astype('int').tolist()) - - def get_validity(self): - """get_validity - """ - self.comp_valid = smact_validity(self.elems, self.comps) - if self.constructed: - self.struct_valid = structure_validity(self.structure) - else: - self.struct_valid = False - self.valid = self.comp_valid and self.struct_valid - - def get_fingerprints(self): - """get_fingerprints - """ - elem_counter = Counter(self.atom_types) # pylint: disable=E1121 - comp = Composition(elem_counter) - self.comp_fp = comp_fp.featurize(comp) - try: - site_fps = [ - crystalnn_fp.featurize(self.structure, i) - for i in range(len(self.structure)) - ] - # pylint: disable=W0703 - except Exception: - # counts crystal as invalid if fingerprint cannot be constructed. - self.valid = False - self.comp_fp = None - self.struct_fp = None - return - self.struct_fp = np.array(site_fps).mean(axis=0) - - -def get_rms(pred_struc_list, gt_struc: Structure, num_eval, np_list): - """Calculate the rms distance between the ground truth and predicted crystal structures. - - Args: - pred_struc_list (List[Structure]): The crystals generated by diffution model - in the form of Structure. - gt_struc (Structure): The ground truth crystal. - num_eval (int): Specify that the first N items in the predicted List of crystal structures - participate in the evaluationo. - np_list (List[Dict]): The crystals generated by diffution model in the form of Dict. - """ - - def process_one(pred_struc: Structure): - try: - if not pred_struc.is_valid(): - return None - rms_dist = matcher.get_rms_dist(pred_struc, gt_struc) - rms_dist = None if rms_dist is None else rms_dist[0] - tune_rms = rms_dist - # pylint: disable=W0703 - except Exception: - tune_rms = None - return tune_rms - - min_rms = None - min_struc = None - for i, struct in enumerate(pred_struc_list): - if i == num_eval: - break - rms = process_one(struct) - if rms is not None and (min_rms is None or min_rms > rms): - min_rms = rms - min_struc = np_list[i] - return min_rms, min_struc - - -def get_struc_from_np_list(np_list): - """convert the crystal in the form of Dict to pymatgen.Structure - """ - result = [] - for cry_array in np_list: - try: - struct = Structure(lattice=Lattice.from_parameters( - *(cry_array['lengths'].tolist() + - cry_array['angles'].tolist())), - species=cry_array['atom_types'], - coords=cry_array['frac_coords'], - coords_are_cartesian=False) - # pylint: disable=W0703 - except Exception: - print('Warning: One anomalous crystal structure has captured and removed. ') - struct = None - - result.append(struct) - return result - -def main(args): - """main - """ - with open(args.config, 'r') as stream: - config = yaml.safe_load(stream) - - eval_file = config['test']['eval_save_path'] - num_eval = config['test']['num_eval'] - output_path = config['test']['metric_dir'] - - with open(eval_file, 'rb') as f: - eval_dict = pickle.load(f) - - pred_list = eval_dict['pred'] - gt_list = eval_dict['gt'] - gt_list = get_struc_from_np_list(gt_list) - rms = [] - - # calculate rmsd - for i in trange(len(gt_list)): - pred_struc = get_struc_from_np_list(pred_list[i]) - gt_struc = gt_list[i] - rms_single, struc_single = get_rms(pred_struc, gt_struc, num_eval, - pred_list[i]) - rms.append((rms_single, struc_single)) - - rms, struc_list = zip(*rms) - - # Remove the ones with RMSD as None, and store the valid structures in the list valid_crys. - rms_np = [] - valid_crys = [] - for i, rms_per in enumerate(rms): - if rms_per is not None: - rms_np.append(rms_per) - valid_crys.append(struc_list[i]) - - # Conduct rigorous structural verification, specifically through verification using the Crystal class. - print('Using the Crystal class for validity checks') - valid_list = p_map(lambda x: Crystal(x).valid, valid_crys) - rms_np_strict = [] - for i, is_valid in enumerate(valid_list): - if is_valid: - rms_np_strict.append(rms_np[i]) - - rms_np = np.array(rms_np_strict) - rms_valid_index = np.array([x is not None for x in rms_np_strict]) - - match_rate = rms_valid_index.sum() / len(gt_list) - rms = rms_np[rms_valid_index].mean() - - print('match_rate: ', match_rate) - print('rms: ', rms) - - all_metrics = {'match_rate': match_rate, 'rms_dist': rms} - - if Path(output_path).exists(): - metrics_out_file = f'eval_metrics_{num_eval}.json' - metrics_out_file = os.path.join(output_path, metrics_out_file) - - # only overwrite metrics computed in the new run. - if Path(metrics_out_file).exists(): - with open(metrics_out_file, 'r') as f: - written_metrics = json.load(f) - if isinstance(written_metrics, dict): - written_metrics.update(all_metrics) - else: - with open(metrics_out_file, 'w') as f: - json.dump(all_metrics, f) - if isinstance(written_metrics, dict): - with open(metrics_out_file, 'w') as f: - json.dump(written_metrics, f) - else: - with open(metrics_out_file, 'w') as f: - json.dump(all_metrics, f) - else: - print('Warning: The metric result file path is not specified') - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('--config', default='config.yaml') - main_args = parser.parse_args() - main(main_args) diff --git a/MindChemistry/applications/diffcsp/config.yaml b/MindChemistry/applications/diffcsp/config.yaml deleted file mode 100644 index 6c2899a15..000000000 --- a/MindChemistry/applications/diffcsp/config.yaml +++ /dev/null @@ -1,41 +0,0 @@ -dataset: - data_name: 'mp_20' - train: - path: './dataset/mp_20/train.csv' - save_path: './dataset/mp_20/train.npy' - val: - path: './dataset/mp_20/val.csv' - save_path: './dataset/mp_20/val.npy' - test: - path: './dataset/mp_20/test.csv' - save_path: './dataset/mp_20/test.npy' - -model: - # For dataset carbon, mp, mpts - hidden_dim: 512 - num_layers: 6 - num_freqs: 128 - # # For dataset perov - # hidden_dim: 256 - # num_layers: 4 - # num_freqs: 10 - -train: - ckpt_dir: "./ckpt/mp_20" - # 3500, 4000, 1000, 1000 epochs for Perov-5, Carbon-24, MP-20 and MPTS-52 respectively. - epoch_size: 1000 - # 512, 512, 128, 128 for Perov-5, Carbon-24, MP-20 and MPTS-52 respectively. - batch_size: 256 - seed: 1234 - -checkpoint: - last_path: "./ckpt/mp_20/last_test.ckpt" - -test: - # 1024 for perov, 512 for carbon and mp, 256 for mpts - batch_size: 512 - num_eval: 1 - # 1e-5 for mp and mpts, 5e-7 for perov, 5e-6 for carbon num_eval=1 and 5e-7 for carbon num_eval=20 - step_lr: 1e-5 - eval_save_path: './ckpt/mp_20/predict_crys.pkl' - metric_dir: './ckpt/mp_20/' diff --git a/MindChemistry/applications/diffcsp/data/dataset.py b/MindChemistry/applications/diffcsp/data/dataset.py deleted file mode 100644 index 964a9a957..000000000 --- a/MindChemistry/applications/diffcsp/data/dataset.py +++ /dev/null @@ -1,135 +0,0 @@ -# Copyright 2024 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""dataset file""" -import os -from pathlib import Path - -import numpy as np - -from data.data_utils import StandardScaler, preprocess - -def fullconnect_dataset(name, - path, - niggli=True, - primitive=False, - graph_method='none', - preprocess_workers=30, - save_path='', - nrows=-1): - """ - Read crystal data from a CSV file and convert each into a fully connected graph, - where the nodes represent atoms within the unit cell and the edges connect every pair of nodes. - - Args: - name (str): The name of dataset, mainly used to read the dataset - property in './dataset/dataset_prop.txt'. - It doesn't matter for crystal structure prediction task. - Choices: [perov_5, carbon_24, mp_20, mpts_52]. - Users can also create custom datasets, by modify the - './dataset/dataset_prop.txt'. - path (str): The path of csv file of dataset. - niggli (bool): Whether to use niggli algorithom to - preprocess the choice of lattice. Default: - ``True``. - primitive (bool): Whether to represent the crystal in primitive cell. Default: - ``False``. - graph_method (str): If 'crystalnn', construct the graph by crystalnn - algorithm, mainly effect the construct of edges. - If 'none', don't construct any edge. Default: ``none``. - preprocess_workers (int): The numbers of cpus used for - preprocessing the crystals. Default: ``None``. - save_path (str): The path for saving the preprocessed data, - aiming to load the dataset more quickly next time. - nrows (int): If nrows > 0, read the first 'nrows' lines of csv file. - If nrows = -1, read the whole csv file. - This arg is mainly for debugging to quickly load a few crystals. - Returns: - x (list): List of Atom types. Shape of each element i.e. numpy array: (num_atoms, 1) - frac_coord_list (list): List of Fractional Coordinates of atoms. - Shape of each element i.e. numpy array: (num_atoms, 3) - edge_attr (list): List of numpy arrays filled with ones, - just used to better construct the dataloader, - without numerical significance. Shape of each element - i.e. numpy array: (num_edges, 3) - edge_index (list): List of index of the beginning and end - of edges. Each element is composed as [src, dst], where - src and dst is numpy arrays with Shape (num_edges,). - lengths_list (list): List of lengths of lattice. Shape of - each element i.e. numpy array: (3,) - angles_list (list): List of angles of lattice. Shape of - each element i.e. numpy array: (3,) - labels (list): List of property of crystal. Shape of - each element i.e. numpy array: (1,) - """ - x = [] - frac_coord_list = [] - edge_index = [] - edge_attr = [] - labels = [] - lengths_list = [] - angles_list = [] - - if Path('./dataset/dataset_prop.txt').exists(): - with open('./dataset/dataset_prop.txt', 'r') as file: - data = file.read() - # pylint: disable=W0123 - scalar_dict = eval(data) - else: - scalar_dict = {} - - if name in scalar_dict.keys(): - prop = scalar_dict[name]['prop'] - scaler = StandardScaler(scalar_dict[name]['scaler.means'], - scalar_dict[name]['scaler.stds']) - else: - print('No dataset property is specified, so no property reading is performed') - prop = "None" - scaler = None - - if os.path.exists(save_path): - cached_data = np.load(save_path, allow_pickle=True) - else: - cached_data = preprocess(path, - preprocess_workers, - niggli=niggli, - primitive=primitive, - graph_method=graph_method, - prop_list=[prop], - nrows=nrows) - - np.save(save_path, cached_data) - - for idx in range(len(cached_data)): - data_dict = cached_data[idx] - (frac_coords, atom_types, lengths, angles, _, _, - num_atoms) = data_dict['graph_arrays'] - - indices = np.arange(num_atoms) - dst, src = np.meshgrid(indices, indices) - src = src.reshape(-1) - dst = dst.reshape(-1) - - x.append(atom_types.reshape(-1, 1)) - frac_coord_list.append(frac_coords) - edge_index.append(np.array([src, dst])) - edge_attr.append(np.ones((num_atoms * num_atoms, 3))) - lengths_list.append(lengths) - angles_list.append(angles) - if scaler is not None: - labels.append(scaler.transform(data_dict[prop])) - else: - labels.append(0.0) - - return x, frac_coord_list, edge_attr, edge_index, lengths_list, angles_list, labels diff --git a/MindChemistry/applications/diffcsp/train.py b/MindChemistry/applications/diffcsp/train.py deleted file mode 100644 index fd21606a6..000000000 --- a/MindChemistry/applications/diffcsp/train.py +++ /dev/null @@ -1,208 +0,0 @@ -# Copyright 2024 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""train file""" -import os -import time -import logging -import argparse -import yaml -import numpy as np -import mindspore as ms -from mindspore import nn, set_seed -from mindspore.amp import all_finite -from mindchemistry.graph.loss import L2LossMask -from models.cspnet import CSPNet -from models.diffusion import CSPDiffusion -from models.train_utils import LossRecord -from data.dataset import fullconnect_dataset -from data.crysloader import Crysloader as DataLoader - -logging.basicConfig(level=logging.INFO) - -def parse_args(): - '''Parse input args''' - parser = argparse.ArgumentParser() - parser.add_argument('--config', default='config.yaml', help="The config file path") - parser.add_argument('--device_id', type=int, default=0, - help="ID of the target device") - parser.add_argument('--device_target', type=str, default='Ascend', choices=["GPU", "Ascend"], - help="The target device to run, support 'Ascend', 'GPU'") - input_args = parser.parse_args() - return input_args - -def main(): - args = parse_args() - ms.set_context(device_target=args.device_target, device_id=args.device_id) - - with open(args.config, 'r') as stream: - config = yaml.safe_load(stream) - - ckpt_dir = config['train']["ckpt_dir"] - - if not os.path.exists(ckpt_dir): - os.makedirs(ckpt_dir) - - set_seed(config['train']["seed"]) - - batch_size_max = config['train']['batch_size'] - - cspnet = CSPNet(num_layers=config['model']['num_layers'], hidden_dim=config['model']['hidden_dim'], - num_freqs=config['model']['num_freqs']) - - if os.path.exists(config['checkpoint']['last_path']): - logging.info("load from existing check point................") - param_dict = ms.load_checkpoint(config['checkpoint']['last_path']) - ms.load_param_into_net(cspnet, param_dict) - logging.info("finish load from existing checkpoint") - else: - logging.info("Starting new training process") - - diffcsp = CSPDiffusion(cspnet) - - model_parameters = filter(lambda p: p.requires_grad, diffcsp.get_parameters()) - params = sum(np.prod(p.shape) for p in model_parameters) - logging.info("The model you built has %s parameters.", params) - - optimizer = nn.Adam(params=diffcsp.trainable_params()) - loss_func_mse = L2LossMask(reduction='mean') - - def forward(atom_types_step, frac_coords_step, _, lengths_step, angles_step, edge_index_step, batch_node2graph, \ - node_mask_step, edge_mask_step, batch_mask, node_num_valid, batch_size_valid): - """forward""" - pred_l, rand_l, pred_x, tar_x = diffcsp(batch_size_valid, atom_types_step, lengths_step, - angles_step, frac_coords_step, batch_node2graph, edge_index_step, - node_mask_step, edge_mask_step, batch_mask) - mseloss_l = loss_func_mse(pred_l, rand_l, mask=batch_mask, num=batch_size_valid) - mseloss_x = loss_func_mse(pred_x, tar_x, mask=node_mask_step, num=node_num_valid) - mseloss = mseloss_l + mseloss_x - - return mseloss, mseloss_l, mseloss_x - - backward = ms.value_and_grad(forward, None, weights=diffcsp.trainable_params(), has_aux=True) - - @ms.jit - def train_step(atom_types_step, frac_coords_step, property_step, lengths_step, angles_step, - edge_index_step, batch_node2graph, node_mask_step, edge_mask_step, batch_mask, - node_num_valid, batch_size_valid): - """train step""" - (mseloss, mseloss_l, mseloss_x), grads = backward(atom_types_step, frac_coords_step, property_step, - lengths_step, angles_step, edge_index_step, batch_node2graph, - node_mask_step, edge_mask_step, batch_mask, node_num_valid, - batch_size_valid) - - is_finite = all_finite(grads) - if is_finite: - optimizer(grads) - - return mseloss, is_finite, mseloss_l, mseloss_x - - @ms.jit - def eval_step(atom_types_step, frac_coords_step, property_step, lengths_step, angles_step, - edge_index_step, batch_node2graph, - node_mask_step, edge_mask_step, batch_mask, node_num_valid, batch_size_valid): - """eval step""" - mseloss, mseloss_l, mseloss_x = forward(atom_types_step, frac_coords_step, property_step, lengths_step, - angles_step, edge_index_step, batch_node2graph, - node_mask_step, edge_mask_step, batch_mask, node_num_valid, - batch_size_valid) - return mseloss, mseloss_l, mseloss_x - - epoch = 0 - epoch_size = config['train']["epoch_size"] - - logging.info("Start to initialise train_loader") - train_datatset = fullconnect_dataset(name=config['dataset']["data_name"], path=config['dataset']["train"]["path"], - save_path=config['dataset']["train"]["save_path"]) - train_loader = DataLoader(batch_size_max, *train_datatset, shuffle_dataset=True) - logging.info("Start to initialise eval_loader") - val_datatset = fullconnect_dataset(name=config['dataset']["data_name"], path=config['dataset']["val"]["path"], - save_path=config['dataset']["val"]["save_path"]) - eval_loader = DataLoader(batch_size_max, *val_datatset, - dynamic_batch_size=False, shuffle_dataset=True) - - while epoch < epoch_size: - epoch_starttime = time.time() - - train_mseloss_record = LossRecord() - eval_mseloss_record = LossRecord() - - #################################################### train ################################################### - logging.info("+++++++++++++++ start traning +++++++++++++++++++++") - diffcsp.set_train(True) - - starttime = time.time() - record_iter = 0 - for atom_types_batch, frac_coords_batch, property_batch, lengths_batch, angles_batch,\ - edge_index_batch, batch_node2graph_, node_mask_batch, edge_mask_batch, batch_mask_batch,\ - node_num_valid_, batch_size_valid_ in train_loader: - - result = train_step(atom_types_batch, frac_coords_batch, property_batch, - lengths_batch, angles_batch, edge_index_batch, batch_node2graph_, - node_mask_batch, edge_mask_batch, batch_mask_batch, node_num_valid_, - batch_size_valid_) - - mseloss_step, _, mseloss_l_, mseloss_x_ = result - - if record_iter % 50 == 0: - logging.info("==============================step: %s ,epoch: %s", train_loader.step - 1, epoch) - logging.info("learning rate: %s", optimizer.learning_rate.value()) - logging.info("train mse loss: %s", mseloss_step) - logging.info("train mse_lattice loss: %s", mseloss_l_) - logging.info("train mse_coords loss: %s", mseloss_x_) - starttime0 = starttime - starttime = time.time() - logging.info("traning time: %s", starttime - starttime0) - - record_iter += 1 - - train_mseloss_record.update(mseloss_step) - - #################################################### finish train ######################################## - epoch_endtime = time.time() - logging.info("epoch %s running time: %s", epoch, epoch_endtime - epoch_starttime) - logging.info("epoch %s average train mse loss: %s", epoch, train_mseloss_record.avg) - - ms.save_checkpoint(diffcsp.decoder, config['checkpoint']['last_path']) - - if epoch % 5 == 0: - #################################################### validation ########################################## - logging.info("+++++++++++++++ start validation +++++++++++++++++++++") - diffcsp.set_train(False) - - starttime = time.time() - for atom_types_batch, frac_coords_batch, property_batch, lengths_batch, angles_batch,\ - edge_index_batch, batch_node2graph_, node_mask_batch, edge_mask_batch, batch_mask_batch,\ - node_num_valid_, batch_size_valid_ in eval_loader: - - result_e = eval_step(atom_types_batch, frac_coords_batch, property_batch, - lengths_batch, angles_batch, edge_index_batch, batch_node2graph_, - node_mask_batch, edge_mask_batch, batch_mask_batch, node_num_valid_, - batch_size_valid_) - - mseloss_step, mseloss_l_, mseloss_x_ = result_e - - eval_mseloss_record.update(mseloss_step) - - #################################################### finish validation ################################# - - starttime0 = starttime - starttime = time.time() - logging.info("validation time: %s", starttime - starttime0) - logging.info("epoch %s average validation mse loss: %s", epoch, eval_mseloss_record.avg) - - epoch = epoch + 1 - -if __name__ == '__main__': - main() -- Gitee From 461988f077689f5c81f341c248dc69868875d276 Mon Sep 17 00:00:00 2001 From: birfied Date: Thu, 18 Sep 2025 11:53:59 +0800 Subject: [PATCH 03/21] add e3nn --- mindscience/e3nn/__init__.py | 14 +- mindscience/e3nn/nn/__init__.py | 35 + mindscience/e3nn/nn/activation.py | 147 ++++ mindscience/e3nn/nn/batchnorm.py | 181 +++++ mindscience/e3nn/nn/fc.py | 110 +++ mindscience/e3nn/nn/gate.py | 182 +++++ mindscience/e3nn/nn/normact.py | 128 ++++ mindscience/e3nn/nn/one_hot.py | 238 +++++++ mindscience/e3nn/nn/scatter.py | 74 ++ mindscience/e3nn/o3/__init__.py | 51 ++ mindscience/e3nn/o3/irreps.py | 761 ++++++++++++++++++++ mindscience/e3nn/o3/norm.py | 81 +++ mindscience/e3nn/o3/rotation.py | 387 +++++++++++ mindscience/e3nn/o3/spherical_harmonics.py | 679 ++++++++++++++++++ mindscience/e3nn/o3/sub.py | 503 ++++++++++++++ mindscience/e3nn/o3/tensor_product.py | 768 +++++++++++++++++++++ mindscience/e3nn/o3/wigner.py | 336 +++++++++ mindscience/e3nn/utils/__init__.py | 26 + mindscience/e3nn/utils/batch_dot.py | 152 ++++ mindscience/e3nn/utils/func.py | 160 +++++ mindscience/e3nn/utils/initializer.py | 63 ++ mindscience/e3nn/utils/linalg.py | 33 + mindscience/e3nn/utils/ncon.py | 699 +++++++++++++++++++ mindscience/e3nn/utils/perm.py | 140 ++++ mindscience/e3nn/utils/radius.py | 248 +++++++ 25 files changed, 6191 insertions(+), 5 deletions(-) create mode 100644 mindscience/e3nn/nn/__init__.py create mode 100644 mindscience/e3nn/nn/activation.py create mode 100644 mindscience/e3nn/nn/batchnorm.py create mode 100644 mindscience/e3nn/nn/fc.py create mode 100644 mindscience/e3nn/nn/gate.py create mode 100644 mindscience/e3nn/nn/normact.py create mode 100644 mindscience/e3nn/nn/one_hot.py create mode 100644 mindscience/e3nn/nn/scatter.py create mode 100644 mindscience/e3nn/o3/__init__.py create mode 100644 mindscience/e3nn/o3/irreps.py create mode 100644 mindscience/e3nn/o3/norm.py create mode 100644 mindscience/e3nn/o3/rotation.py create mode 100644 mindscience/e3nn/o3/spherical_harmonics.py create mode 100644 mindscience/e3nn/o3/sub.py create mode 100644 mindscience/e3nn/o3/tensor_product.py create mode 100644 mindscience/e3nn/o3/wigner.py create mode 100644 mindscience/e3nn/utils/__init__.py create mode 100644 mindscience/e3nn/utils/batch_dot.py create mode 100644 mindscience/e3nn/utils/func.py create mode 100644 mindscience/e3nn/utils/initializer.py create mode 100644 mindscience/e3nn/utils/linalg.py create mode 100644 mindscience/e3nn/utils/ncon.py create mode 100644 mindscience/e3nn/utils/perm.py create mode 100644 mindscience/e3nn/utils/radius.py diff --git a/mindscience/e3nn/__init__.py b/mindscience/e3nn/__init__.py index 69a14b29e..5ba0a5f68 100644 --- a/mindscience/e3nn/__init__.py +++ b/mindscience/e3nn/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2025 Huawei Technologies Co., Ltd +# Copyright 2022 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,8 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -""" -init -""" +"""init for e3 module""" +from .o3 import * +from .nn import * +from .utils import * -__all__ = [] \ No newline at end of file +__all__ = [] +__all__.extend(o3.__all__) +__all__.extend(nn.__all__) +__all__.extend(utils.__all__) diff --git a/mindscience/e3nn/nn/__init__.py b/mindscience/e3nn/nn/__init__.py new file mode 100644 index 000000000..17d4a7118 --- /dev/null +++ b/mindscience/e3nn/nn/__init__.py @@ -0,0 +1,35 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""init""" +from .activation import Activation +from .gate import Gate +from .fc import FullyConnectedNet +from .normact import NormActivation +from .scatter import Scatter +from .one_hot import SoftOneHotLinspace, soft_one_hot_linspace, soft_unit_step, OneHot +from .batchnorm import BatchNorm + +__all__ = [ + "Activation", + "Gate", + "FullyConnectedNet", + "NormActivation", + "Scatter", + "SoftOneHotLinspace", + "soft_one_hot_linspace", + "soft_unit_step", + "OneHot", + "BatchNorm" +] \ No newline at end of file diff --git a/mindscience/e3nn/nn/activation.py b/mindscience/e3nn/nn/activation.py new file mode 100644 index 000000000..8553fd84e --- /dev/null +++ b/mindscience/e3nn/nn/activation.py @@ -0,0 +1,147 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""activation""" +import numpy as np + +from mindspore import Tensor, nn, ops, float32 +from ..o3.irreps import Irreps + +identity = ops.Identity() +NTOL = 1e-5 + + +def _moment(f, n, dtype=float32): + x = Tensor(np.random.randn(1000000), dtype=dtype) + y = f(x).pow(n).mean().pow(-0.5) + + return y + + +def _parity_function(f, dtype=float32): + x = Tensor(np.linspace(.0, 10., 256), dtype=dtype) + y1, y2 = f(x).asnumpy(), f(-x).asnumpy() + if np.max(np.abs(y1 - y2)) < NTOL: + return 1 + if np.max(np.abs(y1 + y2)) < NTOL: + return -1 + return 0 + + +class _Normalize(nn.Cell): + """_Normalize""" + + def __init__(self, f, dtype=float32): + super().__init__() + self.f = f + self.factor = _moment(f, 2, dtype) + if ops.abs(self.factor - 1.) < 1e-4: + self._is_id = True + else: + self._is_id = False + + def construct(self, x): + if self._is_id: + return self.f(x) + return self.f(x).mul(self.factor) + + +class Activation(nn.Cell): + r""" + Activation function for scalar-tensors. The parities of irreps may be changed according to the parity of each + activation functions. + Odd scalars require the corresponding activation functions to be odd or even. + + Args: + irreps_in (Union[str, Irrep, Irreps]): the input irreps. + acts (List[Func]): a list of activation functions for each part of `irreps_in`. + The length of the `acts` will be clipped or filled by identity functions to match the length of `irreps_in`. + dtype (mindspore.dtype): The type of input tensor. Default: ``mindspore.float32``. + + Inputs: + - **inputs** (Tensor) - The shape of Tensor is :math:`(*, irreps\_in.dim)`. + + Outputs: + - **outputs** (Tensor) - The shape of Tensor is :math:`(*, irreps\_in.dim)`. + + Raises: + ValueError: If `irreps_in` contain non-scalar irrep. + ValueError: If a irrep in `irreps_in` is odd, but the corresponding activation function is neither even nor odd. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.nn import Activation + >>> from mindspore import ops, Tensor + >>> act = Activation('3x0o+2x0e+1x0o', [ops.abs, ops.tanh]) + >>> print(act) + Activation [xx-] (3x0o+2x0e+1x0o -> 3x0e+2x0e+1x0o) + >>> inputs = Tensor(ops.ones((4,6))) + >>> outputs = act(inputs) + >>> print(outputs.shape) + (4, 6) + """ + + def __init__(self, irreps_in, acts, dtype=float32): + super().__init__() + irreps_in = Irreps(irreps_in) + while len(acts) < len(irreps_in): + acts.append(None) + irreps_out = [] + acts_out = [] + for (mul, (l_in, p_in)), act in zip(irreps_in.data, acts): + if act is not None: + if l_in != 0: + raise ValueError(f"Activation cannot apply an activation function to a non-scalar input.") + + acts_out.append(_Normalize(act, dtype=dtype)) + p_out = _parity_function(acts_out[-1]) if p_in == -1 else p_in + + if p_out == 0: + raise ValueError( + "Parity is not match. The input scalar is odd but the activation is neither even nor odd." + ) + + irreps_out.append((mul, (0, p_out))) + + else: + acts_out.append(identity) + irreps_out.append((mul, (l_in, p_in))) + + self.irreps_in = irreps_in + self.irreps_out = Irreps(irreps_out) + self.acts = acts_out[:len(irreps_in)] + + def construct(self, v): + """Implement the activation function for the input tensor.""" + vs = self.irreps_in.decompose(v) + batch_shape = v.shape[:-1] + out_list = [] + i = 0 + for act in self.acts: + out_list.append(act(vs[i]).reshape(batch_shape + (self.irreps_in.data[i].dim,))) + i += 1 + + if len(out_list) > 1: + out = ops.concat(out_list, axis=-1) + elif len(out_list) == 1: + out = out_list[0] + else: + out = ops.zeros_like(v) + return out + + def __repr__(self): + acts = "".join(["x" if a is not identity else "-" for a in self.acts]) + return f"{self.__class__.__name__} [{acts}] ({self.irreps_in} -> {self.irreps_out})" diff --git a/mindscience/e3nn/nn/batchnorm.py b/mindscience/e3nn/nn/batchnorm.py new file mode 100644 index 000000000..9a98cb0ba --- /dev/null +++ b/mindscience/e3nn/nn/batchnorm.py @@ -0,0 +1,181 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""batchnorm""" + +from mindspore import nn, Parameter, ops, float32 + +from ..o3.irreps import Irreps + + +class BatchNorm(nn.Cell): + r""" + Batch normalization for orthonormal representations. + It normalizes by the norm of the representations. + Note that the norm is invariant only for orthonormal representations. + Irreducible representations `wigner_D` are orthonormal. + + Args: + irreps (Union[str, Irrep, Irreps]): the input irreps. + eps (float): avoid division by zero when we normalize by the variance. Default: ``1e-5``. + momentum (float): momentum of the running average. Default: ``0.1``. + affine (bool): do we have weight and bias parameters. Default: ``True``. + reduce (str): {'mean', 'max'}, method used to reduce. Default: ``'mean'``. + instance (bool): apply instance norm instead of batch norm. Default: ``Flase``. + normalization (str): {'component', 'norm'}, normalization method. Default: ``'component'``. + dtype (mindspore.dtype): The type of input tensor. Default: ``mindspore.float32``. + + Inputs: + - **input** (Tensor) - The shape of Tensor is :math:`(batch, ..., irreps.dim)`. + + Outputs: + - **output** (Tensor) - The shape of Tensor is :math:`(batch, ..., irreps.dim)`. + + Raises: + ValueError: If `reduce` is not in ['mean', 'max']. + ValueError: If `normalization` is not in ['component', 'norm']. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.nn import BatchNorm + >>> from mindspore import ops, Tensor + >>> bn = BatchNorm('3x0o+2x0e+1x0o') + >>> print(bn) + BatchNorm (3x0o+2x0e+1x0o, eps=1e-05, momentum=0.1) + >>> inputs = Tensor(ops.ones((4, 6))) + >>> outputs = bn(inputs) + >>> print(outputs.shape) + (4, 6) + """ + + def __init__(self, irreps, eps=1e-5, momentum=0.1, affine=True, reduce='mean', instance=False, + normalization='component', dtype=float32): + super().__init__() + self.irreps = Irreps(irreps) + self.eps = eps + self.momentum = momentum + self.affine = affine + self.instance = instance + self.reduce = reduce + self.normalization = normalization + self.training = True + + num_scalar = sum(mul for mul, ir in self.irreps if ir.is_scalar()) + num_features = self.irreps.num_irreps + + self.running_mean = None if self.instance else Parameter(ops.zeros(num_scalar, dtype=dtype), + requires_grad=False) + self.running_var = None if self.instance else Parameter(ops.ones(num_features, dtype=dtype), + requires_grad=False) + + self.weight = Parameter(ops.ones(num_features, dtype=dtype)) if affine else None + self.bias = Parameter(ops.zeros(num_scalar, dtype=dtype)) if affine else None + + def _roll_avg(self, curr, update): + return (1 - self.momentum) * curr + self.momentum * update + + def __repr__(self): + return f"{self.__class__.__name__} ({self.irreps}, eps={self.eps}, momentum={self.momentum})" + + def construct(self, inputs): + """construct""" + inputs_shape = inputs.shape + batch = inputs_shape[0] + dim = inputs_shape[-1] + inputs = inputs.reshape(batch, -1, dim) + + new_means = [] + new_vars = [] + + fields = [] + ix = 0 + irm = 0 + irv = 0 + iw = 0 + ib = 0 + + for mir in self.irreps.data: + mul = mir.mul + ir = mir.ir + + d = ir.dim + field = inputs[:, :, ix: ix + mul * d] # [batch, sample, mul * repr] + ix += mul * d + + # (batch, sample, mul, repr) + field = field.reshape(batch, -1, mul, d) + + if ir.is_scalar(): # scalars + if self.training or self.instance: + if self.instance: + field_mean = field.mean(1).reshape(batch, mul) # [batch, mul] + else: + field_mean = field.mean([0, 1]).reshape(mul) # [mul] + new_means.append( + self._roll_avg(self.running_mean[irm:irm + mul], field_mean) + ) + else: + field_mean = self.running_mean[irm: irm + mul] + irm += mul + + # (batch, sample, mul, repr) + field = field - field_mean.reshape(-1, 1, mul, 1) + + if self.training or self.instance: + if self.normalization == 'norm': + field_norm = field.pow(2).sum(3) # [batch, sample, mul] + elif self.normalization == 'component': + field_norm = field.pow(2).mean(3) # [batch, sample, mul] + else: + raise ValueError(f"Invalid normalization option {self.normalization}") + + if self.reduce == 'mean': + field_norm = field_norm.mean(1) # [batch, mul] + elif self.reduce == 'max': + field_norm = ops.amax(field_norm, 1) # [batch, mul] + else: + raise ValueError(f"Invalid reduce option {self.reduce}") + + if not self.instance: + field_norm = field_norm.mean(0) # [mul] + new_vars.append(self._roll_avg(self.running_var[irv: irv + mul], field_norm)) + else: + field_norm = self.running_var[irv: irv + mul] + irv += mul + + field_norm = (field_norm + self.eps).pow(-0.5) # [(batch,) mul] + + if self.affine: + weight = self.weight[iw: iw + mul] # [mul] + iw += mul + + field_norm = field_norm * weight # [(batch,) mul] + + field = field * field_norm.reshape(-1, 1, mul, 1) # [batch, sample, mul, repr] + + if self.affine and ir.is_scalar(): # scalars + bias = self.bias[ib: ib + mul] # [mul] + ib += mul + field += bias.reshape(mul, 1) # [batch, sample, mul, repr] + + fields.append(field.reshape(batch, -1, mul * d)) # [batch, sample, mul * repr] + + if self.training and not self.instance: + ops.assign(self.running_mean, ops.cat(new_means)) + ops.assign(self.running_var, ops.cat(new_vars)) + + output = ops.cat(fields, 2) + return output.reshape(inputs_shape) diff --git a/mindscience/e3nn/nn/fc.py b/mindscience/e3nn/nn/fc.py new file mode 100644 index 000000000..4d85dc4e2 --- /dev/null +++ b/mindscience/e3nn/nn/fc.py @@ -0,0 +1,110 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""FullyConnectedNet""" +from mindspore import Tensor, nn, Parameter, float32, ops +from mindspore.common.initializer import initializer + +from .activation import _Normalize +from ..utils.initializer import renormal_initializer + +identity = ops.Identity() + + +class _Layer(nn.Cell): + r"""Single simple dense layer with parameter w.""" + + def __init__(self, h_in, h_out, act, init_method='normal', dtype=float32): + super().__init__() + + init_method = renormal_initializer(init_method) + + self.weight = Parameter(initializer( + init_method, (h_in, h_out), dtype), name='Layer') + self.act = act if act is not None else identity + self.h_in = h_in + self.h_out = h_out + self.weight_numel = self.weight.numel() + self.sqrt_h_in = ops.sqrt(Tensor(self.h_in, self.weight.dtype)) + + def construct(self, x): + w = self.weight / self.sqrt_h_in + x = ops.matmul(x, w) + x = self.act(x) + return x + + def __repr__(self): + return f"Layer ({self.h_in}->{self.h_out})" + + +class FullyConnectedNet(nn.SequentialCell): + r""" + Fully-connected Neural Network with normalized activation on scalars. + + Args: + h_list (List[int]): a list of input, internal and output dimensions for dense layers. + act (Func): activation function which will be automatically normalized. Default: ``None``. + out_act (bool): whether apply the activation function on the output. Default: ``False``. + init_method (Union[str, mindspore.common.initializer]): initialize parameters. Default: ``'normal'``. + dtype (mindspore.dtype): The type of input tensor. Default: ``mindspore.float32``. + + Inputs: + - **input** (Tensor) - The shape of Tensor is :math:`(h\_list[0])`. + + Outputs: + - **output** (Tensor) - The shape of Tensor is :math:`(h\_list[-1])`. + + Raises: + TypeError: If the elements `h_list` are not `int`. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> import mindspore as ms + >>> from mindchemistry.e3.nn import FullyConnectedNet + >>> fc = FullyConnectedNet([4,10,20,12,6], ops.tanh) + FullyConnectedNet [4, 10, 20, 12, 6] + >>> v = ms.Tensor([.1,.2,.3,.4]) + >>> grad = ops.grad(fc, weights=fc.trainable_params()) + >>> fc(v).shape + (6,) + >>> [x.shape for x in grad(v)[1]] + [(4, 10), (10, 20), (20, 12), (12, 6)] + + """ + + def __init__(self, h_list, act=None, out_act=False, init_method='normal', dtype=float32): + self.h_list = list(h_list) + if act is not None: + act = _Normalize(act, dtype=dtype) + + self.layer_list = [] + + for i, (h1, h2) in enumerate(zip(self.h_list, self.h_list[1:])): + if not isinstance(h1, int) or not isinstance(h2, int): + raise TypeError + + if i == len(self.h_list) - 2 and (not out_act): + a = identity + else: + a = act + layer = _Layer(h1, h2, a, init_method, dtype=dtype) + self.layer_list.append(layer) + + super().__init__(self.layer_list) + self.weight_numel = sum([lay.weight_numel for lay in self.layer_list]) + + def __repr__(self): + return f"{self.__class__.__name__} ({self.h_list} | {self.weight_numel} weights)" diff --git a/mindscience/e3nn/nn/gate.py b/mindscience/e3nn/nn/gate.py new file mode 100644 index 000000000..f67a35d55 --- /dev/null +++ b/mindscience/e3nn/nn/gate.py @@ -0,0 +1,182 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""gate""" +from mindspore import nn, ops, float32 + +from .activation import Activation +from ..o3.irreps import Irreps +from ..o3.tensor_product import TensorProduct +from ..utils.func import narrow + + +class _Extract(nn.Cell): + """Extract tuple of tensors from irreps_in by irreps_outs with respecting instructions.""" + + def __init__(self, irreps_in, irreps_outs, instructions): + super().__init__() + self.irreps_in = Irreps(irreps_in) + self.irreps_outs = tuple(Irreps(irreps) for irreps in irreps_outs) + self.instr = instructions + + if not len(self.irreps_outs) == len(self.instr): + raise ValueError('inputs are illegal') + for irreps_out, ins in zip(self.irreps_outs, self.instr): + if not len(irreps_out) == len(ins): + raise ValueError('inputs are illegal') + + def construct(self, x): + """construct""" + out = [] + for i in range(len(self.irreps_outs)): + if self.instr[i] == tuple(range(len(self.irreps_in.data))): + out.append(x) + else: + out_i = [] + for i_in in self.instr[i]: + out_i.append(narrow(x, -1, *self.irreps_in.slice_tuples[i_in])) + if out_i: + out.append(ops.concat(out_i, -1)) + return out + + +class _Sortcut(nn.Cell): + """Sort and cut a tensor by irreps_outs.""" + + def __init__(self, *irreps_outs): + super().__init__() + self.irreps_outs = tuple(Irreps(irreps).simplify() for irreps in irreps_outs) + irreps_in = sum(self.irreps_outs, Irreps([])) + + i = 0 + instructions = [] + for irreps_out in self.irreps_outs: + instructions.append(tuple(range(i, i + len(irreps_out)))) + i += len(irreps_out) + + irreps_in, p, _ = irreps_in.sort() + instructions = [tuple(p[i] for i in x) for x in instructions] + + self.cut = _Extract(irreps_in, self.irreps_outs, instructions) + self.irreps_in = irreps_in.simplify() + + def construct(self, x): + return self.cut(x) + + +class Gate(nn.Cell): + r""" + Gate activation function. The input contain three parts: the first part `irreps_scalars` are scalars that only be + affected by activation functions `acts`; + the second part `irreps_gates` are scalars that be affected by activation functions `act_gates` and be multiplied + on the third part. + + .. math:: + \left(\bigoplus_i \phi_i(x_i) \right) \oplus \left(\bigoplus_j \phi_j(g_j) y_j \right) + + where :math:`x_i` and :math:`\phi_i` are from `irreps_scalars` and `acts`, and :math:`g_j`, :math:`\phi_j`, + and :math:`y_j` are from `irreps_gates`, `act_gates`, and `irreps_gated`. + + Args: + irreps_scalars (Union[str, Irrep, Irreps]): the input scalar irreps that will be passed through the + activation functions `acts`. + acts (List[Func]): a list of activation functions for each part of `irreps_scalars`. + The length of the `acts` will be clipped or filled by identity functions to match the length of + `irreps_scalars`. + irreps_gates (Union[str, Irrep, Irreps]): the input scalar irreps that will be passed through the + activation functions `act_gates` and multiplied by `irreps_gated`. + act_gates (List[Func]): a list of activation functions for each part of `irreps_gates`. + The length of the `acts` will be clipped or filled by identity functions to match the length of + `irreps_gates`. + irreps_gated (Union[str, Irrep, Irreps]): the input irreps that will be gated. + dtype (mindspore.dtype): The type of input tensor. Default: ``mindspore.float32``. + ncon_dtype (mindspore.dtype): The type of input tensors of ncon computation module. + Default: ``mindspore.float32``. + + Inputs: + - **input** (Tensor) - The shape of Tensor is :math:`(..., irreps\_in.dim)`. + + Outputs: + - **output** (Tensor) - The shape of Tensor is :math:`(..., irreps\_out.dim)`. + + Raises: + ValueError: If `irreps_scalars` or `irreps_gates` contain non-scalar irrep. + ValueError: If the total multiplication of `irreps_gates` do not match the total multiplication of + `irreps_gated`. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindspore import ops + >>> from mindchemistry.e3.nn import Gate + >>> Gate('2x0e', [ops.tanh], '1x0o+2x0e', [ops.abs], '2x1o+1x2e') + Gate (2x0e+1x0o+2x0e+2x1o+1x2e -> 2x0e+2x1o+1x2e) + """ + + def __init__(self, irreps_scalars, acts, irreps_gates, act_gates, irreps_gated, dtype=float32, ncon_dtype=float32): + super().__init__() + irreps_scalars = Irreps(irreps_scalars) + irreps_gates = Irreps(irreps_gates) + irreps_gated = Irreps(irreps_gated) + + # pylint: disable=C1801 + if len(irreps_gates) > 0 and irreps_gates.lmax > 0: + raise ValueError(f"Gate scalars must be scalars, instead got irreps_gates = {irreps_gates}") + # pylint: disable=C1801 + if len(irreps_scalars) > 0 and irreps_scalars.lmax > 0: + raise ValueError(f"Scalars must be scalars, instead got irreps_scalars = {irreps_scalars}") + if not irreps_gates.num_irreps == irreps_gated.num_irreps: + raise ValueError(f"There are {irreps_gated.num_irreps} irreps in irreps_gated, \ + but a different number ({irreps_gates.num_irreps}) of gate scalars in irreps_gates") + + self.sc = _Sortcut(irreps_scalars, irreps_gates, irreps_gated) + self.irreps_scalars, self.irreps_gates, self.irreps_gated = self.sc.irreps_outs + + if self.irreps_scalars.num_irreps == 0: + self._has_scalar = False + else: + self._has_scalar = True + self.act_pass = Activation(irreps_scalars, acts, dtype=dtype) + irreps_scalars = self.act_pass.irreps_out + self.act_gates = Activation(irreps_gates, act_gates, dtype=dtype) + irreps_gates = self.act_gates.irreps_out + + self.tp = TensorProduct(irreps_gated, irreps_gates, instructions='element', dtype=dtype, ncon_dtype=ncon_dtype) + irreps_gated = self.tp.irreps_out + + self.irreps_in = self.sc.irreps_in + self.irreps_out = irreps_scalars + irreps_gated + + def construct(self, x): + """Implement the gate activation function for the input tensor.""" + + scalars, gates, gated = self.sc(x) + if self._has_scalar: + scalars = self.act_pass(scalars) + + if gates.shape[-1] > 0: + gates = self.act_gates(gates) + gated = self.tp(gated, gates) + if self._has_scalar: + x = ops.concat([scalars, gated], axis=-1) + else: + x = gated + else: + x = scalars + + return x + + def __repr__(self): + return f"{self.__class__.__name__} ({self.irreps_in} -> {self.irreps_out})" diff --git a/mindscience/e3nn/nn/normact.py b/mindscience/e3nn/nn/normact.py new file mode 100644 index 000000000..a4adee62d --- /dev/null +++ b/mindscience/e3nn/nn/normact.py @@ -0,0 +1,128 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""normact""" +from mindspore import nn, Parameter, float32, ops +from mindspore.common.initializer import initializer + +from ..o3.irreps import Irreps +from ..o3.tensor_product import TensorProduct +from ..o3.norm import Norm + + +class NormActivation(nn.Cell): + r"""Activation function for the norm of irreps. + Applies a scalar activation to the norm of each irrep and outputs a (normalized) version of that irrep multiplied + by the scalar output of the scalar activation. + + Args: + irreps_in (Union[str, Irrep, Irreps]): the input irreps. + act (Func): an activation function for each part of the norm of `irreps_in`. + normalize (bool): whether to normalize the input features before multiplying them by the scalars from the + nonlinearity. Default: True. + epsilon (float): when ``normalize``, norms smaller than ``epsilon`` will be clamped up to ``epsilon`` + to avoid division by zero. Not allowed when `normalize` is False. Default: None. + bias (bool): whether to apply a learnable additive bias to the inputs of the `act`. Default: False. + init_method (Union[str, float, mindspore.common.initializer]): initialize parameters. + Default: ``'normal'``. + dtype (mindspore.dtype): The type of input tensor. Default: ``mindspore.float32``. + ncon_dtype (mindspore.dtype): The type of input tensors of ncon computation module. + Default: ``mindspore.float32``. + + Inputs: + - **input** (Tensor) - The shape of Tensor is :math:`(..., irreps\_in.dim)`. + + Outputs: + - **output** (Tensor) - The shape of Tensor is :math:`(..., irreps\_in.dim)`. + + Raises: + ValueError: If `epsilon` is not None and `normalize` is False. + ValueError: If `epsilon` is not positive. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.nn import NormActivation + >>> from mindspore import ops, Tensor + >>> set_context(device_id=6) + >>> norm_activation = NormActivation("2x1e", ops.sigmoid, bias=True) + >>> print(norm_activation) + NormActivation [sigmoid] (2x1e -> 2x1e) + >>> inputs = Tensor(ops.ones((4, 6))) + >>> outputs = norm_activation(inputs) + >>> print(outputs.shape) + (4, 6) + """ + + def __init__(self, + irreps_in, + act, + normalize=True, + epsilon=None, + bias=False, + init_method='zeros', + dtype=float32, + ncon_dtype=float32): + super().__init__() + + self.irreps_in = Irreps(irreps_in) + self.irreps_out = Irreps(irreps_in) + + if epsilon is None and normalize: + epsilon = 1e-8 + elif epsilon is not None and not normalize: + raise ValueError("`epsilon` and `normalize = False` don't make sense together.") + elif not epsilon > 0: + raise ValueError(f"epsilon {epsilon} is invalid, must be strictly positive.") + self.epsilon = epsilon + if self.epsilon is not None: + self._eps_squared = epsilon * epsilon + else: + self._eps_squared = 0.0 + + self.norm = Norm(irreps_in, squared=(epsilon is not None), dtype=dtype) + self.act = act + self.normalize = normalize + if bias: + self.bias = Parameter(initializer(init_method, (self.irreps_in.num_irreps,), dtype), + name=self.__class__.__name__) + else: + self.bias = None + + self.scalar_multiplier = TensorProduct(irreps_in1=self.norm.irreps_out, + irreps_in2=irreps_in, + instructions='element', + dtype=dtype, + ncon_dtype=ncon_dtype) + + def construct(self, v): + """Implement the norm-activation function for the input tensor.""" + norms = self.norm(v) + if self._eps_squared > 0: + norms[norms < self._eps_squared] = self._eps_squared + norms = ops.sqrt(norms) + + nonlin_arg = norms + if self.bias is not None: + nonlin_arg = nonlin_arg + self.bias + + scalings = self.act(nonlin_arg) + if self.normalize: + scalings = scalings / norms + + return self.scalar_multiplier(scalings, v) + + def __repr__(self): + return f"{self.__class__.__name__} [{self.act.__name__}] ({self.irreps_in} -> {self.irreps_in})" diff --git a/mindscience/e3nn/nn/one_hot.py b/mindscience/e3nn/nn/one_hot.py new file mode 100644 index 000000000..262b4863b --- /dev/null +++ b/mindscience/e3nn/nn/one_hot.py @@ -0,0 +1,238 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""OneHot""" +import math + +import numpy as np + +from mindspore import Tensor, ops, nn, float32, float16 +from mindspore import numpy as mnp + +from ..o3.irreps import Irreps + +TMAP = {"MixedPrecisionType.FP16": float16, "MixedPrecisionType.FP32": float32} + +def soft_unit_step(x): + r""" + Smooth version of the unit step function. + + .. math:: + x \mapsto \theta(x) e^{-1/x} + + Args: + x (Tensor): the input tensor. + + Returns: + Tensor, the output of the unit step function. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.nn import soft_unit_step + >>> from mindspore import ops, set_context, Tensor + >>> x = Tensor(ops.linspace(-1.0, 10.0, 1000)) + >>> outputs = soft_unit_step(x) + >>> print(outputs.shape) + (1000,) + + """ + return ops.relu(x) * ops.exp(- 1 / x) / x + + +class OneHot(nn.Cell): + r""" + One-hot embedding. + """ + + def __init__(self, num_types, dtype=float32): + super().__init__() + self.num_types = num_types + self.irreps_output = Irreps([(self.num_types, (0, 1))]) + + self.one_hot = ops.OneHot() + self.on_off = (Tensor(1., dtype=dtype), Tensor(0., dtype=dtype)) + + def construct(self, atom_type): + type_numbers = atom_type + one_hot = self.one_hot(type_numbers, self.num_types, *self.on_off) + return one_hot + + def __repr__(self): + return f'OneHot [num_types: {self.num_types}] ( -> {self.irreps_output})' + + +# pylint: disable=C0103 +# pylint: disable=R1705 +class SoftOneHotLinspace(nn.Cell): + r""" + Projection on a basis of functions. Returns a set of :math:`\{y_i(x)\}_{i=1}^N`, + + .. math:: + y_i(x) = \frac{1}{Z} f_i(x) + + where :math:`x` is the input and :math:`f_i` is the ith basis function. + :math:`Z` is a constant defined (if possible) such that, + + .. math:: + \langle \sum_{i=1}^N y_i(x)^2 \rangle_x \approx 1 + + Note that `bessel` basis cannot be normalized. + + Args: + start (float): minimum value span by the basis. + end (float): maximum value span by the basis. + number (int): number of basis functions :math:`N`. + basis (str): {'gaussian', 'cosine', 'smooth_finite', 'fourier', 'bessel'}, the basis family. + Default: ``'smooth_finite'``. + cutoff (bool): whether require the :math:`y_i(x)` from the outside domain of (`start`, `end`) to be + vanished. Default: ``True``. + dtype (mindspore.dtype): The type of input tensor. Default: ``mindspore.float32``. + + Inputs: + - **x** (Tensor) - The shape of Tensor is :math:`(...)`. + + Outputs: + - **output** (Tensor) - The shape of Tensor is :math:`(..., N)`. + + Raises: + ValueError: If `basis` is not in {'gaussian', 'cosine', 'smooth_finite', 'fourier', 'bessel'}. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.nn import SoftOneHotLinspace + >>> from mindspore import ops, Tensor + >>> soft_one_hot_linspace = SoftOneHotLinspace(-0.5, 1.5, number=4) + >>> x = Tensor(ops.ones((4, 6))) + >>> outputs = soft_one_hot_linspace(x) + >>> print(outputs.shape) + (4, 6, 4) + + """ + + def __init__(self, start, end, number, basis='smooth_finite', cutoff=True, dtype=float32): + super().__init__() + + self.start = Tensor(start, dtype=dtype) + self.end = Tensor(end, dtype=dtype) + self.number = number + self.basis = basis + self.cutoff = cutoff + + if self.cutoff: + self.values = Tensor(np.linspace(start, end, number), dtype=dtype) + self.step = self.values[1] - self.values[0] + else: + self.values = Tensor(np.linspace(start, end, number + 2), dtype=dtype) + self.step = self.values[1] - self.values[0] + self.values = self.values[1:-1] + + self.PI = Tensor(math.pi, dtype=dtype) + self.c = self.end - self.start + self.consts = [ + ops.exp(Tensor(2.0, dtype=dtype)), + ops.sqrt(Tensor(0.25 + self.number / 2, dtype=dtype)), + ops.sqrt(Tensor(2. / self.c, dtype=dtype)) + ] + self.bessel_roots = mnp.arange(1, self.number + 1) * self.PI + + def construct(self, x): + """construct""" + diff = (x.expand_dims(-1) - self.values) / self.step + + if self.basis == 'gaussian': + return ops.exp(-diff.pow(2)) / 1.12 + + elif self.basis == 'cosine': + return ops.cos(self.PI / 2 * diff) * (diff < 1) * (-1 < diff) + + elif self.basis == 'smooth_finite': + return 1.14136 * self.consts[0] * soft_unit_step(diff + 1.) * soft_unit_step(1. - diff) + + elif self.basis == 'fourier': + x = (x.expand_dims(-1) - self.start) / (self.end - self.start) + if not self.cutoff: + i = mnp.arange(0, self.number) + return ops.cos(self.PI * i * x) / self.consts[1] + else: + i = mnp.arange(1, self.number + 1) + return ops.sin(self.PI * i * x) / self.consts[1] * (x > 0) * (x < 1) + + if self.basis == 'bessel': + x = x.expand_dims(-1) - self.start + out = self.consts[2] * ops.sin(self.bessel_roots * x / self.c) / x + + if not self.cutoff: + return out + else: + return out * ((x / self.c) < 1) * (x > 0) + + else: + raise ValueError(f"Unsupported basis: {self.basis}.") + + def _set_mixed_precision_type_recursive(self, dst_type): + super()._set_mixed_precision_type_recursive(dst_type) + self.values = self.values.astype(TMAP[dst_type.__str__()]) + for i in range(len(self.consts)): + self.consts[i] = self.consts[i].astype(TMAP[dst_type.__str__()]) + + +def soft_one_hot_linspace(x, start, end, number, basis='smooth_finite', cutoff=True): + r""" + Projection on a basis of functions. Returns a set of :math:`\{y_i(x)\}_{i=1}^N`, + + .. math:: + y_i(x) = \frac{1}{Z} f_i(x) + + where :math:`x` is the input and :math:`f_i` is the ith basis function. + :math:`Z` is a constant defined (if possible) such that, + + .. math:: + \langle \sum_{i=1}^N y_i(x)^2 \rangle_x \approx 1 + + Note that `bessel` basis cannot be normalized. + + Args: + x (Tensor): The shape of Tensor is :math:`(...)`. + start (float): minimum value span by the basis. + end (float): maximum value span by the basis. + number (int): number of basis functions :math:`N`. + basis (str): {'gaussian', 'cosine', 'smooth_finite', 'fourier', 'bessel'}, the basis family. + Default: ``'smooth_finite'``. + cutoff (bool): whether require the :math:`y_i(x)` from the outside domain of (`start`, `end`) to be + vanished. Default: ``True``. + + Returns: + Tensor, shape is :math:`(..., N)`. + + Raises: + ValueError: If `basis` is not in {'gaussian', 'cosine', 'smooth_finite', 'fourier', 'bessel'}. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.nn import soft_one_hot_linspace + >>> from mindspore import ops, Tensor + >>> x = Tensor(ops.ones((4, 6))) + >>> outputs = soft_one_hot_linspace(x, -0.5, 1.5, number=4) + >>> print(outputs.shape) + (4, 6, 4) + + """ + soft = SoftOneHotLinspace(start, end, number, basis=basis, cutoff=cutoff, dtype=x.dtype) + return soft(x) diff --git a/mindscience/e3nn/nn/scatter.py b/mindscience/e3nn/nn/scatter.py new file mode 100644 index 000000000..922ac15ef --- /dev/null +++ b/mindscience/e3nn/nn/scatter.py @@ -0,0 +1,74 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""scatter""" +from mindspore import ops, nn +from mindspore.ops import operations as P + + +class Scatter(nn.Cell): + r""" + Easy-use version of scatter. + + Args: + mode (str): {'add', 'sum', 'div', 'max', 'min', 'mul'}, scatter mode. + + Raises: + ValueError: If `mode` is not legal. + + Supported Platforms: + ``CPU`` ``GPU`` ``Ascend`` + + """ + + def __init__(self, mode='add'): + super().__init__() + self.mode = mode + if mode in ('add', 'sum'): + self.scatter = P.TensorScatterAdd() + elif mode == 'div': + self.scatter = P.TensorScatterDiv() + elif mode == 'max': + self.scatter = P.TensorScatterMax() + elif mode == 'min': + self.scatter = P.TensorScatterMin() + elif mode == 'mul': + self.scatter = P.TensorScatterMul() + else: + raise ValueError(f"Unexpected scatter mode {mode}") + + self.zeros = ops.Zeros() + + def construct(self, src, index, out=None, dim_size=None): + r""" + Args: + src (Tensor): The source tensor. + index (Tensor): The indices of elements to scatter. + out (Tensor): The destination tensor. Default: None. + dim_size (int): If `out` is not given, automatically create output with size `dim_size`. + If `dim_size` is not given, a minimal sized output tensor is returned. Default: None. + + Returns: + Tensor. + """ + if index.ndim < 2: + index = index.unsqueeze(-1) + if out is not None: + return self.scatter(out, index, src) + dim_size = src.shape[0] if dim_size is None else dim_size + zero = self.zeros((dim_size, src.shape[1]), src.dtype) + return self.scatter(zero, index, src) + + def __repr__(self): + return f'Scatter [{self.mode}]' diff --git a/mindscience/e3nn/o3/__init__.py b/mindscience/e3nn/o3/__init__.py new file mode 100644 index 000000000..4f9b6d853 --- /dev/null +++ b/mindscience/e3nn/o3/__init__.py @@ -0,0 +1,51 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""init""" +from .irreps import Irrep, Irreps +from .rotation import * +from .wigner import change_basis_real_to_complex, su2_generators, so3_generators, wigner_D, wigner_3j +from .spherical_harmonics import SphericalHarmonics, spherical_harmonics +from .tensor_product import TensorProduct +from .sub import * +from .norm import Norm + +__all__ = [ + "Irrep", + "Irreps", + "identity_angles", + "rand_angles", + "compose_angles", + "matrix_x", + "matrix_y", + "matrix_z", + "angles_to_matrix", + "matrix_to_angles", + "angles_to_xyz", + "xyz_to_angles", + "change_basis_real_to_complex", + "su2_generators", + "so3_generators", + "wigner_D", + "wigner_3j", + "TensorProduct", + "SphericalHarmonics", + "spherical_harmonics", + "FullyConnectedTensorProduct", + "FullTensorProduct", + "ElementwiseTensorProduct", + "Linear", + "TensorSquare", + "Norm", +] \ No newline at end of file diff --git a/mindscience/e3nn/o3/irreps.py b/mindscience/e3nn/o3/irreps.py new file mode 100644 index 000000000..01273bf9d --- /dev/null +++ b/mindscience/e3nn/o3/irreps.py @@ -0,0 +1,761 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import itertools +import collections +import dataclasses + +import numpy as np + +from mindspore import jit_class, Tensor, ops + +from .wigner import wigner_D +from .rotation import matrix_to_angles +from ..utils.func import broadcast_args, _to_tensor, norm_keep, _expand_last_dims, narrow +from ..utils.perm import _inverse +from ..utils.linalg import _direct_sum + +# pylint: disable=C0111 + +@jit_class +@dataclasses.dataclass(init=False, frozen=True) +class Irrep: + r""" + Irreducible representation of O(3). This class does not contain any data, it is a structure that describe the representation. + It is typically used as argument of other classes of the library to define the input and output representations of functions. + + Args: + l (Union[int, str]): non-negative integer, the degree of the representation, :math:`l = 0, 1, \dots`. Or string to indicate the degree and parity. + p (int): {1, -1}, the parity of the representation. Default: ``None``. + + Raises: + NotImplementedError: If method is not implemented. + ValueError: If `l` is negative or `p` is not in {1, -1}. + ValueError: If `l` cannot be converted to an `Irrep`. + TypeError: If `l` is not int or str. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.o3 import Irrep + >>> Irrep(0, 1) + 0e + >>> Irrep("1y") + 1o + >>> Irrep("2o").dim + 5 + >>> Irrep("2e") in Irrep("1o") * Irrep("1o") + True + >>> Irrep("1o") + Irrep("2o") + 1x1o+1x2o + """ + l: int + p: int + + def __init__(self, l, p=None): + if p is None: + if isinstance(l, Irrep): + p = l.p + l = l.l + + if isinstance(l, _MulIr): + p = l.ir.p + l = l.ir.l + + if isinstance(l, str): + try: + name = l.strip() + l = int(name[:-1]) + if l < 0: + raise ValueError + p = { + 'e': 1, + 'o': -1, + 'y': (-1) ** l, + }[name[-1]] + except Exception: + raise ValueError + elif isinstance(l, tuple): + l, p = l + + if not isinstance(l, int): + raise TypeError + elif l < 0: + raise ValueError + if p not in [-1, 1]: + raise ValueError + object.__setattr__(self, "l", l) + object.__setattr__(self, "p", p) + + def __repr__(self): + """Representation of the Irrep.""" + p = {+1: 'e', -1: 'o'}[self.p] + return f"{self.l}{p}" + + @classmethod + def iterator(cls, lmax=None): + for l in itertools.count(): + yield Irrep(l, (-1) ** l) + yield Irrep(l, -(-1) ** l) + + if l == lmax: + break + + def wigD_from_angles(self, alpha, beta, gamma, k=None): + r""" + Representation wigner D matrices of O(3) from Euler angles. + + Args: + alpha (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): rotation :math:`\alpha` around Y axis, applied third. + beta (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): rotation :math:`\beta` around X axis, applied second. + gamma (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): rotation :math:`\gamma` around Y axis, applied first. + k (Union[None, Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): How many times the parity is applied. Default: ``None`` . + + Returns: + Tensor, representation wigner D matrix of O(3). The shape of Tensor is :math:`(..., 2l+1, 2l+1)` . + + Examples: + >>> m = Irrep(1, -1).wigD_from_angles(0, 0 ,0, 1) + >>> print(m) + [[-1, 0, 0], + [ 0, -1, 0], + [ 0, 0, -1]] + """ + if k is None: + k = ops.zeros_like(_to_tensor(alpha)) + + alpha, beta, gamma, k = broadcast_args(alpha, beta, gamma, k) + return wigner_D(self.l, alpha, beta, gamma) * self.p ** _expand_last_dims(k) + + def wigD_from_matrix(self, R): + r""" + Representation wigner D matrices of O(3) from rotation matrices. + + Args: + R (Tensor): Rotation matrices. The shape of Tensor is :math:`(..., 3, 3)`. + + Returns: + Tensor, representation wigner D matrix of O(3). The shape of Tensor is :math:`(..., 2l+1, 2l+1)`. + + Raises: + TypeError: If `R` is not a Tensor. + + Examples: + >>> from mindspore import ops + >>> m = Irrep(1, -1).wigD_from_matrix(-ops.eye(3)) + >>> print(m) + [[-1, 0, 0], + [ 0, -1, 0], + [ 0, 0, -1]] + """ + if not isinstance(R, Tensor): + raise TypeError + d = Tensor(np.sign(np.linalg.det(R.asnumpy()))) + R = _expand_last_dims(d) * R + k = (1. - d) / 2 + return self.wigD_from_angles(*matrix_to_angles(R), k) + + @property + def dim(self) -> int: + return 2 * self.l + 1 + + def is_scalar(self) -> bool: + return self.l == 0 and self.p == 1 + + def __mul__(self, other): + r""" + Generate the irreps from the product of two irreps. + + Returns: + generator of `Irrep`. + """ + other = Irrep(other) + p = self.p * other.p + lmin = abs(self.l - other.l) + lmax = self.l + other.l + for l in range(lmin, lmax + 1): + yield Irrep(l, p) + + def __rmul__(self, other): + r""" + Return `Irreps` of multiple `Irrep`. + + Args: + other (int): multiple number of the `Irrep`. + + Returns: + `Irreps` - corresponding multiple `Irrep`. + + Raises: + TypeError: If `other` is not int. + """ + if not isinstance(other, int): + raise TypeError + return Irreps([(other, self)]) + + def __add__(self, other): + r"""Sum of two irreps.""" + return Irreps(self) + Irreps(other) + + def __radd__(self, other): + r"""Sum of two irreps.""" + return Irreps(other) + Irreps(self) + + def __iter__(self): + r"""Deconstruct the irrep into ``l`` and ``p``.""" + yield self.l + yield self.p + + def __lt__(self, other): + r"""Compare the order of two irreps.""" + return (self.l, self.p) < (other.l, other.p) + + def __eq__(self, other): + """Compare two irreps.""" + other = Irrep(other) + return (self.l, self.p) == (other.l, other.p) + + +@jit_class +@dataclasses.dataclass(init=False, frozen=True) +class _MulIr: + """Multiple Irrep.""" + mul: int + ir: Irrep + + def __init__(self, mul, ir=None): + if ir is None: + mul, ir = mul + + if not (isinstance(mul, int) and isinstance(ir, Irrep)): + raise TypeError + object.__setattr__(self, "mul", mul) + object.__setattr__(self, "ir", ir) + + @property + def dim(self): + return self.mul * self.ir.dim + + def __repr__(self): + """Representation of the irrep.""" + return f"{self.mul}x{self.ir}" + + def __iter__(self): + """Deconstruct the mulirrep into `mul` and `ir`.""" + yield self.mul + yield self.ir + + def __lt__(self, other): + """Compare the order of two mulirreps.""" + return (self.ir, self.mul) < (other.ir, other.mul) + + def __eq__(self, other): + """Compare two irreps.""" + return (self.mul, self.ir) == (other.mul, other.ir) + + +@jit_class +@dataclasses.dataclass(init=False, frozen=False) +class Irreps: + r""" + Direct sum of irreducible representations of O(3). This class does not contain any data, it is a structure that describe the representation. + It is typically used as argument of other classes of the library to define the input and output representations of functions. + + Args: + irreps (Union[str, Irrep, Irreps, List[Tuple[int]]]): a string to represent the direct sum of irreducible representations. + + Raises: + ValueError: If `irreps` cannot be converted to an `Irreps`. + ValueError: If the mul part of `irreps` part is negative. + TypeError: If the mul part of `irreps` part is not int. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.o3 import Irreps + >>> x = Irreps([(100, (0, 1)), (50, (1, 1))]) + 100x0e+50x1e + >>> x.dim + 250 + >>> Irreps("100x0e+50x1e+0x2e") + 100x0e+50x1e+0x2e + >>> Irreps("100x0e+50x1e+0x2e").lmax + 1 + >>> Irrep("2e") in Irreps("0e+2e") + True + >>> Irreps(), Irreps("") + (, ) + >>> Irreps('2x1o+1x0o') * Irreps('2x1o+1x0e') + 4x0e+1x0o+2x1o+4x1e+2x1e+4x2e + """ + __slots__ = ('data', 'dim', 'slice', 'slice_tuples') + + def __init__(self, irreps=None): + if isinstance(irreps, Irreps): + self.data = irreps.data + self.dim = irreps.dim + self.slice = irreps.slice + self.slice_tuples = irreps.slice_tuples + else: + out = () + if isinstance(irreps, Irrep): + out += (_MulIr(1, Irrep(irreps)),) + elif isinstance(irreps, _MulIr): + out += (irreps,) + elif isinstance(irreps, str): + try: + if irreps.strip() != "": + for mir in irreps.split('+'): + if 'x' in mir: + mul, ir = mir.split('x') + mul = int(mul) + ir = Irrep(ir) + else: + mul = 1 + ir = Irrep(mir) + + if not isinstance(mul, int): + raise TypeError + elif mul < 0: + raise ValueError + out += (_MulIr(mul, ir),) + except Exception: + raise ValueError + elif irreps is None: + pass + else: + out = self.handle_irreps(irreps, out) + self.data = out + self.dim = self._dim() + self.slice = self._slices() + self.slice_tuples = [(s.start, s.stop - s.start) for s in self.slice] + + def handle_irreps(self, irreps, out): + for mir in irreps: + + if isinstance(mir, str): + if 'x' in mir: + mul, ir = mir.split('x') + mul = int(mul) + ir = Irrep(ir) + else: + mul = 1 + ir = Irrep(mir) + elif isinstance(mir, Irrep): + mul = 1 + ir = mir + elif isinstance(mir, _MulIr): + mul, ir = mir + elif isinstance(mir, int): + mul, ir = 1, Irrep(l=mir, p=1) + elif len(mir) == 2: + mul, ir = mir + ir = Irrep(ir) + + if not (isinstance(mul, int) and mul >= 0 and ir is not None): + raise ValueError + + out += (_MulIr(mul, ir),) + return out + + def __iter__(self): + return iter(self.data) + + def __hash__(self): + return hash(self.data) + + def __len__(self): + return len(self.data) + + def __repr__(self): + """Representation of the irreps.""" + return "+".join(f"{mir}" for mir in self.data) + + def __eq__(self, other): + """Compare two irreps.""" + other = Irreps(other) + if not len(self) == len(other): + return False + for m_1, m_2 in zip(self.data, other.data): + if not m_1 == m_2: + return False + return True + + def __contains__(self, ir): + """Check if an irrep or an irreps is in the representation.""" + try: + ir = Irrep(ir) + return ir in (irrep for _, irrep in self.data) + except: + irreps = Irreps(ir) + m, n = len(irreps), len(self) + mask = [False] * n + + def dfs(i): + if i == m: + return True + for j in range(n): + if not mask[j]: + if irreps.data[i].mul <= self.data[j].mul and irreps.data[i].ir == self.data[j].ir: + mask[j] = True + found = dfs(i + 1) + if found: + return True + mask[j] = False + return False + + return dfs(0) + + def __add__(self, irreps): + irreps = Irreps(irreps) + return Irreps(self.data.__add__(irreps.data)) + + def __mul__(self, other): + r""" + Return `Irreps` of multiple `Irreps`. + + Args: + other (int): multiple number of the `Irreps`. + + Returns: + `Irreps` - corresponding multiple `Irreps`. + + Raises: + NotImplementedError: If `other` is `Irreps`, please use `o3.TensorProduct`. + """ + if isinstance(other, Irreps): + res = Irreps() + for mir_1 in self.data: + for mir_2 in other.data: + out_ir = mir_1.ir * mir_2.ir + for ir in out_ir: + res += mir_1.mul * mir_2.mul * ir + res, p, _ = res.simplify().sort() + return res + return Irreps([(mul * other, ir) for mul, ir in self.data]) + + def __rmul__(self, other): + r""" + Return repeated `Irreps` of multiple `Irreps`. + + Args: + other (int): multiple number of the `Irreps`. + + Returns: + `Irreps` - repeated multiple `Irreps`. + """ + return self * other + + def _dim(self): + """The dimension of the representation, :math:`2 l + 1`.""" + return sum(mul * ir.dim for mul, ir in self.data) + + @property + def num_irreps(self): + return sum(mul for mul, _ in self.data) + + @property + def ls(self): + res = [] + for mul, (l, _) in self.data: + res.extend([l] * mul) + return res + + @property + def lmax(self): + if len(self) == 0: + raise ValueError("Cannot get lmax of empty Irreps") + return max(self.ls) + + def count(self, ir): + r""" + Multiplicity of `ir`. + + Args: + ir (Irrep): `Irrep` + + Returns: + int, total multiplicity of `ir`. + + Examples: + >>> Irreps("1o + 3x2e").count("2e") + 3 + """ + ir = Irrep(ir) + res = 0 + for mul, irrep in self.data: + if ir == irrep: + res += mul + return res + + def simplify(self): + """ + Simplify the representations. + + Returns: + `Irreps` + + Examples: + >>> Irreps("1e + 1e + 0e").simplify() + 2x1e+1x0e + >>> Irreps("1e + 1e + 0e + 1e").simplify() + 2x1e+1x0e+1x1e + """ + out = [] + for mul, ir in self.data: + if out and out[-1][1] == ir: + out[-1] = (out[-1][0] + mul, ir) + elif mul > 0: + out.append((mul, ir)) + return Irreps(out) + + def remove_zero_multiplicities(self): + """ + Remove any irreps with multiplicities of zero. + + Returns: + `Irreps` + + Examples: + >>> Irreps("4x0e + 0x1o + 2x3e").remove_zero_multiplicities() + 4x0e+2x3e + """ + out = [(mul, ir) for mul, ir in self.data if mul > 0] + return Irreps(out) + + def _slices(self): + r""" + List of slices corresponding to indices for each irrep. + + Examples: + >>> Irreps('2x0e + 1e').slices() + [slice(0, 2, None), slice(2, 5, None)] + """ + s = [] + i = 0 + for mir in self.data: + s.append(slice(i, i + mir.dim)) + i += mir.dim + return s + + def sort(self): + r""" + Sort the representations by increasing degree. + + Returns: + irreps (`Irreps`) - sorted `Irreps` + + p (tuple[int]) - permute orders. `p[old_index] = new_index` + + inv (tuple[int]) - inversed permute orders. `p[new_index] = old_index` + + Examples: + >>> Irreps("1e + 0e + 1e").sort().irreps + 1x0e+1x1e+1x1e + >>> Irreps("2o + 1e + 0e + 1e").sort().p + (3, 1, 0, 2) + >>> Irreps("2o + 1e + 0e + 1e").sort().inv + (2, 1, 3, 0) + """ + Ret = collections.namedtuple("sort", ["irreps", "p", "inv"]) + out = [(ir, i, mul) for i, (mul, ir) in enumerate(self.data)] + out = sorted(out) + inv = tuple(i for _, i, _ in out) + p = _inverse(inv) + irreps = Irreps([(mul, ir) for ir, _, mul in out]) + return Ret(irreps, p, inv) + + def filter(self, keep=None, drop=None): + r""" + Filter the `Irreps` by either `keep` or `drop`. + + Args: + keep (Union[str, Irrep, Irreps, List[str, Irrep]]): list of irrep to keep. Default: None. + drop (Union[str, Irrep, Irreps, List[str, Irrep]]): list of irrep to drop. Default: None. + + Returns: + `Irreps`, filtered irreps. + + Raises: + ValueError: If both `keep` and `drop` are not `None`. + + Examples: + >>> Irreps("1o + 2e").filter(keep="1o") + 1x1o + >>> Irreps("1o + 2e").filter(drop="1o") + 1x2e + """ + if keep is None and drop is None: + return self + if keep is not None and drop is not None: + raise ValueError("Cannot specify both keep and drop") + if keep is not None: + keep = Irreps(keep).data + keep = {mir.ir for mir in keep} + return Irreps([(mul, ir) for mul, ir in self.data if ir in keep]) + if drop is not None: + drop = Irreps(drop).data + drop = {mir.ir for mir in drop} + return Irreps([(mul, ir) for mul, ir in self.data if not ir in drop]) + return None + + def decompose(self, v, batch=False): + r""" + Decompose a vector by `Irreps`. + + Args: + v (Tensor): the vector to be decomposed. + batch (bool): whether reshape the result such that there is at least a batch dimension. Default: `False`. + + Returns: + List of Tensors, the decomposed vectors by `Irreps`. + + Raises: + TypeError: If v is not Tensor. + ValueError: If length of the vector `v` is not matching with dimension of `Irreps`. + + Examples: + >>> import mindspore as ms + >>> input = ms.Tensor([1, 2, 3]) + >>> m = Irreps("1o").decompose(input) + >>> print(m) + [Tensor(shape=[1,3], dtype=Int64, value= + [[1,2,3]])] + """ + if not isinstance(v, Tensor): + raise TypeError( + f"The input for decompose should be Tensor, but got {type(v)}.") + len_v = v.shape[-1] + if not self.dim == len_v: + raise ValueError( + f"the shape of input {v.shape[-1]} do not match irreps dimension {self.dim}.") + + res = [] + batch_shape = v.shape[:-1] + for (s, l), mir in zip(self.slice_tuples, self.data): + v_slice = narrow(v, -1, s, l) + if v.ndim == 1 and batch: + res.append(v_slice.reshape( + (1,) + batch_shape + (mir.mul, mir.ir.dim))) + else: + res.append(v_slice.reshape( + batch_shape + (mir.mul, mir.ir.dim))) + + return res + + @staticmethod + def spherical_harmonics(lmax, p=-1): + r""" + Representation of the spherical harmonics. + + Args: + lmax (int): maximum of `l`. + p (int): {1, -1}, the parity of the representation. + + Returns: + `Irreps`, representation of :math:`(Y^0, Y^1, \dots, Y^{\mathrm{lmax}})`. + + Examples: + >>> Irreps.spherical_harmonics(3) + 1x0e+1x1o+1x2e+1x3o + >>> Irreps.spherical_harmonics(4, p=1) + 1x0e+1x1e+1x2e+1x3e+1x4e + """ + return Irreps([(1, (l, p ** l)) for l in range(lmax + 1)]) + + def randn(self, *size, normalization='component'): + r""" + Random tensor. + + Args: + *size (List[int]): size of the output tensor, needs to contains a `-1`. + normalization (str): {'component', 'norm'}, type of normalization method. + + Returns: + Tensor, the shape is `size` where `-1` is replaced by `self.dim`. + + Examples: + >>> Irreps("5x0e + 10x1o").randn(5, -1, 5, normalization='norm').shape + (5, 35, 5) + """ + di = size.index(-1) + lsize = size[:di] + rsize = size[di + 1:] + + if normalization == 'component': + return ops.standard_normal((*lsize, self.dim, *rsize)) + elif normalization == 'norm': + x_list = [] + for s, (mul, ir) in zip(self.slice, self.data): + if mul < 1: + continue + r = ops.standard_normal((*lsize, mul, ir.dim, *rsize)) + r = r / norm_keep(r, axis=di + 1) + + x_list.append(r.reshape((*lsize, -1, *rsize))) + return ops.concat(x_list, axis=di) + else: + raise ValueError("Normalization needs to be 'norm' or 'component'") + + def wigD_from_angles(self, alpha, beta, gamma, k=None): + r""" + Representation wigner D matrices of O(3) from Euler angles. + + Args: + alpha (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): rotation :math:`\alpha` around Y axis, applied third. + beta (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): rotation :math:`\beta` around X axis, applied second. + gamma (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): rotation :math:`\gamma` around Y axis, applied first. + k (Union[None, Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): How many times the parity is applied. Default: None. + + Returns: + Tensor, representation wigner D matrix of O(3). The shape of Tensor is :math:`(..., 2l+1, 2l+1)` + + Examples: + >>> m = Irreps("1o").wigD_from_angles(0, 0 ,0, 1) + >>> print(m) + [[-1, 0, 0], + [ 0, -1, 0], + [ 0, 0, -1]] + """ + return _direct_sum(*[ir.wigD_from_angles(alpha, beta, gamma, k) for mul, ir in self for _ in range(mul)]) + + def wigD_from_matrix(self, R): + r""" + Representation wigner D matrices of O(3) from rotation matrices. + + Args: + R (Tensor): Rotation matrices. The shape of Tensor is :math:`(..., 3, 3)`. + + Returns: + Tensor, representation wigner D matrix of O(3). The shape of Tensor is :math:`(..., 2l+1, 2l+1)` + + Raises: + TypeError: If `R` is not a Tensor. + + Examples: + >>> m = Irreps("1o").wigD_from_matrix(-ops.eye(3)) + >>> print(m) + [[-1, 0, 0], + [ 0, -1, 0], + [ 0, 0, -1]] + """ + if not isinstance(R, Tensor): + raise TypeError + d = Tensor(np.sign(np.linalg.det(R.asnumpy()))) + R = _expand_last_dims(d) * R + k = (1 - d) / 2 + return self.wigD_from_angles(*matrix_to_angles(R), k) diff --git a/mindscience/e3nn/o3/norm.py b/mindscience/e3nn/o3/norm.py new file mode 100644 index 000000000..150e52178 --- /dev/null +++ b/mindscience/e3nn/o3/norm.py @@ -0,0 +1,81 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""norm""" +from mindspore import nn, ops, float32 + +from .irreps import Irreps +from .tensor_product import TensorProduct + + +class Norm(nn.Cell): + r""" + Norm of each irrep in a direct sum of irreps. + + Args: + irreps_in (Union[str, Irrep, Irreps]): Irreps for the input. + squared (bool): whether to return the squared norm. Default: False. + dtype (mindspore.dtype): The type of input tensor. Default: ``mindspore.float32`` . + ncon_dtype (mindspore.dtype): The type of input tensors of ncon computation module. + Default: ``mindspore.float32`` . + + Inputs: + - **v** (Tensor) - The shape of Tensor is :math:`(..., irreps\_in.dim)` . + + Outputs: + - **output** (Tensor) - The shape of Tensor is :math:`(..., irreps\_out.dim)` . + + Supported Platforms: + ``Ascend`` + + Examples: + >>> import mindspore as ms + >>> import numpy as np + >>> from mindchemistry.e3.o3 import Norm + >>> n = Norm('3x1o') + >>> v = ms.Tensor(np.linspace(1., 2., n.irreps_in.dim), dtype=ms.float32) + >>> n(v).shape + (1, 3) + + """ + + def __init__(self, irreps_in, squared=False, dtype=float32, ncon_dtype=float32): + super().__init__() + + self.squared = squared + irreps_in = Irreps(irreps_in).simplify() + irreps_out = Irreps([(mul, "0e") for mul, _ in irreps_in]) + + instr = [(i, i, i, "uuu", False, ir.dim) for i, (mul, ir) in enumerate(irreps_in)] + + self.tp = TensorProduct(irreps_in, + irreps_in, + irreps_out, + instr, + irrep_norm="component", + dtype=dtype, + ncon_dtype=ncon_dtype) + + self.irreps_in = irreps_in + self.irreps_out = irreps_out.simplify() + + def construct(self, v): + """Implement the norm-activation function for the input tensor.""" + out = self.tp(v, v) + if self.squared: + return out + return ops.sqrt(ops.relu(out)) + + def __repr__(self): + return f"{self.__class__.__name__} ({self.irreps_in})" diff --git a/mindscience/e3nn/o3/rotation.py b/mindscience/e3nn/o3/rotation.py new file mode 100644 index 000000000..96bbe21cc --- /dev/null +++ b/mindscience/e3nn/o3/rotation.py @@ -0,0 +1,387 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""rotation""" +import math +import random + +import numpy as np + +from mindspore import Tensor, float32, ops + +from ..utils.func import broadcast_args, _to_tensor, norm_keep + +seed = int(random.random() * 10000) +zeros = ops.Zeros() +cos = ops.Cos() +sin = ops.Sin() +rand = ops.UniformReal(seed=seed) + + +def identity_angles(*shape, dtype=float32): + r""" + Give the identity set of Euler angles. + + Args: + shape (Tuple[int]): The shape of additional dimensions. + dtype (mindspore.dtype): The type of input tensor. Default: ``mindspore.float32`` . + + Returns: + alpha (Tensor) - The alpha Euler angles. + + beta (Tensor) - The beta Euler angles. + + gamma (Tensor) - The gamma Euler angles. + + Raises: + TypeError: If dtype of 'shape' is not tuple. + TypeError: If dtype of the element of 'shape' is not int. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.o3 import identity_angles + >>> m = identity_angles((1)) + >>> print(m) + (Tensor(shape=[1], dtype=Float32, value= [ 0.00000000e+00]), Tensor(shape=[1], dtype=Float32, + value= [ 0.00000000e+00]), Tensor(shape=[1], dtype=Float32, value= [ 0.00000000e+00])) + """ + if not isinstance(shape, tuple): + raise TypeError + if not all(map(lambda x: isinstance(x, int), shape)): + raise TypeError + abc = zeros((3,) + shape, dtype) + return abc[0], abc[1], abc[2] + + +def rand_angles(*shape): + r""" + Give a random set of Euler angles. + + Args: + shape (Tuple[int]): The shape of additional dimensions. + + Returns: + alpha (Tensor) - The alpha Euler angles. + + beta (Tensor) - The beta Euler angles. + + gamma (Tensor) - The gamma Euler angles. + + Raises: + TypeError: If dtype of 'shape' is not tuple. + TypeError: If dtype of the element of 'shape' is not int. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.o3 import rand_angles + >>> m = rand_angles((1)) + >>> print(m) + (Tensor(shape=[1], dtype=Float32, value= [ 4.00494671e+00]), Tensor(shape=[1], dtype=Float32, + value= [ 1.29240000e+00]), Tensor(shape=[1], dtype=Float32, value= [ 5.71690750e+00])) + """ + if not isinstance(shape, tuple): + raise TypeError + if not all(map(lambda x: isinstance(x, int), shape)): + raise TypeError + alpha, gamma = 2 * math.pi * rand((2,) + shape) + beta = ops.acos(2 * rand(shape) - 1) + return alpha, beta, gamma + + +def compose_angles(a1, b1, c1, a2, b2, c2): + r""" + Computes the composed Euler angles of two sets of Euler angles. + + .. math:: + + R(a, b, c) = R(a_1, b_1, c_1) \circ R(a_2, b_2, c_2) + + Note: + The second set of Euler angles 'a2, b2, c2' are applied first, while the first set of Euler angles a2, b2, c2' + are applied Second. + The elements of Euler angles should be one of the following types: float, float32, np.float32. + + Args: + a1 (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): + The second applied alpha Euler angles. + b1 (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): + The second applied beta Euler angles. + c1 (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): + The second applied gamma Euler angles. + a2 (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): + The first applied alpha Euler angles. + b2 (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): + The first applied beta Euler angles. + c2 (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): + The first applied gamma Euler angles. + + Returns: + - alpha (Tensor), The composed alpha Euler angles. + - beta (Tensor), The composed beta Euler angles. + - gamma (Tensor), The composed gamma Euler angles. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.o3 import compose_angles + >>> m = compose_angles(0.4, 0.5, 0.6, 0.7, 0.8, 0.9) + >>> print(m) + (Tensor(shape=[], dtype=Float32, value= 1.34227), Tensor(shape=[], dtype=Float32, value= 1.02462), + Tensor(shape=[], dtype=Float32, value= 1.47115)) + """ + + a1, b1, c1, a2, b2, c2 = broadcast_args(a1, b1, c1, a2, b2, c2) + return matrix_to_angles( + ops.matmul(angles_to_matrix(a1, b1, c1), angles_to_matrix(a2, b2, c2))) + + +def matrix_x(angle): + r""" + Give the rotation matrices around x axis for given angle. + + Args: + angle (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): + The rotation angles around x axis. + The shape of 'angle' is :math:`(...)`. + + Returns: + Tensor, the rotation matrices around x axis. The shape of output is :math:`(..., 3, 3)` + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.o3 import matrix_x + >>> m = matrix_x(0.4) + >>> print(m) + [[ 1. 0. 0. ] + [ 0. 0.92106086 -0.38941833] + [ 0. 0.38941833 0.92106086]] + """ + angle = _to_tensor(angle) + o = ops.ones_like(angle) + z = ops.zeros_like(angle) + return ops.stack([ + ops.stack([o, z, z], axis=-1), + ops.stack([z, cos(angle), -sin(angle)], axis=-1), + ops.stack([z, sin(angle), cos(angle)], axis=-1), + ], + axis=-2) + + +def matrix_y(angle): + r""" + Give the rotation matrices around y axis for given angle. + + Args: + angle (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): + The rotation angles around y axis. + + Returns: + Tensor, the rotation matrices around y axis. The shape of output is :math:`(..., 3, 3)` + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.o3 import matrix_y + >>> m = matrix_y(0.5) + >>> print(m) + [[ 0.87758255 0. 0.47942555] + [ 0. 1. 0. ] + [-0.47942555 0. 0.87758255]] + """ + angle = _to_tensor(angle) + o = ops.ones_like(angle) + z = ops.zeros_like(angle) + return ops.stack([ + ops.stack([cos(angle), z, sin(angle)], axis=-1), + ops.stack([z, o, z], axis=-1), + ops.stack([-sin(angle), z, cos(angle)], axis=-1), + ], + axis=-2) + + +def matrix_z(angle): + r""" + Give the rotation matrices around z axis for given angle. + + Args: + angle (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): + The rotation angles around z axis. + The shape of 'angle' is :math:`(...)`. + + Returns: + Tensor, the rotation matrices around z axis. The shape of output is :math:`(..., 3, 3)`. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.o3 import matrix_z + >>> m = matrix_z(0.6) + >>> print(m) + [[ 0.8253357 -0.5646425 0. ] + [ 0.5646425 0.8253357 0. ] + [ 0. 0. 1. ]] + """ + angle = _to_tensor(angle) + o = ops.ones_like(angle) + z = ops.zeros_like(angle) + return ops.stack([ + ops.stack([cos(angle), -sin(angle), z], axis=-1), + ops.stack([sin(angle), cos(angle), z], axis=-1), + ops.stack([z, z, o], axis=-1), + ], + axis=-2) + + +def angles_to_matrix(alpha, beta, gamma): + r""" + Conversion from angles to matrix. + + Args: + alpha (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): + The alpha Euler angles. The shape of Tensor is :math:`(...)`. + beta (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): + The beta Euler angles. The shape of Tensor is :math:`(...)`. + gamma (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): + The gamma Euler angles. The shape of Tensor is :math:`(...)`. + + Returns: + Tensor, the rotation matrices. Matrices of shape :math:`(..., 3, 3)`. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.o3 import angles_to_matrix + >>> m = angles_to_matrix(0.4, 0.5, 0.6) + >>> print(m) + [[ 0.5672197 0.1866971 0.8021259 ] + [ 0.27070403 0.87758255 -0.395687 ] + [-0.77780527 0.44158012 0.4472424 ]] + """ + alpha, beta, gamma = broadcast_args(alpha, beta, gamma) + return ops.matmul(ops.matmul(matrix_y(alpha), matrix_x(beta)), + matrix_y(gamma)) + + +def matrix_to_angles(r_param): + r""" + Conversion from matrix to angles. + + Args: + r_param (Tensor): The rotation matrices. Matrices of shape :math:`(..., 3, 3)`. + + Returns: + - alpha (Tensor), The alpha Euler angles. The shape of Tensor is :math:`(...)`. + - beta (Tensor), The beta Euler angles. The shape of Tensor is :math:`(...)`. + - gamma (Tensor), The gamma Euler angles. The shape of Tensor is :math:`(...)`. + + Raise: + ValueError: If the det(R) is not equal to 1. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> import mindspore as ms + >>> from mindchemistry.e3.o3 import matrix_to_angles + >>> input = ms.Tensor([[0.5672197, 0.1866971, 0.8021259], [0.27070403, 0.87758255, -0.395687], + ... [-0.77780527, 0.44158012,0.4472424]]) + >>> m = matrix_to_angles(input) + >>> print(m) + (Tensor(shape=[], dtype=Float32, value= 0.4), Tensor(shape=[], dtype=Float32, value= 0.5), + Tensor(shape=[], dtype=Float32, value= 0.6)) + """ + if not np.allclose(np.linalg.det(r_param.asnumpy()), 1., 1e-3, 1e-5): + raise ValueError + + x = ops.matmul(r_param, Tensor([0.0, 1.0, 0.0])) + a, b = xyz_to_angles(x) + tmp_r_param = angles_to_matrix(a, b, ops.zeros_like(a)) + perm = tuple(range(len(tmp_r_param.shape))) + r_param = ops.matmul( + tmp_r_param.transpose(perm[:-2] + (perm[-1],) + (perm[-2],)), + r_param) + c = ops.atan2(r_param[..., 0, 2], r_param[..., 0, 0]) + return a, b, c + + +def angles_to_xyz(alpha, beta): + r""" + Convert :math:`(\alpha, \beta)` into a point :math:`(x, y, z)` on the sphere. + + Args: + alpha (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): + The alpha Euler angles. The shape of Tensor is :math:`(...)`. + beta (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): + The beta Euler angles. The shape of Tensor is :math:`(...)`. + + Returns: + Tensor, the point :math:`(x, y, z)` on the sphere. The shape of Tensor is :math:`(..., 3)` + + Supported Platforms: + ``Ascend`` + + Examples + >>> import mindspore as ms + >>> from mindchemistry.e3.o3 import angles_to_xyz + >>> print(angles_to_xyz(ms.Tensor(1.7), ms.Tensor(0.0)).abs()) + [0., 1., 0.] + """ + alpha, beta = broadcast_args(alpha, beta) + x = sin(beta) * sin(alpha) + y = cos(beta) + z = sin(beta) * cos(alpha) + return ops.stack([x, y, z], axis=-1) + + +def xyz_to_angles(xyz): + r""" + Convert a point :math:`\vec r = (x, y, z)` on the sphere into angles :math:`(\alpha, \beta)`. + + .. math:: + \vec r = R(\alpha, \beta, 0) \vec e_z + + Args: + xyz (Tensor): The point :math:`(x, y, z)` on the sphere. The shape of Tensor is :math:`(..., 3)`. + + Returns: + alpha (Tensor) - The alpha Euler angles. The shape of Tensor is :math:`(...)`. + beta (Tensor) - The beta Euler angles. The shape of Tensor is :math:`(...)`. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> import mindspore as ms + >>> from mindchemistry.e3.o3 import xyz_to_angles + >>> input = ms.Tensor([3, 3, 3]) + >>> m = xyz_to_angles(input) + >>> print(m) + (Tensor(shape=[], dtype=Float32, value= 0.785398), Tensor(shape=[], dtype=Float32, value= 0.955318)) + """ + xyz = xyz / norm_keep(xyz, axis=-1) + xyz = ops.nan_to_num(ops.clamp(xyz, -1, 1), 1.0) + + beta = ops.acos(xyz[..., 1]) + alpha = ops.atan2(xyz[..., 0], xyz[..., 2]) + return alpha, beta diff --git a/mindscience/e3nn/o3/spherical_harmonics.py b/mindscience/e3nn/o3/spherical_harmonics.py new file mode 100644 index 000000000..392736735 --- /dev/null +++ b/mindscience/e3nn/o3/spherical_harmonics.py @@ -0,0 +1,679 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""SphericalHarmonics""" +from mindspore import Tensor, nn, ops, float32 +from .irreps import Irreps + + +def _sqrt(x, dtype=float32): + sqrt = ops.Sqrt() + return sqrt(Tensor(x, dtype=dtype)) + + +class SphericalHarmonics(nn.Cell): + r""" + Return Spherical harmonics layer. + + Args: + irreps_out (Union[str, `Irreps`]): irreducible representations of output for spherical harmonics. + normalize (bool): whether to normalize the input Tensor to unit vectors that lie on the sphere before + projecting onto the spherical harmonics. + normalization (str): {'integral', 'component', 'norm'}, normalization method of the output tensors. + Default: ``'integral'``. + irreps_in (Union[str, `Irreps`, None]): irreducible representations of input for spherical harmonics. + Default: ``None``. + dtype (mindspore.dtype): The type of input tensor. Default: ``mindspore.float32`` . + + Inputs: + - **x** (Tensor) - Tensor for construct spherical harmonics. The shape of Tensor is :math:`(..., 3)`. + + Outputs: + - **output** (Tensor) - the spherical harmonics :math:`Y^l(x)`. The shape of Tensor is :math:`(..., 2l+1)`. + + Raise: + ValueError: If `normalization` is not in {'integral', 'component', 'norm'}. + ValueError: If `irreps_in` for SphericalHarmonics is not neither a vector (`1x1o`) nor a pseudovector (`1x1e`). + ValueError: If the `l` and `p` of `irreps_out` are not consistent with `irreps_in` for spherical harmonics. + The output parity should have been p = {input_p**l}. + NotImplementedError: If `l` is larger than 11. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.o3 import SphericalHarmonics + >>> from mindspore import ops + >>> sh = SphericalHarmonics(0, False, normalization='component') + >>> x = ops.rand(2,3) + >>> m = sh(x) + [[1.] + [1.]] + """ + + def __init__(self, irreps_out, normalize, normalization='integral', irreps_in=None, dtype=float32): + super().__init__() + self.normalize = normalize + self.normalization = normalization + if normalization not in ['integral', 'component', 'norm']: + raise ValueError + + if isinstance(irreps_out, str): + irreps_out = Irreps(irreps_out) + if isinstance(irreps_out, Irreps) and irreps_in is None: + for mul, (l, p) in irreps_out: + if l % 2 == 1 and p == 1: + irreps_in = Irreps("1e") + if irreps_in is None: + irreps_in = Irreps("1o") + + irreps_in = Irreps(irreps_in) + if irreps_in not in (Irreps("1x1o"), Irreps("1x1e")): + raise ValueError + self.irreps_in = irreps_in + input_p = irreps_in.data[0].ir.p + + if isinstance(irreps_out, Irreps): + ls = [] + for mul, (l, p) in irreps_out: + if p != input_p ** l: + raise ValueError + ls.extend([l] * mul) + elif isinstance(irreps_out, int): + ls = [irreps_out] + else: + ls = list(irreps_out) + + irreps_out = Irreps([(1, (l, input_p ** l)) for l in ls]).simplify() + self.irreps_out = irreps_out + self._ls_list = ls + self._lmax = max(ls) + self._is_range_lmax = ls == list(range(max(ls) + 1)) + self._prof_str = f'spherical_harmonics({ls})' + self.ones = ops.Ones() + + if self.normalization == 'integral': + self.norm_factors = [ + (_sqrt(2 * l + 1., dtype) / 3.5449077018110318) * + self.ones(2 * l + 1, dtype) + for l in self._ls_list + ] + elif self.normalization == 'component': + self.norm_factors = [ + _sqrt(2 * l + 1., dtype) * self.ones(2 * l + 1, dtype) + for l in self._ls_list + ] + + self.l2_normalize = ops.L2Normalize(axis=-1, epsilon=0.000000000001) + + def construct(self, x): + """ + Compute spherical harmonics of vector `x`. + + Args: + x (Tensor): Tensor for construct spherical harmonics. The shape of Tensor is :math:`x` of shape ``(..., 3)`` + + Returns: + Tensor, the spherical harmonics :math:`Y^l(x)`. The shape of Tensor is ``(..., 2l+1)`` + + Examples: + >>> sh = SphericalHarmonics(irreps_out="1o + 2x2e", normalize=True) + >>> input = ops.ones([1,3]) + >>> output = sh(input) + >>> print(output) + [[0.28209478 0.28209478 0.28209478 0.36418277 0.36418277 0 + 0.36418277 0 0.36418277 0.36418277 0 0.36418277 + 0]] + """ + last_dim = x.shape[-1] + if not last_dim == 3: + raise ValueError + + if self.normalize: + x = self.l2_normalize(x) + + sh = _spherical_harmonics(self._lmax, x[..., 0], x[..., 1], x[..., 2]) + + if not self._is_range_lmax: + sh = ops.concat([ + sh[..., l * l:(l + 1) * (l + 1)] + for l in self._ls_list + ], axis=-1) + if self.normalization != 'norm': + sh = ops.mul(sh, ops.concat(self.norm_factors)) + + return sh + + def __repr__(self): + return f'SphericalHarmonics {self._ls_list} ({self.irreps_in} -> {self.irreps_out})' + + +def spherical_harmonics(l, x, normalize=True, normalization='integral'): + r""" + Compute spherical harmonics. + + Spherical harmonics are polynomials defined on the 3d space : + math:`Y^l: \mathbb{R}^3 \longrightarrow \mathbb{R}^{2l+1}` + Usually restricted on the sphere (with ``normalize=True``) : + math:`Y^l: S^2 \longrightarrow \mathbb{R}^{2l+1}` + who satisfies the following properties: + - are polynomials of the cartesian coordinates ``x, y, z`` + - is equivariant :math:`Y^l(R x) = D^l(R) Y^l(x)` + - are orthogonal :math:`\int_{S^2} Y^l_m(x) Y^j_n(x) dx = \text{cste} \; \delta_{lj} \delta_{mn}` + The value of the constant depends on the choice of normalization. + + It obeys the following property: + .. math:: + Y^{l+1}_i(x) &= \text{cste}(l) \; & C_{ijk} Y^l_j(x) x_k + \partial_k Y^{l+1}_i(x) &= \text{cste}(l) \; (l+1) & C_{ijk} Y^l_j(x) + Where :math:`C` are the `wigner_3j`. + + Args: + l (Union[int, List[int]]): degree of the spherical harmonics. + x (Tensor): tensor for construct spherical harmonics. + The shape of Tensor is :math:`x` of shape ``(..., 3)`` + normalize (bool): whether to normalize the ``x`` to unit vectors that lie on the sphere before projecting onto + the spherical harmonics. + normalization (str): {'integral', 'component', 'norm'}, normalization method of the output tensors. + Default: 'intergral'. + 'component': :math:`\|Y^l(x)\|^2 = 2l+1, x \in S^2` + 'norm': :math:`\|Y^l(x)\| = 1, x \in S^2`, ``component / sqrt(2l+1)`` + 'integral': :math:`\int_{S^2} Y^l_m(x)^2 dx = 1`, ``component / sqrt(4pi)`` + + Returns: + Tensor, the spherical harmonics :math:`Y^l(x)`. The shape of Tensor is ``(..., 2l+1)``. + + Raise: + ValueError: If `normalization` is not in {'integral', 'component', 'norm'}. + ValueError: If `irreps_in` for SphericalHarmonics is not neither a vector (`1x1o`) nor a pseudovector (`1x1e`). + ValueError: If the `l` and `p` of `irreps_out` are not consistent with `irreps_in` for spherical harmonics. + The output parity should have been p = {input_p**l}. + ValueError: If the tensor `x` is not the shape of ``(..., 3)``. + NotImplementedError: If `l` is larger than 11. + + """ + sh = SphericalHarmonics(l, normalize, normalization, dtype=x.dtype) + return sh(x) + + +def _spherical_harmonics(lmax: int, x, y, z): + """core functions of spherical harmonics""" + + sh_0_0 = ops.ones_like(x) + if lmax == 0: + return ops.stack([ + sh_0_0, + ], axis=-1) + + sh_1_0 = x + sh_1_1 = y + sh_1_2 = z + if lmax == 1: + return ops.stack([ + sh_0_0, + sh_1_0, sh_1_1, sh_1_2 + ], axis=-1) + + sh_2_0 = 1.7320508075688772 * x * z + sh_2_1 = 1.7320508075688772 * x * y + y2 = y.pow(2) + x2z2 = x.pow(2) + z.pow(2) + sh_2_2 = y2 - 0.5 * x2z2 + sh_2_3 = 1.7320508075688772 * y * z + sh_2_4 = 1.7320508075688772 / 2.0 * (z.pow(2) - x.pow(2)) + + if lmax == 2: + return ops.stack([ + sh_0_0, + sh_1_0, sh_1_1, sh_1_2, + sh_2_0, sh_2_1, sh_2_2, sh_2_3, sh_2_4 + ], axis=-1) + + sh_3_0 = 0.9128709291752769 * (sh_2_0 * z + sh_2_4 * x) + sh_3_1 = 2.23606797749979 * sh_2_0 * y + sh_3_2 = 0.6123724356957945 * (4.0 * y2 - x2z2) * x + sh_3_3 = 0.5 * y * (2.0 * y2 - 3.0 * x2z2) + sh_3_4 = 0.6123724356957945 * z * (4.0 * y2 - x2z2) + sh_3_5 = 2.23606797749979 * sh_2_4 * y + sh_3_6 = 0.9128709291752769 * (sh_2_4 * z - sh_2_0 * x) + + if lmax == 3: + return ops.stack([ + sh_0_0, + sh_1_0, sh_1_1, sh_1_2, + sh_2_0, sh_2_1, sh_2_2, sh_2_3, sh_2_4, + sh_3_0, sh_3_1, sh_3_2, sh_3_3, sh_3_4, sh_3_5, sh_3_6 + ], axis=-1) + + sh_4_0 = 0.935414346693485 * sh_3_0 * z + 0.935414346693485 * sh_3_6 * x + sh_4_1 = 0.661437827766148 * sh_3_0 * y + 0.810092587300982 * \ + sh_3_1 * z + 0.810092587300983 * sh_3_5 * x + sh_4_2 = -0.176776695296637 * sh_3_0 * z + 0.866025403784439 * sh_3_1 * y + \ + 0.684653196881458 * sh_3_2 * z + 0.684653196881457 * \ + sh_3_4 * x + 0.176776695296637 * sh_3_6 * x + sh_4_3 = -0.306186217847897 * sh_3_1 * z + 0.968245836551855 * sh_3_2 * \ + y + 0.790569415042095 * sh_3_3 * x + 0.306186217847897 * sh_3_5 * x + sh_4_4 = -0.612372435695795 * sh_3_2 * x + \ + sh_3_3 * y - 0.612372435695795 * sh_3_4 * z + sh_4_5 = -0.306186217847897 * sh_3_1 * x + 0.790569415042096 * sh_3_3 * \ + z + 0.968245836551854 * sh_3_4 * y - 0.306186217847897 * sh_3_5 * z + sh_4_6 = -0.176776695296637 * sh_3_0 * x - 0.684653196881457 * sh_3_2 * x + \ + 0.684653196881457 * sh_3_4 * z + 0.866025403784439 * \ + sh_3_5 * y - 0.176776695296637 * sh_3_6 * z + sh_4_7 = -0.810092587300982 * sh_3_1 * x + 0.810092587300982 * \ + sh_3_5 * z + 0.661437827766148 * sh_3_6 * y + sh_4_8 = -0.935414346693485 * sh_3_0 * x + 0.935414346693486 * sh_3_6 * z + if lmax == 4: + return ops.stack([ + sh_0_0, + sh_1_0, sh_1_1, sh_1_2, + sh_2_0, sh_2_1, sh_2_2, sh_2_3, sh_2_4, + sh_3_0, sh_3_1, sh_3_2, sh_3_3, sh_3_4, sh_3_5, sh_3_6, + sh_4_0, sh_4_1, sh_4_2, sh_4_3, sh_4_4, sh_4_5, sh_4_6, sh_4_7, sh_4_8 + ], axis=-1) + + sh_5_0 = 0.948683298050513 * sh_4_0 * z + 0.948683298050513 * sh_4_8 * x + sh_5_1 = 0.6 * sh_4_0 * y + 0.848528137423857 * \ + sh_4_1 * z + 0.848528137423858 * sh_4_7 * x + sh_5_2 = -0.14142135623731 * sh_4_0 * z + 0.8 * sh_4_1 * y + 0.748331477354788 * \ + sh_4_2 * z + 0.748331477354788 * sh_4_6 * x + 0.14142135623731 * sh_4_8 * x + sh_5_3 = -0.244948974278318 * sh_4_1 * z + 0.916515138991168 * sh_4_2 * y + \ + 0.648074069840786 * sh_4_3 * z + 0.648074069840787 * \ + sh_4_5 * x + 0.244948974278318 * sh_4_7 * x + sh_5_4 = -0.346410161513776 * sh_4_2 * z + 0.979795897113272 * sh_4_3 * \ + y + 0.774596669241484 * sh_4_4 * x + 0.346410161513776 * sh_4_6 * x + sh_5_5 = -0.632455532033676 * sh_4_3 * x + \ + sh_4_4 * y - 0.632455532033676 * sh_4_5 * z + sh_5_6 = -0.346410161513776 * sh_4_2 * x + 0.774596669241483 * sh_4_4 * \ + z + 0.979795897113273 * sh_4_5 * y - 0.346410161513776 * sh_4_6 * z + sh_5_7 = -0.244948974278318 * sh_4_1 * x - 0.648074069840787 * sh_4_3 * x + \ + 0.648074069840786 * sh_4_5 * z + 0.916515138991169 * \ + sh_4_6 * y - 0.244948974278318 * sh_4_7 * z + sh_5_8 = -0.141421356237309 * sh_4_0 * x - 0.748331477354788 * sh_4_2 * x + \ + 0.748331477354788 * sh_4_6 * z + 0.8 * \ + sh_4_7 * y - 0.141421356237309 * sh_4_8 * z + sh_5_9 = -0.848528137423857 * sh_4_1 * x + \ + 0.848528137423857 * sh_4_7 * z + 0.6 * sh_4_8 * y + sh_5_10 = -0.948683298050513 * sh_4_0 * x + 0.948683298050513 * sh_4_8 * z + if lmax == 5: + return ops.stack([ + sh_0_0, + sh_1_0, sh_1_1, sh_1_2, + sh_2_0, sh_2_1, sh_2_2, sh_2_3, sh_2_4, + sh_3_0, sh_3_1, sh_3_2, sh_3_3, sh_3_4, sh_3_5, sh_3_6, + sh_4_0, sh_4_1, sh_4_2, sh_4_3, sh_4_4, sh_4_5, sh_4_6, sh_4_7, sh_4_8, + sh_5_0, sh_5_1, sh_5_2, sh_5_3, sh_5_4, sh_5_5, sh_5_6, sh_5_7, sh_5_8, sh_5_9, sh_5_10 + ], axis=-1) + + sh_6_0 = 0.957427107756337 * sh_5_0 * z + 0.957427107756338 * sh_5_10 * x + sh_6_1 = 0.552770798392565 * sh_5_0 * y + 0.874007373475125 * \ + sh_5_1 * z + 0.874007373475125 * sh_5_9 * x + sh_6_2 = -0.117851130197757 * sh_5_0 * z + 0.745355992499929 * sh_5_1 * y + \ + 0.117851130197758 * sh_5_10 * x + 0.790569415042094 * \ + sh_5_2 * z + 0.790569415042093 * sh_5_8 * x + sh_6_3 = -0.204124145231931 * sh_5_1 * z + 0.866025403784437 * sh_5_2 * y + \ + 0.707106781186546 * sh_5_3 * z + 0.707106781186547 * \ + sh_5_7 * x + 0.204124145231931 * sh_5_9 * x + sh_6_4 = -0.288675134594813 * sh_5_2 * z + 0.942809041582062 * sh_5_3 * y + \ + 0.623609564462323 * sh_5_4 * z + 0.623609564462322 * \ + sh_5_6 * x + 0.288675134594812 * sh_5_8 * x + sh_6_5 = -0.372677996249965 * sh_5_3 * z + 0.986013297183268 * sh_5_4 * \ + y + 0.763762615825972 * sh_5_5 * x + 0.372677996249964 * sh_5_7 * x + sh_6_6 = -0.645497224367901 * sh_5_4 * x + \ + sh_5_5 * y - 0.645497224367902 * sh_5_6 * z + sh_6_7 = -0.372677996249964 * sh_5_3 * x + 0.763762615825972 * sh_5_5 * \ + z + 0.986013297183269 * sh_5_6 * y - 0.372677996249965 * sh_5_7 * z + sh_6_8 = -0.288675134594813 * sh_5_2 * x - 0.623609564462323 * sh_5_4 * x + \ + 0.623609564462323 * sh_5_6 * z + 0.942809041582062 * \ + sh_5_7 * y - 0.288675134594812 * sh_5_8 * z + sh_6_9 = -0.20412414523193 * sh_5_1 * x - 0.707106781186546 * sh_5_3 * x + \ + 0.707106781186547 * sh_5_7 * z + 0.866025403784438 * \ + sh_5_8 * y - 0.204124145231931 * sh_5_9 * z + sh_6_10 = -0.117851130197757 * sh_5_0 * x - 0.117851130197757 * sh_5_10 * z - \ + 0.790569415042094 * sh_5_2 * x + 0.790569415042093 * \ + sh_5_8 * z + 0.745355992499929 * sh_5_9 * y + sh_6_11 = -0.874007373475124 * sh_5_1 * x + 0.552770798392566 * \ + sh_5_10 * y + 0.874007373475125 * sh_5_9 * z + sh_6_12 = -0.957427107756337 * sh_5_0 * x + 0.957427107756336 * sh_5_10 * z + if lmax == 6: + return ops.stack([ + sh_0_0, + sh_1_0, sh_1_1, sh_1_2, + sh_2_0, sh_2_1, sh_2_2, sh_2_3, sh_2_4, + sh_3_0, sh_3_1, sh_3_2, sh_3_3, sh_3_4, sh_3_5, sh_3_6, + sh_4_0, sh_4_1, sh_4_2, sh_4_3, sh_4_4, sh_4_5, sh_4_6, sh_4_7, sh_4_8, + sh_5_0, sh_5_1, sh_5_2, sh_5_3, sh_5_4, sh_5_5, sh_5_6, sh_5_7, sh_5_8, sh_5_9, sh_5_10, + sh_6_0, sh_6_1, sh_6_2, sh_6_3, sh_6_4, sh_6_5, sh_6_6, sh_6_7, sh_6_8, sh_6_9, sh_6_10, sh_6_11, sh_6_12 + ], axis=-1) + + sh_7_0 = 0.963624111659433 * sh_6_0 * z + 0.963624111659432 * sh_6_12 * x + sh_7_1 = 0.515078753637713 * sh_6_0 * y + 0.892142571199771 * \ + sh_6_1 * z + 0.892142571199771 * sh_6_11 * x + sh_7_2 = -0.101015254455221 * sh_6_0 * z + 0.699854212223765 * sh_6_1 * y + \ + 0.82065180664829 * sh_6_10 * x + 0.101015254455222 * \ + sh_6_12 * x + 0.82065180664829 * sh_6_2 * z + sh_7_3 = -0.174963553055942 * sh_6_1 * z + 0.174963553055941 * sh_6_11 * x + \ + 0.82065180664829 * sh_6_2 * y + 0.749149177264394 * \ + sh_6_3 * z + 0.749149177264394 * sh_6_9 * x + sh_7_4 = 0.247435829652697 * sh_6_10 * x - 0.247435829652697 * sh_6_2 * z + \ + 0.903507902905251 * sh_6_3 * y + 0.677630927178938 * \ + sh_6_4 * z + 0.677630927178938 * sh_6_8 * x + sh_7_5 = -0.31943828249997 * sh_6_3 * z + 0.95831484749991 * sh_6_4 * y + \ + 0.606091526731326 * sh_6_5 * z + 0.606091526731326 * \ + sh_6_7 * x + 0.31943828249997 * sh_6_9 * x + sh_7_6 = -0.391230398217976 * sh_6_4 * z + 0.989743318610787 * sh_6_5 * \ + y + 0.755928946018454 * sh_6_6 * x + 0.391230398217975 * sh_6_8 * x + sh_7_7 = -0.654653670707977 * sh_6_5 * x + \ + sh_6_6 * y - 0.654653670707978 * sh_6_7 * z + sh_7_8 = -0.391230398217976 * sh_6_4 * x + 0.755928946018455 * sh_6_6 * \ + z + 0.989743318610787 * sh_6_7 * y - 0.391230398217975 * sh_6_8 * z + sh_7_9 = -0.31943828249997 * sh_6_3 * x - 0.606091526731327 * sh_6_5 * x + \ + 0.606091526731326 * sh_6_7 * z + 0.95831484749991 * \ + sh_6_8 * y - 0.31943828249997 * sh_6_9 * z + sh_7_10 = -0.247435829652697 * sh_6_10 * z - 0.247435829652697 * sh_6_2 * x - \ + 0.677630927178938 * sh_6_4 * x + 0.677630927178938 * \ + sh_6_8 * z + 0.903507902905251 * sh_6_9 * y + sh_7_11 = -0.174963553055942 * sh_6_1 * x + 0.820651806648289 * sh_6_10 * y - \ + 0.174963553055941 * sh_6_11 * z - 0.749149177264394 * \ + sh_6_3 * x + 0.749149177264394 * sh_6_9 * z + sh_7_12 = -0.101015254455221 * sh_6_0 * x + 0.82065180664829 * sh_6_10 * z + \ + 0.699854212223766 * sh_6_11 * y - 0.101015254455221 * \ + sh_6_12 * z - 0.82065180664829 * sh_6_2 * x + sh_7_13 = -0.892142571199772 * sh_6_1 * x + 0.892142571199772 * \ + sh_6_11 * z + 0.515078753637713 * sh_6_12 * y + sh_7_14 = -0.963624111659431 * sh_6_0 * x + 0.963624111659433 * sh_6_12 * z + if lmax == 7: + return ops.stack([ + sh_0_0, + sh_1_0, sh_1_1, sh_1_2, + sh_2_0, sh_2_1, sh_2_2, sh_2_3, sh_2_4, + sh_3_0, sh_3_1, sh_3_2, sh_3_3, sh_3_4, sh_3_5, sh_3_6, + sh_4_0, sh_4_1, sh_4_2, sh_4_3, sh_4_4, sh_4_5, sh_4_6, sh_4_7, sh_4_8, + sh_5_0, sh_5_1, sh_5_2, sh_5_3, sh_5_4, sh_5_5, sh_5_6, sh_5_7, sh_5_8, sh_5_9, sh_5_10, + sh_6_0, sh_6_1, sh_6_2, sh_6_3, sh_6_4, sh_6_5, sh_6_6, sh_6_7, sh_6_8, sh_6_9, sh_6_10, sh_6_11, sh_6_12, + sh_7_0, sh_7_1, sh_7_2, sh_7_3, sh_7_4, sh_7_5, sh_7_6, sh_7_7, sh_7_8, sh_7_9, sh_7_10, sh_7_11, sh_7_12, + sh_7_13, sh_7_14 + ], axis=-1) + + sh_8_0 = 0.968245836551854 * sh_7_0 * z + 0.968245836551853 * sh_7_14 * x + sh_8_1 = 0.484122918275928 * sh_7_0 * y + 0.90571104663684 * \ + sh_7_1 * z + 0.90571104663684 * sh_7_13 * x + sh_8_2 = -0.0883883476483189 * sh_7_0 * z + 0.661437827766148 * sh_7_1 * y + \ + 0.843171097702002 * sh_7_12 * x + 0.088388347648318 * \ + sh_7_14 * x + 0.843171097702003 * sh_7_2 * z + sh_8_3 = -0.153093108923948 * sh_7_1 * z + 0.7806247497998 * sh_7_11 * x + \ + 0.153093108923949 * sh_7_13 * x + 0.7806247497998 * \ + sh_7_2 * y + 0.780624749799799 * sh_7_3 * z + sh_8_4 = 0.718070330817253 * sh_7_10 * x + 0.21650635094611 * sh_7_12 * x - \ + 0.21650635094611 * sh_7_2 * z + 0.866025403784439 * \ + sh_7_3 * y + 0.718070330817254 * sh_7_4 * z + sh_8_5 = 0.279508497187474 * sh_7_11 * x - 0.279508497187474 * sh_7_3 * z + \ + 0.927024810886958 * sh_7_4 * y + 0.655505530106345 * \ + sh_7_5 * z + 0.655505530106344 * sh_7_9 * x + sh_8_6 = 0.342326598440729 * sh_7_10 * x - 0.342326598440729 * sh_7_4 * z + \ + 0.968245836551854 * sh_7_5 * y + 0.592927061281572 * \ + sh_7_6 * z + 0.592927061281571 * sh_7_8 * x + sh_8_7 = -0.405046293650492 * sh_7_5 * z + 0.992156741649221 * \ + sh_7_6 * y + 0.75 * sh_7_7 * x + 0.405046293650492 * sh_7_9 * x + sh_8_8 = -0.661437827766148 * sh_7_6 * x + \ + sh_7_7 * y - 0.661437827766148 * sh_7_8 * z + sh_8_9 = -0.405046293650492 * sh_7_5 * x + 0.75 * sh_7_7 * z + \ + 0.992156741649221 * sh_7_8 * y - 0.405046293650491 * sh_7_9 * z + sh_8_10 = -0.342326598440728 * sh_7_10 * z - 0.342326598440729 * sh_7_4 * x - \ + 0.592927061281571 * sh_7_6 * x + 0.592927061281571 * \ + sh_7_8 * z + 0.968245836551855 * sh_7_9 * y + sh_8_11 = 0.927024810886958 * sh_7_10 * y - 0.279508497187474 * sh_7_11 * z - \ + 0.279508497187474 * sh_7_3 * x - 0.655505530106345 * \ + sh_7_5 * x + 0.655505530106345 * sh_7_9 * z + sh_8_12 = 0.718070330817253 * sh_7_10 * z + 0.866025403784439 * sh_7_11 * y - \ + 0.216506350946109 * sh_7_12 * z - 0.216506350946109 * \ + sh_7_2 * x - 0.718070330817254 * sh_7_4 * x + sh_8_13 = -0.153093108923948 * sh_7_1 * x + 0.7806247497998 * sh_7_11 * z + \ + 0.7806247497998 * sh_7_12 * y - 0.153093108923948 * \ + sh_7_13 * z - 0.780624749799799 * sh_7_3 * x + sh_8_14 = -0.0883883476483179 * sh_7_0 * x + 0.843171097702002 * sh_7_12 * z + \ + 0.661437827766147 * sh_7_13 * y - 0.088388347648319 * \ + sh_7_14 * z - 0.843171097702002 * sh_7_2 * x + sh_8_15 = -0.90571104663684 * sh_7_1 * x + 0.90571104663684 * \ + sh_7_13 * z + 0.484122918275927 * sh_7_14 * y + sh_8_16 = -0.968245836551853 * sh_7_0 * x + 0.968245836551855 * sh_7_14 * z + if lmax == 8: + return ops.stack([ + sh_0_0, + sh_1_0, sh_1_1, sh_1_2, + sh_2_0, sh_2_1, sh_2_2, sh_2_3, sh_2_4, + sh_3_0, sh_3_1, sh_3_2, sh_3_3, sh_3_4, sh_3_5, sh_3_6, + sh_4_0, sh_4_1, sh_4_2, sh_4_3, sh_4_4, sh_4_5, sh_4_6, sh_4_7, sh_4_8, + sh_5_0, sh_5_1, sh_5_2, sh_5_3, sh_5_4, sh_5_5, sh_5_6, sh_5_7, sh_5_8, sh_5_9, sh_5_10, + sh_6_0, sh_6_1, sh_6_2, sh_6_3, sh_6_4, sh_6_5, sh_6_6, sh_6_7, sh_6_8, sh_6_9, sh_6_10, sh_6_11, sh_6_12, + sh_7_0, sh_7_1, sh_7_2, sh_7_3, sh_7_4, sh_7_5, sh_7_6, sh_7_7, sh_7_8, sh_7_9, sh_7_10, sh_7_11, sh_7_12, + sh_7_13, sh_7_14, + sh_8_0, sh_8_1, sh_8_2, sh_8_3, sh_8_4, sh_8_5, sh_8_6, sh_8_7, sh_8_8, sh_8_9, sh_8_10, sh_8_11, sh_8_12, + sh_8_13, sh_8_14, sh_8_15, sh_8_16 + ], axis=-1) + + sh_9_0 = 0.97182531580755 * sh_8_0 * z + 0.971825315807551 * sh_8_16 * x + sh_9_1 = 0.458122847290851 * sh_8_0 * y + 0.916245694581702 * \ + sh_8_1 * z + 0.916245694581702 * sh_8_15 * x + sh_9_2 = -0.078567420131839 * sh_8_0 * z + 0.62853936105471 * sh_8_1 * y + 0.86066296582387 * \ + sh_8_14 * x + 0.0785674201318385 * sh_8_16 * x + 0.860662965823871 * sh_8_2 * z + sh_9_3 = -0.136082763487955 * sh_8_1 * z + 0.805076485899413 * sh_8_13 * x + \ + 0.136082763487954 * sh_8_15 * x + 0.74535599249993 * \ + sh_8_2 * y + 0.805076485899413 * sh_8_3 * z + sh_9_4 = 0.749485420179558 * sh_8_12 * x + 0.192450089729875 * sh_8_14 * x - \ + 0.192450089729876 * sh_8_2 * z + 0.831479419283099 * \ + sh_8_3 * y + 0.749485420179558 * sh_8_4 * z + sh_9_5 = 0.693888666488711 * sh_8_11 * x + 0.248451997499977 * sh_8_13 * x - \ + 0.248451997499976 * sh_8_3 * z + 0.895806416477617 * \ + sh_8_4 * y + 0.69388866648871 * sh_8_5 * z + sh_9_6 = 0.638284738504225 * sh_8_10 * x + 0.304290309725092 * sh_8_12 * x - \ + 0.304290309725092 * sh_8_4 * z + 0.942809041582063 * \ + sh_8_5 * y + 0.638284738504225 * sh_8_6 * z + sh_9_7 = 0.360041149911548 * sh_8_11 * x - 0.360041149911548 * sh_8_5 * z + \ + 0.974996043043569 * sh_8_6 * y + 0.582671582316751 * \ + sh_8_7 * z + 0.582671582316751 * sh_8_9 * x + sh_9_8 = 0.415739709641549 * sh_8_10 * x - 0.415739709641549 * sh_8_6 * \ + z + 0.993807989999906 * sh_8_7 * y + 0.74535599249993 * sh_8_8 * x + sh_9_9 = -0.66666666666666666667 * sh_8_7 * x + \ + sh_8_8 * y - 0.66666666666666666667 * sh_8_9 * z + sh_9_10 = -0.415739709641549 * sh_8_10 * z - 0.415739709641549 * sh_8_6 * \ + x + 0.74535599249993 * sh_8_8 * z + 0.993807989999906 * sh_8_9 * y + sh_9_11 = 0.974996043043568 * sh_8_10 * y - 0.360041149911547 * sh_8_11 * z - \ + 0.360041149911548 * sh_8_5 * x - 0.582671582316751 * \ + sh_8_7 * x + 0.582671582316751 * sh_8_9 * z + sh_9_12 = 0.638284738504225 * sh_8_10 * z + 0.942809041582063 * sh_8_11 * y - \ + 0.304290309725092 * sh_8_12 * z - 0.304290309725092 * \ + sh_8_4 * x - 0.638284738504225 * sh_8_6 * x + sh_9_13 = 0.693888666488711 * sh_8_11 * z + 0.895806416477617 * sh_8_12 * y - \ + 0.248451997499977 * sh_8_13 * z - 0.248451997499977 * \ + sh_8_3 * x - 0.693888666488711 * sh_8_5 * x + sh_9_14 = 0.749485420179558 * sh_8_12 * z + 0.831479419283098 * sh_8_13 * y - \ + 0.192450089729875 * sh_8_14 * z - 0.192450089729875 * \ + sh_8_2 * x - 0.749485420179558 * sh_8_4 * x + sh_9_15 = -0.136082763487954 * sh_8_1 * x + 0.805076485899413 * sh_8_13 * z + \ + 0.745355992499929 * sh_8_14 * y - 0.136082763487955 * \ + sh_8_15 * z - 0.805076485899413 * sh_8_3 * x + sh_9_16 = -0.0785674201318389 * sh_8_0 * x + 0.86066296582387 * sh_8_14 * z + \ + 0.628539361054709 * sh_8_15 * y - 0.0785674201318387 * \ + sh_8_16 * z - 0.860662965823871 * sh_8_2 * x + sh_9_17 = -0.9162456945817 * sh_8_1 * x + 0.916245694581702 * \ + sh_8_15 * z + 0.458122847290851 * sh_8_16 * y + sh_9_18 = -0.97182531580755 * sh_8_0 * x + 0.97182531580755 * sh_8_16 * z + if lmax == 9: + return ops.stack([ + sh_0_0, + sh_1_0, sh_1_1, sh_1_2, + sh_2_0, sh_2_1, sh_2_2, sh_2_3, sh_2_4, + sh_3_0, sh_3_1, sh_3_2, sh_3_3, sh_3_4, sh_3_5, sh_3_6, + sh_4_0, sh_4_1, sh_4_2, sh_4_3, sh_4_4, sh_4_5, sh_4_6, sh_4_7, sh_4_8, + sh_5_0, sh_5_1, sh_5_2, sh_5_3, sh_5_4, sh_5_5, sh_5_6, sh_5_7, sh_5_8, sh_5_9, sh_5_10, + sh_6_0, sh_6_1, sh_6_2, sh_6_3, sh_6_4, sh_6_5, sh_6_6, sh_6_7, sh_6_8, sh_6_9, sh_6_10, sh_6_11, sh_6_12, + sh_7_0, sh_7_1, sh_7_2, sh_7_3, sh_7_4, sh_7_5, sh_7_6, sh_7_7, sh_7_8, sh_7_9, sh_7_10, sh_7_11, sh_7_12, + sh_7_13, sh_7_14, + sh_8_0, sh_8_1, sh_8_2, sh_8_3, sh_8_4, sh_8_5, sh_8_6, sh_8_7, sh_8_8, sh_8_9, sh_8_10, sh_8_11, sh_8_12, + sh_8_13, sh_8_14, sh_8_15, sh_8_16, + sh_9_0, sh_9_1, sh_9_2, sh_9_3, sh_9_4, sh_9_5, sh_9_6, sh_9_7, sh_9_8, sh_9_9, sh_9_10, sh_9_11, sh_9_12, + sh_9_13, sh_9_14, sh_9_15, sh_9_16, sh_9_17, sh_9_18 + ], axis=-1) + + sh_10_0 = 0.974679434480897 * sh_9_0 * z + 0.974679434480897 * sh_9_18 * x + sh_10_1 = 0.435889894354067 * sh_9_0 * y + 0.924662100445347 * \ + sh_9_1 * z + 0.924662100445347 * sh_9_17 * x + sh_10_2 = -0.0707106781186546 * sh_9_0 * z + 0.6 * sh_9_1 * y + 0.874642784226796 * \ + sh_9_16 * x + 0.070710678118655 * sh_9_18 * x + 0.874642784226795 * sh_9_2 * z + sh_10_3 = -0.122474487139159 * sh_9_1 * z + 0.824621125123533 * sh_9_15 * x + \ + 0.122474487139159 * sh_9_17 * x + 0.714142842854285 * \ + sh_9_2 * y + 0.824621125123533 * sh_9_3 * z + sh_10_4 = 0.774596669241484 * sh_9_14 * x + 0.173205080756887 * sh_9_16 * x - \ + 0.173205080756888 * sh_9_2 * z + 0.8 * \ + sh_9_3 * y + 0.774596669241483 * sh_9_4 * z + sh_10_5 = 0.724568837309472 * sh_9_13 * x + 0.223606797749979 * sh_9_15 * x - \ + 0.223606797749979 * sh_9_3 * z + 0.866025403784438 * \ + sh_9_4 * y + 0.724568837309472 * sh_9_5 * z + sh_10_6 = 0.674536878161602 * sh_9_12 * x + 0.273861278752583 * sh_9_14 * x - \ + 0.273861278752583 * sh_9_4 * z + 0.916515138991168 * \ + sh_9_5 * y + 0.674536878161602 * sh_9_6 * z + sh_10_7 = 0.62449979983984 * sh_9_11 * x + 0.324037034920393 * sh_9_13 * x - \ + 0.324037034920393 * sh_9_5 * z + 0.953939201416946 * \ + sh_9_6 * y + 0.62449979983984 * sh_9_7 * z + sh_10_8 = 0.574456264653803 * sh_9_10 * x + 0.374165738677394 * sh_9_12 * x - \ + 0.374165738677394 * sh_9_6 * z + 0.979795897113272 * \ + sh_9_7 * y + 0.574456264653803 * sh_9_8 * z + sh_10_9 = 0.424264068711928 * sh_9_11 * x - 0.424264068711929 * sh_9_7 * \ + z + 0.99498743710662 * sh_9_8 * y + 0.741619848709567 * sh_9_9 * x + sh_10_10 = -0.670820393249937 * sh_9_10 * z - \ + 0.670820393249937 * sh_9_8 * x + sh_9_9 * y + sh_10_11 = 0.99498743710662 * sh_9_10 * y - 0.424264068711929 * sh_9_11 * \ + z - 0.424264068711929 * sh_9_7 * x + 0.741619848709567 * sh_9_9 * z + sh_10_12 = 0.574456264653803 * sh_9_10 * z + 0.979795897113272 * sh_9_11 * y - \ + 0.374165738677395 * sh_9_12 * z - 0.374165738677394 * \ + sh_9_6 * x - 0.574456264653803 * sh_9_8 * x + sh_10_13 = 0.62449979983984 * sh_9_11 * z + 0.953939201416946 * sh_9_12 * y - \ + 0.324037034920393 * sh_9_13 * z - 0.324037034920393 * \ + sh_9_5 * x - 0.62449979983984 * sh_9_7 * x + sh_10_14 = 0.674536878161602 * sh_9_12 * z + 0.916515138991168 * sh_9_13 * y - \ + 0.273861278752583 * sh_9_14 * z - 0.273861278752583 * \ + sh_9_4 * x - 0.674536878161603 * sh_9_6 * x + sh_10_15 = 0.724568837309472 * sh_9_13 * z + 0.866025403784439 * sh_9_14 * y - \ + 0.223606797749979 * sh_9_15 * z - 0.223606797749979 * \ + sh_9_3 * x - 0.724568837309472 * sh_9_5 * x + sh_10_16 = 0.774596669241484 * sh_9_14 * z + 0.8 * sh_9_15 * y - 0.173205080756888 * \ + sh_9_16 * z - 0.173205080756887 * sh_9_2 * x - 0.774596669241484 * sh_9_4 * x + sh_10_17 = -0.12247448713916 * sh_9_1 * x + 0.824621125123532 * sh_9_15 * z + \ + 0.714142842854285 * sh_9_16 * y - 0.122474487139158 * \ + sh_9_17 * z - 0.824621125123533 * sh_9_3 * x + sh_10_18 = -0.0707106781186548 * sh_9_0 * x + 0.874642784226796 * sh_9_16 * z + \ + 0.6 * sh_9_17 * y - 0.0707106781186546 * \ + sh_9_18 * z - 0.874642784226796 * sh_9_2 * x + sh_10_19 = -0.924662100445348 * sh_9_1 * x + 0.924662100445347 * \ + sh_9_17 * z + 0.435889894354068 * sh_9_18 * y + sh_10_20 = -0.974679434480898 * sh_9_0 * x + 0.974679434480896 * sh_9_18 * z + if lmax == 10: + return ops.stack([ + sh_0_0, + sh_1_0, sh_1_1, sh_1_2, + sh_2_0, sh_2_1, sh_2_2, sh_2_3, sh_2_4, + sh_3_0, sh_3_1, sh_3_2, sh_3_3, sh_3_4, sh_3_5, sh_3_6, + sh_4_0, sh_4_1, sh_4_2, sh_4_3, sh_4_4, sh_4_5, sh_4_6, sh_4_7, sh_4_8, + sh_5_0, sh_5_1, sh_5_2, sh_5_3, sh_5_4, sh_5_5, sh_5_6, sh_5_7, sh_5_8, sh_5_9, sh_5_10, + sh_6_0, sh_6_1, sh_6_2, sh_6_3, sh_6_4, sh_6_5, sh_6_6, sh_6_7, sh_6_8, sh_6_9, sh_6_10, sh_6_11, sh_6_12, + sh_7_0, sh_7_1, sh_7_2, sh_7_3, sh_7_4, sh_7_5, sh_7_6, sh_7_7, sh_7_8, sh_7_9, sh_7_10, sh_7_11, sh_7_12, + sh_7_13, sh_7_14, + sh_8_0, sh_8_1, sh_8_2, sh_8_3, sh_8_4, sh_8_5, sh_8_6, sh_8_7, sh_8_8, sh_8_9, sh_8_10, sh_8_11, sh_8_12, + sh_8_13, sh_8_14, sh_8_15, sh_8_16, + sh_9_0, sh_9_1, sh_9_2, sh_9_3, sh_9_4, sh_9_5, sh_9_6, sh_9_7, sh_9_8, sh_9_9, sh_9_10, sh_9_11, sh_9_12, + sh_9_13, sh_9_14, sh_9_15, sh_9_16, sh_9_17, sh_9_18, + sh_10_0, sh_10_1, sh_10_2, sh_10_3, sh_10_4, sh_10_5, sh_10_6, sh_10_7, sh_10_8, sh_10_9, sh_10_10, + sh_10_11, sh_10_12, sh_10_13, sh_10_14, sh_10_15, sh_10_16, sh_10_17, sh_10_18, sh_10_19, sh_10_20 + ], axis=-1) + + sh_11_0 = 0.977008420918394 * sh_10_0 * z + 0.977008420918394 * sh_10_20 * x + sh_11_1 = 0.416597790450531 * sh_10_0 * y + 0.9315409787236 * \ + sh_10_1 * z + 0.931540978723599 * sh_10_19 * x + sh_11_2 = -0.0642824346533223 * sh_10_0 * z + 0.574959574576069 * sh_10_1 * y + \ + 0.88607221316445 * sh_10_18 * x + 0.886072213164452 * \ + sh_10_2 * z + 0.0642824346533226 * sh_10_20 * x + sh_11_3 = -0.111340442853781 * sh_10_1 * z + 0.84060190949577 * sh_10_17 * x + \ + 0.111340442853781 * sh_10_19 * x + 0.686348585024614 * \ + sh_10_2 * y + 0.840601909495769 * sh_10_3 * z + sh_11_4 = 0.795129803842541 * sh_10_16 * x + 0.157459164324444 * sh_10_18 * x - \ + 0.157459164324443 * sh_10_2 * z + 0.771389215839871 * \ + sh_10_3 * y + 0.795129803842541 * sh_10_4 * z + sh_11_5 = 0.74965556829412 * sh_10_15 * x + 0.203278907045435 * sh_10_17 * x - \ + 0.203278907045436 * sh_10_3 * z + 0.838140405208444 * \ + sh_10_4 * y + 0.74965556829412 * sh_10_5 * z + sh_11_6 = 0.70417879021953 * sh_10_14 * x + 0.248964798865985 * sh_10_16 * x - \ + 0.248964798865985 * sh_10_4 * z + 0.890723542830247 * \ + sh_10_5 * y + 0.704178790219531 * sh_10_6 * z + sh_11_7 = 0.658698943008611 * sh_10_13 * x + 0.294579122654903 * sh_10_15 * x - \ + 0.294579122654903 * sh_10_5 * z + 0.9315409787236 * \ + sh_10_6 * y + 0.658698943008611 * sh_10_7 * z + sh_11_8 = 0.613215343783275 * sh_10_12 * x + 0.340150671524904 * sh_10_14 * x - \ + 0.340150671524904 * sh_10_6 * z + 0.962091385841669 * \ + sh_10_7 * y + 0.613215343783274 * sh_10_8 * z + sh_11_9 = 0.567727090763491 * sh_10_11 * x + 0.385694607919935 * sh_10_13 * x - \ + 0.385694607919935 * sh_10_7 * z + 0.983332166035633 * \ + sh_10_8 * y + 0.56772709076349 * sh_10_9 * z + sh_11_10 = 0.738548945875997 * sh_10_10 * x + 0.431219680932052 * sh_10_12 * \ + x - 0.431219680932052 * sh_10_8 * z + 0.995859195463938 * sh_10_9 * y + sh_11_11 = sh_10_10 * y - 0.674199862463242 * \ + sh_10_11 * z - 0.674199862463243 * sh_10_9 * x + sh_11_12 = 0.738548945875996 * sh_10_10 * z + 0.995859195463939 * sh_10_11 * \ + y - 0.431219680932052 * sh_10_12 * z - 0.431219680932053 * sh_10_8 * x + sh_11_13 = 0.567727090763491 * sh_10_11 * z + 0.983332166035634 * sh_10_12 * y - \ + 0.385694607919935 * sh_10_13 * z - 0.385694607919935 * \ + sh_10_7 * x - 0.567727090763491 * sh_10_9 * x + sh_11_14 = 0.613215343783275 * sh_10_12 * z + 0.96209138584167 * sh_10_13 * y - \ + 0.340150671524904 * sh_10_14 * z - 0.340150671524904 * \ + sh_10_6 * x - 0.613215343783274 * sh_10_8 * x + sh_11_15 = 0.658698943008611 * sh_10_13 * z + 0.9315409787236 * sh_10_14 * y - \ + 0.294579122654903 * sh_10_15 * z - 0.294579122654903 * \ + sh_10_5 * x - 0.65869894300861 * sh_10_7 * x + sh_11_16 = 0.70417879021953 * sh_10_14 * z + 0.890723542830246 * sh_10_15 * y - \ + 0.248964798865985 * sh_10_16 * z - 0.248964798865985 * \ + sh_10_4 * x - 0.70417879021953 * sh_10_6 * x + sh_11_17 = 0.749655568294121 * sh_10_15 * z + 0.838140405208444 * sh_10_16 * y - \ + 0.203278907045436 * sh_10_17 * z - 0.203278907045435 * \ + sh_10_3 * x - 0.749655568294119 * sh_10_5 * x + sh_11_18 = 0.79512980384254 * sh_10_16 * z + 0.77138921583987 * sh_10_17 * y - \ + 0.157459164324443 * sh_10_18 * z - 0.157459164324444 * \ + sh_10_2 * x - 0.795129803842541 * sh_10_4 * x + sh_11_19 = -0.111340442853782 * sh_10_1 * x + 0.84060190949577 * sh_10_17 * z + \ + 0.686348585024614 * sh_10_18 * y - 0.111340442853781 * \ + sh_10_19 * z - 0.840601909495769 * sh_10_3 * x + sh_11_20 = -0.0642824346533226 * sh_10_0 * x + 0.886072213164451 * sh_10_18 * z + \ + 0.57495957457607 * sh_10_19 * y - 0.886072213164451 * \ + sh_10_2 * x - 0.0642824346533228 * sh_10_20 * z + sh_11_21 = -0.9315409787236 * sh_10_1 * x + 0.931540978723599 * \ + sh_10_19 * z + 0.416597790450531 * sh_10_20 * y + sh_11_22 = -0.977008420918393 * sh_10_0 * x + 0.977008420918393 * sh_10_20 * z + return ops.stack([ + sh_0_0, + sh_1_0, sh_1_1, sh_1_2, + sh_2_0, sh_2_1, sh_2_2, sh_2_3, sh_2_4, + sh_3_0, sh_3_1, sh_3_2, sh_3_3, sh_3_4, sh_3_5, sh_3_6, + sh_4_0, sh_4_1, sh_4_2, sh_4_3, sh_4_4, sh_4_5, sh_4_6, sh_4_7, sh_4_8, + sh_5_0, sh_5_1, sh_5_2, sh_5_3, sh_5_4, sh_5_5, sh_5_6, sh_5_7, sh_5_8, sh_5_9, sh_5_10, + sh_6_0, sh_6_1, sh_6_2, sh_6_3, sh_6_4, sh_6_5, sh_6_6, sh_6_7, sh_6_8, sh_6_9, sh_6_10, sh_6_11, sh_6_12, + sh_7_0, sh_7_1, sh_7_2, sh_7_3, sh_7_4, sh_7_5, sh_7_6, sh_7_7, sh_7_8, sh_7_9, sh_7_10, sh_7_11, sh_7_12, + sh_7_13, sh_7_14, + sh_8_0, sh_8_1, sh_8_2, sh_8_3, sh_8_4, sh_8_5, sh_8_6, sh_8_7, sh_8_8, sh_8_9, sh_8_10, sh_8_11, sh_8_12, + sh_8_13, sh_8_14, sh_8_15, sh_8_16, + sh_9_0, sh_9_1, sh_9_2, sh_9_3, sh_9_4, sh_9_5, sh_9_6, sh_9_7, sh_9_8, sh_9_9, sh_9_10, sh_9_11, sh_9_12, + sh_9_13, sh_9_14, sh_9_15, sh_9_16, sh_9_17, sh_9_18, + sh_10_0, sh_10_1, sh_10_2, sh_10_3, sh_10_4, sh_10_5, sh_10_6, sh_10_7, sh_10_8, sh_10_9, sh_10_10, sh_10_11, + sh_10_12, sh_10_13, sh_10_14, sh_10_15, sh_10_16, sh_10_17, sh_10_18, sh_10_19, sh_10_20, + sh_11_0, sh_11_1, sh_11_2, sh_11_3, sh_11_4, sh_11_5, sh_11_6, sh_11_7, sh_11_8, sh_11_9, sh_11_10, sh_11_11, + sh_11_12, sh_11_13, sh_11_14, sh_11_15, sh_11_16, sh_11_17, sh_11_18, sh_11_19, sh_11_20, sh_11_21, sh_11_22 + ], axis=-1) diff --git a/mindscience/e3nn/o3/sub.py b/mindscience/e3nn/o3/sub.py new file mode 100644 index 000000000..03ebe60cf --- /dev/null +++ b/mindscience/e3nn/o3/sub.py @@ -0,0 +1,503 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""sub""" +from typing import NamedTuple +from mindspore.common.parameter import Parameter +from mindspore.ops import operations as P +from mindspore import ops, float32 +from .tensor_product import TensorProduct +from .irreps import Irreps +from ..utils.func import narrow + + +class FullyConnectedTensorProduct(TensorProduct): + r""" + Fully-connected weighted tensor product. All the possible path allowed + by :math:`|l_1 - l_2| \leq l_{out} \leq l_1 + l_2` are made. + Equivalent to `TensorProduct` with `instructions='connect'`. + For details, see :class:`mindchemistry.e3.o3.TensorProduct`. + + Args: + irreps_in1 (Union[str, Irrep, Irreps]): Irreps for the first input. + irreps_in2 (Union[str, Irrep, Irreps]): Irreps for the second input. + irreps_out (Union[str, Irrep, Irreps]): Irreps for the output. + irrep_norm (str): {'component', 'norm'}, the assumed normalization of the input and output representations. + Default: 'component'. Default: 'component'. + path_norm (str): {'element', 'path'}, the normalization method of path weights. Default: 'element'. + weight_init (str): {'zeros', 'ones', 'truncatedNormal', 'normal', 'uniform', 'he_uniform', 'he_normal', + 'xavier_uniform'}, the initial method of weights. Default: 'normal'. + ncon_dtype (mindspore.dtype): The type of input tensors of ncon computation module. + Default: ``mindspore.float32`` . + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.o3 import FullyConnectedTensorProduct + >>> FullyConnectedTensorProduct('2x1o', '1x1o+3x0e', '5x2e+4x1o') + TensorProduct [connect] (2x1o x 1x1o+3x0e -> 5x2e+4x1o) + + """ + + def __init__(self, + irreps_in1, + irreps_in2, + irreps_out, + ncon_dtype=float32, + **kwargs): + super().__init__(irreps_in1, + irreps_in2, + irreps_out, + instructions='connect', + ncon_dtype=ncon_dtype, + **kwargs) + + +class FullTensorProduct(TensorProduct): + r""" + Full tensor product between two irreps. + + Equivalent to `TensorProduct` with `instructions='full'`. + For details, see :class:`mindchemistry.e3.o3.TensorProduct`. + + Args: + irreps_in1 (Union[str, Irrep, Irreps]): Irreps for the first input. + irreps_in2 (Union[str, Irrep, Irreps]): Irreps for the second input. + filter_ir_out (Union[str, Irrep, Irreps, None]): Filter to select only specific `Irrep` + of the output. Default: None. + irrep_norm (str): {'component', 'norm'}, the assumed normalization of the input and output representations. + Default: 'component'. Default: 'component'. + path_norm (str): {'element', 'path'}, the normalization method of path weights. Default: 'element'. + weight_init (str): {'zeros', 'ones', 'truncatedNormal', 'normal', 'uniform', 'he_uniform', 'he_normal', + 'xavier_uniform'}, the initial method of weights. Default: 'normal'. + ncon_dtype (mindspore.dtype): The type of input tensors of ncon computation module. + Default: ``mindspore.float32`` . + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.o3 import FullTensorProduct + >>> FullTensorProduct('2x1o+4x0o', '1x1o+3x0e') + TensorProduct [full] (2x1o+4x0o x 1x1o+3x0e -> 2x0e+12x0o+6x1o+2x1e+4x1e+2x2e) + + """ + + def __init__(self, + irreps_in1, + irreps_in2, + filter_ir_out=None, + ncon_dtype=float32, + **kwargs): + super().__init__(irreps_in1, + irreps_in2, + filter_ir_out, + instructions='full', + ncon_dtype=ncon_dtype, + **kwargs) + + +class ElementwiseTensorProduct(TensorProduct): + r""" + Elementwise connected tensor product. + + Equivalent to `TensorProduct` with `instructions='element'`. + For details, see :class:`mindchemistry.e3.o3.TensorProduct`. + + Args: + irreps_in1 (Union[str, Irrep, Irreps]): Irreps for the first input. + irreps_in2 (Union[str, Irrep, Irreps]): Irreps for the second input. + filter_ir_out (Union[str, Irrep, Irreps, None]): Filter to select only specific `Irrep` of the output. + Default: None. + irrep_norm (str): {'component', 'norm'}, the assumed normalization of the input and output representations. + Default: 'component'. Default: 'component'. + path_norm (str): {'element', 'path'}, the normalization method of path weights. Default: 'element'. + weight_init (str): {'zeros', 'ones', 'truncatedNormal', 'normal', 'uniform', 'he_uniform', 'he_normal', + 'xavier_uniform'}, the initial method of weights. Default: 'normal'. + ncon_dtype (mindspore.dtype): The type of input tensors of ncon computation module. + Default: ``mindspore.float32`` . + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.o3 import ElementwiseTensorProduct + >>> ElementwiseTensorProduct('2x2e+4x1o', '3x1e+3x0o') + TensorProduct [element] (2x2e+1x1o+3x1o x 2x1e+1x1e+3x0o -> 2x1e+2x2e+2x3e+1x0o+1x1o+1x2o+3x1e) + + """ + + def __init__(self, + irreps_in1, + irreps_in2, + filter_ir_out=None, + ncon_dtype=float32, + **kwargs): + super().__init__(irreps_in1, + irreps_in2, + filter_ir_out, + instructions='element', + ncon_dtype=ncon_dtype, + **kwargs) + + +class Linear(TensorProduct): + r""" + Linear operation equivariant. + + Equivalent to `TensorProduct` with `instructions='linear'`. + For details, see :class:`mindchemistry.e3.o3.TensorProduct`. + + Args: + irreps_in (Union[str, Irrep, Irreps]): Irreps for the input. + irreps_out (Union[str, Irrep, Irreps]): Irreps for the output. + irrep_norm (str): {'component', 'norm'}, the assumed normalization of the input and output representations. + Default: ``'component'``. + path_norm (str): {'element', 'path'}, the normalization method of path weights. Default: ``'element'``. + weight_init (str): {'zeros', 'ones', 'truncatedNormal', 'normal', 'uniform', 'he_uniform', 'he_normal', + 'xavier_uniform'}, the initial method of weights. Default: ``'normal'``. + ncon_dtype (mindspore.dtype): The type of input tensors of ncon computation module. + Default: ``mindspore.float32`` . + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.o3 import Linear + >>> Linear('2x2e+3x1o+3x0e', '3x2e+5x1o+2x0e') + TensorProduct [linear] (2x2e+3x1o+3x0e x 1x0e -> 3x2e+5x1o+2x0e) + + """ + + def __init__(self, irreps_in, irreps_out, ncon_dtype=float32, **kwargs): + super().__init__(irreps_in, + None, + irreps_out, + instructions='linear', + ncon_dtype=ncon_dtype, + **kwargs) + + +class Instruction(NamedTuple): + i_in: int + i_out: int + path_shape: tuple + path_weight: float + + +def _prod(x): + out = 1 + for i in x: + out *= i + return out + + +def prod(x): + """Compute the product of a sequence.""" + out = 1 + for a in x: + out *= a + return out + + +def _sum_tensors_withbias(xs, shape, dtype): + """sum tensors of same irrep.""" + if xs: + if len(xs[0].shape) == 1: + out = xs[0] + else: + out = xs[0].reshape(shape) + for x in xs[1:]: + if len(x.shape) == 1: + out = out + x + else: + out = out + x.reshape(shape) + return out + return ops.zeros(shape, dtype=dtype) + + +def _compose(tensors, ir_data, instructions, batch_shape): + """compose list of tensor `tensors` into a 1d-tensor by `ir_data`.""" + res = [] + for i_out, mir_out in enumerate(ir_data): + if mir_out.mul > 0: + res.append( + _sum_tensors_withbias([ + out for ins, out in zip(instructions, tensors) + if ins['i_out'] == i_out + ], + shape=batch_shape + (mir_out.dim,), + dtype=tensors[0].dtype)) + + if len(res) > 1: + res = ops.concat(res, axis=-1) + else: + res = res[0] + return res + + +def _run_continue(ir1_data, ir2_data, irout_data, ins): + """check trivial computations""" + mir_in1 = ir1_data[ins['indice_one']] + mir_in2 = ir2_data[ins['indice_two']] + mir_out = irout_data[ins['i_out']] + if mir_in1.dim == 0 or mir_in2.dim == 0 or mir_out.dim == 0: + return True + return False + + +class LinearBias(TensorProduct): + r""" + Linear operation equivariant with option to add bias. + + Equivalent to `TensorProduct` with `instructions='linear'` with option to add bias. For details, + see :class:`mindchemistry.e3.o3.TensorProduct`. + + Args: + irreps_in (Union[str, Irrep, Irreps]): Irreps for the input. + irreps_out (Union[str, Irrep, Irreps]): Irreps for the output. + irrep_norm (str): {'component', 'norm'}, the assumed normalization of the input and output representations. + Default: ``'component'``. + path_norm (str): {'element', 'path'}, the normalization method of path weights. Default: ``'element'``. + weight_init (str): {'zeros', 'ones', 'truncatedNormal', 'normal', 'uniform', 'he_uniform', 'he_normal', + 'xavier_uniform'}, the initial method of weights. Default: ``'normal'``. + has_bias (bool): whether add bias to calculation + ncon_dtype (mindspore.dtype): The type of input tensors of ncon computation module. + Default: ``mindspore.float32`` . + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.o3 import LinearBias + >>> LinearBias('2x2e+3x1o+3x0e', '3x2e+5x1o+2x0e') + TensorProduct [linear] (2x2e+3x1o+3x0e x 1x0e -> 3x2e+5x1o+2x0e) + + """ + + def __init__(self, + irreps_in, + irreps_out, + has_bias, + ncon_dtype=float32, + **kwargs): + super().__init__(irreps_in, + None, + irreps_out, + instructions='linear', + ncon_dtype=ncon_dtype, + **kwargs) + irreps_in = Irreps(irreps_in) + irreps_out = Irreps(irreps_out) + + biases = [has_bias and ir.is_scalar() for _, ir in irreps_out] + + is_scalar_num = biases.count(True) + + instructions = [ + Instruction(i_in=-1, + i_out=i_out, + path_shape=(mul_ir.dim,), + path_weight=1.0) + for i_out, (bias, mul_ir) in enumerate(zip(biases, irreps_out)) + if bias + ] + self.has_bias = has_bias + self.bias_numel = None + self.bias_instructions = None + if self.has_bias: + self.bias_instructions = [] + for i_out, (bias, mul_ir) in enumerate(zip(biases, self.irreps_out)): + if bias: + path_shape = (mul_ir.dim,) + path_weight = 1.0 + instruction = Instruction(i_in=-1, i_out=i_out, path_shape=path_shape, path_weight=path_weight) + self.bias_instructions.append(instruction) + + if is_scalar_num == 1: + self.bias_numel = sum(irreps_out.data[i.i_out].dim + for i in instructions if i.i_in == -1) + bias = ops.zeros((self.bias_numel)) + self.bias = Parameter(bias, name="bias") + self.instr.append({ + "i_out": self.bias_instructions[0].i_out, + "indice_one": self.bias_instructions[0].i_in + }) + else: + bias = ops.zeros((is_scalar_num, 1)) + self.bias = Parameter(bias, name="bias") + + for bias_instr in self.bias_instructions: + self.instr.append({ + "i_out": bias_instr.i_out, + "indice_one": bias_instr.i_in + }) + + self.bias_add = P.BiasAdd() + self.ncon_dtype = ncon_dtype + + def construct(self, v1, v2=None, weight=None): + """Implement tensor product for input tensors.""" + self._weight_check(weight) + + if self._in2_is_none: + if v2 is not None: + raise ValueError(f"This tensor product should input 1 tensor.") + + if self._mode == 'linear': + v2_shape = v1.shape[:-1] + (1,) + v2 = ops.ones(v2_shape, v1.dtype) + else: + v2 = v1.copy() + else: + if v2 is None: + raise ValueError( + f"This tensor product should input 2 tensors.") + if self._mode == 'linear': + v2_shape = v1.shape[:-1] + (1,) + v2 = ops.ones(v2_shape, v1.dtype) + + batch_shape = v1.shape[:-1] + + v2s = self.irreps_in2.decompose(v2, batch=True) + v1s = self.irreps_in1.decompose(v1, batch=True) + + weight = self._get_weights(weight) + + if not (v1.shape[-1] == self.irreps_in1.dim + and v2.shape[-1] == self.irreps_in2.dim): + raise ValueError(f"The shape of input tensors do not match.") + + v3_list = [] + weight_ind = 0 + fn = 0 + index_one = 'indice_one' + index_two = 'indice_two' + index_wigner = 'wigner_matrix' + + for ins in self.instr: + if ins[index_one] == -1 or _run_continue(self.irreps_in1.data, + self.irreps_in2.data, + self.irreps_out.data, ins): + continue + fn = self._ncons[ins['i_ncon']] + if ins['has_weight']: + l = _prod(ins['path_shape']) + w = narrow( + weight, -1, weight_ind, + l).reshape(( + (-1,) if self.weight_mode == 'custom' else ()) + + ins['path_shape']).astype(self.ncon_dtype) + weight_ind += l + if self.core_mode == 'einsum': + v3 = fn((ins[index_wigner].astype(self.ncon_dtype), + v1s[ins[index_one]].astype(self.ncon_dtype), + v2s[ins[index_two]].astype(self.ncon_dtype), w)) + else: + v3 = fn([ + ins[index_wigner].astype(self.ncon_dtype), + v1s[ins[index_one]].astype(self.ncon_dtype), + v2s[ins[index_two]].astype(self.ncon_dtype), w + ]) + else: + if self.core_mode == 'einsum': + v3 = fn((ins[index_wigner].astype(self.ncon_dtype), + v1s[ins[index_one]].astype(self.ncon_dtype), + v2s[ins[index_two]].astype(self.ncon_dtype))) + else: + v3 = fn([ + ins[index_wigner].astype(self.ncon_dtype), + v1s[ins[index_one]].astype(self.ncon_dtype), + v2s[ins[index_two]].astype(self.ncon_dtype) + ]) + v3_list.append(ins['path_weight'].astype(self.dtype) * + v3.astype(self.dtype)) + + if self.has_bias: + if len(self.bias_instructions) == 1: + v3_list.append(self.bias) + else: + for i in range(len(self.bias_instructions)): + v3_list.append(self.bias[i]) + + v_out = _compose(v3_list, self.irreps_out.data, self.instr, + batch_shape) + + return v_out + + +class TensorSquare(TensorProduct): + r""" + Compute the square tensor product of a tensor. + + Equivalent to `TensorProduct` with `irreps_in2=None and instructions='full' or 'connect'`. For details, + see :class:`mindchemistry.e3.o3.TensorProduct`. + + If `irreps_out` is given, this operation is fully connected. + If `irreps_out` is not given, the operation has no parameter and is like full tensor product. + + Args: + irreps_in (Union[str, Irrep, Irreps]): Irreps for the input. + irreps_out (Union[str, Irrep, Irreps, None]): Irreps for the output. Default: ``None``. + filter_ir_out (Union[str, Irrep, Irreps, None]): Filter to select only specific `Irrep` of the output. + Default: ``None``. + irrep_norm (str): {'component', 'norm'}, the assumed normalization of the input and output representations. + Default: ``'component'``. + path_norm (str): {'element', 'path'}, the normalization method of path weights. Default: ``'element'``. + weight_init (str): {'zeros', 'ones', 'truncatedNormal', 'normal', 'uniform', 'he_uniform', 'he_normal', + 'xavier_uniform'}, the initial method of weights. Default: 'normal'. + ncon_dtype (mindspore.dtype): The type of input tensors of ncon computation module. + Default: ``mindspore.float32`` . + + Raises: + ValueError: If both `irreps_out` and `filter_ir_out` are not None. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.o3 import TensorSquare + >>> TensorSquare('2x1o', irreps_out='5x2e+4x1e+7x1o') + TensorProduct [connect] (2x1o x 2x1o -> 5x2e+4x1e) + >>> TensorSquare('2x1o+3x0e', filter_ir_out='5x2o+4x1e+2x0e') + TensorProduct [full] (2x1o+3x0e x 2x1o+3x0e -> 4x0e+9x0e+4x1e) + + """ + + def __init__(self, + irreps_in, + irreps_out=None, + filter_ir_out=None, + ncon_dtype=float32, + **kwargs): + if irreps_out is None: + super().__init__(irreps_in, + None, + filter_ir_out, + instructions='full', + ncon_dtype=ncon_dtype, + **kwargs) + else: + if filter_ir_out is None: + super().__init__(irreps_in, + None, + irreps_out, + instructions='connect', + ncon_dtype=ncon_dtype, + **kwargs) + else: + raise ValueError( + "Both `irreps_out` and `filter_ir_out` are not None, this is ambiguous." + ) diff --git a/mindscience/e3nn/o3/tensor_product.py b/mindscience/e3nn/o3/tensor_product.py new file mode 100644 index 000000000..0f281fc11 --- /dev/null +++ b/mindscience/e3nn/o3/tensor_product.py @@ -0,0 +1,768 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from mindspore import Tensor, nn, ops, Parameter, get_context, float32, int32, vmap +from mindspore.common.initializer import initializer +import mindspore as ms +from .irreps import Irreps +from .wigner import wigner_3j +from ..utils.ncon import Ncon +from ..utils.func import narrow +from ..utils.initializer import renormal_initializer +import numpy as np +from mindspore.numpy import tensordot + +def _prod(x): + out = 1 + for i in x: + out *= i + return out + + +sqrt = ops.Sqrt() +zeros = ops.Zeros() + +def _sqrt(x, dtype=float32): + """sqrt operator with producing a tensor""" + return sqrt(Tensor(x, dtype=dtype)) + + +def _sum_tensors(xs, shape, dtype): + """sum tensors of same irrep.""" + if len(xs) > 0: + out = xs[0].reshape(shape) + for x in xs[1:]: + out = out + x.reshape(shape) + return out + return zeros(shape, dtype) + + +def _compose(tensors, ir_data, instructions, batch_shape): + """compose list of tensor `tensors` into a 1d-tensor by `ir_data`.""" + res = [] + for i_out, mir_out in enumerate(ir_data): + if mir_out.mul > 0: + res.append(_sum_tensors([out for ins, out in zip(instructions, tensors) + if ins['i_out'] == i_out], shape=batch_shape + (mir_out.dim,), + dtype=tensors[0].dtype)) + if len(res) > 1: + res = ops.concat(res, axis=-1) + else: + res = res[0] + return res + + +def _connect_init(irreps_in1, irreps_in2, irreps_out): + """Input initial for 'connect' mode.""" + full_out = (irreps_in1 * irreps_in2).simplify() + irreps_out = full_out if irreps_out is None else Irreps(irreps_out) + + instr = [] + for i_1, (_, ir_1) in enumerate(irreps_in1.data): + for i_2, (_, ir_2) in enumerate(irreps_in2.data): + ir_out_list = list(ir_1 * ir_2) + for i_out, (_, ir_out) in enumerate(irreps_out.data): + if ir_out in ir_out_list: + instr.append((i_1, i_2, i_out, 'uvw', True)) + + return irreps_out, instr + + +def _full_init(irreps_in1, irreps_in2, irreps_out): + """Input initial for 'full' mode.""" + full_out = irreps_in1 * irreps_in2 + irreps_out = full_out.filter(irreps_out) + + instr = [] + for i_1, (mul_1, ir_1) in enumerate(irreps_in1.data): + for i_2, (mul_2, ir_2) in enumerate(irreps_in2.data): + ir_out_list = list(ir_1 * ir_2) + for i_out, (mul_out, ir_out) in enumerate(irreps_out.data): + if ir_out in ir_out_list and mul_out == mul_1 * mul_2: + instr.append((i_1, i_2, i_out, 'uvuv', False)) + + return irreps_out, instr + + +def _element_init(irreps_in1, irreps_in2, irreps_out): + """Input initial for 'element' mode.""" + irreps_out = None if irreps_out is None else Irreps(irreps_out) + + if not irreps_in1.num_irreps == irreps_in2.num_irreps: + raise ValueError( + f"The total multiplicities of irreps_in1 {irreps_in1} and irreps_in2 {irreps_in2} should be equal.") + + irreps_in1_list = list(Irreps(irreps_in1).simplify().data) + irreps_in2_list = list(Irreps(irreps_in2).simplify().data) + + i = 0 + while i < len(irreps_in1_list): + mul_1, ir_1 = irreps_in1_list[i] + mul_2, ir_2 = irreps_in2_list[i] + + if mul_1 < mul_2: + irreps_in2_list[i] = (mul_1, ir_2) + irreps_in2_list.insert(i + 1, (mul_2 - mul_1, ir_2)) + + if mul_2 < mul_1: + irreps_in1_list[i] = (mul_2, ir_1) + irreps_in1_list.insert(i + 1, (mul_1 - mul_2, ir_1)) + i += 1 + + out = [] + instr = [] + for i, ((mul, ir_1), (mul_2, ir_2)) in enumerate(zip(irreps_in1_list, irreps_in2_list)): + for ir in ir_1 * ir_2: + if irreps_out is not None and ir not in irreps_out: + continue + + out.append((mul, ir)) + instr.append((i, i, len(out) - 1, 'uuu', False)) + + return Irreps(irreps_in1_list), Irreps(irreps_in2_list), Irreps(out), instr + + +def _linear_init(irreps_in1, irreps_out): + """Input initial for 'lnear' mode.""" + irreps_out = Irreps(irreps_out) + + instr = [] + for i_1, (_, ir_1) in enumerate(irreps_in1.data): + for i_out, (_, ir_out) in enumerate(irreps_out.data): + if ir_1 == ir_out: + instr.append((i_1, 0, i_out, 'uvw', True)) + + return irreps_out, instr + + +def _merge_init(irreps_in1, irreps_in2, irreps_out_filter): + """Input initial for 'merge' mode.""" + irreps_out_filter = Irreps( + irreps_out_filter) if irreps_out_filter is not None else irreps_in1 * irreps_in2 + + irreps_out_list = [] + instr = [] + for i_1, (mul, ir_1) in enumerate(irreps_in1.data): + for i_2, (_, ir_2) in enumerate(irreps_in2.data): + for ir in ir_1 * ir_2: + if ir in irreps_out_filter: + k = len(irreps_out_list) + irreps_out_list.append((mul, ir)) + instr.append((i_1, i_2, k, 'uvu', True)) + + irreps_out = Irreps(irreps_out_list) + irreps_out, p, _ = irreps_out.sort() + + instr = [(i_1, i_2, p[i_out], mode, train) + for i_1, i_2, i_out, mode, train in instr] + + return irreps_out, instr + + +def _raw_ins_check(mir_in1, mir_in2, mir_out, raw_ins): + """Check raw input instructions.""" + if not mir_in1.ir.p * mir_in2.ir.p == mir_out.ir.p: + raise ValueError( + f"The parity of inputs and output do not match. \n \ + {mir_in1.ir.p} * {mir_in2.ir.p} should equal to {mir_out.ir.p}.") + if not (abs(mir_in1.ir.l - mir_in2.ir.l) <= mir_out.ir.l and mir_out.ir.l <= mir_in1.ir.l + mir_in2.ir.l): + raise ValueError( + f"The degree of inputs and output do not match. \n \ + The degrees should be |{mir_in1.ir.l} - {mir_in2.ir.l}| <= {mir_out.ir.l} <= |{mir_in1.ir.l} + {mir_in2.ir.l}|.") + if not raw_ins[3] in ['uvw', 'uvu', 'uvv', 'uuw', 'uuu', 'uvuv']: + raise ValueError( + f"The connection mode should be in ['uvw', 'uvu', 'uvv', 'uuw', 'uuu', 'uvuv']") + + +def _mode_check(mul_in1, mul_in2, mul_out, ins): + """Consistency check for multiplicities.""" + if ins['mode'] == 'uvw': + if not ins['has_weight']: + raise ValueError(f"The connection mode 'uvw' should have weights.") + elif ins['mode'] == 'uuu': + if not (mul_in1 == mul_in2 and mul_in2 == mul_out): + raise ValueError( + f"The multiplicity of inputs and output do not match. \ + It should be {mul_in1} == {mul_in2} == {mul_out}.") + elif ins['mode'] == 'uuw': + if not mul_in1 == mul_in2: + raise ValueError( + f"The multiplicity of inputs do not match. \ + It should be {mul_in1} == {mul_in2}.") + if not (ins['has_weight'] or mul_out == 1): + raise ValueError( + f"The multiplicity of input or 'has_weight' do not match. \ + If 'has_weight' == Flase, {mul_out} should equal to 1.") + elif ins['mode'] == 'uvu': + if not mul_in1 == mul_out: + raise ValueError( + f"The multiplicity of input 1 and output do not match. \ + It should be {mul_in1} == {mul_out}.") + elif ins['mode'] == 'uvv': + if not mul_in2 == mul_out: + raise ValueError( + f"The multiplicity of input 2 and output do not match. \ + It should be {mul_in2} == {mul_out}.") + elif ins['mode'] == 'uvuv': + if not mul_in1 * mul_in2 == mul_out: + raise ValueError( + f"The multiplicity of inputs and output do not match. \ + It should be {mul_in1} * {mul_in2} == {mul_out}.") + + +def _init_einsum(mode, ls): + """tensor graph contractions""" + if mode == 'uuu': + einsum = ops.Einsum("ijk,zui,zuj->zuk") + elif mode == 'uuw': + einsum = ops.Einsum("ijk,zui,zuj->zk") + elif mode == 'uvu': + einsum = ops.Einsum("ijk,zui,zvj->zuk") + elif mode == 'uvv': + einsum = ops.Einsum("ijk,zui,zvj->zvk") + elif mode == 'uvuv': + einsum = ops.Einsum("ijk,zui,zvj->zuvk") + return einsum + + +def _init_einsum_weight(mode, weight_mode, ls): + """tensor graph contractions with weights""" + z = "z" if weight_mode == 'custom' else "" + if mode == 'uvw': + einsum = ops.Einsum(f"ijk,zui,zvj,{z}uvw->zwk") + elif mode == 'uuu': + einsum = ops.Einsum(f"ijk,zui,zuj,{z}u->zuk") + elif mode == 'uuw': + einsum = ops.Einsum(f"ijk,zui,zuj,{z}uw->zwk") + elif mode == 'uvu': + einsum = ops.Einsum(f"ijk,zui,zvj,{z}uv->zuk") + elif mode == 'uvv': + einsum = ops.Einsum(f"ijk,zui,zvj,{z}uv->zvk") + elif mode == 'uvuv': + einsum = ops.Einsum(f"ijk,zui,zvj,{z}uv->zuvk") + return einsum + + +def _init_ncon(mode, ls): + """tensor graph contractions""" + if mode == 'uuu': + con_list = [[1, 2, -3], [-1, -2, 1], [-1, -2, 2]] + elif mode == 'uuw': + con_list = [[1, 2, -2], [-1, 3, 1], [-1, 3, 2]] + elif mode == 'uvu': + con_list = [[1, 2, -3], [-1, -2, 1], [-1, 3, 2]] + elif mode == 'uvv': + con_list = [[1, 2, -3], [-1, 3, 1], [-1, -2, 2]] + elif mode == 'uvuv': + con_list = [[1, 2, -4], [-1, -2, 1], [-1, -3, 2]] + ncon = Ncon(con_list) + return ncon + + +class uvw_ncon_v2(nn.Cell): + def __init__(self): + super(uvw_ncon_v2, self).__init__() + self.tensordot1 = tensordot + self.tensordot2 = tensordot + self.tensordot3 = vmap(tensordot, (0,0,None), 0) + def construct(self, m1, m2, m3, m4): + temp1 = self.tensordot1(m3, m1 , [2,1]) + temp2 = self.tensordot1(m2, m4 , [1,0]) + res = self.tensordot3(temp2, temp1, ([0,1],[1,0])) + return res + +def _init_ncon_weight(mode, weight_mode, ls): + """tensor graph contractions with weights""" + if mode == 'uvw': + con_list = [[1, 2, -3], [-1, 3, 1], [-1, 4, 2], [3, 4, -2]] + elif mode == 'uuu': + con_list = [[1, 2, -3], [-1, -2, 1], [-1, -2, 2], [-2]] + elif mode == 'uuw': + con_list = [[1, 2, -3], [-1, 3, 1], [-1, 3, 2], [3, -2]] + elif mode == 'uvu': + con_list = [[1, 2, -3], [-1, -2, 1], [-1, 3, 2], [-2, 3]] + elif mode == 'uvv': + con_list = [[1, 2, -3], [-1, 3, 1], [-1, -2, 2], [3, -2]] + elif mode == 'uvuv': + con_list = [[1, 2, -4], [-1, -2, 1], [-1, -3, 2], [-2, -3]] + if weight_mode == 'custom': + con_list[3] = [-1] + con_list[3] + ncon = Ncon(con_list) + return ncon + + +def _run_continue(ir1_data, ir2_data, irout_data, ins): + """check trivial computations""" + mir_in1 = ir1_data[ins['indice_one']] + mir_in2 = ir2_data[ins['indice_two']] + mir_out = irout_data[ins['i_out']] + if mir_in1.dim == 0 or mir_in2.dim == 0 or mir_out.dim == 0: + return True + return False + + +class TensorProduct(nn.Cell): + r""" + Versatile tensor product operator of two input `Irreps` and a output `Irreps`, that sends two tensors into a tensor + and keep the geometric tensor properties. + This class integrates different typical usages: `TensorSquare`, `FullTensorProduct`, `FullyConnectedTensorProduct`, + `ElementwiseTensorProduct` and `Linear`. + + A `TensorProduct` class defines an algebraic structure with equivariance. + Ones the `TensorProduct` object is created and initialized, the algorithm is determined. For any given two legal input + tensors, this object will provide a output tensor. + If the object do not have learnable weights, the output tensor is deterministic. + When the learnable weights are introduced, this operator will correspond to a general bilinear, equivariant operation, + as a generalization of the standard tensor product. + + If `irreps_in2` is not specified, it will be assigned as `irreps_in1`, corresponding to `TensorSquare`. + If `irreps_out` is not specified, this operator will account all possible output irreps. + If both `irreps_out` and `instructions` are not specified, this operator is the standard tensor product without + any learnable weights, corresponding to ``FullTensorProduct``. + + Each output irrep should satisfy: + + .. math:: + \| l_1 - l_2 \| \leq l_{out} \leq \| l_1 + l_2 \| + p_1 p_2 = p_{out} + + Args: + irreps_in1 (Union[str, Irrep, Irreps]): Irreps for the first input. + irreps_in2 (Union[str, Irrep, Irreps, None]): Irreps for the second input. Default: ``None``. + If `irreps_in2` is None, `irreps_in2` will be assigned as '0e' in 'linear' instructions, or be assigned as `irreps_in1` in otherwise, corresponding to `TensorSquare`. + irreps_out (Union[str, Irrep, Irreps, None]): Irreps for the output in 'connect' and custom instructions, or filter irreps for the output in otherwise. + If `irreps_out` is None, `irreps_out` will be the full tensor product irreps (including all possible paths). Default: ``None``. + instructions (Union[str, List[Tuple[int, int, int, str, bool, (float)]]]): List of tensor product path instructions. Default: ``'full'``. + For `str` in {'full', 'connect', 'element', 'linear', 'mearge'}, the instructions are constructed automatically according to the different modes: + + - 'full': each output irrep for every pair of input irreps — is created and returned independently. The outputs are not mixed with each other. + Corresponding to the standard tensor product `FullTensorProduct` if `irreps_out` is not specified. + - 'connect': each output is a learned weighted sum of compatible paths. This allows the operator to produce outputs with any multiplicity. + Corresponding to `FullyConnectedTensorProduct`. + - 'element': the irreps are multiplied one-by-one. The inputs will be split and that the multiplicities of the outputs match with the multiplicities of the input. + Corresponding to `ElementwiseTensorProduct`. + - 'linear': linear operation equivariant on the first irreps, while the second irreps is set to be '0e'. This can be regarded as the geometric tensors version of teh dense layer. + Corresponding to `Linear`. + - 'merge': Automatically build 'uvu' mode instructions with trainable parameters. The `irreps_out` here plays the role of output filters. + + For `List[Tuple[int, int, int, str, bool, (float)]]`, the instructions are constructed manually. + + Each instruction contain a tuple: (indice_one, indice_two, i_out, mode, has_weight, (optional: path_weight)). + Each instruction puts ``in1[indice_one]`` :math:`\otimes` ``in2[indice_two]`` into ``out[i_out]``. + + - `indice_one`, `indice_two`, `i_out`: int, the index of the irrep in irreps for `irreps_in1`, `irreps_in2` and `irreps_out` correspondingly. + - `mode`: str in {'uvw', 'uvu', 'uvv', 'uuw', 'uuu', 'uvuv'}, the way of the multiplicities of each path are treated. 'uvw' is the fully mixed mode. + - `has_weight`: bool, `True` if this path should have learnable weights, otherwise `False`. + - `path_weight`:float, a multiplicative weight to apply to the output of this path. Defaults: 1.0. + + irrep_norm (str): {'component', 'norm'}, the assumed normalization of the input and output representations. Default: ``'component'``. + + - 'norm': :math:`\| x \| = \| y \| = 1 \Longrightarrow \| x \otimes y \| = 1` + + path_norm (str): {'element', 'path'}, the normalization method of path weights. Default: ``'element'``. + + - 'element': each output is normalized by the total number of elements (independently of their paths). + - 'path': each path is normalized by the total number of elements in the path, then each output is normalized by the number of paths. + + weight_init (str): {'zeros', 'ones', 'truncatedNormal', 'normal', 'uniform', 'he_uniform', 'he_normal', 'xavier_uniform'}, the initial method of weights. Default: ``'normal'``. + weight_mode (str): {'inner', 'share', 'custom'} determine the weights' mode. Default: ``'inner'``. + + - 'inner': weights will initialized in the tensor product internally. + - 'share': weights should given manually without batch dimension. + - 'custom': weights should given manually with batch dimension. + + core_mode (str): {'ncon', 'einsum'} determine the core computation mode. Default: ``'ncon'``. + dtype (mindspore.dtype): The type of input tensor. Default: ``mindspore.float32`` . + ncon_dtype (mindspore.dtype): The type of input tensors of ncon computation module. + Default: ``mindspore.float32`` . + + Inputs: + - **x** (Tensor) - The shape of Tensor is ``(..., irreps_in1.dim)`` + - **y** (Tensor) - The shape of Tensor is ``(..., irreps_in2.dim)`` + - **weight** (Tensor) - `Tensor` or list of `Tensor`, optional + required if ``internal_weights`` is ``False``. + The shape of Tensor is ``(self.weight_numel,)`` if ``shared_weights`` is ``True``. + The shape of Tensor is ``(..., self.weight_numel)`` if ``shared_weights`` is ``False`` + or list of tensors of shapes ``weight_shape`` / ``(...) + weight_shape``. + Use ``self.instructions`` to know what are the weights used for. + The shape of Tensor is ``(..., irreps_out.dim)``. + + Outputs: + - **outputs** (Tensor) - The shape of Tensor is ``(..., irreps_out.dim)``. + + Raises: + ValueError: If `irreps_out` is not legal. + ValueError: If the connection mode is not in ['uvw', 'uvu', 'uvv', 'uuw', 'uuu', 'uvuv']. + ValueError: If the degree of inputs and output do not match. + ValueError: If the parity of inputs and output do not match. + ValueError: If the multiplicity of inputs and output do not match. + ValueError: If the connection mode is 'uvw', but `has_weight` is `False`. + ValueError: If the connection mode is 'uuw' and `has_weight` is `False`, but the multiplicity is not equal to 1. + ValueError: If the initial method is not supported. + ValueError: If the number of input tensors is not match to the number of input irreps. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> import mindspore as ms + >>> from mindchemistry.e3.o3 import TensorProduct + Standard tensor product: + >>> tp1 = TensorProduct('2x1o+4x0o', '1x1o+3x0e') + TensorProduct [full] (2x1o+4x0o x 1x1o+3x0e -> 2x0e+12x0o+6x1o+2x1e+4x1e+2x2e) + >>> v1 = ms.Tensor(np.linspace(1., 2., tp1.irreps_in1.dim), dtype=ms.float32) + >>> v2 = ms.Tensor(np.linspace(2., 3., tp1.irreps_in2.dim), dtype=ms.float32) + >>> tp1(v1, v2).shape + (1, 60) + Elementwise tensor product: + >>> tp2 = TensorProduct('2x2e+4x1o', '3x1e+3x0o') + TensorProduct [element] (2x2e+1x1o+3x1o x 2x1e+1x1e+3x0o -> 2x1e+2x2e+2x3e+1x0o+1x1o+1x2o+3x1e) + >>> tp2.instructions + [(0, 0, 0, 'uuu', False), (0, 0, 1, 'uuu', False), (0, 0, 2, 'uuu', False), (1, 1, 3, 'uuu', False), + (1, 1, 4, 'uuu', False), (1, 1, 5, 'uuu', False), (2, 2, 6, 'uuu', False)] + Custom tensor product with learnable weights: + >>> tp3 = TensorProduct( + ... '3x2o+2x1o', '2x2e+4x1o+5x0e', '2x3o+8x1e+10x1o', + ... [ + ... (0,0,0,'uvv',True), + ... (1,0,0,'uuu',True), + ... (1,1,1,'uvuv',True), + ... (1,2,2,'uvw',True) + ... ] + ... ) + TensorProduct [custom] (3x2o+2x1o x 2x2e+4x1o+5x0e -> 2x3o+8x1e+10x1o) + >>> [w.shape for w in tp3.weights] + [(3, 2), (2,), (2, 4), (2, 5, 10)] + Linear operation with an output filter: + >>> tp4 = TensorProduct('2x1o', irreps_out='5x2e+4x1e+7x1o', instructions='connect') + TensorProduct [linear] (2x2e+3x1o+3x0e x 1x0e -> 3x2e+5x1o+2x0e) + >>> v1 = ms.Tensor(np.linspace(1., 2., tp4.irreps_in1.dim), dtype=ms.float32) + >>> tp4(v1).shape + (1, 32) + """ + __slots__ = ('irreps_in1', 'irreps_in2', 'irreps_out', + 'weights', '_in2_is_none', '_mode', '_device', 'output_mask', 'core_mode') + + def __init__( + self, + irreps_in1, + irreps_in2=None, + irreps_out=None, + instructions='full', + dtype=float32, + irrep_norm='component', + path_norm='element', + weight_init='normal', + weight_mode='inner', + core_mode='ncon', + ncon_dtype = float32 + ): + super().__init__() + + if weight_mode not in ['inner', 'share', 'custom']: + raise ValueError( + f"`weight_mode` should be one of ['inner', 'share', 'custom'].") + if core_mode not in ['ncon', 'einsum']: + raise ValueError( + f"`core_mode` should be one of ['ncon', 'einsum'].") + elif core_mode == 'einsum' and get_context('device_target') != 'GPU': + raise ValueError( + f"The `core_mode`: einsum only support GPU, but got {get_context('device_target')}.") + self.weight_mode = weight_mode + self.dtype = dtype + self.core_mode = core_mode + self.ones = ops.Ones() + self.zeros = ops.Zeros() + + self.irreps_in1 = Irreps(irreps_in1).simplify() + if irreps_in2 is None: + self.irreps_in2 = Irreps(irreps_in1).simplify() + self._in2_is_none = True + else: + self.irreps_in2 = Irreps(irreps_in2).simplify() + self._in2_is_none = False + + self.irreps_out, instructions = self._input_init( + self.irreps_in1, self.irreps_in2, irreps_out, instructions) + + self.instr, self._ncons = self._ins_init(instructions) + + self.weight_numel = sum(_prod(ins['path_shape']) + for ins in self.instr if ins['has_weight']) + + self.weights = self._weight_init(weight_init) + + self.output_mask = self._init_mask() + + self._normalization(irrep_norm=irrep_norm, path_norm=path_norm) + + self.ncon_dtype = ncon_dtype + + def construct(self, v1, v2=None, weight=None): + """Implement tensor product for input tensors.""" + self._weight_check(weight) + + if self._in2_is_none: + if v2 is not None: + raise ValueError(f"This tensor product should input 1 tensor.") + + if self._mode == 'linear': + v2_shape = v1.shape[:-1] + (1,) + v2 = self.ones(v2_shape, v1.dtype) + else: + v2 = v1.copy() + else: + if v2 is None: + raise ValueError( + f"This tensor product should input 2 tensors.") + if self._mode == 'linear': + v2_shape = v1.shape[:-1] + (1,) + v2 = self.ones(v2_shape, v1.dtype) + + batch_shape = v1.shape[:-1] + v1s = self.irreps_in1.decompose(v1, batch=True) + v2s = self.irreps_in2.decompose(v2, batch=True) + weight = self._get_weights(weight) + if not (v1.shape[-1] == self.irreps_in1.dim and v2.shape[-1] == self.irreps_in2.dim): + raise ValueError(f"The shape of input tensors do not match.") + + v3_list = [] + weight_ind = 0 + fn = 0 + + for ins in self.instr: + if _run_continue(self.irreps_in1.data, self.irreps_in2.data, self.irreps_out.data, ins): + continue + fn = self._ncons[ins['i_ncon']] + if ins['has_weight']: + l = _prod(ins['path_shape']) + w = narrow(weight, -1, weight_ind, l).reshape(((-1,) + if self.weight_mode == 'custom' else ()) + ins['path_shape']).astype(self.ncon_dtype) + weight_ind += l + if self.core_mode == 'einsum': + v3 = fn((ins['wigner_matrix'].astype(self.ncon_dtype), v1s[ins['indice_one']].astype(self.ncon_dtype), v2s[ins['indice_two']].astype(self.ncon_dtype), w)) + else: + v3 = fn([ins['wigner_matrix'].astype(self.ncon_dtype), v1s[ins['indice_one']].astype(self.ncon_dtype), v2s[ins['indice_two']].astype(self.ncon_dtype), w]) + else: + if self.core_mode == 'einsum': + v3 = fn((ins['wigner_matrix'].astype(self.ncon_dtype), v1s[ins['indice_one']].astype(self.ncon_dtype), v2s[ins['indice_two']].astype(self.ncon_dtype))) + else: + v3 = fn([ins['wigner_matrix'].astype(self.ncon_dtype), v1s[ins['indice_one']].astype(self.ncon_dtype), v2s[ins['indice_two']].astype(self.ncon_dtype)]) + v3_list.append(ins['path_weight'].astype(self.dtype) * v3.astype(self.dtype)) + + v_out = _compose(v3_list, self.irreps_out.data, self.instr, batch_shape) + return v_out + + def __repr__(self): + return f'TensorProduct [{self._mode}] ({self.irreps_in1.simplify().__repr__()} x {self.irreps_in2.simplify().__repr__()} -> {self.irreps_out.simplify().__repr__()} | {self.weight_numel} weights)' + + @property + def instructions(self): + return [tuple(ins.values())[:5] for ins in self.instr] + + def _input_init(self, irreps_in1, irreps_in2, irreps_out, instructions): + if not isinstance(instructions, str): + irreps_out = irreps_in1 * \ + irreps_in2 if irreps_out is None else Irreps(irreps_out) + self._mode = 'custom' + else: + if instructions == 'connect': + irreps_out, instructions = _connect_init( + irreps_in1, irreps_in2, irreps_out) + self._mode = 'connect' + + elif instructions == 'full': + irreps_out, instructions = _full_init( + irreps_in1, irreps_in2, irreps_out) + self._mode = 'full' + + elif instructions == 'element': + self.irreps_in1, self.irreps_in2, irreps_out, instructions = _element_init( + irreps_in1, irreps_in2, irreps_out) + self._mode = 'element' + + elif instructions == 'linear': + self.irreps_in2 = Irreps('0e') + irreps_out, instructions = _linear_init(irreps_in1, irreps_out) + self._mode = 'linear' + + elif instructions == 'merge': + irreps_out, instructions = _merge_init( + irreps_in1, irreps_in2, irreps_out) + self._mode = 'merge' + + else: + raise ValueError( + f"Unexpected instructions mode {instructions}") + + return irreps_out, instructions + + def _ins_init(self, raw_ins): + """reform instructions""" + raw_ins = [x if len(x) == 6 else x + (1.0,) for x in raw_ins] + res = [] + ncons = [] + + for ins in raw_ins: + indice_one = ins[0] + indice_two = ins[1] + i_out = ins[2] + mode = ins[3] + has_weight = ins[4] + path_weight = ins[5] + + mirs = ( + self.irreps_in1.data[indice_one], self.irreps_in2.data[indice_two], self.irreps_out.data[i_out]) + muls = (mirs[0].mul, mirs[1].mul, mirs[2].mul) + + _raw_ins_check(*mirs, ins) + + path_shape = { + 'uvw': (muls[0], muls[1], muls[2]), + 'uvu': (muls[0], muls[1]), + 'uvv': (muls[0], muls[1]), + 'uuw': (muls[0], muls[2]), + 'uuu': (muls[0],), + 'uvuv': (muls[0], muls[1]), + }[mode] + + num_elements = { + 'uvw': (muls[0] * muls[1]), + 'uvu': muls[1], + 'uvv': muls[0], + 'uuw': muls[0], + 'uuu': 1, + 'uvuv': 1, + }[mode] + + ls = (mirs[0].ir.l, mirs[1].ir.l, mirs[2].ir.l) + + d, op = self._ins_dict(indice_one, indice_two, i_out, mode, has_weight, + path_weight, path_shape, num_elements, wigner_3j(*ls, self.dtype), ls) + ncons.append(op) + d['i_ncon'] = len(ncons) - 1 + res.append(d) + + _mode_check(*muls, res[-1]) + + return res, ncons + + def _ins_dict(self, *args): + """generate reformed instructions""" + d = {} + keys = ['indice_one', 'indice_two', 'i_out', 'mode', 'has_weight', + 'path_weight', 'path_shape', 'num_elements', 'wigner_matrix', 'ls'] + for i, arg in enumerate(args): + d[keys[i]] = arg + + if d['has_weight']: + if self.core_mode == 'einsum': + operator = _init_einsum_weight( + d['mode'], self.weight_mode, d['ls']) + else: + operator = _init_ncon_weight( + d['mode'], self.weight_mode, d['ls']) + else: + if self.core_mode == 'einsum': + operator = _init_einsum(d['mode'], d['ls']) + else: + operator = _init_ncon(d['mode'], d['ls']) + + return d, operator + + def _weight_init(self, init_method): + """init weights""" + init_method = renormal_initializer(init_method) + + if self.weight_numel > 0 and self.weight_mode == 'inner': + weights = Parameter(initializer(init_method, (1, self.weight_numel), dtype=self.dtype).init_data().flatten()) + else: + weights = None + + return weights + + def _init_mask(self): + if self.irreps_out.dim > 0: + output_mask = ops.cat([ + self.ones(mul * ir.dim, int32) + if any( + (ins['i_out'] == i_out) and (ins['path_weight'] + != 0) and (0 not in ins['path_shape']) + for ins in self.instr + ) + else self.zeros(mul * ir.dim, int32) + for i_out, (mul, ir) in enumerate(self.irreps_out.data) + ]) + else: + output_mask = Tensor(0) + + return output_mask + + def _normalization(self, irrep_norm, path_norm): + """path normalization""" + for ins in self.instr: + mir_in1 = self.irreps_in1.data[ins['indice_one']] + mir_in2 = self.irreps_in2.data[ins['indice_two']] + mir_out = self.irreps_out.data[ins['i_out']] + + alpha = 1. + if irrep_norm == 'component': + alpha = mir_out.ir.dim + if irrep_norm == 'norm': + alpha = mir_in1.ir.dim * mir_in2.ir.dim + + x = 1. + if path_norm == 'element': + x = sum(i['num_elements'] + for i in self.instr if i['i_out'] == ins['i_out']) + if path_norm == 'path': + x = ins['num_elements'] + x *= len([i for i in self.instr if i['i_out'] + == ins['i_out']]) + + if x > 0.0: + alpha /= x + + alpha *= ins['path_weight'] + ins['path_weight'] = _sqrt(alpha, self.dtype) + + def _weight_check(self, weight): + if self.weight_mode == 'inner': + if weight is None: + return True + raise ValueError( + f"For `weight_mode` {self.weight_mode}, the `weight` should not given manually.") + elif self.weight_mode == 'share': + if weight is None: + raise ValueError( + f"For `weight_mode` {self.weight_mode}, the `weight` should given manually.") + if not weight.ndim == 1: + raise ValueError( + f"The shape of custom weight {weight.shape} is illegal.") + elif self.weight_mode == 'custom': + if weight is None: + raise ValueError( + f"For `weight_mode` {self.weight_mode}, the `weight` should given manually.") + if not weight.ndim > 1: + raise ValueError( + f"Custom weight {weight} should have batch dimension if `weight_mode` is `'custom'`.") + else: + raise ValueError(f"Unknown `weight_mode`: {self.weight_mode}.") + return True + + def _get_weights(self, weight): + if weight is None: + return self.weights + else: + return weight.reshape(-1, self.weight_numel) diff --git a/mindscience/e3nn/o3/wigner.py b/mindscience/e3nn/o3/wigner.py new file mode 100644 index 000000000..bd086be33 --- /dev/null +++ b/mindscience/e3nn/o3/wigner.py @@ -0,0 +1,336 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import functools +import math +from fractions import Fraction +from math import factorial + +import numpy as np + +from mindspore import Tensor, ops, float32, float64, complex64, complex128 + +from ..utils.func import _ndexpm, broadcast_args, _expand_last_dims + +PI = Tensor(math.pi) + + +def change_basis_real_to_complex(l, dtype=float32): + r""" + Convert a real basis of spherical harmonics in term of complex. + + Args: + l (int): degree of spherical harmonics. + dtype (dtype): {float32, float64} data type of the real basis. Default: float32. + + Returns: + Tensor, the complex basis with dtype complex64 for `dtype` = float32 and complex128 for `dtype` = float64. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.o3 import change_basis_real_to_complex + >>> m = change_basis_real_to_complex(1) + >>> print(m) + [[-0.70710677+0.j 0. +0.j 0. -0.70710677j] + [ 0. +0.j 0. -1.j 0. +0.j ] + [-0.70710677+0.j 0. +0.j 0. +0.70710677j]] + """ + q = np.zeros((2 * l + 1, 2 * l + 1), np.complex128) + for m in range(-l, 0): + q[l + m, l + abs(m)] = 1 / 2 ** 0.5 + q[l + m, l - abs(m)] = -1j / 2 ** 0.5 + q[l, l] = 1 + for m in range(1, l + 1): + q[l + m, l + abs(m)] = (-1) ** m / 2 ** 0.5 + q[l + m, l - abs(m)] = 1j * (-1) ** m / 2 ** 0.5 + q = (-1j) ** l * q + + dtype = { + float32: complex64, + float64: complex128, + }[dtype] + + q_new = Tensor(q, dtype=dtype) + return q_new + + +def su2_generators(j, dtype=complex64): + r""" + Compute the su(2) Lie algebra generators. + + Args: + j (int): degree of generators. + dtype (dtype): {complex64, complex128} data type of generators. Default: complex64. + + Returns: + Tensor, su(2) generators with the dtype is `dtype`. + + Raise: + TypeError: If `j` is not int. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.o3 import su2_generators + >>> m = su2_generators(1) + >>> print(m) + [[[ 0. +0.j 0.70710677+0.j + 0. +0.j ] + [-0.70710677+0.j 0. +0.j + 0.70710677+0.j ] + [ 0. +0.j -0.70710677+0.j + 0. +0.j ]] + [[-0. -1.j 0. +0.j + 0. +0.j ] + [ 0. +0.j 0. +0.j + 0. +0.j ] + [ 0. +0.j 0. +0.j + 0. +1.j ]] + [[ 0. -0.j 0. +0.70710677j + 0. -0.j ] + [ 0. +0.70710677j 0. -0.j + 0. +0.70710677j] + [ 0. -0.j 0. +0.70710677j + 0. -0.j ]]] + """ + if not isinstance(j, int): + raise TypeError + m = np.arange(-j, j) + raising = np.diag(-np.sqrt(j * (j + 1) - m * (m + 1)), k=-1) + + m = np.arange(-j + 1, j + 1) + lowering = np.diag(np.sqrt(j * (j + 1) - m * (m - 1)), k=1) + + m = np.arange(-j, j + 1) + res = np.stack([ + 0.5 * (raising + lowering), # x (usually) + np.diag(1j * m), # z (usually) + -0.5j * (raising - lowering), # -y (usually) + ], axis=0) + return Tensor(res, dtype=dtype) + + +def so3_generators(l, dtype=float32): + r""" + Compute the so(3) Lie algebra generators. + + Args: + l (int): degree of generators. + dtype (dtype): {float32, float64} data type of generators. Default: float32. + + Returns: + Tensor, so(3) generators with the dtype is `dtype`. + + Raise: + TypeError: If `l` is not int. + ValueError: If matrices data are inconsistent. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.o3 import so3_generators + >>> m = so3_generators(1) + >>> print(m) + [[[ 0. 0. 0. ] + [ 0. 0. -0.99999994] + [ 0. 0.99999994 0. ]] + [[ 0. 0. 0.99999994] + [ 0. 0. 0. ] + [-0.99999994 0. 0. ]] + [[ 0. -0.99999994 0. ] + [ 0.99999994 0. 0. ] + [ 0. 0. 0. ]]] + """ + if not isinstance(l, int): + raise TypeError + cdtype = { + float32: complex64, + float64: complex128, + }[dtype] + X = su2_generators(l, dtype=cdtype).asnumpy() + Q = change_basis_real_to_complex(l, dtype=dtype).asnumpy() + X = np.conj(Q.T) @ X @ Q + + if not np.all(np.abs(np.imag(X)) < 1e-5): + raise ValueError + X_real = np.real(X) + return Tensor(X_real, dtype=dtype) + + +def wigner_D(l, alpha, beta, gamma): + r""" + Wigner D matrix representation of SO(3). + + It satisfies the following properties: + * :math:`D(\text{identity rotation}) = \text{identity matrix}` + * :math:`D(R_1 \circ R_2) = D(R_1) \circ D(R_2)` + * :math:`D(R^{-1}) = D(R)^{-1} = D(R)^T` + + Args: + l (int): degree of representation. + alpha (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): rotation :math:`\alpha` around Y axis, applied third. + beta (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): rotation :math:`\beta` around X axis, applied second. + gamma (Union[Tensor[float32], List[float], Tuple[float], ndarray[np.float32], float]): rotation :math:`\gamma` around Y axis, applied first. + + Returns: + Tensor, Wigner D matrix :math:`D^l(\alpha, \beta, \gamma)`. The shape of Tensor is :math:`(2l+1, 2l+1)`. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.o3 import wigner_D + >>> m = wigner_D(1,1,1,1) + >>> print(m) + [[-0.09064701 0.7080733 0.70029646] + [ 0.7080733 0.54030234 -0.45464867] + [-0.7002964 0.45464864 -0.5503447 ]] + + """ + + alpha, beta, gamma = broadcast_args(alpha, beta, gamma) + alpha = _expand_last_dims(alpha) % (2 * PI) + beta = _expand_last_dims(beta) % (2 * PI) + gamma = _expand_last_dims(gamma) % (2 * PI) + X = so3_generators(l) + return ops.matmul(ops.matmul(_ndexpm(alpha * X[1]), _ndexpm(beta * X[0])), _ndexpm(gamma * X[1])) + + +def wigner_3j(l1, l2, l3, dtype=float32): + r""" + Wigner 3j symbols :math:`C_{lmn}`. + + It satisfies the following two properties: + + .. math:: + C_{lmn} = C_{ijk} D_{il}(g) D_{jm}(g) D_{kn}(g) \qquad \forall g \in SO(3) + + where :math:`D` are given by `wigner_D`. + + .. math:: + C_{ijk} C_{ijk} = 1 + + Args: + l1 (int): :math:`l_1` parameter of ``wigner_3j``. + l2 (int): :math:`l_2` parameter of ``wigner_3j``. + l3 (int): :math:`l_3` parameter of ``wigner_3j``. + dtype (mindspore.dtype): The type of input tensor. Default: ``mindspore.float32`` . + + Returns: + Tensor, Wigner 3j symbols :math:`C_{lmn}`. The shape of Tensor is :math:`(2l_1+1, 2l_2+1, 2l_3+1)`. + + Raise: + TypeError: If `l1`, `l2` or `l3` are not int. + ValueError: If `l1`, `l2` and `l3` do not satisfy abs(l2 - l3) <= l1 <= l2 + l3. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.o3 import wigner_3j + >>> m = wigner_3j(1,1,1) + >>> print(m) + [[[ 0. 0. 0. ] + [ 0. 0. 0.4082483] + [ 0. -0.4082483 0. ]] + [[ 0. 0. -0.4082483] + [ 0. 0. 0. ] + [ 0.4082483 0. 0. ]] + [[ 0. 0.4082483 0. ] + [-0.4082483 0. 0. ] + [ 0. 0. 0. ]]] + """ + if not isinstance(l1, int) and isinstance(l2, int) and isinstance(l3, int): + raise TypeError + if not abs(l2 - l3) <= l1 and l1 <= l2 + l3: + raise ValueError( + f"The inputs degree \"{l1}\" and \"{l2}\" do not match to output degree \"{l3}\". \nThe degrees should be |{l1} - {l2}| <= {l3} <= |{l1} + {l2}|.") + C = _so3_clebsch_gordan(l1, l2, l3) + + return Tensor(C, dtype=dtype) + + +@functools.lru_cache(maxsize=None) +def _so3_clebsch_gordan(l1, l2, l3, dtype=float64): + """Calculates the Clebsch-Gordon matrix for SO(3) coupling l1 and l2 to give l3.""" + Q1 = change_basis_real_to_complex(l1, dtype=dtype).asnumpy() + Q2 = change_basis_real_to_complex(l2, dtype=dtype).asnumpy() + Q3 = change_basis_real_to_complex(l3, dtype=dtype).asnumpy() + C = _su2_clebsch_gordan(l1, l2, l3) + + C = np.einsum('ij,kl,mn,ikn->jlm', Q1, Q2, np.conj(Q3.T), C) + + if not np.all(np.abs(np.imag(C)) < 1e-5): + raise ValueError + C = np.real(C) + + C = C / np.linalg.norm(C) + return C + + +@functools.lru_cache(maxsize=None) +def _su2_clebsch_gordan(j1, j2, j3): + """Calculates the Clebsch-Gordon matrix for SU(2) coupling j1 and j2 to give j3.""" + if not (isinstance(j1, (int, float)) and isinstance(j2, (int, float)) and isinstance(j3, (int, float))): + raise TypeError + mat = np.zeros((int(2 * j1 + 1), int(2 * j2 + 1), + int(2 * j3 + 1)), np.float64) + if int(2 * j3) in range(int(2 * abs(j1 - j2)), int(2 * (j1 + j2)) + 1, 2): + for m1 in (x / 2 for x in range(-int(2 * j1), int(2 * j1) + 1, 2)): + for m2 in (x / 2 for x in range(-int(2 * j2), int(2 * j2) + 1, 2)): + if abs(m1 + m2) <= j3: + mat[int(j1 + m1), int(j2 + m2), int(j3 + m1 + m2) + ] = _su2_clebsch_gordan_coeff((j1, m1), (j2, m2), (j3, m1 + m2)) + + return mat + + +def _su2_clebsch_gordan_coeff(idx1, idx2, idx3): + """core function of the Clebsch-Gordon coefficient for SU(2) coupling (j1,m1) and (j2,m2) to give (j3,m3).""" + + j1, m1 = idx1 + j2, m2 = idx2 + j3, m3 = idx3 + + if m3 != m1 + m2: + return 0 + vmin = int(max([-j1 + j2 + m3, -j1 + m1, 0])) + vmax = int(min([j2 + j3 + m1, j3 - j1 + j2, j3 + m3])) + + def f(n): + if not n == round(n): + raise ValueError + return factorial(round(n)) + + C = ( + (2.0 * j3 + 1.0) * Fraction( + f(j3 + j1 - j2) * f(j3 - j1 + j2) * + f(j1 + j2 - j3) * f(j3 + m3) * f(j3 - m3), + f(j1 + j2 + j3 + 1) * f(j1 - m1) * + f(j1 + m1) * f(j2 - m2) * f(j2 + m2) + ) + ) ** 0.5 + + S = 0 + for v in range(vmin, vmax + 1): + S += (-1) ** int(v + j2 + m2) * Fraction( + f(j2 + j3 + m1 - v) * f(j1 - m1 + v), + f(v) * f(j3 - j1 + j2 - v) * f(j3 + m3 - v) * f(v + j1 - j2 - m3) + ) + C = C * S + return C diff --git a/mindscience/e3nn/utils/__init__.py b/mindscience/e3nn/utils/__init__.py new file mode 100644 index 000000000..c67bd0f61 --- /dev/null +++ b/mindscience/e3nn/utils/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""init""" +from .ncon import Ncon +from .radius import radius, radius_graph, radius_full, radius_graph_full + + +__all__ = [ + "Ncon", + "radius", + "radius_graph", + "radius_full", + "radius_graph_full", +] \ No newline at end of file diff --git a/mindscience/e3nn/utils/batch_dot.py b/mindscience/e3nn/utils/batch_dot.py new file mode 100644 index 000000000..cf578e6bd --- /dev/null +++ b/mindscience/e3nn/utils/batch_dot.py @@ -0,0 +1,152 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from mindspore.ops.primitive import constexpr +from mindspore.ops import functional as F +from mindspore.ops import operations as P + + +@constexpr +def _get_batch_size(x1_shape, x2_shape): + """ + Get batch sizes from two inputs + """ + return x1_shape[0], x2_shape[0] + + +@constexpr +def _get_batch_size(x1_shape, x2_shape): + """ + Get batch sizes from two inputs + """ + return x1_shape[0], x2_shape[0] + + +@constexpr +def _calc_new_shape_batchdot(shape, axes, position=0): + """ + Calculate transpose and reshape parameters for input transformations, + 'position' refers to whether tensor is first or second in the op. + """ + axis = axes[position] + contraction_axes = tuple([axis]) + prod_contraction = 1 + for i in contraction_axes: + prod_contraction *= shape[i] + free_axes = tuple(i for i in range(1, len(shape)) if i not in contraction_axes) + free_dims = tuple(shape[i] for i in free_axes) + prod_free = 1 + for free_dim in free_dims: + prod_free *= free_dim + + transpose_perm = contraction_axes + free_axes if position else free_axes + contraction_axes + transpose_perm = tuple([0]) + transpose_perm + new_shape = (prod_contraction, prod_free) if position else (prod_free, prod_contraction) + new_shape = tuple([shape[0]]) + new_shape + return new_shape, transpose_perm, free_dims + + +@constexpr +def _check_batch_size(x1_batch_size, x2_batch_size, prim_name=None): + """ + Check whether batch size of two inputs are the same + """ + msg_prefix = f"For '{prim_name}', the" if prim_name else "The" + if x1_batch_size != x2_batch_size: + raise ValueError(f"{msg_prefix} inputs 'x1', 'x2' should have the same batch sizes, but got " + f"'x1_batch_size': {x1_batch_size} and 'x2_batch_size': {x2_batch_size}.") + + +@constexpr +def _check_axes_for_batch_dot(x1_shape, x2_shape, axes): + """ + Check whether axes are valid and cast axes from tuple to list + """ + if axes is None: + if len(x2_shape) == 2: + axes = [len(x1_shape) - 1, len(x2_shape) - 1] + else: + axes = [len(x1_shape) - 1, len(x2_shape) - 2] + + if isinstance(axes, (list, tuple)): + if isinstance(axes, tuple): + axes = list(axes) + # Reverse if axis < 0 + if axes[0] < 0: + axes[0] += len(x1_shape) + if axes[1] < 0: + axes[1] += len(x2_shape) + elif isinstance(axes, int): + if axes < 0: + axes = [axes + len(x1_shape), axes + len(x2_shape)] + else: + axes = [axes, axes] + return axes + + +@constexpr +def _get_output_shape(batch_size, x1_ret, x2_ret): + """ + Compute output shape for batch dot + """ + output_shape = tuple([batch_size]) + x1_ret + x2_ret + return output_shape + + +def batch_dot(x1, x2, axes=None): + transpose_op = P.Transpose() + batch_matmul_op = P.BatchMatMul() + squeeze_one_op = P.Squeeze(1) + squeeze_minus_one_op = P.Squeeze(-1) + # input validity checks + x1_shape = F.shape(x1) + x2_shape = F.shape(x2) + x1_dim_num = len(x1_shape) + x2_dim_num = len(x2_shape) + + x1_batch_size, x2_batch_size = _get_batch_size(x1_shape, x2_shape, 'batch_dot') + + _check_batch_size(x1_batch_size, x2_batch_size, 'batch_dot') + axes = _check_axes_for_batch_dot(x1_shape, x2_shape, axes, 'batch_dot') + + if x1_dim_num == 2: + x1 = F.expand_dims(x1, 1) + axes[0] += 1 + if x2_dim_num == 2: + x2 = F.expand_dims(x2, 2) + + x1_shape = F.shape(x1) + x2_shape = F.shape(x2) + + x1_reshape_fwd, x1_transpose_fwd, x1_ret = _calc_new_shape_batchdot(x1_shape, axes, 0) + x2_reshape_fwd, x2_transpose_fwd, x2_ret = _calc_new_shape_batchdot(x2_shape, axes, 1) + output_shape = _get_output_shape(x1_batch_size, x1_ret, x2_ret) + + x1_transposed = transpose_op(x1, x1_transpose_fwd) + x2_transposed = transpose_op(x2, x2_transpose_fwd) + x1_reshaped = F.reshape(x1_transposed, x1_reshape_fwd) + x2_reshaped = F.reshape(x2_transposed, x2_reshape_fwd) + + # Batch matmal op part + mul_result = batch_matmul_op(x1_reshaped, x2_reshaped) + + final_result = F.reshape(mul_result, output_shape) + + # if the original dims are expanded, restore them from 3 to 2 + if x1_dim_num == 2: + final_result = squeeze_one_op(final_result) + elif x2_dim_num == 2: + final_result = squeeze_minus_one_op(final_result) + + return final_result diff --git a/mindscience/e3nn/utils/func.py b/mindscience/e3nn/utils/func.py new file mode 100644 index 000000000..bdd06099c --- /dev/null +++ b/mindscience/e3nn/utils/func.py @@ -0,0 +1,160 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +from scipy.linalg import expm + +from mindspore import Tensor, ops +from mindspore.ops import operations as P + + +def norm_keep(input_x, axis): + r""" + Compute the matrix norm or vector norm of a given tensor, and the output tensors have dimension retained. + + Args: + input_x (Tensor): Input tensor. The dtype must be float32 or float16. + axis (Union[int, list, tuple]): Specifies which dimension or dimensions of input to calculate the norm across. + + Returns: + Tensor, has the same dtype and shape as `input`. + """ + return ops.expand_dims(input_x.norm(None, axis, False), axis=axis) + + +def _to_tensor(arg): + if isinstance(arg, (int, float)): + return Tensor(arg) + elif isinstance(arg, (np.ndarray, list, tuple)): + return Tensor(arg) + elif isinstance(arg, Tensor): + return arg + else: + raise TypeError + + +def broadcast_shapes(*shapes): + r""" + Return the broadcast shape of the shapes of input tensors. + + Args: + shapes (tuple): Any number of shapes of tensors to be broadcasted. + + Returns: + Tuple, a shape compatible with all input shapes. + """ + max_len = 0 + for shape in shapes: + if isinstance(shape, int): + if max_len < 1: + max_len = 1 + elif isinstance(shape, tuple) or isinstance(shape, list): + s = len(shape) + if max_len < s: + max_len = s + result = [1] * max_len + for shape in shapes: + if isinstance(shape, int): + shape = (shape,) + if isinstance(shape, tuple) or isinstance(shape, list): + for i in range(-1, -1 - len(shape), -1): + if shape[i] < 0: + raise RuntimeError("Trying to create tensor with negative dimension ({}): ({})" + .format(shape[i], shape[i])) + if shape[i] == 1 or shape[i] == result[i]: + continue + if result[i] != 1: + raise RuntimeError( + "Shape mismatch: objects cannot be broadcast to a single shape") + result[i] = shape[i] + else: + raise RuntimeError( + "Input shapes should be of type ints, a tuple of ints, or a list of ints, got ", shape) + return tuple(result) + + +def broadcast_tensors(*tensors): + r""" + Broadcasts the given tensors. + + Args: + tensors (Tensor): Any number of tensors of the same type. + + Returns: + A list of tensors, tensors after broadcast. + """ + shapes = [] + for tensor in tensors: + shapes.append(tensor.shape) + shape = broadcast_shapes(*shapes) + res = [] + for tensor in tensors: + if len(shape): + res.append(ops.broadcast_to(tensor, shape)) + else: + res.append(tensor) + return res + + +def broadcast_args(*args): + r""" + Broadcasts the given data with multiple types. + + Args: + *arg (Union[Tensor[float32], list[float], tuple[float], ndarray[np.float32], float]): Any number of data to be broadcasted. + + Returns: + A list of tensors, tensors after broadcast. + """ + tensors = [] + for arg in args: + tensors.append(_to_tensor(arg)) + res = broadcast_tensors(*tensors) + return res + + +def _ndexpm(mat): + """Compute matrix-product exponential of matrices.""" + if isinstance(mat, Tensor): + mat = mat.asnumpy() + mat_shape = mat.shape + if len(mat_shape) < 2: + raise ValueError + elif len(mat_shape) == 2: + return Tensor(expm(mat)) + else: + mat = np.reshape(mat, (-1, mat_shape[-1], mat_shape[-1])) + n = mat.shape[0] + for i in range(n): + mat[i] = expm(mat[i]) + mat = np.reshape(mat, mat_shape) + return Tensor(mat) + + +def _expand_last_dims(x): + if isinstance(x, Tensor): + x = ops.expand_dims(x, -1) + x = ops.expand_dims(x, -1) + else: + x = x[..., None, None] + return x + + +def narrow(inputs, axis, start, length): + """tmp narrow API""" + begins = [0] * inputs.ndim + begins[axis] = start + sizes = [i for i in inputs.shape] + sizes[axis] = length + return P.Slice()(inputs, begins, sizes) diff --git a/mindscience/e3nn/utils/initializer.py b/mindscience/e3nn/utils/initializer.py new file mode 100644 index 000000000..e4f4d46a7 --- /dev/null +++ b/mindscience/e3nn/utils/initializer.py @@ -0,0 +1,63 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from mindspore.common.initializer import Initializer, _register, _init_random_uniform, _assignment, TruncatedNormal, \ + Normal, HeNormal, HeUniform, XavierUniform + + +@_register() +class Uniform(Initializer): + r""" + Generates an array with values sampled from Uniform distribution :math:`{U}(-\text{scale}, \text{scale})` in order + to initialize a tensor. + + Args: + scale (float): The bound of the Uniform distribution. Default: 1.0. + + + Examples: + >>> import mindspore + >>> from mindspore.common.initializer import initializer, Uniform + >>> tensor1 = initializer(Uniform(), [1, 2, 3], mindspore.float32) + >>> tensor2 = initializer('uniform', [1, 2, 3], mindspore.float32) + """ + + def __init__(self, scale=1.): + super(Uniform, self).__init__(scale=scale) + self.scale = scale + + def _initialize(self, arr): + tmp = _init_random_uniform(0., self.scale, arr.shape) + _assignment(arr, tmp) + + +def renormal_initializer(init_method): + name_list = ['zeros', 'ones', 'truncatedNormal', 'normal', 'uniform', 'he_uniform', 'he_normal', 'xavier_uniform'] + if not init_method in name_list and not isinstance(init_method, Initializer): + raise ValueError( + f'initial method \"{init_method}\" is not supported.') + + if init_method == 'truncatedNormal': + init_method = TruncatedNormal(sigma=1.) + elif init_method == 'normal': + init_method = Normal(sigma=1.) + elif init_method == 'uniform': + init_method = Uniform() + elif init_method == 'he_normal': + init_method = HeNormal() + elif init_method == 'he_uniform': + init_method = HeUniform() + + return init_method diff --git a/mindscience/e3nn/utils/linalg.py b/mindscience/e3nn/utils/linalg.py new file mode 100644 index 000000000..e43c7e9cb --- /dev/null +++ b/mindscience/e3nn/utils/linalg.py @@ -0,0 +1,33 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from mindspore import ops + + +def _direct_sum(*matrices): + r"""Direct sum of matrices, put them in the diagonal + """ + front_indices = matrices[0].shape[:-2] + m = sum(x.shape[-2] for x in matrices) + n = sum(x.shape[-1] for x in matrices) + total_shape = list(front_indices) + [m, n] + zeros = ops.Zeros() + out = zeros(tuple(total_shape), matrices[0].dtype) + i, j = 0, 0 + for x in matrices: + m, n = x.shape[-2:] + out[..., i: i + m, j: j + n] = x + i += m + j += n + return out diff --git a/mindscience/e3nn/utils/ncon.py b/mindscience/e3nn/utils/ncon.py new file mode 100644 index 000000000..f47d64bab --- /dev/null +++ b/mindscience/e3nn/utils/ncon.py @@ -0,0 +1,699 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""ncon""" +from copy import deepcopy +import numpy as np + +from mindspore import ops, nn, vmap +from mindspore.numpy import tensordot, trace, expand_dims + + +def list_to_tuple(lst): + """list_to_tuple""" + return tuple(list_to_tuple(item) if isinstance(item, list) else item for item in lst) + + +def nest_vmap(fn, in_list, out_list, pt): + """nest vmap function""" + if pt == len(in_list) - 1: + return vmap(fn, in_list[pt], out_list[pt]) + return vmap(nest_vmap(fn, in_list, out_list, pt + 1), in_list[pt], out_list[pt]) + + +def _create_order(con_list): + """ Identify all unique, positive indices and return them sorted. """ + flat_con = np.concatenate(con_list) + return np.unique(flat_con[flat_con > 0]).tolist() + + +def _single_trace(con, leg): + """_single_trace""" + leg = np.where(np.array(con) == leg)[0] + con = np.delete(con, leg).tolist() + return con, leg.tolist() + + +def _find_sum(con_list): + """_find_sum + + Args: + con_list: con_list + + Returns: + legs + """ + flat = [] + for item in con_list: + flat += item + legs = [] + for leg in np.unique(flat): + if leg < 0: + continue + if np.sum(np.array(flat) == leg) == 1: + legs.append(leg) + return legs + + +def _find_trace(con_list): + """_find_trace + + Args: + con_list: con_list + + Returns: + legs_list + """ + legs_list = [] + for i in range(len(con_list)): + tr_num = len(con_list[i]) - len(np.unique(con_list[i])) + legs = [] + if tr_num: + for leg in np.unique(con_list[i]): + if sum(con_list[i] == leg) > 1 and leg > 0: + leg = np.where(con_list[i] == leg)[0].tolist() + legs.append(leg) + con_list[i] = np.delete(con_list[i], leg).tolist() + + legs_list.append(legs) + return legs_list + + +def _find_batch(con_list): + """_find_batch + + Args: + con_list: con_list + + Returns: + outer + """ + outer = [] + for i in con_list: + if not isinstance(i, np.ndarray): + i = np.array(i) + outer.extend(i[i < 0].tolist()) + if not outer: + return None + if -len(outer) == min(outer): + return None + + for leg in np.unique(outer): + if sum(outer == leg) == 1: + outer = np.delete(outer, outer.index(leg)).tolist() + + return outer + + +def _process_perm(con, batch_leg): + """_process_perm""" + p = list(range(len(con))) + for i, ind in enumerate(batch_leg): + j = con.index(ind) + if i == j: + continue + con[i], con[j] = con[j], con[i] + p[i], p[j] = p[j], p[i] + + return con, tuple(p) + + +def _make_dict(mode, + inds=None, + legs=None, + batch_leg=None, + p_list=None, + res_legs=None, + permute_index=None, + expand_axis=None): + """_summary_ + + Args: + mode: mode + inds: inds. Defaults to None. + legs: legs. Defaults to None. + batch_leg: batch_leg. Defaults to None. + p_list: p_list. Defaults to None. + res_legs: res_legs. Defaults to None. + permute_index: permute_index. Defaults to None. + expand_axis: expand_axis. Defaults to None. + + Raises: + ValueError: ValueError + + Returns: + d + """ + d = {} + calculate_mode = 'mode' + indices = 'inds' + indices_legs = 'legs' + d[calculate_mode] = mode + + if d[calculate_mode] == 'permute': + d['perms'] = p_list + + elif d[calculate_mode] == 'outer': + d[indices] = inds + + elif d[calculate_mode] in ('diag', 'sum', 'trace'): + d[indices] = inds + d[indices_legs] = legs + + elif d[calculate_mode] == 'ndot': + d[indices] = inds + d[indices_legs] = legs + d['batch_leg'] = batch_leg + + elif d[calculate_mode] == 'hadamard': + d[indices] = inds + d[indices_legs] = legs + d['res_legs'] = res_legs + d['permute_index'] = permute_index + d['expand_axis'] = expand_axis + + else: + raise ValueError + + return d + + +def _process_commands(con_list): + """_process_commands + + Args: + con_list: con_list + + Returns: + conmmands, operators + """ + conmmands = [] + operators = [] + + # find sum index + sum_legs = _find_sum(con_list) + for leg in sum_legs: + for i, con in enumerate(con_list): + if leg in con: + leg_ind = con.index(leg) + con_list[i].remove(leg) + conmmands.append(_make_dict('sum', [i], [leg_ind])) + operators.append(ops.sum) + + # find trace + trace_legs = _find_trace(con_list) + for i, leg_list in enumerate(trace_legs): + if leg_list: + for legs in leg_list: + conmmands.append(_make_dict('trace', [i], legs)) + operators.append(trace) + + order = _create_order(con_list) + batch_legs = _find_batch(con_list) + + if not con_list[0]: + return conmmands, operators + + do_ndot(con_list, conmmands, operators, order, batch_legs) + + # do Hadamard(alike) product + do_hadamard(con_list, conmmands, operators) + + # do outer product + for i, con in enumerate(con_list): + if not i: + continue + if con: + inds = [0, i] + for leg in con: + con_list[0].append(leg) + con_list[i] = [] + conmmands.append(_make_dict('outer', inds)) + operators.append(tensordot) + + # do diagonal + min_leg = min(con_list[0]) + for leg in range(-1, min_leg - 1, -1): + num_leg = con_list[0].count(leg) + while num_leg > 1: + i = con_list[0].index(leg) + j = con_list[0].index(leg, i + 1) + conmmands.append(_make_dict('diag', [0], [i, j])) + operators.append(ops.diagonal) + con_list[0] = con_list[0][:i] + con_list[0][i + 1:j] + con_list[0][j + 1:] + [leg] + num_leg = con_list[0].count(leg) + + # do final permutation + fin_con = list(range(-1, -1 - len(con_list[0]), -1)) + con_list[0], p = _process_perm(con_list[0], fin_con) + conmmands.append(_make_dict('permute', p_list=[p])) + operators.append(ops.permute) + + return conmmands, operators + + +def do_ndot(con_list, conmmands, operators, order, batch_legs): + """do_ndot + + Args: + con_list: con_list + conmmands: conmmands + operators: operators + order: order + batch_legs: batch_legs + """ + while order: + leg_now = order[-1] + inds = [] + legs = [] + batch_legs_now = [] + + # find the two tensors' indices + for i, item in enumerate(con_list): + if leg_now in item: + inds.append(i) + + # check trace + if len(inds) == 1: + con_list[inds[0]], legs = _single_trace(con_list[inds[0]], leg_now) + conmmands.append(_make_dict('trace', inds, legs)) + operators.append(trace) + + else: + # find batch legs + batch_leg_inds = [] + if batch_legs is not None: + tmp = np.intersect1d(con_list[inds[0]], con_list[inds[1]]) + batch_legs_now = np.intersect1d(tmp, batch_legs, False).tolist() + + # find indices of batch legs + for batch_leg in batch_legs_now: + i_leg_0 = con_list[inds[0]].index(batch_leg) + i_leg_1 = con_list[inds[1]].index(batch_leg) + con_list[inds[0]].remove(batch_leg) + con_list[inds[1]].remove(batch_leg) + batch_leg_inds.append((i_leg_0, i_leg_1, None)) + + ndot_legs = [] + ndot_leg_inds = [] + # find all ndot legs and their indices + for leg in con_list[inds[0]]: + if leg in con_list[inds[1]]: + i_leg_0 = con_list[inds[0]].index(leg) + i_leg_1 = con_list[inds[1]].index(leg) + ndot_legs.append(leg) + ndot_leg_inds.append([i_leg_0, i_leg_1]) + + # do ndot contraction and update order + for leg in ndot_legs: + con_list[inds[0]].remove(leg) + con_list[inds[1]].remove(leg) + for leg in ndot_legs: + if leg != leg_now: + order.remove(leg) + + ndot_leg_inds = ndot_leg_inds[0] if len(ndot_leg_inds) == 1 else np.array( + ndot_leg_inds).transpose().tolist() + conmmands.append(_make_dict('ndot', inds, list_to_tuple(ndot_leg_inds), batch_leg_inds)) + operators.append( + nest_vmap(tensordot, batch_leg_inds, [0] * len(batch_leg_inds), 0) if batch_leg_inds else tensordot) + + # merge two con_list + for leg in con_list[inds[1]]: + if leg not in batch_legs_now: + con_list[inds[0]].append(leg) + con_list[inds[1]] = [] + con_list[inds[0]] = batch_legs_now + con_list[inds[0]] + + order = order[:-1] + + +def do_hadamard(con_list, conmmands, operators): + """do_hadamard + + Args: + con_list: con_list + conmmands: conmmands + operators: operators + """ + is_con_list_not_none = len(con_list) == 2 and con_list[1] + if is_con_list_not_none and not [i for i in con_list[0] if i > 0] and not [i for i in con_list[1] if i > 0]: + con_list_all = [] + for con in con_list: + con_list_all.extend(con) + con_min_leg = min(con_list_all) + out_list = [i for i in range(-1, con_min_leg - 1, -1)] + + res_legs = [] + for ind in out_list: + for i, con in enumerate(con_list): + if ind in con: + res_legs.append((i, con.index(ind))) + break + + hadamard_legs = [[], []] + con_raw = deepcopy(con_list) + handle_inds(con_list, out_list, hadamard_legs) + + expand_axis = deepcopy(hadamard_legs) + for i, axis in enumerate(expand_axis): + if axis and len(axis) <= 1: + expand_axis[i] = axis[0] + + # input permute + permute_index = [[], []] + con_sort = deepcopy(con_raw) + for i, con in enumerate(con_raw): + con_sort[i].sort(reverse=True) + _, permute_index[i] = _process_perm(con, con_sort[i]) + + conmmands.append( + _make_dict('hadamard', + inds=[0, 1], + legs=hadamard_legs, + res_legs=res_legs, + permute_index=permute_index, + expand_axis=expand_axis)) + operators.append([ops.permute, ops.tile, ops.mul, expand_dims]) + + +def handle_inds(con_list, out_list, hadamard_legs): + """handle_inds""" + for i, con in enumerate(con_list): + if con: + for ind in out_list: + if ind not in con: + hadamard_legs[i].append((out_list.index(ind))) + if i: + con_list[i] = [] + else: + con_list[i] = out_list + + +class Ncon(nn.Cell): + r""" + Multiple-tensor contraction operator which has similar function to Einsum. + + Args: + con_list (List[List[int]]): lists of indices for each tensor. + The number of each list in `con_list` should coincide with the corresponding tensor's dimensions. + The positive indices indicate the dimensions to be contracted or summed. + The negative indices indicate the dimensions to be keeped (as batch dimensions). + + Inputs: + - **input** (List[Tensor]) - Tensor List. + + Outputs: + - **output** (Tensor) - The shape of tensor depends on the input and the computation process. + + Raises: + ValueError: If the number of commands is not match the number of operations. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindspore import ops + >>> from mindchemistry.e3.utils import Ncon + Trace of a matrix: + >>> a = ops.ones((3, 3)) + >>> Ncon([[1, 1]])([a]) + 3.0 + Diagonal of a matrix: + >>> Ncon([[-1, -1]])([a]) + [1. 1. 1.] + Outer product: + >>> b = ops.ones((2)) + >>> c = ops.ones((3)) + >>> Ncon([[-1], [-2]])([b, c]).shape + (2, 3) + Batch matrix multiplication + >>> d = ops.ones((2, 3, 4)) + >>> e = ops.ones((2, 4, 1)) + >>> Ncon([[-1, -2, 1], [-1, 1, -3]])([d, e]).shape + (2, 3, 1) + """ + + def __init__(self, con_list): + super().__init__() + self.con_list = tuple(con_list) + con_list_copy = deepcopy(con_list) + self.commands, self.ops = _process_commands(con_list_copy) + if len(self.commands) != len(self.ops): + raise ValueError(f'{self.commands} is not match {len(self.ops)}') + + def construct(self, ten_list): + """ + The list of tensors to be conctracted. + """ + i = 0 + for d in self.commands: + if d['mode'] == 'diag': + ten_list[0] = self.ops[i](ten_list[0], 0, *d['legs']) + elif d['mode'] == 'permute': + ten_list[0] = self.ops[i](ten_list[0], d['perms'][0]) + elif d['mode'] == 'sum': + i1 = d['inds'][0] + ten_list[i1] = self.ops[i](ten_list[i1], d['legs'][0]) + elif d['mode'] == 'trace': + i1 = d['inds'][0] + ten_list[i1] = self.ops[i](ten_list[i1], 0, d['legs'][0], d['legs'][1]) + elif d['mode'] == 'outer': + i1, i2 = d['inds'] + ten_list[i1] = self.ops[i](ten_list[i1], ten_list[i2], 0) + elif d['mode'] == 'ndot': + i1, i2 = d['inds'] + ten_list[i1] = self.ops[i](ten_list[i1], ten_list[i2], d['legs']) + elif d['mode'] == 'hadamard': + i1, i2 = d['inds'] + a = ten_list[i1] + b = ten_list[i2] + res_legs = d['res_legs'] + + a = ops.permute(a, d['permute_index'][i1]) + b = ops.permute(b, d['permute_index'][i2]) + + if d['expand_axis'][i1]: + a = expand_dims(a, d['expand_axis'][i1]) + if d['expand_axis'][i2]: + b = expand_dims(b, d['expand_axis'][i2]) + + tile_index = [[1 for _ in res_legs], [1 for _ in res_legs]] + for j in range(len(d['legs'][i1])): + tile_index[0][d['legs'][i1][j]] = ten_list[res_legs[d['legs'][i1][j]][0]].shape[res_legs[ + d['legs'][i1][j]][1]] + for j in range(len(d['legs'][i2])): + tile_index[1][d['legs'][i2][j]] = ten_list[res_legs[d['legs'][i2][j]][0]].shape[res_legs[ + d['legs'][i2][j]][1]] + a = ops.tile(a, tuple(tile_index[0])) + b = ops.tile(b, tuple(tile_index[1])) + + ten_list[i1] = ops.mul(a, b) + else: + i += 1 + continue + i += 1 + return ten_list[0] + + def __repr__(self): + s = f'Ncon: {self.con_list}\n' + for d in self.commands: + s += str(d) + '\n' + return s + + +def test_other(): + """test_other""" + ncon = Ncon([[5, -1, 1, 4, 3, -2], [3, -2, -1, 4, 2], [2, -3], [-3, -4]]) + v1 = ops.ones((3, 1, 3, 4, 5, 2)) + v2 = ops.ones((5, 2, 1, 4, 6)) + v3 = ops.ones((6, 3)) + v4 = ops.ones((3, 4)) + print(ncon) + out = ncon([v1, v2, v3, v4]) + print(out.shape) + + ncon = Ncon([[-1, 2], [-1, 1], [2, 1, -2]]) + v1 = ops.ones((20, 50)) + v2 = ops.ones((20, 2)) + v3 = ops.ones((50, 2, 7)) + print(ncon) + out = ncon([v1, v2, v3]) + print(out.shape) + + ncon = Ncon([[-1, -2, 1], [-1, 1]]) + v1 = ops.ones((3, 4, 5)) + v2 = ops.ones((3, 5)) + print(ncon) + out = ncon([v1, v2]) + print(out.shape) + + +def test_diagonal(): + """test_diagonal""" + ncon = Ncon([[-1, -1]]) + v1 = ops.ones((3, 3)) + print(ncon) + out = ncon([v1]) + print(out.shape) + print(out) + + +def test_outer(): + """test_other""" + ncon = Ncon([[-1], [-2]]) + v1 = ops.ones((2)) + v2 = ops.ones((3)) + print(ncon) + out = ncon([v1, v2]) + print(out.shape) + print(out) + + +def test_outer_multi_input(): + """test_other""" + ncon = Ncon([[-1], [-2], [-3]]) + v1 = ops.ones((2)) + v2 = ops.ones((3)) + v3 = ops.ones((4)) + print(ncon) + out = ncon([v1, v2, v3]) + print(out.shape) + print(out) + + +def test_ndot(): + """test_other""" + ncon = Ncon([[-1, -2, 1], [-1, 1]]) + v1 = ops.ones((3, 4, 5)) + v2 = ops.ones((3, 5)) + print(ncon) + out = ncon([v1, v2]) + print(out.shape) + print(out) + + +def test_ndot_2(): + """test_other""" + ncon = Ncon([[-1, -2, 1, 2], [-1, 1, 2]]) + v1 = ops.ones((3, 4, 5, 6)) + v2 = ops.ones((3, 5, 6)) + print(ncon) + out = ncon([v1, v2]) + print(out.shape) + print(out) + + +def test_hadamard(): + """test_hadamard""" + a = np.arange(6).reshape((2, 3)) + b = np.arange(6).reshape((2, 3)) + print(a) + print(b) + einstr = f"zu,zu->zu" + d = np.einsum(einstr, a, b) + print(d) + print(d.shape) + + ma = ms.Tensor(a, dtype=ms.float32) + mb = ms.Tensor(b, dtype=ms.float32) + ncon = Ncon([[-1, -2], [-1, -2]]) + print(ncon) + md = ncon([ma, mb]) + print(md.shape) + print(np.allclose(md.asnumpy(), d)) + + +def test_hadamard_alike(): + """test_hadamard_alike""" + a = np.arange(8).reshape((2, 4)) + b = np.arange(24).reshape((2, 3, 4)) + print(a) + print(b) + einstr = f"zi,zui->zui" + d = np.einsum(einstr, a, b) + print(d) + print(d.shape) + + ma = ms.Tensor(a, dtype=ms.float32) + mb = ms.Tensor(b, dtype=ms.float32) + ncon = Ncon([[-1, -3], [-1, -2, -3]]) + print(ncon) + md = ncon([ma, mb]) + print(md.shape) + print(np.allclose(md.asnumpy(), d)) + + +def test_hadamard_with_outer(): + """test_hadamard_with_outer""" + a = np.arange(24).reshape((2, 3, 4)) + b = np.arange(30).reshape((2, 3, 5)) + print(f"a:\n {a}") + print(f"b:\n {b}") + + einstr = f"zui,zuj->zuij" + + d = np.einsum(einstr, a, b) + print(f"d:\n {d}") + print(f"d.shape:\n {d.shape}") + + ma = ms.Tensor(a, dtype=ms.float32) + mb = ms.Tensor(b, dtype=ms.float32) + + ncon = Ncon([[-1, -2, -3], [-1, -2, -4]]) + print(ncon) + md = ncon([ma, mb]) + print(md.shape) + print(np.allclose(md.asnumpy(), d)) + + +def test_hadamard_outer_nosequential(): + """test_hadamard_outer_nosequential""" + a = np.arange(8).reshape((2, 4)) + b = np.arange(30).reshape((2, 5, 3)) + print(f"a:\n {a}") + print(f"b:\n {b}") + + einstr = f"ac,adb->abcd" + + d = np.einsum(einstr, a, b) + print(f"d:\n {d}") + print(f"d.shape:\n {d.shape}") + ma = ms.Tensor(a, dtype=ms.float32) + mb = ms.Tensor(b, dtype=ms.float32) + + ncon = Ncon([[-1, -3], [-1, -4, -2]]) + print(ncon) + md = ncon([ma, mb]) + print(md.shape) + print(np.allclose(md.asnumpy(), d)) + + +def test_sum(): + """test_other""" + ncon = Ncon([[1, 2]]) + v1 = ops.ones((2, 3)) + print(ncon) + out = ncon([v1]) + print(out.shape) + print(out) + + +if __name__ == '__main__': + import mindspore as ms + + ms.set_context(device_target="GPU", device_id=4, mode=ms.GRAPH_MODE, save_graphs=False) + np.random.seed(123) + + test_hadamard_outer_nosequential() diff --git a/mindscience/e3nn/utils/perm.py b/mindscience/e3nn/utils/perm.py new file mode 100644 index 000000000..5b8acd85c --- /dev/null +++ b/mindscience/e3nn/utils/perm.py @@ -0,0 +1,140 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""permutation operators""" +import random +import math + + +def _is_perm(p): + return sorted(set(p)) == list(range(len(p))) + + +def _identity(n): + return tuple(i for i in range(n)) + + +def _compose(p1, p2): + r""" + compute p1 . p2 + p: i |-> p[i] + [p1.p2](i) = p1(p2(i)) = p1[p2[i]] + """ + assert _is_perm(p1) and _is_perm(p2) + assert len(p1) == len(p2) + + return tuple(p1[p2[i]] for i in range(len(p1))) + + +def _inverse(p): + r""" + compute the inverse permutation + """ + return tuple(p.index(i) for i in range(len(p))) + + +def _rand(n): + i = random.randint(0, math.factorial(n) - 1) + return _from_int(i, n) + + +def _from_int(i, n): + pool = list(range(n)) + p = [] + for _ in range(n): + j = i % n + i = i // n + p.append(pool.pop(j)) + n -= 1 + return tuple(p) + + +def _to_int(p): + n = len(p) + pool = list(range(n)) + i = 0 + m = 1 + for j in p: + k = pool.index(j) + i += k * m + m *= len(pool) + pool.pop(k) + return i + + +def _group(n): + return {_from_int(i, n) for i in range(math.factorial(n))} + + +def _germinate(subset): + while True: + n = len(subset) + subset = subset.union([_inverse(p) for p in subset]) + subset = subset.union([ + _compose(p1, p2) + for p1 in subset + for p2 in subset + ]) + if len(subset) == n: + return subset + + +def _is__(g): + if len(g) == 0: + return False + + n = len(next(iter(g))) + + for p in g: + assert len(p) == n, p + + if not _identity(n) in g: + return False + + for p in g: + if not _inverse(p) in g: + return False + + for p1 in g: + for p2 in g: + if not _compose(p1, p2) in g: + return False + + return True + + +def _to_cycles(p): + n = len(p) + + cycles = set() + + for i in range(n): + c = [i] + while p[i] != c[0]: + i = p[i] + c.append(i) + if len(c) >= 2: + i = c.index(min(c)) + c = c[i:] + c[:i] + cycles.add(tuple(c)) + + return cycles + + +def _sign(p): + s = 1 + for c in _to_cycles(p): + if len(c) % 2 == 0: + s = -s + return s diff --git a/mindscience/e3nn/utils/radius.py b/mindscience/e3nn/utils/radius.py new file mode 100644 index 000000000..b6cf2cd5f --- /dev/null +++ b/mindscience/e3nn/utils/radius.py @@ -0,0 +1,248 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""radius""" +from scipy.spatial import cKDTree +import numpy as np + + +def _reshape_and_batch(x, batch_x): + """_reshape_and_batch""" + if x.ndim > 2: + if batch_x is None: + batch_x = np.broadcast_to(np.arange(0, x.shape[0]).reshape(-1, 1), (x.shape[0], x.shape[1])).flatten() + x = x.reshape(-1, x.shape[-1]) + else: + if batch_x is None: + batch_x = np.zeros(x.shape[0], dtype=x.dtype) + x = x.reshape((-1, 1)) if x.ndim == 1 else x + + return x, batch_x.astype(np.int64) + + +def radius(x, y, r, batch_x=None, batch_y=None, max_num_neighbors=32): + r""" + Find all points in `x` for each element in `y` within distance `r`. + + Args: + x (ndarray): node feature matrix of x. + y (ndarray): node feature matrix of y. + r (ndarray, float): the radius. + batch_x (ndarray): batch vector of x. If it is none, then calculate based on x and return. Default: ``None``. + batch_y (ndarray): batch vector of y. If it is none, then calculate based on y and return. Default: ``None``. + max_num_neighbors (int): The maximum number of neighbors to return for each element in `y`. Dufault: ``32``. + + Returns: + edge_index (numpy.ndarray) - including edges of source and destination. + + batch_x (numpy.ndarray) - batch vector of x. + + batch_y (numpy.ndarray) - batch vector of y. + + Raises: + ValueError: If the last dimension of `x` and `y` do not match. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.utils import radius + >>> import numpy as np + >>> np.random.seed(1) + >>> x = np.random.random((5, 12, 3)) + >>> r = 0.5 + >>> edge_index, batch_x, batch_y = radius(x, x, r) + >>> print(edge_index.shape) + (2, 222) + >>> print(batch_x.shape) + (60,) + >>> print(batch_y.shape) + (60,) + + """ + if not x.shape[-1] == y.shape[-1]: + raise ValueError(f"Feature size do not match.") + if max_num_neighbors < 1: + raise Warning(f'max_num_neighbors: {max_num_neighbors}') + + x, batch_x = _reshape_and_batch(x, batch_x) + y, batch_y = _reshape_and_batch(y, batch_y) + + x = np.concatenate((x, 2 * r * batch_x.reshape(-1, 1).astype(x.dtype)), axis=-1) + y = np.concatenate((y, 2 * r * batch_y.reshape(-1, 1).astype(y.dtype)), axis=-1) + + tree = cKDTree(x) + _, col = tree.query(y, k=max_num_neighbors, distance_upper_bound=r + 1e-8) + row = [np.full_like(c, i) for i, c in enumerate(col)] + col = col.flatten() + row = np.concatenate(row, axis=0) + mask = col < int(tree.n) + + return np.stack([row[mask], col[mask]], axis=0), batch_x, batch_y + + +# pylint: disable=C0103 +# pylint: disable=W0612 +def radius_graph(x, r, batch=None, loop=False, max_num_neighbors=32, flow='source_to_target'): + r""" + Computes graph edges to all points within a given distance. + + Args: + x (ndarray): node feature matrix. + r (ndarray, float): the radius. + batch (Tensor): batch vector. If it is none, then calculate and return. Default: ``None``. + loop (bool): whether contain self-loops in the graph. Dufault: ``False``. + max_num_neighbors (int): The maximum number of neighbors to return for each element in `y`. Dufault: ``32``. + flow (str): {'source_to_target', 'target_to_source'}, the flow direction when using in combination with + message passing. Dufault: ``'source_to_target'``. + + Returns: + edge_index (ndarray) - including edges of source and destination. + + batch (ndarray) - batch vector. + + Raises: + ValueError: If `flow` is not in {'source_to_target', 'target_to_source'}. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.utils import radius_graph + >>> import numpy as np + >>> np.random.seed(1) + >>> x = np.random.random((5, 12, 3)) + >>> r = 0.5 + >>> edge_index, batch = radius_graph(x, r) + >>> print(edge_index.shape) + (2, 162) + >>> print(batch.shape) + (60,) + """ + + if flow not in ['source_to_target', 'target_to_source']: + raise ValueError(f'`flow` should be in ["source_to_target", "target_to_source"].') + (row, col), batch, _ = radius(x, x, r, batch, batch, max_num_neighbors + 1) + row, col = (col, row) if flow == 'source_to_target' else (row, col) + if not loop: + mask = row != col + row, col = row[mask], col[mask] + return np.stack([row, col], axis=0), batch + + +def radius_full(x, y, batch_x=None, batch_y=None): + r""" + Find all points in `x` for each element in `y`. + + Args: + x (Tensor): node feature matrix. + y (Tensor): node feature matrix. + batch_x (ndarray): batch vector of x. If it is none, then calculate based on x and return. Default: ``None``. + batch_y (ndarray): batch vector of y. If it is none, then calculate based on y and return. Default: ``None``. + + Returns: + edge_index (numpy.ndarray) - including edges of source and destination. + + batch_x (numpy.ndarray) - batch vector of x. + + batch_y (numpy.ndarray) - batch vector of y. + + Raises: + ValueError: If the last dimension of `x` and `y` do not match. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.utils import radius_full + >>> from mindspore import ops, Tensor + >>> x = Tensor(ops.ones((5, 12, 3))) + >>> edge_index, batch_x, batch_y = radius_full(x, x) + >>> print(edge_index.shape) + (2, 720) + >>> print(batch_x.shape) + (60,) + >>> print(batch_y.shape) + (60,) + + """ + if not x.shape[-1] == y.shape[-1]: + raise ValueError(f"Feature size do not match.") + + if x.ndim > 2 and y.ndim > 2: + b_x, b_y = x.shape[0], y.shape[0] + len_x, len_y = x.shape[1], y.shape[1] + else: + b_x, b_y = 1, 1 + len_x, len_y = x.shape[0], y.shape[0] + + x, batch_x = _reshape_and_batch(x, batch_x) + y, batch_y = _reshape_and_batch(y, batch_y) + + batch_unique = np.unique(batch_x) + _row = [] + edge_dst = [] + for i in batch_unique: + _row.extend(np.arange(len_y) + i * len_y) + _col = np.arange(len_x) + i * len_x + edge_dst.extend(np.broadcast_to(_col, (len_y, len_x)).flatten()) + edge_src = np.broadcast_to(np.array(_row).reshape(-1, 1), (len(_row), len_x)).flatten() + edge_dst = np.array(edge_dst) + + return np.stack([edge_src, edge_dst]), batch_x, batch_y + + +def radius_graph_full(x, batch=None, loop=False, flow='source_to_target'): + r""" + Computes graph edges to all points within a given distance. + + Args: + x (Tensor): node feature matrix. + batch (Tensor): batch vector. If it is none, then calculate and return. Default: ``None``. + loop (bool): whether contain self-loops in the graph. Dufault: ``False``. + flow (str): {'source_to_target', 'target_to_source'}, the flow direction when using in combination with + message passing. Dufault: ``'source_to_target'``. + + Returns: + edge_index (ndarray) - including edges of source and destination. + + batch (ndarray) - batch vector. + + Raises: + ValueError: If `flow` is not in {'source_to_target', 'target_to_source'}. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindchemistry.e3.utils import radius_graph_full + >>> from mindspore import ops, Tensor + >>> x = Tensor(ops.ones((5, 12, 3))) + >>> edge_index, batch = radius_graph_full(x) + >>> print(edge_index.shape) + (2, 660) + >>> print(batch.shape) + (60,) + + """ + if flow not in ['source_to_target', 'target_to_source']: + raise ValueError(f'`flow` should be in ["source_to_target", "target_to_source"].') + + (row, col), batch, _ = radius_full(x, x, batch, batch) + row, col = (col, row) if flow == 'source_to_target' else (row, col) + if not loop: + mask = row != col + row, col = row[mask], col[mask] + + return np.stack([row, col], axis=0), batch -- Gitee From 83331c6def282b6aa27e5362775678cba7ab3c9e Mon Sep 17 00:00:00 2001 From: birfied Date: Thu, 18 Sep 2025 15:56:48 +0800 Subject: [PATCH 04/21] refactor: split spherical_harmonics into per-order helper functions --- mindscience/e3nn/o3/spherical_harmonics.py | 259 ++++++++++----------- 1 file changed, 125 insertions(+), 134 deletions(-) diff --git a/mindscience/e3nn/o3/spherical_harmonics.py b/mindscience/e3nn/o3/spherical_harmonics.py index 392736735..c8986b45e 100644 --- a/mindscience/e3nn/o3/spherical_harmonics.py +++ b/mindscience/e3nn/o3/spherical_harmonics.py @@ -206,25 +206,17 @@ def spherical_harmonics(l, x, normalize=True, normalization='integral'): sh = SphericalHarmonics(l, normalize, normalization, dtype=x.dtype) return sh(x) - -def _spherical_harmonics(lmax: int, x, y, z): - """core functions of spherical harmonics""" - +def _sh0(x, y, z): sh_0_0 = ops.ones_like(x) - if lmax == 0: - return ops.stack([ - sh_0_0, - ], axis=-1) + return [sh_0_0] +def _sh1(x, y, z): sh_1_0 = x sh_1_1 = y sh_1_2 = z - if lmax == 1: - return ops.stack([ - sh_0_0, - sh_1_0, sh_1_1, sh_1_2 - ], axis=-1) + return [sh_1_0, sh_1_1, sh_1_2] +def _sh2(x, y, z): sh_2_0 = 1.7320508075688772 * x * z sh_2_1 = 1.7320508075688772 * x * y y2 = y.pow(2) @@ -232,14 +224,12 @@ def _spherical_harmonics(lmax: int, x, y, z): sh_2_2 = y2 - 0.5 * x2z2 sh_2_3 = 1.7320508075688772 * y * z sh_2_4 = 1.7320508075688772 / 2.0 * (z.pow(2) - x.pow(2)) + return [sh_2_0, sh_2_1, sh_2_2, sh_2_3, sh_2_4] - if lmax == 2: - return ops.stack([ - sh_0_0, - sh_1_0, sh_1_1, sh_1_2, - sh_2_0, sh_2_1, sh_2_2, sh_2_3, sh_2_4 - ], axis=-1) - +def _sh3(x, y, z, prev): + sh_2_0, sh_2_1, sh_2_2, sh_2_3, sh_2_4 = prev + y2 = y.pow(2) + x2z2 = x.pow(2) + z.pow(2) sh_3_0 = 0.9128709291752769 * (sh_2_0 * z + sh_2_4 * x) sh_3_1 = 2.23606797749979 * sh_2_0 * y sh_3_2 = 0.6123724356957945 * (4.0 * y2 - x2z2) * x @@ -247,15 +237,10 @@ def _spherical_harmonics(lmax: int, x, y, z): sh_3_4 = 0.6123724356957945 * z * (4.0 * y2 - x2z2) sh_3_5 = 2.23606797749979 * sh_2_4 * y sh_3_6 = 0.9128709291752769 * (sh_2_4 * z - sh_2_0 * x) + return [sh_3_0, sh_3_1, sh_3_2, sh_3_3, sh_3_4, sh_3_5, sh_3_6] - if lmax == 3: - return ops.stack([ - sh_0_0, - sh_1_0, sh_1_1, sh_1_2, - sh_2_0, sh_2_1, sh_2_2, sh_2_3, sh_2_4, - sh_3_0, sh_3_1, sh_3_2, sh_3_3, sh_3_4, sh_3_5, sh_3_6 - ], axis=-1) - +def _sh4(x, y, z, prev): + sh_3_0, sh_3_1, sh_3_2, sh_3_3, sh_3_4, sh_3_5, sh_3_6 = prev sh_4_0 = 0.935414346693485 * sh_3_0 * z + 0.935414346693485 * sh_3_6 * x sh_4_1 = 0.661437827766148 * sh_3_0 * y + 0.810092587300982 * \ sh_3_1 * z + 0.810092587300983 * sh_3_5 * x @@ -274,15 +259,11 @@ def _spherical_harmonics(lmax: int, x, y, z): sh_4_7 = -0.810092587300982 * sh_3_1 * x + 0.810092587300982 * \ sh_3_5 * z + 0.661437827766148 * sh_3_6 * y sh_4_8 = -0.935414346693485 * sh_3_0 * x + 0.935414346693486 * sh_3_6 * z - if lmax == 4: - return ops.stack([ - sh_0_0, - sh_1_0, sh_1_1, sh_1_2, - sh_2_0, sh_2_1, sh_2_2, sh_2_3, sh_2_4, - sh_3_0, sh_3_1, sh_3_2, sh_3_3, sh_3_4, sh_3_5, sh_3_6, - sh_4_0, sh_4_1, sh_4_2, sh_4_3, sh_4_4, sh_4_5, sh_4_6, sh_4_7, sh_4_8 - ], axis=-1) + return [sh_4_0, sh_4_1, sh_4_2, sh_4_3, sh_4_4, sh_4_5, sh_4_6, sh_4_7, sh_4_8] +def _sh5(x, y, z, prev): + sh_4_0, sh_4_1, sh_4_2, sh_4_3, sh_4_4, sh_4_5, sh_4_6, sh_4_7, sh_4_8 = prev + sh_5_0 = 0.948683298050513 * sh_4_0 * z + 0.948683298050513 * sh_4_8 * x sh_5_1 = 0.6 * sh_4_0 * y + 0.848528137423857 * \ sh_4_1 * z + 0.848528137423858 * sh_4_7 * x @@ -306,15 +287,10 @@ def _spherical_harmonics(lmax: int, x, y, z): sh_5_9 = -0.848528137423857 * sh_4_1 * x + \ 0.848528137423857 * sh_4_7 * z + 0.6 * sh_4_8 * y sh_5_10 = -0.948683298050513 * sh_4_0 * x + 0.948683298050513 * sh_4_8 * z - if lmax == 5: - return ops.stack([ - sh_0_0, - sh_1_0, sh_1_1, sh_1_2, - sh_2_0, sh_2_1, sh_2_2, sh_2_3, sh_2_4, - sh_3_0, sh_3_1, sh_3_2, sh_3_3, sh_3_4, sh_3_5, sh_3_6, - sh_4_0, sh_4_1, sh_4_2, sh_4_3, sh_4_4, sh_4_5, sh_4_6, sh_4_7, sh_4_8, - sh_5_0, sh_5_1, sh_5_2, sh_5_3, sh_5_4, sh_5_5, sh_5_6, sh_5_7, sh_5_8, sh_5_9, sh_5_10 - ], axis=-1) + return [sh_5_0, sh_5_1, sh_5_2, sh_5_3, sh_5_4, sh_5_5, sh_5_6, sh_5_7, sh_5_8, sh_5_9, sh_5_10] + +def _sh6(x, y, z, prev): + sh_5_0, sh_5_1, sh_5_2, sh_5_3, sh_5_4, sh_5_5, sh_5_6, sh_5_7, sh_5_8, sh_5_9, sh_5_10 = prev sh_6_0 = 0.957427107756337 * sh_5_0 * z + 0.957427107756338 * sh_5_10 * x sh_6_1 = 0.552770798392565 * sh_5_0 * y + 0.874007373475125 * \ @@ -346,16 +322,10 @@ def _spherical_harmonics(lmax: int, x, y, z): sh_6_11 = -0.874007373475124 * sh_5_1 * x + 0.552770798392566 * \ sh_5_10 * y + 0.874007373475125 * sh_5_9 * z sh_6_12 = -0.957427107756337 * sh_5_0 * x + 0.957427107756336 * sh_5_10 * z - if lmax == 6: - return ops.stack([ - sh_0_0, - sh_1_0, sh_1_1, sh_1_2, - sh_2_0, sh_2_1, sh_2_2, sh_2_3, sh_2_4, - sh_3_0, sh_3_1, sh_3_2, sh_3_3, sh_3_4, sh_3_5, sh_3_6, - sh_4_0, sh_4_1, sh_4_2, sh_4_3, sh_4_4, sh_4_5, sh_4_6, sh_4_7, sh_4_8, - sh_5_0, sh_5_1, sh_5_2, sh_5_3, sh_5_4, sh_5_5, sh_5_6, sh_5_7, sh_5_8, sh_5_9, sh_5_10, - sh_6_0, sh_6_1, sh_6_2, sh_6_3, sh_6_4, sh_6_5, sh_6_6, sh_6_7, sh_6_8, sh_6_9, sh_6_10, sh_6_11, sh_6_12 - ], axis=-1) + return [sh_6_0, sh_6_1, sh_6_2, sh_6_3, sh_6_4, sh_6_5, sh_6_6, sh_6_7, sh_6_8, sh_6_9, sh_6_10, sh_6_11, sh_6_12] + +def _sh7(x, y, z, prev): + sh_6_0, sh_6_1, sh_6_2, sh_6_3, sh_6_4, sh_6_5, sh_6_6, sh_6_7, sh_6_8, sh_6_9, sh_6_10, sh_6_11, sh_6_12 = prev sh_7_0 = 0.963624111659433 * sh_6_0 * z + 0.963624111659432 * sh_6_12 * x sh_7_1 = 0.515078753637713 * sh_6_0 * y + 0.892142571199771 * \ @@ -393,18 +363,11 @@ def _spherical_harmonics(lmax: int, x, y, z): sh_7_13 = -0.892142571199772 * sh_6_1 * x + 0.892142571199772 * \ sh_6_11 * z + 0.515078753637713 * sh_6_12 * y sh_7_14 = -0.963624111659431 * sh_6_0 * x + 0.963624111659433 * sh_6_12 * z - if lmax == 7: - return ops.stack([ - sh_0_0, - sh_1_0, sh_1_1, sh_1_2, - sh_2_0, sh_2_1, sh_2_2, sh_2_3, sh_2_4, - sh_3_0, sh_3_1, sh_3_2, sh_3_3, sh_3_4, sh_3_5, sh_3_6, - sh_4_0, sh_4_1, sh_4_2, sh_4_3, sh_4_4, sh_4_5, sh_4_6, sh_4_7, sh_4_8, - sh_5_0, sh_5_1, sh_5_2, sh_5_3, sh_5_4, sh_5_5, sh_5_6, sh_5_7, sh_5_8, sh_5_9, sh_5_10, - sh_6_0, sh_6_1, sh_6_2, sh_6_3, sh_6_4, sh_6_5, sh_6_6, sh_6_7, sh_6_8, sh_6_9, sh_6_10, sh_6_11, sh_6_12, - sh_7_0, sh_7_1, sh_7_2, sh_7_3, sh_7_4, sh_7_5, sh_7_6, sh_7_7, sh_7_8, sh_7_9, sh_7_10, sh_7_11, sh_7_12, - sh_7_13, sh_7_14 - ], axis=-1) + + return [sh_7_0, sh_7_1, sh_7_2, sh_7_3, sh_7_4, sh_7_5, sh_7_6, sh_7_7, sh_7_8, sh_7_9, sh_7_10, sh_7_11, sh_7_12, sh_7_13, sh_7_14] + +def _sh8(x, y, z, prev): + sh_7_0, sh_7_1, sh_7_2, sh_7_3, sh_7_4, sh_7_5, sh_7_6, sh_7_7, sh_7_8, sh_7_9, sh_7_10, sh_7_11, sh_7_12, sh_7_13, sh_7_14 = prev sh_8_0 = 0.968245836551854 * sh_7_0 * z + 0.968245836551853 * sh_7_14 * x sh_8_1 = 0.484122918275928 * sh_7_0 * y + 0.90571104663684 * \ @@ -448,20 +411,11 @@ def _spherical_harmonics(lmax: int, x, y, z): sh_8_15 = -0.90571104663684 * sh_7_1 * x + 0.90571104663684 * \ sh_7_13 * z + 0.484122918275927 * sh_7_14 * y sh_8_16 = -0.968245836551853 * sh_7_0 * x + 0.968245836551855 * sh_7_14 * z - if lmax == 8: - return ops.stack([ - sh_0_0, - sh_1_0, sh_1_1, sh_1_2, - sh_2_0, sh_2_1, sh_2_2, sh_2_3, sh_2_4, - sh_3_0, sh_3_1, sh_3_2, sh_3_3, sh_3_4, sh_3_5, sh_3_6, - sh_4_0, sh_4_1, sh_4_2, sh_4_3, sh_4_4, sh_4_5, sh_4_6, sh_4_7, sh_4_8, - sh_5_0, sh_5_1, sh_5_2, sh_5_3, sh_5_4, sh_5_5, sh_5_6, sh_5_7, sh_5_8, sh_5_9, sh_5_10, - sh_6_0, sh_6_1, sh_6_2, sh_6_3, sh_6_4, sh_6_5, sh_6_6, sh_6_7, sh_6_8, sh_6_9, sh_6_10, sh_6_11, sh_6_12, - sh_7_0, sh_7_1, sh_7_2, sh_7_3, sh_7_4, sh_7_5, sh_7_6, sh_7_7, sh_7_8, sh_7_9, sh_7_10, sh_7_11, sh_7_12, - sh_7_13, sh_7_14, - sh_8_0, sh_8_1, sh_8_2, sh_8_3, sh_8_4, sh_8_5, sh_8_6, sh_8_7, sh_8_8, sh_8_9, sh_8_10, sh_8_11, sh_8_12, - sh_8_13, sh_8_14, sh_8_15, sh_8_16 - ], axis=-1) + + return [sh_8_0, sh_8_1, sh_8_2, sh_8_3, sh_8_4, sh_8_5, sh_8_6, sh_8_7, sh_8_8, sh_8_9, sh_8_10, sh_8_11, sh_8_12, sh_8_13, sh_8_14, sh_8_15, sh_8_16] + +def _sh9(x, y, z, prev): + sh_8_0, sh_8_1, sh_8_2, sh_8_3, sh_8_4, sh_8_5, sh_8_6, sh_8_7, sh_8_8, sh_8_9, sh_8_10, sh_8_11, sh_8_12, sh_8_13, sh_8_14, sh_8_15, sh_8_16 = prev sh_9_0 = 0.97182531580755 * sh_8_0 * z + 0.971825315807551 * sh_8_16 * x sh_9_1 = 0.458122847290851 * sh_8_0 * y + 0.916245694581702 * \ @@ -510,22 +464,11 @@ def _spherical_harmonics(lmax: int, x, y, z): sh_9_17 = -0.9162456945817 * sh_8_1 * x + 0.916245694581702 * \ sh_8_15 * z + 0.458122847290851 * sh_8_16 * y sh_9_18 = -0.97182531580755 * sh_8_0 * x + 0.97182531580755 * sh_8_16 * z - if lmax == 9: - return ops.stack([ - sh_0_0, - sh_1_0, sh_1_1, sh_1_2, - sh_2_0, sh_2_1, sh_2_2, sh_2_3, sh_2_4, - sh_3_0, sh_3_1, sh_3_2, sh_3_3, sh_3_4, sh_3_5, sh_3_6, - sh_4_0, sh_4_1, sh_4_2, sh_4_3, sh_4_4, sh_4_5, sh_4_6, sh_4_7, sh_4_8, - sh_5_0, sh_5_1, sh_5_2, sh_5_3, sh_5_4, sh_5_5, sh_5_6, sh_5_7, sh_5_8, sh_5_9, sh_5_10, - sh_6_0, sh_6_1, sh_6_2, sh_6_3, sh_6_4, sh_6_5, sh_6_6, sh_6_7, sh_6_8, sh_6_9, sh_6_10, sh_6_11, sh_6_12, - sh_7_0, sh_7_1, sh_7_2, sh_7_3, sh_7_4, sh_7_5, sh_7_6, sh_7_7, sh_7_8, sh_7_9, sh_7_10, sh_7_11, sh_7_12, - sh_7_13, sh_7_14, - sh_8_0, sh_8_1, sh_8_2, sh_8_3, sh_8_4, sh_8_5, sh_8_6, sh_8_7, sh_8_8, sh_8_9, sh_8_10, sh_8_11, sh_8_12, - sh_8_13, sh_8_14, sh_8_15, sh_8_16, - sh_9_0, sh_9_1, sh_9_2, sh_9_3, sh_9_4, sh_9_5, sh_9_6, sh_9_7, sh_9_8, sh_9_9, sh_9_10, sh_9_11, sh_9_12, - sh_9_13, sh_9_14, sh_9_15, sh_9_16, sh_9_17, sh_9_18 - ], axis=-1) + + return [sh_9_0, sh_9_1, sh_9_2, sh_9_3, sh_9_4, sh_9_5, sh_9_6, sh_9_7, sh_9_8, sh_9_9, sh_9_10, sh_9_11, sh_9_12, sh_9_13, sh_9_14, sh_9_15, sh_9_16, sh_9_17, sh_9_18] + +def _sh10(x, y, z, prev): + sh_9_0, sh_9_1, sh_9_2, sh_9_3, sh_9_4, sh_9_5, sh_9_6, sh_9_7, sh_9_8, sh_9_9, sh_9_10, sh_9_11, sh_9_12, sh_9_13, sh_9_14, sh_9_15, sh_9_16, sh_9_17, sh_9_18 = prev sh_10_0 = 0.974679434480897 * sh_9_0 * z + 0.974679434480897 * sh_9_18 * x sh_10_1 = 0.435889894354067 * sh_9_0 * y + 0.924662100445347 * \ @@ -579,24 +522,11 @@ def _spherical_harmonics(lmax: int, x, y, z): sh_10_19 = -0.924662100445348 * sh_9_1 * x + 0.924662100445347 * \ sh_9_17 * z + 0.435889894354068 * sh_9_18 * y sh_10_20 = -0.974679434480898 * sh_9_0 * x + 0.974679434480896 * sh_9_18 * z - if lmax == 10: - return ops.stack([ - sh_0_0, - sh_1_0, sh_1_1, sh_1_2, - sh_2_0, sh_2_1, sh_2_2, sh_2_3, sh_2_4, - sh_3_0, sh_3_1, sh_3_2, sh_3_3, sh_3_4, sh_3_5, sh_3_6, - sh_4_0, sh_4_1, sh_4_2, sh_4_3, sh_4_4, sh_4_5, sh_4_6, sh_4_7, sh_4_8, - sh_5_0, sh_5_1, sh_5_2, sh_5_3, sh_5_4, sh_5_5, sh_5_6, sh_5_7, sh_5_8, sh_5_9, sh_5_10, - sh_6_0, sh_6_1, sh_6_2, sh_6_3, sh_6_4, sh_6_5, sh_6_6, sh_6_7, sh_6_8, sh_6_9, sh_6_10, sh_6_11, sh_6_12, - sh_7_0, sh_7_1, sh_7_2, sh_7_3, sh_7_4, sh_7_5, sh_7_6, sh_7_7, sh_7_8, sh_7_9, sh_7_10, sh_7_11, sh_7_12, - sh_7_13, sh_7_14, - sh_8_0, sh_8_1, sh_8_2, sh_8_3, sh_8_4, sh_8_5, sh_8_6, sh_8_7, sh_8_8, sh_8_9, sh_8_10, sh_8_11, sh_8_12, - sh_8_13, sh_8_14, sh_8_15, sh_8_16, - sh_9_0, sh_9_1, sh_9_2, sh_9_3, sh_9_4, sh_9_5, sh_9_6, sh_9_7, sh_9_8, sh_9_9, sh_9_10, sh_9_11, sh_9_12, - sh_9_13, sh_9_14, sh_9_15, sh_9_16, sh_9_17, sh_9_18, - sh_10_0, sh_10_1, sh_10_2, sh_10_3, sh_10_4, sh_10_5, sh_10_6, sh_10_7, sh_10_8, sh_10_9, sh_10_10, - sh_10_11, sh_10_12, sh_10_13, sh_10_14, sh_10_15, sh_10_16, sh_10_17, sh_10_18, sh_10_19, sh_10_20 - ], axis=-1) + + return [sh_10_0, sh_10_1, sh_10_2, sh_10_3, sh_10_4, sh_10_5, sh_10_6, sh_10_7, sh_10_8, sh_10_9, sh_10_10, sh_10_11, sh_10_12, sh_10_13, sh_10_14, sh_10_15, sh_10_16, sh_10_17, sh_10_18, sh_10_19, sh_10_20] + +def _sh11(x, y, z, prev): + sh_10_0, sh_10_1, sh_10_2, sh_10_3, sh_10_4, sh_10_5, sh_10_6, sh_10_7, sh_10_8, sh_10_9, sh_10_10, sh_10_11, sh_10_12, sh_10_13, sh_10_14, sh_10_15, sh_10_16, sh_10_17, sh_10_18, sh_10_19, sh_10_20 = prev sh_11_0 = 0.977008420918394 * sh_10_0 * z + 0.977008420918394 * sh_10_20 * x sh_11_1 = 0.416597790450531 * sh_10_0 * y + 0.9315409787236 * \ @@ -658,22 +588,83 @@ def _spherical_harmonics(lmax: int, x, y, z): sh_11_21 = -0.9315409787236 * sh_10_1 * x + 0.931540978723599 * \ sh_10_19 * z + 0.416597790450531 * sh_10_20 * y sh_11_22 = -0.977008420918393 * sh_10_0 * x + 0.977008420918393 * sh_10_20 * z - return ops.stack([ - sh_0_0, - sh_1_0, sh_1_1, sh_1_2, - sh_2_0, sh_2_1, sh_2_2, sh_2_3, sh_2_4, - sh_3_0, sh_3_1, sh_3_2, sh_3_3, sh_3_4, sh_3_5, sh_3_6, - sh_4_0, sh_4_1, sh_4_2, sh_4_3, sh_4_4, sh_4_5, sh_4_6, sh_4_7, sh_4_8, - sh_5_0, sh_5_1, sh_5_2, sh_5_3, sh_5_4, sh_5_5, sh_5_6, sh_5_7, sh_5_8, sh_5_9, sh_5_10, - sh_6_0, sh_6_1, sh_6_2, sh_6_3, sh_6_4, sh_6_5, sh_6_6, sh_6_7, sh_6_8, sh_6_9, sh_6_10, sh_6_11, sh_6_12, - sh_7_0, sh_7_1, sh_7_2, sh_7_3, sh_7_4, sh_7_5, sh_7_6, sh_7_7, sh_7_8, sh_7_9, sh_7_10, sh_7_11, sh_7_12, - sh_7_13, sh_7_14, - sh_8_0, sh_8_1, sh_8_2, sh_8_3, sh_8_4, sh_8_5, sh_8_6, sh_8_7, sh_8_8, sh_8_9, sh_8_10, sh_8_11, sh_8_12, - sh_8_13, sh_8_14, sh_8_15, sh_8_16, - sh_9_0, sh_9_1, sh_9_2, sh_9_3, sh_9_4, sh_9_5, sh_9_6, sh_9_7, sh_9_8, sh_9_9, sh_9_10, sh_9_11, sh_9_12, - sh_9_13, sh_9_14, sh_9_15, sh_9_16, sh_9_17, sh_9_18, - sh_10_0, sh_10_1, sh_10_2, sh_10_3, sh_10_4, sh_10_5, sh_10_6, sh_10_7, sh_10_8, sh_10_9, sh_10_10, sh_10_11, - sh_10_12, sh_10_13, sh_10_14, sh_10_15, sh_10_16, sh_10_17, sh_10_18, sh_10_19, sh_10_20, - sh_11_0, sh_11_1, sh_11_2, sh_11_3, sh_11_4, sh_11_5, sh_11_6, sh_11_7, sh_11_8, sh_11_9, sh_11_10, sh_11_11, - sh_11_12, sh_11_13, sh_11_14, sh_11_15, sh_11_16, sh_11_17, sh_11_18, sh_11_19, sh_11_20, sh_11_21, sh_11_22 - ], axis=-1) + + return [sh_11_0, sh_11_1, sh_11_2, sh_11_3, sh_11_4, sh_11_5, sh_11_6, sh_11_7, sh_11_8, sh_11_9, sh_11_10, sh_11_11, sh_11_12, sh_11_13, sh_11_14, sh_11_15, sh_11_16, sh_11_17, sh_11_18, sh_11_19, sh_11_20, sh_11_21, sh_11_22] + +def _spherical_harmonics(lmax: int, x, y, z): + results = [] + + # l = 0 + sh0 = _sh0(x, y, z) + results.extend(sh0) + if lmax == 0: + return ops.stack(results, axis=-1) + + # l = 1 + sh1 = _sh1(x, y, z) + results.extend(sh1) + if lmax == 1: + return ops.stack(results, axis=-1) + + # l = 2 + sh2 = _sh2(x, y, z) + results.extend(sh2) + if lmax == 2: + return ops.stack(results, axis=-1) + + # l = 3 + sh3 = _sh3(x, y, z, sh2) + results.extend(sh3) + if lmax == 3: + return ops.stack(results, axis=-1) + + # l = 4 + sh4 = _sh4(x, y, z, sh3) + results.extend(sh4) + if lmax == 4: + return ops.stack(results, axis=-1) + + # l = 5 + sh5 = _sh5(x, y, z, sh4) + results.extend(sh5) + if lmax == 5: + return ops.stack(results, axis=-1) + + # l = 6 + sh6 = _sh6(x, y, z, sh5) + results.extend(sh6) + if lmax == 6: + return ops.stack(results, axis=-1) + + # l = 7 + sh7 = _sh7(x, y, z, sh6) + results.extend(sh7) + if lmax == 7: + return ops.stack(results, axis=-1) + + # l = 8 + sh8 = _sh8(x, y, z, sh7) + results.extend(sh8) + if lmax == 8: + return ops.stack(results, axis=-1) + + # l = 9 + sh9 = _sh9(x, y, z, sh8) + results.extend(sh9) + if lmax == 9: + return ops.stack(results, axis=-1) + + # l = 10 + sh10 = _sh10(x, y, z, sh9) + results.extend(sh10) + if lmax == 10: + return ops.stack(results, axis=-1) + + # l = 11 + sh11 = _sh11(x, y, z, sh10) + results.extend(sh11) + if lmax == 11: + return ops.stack(results, axis=-1) + + # 默认返回最高阶 (l=11) + return ops.stack(results, axis=-1) -- Gitee From 9e34d807376360539c4c97aef5742cbb5b6a4814 Mon Sep 17 00:00:00 2001 From: birfied Date: Thu, 18 Sep 2025 16:12:18 +0800 Subject: [PATCH 05/21] correct pylint errors --- mindscience/e3nn/nn/__init__.py | 2 +- mindscience/e3nn/o3/__init__.py | 2 +- mindscience/e3nn/utils/__init__.py | 2 +- mindscience/e3nn/utils/batch_dot.py | 33 +++++++++++++++------- mindscience/e3nn/utils/func.py | 23 +++++++++------ mindscience/e3nn/utils/initializer.py | 24 +++++++++++++++- mindscience/e3nn/utils/linalg.py | 7 +++++ mindscience/e3nn/utils/perm.py | 40 ++++++++++++++++++++++++++- 8 files changed, 110 insertions(+), 23 deletions(-) diff --git a/mindscience/e3nn/nn/__init__.py b/mindscience/e3nn/nn/__init__.py index 17d4a7118..63d292a9c 100644 --- a/mindscience/e3nn/nn/__init__.py +++ b/mindscience/e3nn/nn/__init__.py @@ -32,4 +32,4 @@ __all__ = [ "soft_unit_step", "OneHot", "BatchNorm" -] \ No newline at end of file +] diff --git a/mindscience/e3nn/o3/__init__.py b/mindscience/e3nn/o3/__init__.py index 4f9b6d853..44b27fe6d 100644 --- a/mindscience/e3nn/o3/__init__.py +++ b/mindscience/e3nn/o3/__init__.py @@ -48,4 +48,4 @@ __all__ = [ "Linear", "TensorSquare", "Norm", -] \ No newline at end of file +] diff --git a/mindscience/e3nn/utils/__init__.py b/mindscience/e3nn/utils/__init__.py index c67bd0f61..2161cda3e 100644 --- a/mindscience/e3nn/utils/__init__.py +++ b/mindscience/e3nn/utils/__init__.py @@ -23,4 +23,4 @@ __all__ = [ "radius_graph", "radius_full", "radius_graph_full", -] \ No newline at end of file +] diff --git a/mindscience/e3nn/utils/batch_dot.py b/mindscience/e3nn/utils/batch_dot.py index cf578e6bd..8a7b781ea 100644 --- a/mindscience/e3nn/utils/batch_dot.py +++ b/mindscience/e3nn/utils/batch_dot.py @@ -12,19 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ +""" +Batch dot product operations for tensor computations. + +This module provides utilities for performing batch-wise dot products +between tensors with support for various axis configurations. +""" from mindspore.ops.primitive import constexpr from mindspore.ops import functional as F from mindspore.ops import operations as P -@constexpr -def _get_batch_size(x1_shape, x2_shape): - """ - Get batch sizes from two inputs - """ - return x1_shape[0], x2_shape[0] - - @constexpr def _get_batch_size(x1_shape, x2_shape): """ @@ -105,6 +103,21 @@ def _get_output_shape(batch_size, x1_ret, x2_ret): def batch_dot(x1, x2, axes=None): + """ + Compute batch-wise dot product of two tensors. + + Args: + x1 (Tensor): First input tensor with shape (batch_size, ...). + x2 (Tensor): Second input tensor with shape (batch_size, ...). + axes (int, list, tuple, optional): Axes to perform dot product along. + If None, defaults to the last axis of x1 and second-to-last axis of x2. + + Returns: + Tensor: The batch dot product result. + + Raises: + ValueError: If batch sizes of x1 and x2 don't match. + """ transpose_op = P.Transpose() batch_matmul_op = P.BatchMatMul() squeeze_one_op = P.Squeeze(1) @@ -115,10 +128,10 @@ def batch_dot(x1, x2, axes=None): x1_dim_num = len(x1_shape) x2_dim_num = len(x2_shape) - x1_batch_size, x2_batch_size = _get_batch_size(x1_shape, x2_shape, 'batch_dot') + x1_batch_size, x2_batch_size = _get_batch_size(x1_shape, x2_shape) _check_batch_size(x1_batch_size, x2_batch_size, 'batch_dot') - axes = _check_axes_for_batch_dot(x1_shape, x2_shape, axes, 'batch_dot') + axes = _check_axes_for_batch_dot(x1_shape, x2_shape, axes) if x1_dim_num == 2: x1 = F.expand_dims(x1, 1) diff --git a/mindscience/e3nn/utils/func.py b/mindscience/e3nn/utils/func.py index bdd06099c..07bf6a57d 100644 --- a/mindscience/e3nn/utils/func.py +++ b/mindscience/e3nn/utils/func.py @@ -12,6 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ +""" +Utility functions for tensor operations and broadcasting. + +This module provides various utility functions for tensor manipulation, +broadcasting operations, and mathematical computations commonly used +in the e3nn library. +""" import numpy as np from scipy.linalg import expm @@ -36,12 +43,11 @@ def norm_keep(input_x, axis): def _to_tensor(arg): if isinstance(arg, (int, float)): return Tensor(arg) - elif isinstance(arg, (np.ndarray, list, tuple)): + if isinstance(arg, (np.ndarray, list, tuple)): return Tensor(arg) - elif isinstance(arg, Tensor): + if isinstance(arg, Tensor): return arg - else: - raise TypeError + raise TypeError def broadcast_shapes(*shapes): @@ -59,7 +65,7 @@ def broadcast_shapes(*shapes): if isinstance(shape, int): if max_len < 1: max_len = 1 - elif isinstance(shape, tuple) or isinstance(shape, list): + elif isinstance(shape, (list, tuple)): s = len(shape) if max_len < s: max_len = s @@ -67,7 +73,7 @@ def broadcast_shapes(*shapes): for shape in shapes: if isinstance(shape, int): shape = (shape,) - if isinstance(shape, tuple) or isinstance(shape, list): + if isinstance(shape, (list, tuple)): for i in range(-1, -1 - len(shape), -1): if shape[i] < 0: raise RuntimeError("Trying to create tensor with negative dimension ({}): ({})" @@ -100,7 +106,7 @@ def broadcast_tensors(*tensors): shape = broadcast_shapes(*shapes) res = [] for tensor in tensors: - if len(shape): + if shape: res.append(ops.broadcast_to(tensor, shape)) else: res.append(tensor) @@ -112,7 +118,8 @@ def broadcast_args(*args): Broadcasts the given data with multiple types. Args: - *arg (Union[Tensor[float32], list[float], tuple[float], ndarray[np.float32], float]): Any number of data to be broadcasted. + *arg (Union[Tensor[float32], list[float], tuple[float], + ndarray[np.float32], float]): Any number of data to be broadcasted. Returns: A list of tensors, tensors after broadcast. diff --git a/mindscience/e3nn/utils/initializer.py b/mindscience/e3nn/utils/initializer.py index e4f4d46a7..3e9a75f25 100644 --- a/mindscience/e3nn/utils/initializer.py +++ b/mindscience/e3nn/utils/initializer.py @@ -12,9 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ +""" +Custom initializers for neural network parameters. + +This module provides custom weight initialization methods including +uniform distribution initializers and renormalization utilities +for various initialization schemes. +""" from mindspore.common.initializer import Initializer, _register, _init_random_uniform, _assignment, TruncatedNormal, \ - Normal, HeNormal, HeUniform, XavierUniform + Normal, HeNormal, HeUniform @_register() @@ -44,6 +51,21 @@ class Uniform(Initializer): def renormal_initializer(init_method): + """ + Normalize and convert initialization method to proper initializer instance. + + Args: + init_method (str or Initializer): The initialization method name or + an Initializer instance. Supported string values are: + 'zeros', 'ones', 'truncatedNormal', 'normal', 'uniform', + 'he_uniform', 'he_normal', 'xavier_uniform'. + + Returns: + Initializer: The corresponding initializer instance. + + Raises: + ValueError: If the initialization method is not supported. + """ name_list = ['zeros', 'ones', 'truncatedNormal', 'normal', 'uniform', 'he_uniform', 'he_normal', 'xavier_uniform'] if not init_method in name_list and not isinstance(init_method, Initializer): raise ValueError( diff --git a/mindscience/e3nn/utils/linalg.py b/mindscience/e3nn/utils/linalg.py index e43c7e9cb..62ee37f2e 100644 --- a/mindscience/e3nn/utils/linalg.py +++ b/mindscience/e3nn/utils/linalg.py @@ -12,6 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ +""" +Linear algebra utilities for matrix operations. + +This module provides utility functions for linear algebra operations, +including matrix direct sum and other matrix manipulation functions +commonly used in the e3nn library. +""" from mindspore import ops diff --git a/mindscience/e3nn/utils/perm.py b/mindscience/e3nn/utils/perm.py index 5b8acd85c..6fcee00f9 100644 --- a/mindscience/e3nn/utils/perm.py +++ b/mindscience/e3nn/utils/perm.py @@ -61,6 +61,15 @@ def _from_int(i, n): def _to_int(p): + """ + Convert a permutation to its integer representation. + + Args: + p (tuple): A permutation represented as a tuple. + + Returns: + int: The integer representation of the permutation. + """ n = len(p) pool = list(range(n)) i = 0 @@ -78,6 +87,16 @@ def _group(n): def _germinate(subset): + """ + Generate the group closure of a subset of permutations. + + Args: + subset (set): A set of permutations. + + Returns: + set: The group closure containing all permutations that can be + generated from the input subset through composition and inversion. + """ while True: n = len(subset) subset = subset.union([_inverse(p) for p in subset]) @@ -91,7 +110,16 @@ def _germinate(subset): def _is__(g): - if len(g) == 0: + """ + Check if a set of permutations forms a group. + + Args: + g (set): A set of permutations to check. + + Returns: + bool: True if the set forms a group, False otherwise. + """ + if not g: return False n = len(next(iter(g))) @@ -115,6 +143,16 @@ def _is__(g): def _to_cycles(p): + """ + Convert a permutation to its cycle representation. + + Args: + p (tuple): A permutation represented as a tuple. + + Returns: + set: A set of tuples representing the cycles of the permutation. + Only cycles of length >= 2 are included. + """ n = len(p) cycles = set() -- Gitee From 05915e608dec86826575565ced184550a7bab97d Mon Sep 17 00:00:00 2001 From: birfied Date: Thu, 18 Sep 2025 16:30:56 +0800 Subject: [PATCH 06/21] correct pylint errors --- mindscience/e3nn/o3/spherical_harmonics.py | 78 ++++++++++++++++++---- mindscience/e3nn/utils/batch_dot.py | 6 +- mindscience/e3nn/utils/func.py | 17 +++-- mindscience/e3nn/utils/initializer.py | 10 +-- mindscience/e3nn/utils/perm.py | 18 ++--- 5 files changed, 91 insertions(+), 38 deletions(-) diff --git a/mindscience/e3nn/o3/spherical_harmonics.py b/mindscience/e3nn/o3/spherical_harmonics.py index c8986b45e..5243a35ea 100644 --- a/mindscience/e3nn/o3/spherical_harmonics.py +++ b/mindscience/e3nn/o3/spherical_harmonics.py @@ -206,17 +206,44 @@ def spherical_harmonics(l, x, normalize=True, normalization='integral'): sh = SphericalHarmonics(l, normalize, normalization, dtype=x.dtype) return sh(x) -def _sh0(x, y, z): +def _sh0(x, _y, _z): + """ + Compute spherical harmonics of degree 0. + + Args: + x (Tensor): Tensor for construct spherical harmonics. The shape of Tensor is :math:`x` of shape ``(..., 3)`` + + Returns: + Tensor, the spherical harmonics :math:`Y^0(x)`. The shape of Tensor is ``(..., 1)``. + """ sh_0_0 = ops.ones_like(x) return [sh_0_0] def _sh1(x, y, z): + """ + Compute spherical harmonics of degree 1. + + Args: + x (Tensor): Tensor for construct spherical harmonics. The shape of Tensor is :math:`x` of shape ``(..., 3)`` + + Returns: + Tensor, the spherical harmonics :math:`Y^1(x)`. The shape of Tensor is ``(..., 3)``. + """ sh_1_0 = x sh_1_1 = y sh_1_2 = z return [sh_1_0, sh_1_1, sh_1_2] def _sh2(x, y, z): + """ + Compute spherical harmonics of degree 2. + + Args: + x (Tensor): Tensor for construct spherical harmonics. The shape of Tensor is :math:`x` of shape ``(..., 3)`` + + Returns: + Tensor, the spherical harmonics :math:`Y^2(x)`. The shape of Tensor is ``(..., 5)``. + """ sh_2_0 = 1.7320508075688772 * x * z sh_2_1 = 1.7320508075688772 * x * y y2 = y.pow(2) @@ -227,7 +254,8 @@ def _sh2(x, y, z): return [sh_2_0, sh_2_1, sh_2_2, sh_2_3, sh_2_4] def _sh3(x, y, z, prev): - sh_2_0, sh_2_1, sh_2_2, sh_2_3, sh_2_4 = prev + """Compute spherical harmonics of degree 3.""" + sh_2_0, _sh_2_1, _sh_2_2, _sh_2_3, sh_2_4 = prev y2 = y.pow(2) x2z2 = x.pow(2) + z.pow(2) sh_3_0 = 0.9128709291752769 * (sh_2_0 * z + sh_2_4 * x) @@ -240,6 +268,7 @@ def _sh3(x, y, z, prev): return [sh_3_0, sh_3_1, sh_3_2, sh_3_3, sh_3_4, sh_3_5, sh_3_6] def _sh4(x, y, z, prev): + """Compute spherical harmonics of degree 4.""" sh_3_0, sh_3_1, sh_3_2, sh_3_3, sh_3_4, sh_3_5, sh_3_6 = prev sh_4_0 = 0.935414346693485 * sh_3_0 * z + 0.935414346693485 * sh_3_6 * x sh_4_1 = 0.661437827766148 * sh_3_0 * y + 0.810092587300982 * \ @@ -262,8 +291,8 @@ def _sh4(x, y, z, prev): return [sh_4_0, sh_4_1, sh_4_2, sh_4_3, sh_4_4, sh_4_5, sh_4_6, sh_4_7, sh_4_8] def _sh5(x, y, z, prev): + """Compute spherical harmonics of degree 5.""" sh_4_0, sh_4_1, sh_4_2, sh_4_3, sh_4_4, sh_4_5, sh_4_6, sh_4_7, sh_4_8 = prev - sh_5_0 = 0.948683298050513 * sh_4_0 * z + 0.948683298050513 * sh_4_8 * x sh_5_1 = 0.6 * sh_4_0 * y + 0.848528137423857 * \ sh_4_1 * z + 0.848528137423858 * sh_4_7 * x @@ -290,6 +319,7 @@ def _sh5(x, y, z, prev): return [sh_5_0, sh_5_1, sh_5_2, sh_5_3, sh_5_4, sh_5_5, sh_5_6, sh_5_7, sh_5_8, sh_5_9, sh_5_10] def _sh6(x, y, z, prev): + """Compute spherical harmonics of degree 6.""" sh_5_0, sh_5_1, sh_5_2, sh_5_3, sh_5_4, sh_5_5, sh_5_6, sh_5_7, sh_5_8, sh_5_9, sh_5_10 = prev sh_6_0 = 0.957427107756337 * sh_5_0 * z + 0.957427107756338 * sh_5_10 * x @@ -325,6 +355,7 @@ def _sh6(x, y, z, prev): return [sh_6_0, sh_6_1, sh_6_2, sh_6_3, sh_6_4, sh_6_5, sh_6_6, sh_6_7, sh_6_8, sh_6_9, sh_6_10, sh_6_11, sh_6_12] def _sh7(x, y, z, prev): + """Compute spherical harmonics of degree 7.""" sh_6_0, sh_6_1, sh_6_2, sh_6_3, sh_6_4, sh_6_5, sh_6_6, sh_6_7, sh_6_8, sh_6_9, sh_6_10, sh_6_11, sh_6_12 = prev sh_7_0 = 0.963624111659433 * sh_6_0 * z + 0.963624111659432 * sh_6_12 * x @@ -367,9 +398,12 @@ def _sh7(x, y, z, prev): return [sh_7_0, sh_7_1, sh_7_2, sh_7_3, sh_7_4, sh_7_5, sh_7_6, sh_7_7, sh_7_8, sh_7_9, sh_7_10, sh_7_11, sh_7_12, sh_7_13, sh_7_14] def _sh8(x, y, z, prev): - sh_7_0, sh_7_1, sh_7_2, sh_7_3, sh_7_4, sh_7_5, sh_7_6, sh_7_7, sh_7_8, sh_7_9, sh_7_10, sh_7_11, sh_7_12, sh_7_13, sh_7_14 = prev + """Compute spherical harmonics of degree 8.""" + sh_7_0, sh_7_1, sh_7_2, sh_7_3, sh_7_4, sh_7_5, sh_7_6, sh_7_7, sh_7_8, \ + sh_7_9, sh_7_10, sh_7_11, sh_7_12, sh_7_13, sh_7_14 = prev - sh_8_0 = 0.968245836551854 * sh_7_0 * z + 0.968245836551853 * sh_7_14 * x + sh_8_0 = 0.968245836551854 * sh_7_0 * z + \ + 0.968245836551853 * sh_7_14 * x sh_8_1 = 0.484122918275928 * sh_7_0 * y + 0.90571104663684 * \ sh_7_1 * z + 0.90571104663684 * sh_7_13 * x sh_8_2 = -0.0883883476483189 * sh_7_0 * z + 0.661437827766148 * sh_7_1 * y + \ @@ -412,10 +446,15 @@ def _sh8(x, y, z, prev): sh_7_13 * z + 0.484122918275927 * sh_7_14 * y sh_8_16 = -0.968245836551853 * sh_7_0 * x + 0.968245836551855 * sh_7_14 * z - return [sh_8_0, sh_8_1, sh_8_2, sh_8_3, sh_8_4, sh_8_5, sh_8_6, sh_8_7, sh_8_8, sh_8_9, sh_8_10, sh_8_11, sh_8_12, sh_8_13, sh_8_14, sh_8_15, sh_8_16] + return [sh_8_0, sh_8_1, sh_8_2, sh_8_3, sh_8_4, sh_8_5, sh_8_6, sh_8_7, + sh_8_8, sh_8_9, sh_8_10, sh_8_11, sh_8_12, sh_8_13, sh_8_14, + sh_8_15, sh_8_16] def _sh9(x, y, z, prev): - sh_8_0, sh_8_1, sh_8_2, sh_8_3, sh_8_4, sh_8_5, sh_8_6, sh_8_7, sh_8_8, sh_8_9, sh_8_10, sh_8_11, sh_8_12, sh_8_13, sh_8_14, sh_8_15, sh_8_16 = prev + """Compute spherical harmonics of degree 9.""" + sh_8_0, sh_8_1, sh_8_2, sh_8_3, sh_8_4, sh_8_5, sh_8_6, sh_8_7, sh_8_8, \ + sh_8_9, sh_8_10, sh_8_11, sh_8_12, sh_8_13, sh_8_14, sh_8_15, \ + sh_8_16 = prev sh_9_0 = 0.97182531580755 * sh_8_0 * z + 0.971825315807551 * sh_8_16 * x sh_9_1 = 0.458122847290851 * sh_8_0 * y + 0.916245694581702 * \ @@ -465,10 +504,15 @@ def _sh9(x, y, z, prev): sh_8_15 * z + 0.458122847290851 * sh_8_16 * y sh_9_18 = -0.97182531580755 * sh_8_0 * x + 0.97182531580755 * sh_8_16 * z - return [sh_9_0, sh_9_1, sh_9_2, sh_9_3, sh_9_4, sh_9_5, sh_9_6, sh_9_7, sh_9_8, sh_9_9, sh_9_10, sh_9_11, sh_9_12, sh_9_13, sh_9_14, sh_9_15, sh_9_16, sh_9_17, sh_9_18] + return [sh_9_0, sh_9_1, sh_9_2, sh_9_3, sh_9_4, sh_9_5, sh_9_6, sh_9_7, + sh_9_8, sh_9_9, sh_9_10, sh_9_11, sh_9_12, sh_9_13, sh_9_14, + sh_9_15, sh_9_16, sh_9_17, sh_9_18] def _sh10(x, y, z, prev): - sh_9_0, sh_9_1, sh_9_2, sh_9_3, sh_9_4, sh_9_5, sh_9_6, sh_9_7, sh_9_8, sh_9_9, sh_9_10, sh_9_11, sh_9_12, sh_9_13, sh_9_14, sh_9_15, sh_9_16, sh_9_17, sh_9_18 = prev + """Compute spherical harmonics of degree 10.""" + sh_9_0, sh_9_1, sh_9_2, sh_9_3, sh_9_4, sh_9_5, sh_9_6, sh_9_7, sh_9_8, \ + sh_9_9, sh_9_10, sh_9_11, sh_9_12, sh_9_13, sh_9_14, sh_9_15, \ + sh_9_16, sh_9_17, sh_9_18 = prev sh_10_0 = 0.974679434480897 * sh_9_0 * z + 0.974679434480897 * sh_9_18 * x sh_10_1 = 0.435889894354067 * sh_9_0 * y + 0.924662100445347 * \ @@ -523,10 +567,16 @@ def _sh10(x, y, z, prev): sh_9_17 * z + 0.435889894354068 * sh_9_18 * y sh_10_20 = -0.974679434480898 * sh_9_0 * x + 0.974679434480896 * sh_9_18 * z - return [sh_10_0, sh_10_1, sh_10_2, sh_10_3, sh_10_4, sh_10_5, sh_10_6, sh_10_7, sh_10_8, sh_10_9, sh_10_10, sh_10_11, sh_10_12, sh_10_13, sh_10_14, sh_10_15, sh_10_16, sh_10_17, sh_10_18, sh_10_19, sh_10_20] + return [sh_10_0, sh_10_1, sh_10_2, sh_10_3, sh_10_4, sh_10_5, sh_10_6, + sh_10_7, sh_10_8, sh_10_9, sh_10_10, sh_10_11, sh_10_12, + sh_10_13, sh_10_14, sh_10_15, sh_10_16, sh_10_17, sh_10_18, + sh_10_19, sh_10_20] def _sh11(x, y, z, prev): - sh_10_0, sh_10_1, sh_10_2, sh_10_3, sh_10_4, sh_10_5, sh_10_6, sh_10_7, sh_10_8, sh_10_9, sh_10_10, sh_10_11, sh_10_12, sh_10_13, sh_10_14, sh_10_15, sh_10_16, sh_10_17, sh_10_18, sh_10_19, sh_10_20 = prev + """Compute spherical harmonics of degree 11.""" + sh_10_0, sh_10_1, sh_10_2, sh_10_3, sh_10_4, sh_10_5, sh_10_6, sh_10_7, \ + sh_10_8, sh_10_9, sh_10_10, sh_10_11, sh_10_12, sh_10_13, sh_10_14, \ + sh_10_15, sh_10_16, sh_10_17, sh_10_18, sh_10_19, sh_10_20 = prev sh_11_0 = 0.977008420918394 * sh_10_0 * z + 0.977008420918394 * sh_10_20 * x sh_11_1 = 0.416597790450531 * sh_10_0 * y + 0.9315409787236 * \ @@ -589,9 +639,13 @@ def _sh11(x, y, z, prev): sh_10_19 * z + 0.416597790450531 * sh_10_20 * y sh_11_22 = -0.977008420918393 * sh_10_0 * x + 0.977008420918393 * sh_10_20 * z - return [sh_11_0, sh_11_1, sh_11_2, sh_11_3, sh_11_4, sh_11_5, sh_11_6, sh_11_7, sh_11_8, sh_11_9, sh_11_10, sh_11_11, sh_11_12, sh_11_13, sh_11_14, sh_11_15, sh_11_16, sh_11_17, sh_11_18, sh_11_19, sh_11_20, sh_11_21, sh_11_22] + return [sh_11_0, sh_11_1, sh_11_2, sh_11_3, sh_11_4, sh_11_5, sh_11_6, + sh_11_7, sh_11_8, sh_11_9, sh_11_10, sh_11_11, sh_11_12, + sh_11_13, sh_11_14, sh_11_15, sh_11_16, sh_11_17, sh_11_18, + sh_11_19, sh_11_20, sh_11_21, sh_11_22] def _spherical_harmonics(lmax: int, x, y, z): + """Compute spherical harmonics up to degree lmax.""" results = [] # l = 0 diff --git a/mindscience/e3nn/utils/batch_dot.py b/mindscience/e3nn/utils/batch_dot.py index 8a7b781ea..06dfbd2ce 100644 --- a/mindscience/e3nn/utils/batch_dot.py +++ b/mindscience/e3nn/utils/batch_dot.py @@ -105,16 +105,16 @@ def _get_output_shape(batch_size, x1_ret, x2_ret): def batch_dot(x1, x2, axes=None): """ Compute batch-wise dot product of two tensors. - + Args: x1 (Tensor): First input tensor with shape (batch_size, ...). x2 (Tensor): Second input tensor with shape (batch_size, ...). axes (int, list, tuple, optional): Axes to perform dot product along. If None, defaults to the last axis of x1 and second-to-last axis of x2. - + Returns: Tensor: The batch dot product result. - + Raises: ValueError: If batch sizes of x1 and x2 don't match. """ diff --git a/mindscience/e3nn/utils/func.py b/mindscience/e3nn/utils/func.py index 07bf6a57d..22bdf38ad 100644 --- a/mindscience/e3nn/utils/func.py +++ b/mindscience/e3nn/utils/func.py @@ -118,7 +118,7 @@ def broadcast_args(*args): Broadcasts the given data with multiple types. Args: - *arg (Union[Tensor[float32], list[float], tuple[float], + *arg (Union[Tensor[float32], list[float], tuple[float], ndarray[np.float32], float]): Any number of data to be broadcasted. Returns: @@ -138,15 +138,14 @@ def _ndexpm(mat): mat_shape = mat.shape if len(mat_shape) < 2: raise ValueError - elif len(mat_shape) == 2: + if len(mat_shape) == 2: return Tensor(expm(mat)) - else: - mat = np.reshape(mat, (-1, mat_shape[-1], mat_shape[-1])) - n = mat.shape[0] - for i in range(n): - mat[i] = expm(mat[i]) - mat = np.reshape(mat, mat_shape) - return Tensor(mat) + mat = np.reshape(mat, (-1, mat_shape[-1], mat_shape[-1])) + n = mat.shape[0] + for i in range(n): + mat[i] = expm(mat[i]) + mat = np.reshape(mat, mat_shape) + return Tensor(mat) def _expand_last_dims(x): diff --git a/mindscience/e3nn/utils/initializer.py b/mindscience/e3nn/utils/initializer.py index 3e9a75f25..755538ea5 100644 --- a/mindscience/e3nn/utils/initializer.py +++ b/mindscience/e3nn/utils/initializer.py @@ -53,16 +53,16 @@ class Uniform(Initializer): def renormal_initializer(init_method): """ Normalize and convert initialization method to proper initializer instance. - + Args: - init_method (str or Initializer): The initialization method name or + init_method (str or Initializer): The initialization method name or an Initializer instance. Supported string values are: - 'zeros', 'ones', 'truncatedNormal', 'normal', 'uniform', + 'zeros', 'ones', 'truncatedNormal', 'normal', 'uniform', 'he_uniform', 'he_normal', 'xavier_uniform'. - + Returns: Initializer: The corresponding initializer instance. - + Raises: ValueError: If the initialization method is not supported. """ diff --git a/mindscience/e3nn/utils/perm.py b/mindscience/e3nn/utils/perm.py index 6fcee00f9..ca94ae317 100644 --- a/mindscience/e3nn/utils/perm.py +++ b/mindscience/e3nn/utils/perm.py @@ -63,10 +63,10 @@ def _from_int(i, n): def _to_int(p): """ Convert a permutation to its integer representation. - + Args: p (tuple): A permutation represented as a tuple. - + Returns: int: The integer representation of the permutation. """ @@ -89,12 +89,12 @@ def _group(n): def _germinate(subset): """ Generate the group closure of a subset of permutations. - + Args: subset (set): A set of permutations. - + Returns: - set: The group closure containing all permutations that can be + set: The group closure containing all permutations that can be generated from the input subset through composition and inversion. """ while True: @@ -112,10 +112,10 @@ def _germinate(subset): def _is__(g): """ Check if a set of permutations forms a group. - + Args: g (set): A set of permutations to check. - + Returns: bool: True if the set forms a group, False otherwise. """ @@ -145,10 +145,10 @@ def _is__(g): def _to_cycles(p): """ Convert a permutation to its cycle representation. - + Args: p (tuple): A permutation represented as a tuple. - + Returns: set: A set of tuples representing the cycles of the permutation. Only cycles of length >= 2 are included. -- Gitee From f28c54454a4eba9e950e9ee7b0e0db211d550a9e Mon Sep 17 00:00:00 2001 From: birfied Date: Thu, 18 Sep 2025 16:38:50 +0800 Subject: [PATCH 07/21] correct pylint errors --- mindscience/e3nn/o3/spherical_harmonics.py | 41 +++++++++++----------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/mindscience/e3nn/o3/spherical_harmonics.py b/mindscience/e3nn/o3/spherical_harmonics.py index 5243a35ea..f6725802b 100644 --- a/mindscience/e3nn/o3/spherical_harmonics.py +++ b/mindscience/e3nn/o3/spherical_harmonics.py @@ -206,13 +206,13 @@ def spherical_harmonics(l, x, normalize=True, normalization='integral'): sh = SphericalHarmonics(l, normalize, normalization, dtype=x.dtype) return sh(x) -def _sh0(x, _y, _z): +def _sh0(x): """ Compute spherical harmonics of degree 0. - + Args: x (Tensor): Tensor for construct spherical harmonics. The shape of Tensor is :math:`x` of shape ``(..., 3)`` - + Returns: Tensor, the spherical harmonics :math:`Y^0(x)`. The shape of Tensor is ``(..., 1)``. """ @@ -222,10 +222,10 @@ def _sh0(x, _y, _z): def _sh1(x, y, z): """ Compute spherical harmonics of degree 1. - + Args: x (Tensor): Tensor for construct spherical harmonics. The shape of Tensor is :math:`x` of shape ``(..., 3)`` - + Returns: Tensor, the spherical harmonics :math:`Y^1(x)`. The shape of Tensor is ``(..., 3)``. """ @@ -237,10 +237,10 @@ def _sh1(x, y, z): def _sh2(x, y, z): """ Compute spherical harmonics of degree 2. - + Args: x (Tensor): Tensor for construct spherical harmonics. The shape of Tensor is :math:`x` of shape ``(..., 3)`` - + Returns: Tensor, the spherical harmonics :math:`Y^2(x)`. The shape of Tensor is ``(..., 5)``. """ @@ -255,7 +255,7 @@ def _sh2(x, y, z): def _sh3(x, y, z, prev): """Compute spherical harmonics of degree 3.""" - sh_2_0, _sh_2_1, _sh_2_2, _sh_2_3, sh_2_4 = prev + sh_2_0, sh_2_4 = prev[0], prev[4] y2 = y.pow(2) x2z2 = x.pow(2) + z.pow(2) sh_3_0 = 0.9128709291752769 * (sh_2_0 * z + sh_2_4 * x) @@ -395,7 +395,8 @@ def _sh7(x, y, z, prev): sh_6_11 * z + 0.515078753637713 * sh_6_12 * y sh_7_14 = -0.963624111659431 * sh_6_0 * x + 0.963624111659433 * sh_6_12 * z - return [sh_7_0, sh_7_1, sh_7_2, sh_7_3, sh_7_4, sh_7_5, sh_7_6, sh_7_7, sh_7_8, sh_7_9, sh_7_10, sh_7_11, sh_7_12, sh_7_13, sh_7_14] + return [sh_7_0, sh_7_1, sh_7_2, sh_7_3, sh_7_4, sh_7_5, sh_7_6, sh_7_7, sh_7_8, + sh_7_9, sh_7_10, sh_7_11, sh_7_12, sh_7_13, sh_7_14] def _sh8(x, y, z, prev): """Compute spherical harmonics of degree 8.""" @@ -446,8 +447,8 @@ def _sh8(x, y, z, prev): sh_7_13 * z + 0.484122918275927 * sh_7_14 * y sh_8_16 = -0.968245836551853 * sh_7_0 * x + 0.968245836551855 * sh_7_14 * z - return [sh_8_0, sh_8_1, sh_8_2, sh_8_3, sh_8_4, sh_8_5, sh_8_6, sh_8_7, - sh_8_8, sh_8_9, sh_8_10, sh_8_11, sh_8_12, sh_8_13, sh_8_14, + return [sh_8_0, sh_8_1, sh_8_2, sh_8_3, sh_8_4, sh_8_5, sh_8_6, sh_8_7, + sh_8_8, sh_8_9, sh_8_10, sh_8_11, sh_8_12, sh_8_13, sh_8_14, sh_8_15, sh_8_16] def _sh9(x, y, z, prev): @@ -504,8 +505,8 @@ def _sh9(x, y, z, prev): sh_8_15 * z + 0.458122847290851 * sh_8_16 * y sh_9_18 = -0.97182531580755 * sh_8_0 * x + 0.97182531580755 * sh_8_16 * z - return [sh_9_0, sh_9_1, sh_9_2, sh_9_3, sh_9_4, sh_9_5, sh_9_6, sh_9_7, - sh_9_8, sh_9_9, sh_9_10, sh_9_11, sh_9_12, sh_9_13, sh_9_14, + return [sh_9_0, sh_9_1, sh_9_2, sh_9_3, sh_9_4, sh_9_5, sh_9_6, sh_9_7, + sh_9_8, sh_9_9, sh_9_10, sh_9_11, sh_9_12, sh_9_13, sh_9_14, sh_9_15, sh_9_16, sh_9_17, sh_9_18] def _sh10(x, y, z, prev): @@ -567,9 +568,9 @@ def _sh10(x, y, z, prev): sh_9_17 * z + 0.435889894354068 * sh_9_18 * y sh_10_20 = -0.974679434480898 * sh_9_0 * x + 0.974679434480896 * sh_9_18 * z - return [sh_10_0, sh_10_1, sh_10_2, sh_10_3, sh_10_4, sh_10_5, sh_10_6, - sh_10_7, sh_10_8, sh_10_9, sh_10_10, sh_10_11, sh_10_12, - sh_10_13, sh_10_14, sh_10_15, sh_10_16, sh_10_17, sh_10_18, + return [sh_10_0, sh_10_1, sh_10_2, sh_10_3, sh_10_4, sh_10_5, sh_10_6, + sh_10_7, sh_10_8, sh_10_9, sh_10_10, sh_10_11, sh_10_12, + sh_10_13, sh_10_14, sh_10_15, sh_10_16, sh_10_17, sh_10_18, sh_10_19, sh_10_20] def _sh11(x, y, z, prev): @@ -639,9 +640,9 @@ def _sh11(x, y, z, prev): sh_10_19 * z + 0.416597790450531 * sh_10_20 * y sh_11_22 = -0.977008420918393 * sh_10_0 * x + 0.977008420918393 * sh_10_20 * z - return [sh_11_0, sh_11_1, sh_11_2, sh_11_3, sh_11_4, sh_11_5, sh_11_6, - sh_11_7, sh_11_8, sh_11_9, sh_11_10, sh_11_11, sh_11_12, - sh_11_13, sh_11_14, sh_11_15, sh_11_16, sh_11_17, sh_11_18, + return [sh_11_0, sh_11_1, sh_11_2, sh_11_3, sh_11_4, sh_11_5, sh_11_6, + sh_11_7, sh_11_8, sh_11_9, sh_11_10, sh_11_11, sh_11_12, + sh_11_13, sh_11_14, sh_11_15, sh_11_16, sh_11_17, sh_11_18, sh_11_19, sh_11_20, sh_11_21, sh_11_22] def _spherical_harmonics(lmax: int, x, y, z): @@ -649,7 +650,7 @@ def _spherical_harmonics(lmax: int, x, y, z): results = [] # l = 0 - sh0 = _sh0(x, y, z) + sh0 = _sh0(x) results.extend(sh0) if lmax == 0: return ops.stack(results, axis=-1) -- Gitee From 7b4547608e3d1caa8b37c13ce78dcf76b09f4d76 Mon Sep 17 00:00:00 2001 From: birfied Date: Fri, 19 Sep 2025 11:53:25 +0800 Subject: [PATCH 08/21] add e3nn.nn test cases --- mindscience/e3nn/nn/normact.py | 2 +- tests/e3nn/nn/__init__.py | 1 + tests/e3nn/nn/test_activation.py | 109 ++++++++++++++++++++++ tests/e3nn/nn/test_batchnorm.py | 142 +++++++++++++++++++++++++++++ tests/e3nn/nn/test_fc.py | 94 +++++++++++++++++++ tests/e3nn/nn/test_gate.py | 45 +++++++++ tests/e3nn/nn/test_normact.py | 43 +++++++++ tests/e3nn/nn/test_one_hot.py | 151 +++++++++++++++++++++++++++++++ tests/e3nn/nn/test_scatter.py | 64 +++++++++++++ 9 files changed, 650 insertions(+), 1 deletion(-) create mode 100644 tests/e3nn/nn/__init__.py create mode 100644 tests/e3nn/nn/test_activation.py create mode 100644 tests/e3nn/nn/test_batchnorm.py create mode 100644 tests/e3nn/nn/test_fc.py create mode 100644 tests/e3nn/nn/test_gate.py create mode 100644 tests/e3nn/nn/test_normact.py create mode 100644 tests/e3nn/nn/test_one_hot.py create mode 100644 tests/e3nn/nn/test_scatter.py diff --git a/mindscience/e3nn/nn/normact.py b/mindscience/e3nn/nn/normact.py index a4adee62d..2080931fa 100644 --- a/mindscience/e3nn/nn/normact.py +++ b/mindscience/e3nn/nn/normact.py @@ -84,7 +84,7 @@ class NormActivation(nn.Cell): epsilon = 1e-8 elif epsilon is not None and not normalize: raise ValueError("`epsilon` and `normalize = False` don't make sense together.") - elif not epsilon > 0: + elif epsilon is not None and not epsilon > 0: raise ValueError(f"epsilon {epsilon} is invalid, must be strictly positive.") self.epsilon = epsilon if self.epsilon is not None: diff --git a/tests/e3nn/nn/__init__.py b/tests/e3nn/nn/__init__.py new file mode 100644 index 000000000..db172509a --- /dev/null +++ b/tests/e3nn/nn/__init__.py @@ -0,0 +1 @@ +# E3NN neural network module tests \ No newline at end of file diff --git a/tests/e3nn/nn/test_activation.py b/tests/e3nn/nn/test_activation.py new file mode 100644 index 000000000..1babc9031 --- /dev/null +++ b/tests/e3nn/nn/test_activation.py @@ -0,0 +1,109 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test cases for e3nn.nn.activation module - Core functionality only""" + +import pytest +import numpy as np +from mindspore import Tensor, ops, float32 +from mindscience.e3nn.nn.activation import Activation, _Normalize, _moment, _parity_function +from mindscience.e3nn.o3 import Irreps + + +class TestActivation: + """Core tests for Activation class""" + + def test_activation_basic_creation(self): + """Test basic Activation creation and forward pass""" + act = Activation('2x0e+1x0o', [ops.tanh, ops.abs]) + + x = Tensor(np.random.randn(3, 3), dtype=float32) + output = act(x) + + assert output.shape == (3, 3) + assert act.irreps_in == Irreps('2x0e+1x0o') + assert not np.any(np.isnan(output.asnumpy())) + + def test_activation_parity_change(self): + """Test activation function changes parity correctly""" + # abs function should change odd to even + act = Activation('2x0o', [ops.abs]) + + x = Tensor(np.random.randn(2, 2), dtype=float32) + output = act(x) + + assert act.irreps_out == Irreps('2x0e') # odd -> even + assert output.shape == (2, 2) + + def test_activation_invalid_non_scalar(self): + """Test activation with non-scalar irrep raises error""" + with pytest.raises(ValueError, match="non-scalar input"): + Activation('1x1e', [ops.tanh]) + + +class TestNormalize: + """Core tests for _Normalize class""" + + def test_normalize_basic(self): + """Test _Normalize normalizes activation function""" + norm_tanh = _Normalize(ops.tanh) + + x = Tensor(np.random.randn(100), dtype=float32) + output = norm_tanh(x) + + assert output.shape == x.shape + assert hasattr(norm_tanh, 'factor') + + def test_normalize_scaling_function(self): + """Test _Normalize correctly handles scaling functions""" + def scale_func(x): + return x * 2.0 # This will have second moment = 4.0 + + norm_func = _Normalize(scale_func) + + # Verify factor is approximately correct (should be around 1/sqrt(4) = 0.5) + expected_factor = 1.0 / np.sqrt(4.0) + assert abs(float(norm_func.factor) - expected_factor) < 5e-3 + + # Test normalization effect + x = Tensor(np.ones(5), dtype=float32) + output = norm_func(x) + expected_output = scale_func(x) * norm_func.factor + assert np.allclose(output.asnumpy(), expected_output.asnumpy(), atol=1e-4) + + +class TestUtilityFunctions: + """Core tests for utility functions""" + + def test_moment_calculation(self): + """Test _moment function calculates moments correctly""" + moment = _moment(ops.tanh, 2) + + assert isinstance(moment, Tensor) + assert moment.shape == () # scalar + assert moment.asnumpy() > 0 + + def test_parity_function_detection(self): + """Test _parity_function detects function parity""" + # Test even function + parity_even = _parity_function(lambda x: x**2) + assert parity_even == 1 # even function + + # Test odd function + parity_odd = _parity_function(lambda x: x) + assert parity_odd == -1 # odd function + + +if __name__ == "__main__": + pytest.main([__file__]) \ No newline at end of file diff --git a/tests/e3nn/nn/test_batchnorm.py b/tests/e3nn/nn/test_batchnorm.py new file mode 100644 index 000000000..228fd1269 --- /dev/null +++ b/tests/e3nn/nn/test_batchnorm.py @@ -0,0 +1,142 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test cases for e3nn.nn.batchnorm module - Core functionality""" + +import pytest +import numpy as np +from mindspore import Tensor, float32 +from mindscience.e3nn.nn.batchnorm import BatchNorm +from mindscience.e3nn.o3 import Irreps + + +class TestBatchNorm: + """Core tests for BatchNorm class""" + + def test_batchnorm_basic_creation(self): + """Test basic BatchNorm creation and forward pass""" + bn = BatchNorm('2x0e+1x0o') + + x = Tensor(np.random.randn(4, 3), dtype=float32) + output = bn(x) + + assert output.shape == (4, 3) + assert bn.irreps == Irreps('2x0e+1x0o') + assert not np.any(np.isnan(output.asnumpy())) + + def test_batchnorm_normalization_correctness(self): + """Test that BatchNorm actually normalizes the data correctly""" + bn = BatchNorm('2x0e', eps=1e-8, affine=False) + + # Create data with known statistics + x = Tensor(np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]], dtype=np.float32)) + output = bn(x) + + # Manually compute expected normalized output + x_np = x.asnumpy() + x_mean = np.mean(x_np, axis=0) # [4.0, 5.0] + x_var = np.var(x_np, axis=0, ddof=0) # [5.0, 5.0] + expected_output = (x_np - x_mean) / np.sqrt(x_var + 1e-8) + + # Check that actual output matches manual calculation + output_np = output.asnumpy() + assert np.allclose(output_np, expected_output, atol=1e-6), \ + f"Normalization calculation incorrect" + + # Verify normalized output has zero mean and unit variance + assert abs(np.mean(output_np)) < 1e-6, "Mean should be close to 0" + assert abs(np.var(output_np, ddof=0) - 1.0) < 1e-5, "Variance should be close to 1" + + def test_batchnorm_affine_parameters(self): + """Test affine parameters (weight and bias) effect""" + x = Tensor(np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]], dtype=np.float32)) + + # Test with affine=True + bn_affine = BatchNorm('2x0e', affine=True, eps=1e-8) + weight = Tensor([2.0, 0.5], dtype=float32) + bias = Tensor([1.0, -1.0], dtype=float32) + bn_affine.weight.set_data(weight) + bn_affine.bias.set_data(bias) + + output_affine = bn_affine(x) + + # Verify computation: output = weight * normalized_input + bias + x_np = x.asnumpy() + x_mean = np.mean(x_np, axis=0) + x_var = np.var(x_np, axis=0, ddof=0) + x_normalized = (x_np - x_mean) / np.sqrt(x_var + 1e-8) + expected_output = x_normalized * weight.asnumpy() + bias.asnumpy() + + assert np.allclose(output_affine.asnumpy(), expected_output, atol=1e-5), \ + "Affine transformation calculation incorrect" + + # Test with affine=False + bn_no_affine = BatchNorm('2x0e', affine=False, eps=1e-8) + output_no_affine = bn_no_affine(x) + + assert np.allclose(output_no_affine.asnumpy(), x_normalized, atol=1e-5), \ + "Non-affine normalization calculation incorrect" + + def test_batchnorm_training_inference_modes(self): + """Test difference between training and inference modes""" + bn = BatchNorm('2x0e', momentum=0.1, instance=False, affine=False) + x = Tensor(np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]], dtype=np.float32)) + + # Training mode - should update running statistics + bn.training = True + output_train = bn(x) + + # Verify running statistics update follows momentum formula + x_np = x.asnumpy() + batch_mean = np.mean(x_np, axis=0) + batch_var = np.var(x_np, axis=0, ddof=0) + expected_running_mean = 0.9 * 0.0 + 0.1 * batch_mean # initial mean is 0 + expected_running_var = 0.9 * 1.0 + 0.1 * batch_var # initial var is 1 + + assert np.allclose(bn.running_mean.asnumpy(), expected_running_mean, atol=1e-6), \ + "Running mean update calculation incorrect" + assert np.allclose(bn.running_var.asnumpy(), expected_running_var, atol=1e-6), \ + "Running var update calculation incorrect" + + # Inference mode - should not update running statistics + running_mean_before = bn.running_mean.asnumpy().copy() + running_var_before = bn.running_var.asnumpy().copy() + + bn.training = False + output_inference = bn(x) + + assert np.allclose(bn.running_mean.asnumpy(), running_mean_before), \ + "Running mean should not change in inference mode" + assert np.allclose(bn.running_var.asnumpy(), running_var_before), \ + "Running var should not change in inference mode" + assert not np.any(np.isnan(output_train.asnumpy())) + assert not np.any(np.isnan(output_inference.asnumpy())) + + def test_batchnorm_invalid_parameters(self): + """Test error handling for invalid parameters""" + # Test invalid normalization + with pytest.raises(ValueError, match="Invalid normalization option"): + bn = BatchNorm('2x0e', normalization='invalid') + x = Tensor(np.random.randn(4, 2), dtype=float32) + bn(x) + + # Test invalid reduce + with pytest.raises(ValueError, match="Invalid reduce option"): + bn = BatchNorm('2x0e', reduce='invalid') + x = Tensor(np.random.randn(4, 2), dtype=float32) + bn(x) + + +if __name__ == "__main__": + pytest.main([__file__]) \ No newline at end of file diff --git a/tests/e3nn/nn/test_fc.py b/tests/e3nn/nn/test_fc.py new file mode 100644 index 000000000..b6e5f5f84 --- /dev/null +++ b/tests/e3nn/nn/test_fc.py @@ -0,0 +1,94 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test cases for FullyConnectedNet""" +import pytest +import numpy as np +from mindspore import Tensor, float32, ops +from mindscience.e3nn.nn.fc import FullyConnectedNet, _Layer + + +class TestFullyConnectedNet: + """Test cases for FullyConnectedNet""" + + def test_fc_basic_creation(self): + """Test basic creation and parameter initialization""" + h_list = [4, 10, 6] + fc = FullyConnectedNet(h_list) + + assert fc.h_list == h_list + assert len(fc.layer_list) == 2 + assert fc.layer_list[0].h_in == 4 and fc.layer_list[0].h_out == 10 + assert fc.layer_list[1].h_in == 10 and fc.layer_list[1].h_out == 6 + assert fc.weight_numel == 4*10 + 10*6 + + def test_fc_forward_computation(self): + """Test forward propagation computation correctness""" + h_list = [3, 4, 2] + fc = FullyConnectedNet(h_list, act=None, out_act=False) + + x = Tensor(np.array([1.0, 2.0, 3.0], dtype=np.float32)) + + # Set fixed weights for verification + fc.layer_list[0].weight.set_data(Tensor(np.array([ + [0.1, 0.2, 0.3, 0.4], + [0.5, 0.6, 0.7, 0.8], + [0.9, 1.0, 1.1, 1.2] + ], dtype=np.float32))) + + fc.layer_list[1].weight.set_data(Tensor(np.array([ + [0.1, 0.2], + [0.3, 0.4], + [0.5, 0.6], + [0.7, 0.8] + ], dtype=np.float32))) + + output = fc(x) + + # Manual calculation verification + w1_norm = fc.layer_list[0].weight.asnumpy() / np.sqrt(3) + hidden = np.dot(x.asnumpy(), w1_norm) + w2_norm = fc.layer_list[1].weight.asnumpy() / np.sqrt(4) + expected_output = np.dot(hidden, w2_norm) + + assert output.shape == (2,) + assert np.allclose(output.asnumpy(), expected_output, atol=1e-6) + + def test_fc_activation_function(self): + """Test activation function""" + h_list = [2, 3, 2] + fc_with_act = FullyConnectedNet(h_list, act=ops.tanh, out_act=True) + fc_without_act = FullyConnectedNet(h_list, act=ops.tanh, out_act=False) + + x = Tensor(np.array([1.0, -1.0], dtype=np.float32)) + output_with_act = fc_with_act(x) + output_without_act = fc_without_act(x) + + assert output_with_act.shape == (2,) + assert output_without_act.shape == (2,) + assert not np.allclose(output_with_act.asnumpy(), output_without_act.asnumpy()) + + def test_fc_error_handling(self): + """Test error handling""" + # Test invalid h_list + with pytest.raises(TypeError): + FullyConnectedNet([3.5, 4, 2]) + + # Test minimum valid case + fc_minimal = FullyConnectedNet([2, 1]) + assert len(fc_minimal.layer_list) == 1 + + +if __name__ == "__main__": + pytest.main([__file__]) \ No newline at end of file diff --git a/tests/e3nn/nn/test_gate.py b/tests/e3nn/nn/test_gate.py new file mode 100644 index 000000000..15046e427 --- /dev/null +++ b/tests/e3nn/nn/test_gate.py @@ -0,0 +1,45 @@ +"""Test Gate module""" +import pytest +from mindspore import Tensor, ops, float32 +import numpy as np +from mindscience.e3nn.nn import Gate + + +class TestGate: + def test_gate_creation(self): + """测试Gate创建和基本属性""" + gate = Gate('2x0e', [ops.tanh], '1x0e', [ops.sigmoid], '1x1o') + assert isinstance(gate, Gate) + assert gate.irreps_in.dim > 0 + assert gate.irreps_out.dim > 0 + + def test_gate_forward(self): + """测试前向传播""" + gate = Gate('1x0e', [ops.tanh], '2x0e', [ops.sigmoid, ops.abs], '2x1o') + x = Tensor(np.random.randn(3, gate.irreps_in.dim), dtype=float32) + output = gate(x) + + assert output.shape == (3, gate.irreps_out.dim) + assert not np.isnan(output.asnumpy()).any() + + def test_gate_activations(self): + """测试不同激活函数""" + gate1 = Gate('1x0e', [ops.tanh], '1x0e', [ops.sigmoid], '1x1o') + gate2 = Gate('1x0e', [ops.relu], '1x0e', [ops.abs], '1x1o') + + x = Tensor(np.random.randn(2, gate1.irreps_in.dim), dtype=float32) + output1, output2 = gate1(x), gate2(x) + + assert output1.shape == output2.shape + assert not np.allclose(output1.asnumpy(), output2.asnumpy(), atol=1e-6) + + def test_gate_errors(self): + """测试错误处理""" + with pytest.raises(ValueError, match="Scalars must be scalars"): + Gate('1x1o', [ops.tanh], '1x0e', [ops.sigmoid], '1x1o') + + with pytest.raises(ValueError, match="Gate scalars must be scalars"): + Gate('1x0e', [ops.tanh], '1x1o', [ops.sigmoid], '1x1o') + + with pytest.raises(ValueError, match="different number"): + Gate('1x0e', [ops.tanh], '2x0e', [ops.sigmoid, ops.abs], '1x1o') \ No newline at end of file diff --git a/tests/e3nn/nn/test_normact.py b/tests/e3nn/nn/test_normact.py new file mode 100644 index 000000000..7cbbc103e --- /dev/null +++ b/tests/e3nn/nn/test_normact.py @@ -0,0 +1,43 @@ +import pytest +from mindspore import Tensor, ops, float32 +import numpy as np +from mindscience.e3nn.nn import NormActivation + + +class TestNormActivation: + def test_creation_and_forward(self): + normact = NormActivation('2x1e', ops.sigmoid) + assert normact.irreps_in.dim > 0 + assert normact.irreps_out.dim == normact.irreps_in.dim + assert normact.normalize is True + assert normact.epsilon == 1e-8 + + x = Tensor(np.random.randn(3, normact.irreps_in.dim), dtype=float32) + output = normact(x) + assert output.shape == x.shape + assert not np.isnan(output.asnumpy()).any() + + def test_normalize_and_epsilon(self): + normact_norm = NormActivation('1x1o', ops.sigmoid, normalize=True) + normact_no_norm = NormActivation('1x1o', ops.sigmoid, normalize=False) + normact_eps = NormActivation('1x1o', ops.sigmoid, epsilon=1e-6) + + assert normact_norm.normalize and normact_norm.epsilon == 1e-8 + assert not normact_no_norm.normalize and normact_no_norm.epsilon is None + assert normact_eps.epsilon == 1e-6 and normact_eps._eps_squared == 1e-12 + + def test_activations_and_bias(self): + normact1 = NormActivation('1x1o', ops.sigmoid, bias=True) + normact2 = NormActivation('1x1o', ops.tanh, bias=False) + + x = Tensor(np.random.randn(2, 3), dtype=float32) + output1, output2 = normact1(x), normact2(x) + + assert output1.shape == output2.shape + assert normact1.bias is not None and normact2.bias is None + + def test_errors(self): + with pytest.raises(ValueError, match="epsilon.*normalize = False.*don't make sense"): + NormActivation('1x1o', ops.sigmoid, normalize=False, epsilon=1e-6) + with pytest.raises(ValueError, match="epsilon.*invalid.*strictly positive"): + NormActivation('1x1o', ops.sigmoid, epsilon=-1e-6) \ No newline at end of file diff --git a/tests/e3nn/nn/test_one_hot.py b/tests/e3nn/nn/test_one_hot.py new file mode 100644 index 000000000..b7ccbd2ae --- /dev/null +++ b/tests/e3nn/nn/test_one_hot.py @@ -0,0 +1,151 @@ +"""Test cases for one_hot module""" +import pytest +import numpy as np + +from mindspore import Tensor, ops, float32, int32 +from mindscience.e3nn.nn.one_hot import OneHot, SoftOneHotLinspace, soft_one_hot_linspace, soft_unit_step + + +class TestSoftUnitStep: + """Test soft_unit_step function""" + + def test_soft_unit_step_basic(self): + """Test soft_unit_step with basic functionality""" + # Test positive values + x_pos = Tensor([1.0, 2.0], dtype=float32) + result_pos = soft_unit_step(x_pos) + expected_pos = ops.exp(-1.0 / x_pos) + assert np.allclose(result_pos.asnumpy(), expected_pos.asnumpy(), atol=1e-6) + + # Test negative values (should be zero due to relu) + x_neg = Tensor([-1.0, -2.0], dtype=float32) + result_neg = soft_unit_step(x_neg) + expected_neg = Tensor([0.0, 0.0], dtype=float32) + assert np.allclose(result_neg.asnumpy(), expected_neg.asnumpy(), atol=1e-6) + + # Test zero (may be NaN or 0 due to division by zero) + x_zero = Tensor([0.0], dtype=float32) + result_zero = soft_unit_step(x_zero) + result_np = result_zero.asnumpy() + assert result_np[0] == 0.0 or np.isnan(result_np[0]) + + +class TestOneHot: + """Test OneHot class""" + + def test_onehot_basic(self): + """Test OneHot basic functionality""" + num_types = 4 + onehot = OneHot(num_types) + + # Test creation + assert onehot.num_types == num_types + assert str(onehot.irreps_output) == "4x0e" + + # Test single input + atom_type = Tensor([2], dtype=int32) + result = onehot(atom_type) + expected = Tensor([[0., 0., 1., 0.]], dtype=float32) + assert np.allclose(result.asnumpy(), expected.asnumpy()) + assert result.shape == (1, 4) + + # Test batch input + atom_types = Tensor([0, 1, 2], dtype=int32) + result_batch = onehot(atom_types) + expected_batch = Tensor([ + [1., 0., 0., 0.], + [0., 1., 0., 0.], + [0., 0., 1., 0.] + ], dtype=float32) + assert np.allclose(result_batch.asnumpy(), expected_batch.asnumpy()) + assert result_batch.shape == (3, 4) + + +class TestSoftOneHotLinspace: + """Test SoftOneHotLinspace class""" + + def test_soft_onehot_basic(self): + """Test SoftOneHotLinspace basic functionality""" + start, end, number = 0.0, 2.0, 4 + soft_onehot = SoftOneHotLinspace(start, end, number) + + # Test creation + assert soft_onehot.start.asnumpy() == start + assert soft_onehot.end.asnumpy() == end + assert soft_onehot.number == number + + # Test forward pass + x = Tensor([1.0], dtype=float32) + result = soft_onehot(x) + assert result.shape == (1, 4) + + # Test batch input + x_batch = Tensor([[0.5, 1.0], [1.5, 2.0]], dtype=float32) + result_batch = soft_onehot(x_batch) + assert result_batch.shape == (2, 2, 4) + + def test_soft_onehot_different_basis(self): + """Test SoftOneHotLinspace with different basis functions""" + start, end, number = 0.0, 2.0, 3 + x = Tensor([1.0], dtype=float32) + + for basis in ['gaussian', 'cosine', 'smooth_finite']: + soft_onehot = SoftOneHotLinspace(start, end, number, basis=basis) + result = soft_onehot(x) + assert result.shape == (1, 3) + # Some basis functions may produce NaN at boundaries, which is expected + + def test_soft_onehot_cutoff(self): + """Test SoftOneHotLinspace cutoff behavior""" + start, end, number = 0.0, 2.0, 3 + + # Test with and without cutoff + soft_onehot_cutoff = SoftOneHotLinspace(start, end, number, cutoff=True) + soft_onehot_no_cutoff = SoftOneHotLinspace(start, end, number, cutoff=False) + + x = Tensor([3.0], dtype=float32) # Outside domain + result_cutoff = soft_onehot_cutoff(x) + result_no_cutoff = soft_onehot_no_cutoff(x) + + assert result_cutoff.shape == (1, 3) + assert result_no_cutoff.shape == (1, 3) + + +class TestSoftOneHotLinspaceFunction: + """Test soft_one_hot_linspace function""" + + def test_function_basic(self): + """Test soft_one_hot_linspace function interface""" + x = Tensor([1.0, 1.5, 2.0], dtype=float32) + start, end, number = 0.0, 3.0, 4 + + result = soft_one_hot_linspace(x, start, end, number) + assert result.shape == (3, 4) + + # Test with different basis + result_gaussian = soft_one_hot_linspace(x, start, end, number, basis='gaussian') + assert result_gaussian.shape == (3, 4) + + +class TestEdgeCases: + """Test edge cases and error handling""" + + def test_edge_cases(self): + """Test various edge cases""" + # OneHot with single type + onehot = OneHot(1) + atom_type = Tensor([0], dtype=int32) + result = onehot(atom_type) + assert result.shape == (1, 1) + assert np.allclose(result.asnumpy(), Tensor([[1.0]], dtype=float32).asnumpy()) + + # SoftOneHotLinspace with small number + soft_onehot = SoftOneHotLinspace(0.0, 1.0, 2) + x = Tensor([0.5], dtype=float32) + result = soft_onehot(x) + assert result.shape == (1, 2) + + # Invalid basis should raise error + soft_onehot_invalid = SoftOneHotLinspace(0.0, 1.0, 3, basis='invalid') + with pytest.raises(ValueError, match="Unsupported basis"): + soft_onehot_invalid(x) \ No newline at end of file diff --git a/tests/e3nn/nn/test_scatter.py b/tests/e3nn/nn/test_scatter.py new file mode 100644 index 000000000..c10628a3c --- /dev/null +++ b/tests/e3nn/nn/test_scatter.py @@ -0,0 +1,64 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test scatter""" +import numpy as np +import pytest +from mindspore import Tensor, float32, int32 +from mindscience.e3nn.nn import Scatter + + +class TestScatter: + """Test Scatter class core functionality""" + + def test_scatter_add(self): + """Test scatter add operation""" + scatter = Scatter(mode='add') + + src = Tensor([[1.0, 2.0], [3.0, 4.0]], dtype=float32) + index = Tensor([0, 1], dtype=int32) + + result = scatter(src, index, dim_size=2) + expected = Tensor([[1.0, 2.0], [3.0, 4.0]], dtype=float32) + + assert np.allclose(result.asnumpy(), expected.asnumpy()) + + def test_scatter_max(self): + """Test scatter max operation""" + scatter = Scatter(mode='max') + + src = Tensor([[1.0, 5.0], [3.0, 2.0], [2.0, 4.0]], dtype=float32) + index = Tensor([0, 1, 0], dtype=int32) + + result = scatter(src, index, dim_size=2) + expected = Tensor([[2.0, 5.0], [3.0, 2.0]], dtype=float32) + + assert np.allclose(result.asnumpy(), expected.asnumpy()) + + def test_scatter_with_out_parameter(self): + """Test scatter with out parameter for proper initialization""" + scatter = Scatter(mode='mul') + + src = Tensor([[2.0, 3.0], [4.0, 5.0]], dtype=float32) + index = Tensor([0, 1], dtype=int32) + out = Tensor([[1.0, 1.0], [1.0, 1.0]], dtype=float32) + + result = scatter(src, index, out=out) + expected = Tensor([[2.0, 3.0], [4.0, 5.0]], dtype=float32) + + assert np.allclose(result.asnumpy(), expected.asnumpy()) + + def test_scatter_invalid_mode(self): + """Test scatter with invalid mode""" + with pytest.raises(ValueError, match="Unexpected scatter mode"): + Scatter(mode='invalid') \ No newline at end of file -- Gitee From be060b0b0f77bb7d1fa782dd1364814dc8ac037f Mon Sep 17 00:00:00 2001 From: birfied Date: Fri, 19 Sep 2025 14:05:48 +0800 Subject: [PATCH 09/21] pylint errors --- tests/e3nn/nn/__init__.py | 2 +- tests/e3nn/nn/test_activation.py | 26 ++++++++--------- tests/e3nn/nn/test_batchnorm.py | 40 +++++++++++++------------- tests/e3nn/nn/test_fc.py | 24 ++++++++-------- tests/e3nn/nn/test_gate.py | 12 ++++---- tests/e3nn/nn/test_normact.py | 10 +++---- tests/e3nn/nn/test_one_hot.py | 48 ++++++++++++++++---------------- tests/e3nn/nn/test_scatter.py | 20 ++++++------- 8 files changed, 91 insertions(+), 91 deletions(-) diff --git a/tests/e3nn/nn/__init__.py b/tests/e3nn/nn/__init__.py index db172509a..2b2810bbc 100644 --- a/tests/e3nn/nn/__init__.py +++ b/tests/e3nn/nn/__init__.py @@ -1 +1 @@ -# E3NN neural network module tests \ No newline at end of file +# E3NN neural network module tests diff --git a/tests/e3nn/nn/test_activation.py b/tests/e3nn/nn/test_activation.py index 1babc9031..9b585d109 100644 --- a/tests/e3nn/nn/test_activation.py +++ b/tests/e3nn/nn/test_activation.py @@ -27,10 +27,10 @@ class TestActivation: def test_activation_basic_creation(self): """Test basic Activation creation and forward pass""" act = Activation('2x0e+1x0o', [ops.tanh, ops.abs]) - + x = Tensor(np.random.randn(3, 3), dtype=float32) output = act(x) - + assert output.shape == (3, 3) assert act.irreps_in == Irreps('2x0e+1x0o') assert not np.any(np.isnan(output.asnumpy())) @@ -39,10 +39,10 @@ class TestActivation: """Test activation function changes parity correctly""" # abs function should change odd to even act = Activation('2x0o', [ops.abs]) - + x = Tensor(np.random.randn(2, 2), dtype=float32) output = act(x) - + assert act.irreps_out == Irreps('2x0e') # odd -> even assert output.shape == (2, 2) @@ -58,10 +58,10 @@ class TestNormalize: def test_normalize_basic(self): """Test _Normalize normalizes activation function""" norm_tanh = _Normalize(ops.tanh) - + x = Tensor(np.random.randn(100), dtype=float32) output = norm_tanh(x) - + assert output.shape == x.shape assert hasattr(norm_tanh, 'factor') @@ -69,13 +69,13 @@ class TestNormalize: """Test _Normalize correctly handles scaling functions""" def scale_func(x): return x * 2.0 # This will have second moment = 4.0 - + norm_func = _Normalize(scale_func) - + # Verify factor is approximately correct (should be around 1/sqrt(4) = 0.5) expected_factor = 1.0 / np.sqrt(4.0) assert abs(float(norm_func.factor) - expected_factor) < 5e-3 - + # Test normalization effect x = Tensor(np.ones(5), dtype=float32) output = norm_func(x) @@ -89,7 +89,7 @@ class TestUtilityFunctions: def test_moment_calculation(self): """Test _moment function calculates moments correctly""" moment = _moment(ops.tanh, 2) - + assert isinstance(moment, Tensor) assert moment.shape == () # scalar assert moment.asnumpy() > 0 @@ -99,11 +99,11 @@ class TestUtilityFunctions: # Test even function parity_even = _parity_function(lambda x: x**2) assert parity_even == 1 # even function - - # Test odd function + + # Test odd function parity_odd = _parity_function(lambda x: x) assert parity_odd == -1 # odd function if __name__ == "__main__": - pytest.main([__file__]) \ No newline at end of file + pytest.main([__file__]) diff --git a/tests/e3nn/nn/test_batchnorm.py b/tests/e3nn/nn/test_batchnorm.py index 228fd1269..3636bd718 100644 --- a/tests/e3nn/nn/test_batchnorm.py +++ b/tests/e3nn/nn/test_batchnorm.py @@ -27,10 +27,10 @@ class TestBatchNorm: def test_batchnorm_basic_creation(self): """Test basic BatchNorm creation and forward pass""" bn = BatchNorm('2x0e+1x0o') - + x = Tensor(np.random.randn(4, 3), dtype=float32) output = bn(x) - + assert output.shape == (4, 3) assert bn.irreps == Irreps('2x0e+1x0o') assert not np.any(np.isnan(output.asnumpy())) @@ -38,22 +38,22 @@ class TestBatchNorm: def test_batchnorm_normalization_correctness(self): """Test that BatchNorm actually normalizes the data correctly""" bn = BatchNorm('2x0e', eps=1e-8, affine=False) - + # Create data with known statistics x = Tensor(np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]], dtype=np.float32)) output = bn(x) - + # Manually compute expected normalized output x_np = x.asnumpy() x_mean = np.mean(x_np, axis=0) # [4.0, 5.0] x_var = np.var(x_np, axis=0, ddof=0) # [5.0, 5.0] expected_output = (x_np - x_mean) / np.sqrt(x_var + 1e-8) - + # Check that actual output matches manual calculation output_np = output.asnumpy() assert np.allclose(output_np, expected_output, atol=1e-6), \ f"Normalization calculation incorrect" - + # Verify normalized output has zero mean and unit variance assert abs(np.mean(output_np)) < 1e-6, "Mean should be close to 0" assert abs(np.var(output_np, ddof=0) - 1.0) < 1e-5, "Variance should be close to 1" @@ -61,30 +61,30 @@ class TestBatchNorm: def test_batchnorm_affine_parameters(self): """Test affine parameters (weight and bias) effect""" x = Tensor(np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]], dtype=np.float32)) - + # Test with affine=True bn_affine = BatchNorm('2x0e', affine=True, eps=1e-8) weight = Tensor([2.0, 0.5], dtype=float32) bias = Tensor([1.0, -1.0], dtype=float32) bn_affine.weight.set_data(weight) bn_affine.bias.set_data(bias) - + output_affine = bn_affine(x) - + # Verify computation: output = weight * normalized_input + bias x_np = x.asnumpy() x_mean = np.mean(x_np, axis=0) x_var = np.var(x_np, axis=0, ddof=0) x_normalized = (x_np - x_mean) / np.sqrt(x_var + 1e-8) expected_output = x_normalized * weight.asnumpy() + bias.asnumpy() - + assert np.allclose(output_affine.asnumpy(), expected_output, atol=1e-5), \ "Affine transformation calculation incorrect" - + # Test with affine=False bn_no_affine = BatchNorm('2x0e', affine=False, eps=1e-8) output_no_affine = bn_no_affine(x) - + assert np.allclose(output_no_affine.asnumpy(), x_normalized, atol=1e-5), \ "Non-affine normalization calculation incorrect" @@ -92,30 +92,30 @@ class TestBatchNorm: """Test difference between training and inference modes""" bn = BatchNorm('2x0e', momentum=0.1, instance=False, affine=False) x = Tensor(np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]], dtype=np.float32)) - + # Training mode - should update running statistics bn.training = True output_train = bn(x) - + # Verify running statistics update follows momentum formula x_np = x.asnumpy() batch_mean = np.mean(x_np, axis=0) batch_var = np.var(x_np, axis=0, ddof=0) expected_running_mean = 0.9 * 0.0 + 0.1 * batch_mean # initial mean is 0 expected_running_var = 0.9 * 1.0 + 0.1 * batch_var # initial var is 1 - + assert np.allclose(bn.running_mean.asnumpy(), expected_running_mean, atol=1e-6), \ "Running mean update calculation incorrect" assert np.allclose(bn.running_var.asnumpy(), expected_running_var, atol=1e-6), \ "Running var update calculation incorrect" - + # Inference mode - should not update running statistics running_mean_before = bn.running_mean.asnumpy().copy() running_var_before = bn.running_var.asnumpy().copy() - + bn.training = False output_inference = bn(x) - + assert np.allclose(bn.running_mean.asnumpy(), running_mean_before), \ "Running mean should not change in inference mode" assert np.allclose(bn.running_var.asnumpy(), running_var_before), \ @@ -130,7 +130,7 @@ class TestBatchNorm: bn = BatchNorm('2x0e', normalization='invalid') x = Tensor(np.random.randn(4, 2), dtype=float32) bn(x) - + # Test invalid reduce with pytest.raises(ValueError, match="Invalid reduce option"): bn = BatchNorm('2x0e', reduce='invalid') @@ -139,4 +139,4 @@ class TestBatchNorm: if __name__ == "__main__": - pytest.main([__file__]) \ No newline at end of file + pytest.main([__file__]) diff --git a/tests/e3nn/nn/test_fc.py b/tests/e3nn/nn/test_fc.py index b6e5f5f84..bf16a9c0c 100644 --- a/tests/e3nn/nn/test_fc.py +++ b/tests/e3nn/nn/test_fc.py @@ -26,7 +26,7 @@ class TestFullyConnectedNet: """Test basic creation and parameter initialization""" h_list = [4, 10, 6] fc = FullyConnectedNet(h_list) - + assert fc.h_list == h_list assert len(fc.layer_list) == 2 assert fc.layer_list[0].h_in == 4 and fc.layer_list[0].h_out == 10 @@ -37,31 +37,31 @@ class TestFullyConnectedNet: """Test forward propagation computation correctness""" h_list = [3, 4, 2] fc = FullyConnectedNet(h_list, act=None, out_act=False) - + x = Tensor(np.array([1.0, 2.0, 3.0], dtype=np.float32)) - + # Set fixed weights for verification fc.layer_list[0].weight.set_data(Tensor(np.array([ [0.1, 0.2, 0.3, 0.4], - [0.5, 0.6, 0.7, 0.8], + [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2] ], dtype=np.float32))) - + fc.layer_list[1].weight.set_data(Tensor(np.array([ [0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8] ], dtype=np.float32))) - + output = fc(x) - + # Manual calculation verification w1_norm = fc.layer_list[0].weight.asnumpy() / np.sqrt(3) hidden = np.dot(x.asnumpy(), w1_norm) w2_norm = fc.layer_list[1].weight.asnumpy() / np.sqrt(4) expected_output = np.dot(hidden, w2_norm) - + assert output.shape == (2,) assert np.allclose(output.asnumpy(), expected_output, atol=1e-6) @@ -70,11 +70,11 @@ class TestFullyConnectedNet: h_list = [2, 3, 2] fc_with_act = FullyConnectedNet(h_list, act=ops.tanh, out_act=True) fc_without_act = FullyConnectedNet(h_list, act=ops.tanh, out_act=False) - + x = Tensor(np.array([1.0, -1.0], dtype=np.float32)) output_with_act = fc_with_act(x) output_without_act = fc_without_act(x) - + assert output_with_act.shape == (2,) assert output_without_act.shape == (2,) assert not np.allclose(output_with_act.asnumpy(), output_without_act.asnumpy()) @@ -84,11 +84,11 @@ class TestFullyConnectedNet: # Test invalid h_list with pytest.raises(TypeError): FullyConnectedNet([3.5, 4, 2]) - + # Test minimum valid case fc_minimal = FullyConnectedNet([2, 1]) assert len(fc_minimal.layer_list) == 1 if __name__ == "__main__": - pytest.main([__file__]) \ No newline at end of file + pytest.main([__file__]) diff --git a/tests/e3nn/nn/test_gate.py b/tests/e3nn/nn/test_gate.py index 15046e427..0203b38a2 100644 --- a/tests/e3nn/nn/test_gate.py +++ b/tests/e3nn/nn/test_gate.py @@ -18,7 +18,7 @@ class TestGate: gate = Gate('1x0e', [ops.tanh], '2x0e', [ops.sigmoid, ops.abs], '2x1o') x = Tensor(np.random.randn(3, gate.irreps_in.dim), dtype=float32) output = gate(x) - + assert output.shape == (3, gate.irreps_out.dim) assert not np.isnan(output.asnumpy()).any() @@ -26,10 +26,10 @@ class TestGate: """测试不同激活函数""" gate1 = Gate('1x0e', [ops.tanh], '1x0e', [ops.sigmoid], '1x1o') gate2 = Gate('1x0e', [ops.relu], '1x0e', [ops.abs], '1x1o') - + x = Tensor(np.random.randn(2, gate1.irreps_in.dim), dtype=float32) output1, output2 = gate1(x), gate2(x) - + assert output1.shape == output2.shape assert not np.allclose(output1.asnumpy(), output2.asnumpy(), atol=1e-6) @@ -37,9 +37,9 @@ class TestGate: """测试错误处理""" with pytest.raises(ValueError, match="Scalars must be scalars"): Gate('1x1o', [ops.tanh], '1x0e', [ops.sigmoid], '1x1o') - + with pytest.raises(ValueError, match="Gate scalars must be scalars"): Gate('1x0e', [ops.tanh], '1x1o', [ops.sigmoid], '1x1o') - + with pytest.raises(ValueError, match="different number"): - Gate('1x0e', [ops.tanh], '2x0e', [ops.sigmoid, ops.abs], '1x1o') \ No newline at end of file + Gate('1x0e', [ops.tanh], '2x0e', [ops.sigmoid, ops.abs], '1x1o') diff --git a/tests/e3nn/nn/test_normact.py b/tests/e3nn/nn/test_normact.py index 7cbbc103e..34ccaee70 100644 --- a/tests/e3nn/nn/test_normact.py +++ b/tests/e3nn/nn/test_normact.py @@ -11,7 +11,7 @@ class TestNormActivation: assert normact.irreps_out.dim == normact.irreps_in.dim assert normact.normalize is True assert normact.epsilon == 1e-8 - + x = Tensor(np.random.randn(3, normact.irreps_in.dim), dtype=float32) output = normact(x) assert output.shape == x.shape @@ -21,7 +21,7 @@ class TestNormActivation: normact_norm = NormActivation('1x1o', ops.sigmoid, normalize=True) normact_no_norm = NormActivation('1x1o', ops.sigmoid, normalize=False) normact_eps = NormActivation('1x1o', ops.sigmoid, epsilon=1e-6) - + assert normact_norm.normalize and normact_norm.epsilon == 1e-8 assert not normact_no_norm.normalize and normact_no_norm.epsilon is None assert normact_eps.epsilon == 1e-6 and normact_eps._eps_squared == 1e-12 @@ -29,10 +29,10 @@ class TestNormActivation: def test_activations_and_bias(self): normact1 = NormActivation('1x1o', ops.sigmoid, bias=True) normact2 = NormActivation('1x1o', ops.tanh, bias=False) - + x = Tensor(np.random.randn(2, 3), dtype=float32) output1, output2 = normact1(x), normact2(x) - + assert output1.shape == output2.shape assert normact1.bias is not None and normact2.bias is None @@ -40,4 +40,4 @@ class TestNormActivation: with pytest.raises(ValueError, match="epsilon.*normalize = False.*don't make sense"): NormActivation('1x1o', ops.sigmoid, normalize=False, epsilon=1e-6) with pytest.raises(ValueError, match="epsilon.*invalid.*strictly positive"): - NormActivation('1x1o', ops.sigmoid, epsilon=-1e-6) \ No newline at end of file + NormActivation('1x1o', ops.sigmoid, epsilon=-1e-6) diff --git a/tests/e3nn/nn/test_one_hot.py b/tests/e3nn/nn/test_one_hot.py index b7ccbd2ae..81dee981e 100644 --- a/tests/e3nn/nn/test_one_hot.py +++ b/tests/e3nn/nn/test_one_hot.py @@ -8,7 +8,7 @@ from mindscience.e3nn.nn.one_hot import OneHot, SoftOneHotLinspace, soft_one_hot class TestSoftUnitStep: """Test soft_unit_step function""" - + def test_soft_unit_step_basic(self): """Test soft_unit_step with basic functionality""" # Test positive values @@ -16,13 +16,13 @@ class TestSoftUnitStep: result_pos = soft_unit_step(x_pos) expected_pos = ops.exp(-1.0 / x_pos) assert np.allclose(result_pos.asnumpy(), expected_pos.asnumpy(), atol=1e-6) - + # Test negative values (should be zero due to relu) x_neg = Tensor([-1.0, -2.0], dtype=float32) result_neg = soft_unit_step(x_neg) expected_neg = Tensor([0.0, 0.0], dtype=float32) assert np.allclose(result_neg.asnumpy(), expected_neg.asnumpy(), atol=1e-6) - + # Test zero (may be NaN or 0 due to division by zero) x_zero = Tensor([0.0], dtype=float32) result_zero = soft_unit_step(x_zero) @@ -32,23 +32,23 @@ class TestSoftUnitStep: class TestOneHot: """Test OneHot class""" - + def test_onehot_basic(self): """Test OneHot basic functionality""" num_types = 4 onehot = OneHot(num_types) - + # Test creation assert onehot.num_types == num_types assert str(onehot.irreps_output) == "4x0e" - + # Test single input atom_type = Tensor([2], dtype=int32) result = onehot(atom_type) expected = Tensor([[0., 0., 1., 0.]], dtype=float32) assert np.allclose(result.asnumpy(), expected.asnumpy()) assert result.shape == (1, 4) - + # Test batch input atom_types = Tensor([0, 1, 2], dtype=int32) result_batch = onehot(atom_types) @@ -63,65 +63,65 @@ class TestOneHot: class TestSoftOneHotLinspace: """Test SoftOneHotLinspace class""" - + def test_soft_onehot_basic(self): """Test SoftOneHotLinspace basic functionality""" start, end, number = 0.0, 2.0, 4 soft_onehot = SoftOneHotLinspace(start, end, number) - + # Test creation assert soft_onehot.start.asnumpy() == start assert soft_onehot.end.asnumpy() == end assert soft_onehot.number == number - + # Test forward pass x = Tensor([1.0], dtype=float32) result = soft_onehot(x) assert result.shape == (1, 4) - + # Test batch input x_batch = Tensor([[0.5, 1.0], [1.5, 2.0]], dtype=float32) result_batch = soft_onehot(x_batch) assert result_batch.shape == (2, 2, 4) - + def test_soft_onehot_different_basis(self): """Test SoftOneHotLinspace with different basis functions""" start, end, number = 0.0, 2.0, 3 x = Tensor([1.0], dtype=float32) - + for basis in ['gaussian', 'cosine', 'smooth_finite']: soft_onehot = SoftOneHotLinspace(start, end, number, basis=basis) result = soft_onehot(x) assert result.shape == (1, 3) # Some basis functions may produce NaN at boundaries, which is expected - + def test_soft_onehot_cutoff(self): """Test SoftOneHotLinspace cutoff behavior""" start, end, number = 0.0, 2.0, 3 - + # Test with and without cutoff soft_onehot_cutoff = SoftOneHotLinspace(start, end, number, cutoff=True) soft_onehot_no_cutoff = SoftOneHotLinspace(start, end, number, cutoff=False) - + x = Tensor([3.0], dtype=float32) # Outside domain result_cutoff = soft_onehot_cutoff(x) result_no_cutoff = soft_onehot_no_cutoff(x) - + assert result_cutoff.shape == (1, 3) assert result_no_cutoff.shape == (1, 3) class TestSoftOneHotLinspaceFunction: """Test soft_one_hot_linspace function""" - + def test_function_basic(self): """Test soft_one_hot_linspace function interface""" x = Tensor([1.0, 1.5, 2.0], dtype=float32) start, end, number = 0.0, 3.0, 4 - + result = soft_one_hot_linspace(x, start, end, number) assert result.shape == (3, 4) - + # Test with different basis result_gaussian = soft_one_hot_linspace(x, start, end, number, basis='gaussian') assert result_gaussian.shape == (3, 4) @@ -129,7 +129,7 @@ class TestSoftOneHotLinspaceFunction: class TestEdgeCases: """Test edge cases and error handling""" - + def test_edge_cases(self): """Test various edge cases""" # OneHot with single type @@ -138,14 +138,14 @@ class TestEdgeCases: result = onehot(atom_type) assert result.shape == (1, 1) assert np.allclose(result.asnumpy(), Tensor([[1.0]], dtype=float32).asnumpy()) - + # SoftOneHotLinspace with small number soft_onehot = SoftOneHotLinspace(0.0, 1.0, 2) x = Tensor([0.5], dtype=float32) result = soft_onehot(x) assert result.shape == (1, 2) - + # Invalid basis should raise error soft_onehot_invalid = SoftOneHotLinspace(0.0, 1.0, 3, basis='invalid') with pytest.raises(ValueError, match="Unsupported basis"): - soft_onehot_invalid(x) \ No newline at end of file + soft_onehot_invalid(x) diff --git a/tests/e3nn/nn/test_scatter.py b/tests/e3nn/nn/test_scatter.py index c10628a3c..5c21243ce 100644 --- a/tests/e3nn/nn/test_scatter.py +++ b/tests/e3nn/nn/test_scatter.py @@ -24,41 +24,41 @@ class TestScatter: def test_scatter_add(self): """Test scatter add operation""" scatter = Scatter(mode='add') - + src = Tensor([[1.0, 2.0], [3.0, 4.0]], dtype=float32) index = Tensor([0, 1], dtype=int32) - + result = scatter(src, index, dim_size=2) expected = Tensor([[1.0, 2.0], [3.0, 4.0]], dtype=float32) - + assert np.allclose(result.asnumpy(), expected.asnumpy()) def test_scatter_max(self): """Test scatter max operation""" scatter = Scatter(mode='max') - + src = Tensor([[1.0, 5.0], [3.0, 2.0], [2.0, 4.0]], dtype=float32) index = Tensor([0, 1, 0], dtype=int32) - + result = scatter(src, index, dim_size=2) expected = Tensor([[2.0, 5.0], [3.0, 2.0]], dtype=float32) - + assert np.allclose(result.asnumpy(), expected.asnumpy()) def test_scatter_with_out_parameter(self): """Test scatter with out parameter for proper initialization""" scatter = Scatter(mode='mul') - + src = Tensor([[2.0, 3.0], [4.0, 5.0]], dtype=float32) index = Tensor([0, 1], dtype=int32) out = Tensor([[1.0, 1.0], [1.0, 1.0]], dtype=float32) - + result = scatter(src, index, out=out) expected = Tensor([[2.0, 3.0], [4.0, 5.0]], dtype=float32) - + assert np.allclose(result.asnumpy(), expected.asnumpy()) def test_scatter_invalid_mode(self): """Test scatter with invalid mode""" with pytest.raises(ValueError, match="Unexpected scatter mode"): - Scatter(mode='invalid') \ No newline at end of file + Scatter(mode='invalid') -- Gitee From a5d887a9c3ec2fd5322cc70601c2e63cad978cec Mon Sep 17 00:00:00 2001 From: birfied Date: Fri, 19 Sep 2025 14:15:23 +0800 Subject: [PATCH 10/21] pylint errors --- tests/e3nn/nn/test_fc.py | 4 ++-- tests/e3nn/nn/test_gate.py | 10 ++++++---- tests/e3nn/nn/test_normact.py | 9 ++++++++- 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/tests/e3nn/nn/test_fc.py b/tests/e3nn/nn/test_fc.py index bf16a9c0c..104832691 100644 --- a/tests/e3nn/nn/test_fc.py +++ b/tests/e3nn/nn/test_fc.py @@ -15,8 +15,8 @@ """Test cases for FullyConnectedNet""" import pytest import numpy as np -from mindspore import Tensor, float32, ops -from mindscience.e3nn.nn.fc import FullyConnectedNet, _Layer +from mindspore import Tensor, ops +from mindscience.e3nn.nn.fc import FullyConnectedNet class TestFullyConnectedNet: diff --git a/tests/e3nn/nn/test_gate.py b/tests/e3nn/nn/test_gate.py index 0203b38a2..d8e6eea41 100644 --- a/tests/e3nn/nn/test_gate.py +++ b/tests/e3nn/nn/test_gate.py @@ -6,15 +6,17 @@ from mindscience.e3nn.nn import Gate class TestGate: + """Test cases for Gate module""" + def test_gate_creation(self): - """测试Gate创建和基本属性""" + """Test Gate creation and basic properties""" gate = Gate('2x0e', [ops.tanh], '1x0e', [ops.sigmoid], '1x1o') assert isinstance(gate, Gate) assert gate.irreps_in.dim > 0 assert gate.irreps_out.dim > 0 def test_gate_forward(self): - """测试前向传播""" + """Test forward propagation""" gate = Gate('1x0e', [ops.tanh], '2x0e', [ops.sigmoid, ops.abs], '2x1o') x = Tensor(np.random.randn(3, gate.irreps_in.dim), dtype=float32) output = gate(x) @@ -23,7 +25,7 @@ class TestGate: assert not np.isnan(output.asnumpy()).any() def test_gate_activations(self): - """测试不同激活函数""" + """Test different activation functions""" gate1 = Gate('1x0e', [ops.tanh], '1x0e', [ops.sigmoid], '1x1o') gate2 = Gate('1x0e', [ops.relu], '1x0e', [ops.abs], '1x1o') @@ -34,7 +36,7 @@ class TestGate: assert not np.allclose(output1.asnumpy(), output2.asnumpy(), atol=1e-6) def test_gate_errors(self): - """测试错误处理""" + """Test error handling""" with pytest.raises(ValueError, match="Scalars must be scalars"): Gate('1x1o', [ops.tanh], '1x0e', [ops.sigmoid], '1x1o') diff --git a/tests/e3nn/nn/test_normact.py b/tests/e3nn/nn/test_normact.py index 34ccaee70..9f8a5fe0b 100644 --- a/tests/e3nn/nn/test_normact.py +++ b/tests/e3nn/nn/test_normact.py @@ -1,3 +1,4 @@ +"""Test cases for NormActivation module""" import pytest from mindspore import Tensor, ops, float32 import numpy as np @@ -5,7 +6,10 @@ from mindscience.e3nn.nn import NormActivation class TestNormActivation: + """Test cases for NormActivation class""" + def test_creation_and_forward(self): + """Test NormActivation creation and forward pass""" normact = NormActivation('2x1e', ops.sigmoid) assert normact.irreps_in.dim > 0 assert normact.irreps_out.dim == normact.irreps_in.dim @@ -18,15 +22,17 @@ class TestNormActivation: assert not np.isnan(output.asnumpy()).any() def test_normalize_and_epsilon(self): + """Test normalize parameter and epsilon configuration""" normact_norm = NormActivation('1x1o', ops.sigmoid, normalize=True) normact_no_norm = NormActivation('1x1o', ops.sigmoid, normalize=False) normact_eps = NormActivation('1x1o', ops.sigmoid, epsilon=1e-6) assert normact_norm.normalize and normact_norm.epsilon == 1e-8 assert not normact_no_norm.normalize and normact_no_norm.epsilon is None - assert normact_eps.epsilon == 1e-6 and normact_eps._eps_squared == 1e-12 + assert normact_eps.epsilon == 1e-6 and normact_eps.epsilon * normact_eps.epsilon == 1e-12 def test_activations_and_bias(self): + """Test different activation functions and bias parameter""" normact1 = NormActivation('1x1o', ops.sigmoid, bias=True) normact2 = NormActivation('1x1o', ops.tanh, bias=False) @@ -37,6 +43,7 @@ class TestNormActivation: assert normact1.bias is not None and normact2.bias is None def test_errors(self): + """Test error handling for invalid parameter combinations""" with pytest.raises(ValueError, match="epsilon.*normalize = False.*don't make sense"): NormActivation('1x1o', ops.sigmoid, normalize=False, epsilon=1e-6) with pytest.raises(ValueError, match="epsilon.*invalid.*strictly positive"): -- Gitee From 5b31f04056bbd5b60998314b290aa53735ee6c65 Mon Sep 17 00:00:00 2001 From: birfied Date: Fri, 19 Sep 2025 14:21:34 +0800 Subject: [PATCH 11/21] pylint errors --- tests/e3nn/nn/test_gate.py | 2 +- tests/e3nn/nn/test_normact.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/e3nn/nn/test_gate.py b/tests/e3nn/nn/test_gate.py index d8e6eea41..3cbc06447 100644 --- a/tests/e3nn/nn/test_gate.py +++ b/tests/e3nn/nn/test_gate.py @@ -7,7 +7,7 @@ from mindscience.e3nn.nn import Gate class TestGate: """Test cases for Gate module""" - + def test_gate_creation(self): """Test Gate creation and basic properties""" gate = Gate('2x0e', [ops.tanh], '1x0e', [ops.sigmoid], '1x1o') diff --git a/tests/e3nn/nn/test_normact.py b/tests/e3nn/nn/test_normact.py index 9f8a5fe0b..24a34c3f9 100644 --- a/tests/e3nn/nn/test_normact.py +++ b/tests/e3nn/nn/test_normact.py @@ -7,7 +7,7 @@ from mindscience.e3nn.nn import NormActivation class TestNormActivation: """Test cases for NormActivation class""" - + def test_creation_and_forward(self): """Test NormActivation creation and forward pass""" normact = NormActivation('2x1e', ops.sigmoid) -- Gitee From 878d008e4a6cb5baf87046f6292af0d227b2e3ec Mon Sep 17 00:00:00 2001 From: birfied Date: Fri, 19 Sep 2025 14:56:32 +0800 Subject: [PATCH 12/21] remote __init__.py --- tests/e3nn/__init__.py | 14 -- tests/e3nn/nn/__init__.py | 1 - tests/e3nn/o3/test_irreps.py | 430 +++++++++++++++++++++++++++++++++++ 3 files changed, 430 insertions(+), 15 deletions(-) delete mode 100644 tests/e3nn/__init__.py delete mode 100644 tests/e3nn/nn/__init__.py create mode 100644 tests/e3nn/o3/test_irreps.py diff --git a/tests/e3nn/__init__.py b/tests/e3nn/__init__.py deleted file mode 100644 index 83b15297d..000000000 --- a/tests/e3nn/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2025 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ diff --git a/tests/e3nn/nn/__init__.py b/tests/e3nn/nn/__init__.py deleted file mode 100644 index 2b2810bbc..000000000 --- a/tests/e3nn/nn/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# E3NN neural network module tests diff --git a/tests/e3nn/o3/test_irreps.py b/tests/e3nn/o3/test_irreps.py new file mode 100644 index 000000000..6e7c9e0eb --- /dev/null +++ b/tests/e3nn/o3/test_irreps.py @@ -0,0 +1,430 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test cases for irreps module""" + +import pytest +import numpy as np +from mindspore import Tensor, ops, float32 +from mindscience.e3nn.o3 import Irrep, Irreps + + +class TestIrrep: + """Test cases for Irrep class""" + + def test_irrep_basic_creation(self): + """Test basic Irrep creation with different parameters""" + # Test creation with l and p parameters + irrep1 = Irrep(0, 1) + assert irrep1.l == 0 + assert irrep1.p == 1 + assert str(irrep1) == "0e" + + irrep2 = Irrep(1, -1) + assert irrep2.l == 1 + assert irrep2.p == -1 + assert str(irrep2) == "1o" + + # Test creation with string + irrep3 = Irrep("2e") + assert irrep3.l == 2 + assert irrep3.p == 1 + + irrep4 = Irrep("3o") + assert irrep4.l == 3 + assert irrep4.p == -1 + + # Test creation with 'y' notation + irrep5 = Irrep("2y") + assert irrep5.l == 2 + assert irrep5.p == 1 # (-1)^2 = 1 + + irrep6 = Irrep("3y") + assert irrep6.l == 3 + assert irrep6.p == -1 # (-1)^3 = -1 + + def test_irrep_properties(self): + """Test Irrep properties like dim and is_scalar""" + # Test dimension property + assert Irrep(0, 1).dim == 1 + assert Irrep(1, 1).dim == 3 + assert Irrep(2, 1).dim == 5 + assert Irrep(3, -1).dim == 7 + + # Test is_scalar method + assert Irrep(0, 1).is_scalar() is True + assert Irrep(0, -1).is_scalar() is False + assert Irrep(1, 1).is_scalar() is False + assert Irrep(2, 1).is_scalar() is False + + def test_irrep_multiplication(self): + """Test Irrep multiplication (tensor product)""" + irrep1 = Irrep(1, 1) + irrep2 = Irrep(1, 1) + + # Test multiplication result + products = list(irrep1 * irrep2) + expected = [Irrep(0, 1), Irrep(1, 1), Irrep(2, 1)] + assert products == expected + + # Test with different parities + irrep3 = Irrep(1, -1) + products2 = list(irrep1 * irrep3) + expected2 = [Irrep(0, -1), Irrep(1, -1), Irrep(2, -1)] + assert products2 == expected2 + + def test_irrep_arithmetic_operations(self): + """Test Irrep arithmetic operations""" + irrep1 = Irrep(1, 1) + irrep2 = Irrep(2, -1) + + # Test right multiplication with integer + result = 3 * irrep1 + assert isinstance(result, Irreps) + assert len(result) == 1 + assert result.data[0].mul == 3 + assert result.data[0].ir == irrep1 + + # Test addition + result_add = irrep1 + irrep2 + assert isinstance(result_add, Irreps) + assert len(result_add) == 2 + + def test_irrep_comparison_operations(self): + """Test Irrep comparison operations""" + irrep1 = Irrep(1, 1) + irrep2 = Irrep(1, 1) + irrep3 = Irrep(2, 1) + irrep4 = Irrep(1, -1) + + # Test equality + assert irrep1 == irrep2 + assert irrep1 != irrep3 + assert irrep1 != irrep4 + + # Test ordering + assert irrep1 < irrep3 # lower l comes first + assert irrep4 < irrep1 # same l, negative p comes first + + def test_irrep_iteration(self): + """Test Irrep iteration (deconstruction)""" + irrep = Irrep(2, -1) + l, p = irrep + assert l == 2 + assert p == -1 + + def test_irrep_error_handling(self): + """Test Irrep error handling for invalid inputs""" + # Test negative l + with pytest.raises(ValueError): + Irrep(-1, 1) + + # Test invalid parity + with pytest.raises(ValueError): + Irrep(1, 0) + + with pytest.raises(ValueError): + Irrep(1, 2) + + # Test invalid string format + with pytest.raises(ValueError): + Irrep("invalid") + + # Test invalid type for l + with pytest.raises(TypeError): + Irrep(1.5, 1) + + def test_irrep_wigner_d_matrix(self): + """Test Irrep Wigner D matrix computation""" + irrep = Irrep(1, -1) + + # Test with identity matrix + R = ops.eye(3) + d_matrix = irrep.wigD_from_matrix(R) + assert d_matrix.shape == (3, 3) + + # Test with negative identity (inversion) + R_neg = -ops.eye(3) + d_matrix_neg = irrep.wigD_from_matrix(R_neg) + assert d_matrix_neg.shape == (3, 3) + + # Test error handling for non-tensor input + with pytest.raises(TypeError): + irrep.wigD_from_matrix([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + + +class TestIrreps: + """Test cases for Irreps class""" + + def test_irreps_basic_creation(self): + """Test basic Irreps creation with different input formats""" + # Test creation from string + irreps1 = Irreps("1x0e+2x1o") + assert len(irreps1) == 2 + assert irreps1.data[0].mul == 1 + assert irreps1.data[0].ir == Irrep(0, 1) + assert irreps1.data[1].mul == 2 + assert irreps1.data[1].ir == Irrep(1, -1) + + # Test creation from list of tuples + irreps2 = Irreps([(1, (0, 1)), (2, (1, -1))]) + assert irreps1 == irreps2 + + # Test creation from single Irrep + irreps3 = Irreps(Irrep(1, 1)) + assert len(irreps3) == 1 + assert irreps3.data[0].mul == 1 + assert irreps3.data[0].ir == Irrep(1, 1) + + # Test empty creation + irreps4 = Irreps() + assert len(irreps4) == 0 + assert irreps4.dim == 0 + + irreps5 = Irreps("") + assert len(irreps5) == 0 + + def test_irreps_properties(self): + """Test Irreps properties""" + irreps = Irreps("2x0e+3x1o+1x2e") + + # Test dimension + expected_dim = 2 * 1 + 3 * 3 + 1 * 5 # 2 + 9 + 5 = 16 + assert irreps.dim == expected_dim + + # Test slices + assert len(irreps.slice) == 3 + assert irreps.slice[0] == slice(0, 2) + assert irreps.slice[1] == slice(2, 11) + assert irreps.slice[2] == slice(11, 16) + + # Test slice_tuples + assert irreps.slice_tuples == [(0, 2), (2, 9), (11, 5)] + + def test_irreps_string_operations(self): + """Test Irreps string representation and parsing""" + # Test string representation + irreps = Irreps("2x0e+3x1o") + assert str(irreps) == "2x0e+3x1o" + + # Test complex string parsing + irreps_complex = Irreps("100x0e+50x1e+0x2e") + assert len(irreps_complex) == 3 + assert irreps_complex.data[2].mul == 0 + + # Test single irrep without multiplicity + irreps_single = Irreps("1o") + assert len(irreps_single) == 1 + assert irreps_single.data[0].mul == 1 + assert irreps_single.data[0].ir == Irrep(1, -1) + + def test_irreps_arithmetic_operations(self): + """Test Irreps arithmetic operations""" + irreps1 = Irreps("1x0e+1x1o") + irreps2 = Irreps("2x0e+1x2e") + + # Test addition + result_add = irreps1 + irreps2 + assert len(result_add) == 4 + expected = Irreps("1x0e+1x1o+2x0e+1x2e") + assert result_add == expected + + # Test multiplication with integer + result_mul = irreps1 * 2 + expected_mul = Irreps("2x0e+2x1o") + assert result_mul == expected_mul + + # Test tensor product with another Irreps + irreps_small1 = Irreps("1x0e") + irreps_small2 = Irreps("1x1o") + result_tensor = irreps_small1 * irreps_small2 + expected_tensor = Irreps("1x1o") + assert result_tensor == expected_tensor + + def test_irreps_contains_operation(self): + """Test Irreps contains operation""" + irreps = Irreps("1x0e+2x1o+1x2e") + + # Test single Irrep containment + assert Irrep(0, 1) in irreps + assert Irrep(1, -1) in irreps + assert Irrep(2, 1) in irreps + assert Irrep(3, 1) not in irreps + + # Test Irreps containment + subset = Irreps("1x0e+1x1o") + assert subset in irreps + + subset_not = Irreps("3x1o") # More than available + assert subset_not not in irreps + + def test_irreps_comparison_operations(self): + """Test Irreps comparison operations""" + irreps1 = Irreps("1x0e+2x1o") + irreps2 = Irreps("1x0e+2x1o") + irreps3 = Irreps("2x0e+1x1o") + + # Test equality + assert irreps1 == irreps2 + assert irreps1 != irreps3 + + # Test hash (for use in sets/dicts) + irreps_set = {irreps1, irreps2, irreps3} + assert len(irreps_set) == 2 # irreps1 and irreps2 are the same + + def test_irreps_iteration(self): + """Test Irreps iteration""" + irreps = Irreps("2x0e+1x1o") + + # Test iteration over _MulIr objects + mul_irs = list(irreps) + assert len(mul_irs) == 2 + assert mul_irs[0].mul == 2 + assert mul_irs[0].ir == Irrep(0, 1) + assert mul_irs[1].mul == 1 + assert mul_irs[1].ir == Irrep(1, -1) + + def test_irreps_advanced_operations(self): + """Test advanced Irreps operations like simplify and sort""" + # Test with repeated irreps that should be simplified + irreps_unsorted = Irreps("1x1o+2x0e+1x1o+3x0e") + + # The simplify and sort operations are internal, + # but we can test the final result + assert len(irreps_unsorted) == 4 # Before simplification + + def test_irreps_error_handling(self): + """Test Irreps error handling for invalid inputs""" + # Test invalid string format + with pytest.raises(ValueError): + Irreps("invalid_format") + + # Test negative multiplicity should raise ValueError + with pytest.raises(ValueError): + Irreps([(-1, (0, 1))]) + + # Test invalid multiplicity type should raise ValueError (not TypeError) + with pytest.raises(ValueError): + Irreps([(1.5, (0, 1))]) + + def test_irreps_complex_scenarios(self): + """Test complex scenarios and edge cases""" + # Test large irreps + large_irreps = Irreps("100x0e+50x1e+25x2e") + assert large_irreps.dim == 100 * 1 + 50 * 3 + 25 * 5 + assert len(large_irreps) == 3 + + # Test zero multiplicity + zero_irreps = Irreps("0x1o+2x0e") + assert len(zero_irreps) == 2 + assert zero_irreps.data[0].mul == 0 + + # Test mixed parity + mixed_irreps = Irreps("1x0e+1x0o+1x1e+1x1o") + assert len(mixed_irreps) == 4 + + def test_irreps_lmax_property(self): + """Test lmax property of Irreps""" + irreps1 = Irreps("1x0e+1x1o+1x2e") + assert irreps1.lmax == 2 + + irreps2 = Irreps("1x0e") + assert irreps2.lmax == 0 + + # Empty Irreps should raise ValueError when accessing lmax + irreps3 = Irreps("") + with pytest.raises(ValueError): + _ = irreps3.lmax + + def test_irreps_num_irreps_property(self): + """Test num_irreps property of Irreps""" + irreps1 = Irreps("2x0e+3x1o") + assert irreps1.num_irreps == 5 # 2 + 3 + + irreps2 = Irreps("1x2e") + assert irreps2.num_irreps == 1 + + irreps3 = Irreps("") + assert irreps3.num_irreps == 0 + + +class TestMulIr: + """Test cases for _MulIr class""" + + def test_mulir_basic_creation(self): + """Test basic _MulIr creation""" + from mindscience.e3nn.o3.irreps import _MulIr + + # Test creation with mul and Irrep + irrep = Irrep(1, 1) + mulir = _MulIr(3, irrep) + assert mulir.mul == 3 + assert mulir.ir == irrep + assert mulir.dim == 3 * 3 # mul * irrep.dim + + def test_mulir_properties(self): + """Test _MulIr properties""" + from mindscience.e3nn.o3.irreps import _MulIr + + irrep = Irrep(2, -1) + mulir = _MulIr(2, irrep) + + # Test dimension + assert mulir.dim == 2 * 5 # 2 * (2*2+1) + + # Test string representation + assert str(mulir) == "2x2o" + + def test_mulir_iteration(self): + """Test _MulIr iteration""" + from mindscience.e3nn.o3.irreps import _MulIr + + irrep = Irrep(1, -1) + mulir = _MulIr(4, irrep) + + # Test deconstruction + mul, ir = mulir + assert mul == 4 + assert ir == irrep + + def test_mulir_comparison_operations(self): + """Test _MulIr comparison operations""" + from mindscience.e3nn.o3.irreps import _MulIr + + irrep1 = Irrep(1, 1) + irrep2 = Irrep(2, 1) + mulir1 = _MulIr(2, irrep1) + mulir2 = _MulIr(2, irrep1) + mulir3 = _MulIr(3, irrep1) + mulir4 = _MulIr(2, irrep2) + + # Test equality + assert mulir1 == mulir2 + assert mulir1 != mulir3 + assert mulir1 != mulir4 + + # Test ordering (by irrep first, then by multiplicity) + assert mulir1 < mulir4 # irrep1 < irrep2 + assert mulir1 < mulir3 # same irrep, but mul1 < mul3 + + def test_mulir_error_handling(self): + """Test _MulIr error handling""" + from mindscience.e3nn.o3.irreps import _MulIr + + # Test invalid types + with pytest.raises(TypeError): + _MulIr(1.5, Irrep(1, 1)) # mul should be int + + with pytest.raises(TypeError): + _MulIr(2, "1e") # ir should be Irrep instance \ No newline at end of file -- Gitee From 4cf2439a11f768c9a84407c1f7d1f454c00bea04 Mon Sep 17 00:00:00 2001 From: birfied Date: Fri, 19 Sep 2025 15:02:03 +0800 Subject: [PATCH 13/21] pylint --- tests/e3nn/o3/test_irreps.py | 344 +++++++++-------------------------- 1 file changed, 81 insertions(+), 263 deletions(-) diff --git a/tests/e3nn/o3/test_irreps.py b/tests/e3nn/o3/test_irreps.py index 6e7c9e0eb..234ebbe09 100644 --- a/tests/e3nn/o3/test_irreps.py +++ b/tests/e3nn/o3/test_irreps.py @@ -23,57 +23,44 @@ from mindscience.e3nn.o3 import Irrep, Irreps class TestIrrep: """Test cases for Irrep class""" - def test_irrep_basic_creation(self): - """Test basic Irrep creation with different parameters""" + def test_irrep_creation_and_properties(self): + """Test Irrep creation, properties and basic operations""" # Test creation with l and p parameters irrep1 = Irrep(0, 1) assert irrep1.l == 0 assert irrep1.p == 1 assert str(irrep1) == "0e" + assert irrep1.dim == 1 + assert irrep1.is_scalar() is True irrep2 = Irrep(1, -1) assert irrep2.l == 1 assert irrep2.p == -1 assert str(irrep2) == "1o" + assert irrep2.dim == 3 + assert irrep2.is_scalar() is False - # Test creation with string + # Test creation with string notation irrep3 = Irrep("2e") assert irrep3.l == 2 assert irrep3.p == 1 + assert irrep3.dim == 5 - irrep4 = Irrep("3o") + irrep4 = Irrep("3y") assert irrep4.l == 3 - assert irrep4.p == -1 - - # Test creation with 'y' notation - irrep5 = Irrep("2y") - assert irrep5.l == 2 - assert irrep5.p == 1 # (-1)^2 = 1 - - irrep6 = Irrep("3y") - assert irrep6.l == 3 - assert irrep6.p == -1 # (-1)^3 = -1 - - def test_irrep_properties(self): - """Test Irrep properties like dim and is_scalar""" - # Test dimension property - assert Irrep(0, 1).dim == 1 - assert Irrep(1, 1).dim == 3 - assert Irrep(2, 1).dim == 5 - assert Irrep(3, -1).dim == 7 - - # Test is_scalar method - assert Irrep(0, 1).is_scalar() is True - assert Irrep(0, -1).is_scalar() is False - assert Irrep(1, 1).is_scalar() is False - assert Irrep(2, 1).is_scalar() is False - - def test_irrep_multiplication(self): - """Test Irrep multiplication (tensor product)""" + assert irrep4.p == -1 # (-1)^3 = -1 + + # Test comparison operations + assert irrep1 == Irrep(0, 1) + assert irrep1 != irrep2 + assert irrep1 < irrep2 # Compare by l first, then p + + def test_irrep_multiplication_and_arithmetic(self): + """Test Irrep multiplication and arithmetic operations""" irrep1 = Irrep(1, 1) irrep2 = Irrep(1, 1) - # Test multiplication result + # Test tensor product products = list(irrep1 * irrep2) expected = [Irrep(0, 1), Irrep(1, 1), Irrep(2, 1)] assert products == expected @@ -84,82 +71,35 @@ class TestIrrep: expected2 = [Irrep(0, -1), Irrep(1, -1), Irrep(2, -1)] assert products2 == expected2 - def test_irrep_arithmetic_operations(self): - """Test Irrep arithmetic operations""" - irrep1 = Irrep(1, 1) - irrep2 = Irrep(2, -1) - - # Test right multiplication with integer + # Test arithmetic operations result = 3 * irrep1 assert isinstance(result, Irreps) - assert len(result) == 1 assert result.data[0].mul == 3 assert result.data[0].ir == irrep1 - # Test addition - result_add = irrep1 + irrep2 + result_add = irrep1 + irrep3 assert isinstance(result_add, Irreps) assert len(result_add) == 2 - def test_irrep_comparison_operations(self): - """Test Irrep comparison operations""" - irrep1 = Irrep(1, 1) - irrep2 = Irrep(1, 1) - irrep3 = Irrep(2, 1) - irrep4 = Irrep(1, -1) - - # Test equality - assert irrep1 == irrep2 - assert irrep1 != irrep3 - assert irrep1 != irrep4 - - # Test ordering - assert irrep1 < irrep3 # lower l comes first - assert irrep4 < irrep1 # same l, negative p comes first - - def test_irrep_iteration(self): - """Test Irrep iteration (deconstruction)""" - irrep = Irrep(2, -1) - l, p = irrep - assert l == 2 - assert p == -1 - - def test_irrep_error_handling(self): - """Test Irrep error handling for invalid inputs""" - # Test negative l + def test_irrep_error_handling_and_wigner(self): + """Test Irrep error handling and Wigner D matrix""" + # Test error handling with pytest.raises(ValueError): - Irrep(-1, 1) + Irrep(-1, 1) # Negative l - # Test invalid parity with pytest.raises(ValueError): - Irrep(1, 0) + Irrep(1, 2) # Invalid parity - with pytest.raises(ValueError): - Irrep(1, 2) - - # Test invalid string format with pytest.raises(ValueError): Irrep("invalid") - # Test invalid type for l - with pytest.raises(TypeError): - Irrep(1.5, 1) - - def test_irrep_wigner_d_matrix(self): - """Test Irrep Wigner D matrix computation""" + # Test Wigner D matrix irrep = Irrep(1, -1) - - # Test with identity matrix R = ops.eye(3) d_matrix = irrep.wigD_from_matrix(R) assert d_matrix.shape == (3, 3) - # Test with negative identity (inversion) - R_neg = -ops.eye(3) - d_matrix_neg = irrep.wigD_from_matrix(R_neg) - assert d_matrix_neg.shape == (3, 3) - - # Test error handling for non-tensor input + # Test error for non-tensor input with pytest.raises(TypeError): irrep.wigD_from_matrix([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) @@ -167,15 +107,14 @@ class TestIrrep: class TestIrreps: """Test cases for Irreps class""" - def test_irreps_basic_creation(self): - """Test basic Irreps creation with different input formats""" + def test_irreps_creation_and_basic_operations(self): + """Test Irreps creation and basic operations""" # Test creation from string irreps1 = Irreps("1x0e+2x1o") assert len(irreps1) == 2 assert irreps1.data[0].mul == 1 assert irreps1.data[0].ir == Irrep(0, 1) - assert irreps1.data[1].mul == 2 - assert irreps1.data[1].ir == Irrep(1, -1) + assert str(irreps1) == "1x0e+2x1o" # Test creation from list of tuples irreps2 = Irreps([(1, (0, 1)), (2, (1, -1))]) @@ -184,19 +123,19 @@ class TestIrreps: # Test creation from single Irrep irreps3 = Irreps(Irrep(1, 1)) assert len(irreps3) == 1 - assert irreps3.data[0].mul == 1 - assert irreps3.data[0].ir == Irrep(1, 1) # Test empty creation irreps4 = Irreps() assert len(irreps4) == 0 assert irreps4.dim == 0 - irreps5 = Irreps("") - assert len(irreps5) == 0 + # Test single irrep without multiplicity + irreps_single = Irreps("1o") + assert irreps_single.data[0].mul == 1 + assert irreps_single.data[0].ir == Irrep(1, -1) - def test_irreps_properties(self): - """Test Irreps properties""" + def test_irreps_properties_and_slicing(self): + """Test Irreps properties and slicing operations""" irreps = Irreps("2x0e+3x1o+1x2e") # Test dimension @@ -209,222 +148,101 @@ class TestIrreps: assert irreps.slice[1] == slice(2, 11) assert irreps.slice[2] == slice(11, 16) - # Test slice_tuples - assert irreps.slice_tuples == [(0, 2), (2, 9), (11, 5)] - - def test_irreps_string_operations(self): - """Test Irreps string representation and parsing""" - # Test string representation - irreps = Irreps("2x0e+3x1o") - assert str(irreps) == "2x0e+3x1o" + # Test lmax and num_irreps properties + assert irreps.lmax == 2 + assert irreps.num_irreps == 6 # 2 + 3 + 1 - # Test complex string parsing - irreps_complex = Irreps("100x0e+50x1e+0x2e") - assert len(irreps_complex) == 3 - assert irreps_complex.data[2].mul == 0 - - # Test single irrep without multiplicity - irreps_single = Irreps("1o") - assert len(irreps_single) == 1 - assert irreps_single.data[0].mul == 1 - assert irreps_single.data[0].ir == Irrep(1, -1) + # Test contains operation + assert Irrep(0, 1) in irreps + assert Irrep(3, 1) not in irreps - def test_irreps_arithmetic_operations(self): - """Test Irreps arithmetic operations""" + def test_irreps_arithmetic_and_operations(self): + """Test Irreps arithmetic operations and advanced features""" irreps1 = Irreps("1x0e+1x1o") irreps2 = Irreps("2x0e+1x2e") # Test addition result_add = irreps1 + irreps2 assert len(result_add) == 4 - expected = Irreps("1x0e+1x1o+2x0e+1x2e") - assert result_add == expected # Test multiplication with integer result_mul = irreps1 * 2 expected_mul = Irreps("2x0e+2x1o") assert result_mul == expected_mul - # Test tensor product with another Irreps - irreps_small1 = Irreps("1x0e") - irreps_small2 = Irreps("1x1o") - result_tensor = irreps_small1 * irreps_small2 - expected_tensor = Irreps("1x1o") - assert result_tensor == expected_tensor + # Test comparison operations + assert irreps1 == Irreps("1x0e+1x1o") + assert irreps1 != irreps2 - def test_irreps_contains_operation(self): - """Test Irreps contains operation""" - irreps = Irreps("1x0e+2x1o+1x2e") - - # Test single Irrep containment - assert Irrep(0, 1) in irreps - assert Irrep(1, -1) in irreps - assert Irrep(2, 1) in irreps - assert Irrep(3, 1) not in irreps + # Test iteration + for i, (mul, ir) in enumerate(irreps1): + if i == 0: + assert mul == 1 and ir == Irrep(0, 1) + elif i == 1: + assert mul == 1 and ir == Irrep(1, -1) - # Test Irreps containment - subset = Irreps("1x0e+1x1o") - assert subset in irreps - - subset_not = Irreps("3x1o") # More than available - assert subset_not not in irreps - - def test_irreps_comparison_operations(self): - """Test Irreps comparison operations""" - irreps1 = Irreps("1x0e+2x1o") - irreps2 = Irreps("1x0e+2x1o") - irreps3 = Irreps("2x0e+1x1o") - - # Test equality - assert irreps1 == irreps2 - assert irreps1 != irreps3 - - # Test hash (for use in sets/dicts) - irreps_set = {irreps1, irreps2, irreps3} - assert len(irreps_set) == 2 # irreps1 and irreps2 are the same - - def test_irreps_iteration(self): - """Test Irreps iteration""" - irreps = Irreps("2x0e+1x1o") - - # Test iteration over _MulIr objects - mul_irs = list(irreps) - assert len(mul_irs) == 2 - assert mul_irs[0].mul == 2 - assert mul_irs[0].ir == Irrep(0, 1) - assert mul_irs[1].mul == 1 - assert mul_irs[1].ir == Irrep(1, -1) - - def test_irreps_advanced_operations(self): - """Test advanced Irreps operations like simplify and sort""" - # Test with repeated irreps that should be simplified - irreps_unsorted = Irreps("1x1o+2x0e+1x1o+3x0e") - - # The simplify and sort operations are internal, - # but we can test the final result - assert len(irreps_unsorted) == 4 # Before simplification - - def test_irreps_error_handling(self): - """Test Irreps error handling for invalid inputs""" + def test_irreps_error_handling_and_edge_cases(self): + """Test Irreps error handling and edge cases""" # Test invalid string format with pytest.raises(ValueError): Irreps("invalid_format") - # Test negative multiplicity should raise ValueError + # Test negative multiplicity with pytest.raises(ValueError): Irreps([(-1, (0, 1))]) - # Test invalid multiplicity type should raise ValueError (not TypeError) + # Test invalid multiplicity type with pytest.raises(ValueError): Irreps([(1.5, (0, 1))]) - def test_irreps_complex_scenarios(self): - """Test complex scenarios and edge cases""" - # Test large irreps - large_irreps = Irreps("100x0e+50x1e+25x2e") - assert large_irreps.dim == 100 * 1 + 50 * 3 + 25 * 5 - assert len(large_irreps) == 3 + # Test empty Irreps lmax property + irreps_empty = Irreps("") + with pytest.raises(ValueError): + _ = irreps_empty.lmax # Test zero multiplicity zero_irreps = Irreps("0x1o+2x0e") assert len(zero_irreps) == 2 assert zero_irreps.data[0].mul == 0 - # Test mixed parity - mixed_irreps = Irreps("1x0e+1x0o+1x1e+1x1o") - assert len(mixed_irreps) == 4 - - def test_irreps_lmax_property(self): - """Test lmax property of Irreps""" - irreps1 = Irreps("1x0e+1x1o+1x2e") - assert irreps1.lmax == 2 - - irreps2 = Irreps("1x0e") - assert irreps2.lmax == 0 - - # Empty Irreps should raise ValueError when accessing lmax - irreps3 = Irreps("") - with pytest.raises(ValueError): - _ = irreps3.lmax - - def test_irreps_num_irreps_property(self): - """Test num_irreps property of Irreps""" - irreps1 = Irreps("2x0e+3x1o") - assert irreps1.num_irreps == 5 # 2 + 3 - - irreps2 = Irreps("1x2e") - assert irreps2.num_irreps == 1 - - irreps3 = Irreps("") - assert irreps3.num_irreps == 0 + # Test large irreps + large_irreps = Irreps("100x0e+50x1e") + assert large_irreps.dim == 100 * 1 + 50 * 3 + assert len(large_irreps) == 2 class TestMulIr: """Test cases for _MulIr class""" - def test_mulir_basic_creation(self): - """Test basic _MulIr creation""" + def test_mulir_comprehensive(self): + """Test _MulIr creation, properties and operations""" from mindscience.e3nn.o3.irreps import _MulIr - # Test creation with mul and Irrep + # Test creation and properties irrep = Irrep(1, 1) mulir = _MulIr(3, irrep) assert mulir.mul == 3 assert mulir.ir == irrep assert mulir.dim == 3 * 3 # mul * irrep.dim + assert str(mulir) == "3x1e" - def test_mulir_properties(self): - """Test _MulIr properties""" - from mindscience.e3nn.o3.irreps import _MulIr - - irrep = Irrep(2, -1) - mulir = _MulIr(2, irrep) - - # Test dimension - assert mulir.dim == 2 * 5 # 2 * (2*2+1) - - # Test string representation - assert str(mulir) == "2x2o" - - def test_mulir_iteration(self): - """Test _MulIr iteration""" - from mindscience.e3nn.o3.irreps import _MulIr - - irrep = Irrep(1, -1) - mulir = _MulIr(4, irrep) - - # Test deconstruction + # Test iteration/deconstruction mul, ir = mulir - assert mul == 4 + assert mul == 3 assert ir == irrep - def test_mulir_comparison_operations(self): - """Test _MulIr comparison operations""" - from mindscience.e3nn.o3.irreps import _MulIr + # Test comparison operations + mulir2 = _MulIr(3, irrep) + mulir3 = _MulIr(2, irrep) + mulir4 = _MulIr(3, Irrep(2, 1)) - irrep1 = Irrep(1, 1) - irrep2 = Irrep(2, 1) - mulir1 = _MulIr(2, irrep1) - mulir2 = _MulIr(2, irrep1) - mulir3 = _MulIr(3, irrep1) - mulir4 = _MulIr(2, irrep2) - - # Test equality - assert mulir1 == mulir2 - assert mulir1 != mulir3 - assert mulir1 != mulir4 - - # Test ordering (by irrep first, then by multiplicity) - assert mulir1 < mulir4 # irrep1 < irrep2 - assert mulir1 < mulir3 # same irrep, but mul1 < mul3 - - def test_mulir_error_handling(self): - """Test _MulIr error handling""" - from mindscience.e3nn.o3.irreps import _MulIr + assert mulir == mulir2 + assert mulir != mulir3 + assert mulir < mulir4 # Compare by irrep first - # Test invalid types + # Test error handling with pytest.raises(TypeError): - _MulIr(1.5, Irrep(1, 1)) # mul should be int + _MulIr(1.5, irrep) # mul should be int with pytest.raises(TypeError): _MulIr(2, "1e") # ir should be Irrep instance \ No newline at end of file -- Gitee From d7cabaa7e01322bf3a123ad91ad8e02487443f9d Mon Sep 17 00:00:00 2001 From: birfied Date: Fri, 19 Sep 2025 15:05:35 +0800 Subject: [PATCH 14/21] pylint --- tests/e3nn/o3/test_irreps.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/e3nn/o3/test_irreps.py b/tests/e3nn/o3/test_irreps.py index 234ebbe09..6a81937f4 100644 --- a/tests/e3nn/o3/test_irreps.py +++ b/tests/e3nn/o3/test_irreps.py @@ -15,8 +15,7 @@ """Test cases for irreps module""" import pytest -import numpy as np -from mindspore import Tensor, ops, float32 +from mindspore import ops from mindscience.e3nn.o3 import Irrep, Irreps @@ -95,8 +94,8 @@ class TestIrrep: # Test Wigner D matrix irrep = Irrep(1, -1) - R = ops.eye(3) - d_matrix = irrep.wigD_from_matrix(R) + rotation_matrix = ops.eye(3) + d_matrix = irrep.wigD_from_matrix(rotation_matrix) assert d_matrix.shape == (3, 3) # Test error for non-tensor input @@ -126,7 +125,7 @@ class TestIrreps: # Test empty creation irreps4 = Irreps() - assert len(irreps4) == 0 + assert not irreps4 # Check if empty assert irreps4.dim == 0 # Test single irrep without multiplicity @@ -245,4 +244,4 @@ class TestMulIr: _MulIr(1.5, irrep) # mul should be int with pytest.raises(TypeError): - _MulIr(2, "1e") # ir should be Irrep instance \ No newline at end of file + _MulIr(2, "1e") # ir should be Irrep instance -- Gitee From a49974bca29bdfd6376da284e9f395fbe13c7fe5 Mon Sep 17 00:00:00 2001 From: birfied Date: Fri, 19 Sep 2025 15:08:04 +0800 Subject: [PATCH 15/21] pylint --- tests/e3nn/o3/test_irreps.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/e3nn/o3/test_irreps.py b/tests/e3nn/o3/test_irreps.py index 6a81937f4..f382f3541 100644 --- a/tests/e3nn/o3/test_irreps.py +++ b/tests/e3nn/o3/test_irreps.py @@ -244,4 +244,4 @@ class TestMulIr: _MulIr(1.5, irrep) # mul should be int with pytest.raises(TypeError): - _MulIr(2, "1e") # ir should be Irrep instance + _MulIr(2, "1e") # ir should be Irrep instance -- Gitee From 89fbdd47c5dd7d39decbb732b20425e26fd4c8e8 Mon Sep 17 00:00:00 2001 From: birfied Date: Fri, 19 Sep 2025 15:25:51 +0800 Subject: [PATCH 16/21] pytest --- tests/e3nn/o3/test_norm.py | 129 +++++++++++++++++++++++++++++++++++++ 1 file changed, 129 insertions(+) create mode 100644 tests/e3nn/o3/test_norm.py diff --git a/tests/e3nn/o3/test_norm.py b/tests/e3nn/o3/test_norm.py new file mode 100644 index 000000000..787e4be54 --- /dev/null +++ b/tests/e3nn/o3/test_norm.py @@ -0,0 +1,129 @@ +"""Test cases for e3nn.o3.norm module - Streamlined core functionality""" +import pytest +import numpy as np +import mindspore as ms +from mindspore import Tensor, float32, float64 + +from mindscience.e3nn.o3 import Norm, Irreps + + +class TestNorm: + """Streamlined tests for Norm class""" + + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_norm_creation_and_basic_properties(self): + """Test Norm creation with different irreps and basic properties""" + # Test basic creation with string irreps + norm1 = Norm('1x0e') + assert norm1.irreps_in == Irreps('1x0e') + assert norm1.irreps_out == Irreps('1x0e') + assert not norm1.squared + + # Test creation with Irreps object and squared parameter + irreps_in = Irreps('2x1o + 3x0e') + norm2 = Norm(irreps_in, squared=True) + assert norm2.irreps_in == irreps_in.simplify() + assert norm2.irreps_out == Irreps('2x0e + 3x0e').simplify() + assert norm2.squared + + # Test string representation + repr_str = repr(norm2) + assert 'Norm' in repr_str + + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_norm_forward_pass_comprehensive(self): + """Test forward pass with various irrep types and configurations""" + # Test scalar irrep (0e) + norm_scalar = Norm('2x0e') + scalar_input = Tensor([1.0, -2.0], dtype=float32) + scalar_output = norm_scalar(scalar_input) + np.testing.assert_allclose(scalar_output.asnumpy(), [1.0, 2.0], rtol=1e-5) + + # Test vector irrep (1o) with batch processing + norm_vector = Norm('1x1o') + vector_batch = Tensor([[3.0, 4.0, 0.0], [0.0, 0.0, 0.0]], dtype=float32) + vector_output = norm_vector(vector_batch) + expected = np.array([[5.0], [0.0]]) + np.testing.assert_allclose(vector_output.asnumpy(), expected, rtol=1e-5) + + # Test mixed irreps + norm_mixed = Norm('1x0e + 1x1o') + mixed_input = Tensor([2.0, 3.0, 4.0, 0.0], dtype=float32) + mixed_output = norm_mixed(mixed_input) + expected_mixed = np.array([2.0, 5.0]) # scalar norm + vector norm + np.testing.assert_allclose(mixed_output.asnumpy(), expected_mixed, rtol=1e-5) + + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_norm_squared_and_dtype_consistency(self): + """Test squared parameter and dtype consistency""" + # Test squared vs regular norm + norm_regular = Norm('1x1o', squared=False, dtype=float64) + norm_squared = Norm('1x1o', squared=True, dtype=float64) + + input_vec = Tensor([3.0, 4.0, 0.0], dtype=float64) + output_regular = norm_regular(input_vec) + output_squared = norm_squared(input_vec) + + # Verify squared relationship and dtype consistency + np.testing.assert_allclose(output_regular.asnumpy(), [5.0], rtol=1e-5) + np.testing.assert_allclose(output_squared.asnumpy(), [25.0], rtol=1e-5) + assert output_regular.dtype == float64 + assert output_squared.dtype == float64 + + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_norm_mathematical_properties_and_edge_cases(self): + """Test mathematical properties and edge cases""" + norm = Norm('1x1o') + + # Test scaling property: ||k*v|| = |k| * ||v|| + vector = Tensor([3.0, 4.0, 0.0], dtype=float32) + scaled_vector = Tensor([6.0, 8.0, 0.0], dtype=float32) + + norm_original = norm(vector) + norm_scaled = norm(scaled_vector) + np.testing.assert_allclose(norm_scaled.asnumpy(), 2.0 * norm_original.asnumpy(), rtol=1e-5) + + # Test with very small values + small_input = Tensor([1e-10, 1e-10, 1e-10], dtype=float32) + small_output = norm(small_input) + expected_small = np.sqrt(3) * 1e-10 + np.testing.assert_allclose(small_output.asnumpy(), [expected_small], rtol=1e-5) + + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_norm_higher_order_and_mixed_parity(self): + """Test higher order irreps and mixed parity""" + # Test l=2 irrep + norm_l2 = Norm('1x2e') + l2_input = Tensor([1.0, 1.0, 1.0, 1.0, 1.0], dtype=float32) + l2_output = norm_l2(l2_input) + expected_l2 = np.sqrt(5.0) + np.testing.assert_allclose(l2_output.asnumpy(), [expected_l2], rtol=1e-5) + + # Test mixed parity + norm_mixed_parity = Norm('1x0e + 1x1o + 1x0o') + mixed_parity_input = Tensor([2.0, 1.0, 1.0, 1.0, 3.0], dtype=float32) + mixed_parity_output = norm_mixed_parity(mixed_parity_input) + expected_mixed_parity = np.array([2.0, np.sqrt(3.0), 3.0]) + np.testing.assert_allclose(mixed_parity_output.asnumpy(), expected_mixed_parity, rtol=1e-5) + + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + def test_norm_error_handling(self): + """Test error handling for invalid inputs""" + norm = Norm('1x1o') + + # Test with wrong input dimension + with pytest.raises((ValueError, RuntimeError)): + wrong_dim_input = Tensor([1.0, 2.0], dtype=float32) # Should be 3D for 1x1o + norm(wrong_dim_input) -- Gitee From 45315cd7d0e0cee390c367878aa71ce8e1607455 Mon Sep 17 00:00:00 2001 From: birfied Date: Fri, 19 Sep 2025 15:28:58 +0800 Subject: [PATCH 17/21] pylint --- tests/e3nn/o3/test_norm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/e3nn/o3/test_norm.py b/tests/e3nn/o3/test_norm.py index 787e4be54..3ba91be44 100644 --- a/tests/e3nn/o3/test_norm.py +++ b/tests/e3nn/o3/test_norm.py @@ -1,7 +1,6 @@ """Test cases for e3nn.o3.norm module - Streamlined core functionality""" import pytest import numpy as np -import mindspore as ms from mindspore import Tensor, float32, float64 from mindscience.e3nn.o3 import Norm, Irreps -- Gitee From 15afa0314dc22db4c0ac89fb7f77fa93b24b6f67 Mon Sep 17 00:00:00 2001 From: birfied Date: Fri, 19 Sep 2025 16:31:31 +0800 Subject: [PATCH 18/21] add test cases --- tests/e3nn/o3/test_norm.py | 12 +- tests/e3nn/o3/test_rotation.py | 179 ++++++++++++++++++++ tests/e3nn/o3/test_spherical_harmonics.py | 171 +++++++++++++++++++ tests/e3nn/o3/test_sub.py | 145 +++++++++++++++++ tests/e3nn/o3/test_tensor_product.py | 115 +++++++++++++ tests/e3nn/o3/test_wigner.py | 113 +++++++++++++ tests/e3nn/utils/test_utils.py | 190 ++++++++++++++++++++++ 7 files changed, 919 insertions(+), 6 deletions(-) create mode 100644 tests/e3nn/o3/test_rotation.py create mode 100644 tests/e3nn/o3/test_spherical_harmonics.py create mode 100644 tests/e3nn/o3/test_sub.py create mode 100644 tests/e3nn/o3/test_tensor_product.py create mode 100644 tests/e3nn/o3/test_wigner.py create mode 100644 tests/e3nn/utils/test_utils.py diff --git a/tests/e3nn/o3/test_norm.py b/tests/e3nn/o3/test_norm.py index 3ba91be44..d356356c8 100644 --- a/tests/e3nn/o3/test_norm.py +++ b/tests/e3nn/o3/test_norm.py @@ -1,7 +1,7 @@ """Test cases for e3nn.o3.norm module - Streamlined core functionality""" import pytest import numpy as np -from mindspore import Tensor, float32, float64 +from mindspore import Tensor, float32 from mindscience.e3nn.o3 import Norm, Irreps @@ -62,18 +62,18 @@ class TestNorm: def test_norm_squared_and_dtype_consistency(self): """Test squared parameter and dtype consistency""" # Test squared vs regular norm - norm_regular = Norm('1x1o', squared=False, dtype=float64) - norm_squared = Norm('1x1o', squared=True, dtype=float64) + norm_regular = Norm('1x1o', squared=False, dtype=float32) + norm_squared = Norm('1x1o', squared=True, dtype=float32) - input_vec = Tensor([3.0, 4.0, 0.0], dtype=float64) + input_vec = Tensor([3.0, 4.0, 0.0], dtype=float32) output_regular = norm_regular(input_vec) output_squared = norm_squared(input_vec) # Verify squared relationship and dtype consistency np.testing.assert_allclose(output_regular.asnumpy(), [5.0], rtol=1e-5) np.testing.assert_allclose(output_squared.asnumpy(), [25.0], rtol=1e-5) - assert output_regular.dtype == float64 - assert output_squared.dtype == float64 + assert output_regular.dtype == float32 + assert output_squared.dtype == float32 @pytest.mark.level0 @pytest.mark.platform_arm_ascend910b_training diff --git a/tests/e3nn/o3/test_rotation.py b/tests/e3nn/o3/test_rotation.py new file mode 100644 index 000000000..f3d581d77 --- /dev/null +++ b/tests/e3nn/o3/test_rotation.py @@ -0,0 +1,179 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test cases for rotation module.""" + +import math +import pytest +import numpy as np +from mindspore import Tensor, float32 + +from mindscience.e3nn.o3.rotation import ( + identity_angles, rand_angles, compose_angles, + matrix_x, matrix_y, matrix_z, + angles_to_matrix, matrix_to_angles, + angles_to_xyz, xyz_to_angles +) + + +class TestRotation: + """Test class for rotation functions.""" + + def test_identity_angles(self): + """Test identity_angles function comprehensively.""" + # Test basic functionality and shapes + alpha, beta, gamma = identity_angles(2, 3) + assert alpha.shape == (2, 3) + assert beta.shape == (2, 3) + assert gamma.shape == (2, 3) + assert np.allclose(alpha.asnumpy(), 0.0) + assert np.allclose(beta.asnumpy(), 0.0) + assert np.allclose(gamma.asnumpy(), 0.0) + + # Test dtype + alpha, beta, gamma = identity_angles(2, dtype=float32) + assert alpha.dtype == float32 + + # Test error handling + with pytest.raises(TypeError): + identity_angles(1.5) # Should be int + + def test_rand_angles(self): + """Test rand_angles function comprehensively.""" + # Test shapes and angle ranges + alpha, beta, gamma = rand_angles(2, 3) + assert alpha.shape == (2, 3) + assert beta.shape == (2, 3) + assert gamma.shape == (2, 3) + assert np.all(alpha.asnumpy() >= 0) and np.all(alpha.asnumpy() <= 2 * math.pi) + assert np.all(beta.asnumpy() >= 0) and np.all(beta.asnumpy() <= math.pi) + assert np.all(gamma.asnumpy() >= 0) and np.all(gamma.asnumpy() <= 2 * math.pi) + + # Test error handling + with pytest.raises(TypeError): + rand_angles(1.5) # Should be int + + def test_rotation_matrices(self): + """Test rotation matrix functions (matrix_x, matrix_y, matrix_z).""" + # Test identity matrices with zero angle + for matrix_func in [matrix_x, matrix_y, matrix_z]: + mat = matrix_func(0.0) + assert np.allclose(mat.asnumpy(), np.eye(3), atol=1e-6) + + # Test specific rotations + mat_x = matrix_x(math.pi / 2) + expected_x = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]]) + assert np.allclose(mat_x.asnumpy(), expected_x, atol=1e-6) + + mat_y = matrix_y(math.pi / 2) + expected_y = np.array([[0, 0, 1], [0, 1, 0], [-1, 0, 0]]) + assert np.allclose(mat_y.asnumpy(), expected_y, atol=1e-6) + + mat_z = matrix_z(math.pi / 2) + expected_z = np.array([[0, -1, 0], [1, 0, 0], [0, 0, 1]]) + assert np.allclose(mat_z.asnumpy(), expected_z, atol=1e-6) + + # Test batch operations + angles = Tensor([0.1, 0.2, 0.3]) + mat = matrix_x(angles) + assert mat.shape == (3, 3, 3) + + def test_rotation_matrices_orthogonal(self): + """Test that rotation matrices are orthogonal.""" + angle = 0.5 + for matrix_func in [matrix_x, matrix_y, matrix_z]: + mat = matrix_func(angle) + # Check orthogonality: R @ R.T = I + identity = np.matmul(mat.asnumpy(), mat.asnumpy().T) + assert np.allclose(identity, np.eye(3), atol=1e-6) + # Check determinant = 1 + assert np.allclose(np.linalg.det(mat.asnumpy()), 1.0, atol=1e-6) + + def test_angle_matrix_conversion(self): + """Test angles_to_matrix and matrix_to_angles functions.""" + # Test identity conversion + mat = angles_to_matrix(0.0, 0.0, 0.0) + assert np.allclose(mat.asnumpy(), np.eye(3), atol=1e-6) + + # Test roundtrip conversion + alpha_orig = Tensor([0.1, 0.2, 0.3]) + beta_orig = Tensor([0.4, 0.5, 0.6]) + gamma_orig = Tensor([0.7, 0.8, 0.9]) + + mat = angles_to_matrix(alpha_orig, beta_orig, gamma_orig) + assert mat.shape == (3, 3, 3) + alpha_new, beta_new, gamma_new = matrix_to_angles(mat) + assert np.allclose(alpha_orig.asnumpy(), alpha_new.asnumpy(), atol=1e-5) + assert np.allclose(beta_orig.asnumpy(), beta_new.asnumpy(), atol=1e-5) + assert np.allclose(gamma_orig.asnumpy(), gamma_new.asnumpy(), atol=1e-5) + + def test_angles_matrix_roundtrip(self): + """Test roundtrip conversion between angles and matrix.""" + # Test multiple angle sets + test_angles = [ + (0.1, 0.2, 0.3), + (0.4, 0.5, 0.6), + (1.0, 1.5, 2.0), + (math.pi/4, math.pi/3, math.pi/6) + ] + + for alpha, beta, gamma in test_angles: + # Convert angles to matrix and back + mat = angles_to_matrix(alpha, beta, gamma) + alpha_rec, beta_rec, gamma_rec = matrix_to_angles(mat) + + # Check if we get back the same angles (within tolerance) + # Note: Euler angles may have multiple representations + mat_rec = angles_to_matrix(alpha_rec, beta_rec, gamma_rec) + assert np.allclose(mat.asnumpy(), mat_rec.asnumpy(), atol=1e-5) + + def test_matrix_to_angles_error(self): + """Test matrix_to_angles error handling.""" + # Test with non-rotation matrix (determinant != 1) + invalid_matrix = Tensor(np.array([[2, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.float32)) + with pytest.raises(ValueError): + matrix_to_angles(invalid_matrix) + + def test_angle_operations(self): + """Test compose_angles, angles_to_xyz, and xyz_to_angles functions.""" + # Test compose_angles with identity + alpha_comp, beta_comp, gamma_comp = compose_angles(0.0, 0.0, 0.0, 0.1, 0.2, 0.3) + assert np.allclose(alpha_comp.asnumpy(), 0.1, atol=1e-6) + assert np.allclose(beta_comp.asnumpy(), 0.2, atol=1e-6) + assert np.allclose(gamma_comp.asnumpy(), 0.3, atol=1e-6) + + # Test angles_to_xyz and xyz_to_angles roundtrip + xyz = angles_to_xyz(0.0, 0.0) + assert np.allclose(xyz.asnumpy(), [0.0, 1.0, 0.0], atol=1e-6) + alpha, beta = xyz_to_angles(xyz) + assert np.allclose(alpha.asnumpy(), 0.0, atol=1e-6) + assert np.allclose(beta.asnumpy(), 0.0, atol=1e-6) + + def test_batch_and_edge_cases(self): + """Test batch operations and edge cases.""" + # Test batch operations + alphas = Tensor(np.array([[0.1, 0.2], [0.3, 0.4]]).astype(np.float32)) + betas = Tensor(np.array([[0.5, 0.6], [0.7, 0.8]]).astype(np.float32)) + gammas = Tensor(np.array([[0.9, 1.0], [1.1, 1.2]]).astype(np.float32)) + matrices = angles_to_matrix(alphas, betas, gammas) + assert matrices.shape == (2, 2, 3, 3) + # Test edge case: small angles + mat = angles_to_matrix(1e-8, 1e-8, 1e-8) + assert np.allclose(mat.asnumpy(), np.eye(3), atol=1e-6) + + # Test edge case: pi angles (should still be valid rotation matrix) + mat = angles_to_matrix(math.pi, math.pi, math.pi) + identity = np.matmul(mat.asnumpy(), mat.asnumpy().T) + assert np.allclose(identity, np.eye(3), atol=1e-5) + assert np.allclose(np.linalg.det(mat.asnumpy()), 1.0, atol=1e-5) diff --git a/tests/e3nn/o3/test_spherical_harmonics.py b/tests/e3nn/o3/test_spherical_harmonics.py new file mode 100644 index 000000000..f11042a03 --- /dev/null +++ b/tests/e3nn/o3/test_spherical_harmonics.py @@ -0,0 +1,171 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test spherical harmonics module.""" + +import numpy as np +from mindspore import Tensor, float32 +from mindscience.e3nn.o3 import spherical_harmonics, SphericalHarmonics + + +class TestSphericalHarmonicsFunction: + """Test spherical_harmonics function.""" + + def test_core_functionality(self): + """Test core spherical harmonics functionality including degrees and normalization.""" + x = Tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], float32) + + # Test l=0 (constant function) + result_l0 = spherical_harmonics(0, x) + assert result_l0.shape == (2, 1) + np.testing.assert_allclose(result_l0.asnumpy(), [[0.28209479], [0.28209479]], rtol=1e-5) + + # Test different degrees + result_l1 = spherical_harmonics(1, x[:1]) + assert result_l1.shape == (1, 3) + result_l2 = spherical_harmonics(2, x[:1]) + assert result_l2.shape == (1, 5) + + # Test multiple degrees + result_multi = spherical_harmonics([0, 1, 2], x[:1]) + assert result_multi.shape == (1, 9) # 1 + 3 + 5 + + def test_normalization_and_parameters(self): + """Test normalization methods and normalize parameter.""" + x = Tensor([[1.0, 0.0, 0.0]], float32) + x_unnorm = Tensor([[2.0, 0.0, 0.0]], float32) + + # Test different normalization methods + result_integral = spherical_harmonics(1, x, normalization='integral') + result_component = spherical_harmonics(1, x, normalization='component') + result_norm = spherical_harmonics(1, x, normalization='norm') + + # Results should be different for different normalizations + assert not np.allclose(result_integral.asnumpy(), result_component.asnumpy()) + assert not np.allclose(result_integral.asnumpy(), result_norm.asnumpy()) + + # Test normalize parameter + result_normalized = spherical_harmonics(1, x_unnorm, normalize=True) + result_unnormalized = spherical_harmonics(1, x_unnorm, normalize=False) + assert not np.allclose(result_normalized.asnumpy(), result_unnormalized.asnumpy()) + + def test_batch_and_shapes(self): + """Test batch processing and different input shapes.""" + # Multiple vectors + x = Tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], float32) + result = spherical_harmonics(2, x) + assert result.shape == (3, 5) + + # Higher dimensional batch + x_batch = Tensor(np.random.randn(2, 3, 3).astype(np.float32)) + result_batch = spherical_harmonics(1, x_batch) + assert result_batch.shape == (2, 3, 3) + + # 3D input + x_3d = Tensor([[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]], float32) + result_3d = spherical_harmonics(1, x_3d) + assert result_3d.shape == (1, 2, 3) + + +class TestSphericalHarmonicsClass: + """Test the SphericalHarmonics class.""" + + def test_class_initialization_and_forward(self): + """Test class initialization and forward computation.""" + # Test initialization + sh = SphericalHarmonics(2, normalize=True) + assert sh._lmax == 2 + assert sh.irreps_out.dim == 5 + + # Test forward computation + x = Tensor([[1.0, 0.0, 0.0]], float32) + result = sh(x) + assert result.shape == (1, 5) + + # Compare with function version + result_func = spherical_harmonics(2, x) + np.testing.assert_allclose(result.asnumpy(), result_func.asnumpy(), rtol=1e-5) + + def test_consistency_and_parity(self): + """Test normalization consistency and parity.""" + x = Tensor([[1.0, 0.0, 0.0]], float32) + + # Test normalization consistency + sh_integral = SphericalHarmonics(1, normalize=True, normalization='integral') + sh_component = SphericalHarmonics(1, normalize=True, normalization='component') + result_integral = sh_integral(x) + result_component = sh_component(x) + assert not np.allclose(result_integral.asnumpy(), result_component.asnumpy()) + + # Test parity consistency + sh = SphericalHarmonics(2, normalize=True) + x_pos = Tensor([[1.0, 0.0, 0.0]], float32) + x_neg = Tensor([[-1.0, 0.0, 0.0]], float32) + result_pos = sh(x_pos) + result_neg = sh(x_neg) + # For even l, parity should be preserved + np.testing.assert_allclose(result_pos.asnumpy(), result_neg.asnumpy(), rtol=1e-5) + + +class TestMathematicalProperties: + """Test mathematical properties of spherical harmonics.""" + + def test_mathematical_properties(self): + """Test basic mathematical properties and rotation equivariance.""" + # Test basic properties for l=1 + x = Tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], float32) + result = spherical_harmonics(1, x) + + # Test that results are finite and have correct shape + assert result.shape == (3, 3) + assert np.all(np.isfinite(result.asnumpy())) + + # Test rotation equivariance (simplified) + x_original = Tensor([[1.0, 0.0, 0.0]], float32) + x_rotated = Tensor([[0.0, 1.0, 0.0]], float32) # 90° rotation around z + sh_original = spherical_harmonics(1, x_original) + sh_rotated = spherical_harmonics(1, x_rotated) + # Results should be different for different orientations + assert sh_original.shape == sh_rotated.shape + assert not np.allclose(sh_original.asnumpy(), sh_rotated.asnumpy()) + + +class TestEdgeCases: + """Test edge cases and error conditions.""" + + def test_edge_cases(self): + """Test zero vectors, high degrees, and error conditions.""" + # Test zero vector + x_zero = Tensor([[0.0, 0.0, 0.0]], float32) + result_zero = spherical_harmonics(1, x_zero) + assert result_zero.shape == (1, 3) + + # Test high degree + x = Tensor([[1.0, 0.0, 0.0]], float32) + result_high = spherical_harmonics(5, x) + assert result_high.shape == (1, 11) # 2*5+1 + + # Test invalid degree (should raise error) + try: + spherical_harmonics(-1, x) + assert False, "Should raise error for negative degree" + except (ValueError, TypeError): + pass + + # Test invalid normalization + try: + spherical_harmonics(1, x, normalization='invalid') + assert False, "Should raise error for invalid normalization" + except (ValueError, TypeError): + pass diff --git a/tests/e3nn/o3/test_sub.py b/tests/e3nn/o3/test_sub.py new file mode 100644 index 000000000..8bfdb3915 --- /dev/null +++ b/tests/e3nn/o3/test_sub.py @@ -0,0 +1,145 @@ +# Copyright 2021-2023 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Test cases for o3.sub module. + +This module contains comprehensive tests for all classes and functions in the o3.sub module, +including tensor product operations, linear operations, and utility functions. +""" + +import pytest +import numpy as np +import mindspore as ms +from mindspore import Tensor + +from mindscience.e3nn.o3.sub import ( + FullyConnectedTensorProduct, + FullTensorProduct, + ElementwiseTensorProduct, + Linear, + LinearBias, + TensorSquare, + prod, + _prod, + _sum_tensors_withbias, + _compose, + _run_continue, + Instruction +) +from mindscience.e3nn.o3.irreps import Irreps + + +class TestTensorProductClasses: + """Test tensor product classes functionality.""" + + def test_tensor_product_operations(self): + """Test core tensor product operations.""" + # Test FullyConnectedTensorProduct + tp_fc = FullyConnectedTensorProduct('1x1o', '1x0e', '1x1o') + x1 = Tensor(np.random.randn(2, 3), ms.float32) + x2 = Tensor(np.random.randn(2, 1), ms.float32) + output_fc = tp_fc(x1, x2) + assert output_fc.shape == (2, 3) + + # Test FullTensorProduct + tp_full = FullTensorProduct('1x1o', '1x0e') + output_full = tp_full(x1, x2) + assert output_full.ndim == 2 + + # Test ElementwiseTensorProduct + tp_elem = ElementwiseTensorProduct('1x1o', '1x1o') + output_elem = tp_elem(x1, x1) + assert output_elem.ndim == 2 + + def test_linear_operations(self): + """Test linear operations with and without bias.""" + # Test Linear + linear = Linear('1x1o+1x0e', '1x1o') + x = Tensor(np.random.randn(2, 4), ms.float32) + output = linear(x) + assert output.shape == (2, 3) + + # Test LinearBias + linear_bias = LinearBias('1x1o+1x0e', '1x1o+1x0e', has_bias=True) + output_bias = linear_bias(x) + assert output_bias.shape == (2, 4) + + def test_tensor_square(self): + """Test TensorSquare operation.""" + ts = TensorSquare('1x1o', irreps_out='1x0e+1x2e') + x = Tensor(np.random.randn(2, 3), ms.float32) + output = ts(x) + assert output.shape == (2, 6) # 1x0e+1x2e has dim 6 + + +class TestUtilityFunctions: + """Test utility functions.""" + + def test_prod_functions(self): + """Test product computation functions.""" + # Test prod function + assert prod([2, 3, 4]) == 24 + assert prod([]) == 1 + + # Test _prod function + assert _prod((2, 3, 4)) == 24 + assert _prod(()) == 1 + + def test_tensor_utilities(self): + """Test tensor utility functions.""" + # Test _sum_tensors_withbias + t1 = Tensor(np.array([1, 2, 3]), ms.float32) + t2 = Tensor(np.array([4, 5, 6]), ms.float32) + + result = _sum_tensors_withbias([t1, t2], (3,), ms.float32) + expected = np.array([5, 7, 9]) + assert np.allclose(result.asnumpy(), expected) + + # Test Instruction NamedTuple + instr = Instruction(i_in=0, i_out=1, path_shape=(2, 3), path_weight=1.5) + assert instr.i_in == 0 and instr.i_out == 1 + + +class TestEdgeCases: + """Test edge cases and error handling.""" + + def test_error_handling(self): + """Test error handling for invalid inputs.""" + # Test invalid irreps + with pytest.raises((ValueError, TypeError)): + FullyConnectedTensorProduct('invalid', '1x0e', '1x0e') + + # Test mismatched dimensions + tp = FullyConnectedTensorProduct('1x0e', '1x0e', '1x0e') + x1 = Tensor(np.random.randn(2, 5), ms.float32) # Wrong dimension + x2 = Tensor(np.random.randn(2, 1), ms.float32) + + with pytest.raises(ValueError): + tp(x1, x2) + + def test_scalar_operations(self): + """Test operations with scalar irreps.""" + linear = Linear('1x0e', '1x0e') + x = Tensor(np.random.randn(2, 1), ms.float32) + output = linear(x) + assert output.shape == (2, 1) \ No newline at end of file diff --git a/tests/e3nn/o3/test_tensor_product.py b/tests/e3nn/o3/test_tensor_product.py new file mode 100644 index 000000000..3d57f2308 --- /dev/null +++ b/tests/e3nn/o3/test_tensor_product.py @@ -0,0 +1,115 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test cases for tensor_product module.""" + +import numpy as np +from mindspore import Tensor, float32 + +from mindscience.e3nn.o3.tensor_product import TensorProduct +from mindscience.e3nn.o3.sub import ( + FullTensorProduct, FullyConnectedTensorProduct, + ElementwiseTensorProduct, TensorSquare, Linear +) +from mindscience.e3nn.o3.irreps import Irreps + + +class TestTensorProduct: + """Test class for TensorProduct and related classes.""" + + def test_tensor_product_basic(self): + """Test basic TensorProduct functionality.""" + # Test standard tensor product + tp = TensorProduct('2x1o+1x0e', '1x1o+1x0e') + assert tp.irreps_in1.dim == 7 # 2*3 + 1*1 = 7 + assert tp.irreps_in2.dim == 4 # 1*3 + 1*1 = 4 + + # Test with input tensors + x1 = Tensor(np.random.randn(2, tp.irreps_in1.dim), dtype=float32) + x2 = Tensor(np.random.randn(2, tp.irreps_in2.dim), dtype=float32) + output = tp(x1, x2) + assert output.shape == (2, tp.irreps_out.dim) + + def test_full_tensor_product(self): + """Test FullTensorProduct functionality.""" + # Test full tensor product + ftp = FullTensorProduct('1x1o+1x0e', '1x1o+1x0e') + x1 = Tensor(np.random.randn(2, ftp.irreps_in1.dim), dtype=float32) + x2 = Tensor(np.random.randn(2, ftp.irreps_in2.dim), dtype=float32) + output = ftp(x1, x2) + assert output.shape == (2, ftp.irreps_out.dim) + + def test_fully_connected_tensor_product(self): + """Test FullyConnectedTensorProduct functionality.""" + # Test fully connected tensor product + fctp = FullyConnectedTensorProduct('1x1o', '1x1o', '1x2e+1x0e') + x1 = Tensor(np.random.randn(2, fctp.irreps_in1.dim), dtype=float32) + x2 = Tensor(np.random.randn(2, fctp.irreps_in2.dim), dtype=float32) + output = fctp(x1, x2) + assert output.shape == (2, fctp.irreps_out.dim) + assert fctp.weight_numel > 0 # Should have learnable weights + + def test_elementwise_tensor_product(self): + """Test ElementwiseTensorProduct functionality.""" + # Test elementwise tensor product + etp = ElementwiseTensorProduct('2x1o+1x0e', '2x1o+1x0e') + x1 = Tensor(np.random.randn(2, etp.irreps_in1.dim), dtype=float32) + x2 = Tensor(np.random.randn(2, etp.irreps_in2.dim), dtype=float32) + output = etp(x1, x2) + assert output.shape == (2, etp.irreps_out.dim) + + def test_tensor_square(self): + """Test TensorSquare functionality.""" + # Test tensor square without output specification + ts = TensorSquare('1x1o+1x0e') + x = Tensor(np.random.randn(2, ts.irreps_in1.dim), dtype=float32) + output = ts(x) + assert output.shape == (2, ts.irreps_out.dim) + + # Test tensor square with output specification + ts_out = TensorSquare('1x1o', irreps_out='1x2e+1x0e') + x = Tensor(np.random.randn(2, ts_out.irreps_in1.dim), dtype=float32) + output = ts_out(x) + assert output.shape == (2, ts_out.irreps_out.dim) + assert ts_out.weight_numel > 0 # Should have learnable weights + + def test_linear_operation(self): + """Test Linear operation functionality.""" + # Test linear operation + linear = Linear('1x1o+1x0e', '2x1o+1x0e') + x = Tensor(np.random.randn(2, linear.irreps_in1.dim), dtype=float32) + output = linear(x) + assert output.shape == (2, linear.irreps_out.dim) + assert linear.weight_numel > 0 # Should have learnable weights + + def test_tensor_product_properties(self): + """Test tensor product properties and edge cases.""" + # Test properties + tp = TensorProduct('1x1o', '1x1o', '1x2e+1x0e', instructions='connect') + assert isinstance(tp.irreps_in1, Irreps) + assert isinstance(tp.irreps_in2, Irreps) + assert isinstance(tp.irreps_out, Irreps) + assert isinstance(tp.instructions, list) + assert tp.weight_numel >= 0 + + # Test string representation + repr_str = repr(tp) + assert 'TensorProduct' in repr_str + assert 'connect' in repr_str + + # Test with single batch + x1 = Tensor(np.random.randn(tp.irreps_in1.dim), dtype=float32) + x2 = Tensor(np.random.randn(tp.irreps_in2.dim), dtype=float32) + output = tp(x1, x2) + assert output.shape == (tp.irreps_out.dim,) diff --git a/tests/e3nn/o3/test_wigner.py b/tests/e3nn/o3/test_wigner.py new file mode 100644 index 000000000..a4d332a26 --- /dev/null +++ b/tests/e3nn/o3/test_wigner.py @@ -0,0 +1,113 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test cases for o3.wigner module.""" + +import pytest +import numpy as np +from mindspore import float32, float64, complex64, complex128 + +from mindscience.e3nn.o3.wigner import ( + change_basis_real_to_complex, + su2_generators, + so3_generators, + wigner_D, + wigner_3j +) + +class TestWigner: + """Test wigner module functions.""" + + def test_change_basis_real_to_complex(self): + """Test change_basis_real_to_complex function.""" + # Test basic functionality + result = change_basis_real_to_complex(1) + assert result.shape == (3, 3) + assert result.dtype == complex64 + + # Test dtype conversion + result = change_basis_real_to_complex(1, dtype=float64) + assert result.dtype == complex128 + + # Test unitarity property + Q = change_basis_real_to_complex(1) + Q_np = Q.asnumpy() + I = np.eye(3) + np.testing.assert_allclose(Q_np @ Q_np.conj().T, I, atol=1e-6) + + def test_su2_generators(self): + """Test su2_generators function.""" + # Test basic functionality + result = su2_generators(1) + assert result.shape == (3, 3, 3) + assert result.dtype == complex64 + + # Test dtype + result = su2_generators(1, dtype=complex128) + assert result.dtype == complex128 + + # Test invalid input + with pytest.raises(TypeError): + su2_generators(1.5) + + def test_so3_generators(self): + """Test so3_generators function.""" + # Test basic functionality + result = so3_generators(1) + assert result.shape == (3, 3, 3) + assert result.dtype == float32 + + # Test dtype + result = so3_generators(1, dtype=float64) + assert result.dtype == float64 + + # Test invalid input + with pytest.raises(TypeError): + so3_generators(1.5) + + def test_wigner_D(self): + """Test wigner_D function.""" + # Test identity rotation + result = wigner_D(1, 0, 0, 0) + assert result.shape == (3, 3) + expected = np.eye(3) + np.testing.assert_allclose(result.asnumpy(), expected, atol=1e-6) + + # Test orthogonality property + D = wigner_D(1, 0.5, 0.3, 0.7) + I = np.eye(3) + np.testing.assert_allclose((D @ D.T).asnumpy(), I, atol=1e-5) + + def test_wigner_3j(self): + """Test wigner_3j function.""" + # Test basic functionality + result = wigner_3j(1, 1, 1) + assert result.shape == (3, 3, 3) + assert result.dtype == float32 + + # Test dtype + result = wigner_3j(1, 1, 0, dtype=float64) + assert result.dtype == float64 + + # Test normalization property + C = wigner_3j(1, 1, 1) + norm_squared = np.sum(C.asnumpy() ** 2) + np.testing.assert_allclose(norm_squared, 1.0, atol=1e-6) + + # Test invalid combinations + with pytest.raises(ValueError): + wigner_3j(1, 1, 3) + + with pytest.raises(TypeError): + wigner_3j(1.5, 1, 1) diff --git a/tests/e3nn/utils/test_utils.py b/tests/e3nn/utils/test_utils.py new file mode 100644 index 000000000..0bf1a345a --- /dev/null +++ b/tests/e3nn/utils/test_utils.py @@ -0,0 +1,190 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Test cases for e3nn.utils module. + +This module contains comprehensive test cases for all utility functions +in the e3nn.utils package, including tensor operations, linear algebra, +tensor contractions, radius computations, and initialization utilities. +""" + +import pytest +import numpy as np +import mindspore as ms +from mindspore import Tensor, ops +from mindspore.common.initializer import TruncatedNormal + +from mindscience.e3nn.utils.func import broadcast_args, _ndexpm, narrow +from mindscience.e3nn.utils.linalg import _direct_sum +from mindscience.e3nn.utils.ncon import Ncon +from mindscience.e3nn.utils.radius import radius, radius_graph, radius_full, radius_graph_full +from mindscience.e3nn.utils.initializer import Uniform, renormal_initializer +from mindscience.e3nn.utils.perm import _from_int, _to_int, _inverse, _compose, _group, _germinate + +class TestFuncModule: + """Test cases for func.py module.""" + + def test_broadcast_and_operations(self): + """Test broadcasting, matrix exponential, and tensor slicing.""" + # Test broadcasting + a = Tensor([1.0, 2.0]) + b = Tensor([[3.0], [4.0]]) + result = broadcast_args(a, b) + assert len(result) == 2 and result[0].shape == (2, 2) + + # Test matrix exponential + mat = Tensor([[0.0, 1.0], [-1.0, 0.0]], dtype=ms.float32) + exp_result = _ndexpm(mat) + assert exp_result.shape == (2, 2) + + # Test tensor slicing + x = Tensor(np.arange(24).reshape(2, 3, 4), dtype=ms.float32) + sliced = narrow(x, axis=0, start=0, length=1) + assert sliced.shape == (1, 3, 4) + +class TestLinalgModule: + """Test cases for linalg.py module.""" + + def test_direct_sum(self): + """Test direct sum of matrices.""" + a = Tensor([[1.0, 2.0], [3.0, 4.0]]) + b = Tensor([[5.0]]) + result = _direct_sum(a, b) + assert result.shape == (3, 3) + + # Test with batch dimensions + batch_a = Tensor(np.random.randn(2, 3, 3).astype(np.float32)) + batch_b = Tensor(np.random.randn(2, 2, 2).astype(np.float32)) + batch_result = _direct_sum(batch_a, batch_b) + assert batch_result.shape == (2, 5, 5) + +class TestNconModule: + """Test cases for ncon.py module.""" + + def test_ncon_operations(self): + """Test various Ncon tensor contraction operations.""" + # Test trace + ncon_trace = Ncon([[1, 1]]) + a = Tensor([[1.0, 2.0], [3.0, 4.0]]) + trace_result = ncon_trace([a]) + assert np.isclose(trace_result.asnumpy(), 5.0) + + # Test outer product + ncon_outer = Ncon([[-1], [-2]]) + b = Tensor([1.0, 2.0]) + c = Tensor([3.0, 4.0, 5.0]) + outer_result = ncon_outer([b, c]) + assert outer_result.shape == (2, 3) + + # Test batch matrix multiplication + ncon_matmul = Ncon([[-1, -2, 1], [-1, 1, -3]]) + d = Tensor(np.random.randn(2, 3, 4).astype(np.float32)) + e = Tensor(np.random.randn(2, 4, 5).astype(np.float32)) + matmul_result = ncon_matmul([d, e]) + assert matmul_result.shape == (2, 3, 5) + +class TestRadiusModule: + """Test cases for radius.py module.""" + + def test_radius_functions(self): + """Test radius computation functions.""" + np.random.seed(42) + x = np.random.random((8, 3)).astype(np.float32) + y = np.random.random((5, 3)).astype(np.float32) + + # Test basic radius + edge_index, batch_x, batch_y = radius(x, y, 0.5, max_num_neighbors=10) + assert edge_index.shape[0] == 2 and len(batch_x) == len(x) + + # Test radius_graph + edge_index, batch = radius_graph(x, 0.8, loop=False) + assert edge_index.shape[0] == 2 and len(batch) == len(x) + + # Test radius_full + x_batch = np.ones((2, 3, 3), dtype=np.float32) + edge_index_full, batch_x_full, batch_y_full = radius_full(x_batch, x_batch) + assert edge_index_full.shape[0] == 2 and len(batch_x_full) == 6 + +class TestInitializerModule: + """Test cases for initializer.py module.""" + + def test_initializers(self): + """Test custom initializers.""" + # Test Uniform initializer + from mindspore.common.initializer import initializer + uniform_init = Uniform(scale=2.0) + tensor = initializer(uniform_init, [3, 4], ms.float32) + values = tensor.asnumpy() + assert np.all(values >= 0.0) and np.all(values <= 2.0) + + # Test renormal_initializer + init1 = renormal_initializer('uniform') + assert isinstance(init1, Uniform) + + init2 = renormal_initializer('truncatedNormal') + assert isinstance(init2, TruncatedNormal) + + # Test invalid input + with pytest.raises(ValueError): + renormal_initializer('invalid_method') + +class TestPermModule: + """Test cases for perm.py module.""" + + def test_permutation_operations(self): + """Test permutation conversion and operations.""" + # Test conversion functions + n = 3 + for i in range(6): # 3! = 6 + perm = _from_int(i, n) + assert len(perm) == n and _to_int(perm) == i + + # Test permutation operations + perm1 = (0, 2, 1) + perm2 = (1, 0, 2) + inv_perm1 = _inverse(perm1) + composed = _compose(perm1, inv_perm1) + assert composed == (0, 1, 2) # identity + + # Test group operations + group3 = _group(3) + assert len(group3) == 6 # 3! = 6 + + subset = {(0, 1, 2), (1, 0, 2)} + closure = _germinate(subset) + assert len(closure) >= len(subset) + +class TestInputValidation: + """Test input validation and error handling.""" + + def test_error_handling(self): + """Test various error conditions.""" + # Test radius with mismatched dimensions + x = np.random.random((5, 3)) + y = np.random.random((5, 4)) # Different last dimension + with pytest.raises(ValueError): + radius(x, y, 1.0) + + # Test radius_graph with invalid flow + with pytest.raises(ValueError): + radius_graph(x, 1.0, flow='invalid_flow') + + # Test _ndexpm with invalid input + invalid_mat = Tensor([1.0]) # 1D tensor + with pytest.raises(ValueError): + _ndexpm(invalid_mat) + +if __name__ == "__main__": + pytest.main([__file__]) -- Gitee From 3ee0595bc5572f252b3d1b8dc8583f822870f4ac Mon Sep 17 00:00:00 2001 From: birfied Date: Fri, 19 Sep 2025 16:49:13 +0800 Subject: [PATCH 19/21] pylint --- tests/e3nn/o3/test_sub.py | 5 +---- tests/e3nn/o3/test_wigner.py | 20 ++++++++++---------- tests/e3nn/utils/test_utils.py | 9 ++++----- 3 files changed, 15 insertions(+), 19 deletions(-) diff --git a/tests/e3nn/o3/test_sub.py b/tests/e3nn/o3/test_sub.py index 8bfdb3915..fd8d1aabb 100644 --- a/tests/e3nn/o3/test_sub.py +++ b/tests/e3nn/o3/test_sub.py @@ -42,11 +42,8 @@ from mindscience.e3nn.o3.sub import ( prod, _prod, _sum_tensors_withbias, - _compose, - _run_continue, Instruction ) -from mindscience.e3nn.o3.irreps import Irreps class TestTensorProductClasses: @@ -142,4 +139,4 @@ class TestEdgeCases: linear = Linear('1x0e', '1x0e') x = Tensor(np.random.randn(2, 1), ms.float32) output = linear(x) - assert output.shape == (2, 1) \ No newline at end of file + assert output.shape == (2, 1) diff --git a/tests/e3nn/o3/test_wigner.py b/tests/e3nn/o3/test_wigner.py index a4d332a26..d3db13129 100644 --- a/tests/e3nn/o3/test_wigner.py +++ b/tests/e3nn/o3/test_wigner.py @@ -41,10 +41,10 @@ class TestWigner: assert result.dtype == complex128 # Test unitarity property - Q = change_basis_real_to_complex(1) - Q_np = Q.asnumpy() - I = np.eye(3) - np.testing.assert_allclose(Q_np @ Q_np.conj().T, I, atol=1e-6) + q_matrix = change_basis_real_to_complex(1) + q_np = q_matrix.asnumpy() + identity = np.eye(3) + np.testing.assert_allclose(q_np @ q_np.conj().T, identity, atol=1e-6) def test_su2_generators(self): """Test su2_generators function.""" @@ -76,7 +76,7 @@ class TestWigner: with pytest.raises(TypeError): so3_generators(1.5) - def test_wigner_D(self): + def test_wigner_d(self): """Test wigner_D function.""" # Test identity rotation result = wigner_D(1, 0, 0, 0) @@ -85,9 +85,9 @@ class TestWigner: np.testing.assert_allclose(result.asnumpy(), expected, atol=1e-6) # Test orthogonality property - D = wigner_D(1, 0.5, 0.3, 0.7) - I = np.eye(3) - np.testing.assert_allclose((D @ D.T).asnumpy(), I, atol=1e-5) + d_matrix = wigner_D(1, 0.5, 0.3, 0.7) + identity = np.eye(3) + np.testing.assert_allclose((d_matrix @ d_matrix.T).asnumpy(), identity, atol=1e-5) def test_wigner_3j(self): """Test wigner_3j function.""" @@ -101,8 +101,8 @@ class TestWigner: assert result.dtype == float64 # Test normalization property - C = wigner_3j(1, 1, 1) - norm_squared = np.sum(C.asnumpy() ** 2) + coeffs = wigner_3j(1, 1, 1) + norm_squared = np.sum(coeffs.asnumpy() ** 2) np.testing.assert_allclose(norm_squared, 1.0, atol=1e-6) # Test invalid combinations diff --git a/tests/e3nn/utils/test_utils.py b/tests/e3nn/utils/test_utils.py index 0bf1a345a..94825bdd0 100644 --- a/tests/e3nn/utils/test_utils.py +++ b/tests/e3nn/utils/test_utils.py @@ -23,13 +23,13 @@ tensor contractions, radius computations, and initialization utilities. import pytest import numpy as np import mindspore as ms -from mindspore import Tensor, ops +from mindspore import Tensor from mindspore.common.initializer import TruncatedNormal from mindscience.e3nn.utils.func import broadcast_args, _ndexpm, narrow from mindscience.e3nn.utils.linalg import _direct_sum from mindscience.e3nn.utils.ncon import Ncon -from mindscience.e3nn.utils.radius import radius, radius_graph, radius_full, radius_graph_full +from mindscience.e3nn.utils.radius import radius, radius_graph, radius_full from mindscience.e3nn.utils.initializer import Uniform, renormal_initializer from mindscience.e3nn.utils.perm import _from_int, _to_int, _inverse, _compose, _group, _germinate @@ -105,7 +105,7 @@ class TestRadiusModule: y = np.random.random((5, 3)).astype(np.float32) # Test basic radius - edge_index, batch_x, batch_y = radius(x, y, 0.5, max_num_neighbors=10) + edge_index, batch_x, _ = radius(x, y, 0.5, max_num_neighbors=10) assert edge_index.shape[0] == 2 and len(batch_x) == len(x) # Test radius_graph @@ -114,7 +114,7 @@ class TestRadiusModule: # Test radius_full x_batch = np.ones((2, 3, 3), dtype=np.float32) - edge_index_full, batch_x_full, batch_y_full = radius_full(x_batch, x_batch) + edge_index_full, batch_x_full, _ = radius_full(x_batch, x_batch) assert edge_index_full.shape[0] == 2 and len(batch_x_full) == 6 class TestInitializerModule: @@ -153,7 +153,6 @@ class TestPermModule: # Test permutation operations perm1 = (0, 2, 1) - perm2 = (1, 0, 2) inv_perm1 = _inverse(perm1) composed = _compose(perm1, inv_perm1) assert composed == (0, 1, 2) # identity -- Gitee From c17ccc9dc454a1670e78d5f021a185548c979a42 Mon Sep 17 00:00:00 2001 From: birfied Date: Fri, 19 Sep 2025 16:53:58 +0800 Subject: [PATCH 20/21] pylint --- tests/e3nn/o3/test_spherical_harmonics.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/e3nn/o3/test_spherical_harmonics.py b/tests/e3nn/o3/test_spherical_harmonics.py index f11042a03..0d41024d9 100644 --- a/tests/e3nn/o3/test_spherical_harmonics.py +++ b/tests/e3nn/o3/test_spherical_harmonics.py @@ -85,8 +85,10 @@ class TestSphericalHarmonicsClass: """Test class initialization and forward computation.""" # Test initialization sh = SphericalHarmonics(2, normalize=True) - assert sh._lmax == 2 + # Verify the output dimension instead of accessing protected member assert sh.irreps_out.dim == 5 + # Verify the irreps_out contains the expected l=2 representation + assert str(sh.irreps_out) == "1x2e" # Test forward computation x = Tensor([[1.0, 0.0, 0.0]], float32) -- Gitee From 81ca52526aa8de1015be58cf40557f87ae8c8297 Mon Sep 17 00:00:00 2001 From: birfied Date: Fri, 19 Sep 2025 17:22:19 +0800 Subject: [PATCH 21/21] ascend test cases --- tests/e3nn/nn/test_fc.py | 6 ++++++ tests/e3nn/nn/test_gate.py | 6 ++++++ tests/e3nn/o3/test_irreps.py | 6 ++++++ tests/e3nn/o3/test_norm.py | 18 ------------------ tests/e3nn/o3/test_rotation.py | 6 ++++++ tests/e3nn/o3/test_spherical_harmonics.py | 7 +++++++ tests/e3nn/o3/test_tensor_product.py | 7 +++++++ tests/e3nn/o3/test_wigner.py | 6 ++++++ 8 files changed, 44 insertions(+), 18 deletions(-) diff --git a/tests/e3nn/nn/test_fc.py b/tests/e3nn/nn/test_fc.py index 104832691..ee3d9780b 100644 --- a/tests/e3nn/nn/test_fc.py +++ b/tests/e3nn/nn/test_fc.py @@ -22,6 +22,9 @@ from mindscience.e3nn.nn.fc import FullyConnectedNet class TestFullyConnectedNet: """Test cases for FullyConnectedNet""" + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard def test_fc_basic_creation(self): """Test basic creation and parameter initialization""" h_list = [4, 10, 6] @@ -33,6 +36,9 @@ class TestFullyConnectedNet: assert fc.layer_list[1].h_in == 10 and fc.layer_list[1].h_out == 6 assert fc.weight_numel == 4*10 + 10*6 + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard def test_fc_forward_computation(self): """Test forward propagation computation correctness""" h_list = [3, 4, 2] diff --git a/tests/e3nn/nn/test_gate.py b/tests/e3nn/nn/test_gate.py index 3cbc06447..89ef2ca22 100644 --- a/tests/e3nn/nn/test_gate.py +++ b/tests/e3nn/nn/test_gate.py @@ -8,6 +8,9 @@ from mindscience.e3nn.nn import Gate class TestGate: """Test cases for Gate module""" + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard def test_gate_creation(self): """Test Gate creation and basic properties""" gate = Gate('2x0e', [ops.tanh], '1x0e', [ops.sigmoid], '1x1o') @@ -15,6 +18,9 @@ class TestGate: assert gate.irreps_in.dim > 0 assert gate.irreps_out.dim > 0 + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard def test_gate_forward(self): """Test forward propagation""" gate = Gate('1x0e', [ops.tanh], '2x0e', [ops.sigmoid, ops.abs], '2x1o') diff --git a/tests/e3nn/o3/test_irreps.py b/tests/e3nn/o3/test_irreps.py index f382f3541..cfe93092c 100644 --- a/tests/e3nn/o3/test_irreps.py +++ b/tests/e3nn/o3/test_irreps.py @@ -22,6 +22,9 @@ from mindscience.e3nn.o3 import Irrep, Irreps class TestIrrep: """Test cases for Irrep class""" + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard def test_irrep_creation_and_properties(self): """Test Irrep creation, properties and basic operations""" # Test creation with l and p parameters @@ -106,6 +109,9 @@ class TestIrrep: class TestIrreps: """Test cases for Irreps class""" + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard def test_irreps_creation_and_basic_operations(self): """Test Irreps creation and basic operations""" # Test creation from string diff --git a/tests/e3nn/o3/test_norm.py b/tests/e3nn/o3/test_norm.py index d356356c8..84118cbdb 100644 --- a/tests/e3nn/o3/test_norm.py +++ b/tests/e3nn/o3/test_norm.py @@ -9,9 +9,6 @@ from mindscience.e3nn.o3 import Norm, Irreps class TestNorm: """Streamlined tests for Norm class""" - @pytest.mark.level0 - @pytest.mark.platform_arm_ascend910b_training - @pytest.mark.env_onecard def test_norm_creation_and_basic_properties(self): """Test Norm creation with different irreps and basic properties""" # Test basic creation with string irreps @@ -31,9 +28,6 @@ class TestNorm: repr_str = repr(norm2) assert 'Norm' in repr_str - @pytest.mark.level0 - @pytest.mark.platform_arm_ascend910b_training - @pytest.mark.env_onecard def test_norm_forward_pass_comprehensive(self): """Test forward pass with various irrep types and configurations""" # Test scalar irrep (0e) @@ -56,9 +50,6 @@ class TestNorm: expected_mixed = np.array([2.0, 5.0]) # scalar norm + vector norm np.testing.assert_allclose(mixed_output.asnumpy(), expected_mixed, rtol=1e-5) - @pytest.mark.level0 - @pytest.mark.platform_arm_ascend910b_training - @pytest.mark.env_onecard def test_norm_squared_and_dtype_consistency(self): """Test squared parameter and dtype consistency""" # Test squared vs regular norm @@ -75,9 +66,6 @@ class TestNorm: assert output_regular.dtype == float32 assert output_squared.dtype == float32 - @pytest.mark.level0 - @pytest.mark.platform_arm_ascend910b_training - @pytest.mark.env_onecard def test_norm_mathematical_properties_and_edge_cases(self): """Test mathematical properties and edge cases""" norm = Norm('1x1o') @@ -96,9 +84,6 @@ class TestNorm: expected_small = np.sqrt(3) * 1e-10 np.testing.assert_allclose(small_output.asnumpy(), [expected_small], rtol=1e-5) - @pytest.mark.level0 - @pytest.mark.platform_arm_ascend910b_training - @pytest.mark.env_onecard def test_norm_higher_order_and_mixed_parity(self): """Test higher order irreps and mixed parity""" # Test l=2 irrep @@ -115,9 +100,6 @@ class TestNorm: expected_mixed_parity = np.array([2.0, np.sqrt(3.0), 3.0]) np.testing.assert_allclose(mixed_parity_output.asnumpy(), expected_mixed_parity, rtol=1e-5) - @pytest.mark.level0 - @pytest.mark.platform_arm_ascend910b_training - @pytest.mark.env_onecard def test_norm_error_handling(self): """Test error handling for invalid inputs""" norm = Norm('1x1o') diff --git a/tests/e3nn/o3/test_rotation.py b/tests/e3nn/o3/test_rotation.py index f3d581d77..53ed0c410 100644 --- a/tests/e3nn/o3/test_rotation.py +++ b/tests/e3nn/o3/test_rotation.py @@ -30,6 +30,9 @@ from mindscience.e3nn.o3.rotation import ( class TestRotation: """Test class for rotation functions.""" + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard def test_identity_angles(self): """Test identity_angles function comprehensively.""" # Test basic functionality and shapes @@ -100,6 +103,9 @@ class TestRotation: # Check determinant = 1 assert np.allclose(np.linalg.det(mat.asnumpy()), 1.0, atol=1e-6) + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard def test_angle_matrix_conversion(self): """Test angles_to_matrix and matrix_to_angles functions.""" # Test identity conversion diff --git a/tests/e3nn/o3/test_spherical_harmonics.py b/tests/e3nn/o3/test_spherical_harmonics.py index 0d41024d9..f3a78940f 100644 --- a/tests/e3nn/o3/test_spherical_harmonics.py +++ b/tests/e3nn/o3/test_spherical_harmonics.py @@ -14,6 +14,7 @@ # ============================================================================ """Test spherical harmonics module.""" +import pytest import numpy as np from mindspore import Tensor, float32 from mindscience.e3nn.o3 import spherical_harmonics, SphericalHarmonics @@ -22,6 +23,9 @@ from mindscience.e3nn.o3 import spherical_harmonics, SphericalHarmonics class TestSphericalHarmonicsFunction: """Test spherical_harmonics function.""" + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard def test_core_functionality(self): """Test core spherical harmonics functionality including degrees and normalization.""" x = Tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], float32) @@ -81,6 +85,9 @@ class TestSphericalHarmonicsFunction: class TestSphericalHarmonicsClass: """Test the SphericalHarmonics class.""" + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard def test_class_initialization_and_forward(self): """Test class initialization and forward computation.""" # Test initialization diff --git a/tests/e3nn/o3/test_tensor_product.py b/tests/e3nn/o3/test_tensor_product.py index 3d57f2308..e5793ee51 100644 --- a/tests/e3nn/o3/test_tensor_product.py +++ b/tests/e3nn/o3/test_tensor_product.py @@ -14,6 +14,7 @@ # ============================================================================ """Test cases for tensor_product module.""" +import pytest import numpy as np from mindspore import Tensor, float32 @@ -28,6 +29,9 @@ from mindscience.e3nn.o3.irreps import Irreps class TestTensorProduct: """Test class for TensorProduct and related classes.""" + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard def test_tensor_product_basic(self): """Test basic TensorProduct functionality.""" # Test standard tensor product @@ -41,6 +45,9 @@ class TestTensorProduct: output = tp(x1, x2) assert output.shape == (2, tp.irreps_out.dim) + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard def test_full_tensor_product(self): """Test FullTensorProduct functionality.""" # Test full tensor product diff --git a/tests/e3nn/o3/test_wigner.py b/tests/e3nn/o3/test_wigner.py index d3db13129..a1d0b5f45 100644 --- a/tests/e3nn/o3/test_wigner.py +++ b/tests/e3nn/o3/test_wigner.py @@ -29,6 +29,9 @@ from mindscience.e3nn.o3.wigner import ( class TestWigner: """Test wigner module functions.""" + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard def test_change_basis_real_to_complex(self): """Test change_basis_real_to_complex function.""" # Test basic functionality @@ -89,6 +92,9 @@ class TestWigner: identity = np.eye(3) np.testing.assert_allclose((d_matrix @ d_matrix.T).asnumpy(), identity, atol=1e-5) + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard def test_wigner_3j(self): """Test wigner_3j function.""" # Test basic functionality -- Gitee