From b1b659d559df350702e89f60a15a6be1f6a1bf76 Mon Sep 17 00:00:00 2001 From: birfied Date: Thu, 18 Sep 2025 10:59:31 +0800 Subject: [PATCH 1/2] move from MindChemistry to MindChem --- .../applications/crystalflow/.gitignore | 0 MindChem/applications/diffcsp/README.md | 127 +++++++ .../applications/diffcsp/compute_metric.py | 327 ++++++++++++++++++ MindChem/applications/diffcsp/config.yaml | 41 +++ .../applications/diffcsp/data/crysloader.py | 0 .../applications/diffcsp/data/data_utils.py | 0 MindChem/applications/diffcsp/data/dataset.py | 135 ++++++++ .../applications/diffcsp/evaluate.py | 0 .../applications/diffcsp/models/cspnet.py | 0 .../applications/diffcsp/models/diff_utils.py | 0 .../applications/diffcsp/models/diffusion.py | 0 .../diffcsp/models/infer_utils.py | 0 .../diffcsp/models/train_utils.py | 0 .../applications/diffcsp/requirement.txt | 0 MindChem/applications/diffcsp/train.py | 208 +++++++++++ 15 files changed, 838 insertions(+) rename {MindChemistry => MindChem}/applications/crystalflow/.gitignore (100%) create mode 100644 MindChem/applications/diffcsp/README.md create mode 100644 MindChem/applications/diffcsp/compute_metric.py create mode 100644 MindChem/applications/diffcsp/config.yaml rename {MindChemistry => MindChem}/applications/diffcsp/data/crysloader.py (100%) rename {MindChemistry => MindChem}/applications/diffcsp/data/data_utils.py (100%) create mode 100644 MindChem/applications/diffcsp/data/dataset.py rename {MindChemistry => MindChem}/applications/diffcsp/evaluate.py (100%) rename {MindChemistry => MindChem}/applications/diffcsp/models/cspnet.py (100%) rename {MindChemistry => MindChem}/applications/diffcsp/models/diff_utils.py (100%) rename {MindChemistry => MindChem}/applications/diffcsp/models/diffusion.py (100%) rename {MindChemistry => MindChem}/applications/diffcsp/models/infer_utils.py (100%) rename {MindChemistry => MindChem}/applications/diffcsp/models/train_utils.py (100%) rename {MindChemistry => MindChem}/applications/diffcsp/requirement.txt (100%) create mode 100644 MindChem/applications/diffcsp/train.py diff --git a/MindChemistry/applications/crystalflow/.gitignore b/MindChem/applications/crystalflow/.gitignore similarity index 100% rename from MindChemistry/applications/crystalflow/.gitignore rename to MindChem/applications/crystalflow/.gitignore diff --git a/MindChem/applications/diffcsp/README.md b/MindChem/applications/diffcsp/README.md new file mode 100644 index 000000000..858aa417c --- /dev/null +++ b/MindChem/applications/diffcsp/README.md @@ -0,0 +1,127 @@ + +# 模型名称 + +> DiffCSP + +## 介绍 + +> DiffCSP是一种基于扩散模型的深度学习框架,用于解决晶体结构预测这一基础科学难题。其核心思想是将寻找稳定晶体结构的过程转化为一个生成问题:模型通过学习海量已知晶体数据中的分布规律,能够仅根据材料的化学成分(原子种类与比例),直接、快速地生成合理的三维原子结构(包括晶格和原子坐标)。与传统依赖大量量子力学计算的方法相比,DiffCSP的关键创新在于采用了​​SE(3)-等变图神经网络​​并融入了​​周期性边界条件​​,确保了生成的结构严格遵守物理对称性,从而能以极高的效率探索材料的多态性,为新材料的加速发现与设计提供了强大工具。 + +## 环境要求 + +> 1. 安装`mindspore(2.3.0)` +> 2. 安装依赖包:`pip install -r requirement.txt` + +## 快速入门 + +> 1. 将Mindchemistry/mindchemistry文件包下载到当前目录 +> 2. 在[数据集链接](https://download-mindspore.osinfra.cn/mindscience/mindchemistry/diffcsp/)下载相应的数据集 +> 3. 安装依赖包:`pip install -r requirement.txt` +> 4. 训练命令: `python train.py` +> 5. 预测命令: `python evaluate.py` +> 6. 评估命令: `python compute_metric.py` +> 7. 评估结果放在`config.yaml`中指定的`metric_dir`路径的json文件中 + +### 代码目录结构 + +```txt +diffcsp + │ README.md README文件 + │ config.yaml 配置文件 + │ train.py 训练启动脚本 + │ evaluate.py 推理启动脚本 + │ compute_metric.py 评估启动脚本 + │ requirement.txt 环境依赖 + │ + └─data + data_utils.py 数据集处理工具 + dataset.py 读取数据集 + crysloader.py 数据集载入器 + └─models + cspnet.py 基于图神经网络的去噪器模块 + diffusion.py 扩散模型模块 + diff_utils.py 工具模块 + infer_utils.py 推理工具模块 + train_utils.py 训练工具模块 + +``` + +## 下载数据集 + +在[数据集链接](https://download-mindspore.osinfra.cn/mindscience/mindchemistry/diffcsp/)中下载相应的数据集文件夹和dataset_prop.txt数据集属性文件放置于当前路径的dataset文件夹下(如果没有需要自己手动创建),文件路径参考: + +```txt +diffcsp + ... + └─dataset + perov_5 钙钛矿数据集 + carbon_24 碳晶体数据集 + mp_20 晶胞内原子数最多为20的MP数据集 + mpts_52 晶胞内原子数最多为52的MP数据集 + dataset_prop.txt 数据集属性文件 + ... +``` + +## 训练过程 + +### 训练 + +将Mindchemistry/mindchemistry文件包下载到当前目录; + +更改config文件,设置训练参数: +> 1. 设置训练的dataset,见dataset字段 +> 2. 设置去噪器模型的配置,见model字段 +> 3. 设置训练保存的权重文件,更改train.ckpt_dir文件夹名称和checkpoint.last_path权重文件名称 +> 4. 其它训练设置见train字段 + +```bash +pip install -r requirement.txt +python train.py +``` + +### 推理 + +将权重的path写入config文件的checkpoint.last_path中。预训练模型可以从[预训练模型链接](https://download-mindspore.osinfra.cn/mindscience/mindchemistry/diffcsp/pre-train)中获取。 + +更改config文件中的test字段来更改推理参数,特别是test.num_eval,它**决定了对于每个组分生成多少个样本**,对于后续的评估阶段很重要。 + +```bash +python evaluate.py +``` + +推理得到的晶体将保存在test.eval_save_path指定的文件中 + +文件中存储的内容为python字典,格式为: + +```python +{ + 'pred': [ + [晶体A sample 1, 晶体A sample 2, 晶体A sample 3, ... 晶体A sample num_eval], + [晶体B sample 1, 晶体B sample 2, 晶体B sample 3, ... 晶体B sample num_eval] + ... + ] + 'gt': [ + 晶体A ground truth, + 晶体B ground truth, + ... + ] +} +``` + +### 评估 + +将推理得到的晶体文件的path写入config文件的test.eval_save_path中; + +确保num_evals与进行推理时设置的对于每个组分生成样本的数量一致或更小。比如进行推理时,num_evals设置为1,那么评估时,num_evals只能设置为1;推理时,num_evals设置为20,那么评估时,num_evals可以设置为1-20的数字来进行评估。 + +更改config文件中的test.metric_dir字段来设置评估结果的保存路径 + +```bash +python compute_metric.py +``` + +得到的评估结果文件示例: + +```json +{"match_rate": 0.985997357992074, "rms_dist": 0.013073775170360118} +``` diff --git a/MindChem/applications/diffcsp/compute_metric.py b/MindChem/applications/diffcsp/compute_metric.py new file mode 100644 index 000000000..df5539aad --- /dev/null +++ b/MindChem/applications/diffcsp/compute_metric.py @@ -0,0 +1,327 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""compute metric file""" +import itertools +import json +import os +import pickle +from collections import Counter +from pathlib import Path +import argparse +import yaml + +import numpy as np +from matminer.featurizers.composition.composite import ElementProperty +from matminer.featurizers.site.fingerprint import CrystalNNFingerprint +from p_tqdm import p_map +from pymatgen.analysis.structure_matcher import StructureMatcher +from pymatgen.core.composition import Composition +from pymatgen.core.lattice import Lattice +from pymatgen.core.structure import Structure +import smact +from smact.screening import pauling_test +from tqdm import trange + +from models.infer_utils import chemical_symbols + +matcher = StructureMatcher(stol=0.5, angle_tol=10, ltol=0.3) +crystalnn_fp = CrystalNNFingerprint.from_preset("ops") +comp_fp = ElementProperty.from_preset('magpie') + + +def smact_validity(comp, count, use_pauling_test=True, include_alloys=True): + """Smact validity. See details in the paper Crystal Diffution Variational Autoencoder and + its codebase. + """ + elem_symbols = tuple([chemical_symbols[elem] for elem in comp]) + space = smact.element_dictionary(elem_symbols) + smact_elems = [e[1] for e in space.items()] + electronegs = [e.pauling_eneg for e in smact_elems] + ox_combos = [e.oxidation_states for e in smact_elems] + if len(set(elem_symbols)) == 1: + return True + if include_alloys: + is_metal_list = [elem_s in smact.metals for elem_s in elem_symbols] + if all(is_metal_list): + return True + + threshold = np.max(count) + oxn = 1 + for oxc in ox_combos: + oxn *= len(oxc) + if oxn > 1e7: + return False + for ox_states in itertools.product(*ox_combos): + stoichs = [(c,) for c in count] + # Test for charge balance + cn_e, _ = smact.neutral_ratios(ox_states, + stoichs=stoichs, + threshold=threshold) + # Electronegativity test + if cn_e: + if use_pauling_test: + try: + electroneg_ok = pauling_test(ox_states, electronegs) + except TypeError: + # if no electronegativity data, assume it is okay + electroneg_ok = True + else: + electroneg_ok = True + if electroneg_ok: + return True + return False + + +def structure_validity(crystal, cutoff=0.5): + """Structure validity. See details in the paper Crystal Diffution Variational Autoencoder and + its codebase. + """ + dist_mat = crystal.distance_matrix + # Pad diagonal with a large number + dist_mat = dist_mat + np.diag(np.ones(dist_mat.shape[0]) * (cutoff + 10.)) + if dist_mat.min() < cutoff or crystal.volume < 0.1 or max( + crystal.lattice.abc) > 40: + return False + + return True + +class Crystal: + """Strict crystal validity. See details in the paper CDVAE `Crystal + Diffution Variational Autoencoder` and + its codebase. We adopt the same evaluation metric criteria as CDVAE. + """ + + def __init__(self, crys_array_dict): + self.frac_coords = crys_array_dict['frac_coords'] + self.atom_types = crys_array_dict['atom_types'] + self.lengths = crys_array_dict['lengths'] + self.angles = crys_array_dict['angles'] + self.dict = crys_array_dict + if len(self.atom_types.shape) > 1: + self.dict['atom_types'] = np.argmax(self.atom_types, axis=-1) + 1 + self.atom_types = np.argmax(self.atom_types, axis=-1) + 1 + + self.get_structure() + self.get_composition() + self.get_validity() + self.get_fingerprints() + + def get_structure(self): + """get_structure + """ + if min(self.lengths.tolist()) < 0: + self.constructed = False + self.invalid_reason = 'non_positive_lattice' + if np.isnan(self.lengths).any() or np.isnan( + self.angles).any() or np.isnan(self.frac_coords).any(): + self.constructed = False + self.invalid_reason = 'nan_value' + else: + try: + self.structure = Structure(lattice=Lattice.from_parameters( + *(self.lengths.tolist() + self.angles.tolist())), + species=self.atom_types, + coords=self.frac_coords, + coords_are_cartesian=False) + self.constructed = True + # pylint: disable=W0703 + except Exception: + self.constructed = False + self.invalid_reason = 'construction_raises_exception' + if self.structure.volume < 0.1: + self.constructed = False + self.invalid_reason = 'unrealistically_small_lattice' + + def get_composition(self): + """get_composition + """ + elem_counter = Counter(self.atom_types) # pylint: disable=E1121 + composition = [(elem, elem_counter[elem]) + for elem in sorted(elem_counter.keys())] + elems, counts = list(zip(*composition)) + counts = np.array(counts) + counts = counts / np.gcd.reduce(counts) + self.elems = elems + self.comps = tuple(counts.astype('int').tolist()) + + def get_validity(self): + """get_validity + """ + self.comp_valid = smact_validity(self.elems, self.comps) + if self.constructed: + self.struct_valid = structure_validity(self.structure) + else: + self.struct_valid = False + self.valid = self.comp_valid and self.struct_valid + + def get_fingerprints(self): + """get_fingerprints + """ + elem_counter = Counter(self.atom_types) # pylint: disable=E1121 + comp = Composition(elem_counter) + self.comp_fp = comp_fp.featurize(comp) + try: + site_fps = [ + crystalnn_fp.featurize(self.structure, i) + for i in range(len(self.structure)) + ] + # pylint: disable=W0703 + except Exception: + # counts crystal as invalid if fingerprint cannot be constructed. + self.valid = False + self.comp_fp = None + self.struct_fp = None + return + self.struct_fp = np.array(site_fps).mean(axis=0) + + +def get_rms(pred_struc_list, gt_struc: Structure, num_eval, np_list): + """Calculate the rms distance between the ground truth and predicted crystal structures. + + Args: + pred_struc_list (List[Structure]): The crystals generated by diffution model + in the form of Structure. + gt_struc (Structure): The ground truth crystal. + num_eval (int): Specify that the first N items in the predicted List of crystal structures + participate in the evaluationo. + np_list (List[Dict]): The crystals generated by diffution model in the form of Dict. + """ + + def process_one(pred_struc: Structure): + try: + if not pred_struc.is_valid(): + return None + rms_dist = matcher.get_rms_dist(pred_struc, gt_struc) + rms_dist = None if rms_dist is None else rms_dist[0] + tune_rms = rms_dist + # pylint: disable=W0703 + except Exception: + tune_rms = None + return tune_rms + + min_rms = None + min_struc = None + for i, struct in enumerate(pred_struc_list): + if i == num_eval: + break + rms = process_one(struct) + if rms is not None and (min_rms is None or min_rms > rms): + min_rms = rms + min_struc = np_list[i] + return min_rms, min_struc + + +def get_struc_from_np_list(np_list): + """convert the crystal in the form of Dict to pymatgen.Structure + """ + result = [] + for cry_array in np_list: + try: + struct = Structure(lattice=Lattice.from_parameters( + *(cry_array['lengths'].tolist() + + cry_array['angles'].tolist())), + species=cry_array['atom_types'], + coords=cry_array['frac_coords'], + coords_are_cartesian=False) + # pylint: disable=W0703 + except Exception: + print('Warning: One anomalous crystal structure has captured and removed. ') + struct = None + + result.append(struct) + return result + +def main(args): + """main + """ + with open(args.config, 'r') as stream: + config = yaml.safe_load(stream) + + eval_file = config['test']['eval_save_path'] + num_eval = config['test']['num_eval'] + output_path = config['test']['metric_dir'] + + with open(eval_file, 'rb') as f: + eval_dict = pickle.load(f) + + pred_list = eval_dict['pred'] + gt_list = eval_dict['gt'] + gt_list = get_struc_from_np_list(gt_list) + rms = [] + + # calculate rmsd + for i in trange(len(gt_list)): + pred_struc = get_struc_from_np_list(pred_list[i]) + gt_struc = gt_list[i] + rms_single, struc_single = get_rms(pred_struc, gt_struc, num_eval, + pred_list[i]) + rms.append((rms_single, struc_single)) + + rms, struc_list = zip(*rms) + + # Remove the ones with RMSD as None, and store the valid structures in the list valid_crys. + rms_np = [] + valid_crys = [] + for i, rms_per in enumerate(rms): + if rms_per is not None: + rms_np.append(rms_per) + valid_crys.append(struc_list[i]) + + # Conduct rigorous structural verification, specifically through verification using the Crystal class. + print('Using the Crystal class for validity checks') + valid_list = p_map(lambda x: Crystal(x).valid, valid_crys) + rms_np_strict = [] + for i, is_valid in enumerate(valid_list): + if is_valid: + rms_np_strict.append(rms_np[i]) + + rms_np = np.array(rms_np_strict) + rms_valid_index = np.array([x is not None for x in rms_np_strict]) + + match_rate = rms_valid_index.sum() / len(gt_list) + rms = rms_np[rms_valid_index].mean() + + print('match_rate: ', match_rate) + print('rms: ', rms) + + all_metrics = {'match_rate': match_rate, 'rms_dist': rms} + + if Path(output_path).exists(): + metrics_out_file = f'eval_metrics_{num_eval}.json' + metrics_out_file = os.path.join(output_path, metrics_out_file) + + # only overwrite metrics computed in the new run. + if Path(metrics_out_file).exists(): + with open(metrics_out_file, 'r') as f: + written_metrics = json.load(f) + if isinstance(written_metrics, dict): + written_metrics.update(all_metrics) + else: + with open(metrics_out_file, 'w') as f: + json.dump(all_metrics, f) + if isinstance(written_metrics, dict): + with open(metrics_out_file, 'w') as f: + json.dump(written_metrics, f) + else: + with open(metrics_out_file, 'w') as f: + json.dump(all_metrics, f) + else: + print('Warning: The metric result file path is not specified') + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--config', default='config.yaml') + main_args = parser.parse_args() + main(main_args) diff --git a/MindChem/applications/diffcsp/config.yaml b/MindChem/applications/diffcsp/config.yaml new file mode 100644 index 000000000..6c2899a15 --- /dev/null +++ b/MindChem/applications/diffcsp/config.yaml @@ -0,0 +1,41 @@ +dataset: + data_name: 'mp_20' + train: + path: './dataset/mp_20/train.csv' + save_path: './dataset/mp_20/train.npy' + val: + path: './dataset/mp_20/val.csv' + save_path: './dataset/mp_20/val.npy' + test: + path: './dataset/mp_20/test.csv' + save_path: './dataset/mp_20/test.npy' + +model: + # For dataset carbon, mp, mpts + hidden_dim: 512 + num_layers: 6 + num_freqs: 128 + # # For dataset perov + # hidden_dim: 256 + # num_layers: 4 + # num_freqs: 10 + +train: + ckpt_dir: "./ckpt/mp_20" + # 3500, 4000, 1000, 1000 epochs for Perov-5, Carbon-24, MP-20 and MPTS-52 respectively. + epoch_size: 1000 + # 512, 512, 128, 128 for Perov-5, Carbon-24, MP-20 and MPTS-52 respectively. + batch_size: 256 + seed: 1234 + +checkpoint: + last_path: "./ckpt/mp_20/last_test.ckpt" + +test: + # 1024 for perov, 512 for carbon and mp, 256 for mpts + batch_size: 512 + num_eval: 1 + # 1e-5 for mp and mpts, 5e-7 for perov, 5e-6 for carbon num_eval=1 and 5e-7 for carbon num_eval=20 + step_lr: 1e-5 + eval_save_path: './ckpt/mp_20/predict_crys.pkl' + metric_dir: './ckpt/mp_20/' diff --git a/MindChemistry/applications/diffcsp/data/crysloader.py b/MindChem/applications/diffcsp/data/crysloader.py similarity index 100% rename from MindChemistry/applications/diffcsp/data/crysloader.py rename to MindChem/applications/diffcsp/data/crysloader.py diff --git a/MindChemistry/applications/diffcsp/data/data_utils.py b/MindChem/applications/diffcsp/data/data_utils.py similarity index 100% rename from MindChemistry/applications/diffcsp/data/data_utils.py rename to MindChem/applications/diffcsp/data/data_utils.py diff --git a/MindChem/applications/diffcsp/data/dataset.py b/MindChem/applications/diffcsp/data/dataset.py new file mode 100644 index 000000000..964a9a957 --- /dev/null +++ b/MindChem/applications/diffcsp/data/dataset.py @@ -0,0 +1,135 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""dataset file""" +import os +from pathlib import Path + +import numpy as np + +from data.data_utils import StandardScaler, preprocess + +def fullconnect_dataset(name, + path, + niggli=True, + primitive=False, + graph_method='none', + preprocess_workers=30, + save_path='', + nrows=-1): + """ + Read crystal data from a CSV file and convert each into a fully connected graph, + where the nodes represent atoms within the unit cell and the edges connect every pair of nodes. + + Args: + name (str): The name of dataset, mainly used to read the dataset + property in './dataset/dataset_prop.txt'. + It doesn't matter for crystal structure prediction task. + Choices: [perov_5, carbon_24, mp_20, mpts_52]. + Users can also create custom datasets, by modify the + './dataset/dataset_prop.txt'. + path (str): The path of csv file of dataset. + niggli (bool): Whether to use niggli algorithom to + preprocess the choice of lattice. Default: + ``True``. + primitive (bool): Whether to represent the crystal in primitive cell. Default: + ``False``. + graph_method (str): If 'crystalnn', construct the graph by crystalnn + algorithm, mainly effect the construct of edges. + If 'none', don't construct any edge. Default: ``none``. + preprocess_workers (int): The numbers of cpus used for + preprocessing the crystals. Default: ``None``. + save_path (str): The path for saving the preprocessed data, + aiming to load the dataset more quickly next time. + nrows (int): If nrows > 0, read the first 'nrows' lines of csv file. + If nrows = -1, read the whole csv file. + This arg is mainly for debugging to quickly load a few crystals. + Returns: + x (list): List of Atom types. Shape of each element i.e. numpy array: (num_atoms, 1) + frac_coord_list (list): List of Fractional Coordinates of atoms. + Shape of each element i.e. numpy array: (num_atoms, 3) + edge_attr (list): List of numpy arrays filled with ones, + just used to better construct the dataloader, + without numerical significance. Shape of each element + i.e. numpy array: (num_edges, 3) + edge_index (list): List of index of the beginning and end + of edges. Each element is composed as [src, dst], where + src and dst is numpy arrays with Shape (num_edges,). + lengths_list (list): List of lengths of lattice. Shape of + each element i.e. numpy array: (3,) + angles_list (list): List of angles of lattice. Shape of + each element i.e. numpy array: (3,) + labels (list): List of property of crystal. Shape of + each element i.e. numpy array: (1,) + """ + x = [] + frac_coord_list = [] + edge_index = [] + edge_attr = [] + labels = [] + lengths_list = [] + angles_list = [] + + if Path('./dataset/dataset_prop.txt').exists(): + with open('./dataset/dataset_prop.txt', 'r') as file: + data = file.read() + # pylint: disable=W0123 + scalar_dict = eval(data) + else: + scalar_dict = {} + + if name in scalar_dict.keys(): + prop = scalar_dict[name]['prop'] + scaler = StandardScaler(scalar_dict[name]['scaler.means'], + scalar_dict[name]['scaler.stds']) + else: + print('No dataset property is specified, so no property reading is performed') + prop = "None" + scaler = None + + if os.path.exists(save_path): + cached_data = np.load(save_path, allow_pickle=True) + else: + cached_data = preprocess(path, + preprocess_workers, + niggli=niggli, + primitive=primitive, + graph_method=graph_method, + prop_list=[prop], + nrows=nrows) + + np.save(save_path, cached_data) + + for idx in range(len(cached_data)): + data_dict = cached_data[idx] + (frac_coords, atom_types, lengths, angles, _, _, + num_atoms) = data_dict['graph_arrays'] + + indices = np.arange(num_atoms) + dst, src = np.meshgrid(indices, indices) + src = src.reshape(-1) + dst = dst.reshape(-1) + + x.append(atom_types.reshape(-1, 1)) + frac_coord_list.append(frac_coords) + edge_index.append(np.array([src, dst])) + edge_attr.append(np.ones((num_atoms * num_atoms, 3))) + lengths_list.append(lengths) + angles_list.append(angles) + if scaler is not None: + labels.append(scaler.transform(data_dict[prop])) + else: + labels.append(0.0) + + return x, frac_coord_list, edge_attr, edge_index, lengths_list, angles_list, labels diff --git a/MindChemistry/applications/diffcsp/evaluate.py b/MindChem/applications/diffcsp/evaluate.py similarity index 100% rename from MindChemistry/applications/diffcsp/evaluate.py rename to MindChem/applications/diffcsp/evaluate.py diff --git a/MindChemistry/applications/diffcsp/models/cspnet.py b/MindChem/applications/diffcsp/models/cspnet.py similarity index 100% rename from MindChemistry/applications/diffcsp/models/cspnet.py rename to MindChem/applications/diffcsp/models/cspnet.py diff --git a/MindChemistry/applications/diffcsp/models/diff_utils.py b/MindChem/applications/diffcsp/models/diff_utils.py similarity index 100% rename from MindChemistry/applications/diffcsp/models/diff_utils.py rename to MindChem/applications/diffcsp/models/diff_utils.py diff --git a/MindChemistry/applications/diffcsp/models/diffusion.py b/MindChem/applications/diffcsp/models/diffusion.py similarity index 100% rename from MindChemistry/applications/diffcsp/models/diffusion.py rename to MindChem/applications/diffcsp/models/diffusion.py diff --git a/MindChemistry/applications/diffcsp/models/infer_utils.py b/MindChem/applications/diffcsp/models/infer_utils.py similarity index 100% rename from MindChemistry/applications/diffcsp/models/infer_utils.py rename to MindChem/applications/diffcsp/models/infer_utils.py diff --git a/MindChemistry/applications/diffcsp/models/train_utils.py b/MindChem/applications/diffcsp/models/train_utils.py similarity index 100% rename from MindChemistry/applications/diffcsp/models/train_utils.py rename to MindChem/applications/diffcsp/models/train_utils.py diff --git a/MindChemistry/applications/diffcsp/requirement.txt b/MindChem/applications/diffcsp/requirement.txt similarity index 100% rename from MindChemistry/applications/diffcsp/requirement.txt rename to MindChem/applications/diffcsp/requirement.txt diff --git a/MindChem/applications/diffcsp/train.py b/MindChem/applications/diffcsp/train.py new file mode 100644 index 000000000..fd21606a6 --- /dev/null +++ b/MindChem/applications/diffcsp/train.py @@ -0,0 +1,208 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""train file""" +import os +import time +import logging +import argparse +import yaml +import numpy as np +import mindspore as ms +from mindspore import nn, set_seed +from mindspore.amp import all_finite +from mindchemistry.graph.loss import L2LossMask +from models.cspnet import CSPNet +from models.diffusion import CSPDiffusion +from models.train_utils import LossRecord +from data.dataset import fullconnect_dataset +from data.crysloader import Crysloader as DataLoader + +logging.basicConfig(level=logging.INFO) + +def parse_args(): + '''Parse input args''' + parser = argparse.ArgumentParser() + parser.add_argument('--config', default='config.yaml', help="The config file path") + parser.add_argument('--device_id', type=int, default=0, + help="ID of the target device") + parser.add_argument('--device_target', type=str, default='Ascend', choices=["GPU", "Ascend"], + help="The target device to run, support 'Ascend', 'GPU'") + input_args = parser.parse_args() + return input_args + +def main(): + args = parse_args() + ms.set_context(device_target=args.device_target, device_id=args.device_id) + + with open(args.config, 'r') as stream: + config = yaml.safe_load(stream) + + ckpt_dir = config['train']["ckpt_dir"] + + if not os.path.exists(ckpt_dir): + os.makedirs(ckpt_dir) + + set_seed(config['train']["seed"]) + + batch_size_max = config['train']['batch_size'] + + cspnet = CSPNet(num_layers=config['model']['num_layers'], hidden_dim=config['model']['hidden_dim'], + num_freqs=config['model']['num_freqs']) + + if os.path.exists(config['checkpoint']['last_path']): + logging.info("load from existing check point................") + param_dict = ms.load_checkpoint(config['checkpoint']['last_path']) + ms.load_param_into_net(cspnet, param_dict) + logging.info("finish load from existing checkpoint") + else: + logging.info("Starting new training process") + + diffcsp = CSPDiffusion(cspnet) + + model_parameters = filter(lambda p: p.requires_grad, diffcsp.get_parameters()) + params = sum(np.prod(p.shape) for p in model_parameters) + logging.info("The model you built has %s parameters.", params) + + optimizer = nn.Adam(params=diffcsp.trainable_params()) + loss_func_mse = L2LossMask(reduction='mean') + + def forward(atom_types_step, frac_coords_step, _, lengths_step, angles_step, edge_index_step, batch_node2graph, \ + node_mask_step, edge_mask_step, batch_mask, node_num_valid, batch_size_valid): + """forward""" + pred_l, rand_l, pred_x, tar_x = diffcsp(batch_size_valid, atom_types_step, lengths_step, + angles_step, frac_coords_step, batch_node2graph, edge_index_step, + node_mask_step, edge_mask_step, batch_mask) + mseloss_l = loss_func_mse(pred_l, rand_l, mask=batch_mask, num=batch_size_valid) + mseloss_x = loss_func_mse(pred_x, tar_x, mask=node_mask_step, num=node_num_valid) + mseloss = mseloss_l + mseloss_x + + return mseloss, mseloss_l, mseloss_x + + backward = ms.value_and_grad(forward, None, weights=diffcsp.trainable_params(), has_aux=True) + + @ms.jit + def train_step(atom_types_step, frac_coords_step, property_step, lengths_step, angles_step, + edge_index_step, batch_node2graph, node_mask_step, edge_mask_step, batch_mask, + node_num_valid, batch_size_valid): + """train step""" + (mseloss, mseloss_l, mseloss_x), grads = backward(atom_types_step, frac_coords_step, property_step, + lengths_step, angles_step, edge_index_step, batch_node2graph, + node_mask_step, edge_mask_step, batch_mask, node_num_valid, + batch_size_valid) + + is_finite = all_finite(grads) + if is_finite: + optimizer(grads) + + return mseloss, is_finite, mseloss_l, mseloss_x + + @ms.jit + def eval_step(atom_types_step, frac_coords_step, property_step, lengths_step, angles_step, + edge_index_step, batch_node2graph, + node_mask_step, edge_mask_step, batch_mask, node_num_valid, batch_size_valid): + """eval step""" + mseloss, mseloss_l, mseloss_x = forward(atom_types_step, frac_coords_step, property_step, lengths_step, + angles_step, edge_index_step, batch_node2graph, + node_mask_step, edge_mask_step, batch_mask, node_num_valid, + batch_size_valid) + return mseloss, mseloss_l, mseloss_x + + epoch = 0 + epoch_size = config['train']["epoch_size"] + + logging.info("Start to initialise train_loader") + train_datatset = fullconnect_dataset(name=config['dataset']["data_name"], path=config['dataset']["train"]["path"], + save_path=config['dataset']["train"]["save_path"]) + train_loader = DataLoader(batch_size_max, *train_datatset, shuffle_dataset=True) + logging.info("Start to initialise eval_loader") + val_datatset = fullconnect_dataset(name=config['dataset']["data_name"], path=config['dataset']["val"]["path"], + save_path=config['dataset']["val"]["save_path"]) + eval_loader = DataLoader(batch_size_max, *val_datatset, + dynamic_batch_size=False, shuffle_dataset=True) + + while epoch < epoch_size: + epoch_starttime = time.time() + + train_mseloss_record = LossRecord() + eval_mseloss_record = LossRecord() + + #################################################### train ################################################### + logging.info("+++++++++++++++ start traning +++++++++++++++++++++") + diffcsp.set_train(True) + + starttime = time.time() + record_iter = 0 + for atom_types_batch, frac_coords_batch, property_batch, lengths_batch, angles_batch,\ + edge_index_batch, batch_node2graph_, node_mask_batch, edge_mask_batch, batch_mask_batch,\ + node_num_valid_, batch_size_valid_ in train_loader: + + result = train_step(atom_types_batch, frac_coords_batch, property_batch, + lengths_batch, angles_batch, edge_index_batch, batch_node2graph_, + node_mask_batch, edge_mask_batch, batch_mask_batch, node_num_valid_, + batch_size_valid_) + + mseloss_step, _, mseloss_l_, mseloss_x_ = result + + if record_iter % 50 == 0: + logging.info("==============================step: %s ,epoch: %s", train_loader.step - 1, epoch) + logging.info("learning rate: %s", optimizer.learning_rate.value()) + logging.info("train mse loss: %s", mseloss_step) + logging.info("train mse_lattice loss: %s", mseloss_l_) + logging.info("train mse_coords loss: %s", mseloss_x_) + starttime0 = starttime + starttime = time.time() + logging.info("traning time: %s", starttime - starttime0) + + record_iter += 1 + + train_mseloss_record.update(mseloss_step) + + #################################################### finish train ######################################## + epoch_endtime = time.time() + logging.info("epoch %s running time: %s", epoch, epoch_endtime - epoch_starttime) + logging.info("epoch %s average train mse loss: %s", epoch, train_mseloss_record.avg) + + ms.save_checkpoint(diffcsp.decoder, config['checkpoint']['last_path']) + + if epoch % 5 == 0: + #################################################### validation ########################################## + logging.info("+++++++++++++++ start validation +++++++++++++++++++++") + diffcsp.set_train(False) + + starttime = time.time() + for atom_types_batch, frac_coords_batch, property_batch, lengths_batch, angles_batch,\ + edge_index_batch, batch_node2graph_, node_mask_batch, edge_mask_batch, batch_mask_batch,\ + node_num_valid_, batch_size_valid_ in eval_loader: + + result_e = eval_step(atom_types_batch, frac_coords_batch, property_batch, + lengths_batch, angles_batch, edge_index_batch, batch_node2graph_, + node_mask_batch, edge_mask_batch, batch_mask_batch, node_num_valid_, + batch_size_valid_) + + mseloss_step, mseloss_l_, mseloss_x_ = result_e + + eval_mseloss_record.update(mseloss_step) + + #################################################### finish validation ################################# + + starttime0 = starttime + starttime = time.time() + logging.info("validation time: %s", starttime - starttime0) + logging.info("epoch %s average validation mse loss: %s", epoch, eval_mseloss_record.avg) + + epoch = epoch + 1 + +if __name__ == '__main__': + main() -- Gitee From 36f291b9b0bdedf1bd0ca9b326f7a6532c638783 Mon Sep 17 00:00:00 2001 From: birfied Date: Thu, 18 Sep 2025 11:01:03 +0800 Subject: [PATCH 2/2] move files to MindChem --- MindChemistry/applications/diffcsp/README.md | 127 ------- .../applications/diffcsp/compute_metric.py | 327 ------------------ .../applications/diffcsp/config.yaml | 41 --- .../applications/diffcsp/data/dataset.py | 135 -------- MindChemistry/applications/diffcsp/train.py | 208 ----------- 5 files changed, 838 deletions(-) delete mode 100644 MindChemistry/applications/diffcsp/README.md delete mode 100644 MindChemistry/applications/diffcsp/compute_metric.py delete mode 100644 MindChemistry/applications/diffcsp/config.yaml delete mode 100644 MindChemistry/applications/diffcsp/data/dataset.py delete mode 100644 MindChemistry/applications/diffcsp/train.py diff --git a/MindChemistry/applications/diffcsp/README.md b/MindChemistry/applications/diffcsp/README.md deleted file mode 100644 index 858aa417c..000000000 --- a/MindChemistry/applications/diffcsp/README.md +++ /dev/null @@ -1,127 +0,0 @@ - -# 模型名称 - -> DiffCSP - -## 介绍 - -> DiffCSP是一种基于扩散模型的深度学习框架,用于解决晶体结构预测这一基础科学难题。其核心思想是将寻找稳定晶体结构的过程转化为一个生成问题:模型通过学习海量已知晶体数据中的分布规律,能够仅根据材料的化学成分(原子种类与比例),直接、快速地生成合理的三维原子结构(包括晶格和原子坐标)。与传统依赖大量量子力学计算的方法相比,DiffCSP的关键创新在于采用了​​SE(3)-等变图神经网络​​并融入了​​周期性边界条件​​,确保了生成的结构严格遵守物理对称性,从而能以极高的效率探索材料的多态性,为新材料的加速发现与设计提供了强大工具。 - -## 环境要求 - -> 1. 安装`mindspore(2.3.0)` -> 2. 安装依赖包:`pip install -r requirement.txt` - -## 快速入门 - -> 1. 将Mindchemistry/mindchemistry文件包下载到当前目录 -> 2. 在[数据集链接](https://download-mindspore.osinfra.cn/mindscience/mindchemistry/diffcsp/)下载相应的数据集 -> 3. 安装依赖包:`pip install -r requirement.txt` -> 4. 训练命令: `python train.py` -> 5. 预测命令: `python evaluate.py` -> 6. 评估命令: `python compute_metric.py` -> 7. 评估结果放在`config.yaml`中指定的`metric_dir`路径的json文件中 - -### 代码目录结构 - -```txt -diffcsp - │ README.md README文件 - │ config.yaml 配置文件 - │ train.py 训练启动脚本 - │ evaluate.py 推理启动脚本 - │ compute_metric.py 评估启动脚本 - │ requirement.txt 环境依赖 - │ - └─data - data_utils.py 数据集处理工具 - dataset.py 读取数据集 - crysloader.py 数据集载入器 - └─models - cspnet.py 基于图神经网络的去噪器模块 - diffusion.py 扩散模型模块 - diff_utils.py 工具模块 - infer_utils.py 推理工具模块 - train_utils.py 训练工具模块 - -``` - -## 下载数据集 - -在[数据集链接](https://download-mindspore.osinfra.cn/mindscience/mindchemistry/diffcsp/)中下载相应的数据集文件夹和dataset_prop.txt数据集属性文件放置于当前路径的dataset文件夹下(如果没有需要自己手动创建),文件路径参考: - -```txt -diffcsp - ... - └─dataset - perov_5 钙钛矿数据集 - carbon_24 碳晶体数据集 - mp_20 晶胞内原子数最多为20的MP数据集 - mpts_52 晶胞内原子数最多为52的MP数据集 - dataset_prop.txt 数据集属性文件 - ... -``` - -## 训练过程 - -### 训练 - -将Mindchemistry/mindchemistry文件包下载到当前目录; - -更改config文件,设置训练参数: -> 1. 设置训练的dataset,见dataset字段 -> 2. 设置去噪器模型的配置,见model字段 -> 3. 设置训练保存的权重文件,更改train.ckpt_dir文件夹名称和checkpoint.last_path权重文件名称 -> 4. 其它训练设置见train字段 - -```bash -pip install -r requirement.txt -python train.py -``` - -### 推理 - -将权重的path写入config文件的checkpoint.last_path中。预训练模型可以从[预训练模型链接](https://download-mindspore.osinfra.cn/mindscience/mindchemistry/diffcsp/pre-train)中获取。 - -更改config文件中的test字段来更改推理参数,特别是test.num_eval,它**决定了对于每个组分生成多少个样本**,对于后续的评估阶段很重要。 - -```bash -python evaluate.py -``` - -推理得到的晶体将保存在test.eval_save_path指定的文件中 - -文件中存储的内容为python字典,格式为: - -```python -{ - 'pred': [ - [晶体A sample 1, 晶体A sample 2, 晶体A sample 3, ... 晶体A sample num_eval], - [晶体B sample 1, 晶体B sample 2, 晶体B sample 3, ... 晶体B sample num_eval] - ... - ] - 'gt': [ - 晶体A ground truth, - 晶体B ground truth, - ... - ] -} -``` - -### 评估 - -将推理得到的晶体文件的path写入config文件的test.eval_save_path中; - -确保num_evals与进行推理时设置的对于每个组分生成样本的数量一致或更小。比如进行推理时,num_evals设置为1,那么评估时,num_evals只能设置为1;推理时,num_evals设置为20,那么评估时,num_evals可以设置为1-20的数字来进行评估。 - -更改config文件中的test.metric_dir字段来设置评估结果的保存路径 - -```bash -python compute_metric.py -``` - -得到的评估结果文件示例: - -```json -{"match_rate": 0.985997357992074, "rms_dist": 0.013073775170360118} -``` diff --git a/MindChemistry/applications/diffcsp/compute_metric.py b/MindChemistry/applications/diffcsp/compute_metric.py deleted file mode 100644 index df5539aad..000000000 --- a/MindChemistry/applications/diffcsp/compute_metric.py +++ /dev/null @@ -1,327 +0,0 @@ -# Copyright 2024 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""compute metric file""" -import itertools -import json -import os -import pickle -from collections import Counter -from pathlib import Path -import argparse -import yaml - -import numpy as np -from matminer.featurizers.composition.composite import ElementProperty -from matminer.featurizers.site.fingerprint import CrystalNNFingerprint -from p_tqdm import p_map -from pymatgen.analysis.structure_matcher import StructureMatcher -from pymatgen.core.composition import Composition -from pymatgen.core.lattice import Lattice -from pymatgen.core.structure import Structure -import smact -from smact.screening import pauling_test -from tqdm import trange - -from models.infer_utils import chemical_symbols - -matcher = StructureMatcher(stol=0.5, angle_tol=10, ltol=0.3) -crystalnn_fp = CrystalNNFingerprint.from_preset("ops") -comp_fp = ElementProperty.from_preset('magpie') - - -def smact_validity(comp, count, use_pauling_test=True, include_alloys=True): - """Smact validity. See details in the paper Crystal Diffution Variational Autoencoder and - its codebase. - """ - elem_symbols = tuple([chemical_symbols[elem] for elem in comp]) - space = smact.element_dictionary(elem_symbols) - smact_elems = [e[1] for e in space.items()] - electronegs = [e.pauling_eneg for e in smact_elems] - ox_combos = [e.oxidation_states for e in smact_elems] - if len(set(elem_symbols)) == 1: - return True - if include_alloys: - is_metal_list = [elem_s in smact.metals for elem_s in elem_symbols] - if all(is_metal_list): - return True - - threshold = np.max(count) - oxn = 1 - for oxc in ox_combos: - oxn *= len(oxc) - if oxn > 1e7: - return False - for ox_states in itertools.product(*ox_combos): - stoichs = [(c,) for c in count] - # Test for charge balance - cn_e, _ = smact.neutral_ratios(ox_states, - stoichs=stoichs, - threshold=threshold) - # Electronegativity test - if cn_e: - if use_pauling_test: - try: - electroneg_ok = pauling_test(ox_states, electronegs) - except TypeError: - # if no electronegativity data, assume it is okay - electroneg_ok = True - else: - electroneg_ok = True - if electroneg_ok: - return True - return False - - -def structure_validity(crystal, cutoff=0.5): - """Structure validity. See details in the paper Crystal Diffution Variational Autoencoder and - its codebase. - """ - dist_mat = crystal.distance_matrix - # Pad diagonal with a large number - dist_mat = dist_mat + np.diag(np.ones(dist_mat.shape[0]) * (cutoff + 10.)) - if dist_mat.min() < cutoff or crystal.volume < 0.1 or max( - crystal.lattice.abc) > 40: - return False - - return True - -class Crystal: - """Strict crystal validity. See details in the paper CDVAE `Crystal - Diffution Variational Autoencoder` and - its codebase. We adopt the same evaluation metric criteria as CDVAE. - """ - - def __init__(self, crys_array_dict): - self.frac_coords = crys_array_dict['frac_coords'] - self.atom_types = crys_array_dict['atom_types'] - self.lengths = crys_array_dict['lengths'] - self.angles = crys_array_dict['angles'] - self.dict = crys_array_dict - if len(self.atom_types.shape) > 1: - self.dict['atom_types'] = np.argmax(self.atom_types, axis=-1) + 1 - self.atom_types = np.argmax(self.atom_types, axis=-1) + 1 - - self.get_structure() - self.get_composition() - self.get_validity() - self.get_fingerprints() - - def get_structure(self): - """get_structure - """ - if min(self.lengths.tolist()) < 0: - self.constructed = False - self.invalid_reason = 'non_positive_lattice' - if np.isnan(self.lengths).any() or np.isnan( - self.angles).any() or np.isnan(self.frac_coords).any(): - self.constructed = False - self.invalid_reason = 'nan_value' - else: - try: - self.structure = Structure(lattice=Lattice.from_parameters( - *(self.lengths.tolist() + self.angles.tolist())), - species=self.atom_types, - coords=self.frac_coords, - coords_are_cartesian=False) - self.constructed = True - # pylint: disable=W0703 - except Exception: - self.constructed = False - self.invalid_reason = 'construction_raises_exception' - if self.structure.volume < 0.1: - self.constructed = False - self.invalid_reason = 'unrealistically_small_lattice' - - def get_composition(self): - """get_composition - """ - elem_counter = Counter(self.atom_types) # pylint: disable=E1121 - composition = [(elem, elem_counter[elem]) - for elem in sorted(elem_counter.keys())] - elems, counts = list(zip(*composition)) - counts = np.array(counts) - counts = counts / np.gcd.reduce(counts) - self.elems = elems - self.comps = tuple(counts.astype('int').tolist()) - - def get_validity(self): - """get_validity - """ - self.comp_valid = smact_validity(self.elems, self.comps) - if self.constructed: - self.struct_valid = structure_validity(self.structure) - else: - self.struct_valid = False - self.valid = self.comp_valid and self.struct_valid - - def get_fingerprints(self): - """get_fingerprints - """ - elem_counter = Counter(self.atom_types) # pylint: disable=E1121 - comp = Composition(elem_counter) - self.comp_fp = comp_fp.featurize(comp) - try: - site_fps = [ - crystalnn_fp.featurize(self.structure, i) - for i in range(len(self.structure)) - ] - # pylint: disable=W0703 - except Exception: - # counts crystal as invalid if fingerprint cannot be constructed. - self.valid = False - self.comp_fp = None - self.struct_fp = None - return - self.struct_fp = np.array(site_fps).mean(axis=0) - - -def get_rms(pred_struc_list, gt_struc: Structure, num_eval, np_list): - """Calculate the rms distance between the ground truth and predicted crystal structures. - - Args: - pred_struc_list (List[Structure]): The crystals generated by diffution model - in the form of Structure. - gt_struc (Structure): The ground truth crystal. - num_eval (int): Specify that the first N items in the predicted List of crystal structures - participate in the evaluationo. - np_list (List[Dict]): The crystals generated by diffution model in the form of Dict. - """ - - def process_one(pred_struc: Structure): - try: - if not pred_struc.is_valid(): - return None - rms_dist = matcher.get_rms_dist(pred_struc, gt_struc) - rms_dist = None if rms_dist is None else rms_dist[0] - tune_rms = rms_dist - # pylint: disable=W0703 - except Exception: - tune_rms = None - return tune_rms - - min_rms = None - min_struc = None - for i, struct in enumerate(pred_struc_list): - if i == num_eval: - break - rms = process_one(struct) - if rms is not None and (min_rms is None or min_rms > rms): - min_rms = rms - min_struc = np_list[i] - return min_rms, min_struc - - -def get_struc_from_np_list(np_list): - """convert the crystal in the form of Dict to pymatgen.Structure - """ - result = [] - for cry_array in np_list: - try: - struct = Structure(lattice=Lattice.from_parameters( - *(cry_array['lengths'].tolist() + - cry_array['angles'].tolist())), - species=cry_array['atom_types'], - coords=cry_array['frac_coords'], - coords_are_cartesian=False) - # pylint: disable=W0703 - except Exception: - print('Warning: One anomalous crystal structure has captured and removed. ') - struct = None - - result.append(struct) - return result - -def main(args): - """main - """ - with open(args.config, 'r') as stream: - config = yaml.safe_load(stream) - - eval_file = config['test']['eval_save_path'] - num_eval = config['test']['num_eval'] - output_path = config['test']['metric_dir'] - - with open(eval_file, 'rb') as f: - eval_dict = pickle.load(f) - - pred_list = eval_dict['pred'] - gt_list = eval_dict['gt'] - gt_list = get_struc_from_np_list(gt_list) - rms = [] - - # calculate rmsd - for i in trange(len(gt_list)): - pred_struc = get_struc_from_np_list(pred_list[i]) - gt_struc = gt_list[i] - rms_single, struc_single = get_rms(pred_struc, gt_struc, num_eval, - pred_list[i]) - rms.append((rms_single, struc_single)) - - rms, struc_list = zip(*rms) - - # Remove the ones with RMSD as None, and store the valid structures in the list valid_crys. - rms_np = [] - valid_crys = [] - for i, rms_per in enumerate(rms): - if rms_per is not None: - rms_np.append(rms_per) - valid_crys.append(struc_list[i]) - - # Conduct rigorous structural verification, specifically through verification using the Crystal class. - print('Using the Crystal class for validity checks') - valid_list = p_map(lambda x: Crystal(x).valid, valid_crys) - rms_np_strict = [] - for i, is_valid in enumerate(valid_list): - if is_valid: - rms_np_strict.append(rms_np[i]) - - rms_np = np.array(rms_np_strict) - rms_valid_index = np.array([x is not None for x in rms_np_strict]) - - match_rate = rms_valid_index.sum() / len(gt_list) - rms = rms_np[rms_valid_index].mean() - - print('match_rate: ', match_rate) - print('rms: ', rms) - - all_metrics = {'match_rate': match_rate, 'rms_dist': rms} - - if Path(output_path).exists(): - metrics_out_file = f'eval_metrics_{num_eval}.json' - metrics_out_file = os.path.join(output_path, metrics_out_file) - - # only overwrite metrics computed in the new run. - if Path(metrics_out_file).exists(): - with open(metrics_out_file, 'r') as f: - written_metrics = json.load(f) - if isinstance(written_metrics, dict): - written_metrics.update(all_metrics) - else: - with open(metrics_out_file, 'w') as f: - json.dump(all_metrics, f) - if isinstance(written_metrics, dict): - with open(metrics_out_file, 'w') as f: - json.dump(written_metrics, f) - else: - with open(metrics_out_file, 'w') as f: - json.dump(all_metrics, f) - else: - print('Warning: The metric result file path is not specified') - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('--config', default='config.yaml') - main_args = parser.parse_args() - main(main_args) diff --git a/MindChemistry/applications/diffcsp/config.yaml b/MindChemistry/applications/diffcsp/config.yaml deleted file mode 100644 index 6c2899a15..000000000 --- a/MindChemistry/applications/diffcsp/config.yaml +++ /dev/null @@ -1,41 +0,0 @@ -dataset: - data_name: 'mp_20' - train: - path: './dataset/mp_20/train.csv' - save_path: './dataset/mp_20/train.npy' - val: - path: './dataset/mp_20/val.csv' - save_path: './dataset/mp_20/val.npy' - test: - path: './dataset/mp_20/test.csv' - save_path: './dataset/mp_20/test.npy' - -model: - # For dataset carbon, mp, mpts - hidden_dim: 512 - num_layers: 6 - num_freqs: 128 - # # For dataset perov - # hidden_dim: 256 - # num_layers: 4 - # num_freqs: 10 - -train: - ckpt_dir: "./ckpt/mp_20" - # 3500, 4000, 1000, 1000 epochs for Perov-5, Carbon-24, MP-20 and MPTS-52 respectively. - epoch_size: 1000 - # 512, 512, 128, 128 for Perov-5, Carbon-24, MP-20 and MPTS-52 respectively. - batch_size: 256 - seed: 1234 - -checkpoint: - last_path: "./ckpt/mp_20/last_test.ckpt" - -test: - # 1024 for perov, 512 for carbon and mp, 256 for mpts - batch_size: 512 - num_eval: 1 - # 1e-5 for mp and mpts, 5e-7 for perov, 5e-6 for carbon num_eval=1 and 5e-7 for carbon num_eval=20 - step_lr: 1e-5 - eval_save_path: './ckpt/mp_20/predict_crys.pkl' - metric_dir: './ckpt/mp_20/' diff --git a/MindChemistry/applications/diffcsp/data/dataset.py b/MindChemistry/applications/diffcsp/data/dataset.py deleted file mode 100644 index 964a9a957..000000000 --- a/MindChemistry/applications/diffcsp/data/dataset.py +++ /dev/null @@ -1,135 +0,0 @@ -# Copyright 2024 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""dataset file""" -import os -from pathlib import Path - -import numpy as np - -from data.data_utils import StandardScaler, preprocess - -def fullconnect_dataset(name, - path, - niggli=True, - primitive=False, - graph_method='none', - preprocess_workers=30, - save_path='', - nrows=-1): - """ - Read crystal data from a CSV file and convert each into a fully connected graph, - where the nodes represent atoms within the unit cell and the edges connect every pair of nodes. - - Args: - name (str): The name of dataset, mainly used to read the dataset - property in './dataset/dataset_prop.txt'. - It doesn't matter for crystal structure prediction task. - Choices: [perov_5, carbon_24, mp_20, mpts_52]. - Users can also create custom datasets, by modify the - './dataset/dataset_prop.txt'. - path (str): The path of csv file of dataset. - niggli (bool): Whether to use niggli algorithom to - preprocess the choice of lattice. Default: - ``True``. - primitive (bool): Whether to represent the crystal in primitive cell. Default: - ``False``. - graph_method (str): If 'crystalnn', construct the graph by crystalnn - algorithm, mainly effect the construct of edges. - If 'none', don't construct any edge. Default: ``none``. - preprocess_workers (int): The numbers of cpus used for - preprocessing the crystals. Default: ``None``. - save_path (str): The path for saving the preprocessed data, - aiming to load the dataset more quickly next time. - nrows (int): If nrows > 0, read the first 'nrows' lines of csv file. - If nrows = -1, read the whole csv file. - This arg is mainly for debugging to quickly load a few crystals. - Returns: - x (list): List of Atom types. Shape of each element i.e. numpy array: (num_atoms, 1) - frac_coord_list (list): List of Fractional Coordinates of atoms. - Shape of each element i.e. numpy array: (num_atoms, 3) - edge_attr (list): List of numpy arrays filled with ones, - just used to better construct the dataloader, - without numerical significance. Shape of each element - i.e. numpy array: (num_edges, 3) - edge_index (list): List of index of the beginning and end - of edges. Each element is composed as [src, dst], where - src and dst is numpy arrays with Shape (num_edges,). - lengths_list (list): List of lengths of lattice. Shape of - each element i.e. numpy array: (3,) - angles_list (list): List of angles of lattice. Shape of - each element i.e. numpy array: (3,) - labels (list): List of property of crystal. Shape of - each element i.e. numpy array: (1,) - """ - x = [] - frac_coord_list = [] - edge_index = [] - edge_attr = [] - labels = [] - lengths_list = [] - angles_list = [] - - if Path('./dataset/dataset_prop.txt').exists(): - with open('./dataset/dataset_prop.txt', 'r') as file: - data = file.read() - # pylint: disable=W0123 - scalar_dict = eval(data) - else: - scalar_dict = {} - - if name in scalar_dict.keys(): - prop = scalar_dict[name]['prop'] - scaler = StandardScaler(scalar_dict[name]['scaler.means'], - scalar_dict[name]['scaler.stds']) - else: - print('No dataset property is specified, so no property reading is performed') - prop = "None" - scaler = None - - if os.path.exists(save_path): - cached_data = np.load(save_path, allow_pickle=True) - else: - cached_data = preprocess(path, - preprocess_workers, - niggli=niggli, - primitive=primitive, - graph_method=graph_method, - prop_list=[prop], - nrows=nrows) - - np.save(save_path, cached_data) - - for idx in range(len(cached_data)): - data_dict = cached_data[idx] - (frac_coords, atom_types, lengths, angles, _, _, - num_atoms) = data_dict['graph_arrays'] - - indices = np.arange(num_atoms) - dst, src = np.meshgrid(indices, indices) - src = src.reshape(-1) - dst = dst.reshape(-1) - - x.append(atom_types.reshape(-1, 1)) - frac_coord_list.append(frac_coords) - edge_index.append(np.array([src, dst])) - edge_attr.append(np.ones((num_atoms * num_atoms, 3))) - lengths_list.append(lengths) - angles_list.append(angles) - if scaler is not None: - labels.append(scaler.transform(data_dict[prop])) - else: - labels.append(0.0) - - return x, frac_coord_list, edge_attr, edge_index, lengths_list, angles_list, labels diff --git a/MindChemistry/applications/diffcsp/train.py b/MindChemistry/applications/diffcsp/train.py deleted file mode 100644 index fd21606a6..000000000 --- a/MindChemistry/applications/diffcsp/train.py +++ /dev/null @@ -1,208 +0,0 @@ -# Copyright 2024 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""train file""" -import os -import time -import logging -import argparse -import yaml -import numpy as np -import mindspore as ms -from mindspore import nn, set_seed -from mindspore.amp import all_finite -from mindchemistry.graph.loss import L2LossMask -from models.cspnet import CSPNet -from models.diffusion import CSPDiffusion -from models.train_utils import LossRecord -from data.dataset import fullconnect_dataset -from data.crysloader import Crysloader as DataLoader - -logging.basicConfig(level=logging.INFO) - -def parse_args(): - '''Parse input args''' - parser = argparse.ArgumentParser() - parser.add_argument('--config', default='config.yaml', help="The config file path") - parser.add_argument('--device_id', type=int, default=0, - help="ID of the target device") - parser.add_argument('--device_target', type=str, default='Ascend', choices=["GPU", "Ascend"], - help="The target device to run, support 'Ascend', 'GPU'") - input_args = parser.parse_args() - return input_args - -def main(): - args = parse_args() - ms.set_context(device_target=args.device_target, device_id=args.device_id) - - with open(args.config, 'r') as stream: - config = yaml.safe_load(stream) - - ckpt_dir = config['train']["ckpt_dir"] - - if not os.path.exists(ckpt_dir): - os.makedirs(ckpt_dir) - - set_seed(config['train']["seed"]) - - batch_size_max = config['train']['batch_size'] - - cspnet = CSPNet(num_layers=config['model']['num_layers'], hidden_dim=config['model']['hidden_dim'], - num_freqs=config['model']['num_freqs']) - - if os.path.exists(config['checkpoint']['last_path']): - logging.info("load from existing check point................") - param_dict = ms.load_checkpoint(config['checkpoint']['last_path']) - ms.load_param_into_net(cspnet, param_dict) - logging.info("finish load from existing checkpoint") - else: - logging.info("Starting new training process") - - diffcsp = CSPDiffusion(cspnet) - - model_parameters = filter(lambda p: p.requires_grad, diffcsp.get_parameters()) - params = sum(np.prod(p.shape) for p in model_parameters) - logging.info("The model you built has %s parameters.", params) - - optimizer = nn.Adam(params=diffcsp.trainable_params()) - loss_func_mse = L2LossMask(reduction='mean') - - def forward(atom_types_step, frac_coords_step, _, lengths_step, angles_step, edge_index_step, batch_node2graph, \ - node_mask_step, edge_mask_step, batch_mask, node_num_valid, batch_size_valid): - """forward""" - pred_l, rand_l, pred_x, tar_x = diffcsp(batch_size_valid, atom_types_step, lengths_step, - angles_step, frac_coords_step, batch_node2graph, edge_index_step, - node_mask_step, edge_mask_step, batch_mask) - mseloss_l = loss_func_mse(pred_l, rand_l, mask=batch_mask, num=batch_size_valid) - mseloss_x = loss_func_mse(pred_x, tar_x, mask=node_mask_step, num=node_num_valid) - mseloss = mseloss_l + mseloss_x - - return mseloss, mseloss_l, mseloss_x - - backward = ms.value_and_grad(forward, None, weights=diffcsp.trainable_params(), has_aux=True) - - @ms.jit - def train_step(atom_types_step, frac_coords_step, property_step, lengths_step, angles_step, - edge_index_step, batch_node2graph, node_mask_step, edge_mask_step, batch_mask, - node_num_valid, batch_size_valid): - """train step""" - (mseloss, mseloss_l, mseloss_x), grads = backward(atom_types_step, frac_coords_step, property_step, - lengths_step, angles_step, edge_index_step, batch_node2graph, - node_mask_step, edge_mask_step, batch_mask, node_num_valid, - batch_size_valid) - - is_finite = all_finite(grads) - if is_finite: - optimizer(grads) - - return mseloss, is_finite, mseloss_l, mseloss_x - - @ms.jit - def eval_step(atom_types_step, frac_coords_step, property_step, lengths_step, angles_step, - edge_index_step, batch_node2graph, - node_mask_step, edge_mask_step, batch_mask, node_num_valid, batch_size_valid): - """eval step""" - mseloss, mseloss_l, mseloss_x = forward(atom_types_step, frac_coords_step, property_step, lengths_step, - angles_step, edge_index_step, batch_node2graph, - node_mask_step, edge_mask_step, batch_mask, node_num_valid, - batch_size_valid) - return mseloss, mseloss_l, mseloss_x - - epoch = 0 - epoch_size = config['train']["epoch_size"] - - logging.info("Start to initialise train_loader") - train_datatset = fullconnect_dataset(name=config['dataset']["data_name"], path=config['dataset']["train"]["path"], - save_path=config['dataset']["train"]["save_path"]) - train_loader = DataLoader(batch_size_max, *train_datatset, shuffle_dataset=True) - logging.info("Start to initialise eval_loader") - val_datatset = fullconnect_dataset(name=config['dataset']["data_name"], path=config['dataset']["val"]["path"], - save_path=config['dataset']["val"]["save_path"]) - eval_loader = DataLoader(batch_size_max, *val_datatset, - dynamic_batch_size=False, shuffle_dataset=True) - - while epoch < epoch_size: - epoch_starttime = time.time() - - train_mseloss_record = LossRecord() - eval_mseloss_record = LossRecord() - - #################################################### train ################################################### - logging.info("+++++++++++++++ start traning +++++++++++++++++++++") - diffcsp.set_train(True) - - starttime = time.time() - record_iter = 0 - for atom_types_batch, frac_coords_batch, property_batch, lengths_batch, angles_batch,\ - edge_index_batch, batch_node2graph_, node_mask_batch, edge_mask_batch, batch_mask_batch,\ - node_num_valid_, batch_size_valid_ in train_loader: - - result = train_step(atom_types_batch, frac_coords_batch, property_batch, - lengths_batch, angles_batch, edge_index_batch, batch_node2graph_, - node_mask_batch, edge_mask_batch, batch_mask_batch, node_num_valid_, - batch_size_valid_) - - mseloss_step, _, mseloss_l_, mseloss_x_ = result - - if record_iter % 50 == 0: - logging.info("==============================step: %s ,epoch: %s", train_loader.step - 1, epoch) - logging.info("learning rate: %s", optimizer.learning_rate.value()) - logging.info("train mse loss: %s", mseloss_step) - logging.info("train mse_lattice loss: %s", mseloss_l_) - logging.info("train mse_coords loss: %s", mseloss_x_) - starttime0 = starttime - starttime = time.time() - logging.info("traning time: %s", starttime - starttime0) - - record_iter += 1 - - train_mseloss_record.update(mseloss_step) - - #################################################### finish train ######################################## - epoch_endtime = time.time() - logging.info("epoch %s running time: %s", epoch, epoch_endtime - epoch_starttime) - logging.info("epoch %s average train mse loss: %s", epoch, train_mseloss_record.avg) - - ms.save_checkpoint(diffcsp.decoder, config['checkpoint']['last_path']) - - if epoch % 5 == 0: - #################################################### validation ########################################## - logging.info("+++++++++++++++ start validation +++++++++++++++++++++") - diffcsp.set_train(False) - - starttime = time.time() - for atom_types_batch, frac_coords_batch, property_batch, lengths_batch, angles_batch,\ - edge_index_batch, batch_node2graph_, node_mask_batch, edge_mask_batch, batch_mask_batch,\ - node_num_valid_, batch_size_valid_ in eval_loader: - - result_e = eval_step(atom_types_batch, frac_coords_batch, property_batch, - lengths_batch, angles_batch, edge_index_batch, batch_node2graph_, - node_mask_batch, edge_mask_batch, batch_mask_batch, node_num_valid_, - batch_size_valid_) - - mseloss_step, mseloss_l_, mseloss_x_ = result_e - - eval_mseloss_record.update(mseloss_step) - - #################################################### finish validation ################################# - - starttime0 = starttime - starttime = time.time() - logging.info("validation time: %s", starttime - starttime0) - logging.info("epoch %s average validation mse loss: %s", epoch, eval_mseloss_record.avg) - - epoch = epoch + 1 - -if __name__ == '__main__': - main() -- Gitee