221 Star 944 Fork 693

GVPMindSpore/mindscience

Create your Gitee Account
Explore and code with more than 13.5 million developers,Free private repositories !:)
Sign up
文件
Clone or Download
orb.py 28.20 KB
Copy Edit Raw Blame History
Muvyy authored 2025-07-04 15:53 +08:00 . add orb
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698
# ============================================================================
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Orb GraphRegressor."""
from typing import Literal, Optional, Union
import numpy
import mindspore as ms
from mindspore import Parameter, ops, Tensor, mint
from mindchemistry.cell.orb.gns import _KEY, MoleculeGNS
from mindchemistry.cell.orb.utils import (
aggregate_nodes,
build_mlp,
REFERENCE_ENERGIES,
)
class LinearReferenceEnergy(ms.nn.Cell):
r"""
Linear reference energy (no bias term).
This class implements a linear reference energy model that can be used
to compute the reference energy for a given set of atomic numbers.
Args:
weight_init (numpy.ndarray, optional): Initial weights for the linear layer.
If not provided, the weights will be initialized randomly.
trainable (bool, optional): Whether the weights are trainable or not.
If not provided, the weights will be trainable by default.
Inputs:
- **atom_types** (Tensor) - A tensor of atomic numbers of shape (n_atoms,).
- **n_node** (Tensor) - A tensor of shape (n_graphs,) containing the number of nodes in each graph.
Outputs:
- **Tensor** - A tensor of shape (n_graphs, 1) containing the reference energy.
Raises:
ValueError: If the input tensor shapes are not compatible with the expected shapes.
TypeError: If the input types are not compatible with the expected types.
Supported Platforms:
``Ascend``
"""
def __init__(
self,
weight_init: Optional[numpy.ndarray] = None,
trainable: Optional[bool] = None,
):
"""init
"""
super().__init__()
if trainable is None:
trainable = weight_init is None
self.linear = ms.nn.Dense(118, 1, has_bias=False)
if weight_init is not None:
self.linear.weight.set_data(Tensor(weight_init, dtype=ms.float32).reshape(1, 118))
if not trainable:
self.linear.weight.requires_grad = False
def construct(self, atom_types: Tensor, n_node: Tensor):
"""construct
"""
one_hot_atomic = ops.OneHot()(atom_types, 118, Tensor(1.0, ms.float32), Tensor(0.0, ms.float32))
reduced = aggregate_nodes(one_hot_atomic, n_node, reduction="sum")
return self.linear(reduced)
class ScalarNormalizer(ms.nn.Cell):
r"""
Scalar normalizer that learns mean and std from data.
NOTE: Multi-dimensional tensors are flattened before updating
the running mean/std. This is desired behaviour for force targets.
Args:
init_mean (Tensor or float, optional): Initial mean value for normalization.
If not provided, defaults to 0.0.
init_std (Tensor or float, optional): Initial standard deviation value for normalization.
If not provided, defaults to 1.0.
init_num_batches (int, optional): Initial number of batches for normalization.
If not provided, defaults to 1000.
Inputs:
- **x** (Tensor) - A tensor of shape (n_samples, n_features) to normalize.
Outputs:
- **Tensor** - A tensor of the same shape as x, normalized by the running mean and std.
Raises:
ValueError: If the input tensor is not of the expected shape.
TypeError: If the input types are not compatible with the expected types.
Supported Platforms:
``Ascend``
"""
def __init__(
self,
init_mean: Optional[Union[Tensor, float]] = None,
init_std: Optional[Union[Tensor, float]] = None,
init_num_batches: Optional[int] = 1000,
):
"""init
"""
super().__init__()
self.bn = mint.nn.BatchNorm1d(1, affine=False, momentum=None)
self.bn.running_mean = Parameter(Tensor([0], ms.float32))
self.bn.running_var = Parameter(Tensor([1], ms.float32))
self.bn.num_batches_tracked = Parameter(Tensor([1000], ms.float32))
self.stastics = {
"running_mean": init_mean if init_mean is not None else 0.0,
"running_var": init_std**2 if init_std is not None else 1.0,
"num_batches_tracked": init_num_batches if init_num_batches is not None else 1000,
}
def construct(self, x: Tensor):
"""construct
"""
if self.training:
self.bn(x.view(-1, 1))
if hasattr(self, "running_mean"):
return (x - self.running_mean) / mint.sqrt(self.running_var)
return (x - self.bn.running_mean) / mint.sqrt(self.bn.running_var)
def inverse(self, x: Tensor):
"""Reverse the construct normalization.
Args:
x: A tensor of shape (n_samples, n_features) to inverse normalize.
Returns:
A tensor of the same shape as x, inverse normalized by the running mean and std.
"""
if hasattr(self, "running_mean"):
return x * mint.sqrt(self.running_var) + self.running_mean
return x * mint.sqrt(self.bn.running_var) + self.bn.running_mean
# pylint: disable=C0301
class NodeHead(ms.nn.Cell):
r"""
Node-level prediction head.
Implements neural network head for predicting node-level properties from node features. This head can be
added to base models to enable auxiliary tasks during pretraining or added in fine-tuning steps.
Args:
latent_dim (int): Input feature dimension for each node.
num_mlp_layers (int): Number of hidden layers in MLP.
mlp_hidden_dim (int): Hidden dimension size of MLP.
target_property_dim (int): Output dimension of node-level target property.
dropout (Optional[float], optional): Dropout rate for MLP. Default: ``None``.
remove_mean (bool, optional): If True, remove mean from output, typically used for force prediction.
Default: ``True``.
Inputs:
- **node_features** (dict) - Node feature dictionary, must contain key "feat" with shape :math:`(n_{nodes}, latent\_dim)`.
- **n_node** (Tensor) - Number of nodes in graph, shape :math:`(1,)`.
Outputs:
- **output** (dict) - Dictionary containing key "node_pred" with value of shape :math:`(n_{nodes}, target\_property\_dim)`.
Raises:
ValueError: If required feature keys are missing in `node_features`.
Supported Platforms:
``Ascend``
Examples:
>>> import numpy as np
>>> import mindspore
>>> from mindspore import Tensor
>>> from mindchemistry.cell.orb.gns import NodeHead
>>> node_head = NodeHead(
... latent_dim=256,
... num_mlp_layers=1,
... mlp_hidden_dim=256,
... target_property_dim=3,
... remove_mean=True,
... )
>>> n_atoms = 4
>>> n_node = Tensor([n_atoms], mindspore.int32)
>>> atomic_numbers = Tensor(np.random.randint(1, 119, size=(n_atoms,), dtype=np.int32))
>>> atomic_numbers_embedding_np = np.zeros((n_atoms, 118), dtype=np.float32)
>>> for i, num in enumerate(atomic_numbers.asnumpy()):
... atomic_numbers_embedding_np[i, num - 1] = 1.0
>>> node_features = {
... "atomic_numbers": atomic_numbers,
... "atomic_numbers_embedding": Tensor(atomic_numbers_embedding_np),
... "positions": Tensor(np.random.randn(n_atoms, 3).astype(np.float32)),
... "feat": Tensor(np.random.randn(n_atoms, 256).astype(np.float32))
... }
>>> output = node_head(node_features, n_node)
>>> print(output['node_pred'].shape)
(4, 3)
"""
def __init__(
self,
latent_dim: int,
num_mlp_layers: int,
mlp_hidden_dim: int,
target_property_dim: int,
dropout: Optional[float] = None,
remove_mean: bool = True,
):
"""init
"""
super().__init__()
self.target_property_dim = target_property_dim
self.normalizer = ScalarNormalizer()
self.mlp = build_mlp(
input_size=latent_dim,
hidden_layer_sizes=[mlp_hidden_dim] * num_mlp_layers,
output_size=self.target_property_dim,
dropout=dropout,
)
self.remove_mean = remove_mean
def construct(self, node_features, n_node):
"""construct
"""
feat = node_features[_KEY]
pred = self.mlp(feat)
if self.remove_mean:
system_means = aggregate_nodes(
pred, n_node, reduction="mean"
)
node_broadcasted_means = mint.repeat_interleave(
system_means, n_node, dim=0
)
pred = pred - node_broadcasted_means
res = {"node_pred": pred}
return res
def predict(self, node_features, n_node):
"""Predict node-level attributes.
Args:
node_features: Node features tensor of shape (n_nodes, latent_dim).
n_node: Number of nodes in the graph.
Returns:
node_pred: Node-level predictions of shape (n_nodes, target_property_dim).
"""
out = self(node_features, n_node)
pred = out["node_pred"]
return self.normalizer.inverse(pred)
# pylint: disable=C0301
class GraphHead(ms.nn.Cell):
r"""
Graph-level prediction head. Implements graph-level prediction head that can be attached to base models
for predicting graph-level properties (e.g., stress tensor) from node features using aggregation and MLP.
Args:
latent_dim (int): Input feature dimension for each node.
num_mlp_layers (int): Number of hidden layers in MLP.
mlp_hidden_dim (int): Hidden dimension size of MLP.
target_property_dim (int): Output dimension of graph-level property.
node_aggregation (str, optional): Aggregation method for node predictions, e.g., ``"mean"`` or ``"sum"``. Default: ``"mean"``.
dropout (Optional[float], optional): Dropout rate for MLP. Default: ``None``.
compute_stress (bool, optional): Whether to compute and output stress tensor. Default: ``False``.
Inputs:
- **node_features** (dict) - Node feature dictionary, must contain key "feat" with shape :math:`(n_{nodes}, latent\_dim)`.
- **n_node** (Tensor) - Number of nodes in graph, shape :math:`(1,)`.
Outputs:
- **output** (dict) - Dictionary containing key "stress_pred" with value of shape :math:`(1, target\_property\_dim)`.
Raises:
ValueError: If required feature keys are missing in `node_features`.
Supported Platforms:
``Ascend``
Examples:
>>> import numpy as np
>>> import mindspore
>>> from mindspore import Tensor
>>> from mindchemistry.cell.orb.gns import GraphHead
>>> graph_head = GraphHead(
... latent_dim=256,
... num_mlp_layers=1,
... mlp_hidden_dim=256,
... target_property_dim=6,
... compute_stress=True,
... )
>>> n_atoms = 4
>>> n_node = Tensor([n_atoms], mindspore.int32)
>>> atomic_numbers = Tensor(np.random.randint(1, 119, size=(n_atoms,), dtype=np.int32))
>>> atomic_numbers_embedding_np = np.zeros((n_atoms, 118), dtype=np.float32)
>>> for i, num in enumerate(atomic_numbers.asnumpy()):
... atomic_numbers_embedding_np[i, num - 1] = 1.0
>>> node_features = {
... "atomic_numbers": atomic_numbers,
... "atomic_numbers_embedding": Tensor(atomic_numbers_embedding_np),
... "positions": Tensor(np.random.randn(n_atoms, 3).astype(np.float32)),
... "feat": Tensor(np.random.randn(n_atoms, 256).astype(np.float32))
... }
>>> output = graph_head(node_features, n_node)
>>> print(output['stress_pred'].shape)
(1, 6)
"""
def __init__(
self,
latent_dim: int,
num_mlp_layers: int,
mlp_hidden_dim: int,
target_property_dim: int,
node_aggregation: Literal["sum", "mean"] = "mean",
dropout: Optional[float] = None,
compute_stress: Optional[bool] = False,
):
"""init
"""
super().__init__()
self.target_property_dim = target_property_dim
self.normalizer = ScalarNormalizer()
self.node_aggregation = node_aggregation
self.mlp = build_mlp(
input_size=latent_dim,
hidden_layer_sizes=[mlp_hidden_dim] * num_mlp_layers,
output_size=self.target_property_dim,
dropout=dropout,
)
self.output_activation = ops.Identity()
self.compute_stress = compute_stress
def construct(self, node_features, n_node):
"""construct
"""
feat = node_features[_KEY]
# aggregate to get a tensor of shape (num_graphs, latent_dim)
mlp_input = aggregate_nodes(
feat,
n_node,
reduction=self.node_aggregation,
)
pred = self.mlp(mlp_input)
if self.compute_stress:
# name the stress prediction differently
res = {"stress_pred": pred}
else:
res = {"graph_pred": pred}
return res
def predict(self, node_features, n_node, atomic_numbers=None):
"""Predict graph-level attributes.
Args:
node_features: Node features tensor
n_node: Number of nodes
atomic_numbers: Optional atomic numbers for reference energy calculation
Returns:
probs: Graph-level predictions of shape (n_graphs, target_property_dim).
If compute_stress is True, this will be the stress tensor.
If compute_stress is False, this will be the graph-level property (e.g., energy).
"""
pred = self(node_features, n_node)
if self.compute_stress:
logits = pred["stress_pred"].squeeze(-1)
else:
assert atomic_numbers is not None, "atomic_numbers must be provided for graph prediction"
logits = pred["graph_pred"].squeeze(-1)
probs = self.output_activation(logits)
probs = self.normalizer.inverse(probs)
return probs
# pylint: disable=C0301
class EnergyHead(GraphHead):
r"""
Graph-level energy prediction head.
Implements neural network head for predicting total energy or per-atom average energy of molecular graphs.
Supports node-level aggregation, reference energy offset, and flexible output modes.
Args:
latent_dim (int): Input feature dimension for each node.
num_mlp_layers (int): Number of hidden layers in MLP.
mlp_hidden_dim (int): Hidden dimension size of MLP.
target_property_dim (int): Output dimension of energy property (typically 1).
predict_atom_avg (bool, optional): Whether to predict per-atom average energy instead of total energy. Default: ``True``.
reference_energy_name (str, optional): Reference energy name for offset, e.g., ``"vasp-shifted"``. Default: ``"mp-traj-d3"``.
train_reference (bool, optional): Whether to train reference energy as learnable parameter. Default: ``False``.
dropout (Optional[float], optional): Dropout rate for MLP. Default: ``None``.
node_aggregation (str, optional): Aggregation method for node predictions, e.g., ``"mean"`` or ``"sum"``. Default: ``None``.
Inputs:
- **node_features** (dict) - Node feature dictionary, must contain key "feat" with shape :math:`(n_{nodes}, latent\_dim)`.
- **n_node** (Tensor) - Number of nodes in graph, shape :math:`(1,)`.
Outputs:
- **output** (dict) - Dictionary containing key "graph_pred" with value of shape :math:`(1, target\_property\_dim)`.
Raises:
ValueError: If required feature keys are missing in `node_features`.
ValueError: If `node_aggregation` is not a supported type.
Supported Platforms:
``Ascend``
Examples:
>>> import numpy as np
>>> import mindspore
>>> from mindspore import Tensor
>>> from mindchemistry.cell.orb.gns import EnergyHead
>>> energy_head = EnergyHead(
... latent_dim=256,
... num_mlp_layers=1,
... mlp_hidden_dim=256,
... target_property_dim=1,
... node_aggregation="mean",
... reference_energy_name="vasp-shifted",
... train_reference=True,
... predict_atom_avg=True,
... )
>>> n_atoms = 4
>>> n_node = Tensor([n_atoms], mindspore.int32)
>>> atomic_numbers = Tensor(np.random.randint(1, 119, size=(n_atoms,), dtype=np.int32))
>>> atomic_numbers_embedding_np = np.zeros((n_atoms, 118), dtype=np.float32)
>>> for i, num in enumerate(atomic_numbers.asnumpy()):
... atomic_numbers_embedding_np[i, num - 1] = 1.0
>>> node_features = {
... "atomic_numbers": atomic_numbers,
... "atomic_numbers_embedding": Tensor(atomic_numbers_embedding_np),
... "positions": Tensor(np.random.randn(n_atoms, 3).astype(np.float32)),
... "feat": Tensor(np.random.randn(n_atoms, 256).astype(np.float32))
... }
>>> output = energy_head(node_features, n_node)
>>> print(output['graph_pred'].shape)
(1, 1)
"""
def __init__(
self,
latent_dim: int,
num_mlp_layers: int,
mlp_hidden_dim: int,
target_property_dim: int,
predict_atom_avg: bool = True,
reference_energy_name: str = "mp-traj-d3",
train_reference: bool = False,
dropout: Optional[float] = None,
node_aggregation: Optional[str] = "mean",
):
"""init
"""
ref = REFERENCE_ENERGIES[reference_energy_name]
super().__init__(
latent_dim=latent_dim,
num_mlp_layers=num_mlp_layers,
mlp_hidden_dim=mlp_hidden_dim,
target_property_dim=target_property_dim,
node_aggregation=node_aggregation,
dropout=dropout,
)
self.reference = LinearReferenceEnergy(
weight_init=ref.coefficients, trainable=train_reference
)
self.atom_avg = predict_atom_avg
def predict(self, node_features, n_node, atomic_numbers=None):
"""Predict energy.
Args:
node_features: Node features tensor
n_node: Number of nodes
atomic_numbers: Optional atomic numbers for reference energy calculation
Returns:
graph_pred: Energy prediction
"""
if atomic_numbers is None:
raise ValueError("atomic_numbers is required for energy prediction")
pred = self(node_features, n_node)["graph_pred"]
pred = self.normalizer.inverse(pred).squeeze(-1)
if self.atom_avg:
pred = pred * n_node
pred = pred + self.reference(atomic_numbers, n_node)
return pred
# pylint: disable=C0301
class Orb(ms.nn.Cell):
r"""
Orb graph regressor.
Combines a pretrained base model (e.g., MoleculeGNS) with optional node, graph, and stress regression heads, supporting
fine-tuning or feature extraction workflows.
Args:
model (MoleculeGNS): Pretrained or randomly initialized base model for message passing and feature extraction.
node_head (NodeHead, optional): Regression head for node-level property prediction. Default: ``None``.
graph_head (GraphHead, optional): Regression head for graph-level property prediction (e.g., energy). Default: ``None``.
stress_head (GraphHead, optional): Regression head for stress prediction. Default: ``None``.
model_requires_grad (bool, optional): Whether to fine-tune the base model (True) or freeze its parameters (False). Default: ``True``.
cutoff_layers (int, optional): If provided, only use the first ``cutoff_layers`` message passing layers of the base model.
Default: ``None``.
Inputs:
- **edge_features** (dict) - Edge feature dictionary (e.g., `{"vectors": Tensor, "r": Tensor}`).
- **node_features** (dict) - Node feature dictionary (e.g., `{"atomic_numbers": Tensor, ...}`).
- **senders** (Tensor) - Sender node indices for each edge. Shape: :math:`(n_{edges},)`.
- **receivers** (Tensor) - Receiver node indices for each edge. Shape: :math:`(n_{edges},)`.
- **n_node** (Tensor) - Number of nodes for each graph in the batch. Shape: :math:`(n_{graphs},)`.
Outputs:
- **output** (dict) - Dictionary containing:
- **edges** (dict) - Edge features after message passing, e.g., `{..., "feat": Tensor}`.
- **nodes** (dict) - Node features after message passing, e.g., `{..., "feat": Tensor}`.
- **graph_pred** (Tensor) - Graph-level predictions, e.g., energy. Shape: :math:`(n_{graphs}, target\_property\_dim)`.
- **node_pred** (Tensor) - Node-level predictions. Shape: :math:`(n_{nodes}, target\_property\_dim)`.
- **stress_pred** (Tensor) - Stress predictions (if stress_head is provided). Shape: :math:`(n_{graphs}, 6)`.
Raises:
ValueError: If neither node_head nor graph_head is provided.
ValueError: If cutoff_layers exceeds the number of message passing steps in the base model.
ValueError: If atomic_numbers is not provided when graph_head is required.
Supported Platforms:
``Ascend``
Examples:
>>> import numpy as np
>>> import mindspore
>>> from mindspore import Tensor
>>> from mindchemistry.cell.orb import Orb, MoleculeGNS, EnergyHead, NodeHead, GraphHead
>>> Orb = Orb(
... model=MoleculeGNS(
... num_node_in_features=256,
... num_node_out_features=3,
... num_edge_in_features=23,
... latent_dim=256,
... interactions="simple_attention",
... interaction_params={
... "distance_cutoff": True,
... "polynomial_order": 4,
... "cutoff_rmax": 6,
... "attention_gate": "sigmoid",
... },
... num_message_passing_steps=15,
... num_mlp_layers=2,
... mlp_hidden_dim=512,
... use_embedding=True,
... node_feature_names=["feat"],
... edge_feature_names=["feat"],
... ),
... graph_head=EnergyHead(
... latent_dim=256,
... num_mlp_layers=1,
... mlp_hidden_dim=256,
... target_property_dim=1,
... node_aggregation="mean",
... reference_energy_name="vasp-shifted",
... train_reference=True,
... predict_atom_avg=True,
... ),
... node_head=NodeHead(
... latent_dim=256,
... num_mlp_layers=1,
... mlp_hidden_dim=256,
... target_property_dim=3,
... remove_mean=True,
... ),
... stress_head=GraphHead(
... latent_dim=256,
... num_mlp_layers=1,
... mlp_hidden_dim=256,
... target_property_dim=6,
... compute_stress=True,
... ),
... )
>>> n_atoms = 4
>>> n_edges = 10
>>> n_node = Tensor([n_atoms], mindspore.int32)
>>> atomic_numbers = Tensor(np.random.randint(1, 119, size=(n_atoms,), dtype=np.int32))
>>> atomic_numbers_embedding_np = np.zeros((n_atoms, 118), dtype=np.float32)
>>> for i, num in enumerate(atomic_numbers.asnumpy()):
... atomic_numbers_embedding_np[i, num - 1] = 1.0
>>> node_features = {
... "atomic_numbers": atomic_numbers,
... "atomic_numbers_embedding": Tensor(atomic_numbers_embedding_np),
... "positions": Tensor(np.random.randn(n_atoms, 3).astype(np.float32))
... }
>>> edge_features = {
... "vectors": Tensor(np.random.randn(n_edges, 3).astype(np.float32)),
... "r": Tensor(np.abs(np.random.randn(n_edges).astype(np.float32) * 10))
... }
>>> senders = Tensor(np.random.randint(0, n_atoms, size=(n_edges,), dtype=np.int32))
>>> receivers = Tensor(np.random.randint(0, n_atoms, size=(n_edges,), dtype=np.int32))
>>> output = Orb(edge_features, node_features, senders, receivers, n_node)
>>> print(output['graph_pred'].shape, output['node_pred'].shape, output['stress_pred'].shape)
(1, 1) (4, 3) (1, 6)
"""
def __init__(
self,
model: MoleculeGNS,
node_head: Optional[NodeHead] = None,
graph_head: Optional[GraphHead] = None,
stress_head: Optional[GraphHead] = None,
model_requires_grad: bool = True,
cutoff_layers: Optional[int] = None,
):
"""init
"""
super().__init__()
if (node_head is None) and (graph_head is None):
raise ValueError("Must provide at least one node/graph head.")
self.node_head = node_head
self.graph_head = graph_head
self.stress_head = stress_head
self.cutoff_layers = cutoff_layers
self.model = model
if self.cutoff_layers is not None:
if self.cutoff_layers > self.model.num_message_passing_steps:
raise ValueError(
f"cutoff_layers ({self.cutoff_layers}) must be less than or equal to"
f" the number of message passing steps ({self.model.num_message_passing_steps})"
)
self.model.gnn_stacks = self.model.gnn_stacks[: self.cutoff_layers]
self.model.num_message_passing_steps = self.cutoff_layers
self.model_requires_grad = model_requires_grad
if not model_requires_grad:
for param in self.model.parameters():
param.requires_grad = False
def predict(self, edge_features, node_features, senders, receivers, n_node, atomic_numbers):
"""Predict node and/or graph level attributes.
Args:
edge_features: A dictionary, e.g., `{"vectors": Tensor, "r": Tensor}`.
node_features: A dictionary, e.g., `{"atomic_numbers": Tensor, "positions": Tensor,
"atomic_numbers_embedding": Tensor}`.
senders: A tensor of shape (n_edges,) containing the sender node indices.
receivers: A tensor of shape (n_edges,) containing the receiver node indices.
n_node: A tensor of shape (1,) containing the number of nodes.
atomic_numbers: A tensor of atomic numbers for reference energy calculation.
Returns:
ouput_dict: A dictionary containing the predictions:
- `graph_pred`: Graph-level predictions (e.g., energy) of shape (n_graphs, graph_property_dim).
- `stress_pred`: Stress predictions (if stress_head is provided) of shape (n_graphs, stress_dim).
- `node_pred`: Node-level predictions of shape (n_nodes, node_property_dim).
"""
_, nodes = self.model(edge_features, node_features, senders, receivers)
output = {}
output["graph_pred"] = self.graph_head.predict(nodes, n_node, atomic_numbers)
output["stress_pred"] = self.stress_head.predict(nodes, n_node)
output["node_pred"] = self.node_head.predict(nodes, n_node)
return output
def construct(self, edge_features, node_features, senders, receivers, n_node):
"""construct
"""
edges, nodes = self.model(edge_features, node_features, senders, receivers)
res = {"edges": edges, "nodes": nodes}
res.update(self.graph_head(nodes, n_node))
res.update(self.stress_head(nodes, n_node))
res.update(self.node_head(nodes, n_node))
return res
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/mindspore/mindscience.git
git@gitee.com:mindspore/mindscience.git
mindspore
mindscience
mindscience
master

Search