diff --git a/.gitee/ISSUE_TEMPLATE/.keep b/.gitee/ISSUE_TEMPLATE/.keep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.gitee/ISSUE_TEMPLATE/1-documentation.yml b/.gitee/ISSUE_TEMPLATE/1-documentation.yml new file mode 100644 index 0000000000000000000000000000000000000000..128e7e8c88571b7c594c238dc28d5547cce3cda0 --- /dev/null +++ b/.gitee/ISSUE_TEMPLATE/1-documentation.yml @@ -0,0 +1,62 @@ +name: 📚 Documentation +description: Request updates or additions to MindScience documentation +title: "[Doc]: " +labels: ["documentation"] + +body: +- type: markdown + attributes: + value: | + Thanks for taking the time to help MindScience and improve our documentation! + - If this is your first time, please read [our contributor guidelines](https://gitee.com/mindspore/mindscience/blob/master/CONTRIBUTION.md). + - You also confirm that you have searched the [open documentation issues](https://gitee.com/mindspore/mindscience/issues) and have found no duplicates for this request +- type: dropdown + id: domain + attributes: + label: Which domain the issue belongs to? + options: + - MindSpore Science Core + - applications-SPONGE + - applications-Flow + - applications-Energy + - applications-Chemistry + - applications-Earth + - Others + validations: + required: true + +- type: dropdown + id: new_or_correction + attributes: + label: Is this for new documentation, or an update to existing docs? + options: + - New + - Update + validations: + required: true + +- type: textarea + attributes: + label: 📚 The doc issue + description: > + Describe the incorrect/future/missing documentation. + value: | + 1. 【Document Link】/【文档链接】 + + 2. 【Issues Section】/【问题文档片段】 + + 3. 【Existing Issues】/【存在的问题】 + + 4. 【Expected Result】【预期结果】 + + validations: + required: true +- type: textarea + attributes: + label: Suggest a potential alternative/fix + description: > + Tell us how we could improve the documentation in this regard. +- type: markdown + attributes: + value: > + Thanks for contributing 🎉! \ No newline at end of file diff --git a/.gitee/ISSUE_TEMPLATE/2-installation.yml b/.gitee/ISSUE_TEMPLATE/2-installation.yml new file mode 100644 index 0000000000000000000000000000000000000000..5938cfe5dab419c99dfd0e2d08a5b9f191dccfe5 --- /dev/null +++ b/.gitee/ISSUE_TEMPLATE/2-installation.yml @@ -0,0 +1,68 @@ +name: 🛠️ Installation +description: Report an issue here when you hit errors during installation. +title: "[Installation]: " +labels: ["installation"] + +body: +- type: markdown + attributes: + value: > + #### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://gitee.com/mindspore/mindscience/issues). +- type: dropdown + id: domain + attributes: + label: Which domain the issue belongs to? + options: + - MindSpore Science Core + - applications-SPONGE + - applications-Flow + - applications-Energy + - applications-Chemistry + - applications-Earth + - Others + validations: + required: true + +- type: textarea + attributes: + label: Your current environment + description: | + Environment / 环境信息 (Mandatory / 必填) + value: | + - **Hardware Environment / 硬件环境(Mandatory / 必填)**: + Hardware (e.g.`Atlas 800T A2`) + + 样例: + + | 后端类型| 硬件具体类别 | + | --- | --- | + | Server | Atlas 800T A2 | + | CPU| Mac CPU/Win CPU/Linux CPU| + + + - **Software Environment / 软件环境 (Mandatory / 必填)**: + 迭代版本新增问题样例:(根据实际修改和增删) + + | Software | Version(根据实际修改,必填)| + | --- | --- | + | MindSpore | MindSpore 2.4.0 | + | CANN | 8.0.0.beta1 | + | Python | Python XXXXXX | + | OS platform | Ubuntu XXXXXX | + | GCC/Compiler version | XXXXXX | + + validations: + required: true +- type: textarea + attributes: + label: How you are installing MindScience + description: | + Paste the full command you are trying to execute. + placeholder: | + ```sh + pip install mindsponge_*.whl + ``` +- type: markdown + attributes: + value: > + Thanks for contributing 🎉! \ No newline at end of file diff --git a/.gitee/ISSUE_TEMPLATE/3-bug-report.yml b/.gitee/ISSUE_TEMPLATE/3-bug-report.yml new file mode 100644 index 0000000000000000000000000000000000000000..ae8c1bae76121286453adbc054f4affaeebc0039 --- /dev/null +++ b/.gitee/ISSUE_TEMPLATE/3-bug-report.yml @@ -0,0 +1,308 @@ +name: 🐛 Bug report +description: Raise an issue here if you find a bug. +title: "[Bug]: " +labels: ["bug"] + +body: + - type: markdown + attributes: + value: | + Thanks for taking the time to help MindScience and fill out this bug report! + - If this is your first time, please read [our contributor guidelines](https://gitee.com/mindspore/mindscience/blob/master/CONTRIBUTION.md). + - You also confirm that you have searched the [open documentation issues](https://gitee.com/mindspore/mindscience/issues) and have found no duplicates for this request + - type: dropdown + id: domain + attributes: + label: Which domain the issue belongs to? + options: + - MindSpore Science Core + - applications-SPONGE + - applications-Flow + - applications-Energy + - applications-Chemistry + - applications-Earth + - Others + validations: + required: true + + - type: input + id: version + attributes: + label: Version + description: What version of MindScience are you running? + placeholder: "example: r0.7" + validations: + required: true + + - type: textarea + attributes: + label: installation-method + description: | + Paste the full command you are trying to execute. + placeholder: | + ```sh + pip install mindsponge_*.whl + ``` + + - type: textarea + attributes: + label: Your current environment + description: | + Environment / 环境信息 (Mandatory / 必填) + value: | + - **Hardware Environment / 硬件环境(Mandatory / 必填)**: + Hardware (e.g.`Atlas 800T A2`) + + 样例: + + | 后端类型| 硬件具体类别 | + | --- | --- | + | Server | Atlas 800T A2 | + | CPU| Mac CPU/Win CPU/Linux CPU| + + + - **Software Environment / 软件环境 (Mandatory / 必填)**: + 迭代版本新增问题样例:(根据实际修改和增删) + + | Software | Version(根据实际修改,必填)| + | --- | --- | + | MindSpore | MindSpore 2.4.0 | + | CANN | 8.0.0.beta1 | + | Python | Python XXXXXX | + | OS platform | Ubuntu XXXXXX | + | GCC/Compiler version | XXXXXX | + + + bugfix版本问题引入样例:(根据实际修改和增删) + + | Software | Version(根据实际修改,必填)| + | --- | --- | + | MindSpore | MindSpore 2.4.0 (成功)master_202407131XXXXXX _a4230c71d(失败)| + | CANN | 8.0.0.beta1 | + | CANN 归档地址 | | + | Python | Python XXXXXX | + | OS platform | Ubuntu XXXXXX | + | GCC/Compiler version | XXXXXX | + + validations: + required: true + + + - type: textarea + id: description + attributes: + label: Describe the issue + description: | + Please provide a complete and succinct description of the problem, including what you expected to happen. + value: | + #### 1.Describe the current behavior / 问题描述 (Mandatory / 必填) + + 样例: (根据实际修改和增删) + + > sponge.colvar.Distance()报错,同时 sponge.metrics.Metric其子类的 .update() 报错 + + #### 2. / 关联用例 (Mandatory / 必填)Related testcase + + ```python + from mindspore import Tensor + from sponge.colvar import Distance + from sponge.metrics import MetricCV + cv = Distance([0,1]) + coordinate = Tensor([[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]]) + metric = MetricCV(cv) + metric.update(coordinate) + print(metric.eval()) + ``` + + #### 3.Steps to reproduce the issue / 重现步骤 (Mandatory / 必填) + + > 测试步骤:运行关联用例即可 + > 用例执行命令:来自CI日志或者用户执行命令 + + + #### 4.Describe the expected behavior / 预期结果 (Mandatory / 必填) + + > **【预期结果】**:MindSpore 1.10.1 版本下,可正常运行。预期输出为 [1.] + + #### 5.Related log / screenshot / 日志 / 截图 (Mandatory / 必填) + + ```shell + --------------------------------------------------------------------------- + ValueError Traceback (most recent call last) + Cell In[4], line 7 + 5 coordinate = Tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]) + 6 metric = MetricCV(cv) + ----> 7 metric.update(coordinate) + 8 print(metric.eval()) + + File ~/mindscience/MindSPONGE/./src/sponge/metrics/metrics.py:190, in MetricCV.update(self, coordinate, pbc_box, energy, force, potentials, total_bias, biases) + 163 """ + 164 + 165 Args: + (...) + 186 V: Number of bias potential energies. + 187 """ + 188 #pylint: disable=unused-argument + --> 190 colvar = self.colvar(coordinate, pbc_box) + 192 self._value = self._convert_data(colvar) + + File ~/.local/lib/python3.8/site-packages/mindspore/nn/cell.py:705, in Cell.__call__(self, *args, **kwargs) + 703 except Exception as err: + 704 _pynative_executor.clear_res() + --> 705 raise err + 707 if isinstance(output, Parameter): + 708 output = output.data + + File ~/.local/lib/python3.8/site-packages/mindspore/nn/cell.py:701, in Cell.__call__(self, *args, **kwargs) + 699 try: + 700 _pynative_executor.new_graph(self, *args, **kwargs) + --> 701 output = self._run_construct(args, kwargs) + 702 _pynative_executor.end_graph(self, output, *args, **kwargs) + 703 except Exception as err: + + File ~/.local/lib/python3.8/site-packages/mindspore/nn/cell.py:482, in Cell._run_construct(self, cast_inputs, kwargs) + 480 output = self._shard_fn(*cast_inputs, **kwargs) + 481 else: + --> 482 output = self.construct(*cast_inputs, **kwargs) + 483 if self._enable_forward_hook: + 484 output = self._run_forward_hook(cast_inputs, output) + + File ~/mindscience/MindSPONGE/./src/sponge/colvar/basic/distance.py:146, in Distance.construct(self, coordinate, pbc_box) + 131 r"""calculate distance. + 132 + 133 Args: + (...) + 142 + 143 """ + 145 # (B, ..., D) + --> 146 vector = self.vector(coordinate, pbc_box) + 148 # (B, ...) or (B, ..., 1) + 149 if self.norm_last_dim is None: + + File ~/.local/lib/python3.8/site-packages/mindspore/nn/cell.py:705, in Cell.__call__(self, *args, **kwargs) + 703 except Exception as err: + 704 _pynative_executor.clear_res() + --> 705 raise err + 707 if isinstance(output, Parameter): + 708 output = output.data + + File ~/.local/lib/python3.8/site-packages/mindspore/nn/cell.py:701, in Cell.__call__(self, *args, **kwargs) + 699 try: + 700 _pynative_executor.new_graph(self, *args, **kwargs) + --> 701 output = self._run_construct(args, kwargs) + 702 _pynative_executor.end_graph(self, output, *args, **kwargs) + 703 except Exception as err: + + File ~/.local/lib/python3.8/site-packages/mindspore/nn/cell.py:482, in Cell._run_construct(self, cast_inputs, kwargs) + 480 output = self._shard_fn(*cast_inputs, **kwargs) + 481 else: + --> 482 output = self.construct(*cast_inputs, **kwargs) + 483 if self._enable_forward_hook: + 484 output = self._run_forward_hook(cast_inputs, output) + + File ~/mindscience/MindSPONGE/./src/sponge/colvar/atoms/vector.py:183, in Vector.construct(self, coordinate, pbc_box) + 180 atoms1 = self.atoms1(coordinate, pbc_box) + 181 else: + 182 # (B, ..., 2, D) + --> 183 atoms = self.atoms(coordinate, pbc_box) + 184 # (B, ..., 1, D) <- (B, ..., 2, D) + 185 atoms0, atoms1 = self.split2(atoms) + + File ~/.local/lib/python3.8/site-packages/mindspore/nn/cell.py:705, in Cell.__call__(self, *args, **kwargs) + 703 except Exception as err: + 704 _pynative_executor.clear_res() + --> 705 raise err + 707 if isinstance(output, Parameter): + 708 output = output.data + + File ~/.local/lib/python3.8/site-packages/mindspore/nn/cell.py:701, in Cell.__call__(self, *args, **kwargs) + 699 try: + 700 _pynative_executor.new_graph(self, *args, **kwargs) + --> 701 output = self._run_construct(args, kwargs) + 702 _pynative_executor.end_graph(self, output, *args, **kwargs) + 703 except Exception as err: + + File ~/.local/lib/python3.8/site-packages/mindspore/nn/cell.py:482, in Cell._run_construct(self, cast_inputs, kwargs) + 480 output = self._shard_fn(*cast_inputs, **kwargs) + 481 else: + --> 482 output = self.construct(*cast_inputs, **kwargs) + 483 if self._enable_forward_hook: + 484 output = self._run_forward_hook(cast_inputs, output) + + File ~/mindscience/MindSPONGE/./src/sponge/colvar/atoms/atoms.py:232, in Atoms.construct(self, coordinate, pbc_box) + 219 r"""get position coordinate(s) of specific atom(s) + 220 + 221 Args: + (...) + 229 + 230 """ + 231 # (B, a_1, a_2, ..., a_{n}, D) <- (B, A, D) + --> 232 atoms = func.gather_vector(coordinate, self.index) + 233 if self.keep_in_box: + 234 atoms = self.coordinate_in_pbc(atoms, pbc_box) + + File ~/.local/lib/python3.8/site-packages/mindspore/common/api.py:718, in jit..wrap_mindspore..staging_specialize(*args, **kwargs) + 716 if _is_pynative_parallel() and func.__name__ == _PYNATIVE_PARALLEL_FUNC_NAME: + 717 process_obj = hash_args + --> 718 out = _MindsporeFunctionExecutor(func, hash_obj, input_signature, process_obj, jit_config)(*args, **kwargs) + 719 return out + + File ~/.local/lib/python3.8/site-packages/mindspore/common/api.py:121, in _wrap_func..wrapper(*arg, **kwargs) + 119 @wraps(fn) + 120 def wrapper(*arg, **kwargs): + --> 121 results = fn(*arg, **kwargs) + 122 return _convert_python_data(results) + + File ~/.local/lib/python3.8/site-packages/mindspore/common/api.py:350, in _MindsporeFunctionExecutor.__call__(self, *args, **kwargs) + 348 except Exception as err: + 349 _pynative_executor.clear_res() + --> 350 raise err + 352 if context.get_context("precompile_only"): + 353 return None + + File ~/.local/lib/python3.8/site-packages/mindspore/common/api.py:344, in _MindsporeFunctionExecutor.__call__(self, *args, **kwargs) + 342 if context.get_context("mode") == context.PYNATIVE_MODE: + 343 _pynative_executor.set_jit_compile_status(True, phase) + --> 344 phase = self.compile(self.fn.__name__, *args_list, **kwargs) + 345 _pynative_executor.set_jit_compile_status(False, phase) + 346 else: + + File ~/.local/lib/python3.8/site-packages/mindspore/common/api.py:435, in _MindsporeFunctionExecutor.compile(self, method_name, *args, **kwargs) + 433 else: + 434 setattr(self.fn, "__jit_function__", True) + --> 435 is_compile = self._graph_executor.compile(self.fn, compile_args, kwargs, phase, True) + 436 if isinstance(self.fn, types.MethodType): + 437 delattr(self.fn.__func__, "__jit_function__") + + ValueError: For primitive[BroadcastTo], the attribute[x shape] must be less than or equal to 1, but got 2. + + ---------------------------------------------------- + - C++ Call Stack: (For framework developers) + ---------------------------------------------------- + mindspore/core/utils/check_convert_utils.cc:675 Check + ``` + + ### + + #### 6.Special notes for this issue/备注 (Optional / 选填) + + **【定位人】**吴某某(根据实际修改) + + validations: + required: true + + + - type: textarea + id: mvr + attributes: + label: Minimum reproducible example + description: Please supply a [minimum reproducible code example](https://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports) here. + render: shell + + - type: textarea + id: logs + attributes: + label: Relevant log output + description: Please paste relevant error and log output here + render: shell + diff --git a/.gitee/ISSUE_TEMPLATE/4-ci-failure.yml b/.gitee/ISSUE_TEMPLATE/4-ci-failure.yml new file mode 100644 index 0000000000000000000000000000000000000000..b63c2480454460fe145fc4975a9dda124445f41e --- /dev/null +++ b/.gitee/ISSUE_TEMPLATE/4-ci-failure.yml @@ -0,0 +1,83 @@ +name: 🧪 CI failure report +description: Report a failing test. +title: "[CI Failure]: " +labels: ["ci-failure"] + +body: +- type: dropdown + id: domain + attributes: + label: Which domain the issue belongs to? + options: + - MindSpore Science Core + - applications-SPONGE + - applications-Flow + - applications-Energy + - applications-Chemistry + - applications-Earth + - Others + validations: + required: true + +- type: markdown + attributes: + value: > + #### Include the name of the failing Buildkite step and test file in the title. +- type: input + attributes: + label: Name of failing test + description: | + Paste in the fully-qualified name of the failing test from the logs. + placeholder: | + `path/to/test_file.py::test_name[params]` + validations: + required: true +- type: checkboxes + attributes: + label: Basic information + description: Select all items that apply to the failing test. + options: + - label: Flaky test + - label: Can reproduce locally + - label: Caused by external libraries (e.g. bug in `transformers`) +- type: textarea + attributes: + label: 🧪 Describe the failing test + description: | + Please provide a clear and concise description of the failing test. + placeholder: | + A clear and concise description of the failing test. + + ``` + The error message you got, with the full traceback and the error logs with [dump_input.py:##] if present. + ``` + validations: + required: true +- type: textarea + attributes: + label: 📝 History of failing test + description: | + Since when did the test start to fail? + + If you have time, identify the PR that caused the test to fail on main. You can do so via the following methods: + + - Use Buildkite Test Suites to find the PR where the test failure first occurred, and reproduce the failure locally. + + - Run [`git bisect`](https://git-scm.com/docs/git-bisect) locally. + + - Manually unblock Buildkite steps for suspected PRs on main and check the results. (authorized users only) + placeholder: | + Approximate timeline and/or problematic PRs + + A link to the Buildkite analytics of the failing test (if available) + validations: + required: true +- type: textarea + attributes: + label: CC List. + description: > + The list of people you want to CC. Usually, this includes those who worked on the PR that failed the test. +- type: markdown + attributes: + value: > + Thanks for reporting 🙏! \ No newline at end of file diff --git a/.gitee/ISSUE_TEMPLATE/5-feature-request.yml b/.gitee/ISSUE_TEMPLATE/5-feature-request.yml new file mode 100644 index 0000000000000000000000000000000000000000..a51b2526c0a5e9587e8c64e1fdb3bca24d8869aa --- /dev/null +++ b/.gitee/ISSUE_TEMPLATE/5-feature-request.yml @@ -0,0 +1,58 @@ +name: 🚀 Feature request +description: Submit a proposal/request for a new MindScience feature +title: "[Feature]: " +labels: ["feature"] + +body: +- type: markdown + attributes: + value: > + #### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://gitee.com/mindspore/mindscience/issues). +- type: dropdown + id: module + attributes: + label: Which module the issue belongs to? + options: + - MindScience data + - MindScience common + - MindScience e3nn + - MindScience models + - MindScience sciops + - MindScience solver + - MindScience sharker + - MindScience utils + - Others + validations: + required: true +- type: dropdown + id: new_or_improvement + attributes: + label: Is this a new feature, an improvement, or a change to existing functionality? + options: + - New Feature + - Improvement + - Change + validations: + required: true + +- type: textarea + attributes: + label: 🚀 The feature, motivation and pitch + description: > + A clear and concise description of the feature proposal. Please outline the motivation for the proposal. Is your feature request related to a specific problem? e.g., *"I'm working on X and would like Y to be possible"*. If this is related to another GitHub issue, please link here too. For feature design, you can refer to [feature design template](https://gitee.com/mindspore/mindscience/blob/br_refactor/docs/template/feature_design.md). + validations: + required: true +- type: textarea + attributes: + label: Alternatives + description: > + A description of any alternative solutions or features you've considered, if any. +- type: textarea + attributes: + label: Additional context + description: > + Add any other context or screenshots about the feature request. +- type: markdown + attributes: + value: > + Thanks for contributing 🎉! \ No newline at end of file diff --git a/.gitee/ISSUE_TEMPLATE/6-application-case.yml b/.gitee/ISSUE_TEMPLATE/6-application-case.yml new file mode 100644 index 0000000000000000000000000000000000000000..6ea117ca91cf112235559b7e17e151e9210c1562 --- /dev/null +++ b/.gitee/ISSUE_TEMPLATE/6-application-case.yml @@ -0,0 +1,45 @@ +name: 🤗 Support request for a new application case of AIForScience +description: Submit a proposal/request for a new application case of AIForScience +title: "[Application Case]: " +labels: ["application-case"] + +body: +- type: markdown + attributes: + value: > + #### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://gitee.com/mindspore/mindscience/issues). +- type: dropdown + id: domain + attributes: + label: Which domain the issue belongs to? + options: + - MindSPONGE + - MindFlow + - MindEnergy + - MindChemistry + - MindEarth + - Others + validations: + required: true + +- type: textarea + attributes: + label: The model to consider. + description: > + A url, pointing to the model, e.g. https://huggingface.co/openai-community/gpt2 . + validations: + required: true +- type: textarea + attributes: + label: The closest model MindScience already supports. + description: > + Here is the list of models already supported by MindScience: https://gitee.com/mindspore/mindscience#%E6%A6%82%E8%BF%B0 . Which model is the most similar to the model you want to add support for? +- type: textarea + attributes: + label: What's your difficulty of supporting the model you want? + description: > + For example, any new operators or new architecture? +- type: markdown + attributes: + value: > + Thanks for contributing 🎉! \ No newline at end of file diff --git a/.gitee/ISSUE_TEMPLATE/7-RFC.yml b/.gitee/ISSUE_TEMPLATE/7-RFC.yml new file mode 100644 index 0000000000000000000000000000000000000000..63a9035e773d522cb6109f71606521b4da24c005 --- /dev/null +++ b/.gitee/ISSUE_TEMPLATE/7-RFC.yml @@ -0,0 +1,87 @@ +name: 💬 Request for comments (RFC). +description: Ask for feedback on major architectural changes or design choices. +title: "[RFC]: " +labels: ["RFC"] + +body: +- type: markdown + attributes: + value: > + #### Please take a look at previous [RFCs](https://gitee.com/mindspore/mindscience/issues) for reference. +- type: dropdown + id: domain + attributes: + label: Which domain the issue belongs to? + options: + - MindSpore Science Core + - applications-SPONGE + - applications-Flow + - applications-Energy + - applications-Chemistry + - applications-Earth + - Others + validations: + required: true + +- type: textarea + attributes: + label: Backgroud. + description: > + Backgroud(背景信息) + placeholder: | + - Describe/Explain the status of the problem you wish to solve. + - Attach relevant issues if there is any. + validations: + required: true +- type: textarea + attributes: + label: Origin + description: > + Origin(信息来源) + placeholder: | + - Explain which department/team made this request so that its priority can be given. + validations: + required: true +- type: textarea + attributes: + label: Benefit / Necessity + description: > + Benefit / Necessity (价值/作用) + placeholder: | + - Describe/Explain the key value by fulfilling the request. + validations: + required: true +- type: textarea + attributes: + label: Design + description: > + Design(设计方案) + placeholder: | + - Describe/Explain the general idea of the design. Pseudo-code is allowed + validations: + required: true +- type: textarea + attributes: + label: Feedback Period. + description: > + The feedback period of the RFC. Usually at least one week. + validations: + required: false +- type: textarea + attributes: + label: CC List. + description: > + The list of people you want to CC. + validations: + required: false +- type: textarea + attributes: + label: Any Other Things. + description: > + Any other things you would like to mention. + validations: + required: false +- type: markdown + attributes: + value: > + Thanks for contributing 🎉! \ No newline at end of file diff --git a/.gitee/ISSUE_TEMPLATE/8-internship.yml b/.gitee/ISSUE_TEMPLATE/8-internship.yml new file mode 100644 index 0000000000000000000000000000000000000000..5bd065bcf087662ec146c951cf1271dfd3ec48f4 --- /dev/null +++ b/.gitee/ISSUE_TEMPLATE/8-internship.yml @@ -0,0 +1,81 @@ +name: 💻 Internship +description: This issue is intended for the MindScience open source internship project for college students +title: "[Internship]: " +labels: ["internship"] + + +body: +- type: markdown + attributes: + value: | + - This issue is intended for the MindSpore open source internship project for college students. Developers who do not participate in this project are not allowed to receive it. + - 本issue为面向高校学生的“MindSpore开源实习”项目的任务,非参加该项目的人员勿领。 +- type: dropdown + id: domain + attributes: + label: Which domain the issue belongs to? + options: + - MindSpore Science Core + - applications-SPONGE + - applications-Flow + - applications-Energy + - applications-Chemistry + - applications-Earth + - Others + validations: + required: true + +- type: textarea + attributes: + label: Your information. + description: > + Your information for intership. + value: | + 【Task score】 + 【Background description】 + 【Requirements】 + 【Development environment】 + - Hardware: + - Software: + + 【Programming language】 + 【Acceptance criteria】 + 【PR Submission address】 + 【Expected completion time】 + 【Development guide】 + 【Tutor & email】 + + Note: This issue is intended for the MindSpore open source internship project for college students. Developers who do not participate in this project are not allowed to receive it. + + --- + + 【任务分值】 + 【背景描述】 + 【需求描述】 + 【环境要求】 + - 硬件: + - 软件: + + 【编程语言】 + 【产出标准】 + 【PR提交地址】 + 【期望完成时间】 + 【开发指导】 + 【导师及邮箱】 + + 本issue为面向高校学生的“MindSpore开源实习”项目的任务,非参加该项目的人员勿领。 + + validations: + required: false + +- type: textarea + attributes: + label: Any Other Things. + description: > + Any other things you would like to mention. + validations: + required: false +- type: markdown + attributes: + value: > + Thanks for contributing 🎉! \ No newline at end of file diff --git a/.gitee/ISSUE_TEMPLATE/config.yaml b/.gitee/ISSUE_TEMPLATE/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b8ffbb4b2dff618963057deaa280bc530aaef7d4 --- /dev/null +++ b/.gitee/ISSUE_TEMPLATE/config.yaml @@ -0,0 +1,5 @@ +blank_issues_enabled: false +contact_links: + - name: Gitee 帮助中心 + url: https://help.gitee.com/ + about: 提供 Git 使用指南、教程、Gitee.com 平台基本功能使用、介绍和常见问题解答 \ No newline at end of file diff --git a/.gitee/PULL_REQUEST_TEMPLATE.en.md b/.gitee/PULL_REQUEST_TEMPLATE.en.md new file mode 100644 index 0000000000000000000000000000000000000000..9108f279047b6c2a0c9e198f539f447cd9aa7654 --- /dev/null +++ b/.gitee/PULL_REQUEST_TEMPLATE.en.md @@ -0,0 +1,30 @@ +### PR Source +- [ ] Issue (Please link related issue) +- [ ] Feature request +- [ ] Bug report +- [ ] Community contributor + +### Change Description +- **Reason for Modification:** + +- **Content Modified:** + +### Function Validation +- [ ] **Self-verification** +- [ ] **Screenshots of local test cases** + +### Checklist +- [ ] **Code reviewed** +- [ ] **UT test coverage** (If not, explain reason: ____________________) +- [ ] **Involves public API changes in MindSpore Science** +- [ ] **Documentation updated** + +### Code Review Requirements +- Changes over 1000 lines require organized review meeting with conclusions +- PR without function validation cannot be merged +- PR with incomplete checklist cannot be merged +- PR without clear source identification or change description cannot be merged + +### Change Notification +- [ ] **Documentation modified** +- [ ] **API change description** (If API changed, detail description): \ No newline at end of file diff --git a/.gitee/PULL_REQUEST_TEMPLATE.zh-CN.md b/.gitee/PULL_REQUEST_TEMPLATE.zh-CN.md new file mode 100644 index 0000000000000000000000000000000000000000..c166c30dc0a742dc9666ec760788ece4d5c61963 --- /dev/null +++ b/.gitee/PULL_REQUEST_TEMPLATE.zh-CN.md @@ -0,0 +1,30 @@ +### PR来源 +- [ ] issue单(请关联issue) +- [ ] 需求特性 +- [ ] 问题单 +- [ ] 社区开发者贡献 + +### 修改描述 +- **修改原因:** + +- **修改内容:** + +### 功能验证 +- [ ] **功能自验** +- [ ] **本地自验用例截图** + +### 检查清单 +- [ ] **是否经过代码检视** +- [ ] **是否具备UT测试用例看护**(如不符合,请说明原因:____________________) +- [ ] **是否涉及MindSpore Science公共接口变更** +- [ ] **是否涉及文档更新** + +### 代码检视要求 +- 合入代码超过1000行,需组织会议检视并附上结论 +- 未完成功能验证不允许合入 +- 未完成检查清单不允许合入 +- PR来源未标识或修改描述不清晰不允许合入 + +### 变更说明 +- [ ] **文档修改** +- [ ] **接口变更说明**(如涉及接口变更需详细描述): \ No newline at end of file diff --git a/.jenkins/check/config/filter_cppcheck.txt b/.jenkins/check/config/filter_cppcheck.txt index 5ceb4144cef8609817ce7d125debb1e05b53b12a..ba103f4eb1cf5f1aa19b44f371d2a577576eadeb 100644 --- a/.jenkins/check/config/filter_cppcheck.txt +++ b/.jenkins/check/config/filter_cppcheck.txt @@ -2,3 +2,12 @@ "mindscience/MindElec/mindelec/ccsrc/api/python/pybind_register.cc" "syntaxError" "mindscience/MindElec/mindelec/ccsrc/scientific_compute/pointcloud/material_analyse.cc" "useStlAlgorithm" "mindscience/MindElec/mindelec/ccsrc/scientific_compute/pointcloud/tensor_initializer.cc" "useStlAlgorithm" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/cif_dict_lib.cc" "shadowFunction" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/cif_dict_lib.cc" "useStlAlgorithm" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/msa_conversion_pybind.cc" "variableScope" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_altlocs.cc" "shadowVariable" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_altlocs.cc" "useStlAlgorithm" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_struct_conn_lib.cc" "unsignedLessThanZero" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/string_array_pybind.cc" "shadowVariable" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/string_array_pybind.cc" "useStlAlgorithm" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/string_array_pybind.cc" "pointerSize" diff --git a/.jenkins/check/config/filter_cpplint.txt b/.jenkins/check/config/filter_cpplint.txt index 9ceab3d1e814758951d7532142012045a8162f50..205e7a725ec17744096d8a7997cb122794320e72 100644 --- a/.jenkins/check/config/filter_cpplint.txt +++ b/.jenkins/check/config/filter_cpplint.txt @@ -597,3 +597,22 @@ "mindscience/MindSPONGE/mindsponge/ccsrc/molecular_dynamics/barostats/MC_barostat.cu" "whitespace/parens" "mindscience/MindSPONGE/mindsponge/ccsrc/molecular_dynamics/thermostats/Andersen_thermostat.cu" "whitespace/parens" "mindscience/MindSPONGE/mindsponge/ccsrc/molecular_dynamics/common.cuh" "build/include_subdir" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_struct_conn_lib.cc" "whitespace/parens" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_struct_conn_lib.cc" "runtime/references" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_struct_conn_pybind.cc" "build/include" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/mkdssp_pybind.cc" "build/include_order" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/mkdssp_pybind.cc" "whitespace/ending_newline" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/msa_conversion_pybind.cc" "build/include" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_utils_pybind.cc" "whitespace/braces" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_utils_pybind.cc" "whitespace/parens" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_utils_pybind.cc" "build/include" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/cpp/msa_profile_pybind.cc" "build/include" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/cpp/msa_profile_pybind.cc" "whitespace/ending_newline" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/cif_dict_pybind.cc" "build/include" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_atom_site_pybind.cc" "build/include" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_layout_pybind.cc" "build/include" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/aggregation_pybind.cc" "build/include" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/cif_dict_lib.cc" "whitespace/braces" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/fasta_iterator_pybind.cc" "build/include" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/membership_pybind.cc" "build/include" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_altlocs.cc" "runtime/references" \ No newline at end of file diff --git a/.jenkins/check/config/filter_linklint.txt b/.jenkins/check/config/filter_linklint.txt index cea87ddfc3cc48ebf5802b2e0f09574a1873b7e3..20e8fe38ee3ce4119727ba7758dce9921d4db57e 100644 --- a/.jenkins/check/config/filter_linklint.txt +++ b/.jenkins/check/config/filter_linklint.txt @@ -3,4 +3,8 @@ https://api.colabfold.com https://a3m.mmseqs.com -https://www.mindspore.cn/community/SIG/detail/?name=mindflow+SIG \ No newline at end of file +https://www.mindspore.cn/community/SIG/detail/?name=mindflow+SIG + +https://mmcif.wwpdb.org/dictionaries/mmcif_pdbx_v50.dic/Items/_entity.type.html* +https://doi.org/10.1002/prot.340200303* +https://arxiv.org/pdf/2006.14616.pdf* \ No newline at end of file diff --git a/.jenkins/check/config/filter_pylint.txt b/.jenkins/check/config/filter_pylint.txt index d680aadf0a819f0fe12907b2c74a0e7754743dac..9b31777820ff14a3b2a6b66b3b5db34c0f92df09 100644 --- a/.jenkins/check/config/filter_pylint.txt +++ b/.jenkins/check/config/filter_pylint.txt @@ -161,4 +161,141 @@ "mindscience/MindSPONGE/tutorials/basic/tutorial_p05.py" "wrong-import-position" "mindscience/tests/st/mindflow/cell/test_dft.py" "wrong-import-position" "mindscience/tests/st/mindflow/cell/test_fno1d.py" "wrong-import-position" -"mindscience/tests/st/mindflow/cell/attention/test_attention.py" "wrong-import-position" +"mindscience/tests/st/mindflow/cell/attention/test_attention.py" "wrong-import-position" + +# DeepMind AlphaFold3 +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/chemical_components.py" "bad-whitespace" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/nhmmer.py" "syntax-error" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/pipeline/structure_cleaning.py" "syntax-error" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/mmcif.py" "multiple-statements" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/base_config.py" "syntax-error" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/__init__.py" "syntax-error" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/common/resources.py" "bad-whitespace" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/table.py" "syntax-error" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/chemical_components.py" "assigning-non-slot" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/post_processing.py" "unused-variable" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/hmmalign.py" "bad-whitespace" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/jackhmmer.py" "bad-whitespace" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/post_processing.py" "unexpected-keyword-arg" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/template_store.py" "bad-continuation" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/scoring/covalent_bond_cleaning.py" "no-else-return" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/scoring/scoring.py" "bad-whitespace" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/mmcif.py" "bad-continuation" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/structure.py" "syntax-error" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/msa_features.py" "syntax-error" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/msa_identifiers.py" "no-else-return" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/common/base_config.py" "syntax-error" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/bonds.py" "syntax-error" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/msa_store.py" "syntax-error" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/common/testing/data.py" "pointless-statement" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/parsers.py" "bad-whitespace" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/mmcif.py" "bad-whitespace" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/template_realign.py" "syntax-error" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/common/resources.py" "function-redefined" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/chemical_components.py" "unexpected-keyword-arg" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/pipeline.py" "syntax-error" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/protein_data_processing.py" "bad-continuation" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/merging_features.py" "no-else-return" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/confidence_types.py" "syntax-error" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/scoring/scoring.py" "unused-argument" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/common/folding_input.py" "syntax-error" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/templates.py" "syntax-error" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/pipeline/pipeline.py" "syntax-error" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/chemical_components.py" "bad-whitespace" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/common/resources.py" "unused-argument" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/msa.py" "syntax-error" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/hmmsearch.py" "bad-whitespace" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/featurisation.py" "bad-whitespace" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/post_processing.py" "bad-whitespace" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/common/resources.py" "pointless-statement" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/structure_stores.py" "syntax-error" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/msa_config.py" "unexpected-keyword-arg" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/subprocess_utils.py" "syntax-error" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/mmcif.py" "no-else-return" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/parsing.py" "syntax-error" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/common/testing/data.py" "bad-whitespace" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/scoring/alignment.py" "bad-whitespace" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/mmcif.py" "invalid-name" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/test_utils.py" "syntax-error" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/pipeline/inter_chain_bonds.py" "bad-whitespace" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/common/testing/data.py" "function-redefined" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/hmmsearch.py" "bad-continuation" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/hmmsearch.py" "useless-object-inheritance" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/post_processing.py" "bad-continuation" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/periodic_table.py" "unexpected-keyword-arg" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/structure_tables.py" "syntax-error" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/hmmbuild.py" "syntax-error" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/mmcif_names.py" "no-else-return" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/common/testing/data.py" "unused-argument" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/msa_tool.py" "unexpected-keyword-arg" +"mindscience/MindSPONGE/applications/research/AlphaFold3/run_alphafold.py" "bad-whitespace" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/attention/attention_base.py" "syntax-error" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/struct_of_array.py" "no-else-return" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/struct_of_array.py" "missing-docstring" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/feat_batch.py" "missing-docstring" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/components/mapping.py" "unused-argument" +"mindscience/MindSPONGE/applications/research/AlphaFold3/run_alphafold.py" "bad-whitespace" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/diffusion_head.py" "missing-docstring" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/components/mapping.py" "invalid-name" +"mindscience/MindSPONGE/applications/research/AlphaFold3/run_alphafold.py" "pointless-statement" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/atom_layout/atom_layout.py" "syntax-error" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/components/base_modules.py" "unused-argument" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/featurization.py" "missing-docstring" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/components/base_modules.py" "missing-docstring" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/components/base_modules.py" "unused-import" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/params.py" "syntax-error" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/atom_cross_attention.py" "unused-argument" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/attention/ms_attention.py" "unused-argument" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/attention/ms_attention.py" "missing-docstring" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/attention/attention.py" "missing-docstring" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/utils.py" "missing-docstring" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/atom_cross_attention.py" "missing-docstring" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/common/precision.py" "syntax-error" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/__init__.py" "invalid-name" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/vector.py" "unused-import" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/components/mapping.py" "unused-import" +"mindscience/MindSPONGE/applications/research/AlphaFold3/run_alphafold_data_test.py" "syntax-error" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/struct_of_array.py" "cell-var-from-loop" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/vector.py" "invalid-name" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/diffusion_head.py" "unused-import" +"mindscience/MindSPONGE/applications/research/AlphaFold3/run_alphafold.py" "unexpected-keyword-arg" +"mindscience/MindSPONGE/applications/research/AlphaFold3/run_alphafold_test_v2.py" "syntax-error" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/featurization.py" "unused-argument" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/components/mapping.py" "bad-whitespace" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/components/mapping.py" "missing-docstring" +"mindscience/MindSPONGE/applications/research/AlphaFold3/run_alphafold.py" "unused-import" +"mindscience/MindSPONGE/applications/research/AlphaFold3/run_alphafold.py" "function-redefined" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/components/mapping.py" "len-as-condition" +"mindscience/MindSPONGE/applications/research/AlphaFold3/run_alphafold.py" "unsupported-membership-test" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/featurization.py" "redefined-argument-from-local" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/distogram_head.py" "missing-docstring" +"mindscience/MindSPONGE/applications/research/AlphaFold3/run_alphafold.py" "unused-argument" +"mindscience/MindSPONGE/applications/research/AlphaFold3/run_alphafold.py" "unsupported-assignment-operation" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/struct_of_array.py" "unused-argument" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/__init__.py" "missing-docstring" +"mindscience/MindSPONGE/applications/research/AlphaFold3/run_alphafold.py" "multiple-statements" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/components/mapping.py" "no-else-return" +"mindscience/MindSPONGE/applications/research/AlphaFold3/run_alphafold.py" "too-many-function-args" +"mindscience/MindSPONGE/applications/research/AlphaFold3/run_alphafold.py" "missing-docstring" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/components/mapping.py" "unused-variable" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/features.py" "syntax-error" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/model.py" "missing-docstring" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/template_modules.py" "unused-argument" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/confidence_head.py" "missing-docstring" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/diffusion_transformer.py" "syntax-error" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/components/utils.py" "missing-docstring" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/mmcif_metadata.py" "bad-continuation" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/template_modules.py" "missing-docstring" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/load_ckpt.py" "unused-import" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/load_ckpt.py" "missing-docstring" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/rigid_matrix_vector.py" "bad-whitespace" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/load_ckpt.py" "no-value-for-parameter" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/model.py" "unused-argument" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/modules.py" "unused-argument" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/gated_linear_unit/gated_linear_unit_base.py" "pointless-statement" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/modules.py" "missing-docstring" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/load_ckpt.py" "protected-access" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/model.py" "redefined-argument-from-local" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/triangle.py" "unused-argument" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/gated_linear_unit/gated_linear_unit_base.py" "unused-argument" +"mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/mmcif_metadata.py" "unused-argument" \ No newline at end of file diff --git a/.jenkins/check/config/whitelizard.txt b/.jenkins/check/config/whitelizard.txt index 1d50c2b879f4acc6e2ceaed90ede7c2a502c9e8f..7503472fde7121229d8ba57d7cbe325ef07b7fcc 100644 --- a/.jenkins/check/config/whitelizard.txt +++ b/.jenkins/check/config/whitelizard.txt @@ -13,3 +13,22 @@ mindscience/MindChemistry/mindchemistry/so2_conv/wigner.py:wigner_D mindscience/MindSPONGE/src/sponge/system/modelling/mol2_parser.py:mol2parser mindscience/MindSPONGE/src/sponge/system/modelling/hadder.py:add_hydrogen mindscience/MindSPONGE/src/sponge/system/molecule/molecule.py:build_system +mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/common/folding_input.py:from_alphafoldserver_fold_job +mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/common/folding_input.py:from_json +mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/cif_dict_lib.cc:alphafold3::GetEscapeQuote +mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/cif_dict_lib.cc:alphafold3::CifDict::ToString +mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/cif_dict_pybind.cc:alphafold3::Gather +mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/cif_dict_pybind.cc:alphafold3::CifDictGetArray +mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/cif_dict_pybind.cc:alphafold3::RegisterModuleCifDict +mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_utils_pybind.cc:alphafold3::FixArginine::Fix +mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_layout_lib.cc:alphafold3::MmcifLayout::Create +mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_struct_conn_lib.cc:alphafold3::GetBondAtomIndices +mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/pipeline.py:__init__ +mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/common/folding_input.py:from_json +mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/cif_dict_lib.cc:alphafold3::CifDict::ToString +mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/cif_dict_pybind.cc:alphafold3::RegisterModuleCifDict +mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_layout_lib.cc:alphafold3::MmcifLayout::Create +mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_struct_conn_lib.cc:alphafold3::GetBondAtomIndices +mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/parsing.py:from_res_arrays +mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/model.py:get_inference_result +mindscience/MindSPONGE/applications/research/AlphaFold3/run_alphafold_test_v2.py:test_inference diff --git a/.jenkins/test/config/flow_config/dependent_packages.yaml b/.jenkins/test/config/flow_config/dependent_packages.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9e70df642d476be60aeb9313c8938893f69eee28 --- /dev/null +++ b/.jenkins/test/config/flow_config/dependent_packages.yaml @@ -0,0 +1,2 @@ +mindspore: + '/mindspore/mindspore/version/202503/20250326/master_20250326010019_b91eca2945e61641319f9887aa76a1ccb38604d3_newest/' \ No newline at end of file diff --git a/MindChem/applications/orb/Parallel_Implementation.md b/MindChem/applications/orb/Parallel_Implementation.md new file mode 100644 index 0000000000000000000000000000000000000000..285183f0f6a68dc9179d6763412d6c080dc4e48b --- /dev/null +++ b/MindChem/applications/orb/Parallel_Implementation.md @@ -0,0 +1,121 @@ +# ORB模型并行训练说明文档 + +本文档说明了ORB模型从单卡训练到多卡并行训练的实现方案、启动方式以及性能提升结果。 + +## 一、并行实现 + +对比`finetune.py`和`finetune_parallel.py`,主要有以下几处改动: + +1、引入并行训练所需的mindspore通信模块: + +```python +from mindspore.communication import init +from mindspore.communication import get_rank, get_group_size +``` + +2、训练步骤中增加梯度聚合: + +```python +# 单卡版本 +grad_fn = ms.value_and_grad(model.loss, None, optimizer.parameters, has_aux=True) + +# 并行版本 +grad_fn = ms.value_and_grad(model.loss, None, optimizer.parameters, has_aux=True) +grad_reducer = nn.DistributedGradReducer(optimizer.parameters) # 新增梯度规约器 +``` + +3、数据加载时实现数据分片: + +```python +# 单卡版本 +dataloader = [base.batch_graphs([dataset[j] for j in range(i, min(i + batch_size, len(dataset)))]) + for i in range(0, len(dataset), batch_size)] + +# 并行版本 +rank_id = get_rank() +rank_size = get_group_size() +dataloader = [[dataset[j] for j in range(i, min(i + batch_size, len(dataset)))] + for i in range(0, len(dataset), batch_size)] +dataloader = [base.batch_graphs( + data[rank_id*len(data)//rank_size : (rank_id+1)*len(data)//rank_size] +) for data in dataloader] +``` + +4、初始化并行训练环境: + +```python +ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL, gradients_mean=True) +init() +``` + +## 二、启动方式 + +设置训练参数 + +> 1. 修改`configs/config_parallel.yaml`中的参数: +> a. 设置`data_path`字段指定训练和测试数据集 +> b. 设置`checkpoint_path`指定预训练模型权重路径 +> c. 根据需要调整其他训练参数 +> 2. 修改`run_parallel.sh`中的并行数: +> a. 通过`--worker_num=4 --local_worker_num=4`设置使用卡的数量 + +启动训练 + +```bash +pip install -r requirement.txt +bash run_parallel.sh +``` + +## 三、性能提升 + +单卡训练结果如下所示: + +```log +Loading datasets: dataset/train_mptrj_ase.dbTotal train dataset size: 800 samples +Loading datasets: dataset/val_mptrj_ase.dbTotal train dataset size: 200 samples +Model has 25213610 trainable parameters. +Epoch: 0/100, + train_metrics: {'data_time': 0.00010895108183224995, 'train_time': 386.58018293464556, 'energy_reference_mae': 5.598883946736653, 'energy_mae': 3.3611322244008384, 'energy_mae_raw': 103.14391835530598, 'stress_mae': 41.36046473185221, 'stress_mae_raw': 12.710869789123535, 'node_mae': 0.02808943825463454, 'node_mae_raw': 0.0228044210622708, 'node_cosine_sim': 0.7026202281316122, 'fwt_0.03': 0.23958333333333334, 'loss': 44.74968592325846} + val_metrics: {'energy_reference_mae': 5.316623687744141, 'energy_mae': 3.594848871231079, 'energy_mae_raw': 101.00129699707031, 'stress_mae': 30.630516052246094, 'stress_mae_raw': 9.707925796508789, 'node_mae': 0.017718862742185593, 'node_mae_raw': 0.014386476017534733, 'node_cosine_sim': 0.5506304502487183, 'fwt_0.03': 0.375, 'loss': 34.24308395385742} + +... + +Epoch: 99/100, + train_metrics: {'data_time': 7.802306208759546e-05, 'train_time': 59.67856075416785, 'energy_reference_mae': 5.5912095705668134, 'energy_mae': 0.007512244085470836, 'energy_mae_raw': 0.21813046435515085, 'stress_mae': 0.7020445863405863, 'stress_mae_raw': 2.222463607788086, 'node_mae': 0.04725319395462672, 'node_mae_raw': 0.042800972859064736, 'node_cosine_sim': 0.3720853428045909, 'fwt_0.03': 0.09895833333333333, 'loss': 0.7568100094795227} + val_metrics: {'energy_reference_mae': 5.308632850646973, 'energy_mae': 0.27756747603416443, 'energy_mae_raw': 3.251189708709717, 'stress_mae': 2.8720269203186035, 'stress_mae_raw': 9.094478607177734, 'node_mae': 0.05565642938017845, 'node_mae_raw': 0.05041291564702988, 'node_cosine_sim': 0.212838813662529, 'fwt_0.03': 0.19499999284744263, 'loss': 3.2052507400512695} +Checkpoint saved to orb_ckpts/ +Training time: 7333.08717 seconds + +``` + +四卡并行训练结果如下所示: + +```log +Loading datasets: dataset/train_mptrj_ase.dbTotal train dataset size: 800 samples +Loading datasets: dataset/train_mptrj_ase.dbTotal train dataset size: 800 samples +Loading datasets: dataset/train_mptrj_ase.dbTotal train dataset size: 800 samples +Loading datasets: dataset/train_mptrj_ase.dbTotal train dataset size: 800 samples +Loading datasets: dataset/val_mptrj_ase.dbTotal train dataset size: 200 samples +Loading datasets: dataset/val_mptrj_ase.dbTotal train dataset size: 200 samples +Loading datasets: dataset/val_mptrj_ase.dbTotal train dataset size: 200 samples +Loading datasets: dataset/val_mptrj_ase.dbTotal train dataset size: 200 samples +Model has 25213607 trainable parameters. +Model has 25213607 trainable parameters. +Model has 25213607 trainable parameters. +Model has 25213607 trainable parameters. + +... + +Training time: 2375.89474 seconds +Training time: 2377.02413 seconds +Training time: 2377.22778 seconds +Training time: 2376.63176 seconds + +``` + +在相同的训练配置下,并行训练相比单卡训练取得了显著的性能提升: + +- 单卡训练耗时:7293.28995 seconds +- 4卡并行训练耗时:2377.22778 seconds +- 性能提升:67.40% +- 加速比:3.07倍 diff --git a/MindChem/applications/orb/README.md b/MindChem/applications/orb/README.md new file mode 100644 index 0000000000000000000000000000000000000000..afb66a8648eb9650b4104d0d7d7cf8cb986266a3 --- /dev/null +++ b/MindChem/applications/orb/README.md @@ -0,0 +1,171 @@ + +# 模型名称 + +> Orb + +## 介绍 + +> 材料科学中,设计新型功能材料一直是新兴技术的关键部分。然而,传统的从头算计算方法在设计新型无机材料时速度慢且难以扩展到实际规模的系统。近年来,深度学习方法在多个领域展示了其强大的能力,能够通过并行架构高效运行。ORB模型的核心创新在于将这种深度学习方法应用于材料建模,通过可扩展的图神经网络架构学习原子间相互作用的复杂性。ORB模型是一个基于图神经网络(GNN)的机器学习力场(MLFF),设计为通用的原子间势能模型,适用于多种模拟任务(几何优化、蒙特卡洛模拟和分子动力学模拟)。该模型的输入是一个图结构,包含原子的位置、类型以及系统配置(如晶胞尺寸和边界条件);输出包括系统的总能量、每个原子的力向量以及单元格应力。与现有的开源神经网络势能模型(如MACE)相比,ORB模型在大系统规模下的速度提高了3-6倍。在Matbench Discovery基准测试中,ORB模型的误差比其他方法降低了31%,并且在发布时成为该基准测试的最新最佳模型。ORB模型在零样本评估中表现出色,即使在没有针对特定任务进行微调的情况下,也能在高温度非周期分子的分子动力学模拟中保持稳定。 + +![Orb模型预测自由能](docs/orb.png) + +> 上图中:(a) 通过Widom插入法在Mg-MOF-74中获得的MACE + D3(左)和Orb-D3(右)自由能表面。开放金属位点附近的蓝色区域代表最低自由能,表明这些是CO2的优势吸附位点。(b) CO2在Mg-MOF-74中的吸附位置,展示了通过Widom插入法获得的两个最有利的吸附位点,其吸附能分别为-54.5 kJ/mol和-54.4 kJ/mol。虽然Orb和MACE预测的能量极小值位置相似,但ORB的自由能最小值与实验测得的吸附热(-44 kJ/mol)数值更为接近。 + +## 环境要求 + +> 1. 安装`mindspore(2.5.0)` +> 2. 安装`mindchemistry` +> 3. 安装依赖包:`pip install -r requirement.txt` + +## 快速入门 + +> 1. 在[数据集链接](https://download-mindspore.osinfra.cn/mindscience/mindchemistry/orb/dataset/)下载相应的数据集并放在`dataset`目录下 +> 2. 在[模型链接](https://download-mindspore.osinfra.cn/mindscience/mindchemistry/orb/orb_ckpts/)下载orb预训练模型ckpt并放在`orb_ckpts`目录下 +> 3. 安装依赖包:`pip install -r requirement.txt` +> 4. 单卡训练命令: `bash run.sh` +> 5. 多卡训练命令: `bash run_parallel.sh` +> 6. 评估命令: `python evaluate.py` +> 7. 模型预测结果会存在`results`目录下 + +### 代码目录结构 + +```text +代码主要模块在src文件夹下,其中dataset文件夹下是数据集,orb_ckpts文件夹下是预训练模型和训练好的模型权重文件,configs文件夹下是各代码的参数配置文件。 + +orb_models # 模型名 +├── dataset + ├── train_mptrj_ase.db # 微调阶段训练数据集 + └── val_mptrj_ase.db # 微调阶段测试数据集 +├── orb_ckpts + └── orb-mptraj-only-v2.ckpt # 预训练模型checkpoint +├── configs + ├── config.yaml # 单卡训练参数配置文件 + ├── config_parallel.yaml # 多卡并行训练参数配置文件 + └── config_eval.yaml # 推理参数配置文件 +├── src + ├── __init__.py + ├── ase_dataset.py # 处理和加载数据集 + ├── atomic_system.py # 定义原子系统的数据结构 + ├── base.py # 基础类定义 + ├── featurization_utilities.py # 提供将原子系统转换为特征向量的工具 + ├── pretrained.py # 预训练模型相关函数 + ├── property_definitions.py # 定义原子系统中各种物理性质的计算方式和命名规则 + ├── trainer.py # 模型loss类定义 + ├── segment_ops.py # 提供对数据进行分段处理的工具 + └── utils.py # 工具模块 +├── finetune.py # 模型微调代码 +├── evaluate.py # 模型推理代码 +├── run.sh # 单卡训练启动脚本 +├── run_parallel.sh # 多卡并行训练启动脚本 +└── requirement.txt # 环境 +``` + +## 下载数据集 + +在[数据集链接](https://download-mindspore.osinfra.cn/mindscience/mindchemistry/orb/dataset/)下载训练和测试数据集放置于当前路径的dataset文件夹下(如果没有需要自己手动创建);在[模型链接](https://download-mindspore.osinfra.cn/mindscience/mindchemistry/orb/orb_ckpts/)下载orb预训练模型`orb-mptraj-only-v2.ckpt`放置于当前路径的orb_ckpts文件夹下(如果没有需要自己手动创建);文件路径参考[代码目录结构](#代码目录结构) + +## 训练过程 + +### 单卡训练 + +更改`configs/config.yaml`文件中训练参数: + +> 1. 设置微调阶段的训练和测试数据集,见`data_path`字段 +> 2. 设置训练加载的预训练模型权重文件,更改`checkpoint_path`路径字段 +> 3. 其它训练设置见Training Configuration部分 + +```bash +pip install -r requirement.txt +bash run.sh +``` + +代码运行结果如下所示: + +```log +============================================================================================================== +Please run the script as: +bash run.sh +============================================================================================================== +Loading datasets: dataset/train_mptrj_ase.dbTotal train dataset size: 800 samples +Loading datasets: dataset/val_mptrj_ase.dbTotal train dataset size: 200 samples +Model has 25213610 trainable parameters. +Epoch: 0/100, + train_metrics: {'data_time': 0.00010895108183224995, 'train_time': 386.58018293464556, 'energy_reference_mae': 5.598883946736653, 'energy_mae': 3.3611322244008384, 'energy_mae_raw': 103.14391835530598, 'stress_mae': 41.36046473185221, 'stress_mae_raw': 12.710869789123535, 'node_mae': 0.02808943825463454, 'node_mae_raw': 0.0228044210622708, 'node_cosine_sim': 0.7026202281316122, 'fwt_0.03': 0.23958333333333334, 'loss': 44.74968592325846} + val_metrics: {'energy_reference_mae': 5.316623687744141, 'energy_mae': 3.594848871231079, 'energy_mae_raw': 101.00129699707031, 'stress_mae': 30.630516052246094, 'stress_mae_raw': 9.707925796508789, 'node_mae': 0.017718862742185593, 'node_mae_raw': 0.014386476017534733, 'node_cosine_sim': 0.5506304502487183, 'fwt_0.03': 0.375, 'loss': 34.24308395385742} + +... + +Epoch: 99/100, + train_metrics: {'data_time': 7.802306208759546e-05, 'train_time': 59.67856075416785, 'energy_reference_mae': 5.5912095705668134, 'energy_mae': 0.007512244085470836, 'energy_mae_raw': 0.21813046435515085, 'stress_mae': 0.7020445863405863, 'stress_mae_raw': 2.222463607788086, 'node_mae': 0.04725319395462672, 'node_mae_raw': 0.042800972859064736, 'node_cosine_sim': 0.3720853428045909, 'fwt_0.03': 0.09895833333333333, 'loss': 0.7568100094795227} + val_metrics: {'energy_reference_mae': 5.308632850646973, 'energy_mae': 0.27756747603416443, 'energy_mae_raw': 3.251189708709717, 'stress_mae': 2.8720269203186035, 'stress_mae_raw': 9.094478607177734, 'node_mae': 0.05565642938017845, 'node_mae_raw': 0.05041291564702988, 'node_cosine_sim': 0.212838813662529, 'fwt_0.03': 0.19499999284744263, 'loss': 3.2052507400512695} +Checkpoint saved to orb_ckpts/ +Training time: 7333.08717 seconds +``` + +### 多卡并行训练 + +更改`configs/config_parallel.yaml`和`run_parallel.sh`文件中训练参数: + +> 1. 设置微调阶段的训练和测试数据集,见`data_path`字段 +> 2. 设置训练加载的预训练模型权重文件,更改`checkpoint_path`路径字段 +> 3. 其它训练设置见Training Configuration部分 +> 4. 修改`run_parallel.sh`文件中`--worker_num=4 --local_worker_num=4`来设置调用的卡的数量 + +```bash +pip install -r requirement.txt +bash run_parallel.sh +``` + +代码运行结果如下所示: + +```log +Loading datasets: dataset/train_mptrj_ase.dbTotal train dataset size: 800 samples +Loading datasets: dataset/train_mptrj_ase.dbTotal train dataset size: 800 samples +Loading datasets: dataset/train_mptrj_ase.dbTotal train dataset size: 800 samples +Loading datasets: dataset/train_mptrj_ase.dbTotal train dataset size: 800 samples +Loading datasets: dataset/val_mptrj_ase.dbTotal train dataset size: 200 samples +Loading datasets: dataset/val_mptrj_ase.dbTotal train dataset size: 200 samples +Loading datasets: dataset/val_mptrj_ase.dbTotal train dataset size: 200 samples +Loading datasets: dataset/val_mptrj_ase.dbTotal train dataset size: 200 samples +Model has 25213607 trainable parameters. +Model has 25213607 trainable parameters. +Model has 25213607 trainable parameters. +Model has 25213607 trainable parameters. + +... + +Training time: 2375.89474 seconds +Training time: 2377.02413 seconds +Training time: 2377.22778 seconds +Training time: 2376.63176 seconds +``` + +### 推理 + +更改`configs/config_eval.yaml`文件中推理参数: + +> 1. 设置测试数据集,见`val_data_path`字段 +> 2. 设置推理加载的预训练模型权重文件,更改`checkpoint_path`路径字段 +> 3. 其它训练设置见Evaluating Configuration部分 + +```bash +python evaluate.py +``` + +代码运行结果如下所示: + +```log +Loading datasets: dataset/val_mptrj_ase.dbTotal train dataset size: 200 samples +Model has 25213607 trainable parameters. +.Validation loss: 0.89507836 + energy_reference_mae: 5.3159098625183105 + energy_mae: 0.541229784488678 + energy_mae_raw: 4.244375228881836 + stress_mae: 0.22862032055854797 + stress_mae_raw: 10.575761795043945 + node_mae: 0.12522821128368378 + node_mae_raw: 0.04024107754230499 + node_cosine_sim: 0.38037967681884766 + fwt_0.03: 0.22499999403953552 + loss: 0.8950783610343933 +``` diff --git a/MindChem/applications/orb/configs/config.yaml b/MindChem/applications/orb/configs/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cbd4189a3ffa819b92524d4a815f7c107cbf0432 --- /dev/null +++ b/MindChem/applications/orb/configs/config.yaml @@ -0,0 +1,39 @@ +# Training Configuration +train_data_path: dataset/train_mptrj_ase.db +val_data_path: dataset/val_mptrj_ase.db +num_workers: 8 +batch_size: 64 +gradient_clip_val: 0.5 +max_epochs: 100 +checkpoint_path: orb_ckpts/ +lr: 3.0e-4 +random_seed: 1234 + +# Model Configuration +model: + # Energy Head Configuration + energy_head: + latent_dim: 256 + num_mlp_layers: 1 + mlp_hidden_dim: 256 + target: "energy" + node_aggregation: "mean" + reference_energy_name: "vasp-shifted" + train_reference: true + predict_atom_avg: true + + # Node Head Configuration + node_head: + latent_dim: 256 + num_mlp_layers: 1 + mlp_hidden_dim: 256 + target: "forces" + remove_mean: true + + # Stress Head Configuration + stress_head: + latent_dim: 256 + num_mlp_layers: 1 + mlp_hidden_dim: 256 + target: "stress" + compute_stress: true diff --git a/MindChem/applications/orb/configs/config_eval.yaml b/MindChem/applications/orb/configs/config_eval.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1e98c5f0b0d98f1c2dd87bcb095b0c19036ef2b4 --- /dev/null +++ b/MindChem/applications/orb/configs/config_eval.yaml @@ -0,0 +1,40 @@ +# Evaluating Configuration +mode: "PYNATIVE" +device_target: "Ascend" +device_id: 0 +# Dataset config +val_data_path: dataset/val_mptrj_ase.db +num_workers: 8 +batch_size: 64 +checkpoint_path: orb_ckpts/orb-ft-checkpoint_epoch99.ckpt +random_seed: 1234 +output_dir: results/ + +# Model Configuration +model: + # Energy Head Configuration + energy_head: + latent_dim: 256 + num_mlp_layers: 1 + mlp_hidden_dim: 256 + target: "energy" + node_aggregation: "mean" + reference_energy_name: "vasp-shifted" + train_reference: true + predict_atom_avg: true + + # Node Head Configuration + node_head: + latent_dim: 256 + num_mlp_layers: 1 + mlp_hidden_dim: 256 + target: "forces" + remove_mean: true + + # Stress Head Configuration + stress_head: + latent_dim: 256 + num_mlp_layers: 1 + mlp_hidden_dim: 256 + target: "stress" + compute_stress: true diff --git a/MindChem/applications/orb/configs/config_parallel.yaml b/MindChem/applications/orb/configs/config_parallel.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c6a5e0857b52d3cf543e0e0388160e6764184c1c --- /dev/null +++ b/MindChem/applications/orb/configs/config_parallel.yaml @@ -0,0 +1,39 @@ +# Training Configuration +train_data_path: dataset/train_mptrj_ase.db +val_data_path: dataset/val_mptrj_ase.db +num_workers: 8 +batch_size: 256 +gradient_clip_val: 0.5 +max_epochs: 100 +checkpoint_path: orb_ckpts/ +lr: 3.0e-4 +random_seed: 666 + +# Model Configuration +model: + # Energy Head Configuration + energy_head: + latent_dim: 256 + num_mlp_layers: 1 + mlp_hidden_dim: 256 + target: "energy" + node_aggregation: "mean" + reference_energy_name: "vasp-shifted" + train_reference: true + predict_atom_avg: true + + # Node Head Configuration + node_head: + latent_dim: 256 + num_mlp_layers: 1 + mlp_hidden_dim: 256 + target: "forces" + remove_mean: true + + # Stress Head Configuration + stress_head: + latent_dim: 256 + num_mlp_layers: 1 + mlp_hidden_dim: 256 + target: "stress" + compute_stress: true diff --git a/MindChem/applications/orb/docs/orb.png b/MindChem/applications/orb/docs/orb.png new file mode 100644 index 0000000000000000000000000000000000000000..6f9026b83e31dad2f48d626388e15262a831d02c Binary files /dev/null and b/MindChem/applications/orb/docs/orb.png differ diff --git a/MindChem/applications/orb/evaluate.py b/MindChem/applications/orb/evaluate.py new file mode 100644 index 0000000000000000000000000000000000000000..4e7157a82100284ff1c35106b98bae9643b82f23 --- /dev/null +++ b/MindChem/applications/orb/evaluate.py @@ -0,0 +1,102 @@ +# ============================================================================ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Evaluate.""" + +import argparse +import logging +import os +import pickle + +import mindspore as ms +from mindspore import context + +from finetune import build_loader +from src import pretrained, utils, trainer + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + + +def evaluate(args): + """Evaluate the model.""" + # set seed + utils.seed_everything(args.random_seed) + + # load dataset + val_loader = build_loader( + dataset_path=args.val_data_path, + num_workers=args.num_workers, + batch_size=1000, + target_config={"graph": ["energy", "stress"], "node": ["forces"]}, + augmentation=False, # do not apply random augment + shuffle=False, + ) + + # load trained model + if args.checkpoint_path is None: + raise ValueError("Checkpoint path is not provided.") + model = pretrained.orb_mptraj_only_v2(args.checkpoint_path) + model_params = sum(p.size for p in model.trainable_params() if p.requires_grad) + logging.info("Model has %d trainable parameters.", model_params) + + # begin evaluation + model.set_train(False) + val_iter = iter(val_loader) + val_batch = next(val_iter) + + output = model( + val_batch.edge_features, + val_batch.node_features, + val_batch.senders, + val_batch.receivers, + val_batch.n_node, + ) + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir, exist_ok=True) + save_path = os.path.join(args.output_dir, "predictions.pkl") + with open(save_path, "wb") as f: + pickle.dump(output, f) + + loss_fn = trainer.OrbLoss(model) + loss, logs = loss_fn.loss(val_batch) + print(f"Validation loss: {loss}") + for key, value in logs.items(): + print(f" {key}: {value}") + + +def main(): + """Main.""" + parser = argparse.ArgumentParser( + description="Evaluate orb model", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--config", type=str, default="configs/config_eval.yaml", help="Path to config file" + ) + args = parser.parse_args() + args = utils.load_cfg(args.config) + ms.set_context( + mode=context.PYNATIVE_MODE, + device_target=args.device_target, + device_id=args.device_id, + pynative_synchronize=True, + ) + evaluate(args) + + +if __name__ == "__main__": + main() diff --git a/MindChem/applications/orb/finetune.py b/MindChem/applications/orb/finetune.py new file mode 100644 index 0000000000000000000000000000000000000000..cb025e9dfd248bb6b9c958e449fedb054fe12eef --- /dev/null +++ b/MindChem/applications/orb/finetune.py @@ -0,0 +1,311 @@ +# ============================================================================ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Finetuning loop.""" + +import argparse +import logging +import warnings +import os +import timeit +from typing import Dict, Optional + +import mindspore as ms +from mindspore import nn, ops, context +import mindspore.dataset as ds +from mindspore.communication import init +from mindspore.communication import get_rank, get_group_size + +from src import base, pretrained, utils +from src.ase_dataset import AseSqliteDataset, BufferData +from src.trainer import OrbLoss + + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + + +def finetune( + model: nn.Cell, + loss_fn: Optional[nn.Cell], + optimizer: nn.Optimizer, + train_dataloader: ds.GeneratorDataset, + val_dataloader: ds.GeneratorDataset, + lr_scheduler: Optional[ms.experimental.optim.lr_scheduler] = None, + clip_grad: Optional[float] = None, + log_freq: float = 10, + parallel_mode: str = "NONE", +): + """Train for a fixed number of steps. + + Args: + model: The model to optimize. + loss_fn: The loss function to use. + optimizer: The optimizer to use for the model. + train_dataloader: A Dataloader, which may be infinite if num_steps is passed. + val_dataloader: A Dataloader for validation. + lr_scheduler: Optional, a Learning rate scheduler for modifying the learning rate. + clip_grad: Optional, the gradient clipping threshold. + log_freq: The logging frequency for step metrics. + parallel_mode: The parallel mode to use, e.g., "DATA_PARALLEL" or "NONE". + + Returns + A dictionary of metrics. + """ + if clip_grad is not None: + hook_handles = utils.gradient_clipping(model, clip_grad) + + train_metrics = utils.ScalarMetricTracker() + val_metrics = utils.ScalarMetricTracker() + + epoch_metrics = { + "data_time": 0.0, + "train_time": 0.0, + } + + # Get gradient function + grad_fn = ms.value_and_grad(loss_fn.loss, None, optimizer.parameters, has_aux=True) + if parallel_mode == "DATA_PARALLEL": + grad_reducer = nn.DistributedGradReducer(optimizer.parameters) + + # Define function of one-step training + def train_step(data, label=None): + (loss, val_logs), grads = grad_fn(data, label) + if parallel_mode == "DATA_PARALLEL": + grads = grad_reducer(grads) + optimizer(grads) + return loss, val_logs + + step_begin = timeit.default_timer() + for i, batch in enumerate(train_dataloader): + epoch_metrics["data_time"] += timeit.default_timer() - step_begin + # Reset metrics so that it reports raw values for each step but still do averages on + # the gradient accumulation. + if i % log_freq == 0: + train_metrics.reset() + + model.set_train() + loss, train_logs = train_step(batch) + + epoch_metrics["train_time"] += timeit.default_timer() - step_begin + train_metrics.update(epoch_metrics) + train_metrics.update(train_logs) + + if ops.isnan(loss): + raise ValueError("nan loss encountered") + + if lr_scheduler is not None: + lr_scheduler.step() + step_begin = timeit.default_timer() + + if clip_grad is not None: + for h in hook_handles: + h.remove() + + # begin evaluation + model.set_train(False) + val_iter = iter(val_dataloader) + val_batch = next(val_iter) + loss, val_logs = loss_fn.loss(val_batch) + val_metrics.update(val_logs) + + return train_metrics.get_metrics(), val_metrics.get_metrics() + + +def build_loader( + dataset_path: str, + num_workers: int, + batch_size: int, + augmentation: Optional[bool] = True, + target_config: Optional[Dict] = None, + shuffle: Optional[bool] = True, + parallel_mode: str = "NONE", + **kwargs, +) -> ds.GeneratorDataset: + """Builds the dataloader from a config file. + + Args: + dataset_path: Dataset path. + num_workers: The number of workers for each dataset. + batch_size: The batch_size config for each dataset. + augmentation: If rotation augmentation is used. + target_config: The target config. + shuffle: If the dataset should be shuffled. + parallel_mode: The parallel mode to use, e.g., "DATA_PARALLEL" or "NONE". + + Returns: + The Dataloader. + """ + log_loading = f"Loading datasets: {dataset_path} with {num_workers} workers. " + dataset = AseSqliteDataset( + dataset_path, target_config=target_config, augmentation=augmentation, **kwargs + ) + + log_loading += f"Total dataset size: {len(dataset)} samples" + logging.info(log_loading) + + dataset = BufferData(dataset, shuffle=shuffle) + if parallel_mode == "DATA_PARALLEL": + rank_id = get_rank() + rank_size = get_group_size() + dataloader = [ + [dataset[j] for j in range(i, min(i + batch_size, len(dataset)))] \ + for i in range(0, len(dataset), batch_size) + ] + dataloader = [ + base.batch_graphs( + data[rank_id*len(data)//rank_size : (rank_id+1)*len(data)//rank_size] + ) for data in dataloader + ] + else: + dataloader = [ + base.batch_graphs( + [dataset[j] for j in range(i, min(i + batch_size, len(dataset)))] + ) for i in range(0, len(dataset), batch_size) + ] + + return dataloader + + +def run(args, parallel_mode="NONE"): + """Training Loop. + + Args: + config (DictConfig): Config for training loop. + parallel_mode (str): The parallel mode to use, e.g., "DATA_PARALLEL" or "NONE". + """ + utils.seed_everything(args.random_seed) + + # Load dataset + train_loader = build_loader( + dataset_path=args.train_data_path, + num_workers=args.num_workers, + batch_size=args.batch_size, + target_config={"graph": ["energy", "stress"], "node": ["forces"]}, + augmentation=True, + ) + val_loader = build_loader( + dataset_path=args.val_data_path, + num_workers=args.num_workers, + batch_size=1000, + target_config={"graph": ["energy", "stress"], "node": ["forces"]}, + augmentation=False, + shuffle=False, + ) + num_steps = len(train_loader) + + # Instantiate model + pretrained_weights_path = os.path.join(args.checkpoint_path, "orb-mptraj-only-v2.ckpt") + model = pretrained.orb_mptraj_only_v2(pretrained_weights_path) + loss_fn = OrbLoss(model) + model_params = sum(p.size for p in model.trainable_params() if p.requires_grad) + logging.info("Model has %d trainable parameters.", model_params) + + total_steps = args.max_epochs * num_steps + optimizer, lr_scheduler = utils.get_optim(args.lr, total_steps, model) + + # Fine-tuning loop + start_epoch = 0 + train_time = timeit.default_timer() + for epoch in range(start_epoch, args.max_epochs): + train_metrics, val_metrics = finetune( + model=model, + loss_fn=loss_fn, + optimizer=optimizer, + train_dataloader=train_loader, + val_dataloader=val_loader, + lr_scheduler=lr_scheduler, + clip_grad=args.gradient_clip_val, + parallel_mode=parallel_mode, + ) + print(f'Epoch: {epoch}/{args.max_epochs}, \n train_metrics: {train_metrics}\n val_metrics: {val_metrics}') + + # Save checkpoint from last epoch + if epoch == args.max_epochs - 1: + # create ckpts folder if it does not exist + if not os.path.exists(args.checkpoint_path): + os.makedirs(args.checkpoint_path) + if parallel_mode == "DATA_PARALLEL": + rank_id = get_rank() + rank_size = get_group_size() + ms.save_checkpoint( + model, + os.path.join( + args.checkpoint_path, + f"orb-ft-parallel[{rank_id}-{rank_size}]-checkpoint_epoch{epoch}.ckpt" + ), + ) + else: + ms.save_checkpoint( + model, + os.path.join(args.checkpoint_path, f"orb-ft-checkpoint_epoch{epoch}.ckpt"), + ) + logging.info("Checkpoint saved to %s", args.checkpoint_path) + logging.info("Training time: %.5f seconds", timeit.default_timer() - train_time) + + +def main(): + """Main.""" + parser = argparse.ArgumentParser( + description="Finetune orb model", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--config", type=str, default="configs/config.yaml", help="Path to config file" + ) + parser.add_argument( + "--device_target", + type=str, + default="Ascend", + help="The target device to run, support 'Ascend'" + ) + parser.add_argument( + "--device_id", default=0, type=int, help="device index to use." + ) + parser.add_argument( + "--parallel_mode", + type=str, + default="NONE", + choices=["DATA_PARALLEL", "NONE"], + help="Parallel mode, support 'DATA_PARALLEL', 'NONE'" + ) + args = parser.parse_args() + + if args.parallel_mode.upper() == "DATA_PARALLEL": + ms.set_context( + mode=context.PYNATIVE_MODE, + device_target=args.device_target, + pynative_synchronize=True, + ) + # Set parallel context + ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL, gradients_mean=True) + init() + ms.set_seed(1) + else: + ms.set_context( + mode=context.PYNATIVE_MODE, + device_target=args.device_target, + device_id=args.device_id, + pynative_synchronize=True, + ) + configs = utils.load_cfg(args.config) + warnings.filterwarnings("ignore") + + run(configs, args.parallel_mode) + + +if __name__ == "__main__": + main() diff --git a/MindChem/applications/orb/requirement.txt b/MindChem/applications/orb/requirement.txt new file mode 100644 index 0000000000000000000000000000000000000000..e3f1a95c78cbcf5fb4b875a23e8d8c3853b64ac0 --- /dev/null +++ b/MindChem/applications/orb/requirement.txt @@ -0,0 +1,8 @@ +python>=3.10 +cached_path>=1.6.7 +ase>=3.24.0 +numpy>=1.26.4 +scipy>=1.15.1 +dm-tree==0.1.8 +tqdm>=4.66.5 +mindspore==2.5.0 \ No newline at end of file diff --git a/MindChem/applications/orb/run.sh b/MindChem/applications/orb/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..e21df131ac10c208179285b99d65658116eb6b8a --- /dev/null +++ b/MindChem/applications/orb/run.sh @@ -0,0 +1,23 @@ +#!/bin/bash +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +export GLOG_v=3 +echo "==============================================================================================================" +echo "Please run the script as: " +echo "bash run.sh" +echo "==============================================================================================================" + +python finetune.py --device_target Ascend --device_id 7 diff --git a/MindChem/applications/orb/run_parallel.sh b/MindChem/applications/orb/run_parallel.sh new file mode 100644 index 0000000000000000000000000000000000000000..e49c6ab71cef9b615b5fca800f53c79a3e355f1c --- /dev/null +++ b/MindChem/applications/orb/run_parallel.sh @@ -0,0 +1,26 @@ +#!/bin/bash +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +rm -rf msrun_log +mkdir msrun_log + +export GLOG_v=3 +echo "==============================================================================================================" +echo "Please run the script as: " +echo "bash run_parallel.sh" +echo "==============================================================================================================" + +msrun --worker_num=4 --local_worker_num=4 --master_port=8118 --log_dir=msrun_log --join=True --cluster_time_out=300 finetune.py --config configs/config_parallel.yaml --parallel_mode DATA_PARALLEL \ No newline at end of file diff --git a/MindChem/applications/orb/src/__init__.py b/MindChem/applications/orb/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..328a08a650120b0d9aedfd04c203ecb52649a69d --- /dev/null +++ b/MindChem/applications/orb/src/__init__.py @@ -0,0 +1,16 @@ +# ============================================================================ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""init""" diff --git a/MindChem/applications/orb/src/ase_dataset.py b/MindChem/applications/orb/src/ase_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..f61b95aa33dab27abbdefc6650efdf429d7eb988 --- /dev/null +++ b/MindChem/applications/orb/src/ase_dataset.py @@ -0,0 +1,239 @@ +# ============================================================================ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""ASE dataset""" + +import os +from typing import Dict, Optional, Tuple, Union + +import ase +import ase.db +import ase.db.row +import ase.stress +import numpy as np +import mindspore as ms +from mindspore import Tensor + +from src import atomic_system, property_definitions +from src.base import AtomGraphs +from src.utils import rand_matrix + + +class AseSqliteDataset: + """AseSqliteDataset. + + A MindSpore Dataset for reading ASE Sqlite serialized Atoms objects. + + Args: + dataset_path: Local path to read. + system_config: A config for controlling how an atomic system is represented. + target_config: A config for regression/classification targets. + augmentation: If random rotation augmentation is used. + + Returns: + An AseSqliteDataset. + """ + + def __init__( + self, + dataset_path: Union[str, os.PathLike], + system_config: Optional[atomic_system.SystemConfig] = None, + target_config: Optional[Dict] = None, + augmentation: Optional[bool] = True, + ): + super().__init__() + self.augmentation = augmentation + self.path = dataset_path + self.db = ase.db.connect(str(self.path), serial=True, type="db") + + self.feature_config = system_config + if target_config is None: + target_config = { + "graph": ["energy", "stress"], + "node": ["forces"], + "edge": [], + } + self.target_config = target_config + + def __getitem__(self, idx) -> AtomGraphs: + """Fetch an item from the db. + + Args: + idx: An index to fetch from the db file and convert to an AtomGraphs. + + Returns: + A AtomGraphs object containing everything the model needs as input, + positions and atom types and other auxiliary information, such as + fine tuning targets, or global graph features. + """ + # Sqlite db is 1 indexed. + row = self.db.get(idx + 1) + atoms = row.toatoms() + node_properties = property_definitions.get_property_from_row( + self.target_config["node"], row + ) + graph_property_dict = {} + for target_property in self.target_config["graph"]: + system_properties = property_definitions.get_property_from_row( + target_property, row + ) + # transform stress to voigt6 representation + if target_property == "stress" and len(system_properties.reshape(-1)) == 9: + system_properties = Tensor( + ase.stress.full_3x3_to_voigt_6_stress(system_properties.reshape(3, 3)), + dtype=ms.float32, + ).reshape(1, -1) + graph_property_dict[target_property] = system_properties + extra_targets = { + "node": {"forces": node_properties}, + "edge": {}, + "graph": graph_property_dict, + } + if self.augmentation: + atoms, extra_targets = random_rotations_with_properties(atoms, extra_targets) + + atom_graph = atomic_system.ase_atoms_to_atom_graphs( + atoms, + system_id=idx, + brute_force_knn=False, + ) + atom_graph = self._add_extra_targets(atom_graph, extra_targets) + + return atom_graph + + def get_atom(self, idx: int) -> ase.Atoms: + """Return the Atoms object for the dataset index.""" + row = self.db.get(idx + 1) + return row.toatoms() + + def get_atom_and_metadata(self, idx: int) -> Tuple[ase.Atoms, Dict]: + """Return the Atoms object plus a dict of metadata for the dataset index.""" + row = self.db.get(idx + 1) + return row.toatoms(), row.data + + def __len__(self) -> int: + """Return the dataset length.""" + return len(self.db) + + def __repr__(self) -> str: + """String representation of class.""" + return f"AseSqliteDataset(path={self.path})" + + def _add_extra_targets( + self, + atom_graph: AtomGraphs, + extra_targets: Dict[str, Dict], + ): + """Add extra features and targets to the AtomGraphs object. + + Args: + atom_graph: AtomGraphs object to add extra features and targets to. + extra_targets: Dictionary of extra targets to add. + """ + node_targets = ( + atom_graph.node_targets if atom_graph.node_targets is not None else {} + ) + node_targets = {**node_targets, **extra_targets["node"]} + + edge_targets = ( + atom_graph.edge_targets if atom_graph.edge_targets is not None else {} + ) + edge_targets = {**edge_targets, **extra_targets["edge"]} + + system_targets = ( + atom_graph.system_targets if atom_graph.system_targets is not None else {} + ) + system_targets = {**system_targets, **extra_targets["graph"]} + + return atom_graph._replace( + node_targets=node_targets if node_targets != {} else None, + edge_targets=edge_targets if edge_targets != {} else None, + system_targets=system_targets if system_targets != {} else None, + ) + + +def random_rotations_with_properties( + atoms: ase.Atoms, properties: dict +) -> Tuple[ase.Atoms, dict]: + """Randomly rotate atoms in ase.Atoms object. + + This exists to handle the case where we also need to rotate properties. + Currently we only ever do this for random rotations, but it could be extended. + + Args: + atoms (ase.Atoms): Atoms object to rotate. + properties (dict): Dictionary of properties to rotate. + """ + rand_rotation = rand_matrix(1)[0] + atoms.positions = atoms.positions @ rand_rotation + if atoms.cell is not None: + atoms.set_cell(atoms.cell.array @ rand_rotation) + + new_node_properties = {} + for key, v in properties["node"].items(): + if tuple(v.shape) == tuple(atoms.positions.shape): + new_node_properties[key] = v @ rand_rotation + else: + new_node_properties[key] = v + properties["node"] = new_node_properties + + if "stress" in properties["graph"]: + # Transformation rule of stress tensor + stress = properties["graph"]["stress"] + full_stress = ase.stress.voigt_6_to_full_3x3_stress(stress) + + if full_stress.shape != (3, 3): + full_stress = full_stress.reshape(3, 3) + + transformed = np.dot(np.dot(rand_rotation, full_stress), rand_rotation.T) + # Back to voigt notation, and shape (1, 6) for consistency with batching + properties["graph"]["stress"] = Tensor( + [ + transformed[0, 0], + transformed[1, 1], + transformed[2, 2], + transformed[1, 2], + transformed[0, 2], + transformed[0, 1], + ], + dtype=ms.float32, + ).unsqueeze(0) + + return atoms, properties + +class BufferData: + """Wrapper for a dataset. Loads all data into memory.""" + + def __init__(self, dataset, shuffle: bool = True): + """BufferData. + Args: + dataset: The dataset to wrap. + shuffle: If True, shuffle the data. + """ + self.data_objects = [dataset[i] for i in range(len(dataset))] + if shuffle: + self.shuffle() + + def __len__(self): + return len(self.data_objects) + + def __getitem__(self, index): + return self.data_objects[index] + + def shuffle(self): + """Shuffle the data.""" + indices = np.arange(len(self.data_objects)) + np.random.shuffle(indices) + self.data_objects = [self.data_objects[i] for i in indices] diff --git a/MindChem/applications/orb/src/atomic_system.py b/MindChem/applications/orb/src/atomic_system.py new file mode 100644 index 0000000000000000000000000000000000000000..b9895e9bd78be394a1ee9fa09e0c5a0b054de787 --- /dev/null +++ b/MindChem/applications/orb/src/atomic_system.py @@ -0,0 +1,222 @@ +# ============================================================================ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""atomic system""" + +from dataclasses import dataclass +from typing import List, Optional + +import ase +from ase import constraints +from ase.calculators.singlepoint import SinglePointCalculator + +import mindspore as ms +from mindspore import Tensor, mint + +from src import featurization_utilities +from src.base import AtomGraphs + + +@dataclass +class SystemConfig: + """Config controlling how to featurize a system of atoms. + + Args: + radius: radius for edge construction + max_num_neighbors: maximum number of neighbours each node can send messages to. + use_timestep_0: (unused - purely for compatibility with internal models) + """ + + radius: float + max_num_neighbors: int + use_timestep_0: bool = True + + +def atom_graphs_to_ase_atoms( + graphs: AtomGraphs, + energy: Optional[Tensor] = None, + forces: Optional[Tensor] = None, + stress: Optional[Tensor] = None, +) -> List[ase.Atoms]: + """Converts a list of graphs to a list of ase.Atoms.""" + if "atomic_numbers_embedding" in graphs.node_features: + atomic_numbers = mint.argmax( + graphs.node_features["atomic_numbers_embedding"], dim=-1 + ) + else: + atomic_numbers = graphs.node_features["atomic_numbers"] + atomic_numbers_split = mint.split(atomic_numbers, graphs.n_node.tolist()) + positions_split = mint.split(graphs.positions, graphs.n_node.tolist()) + assert graphs.tags is not None and graphs.system_features is not None + tags = mint.split(graphs.tags, graphs.n_node.tolist()) + + calculations = {} + if energy is not None: + energy_list = mint.unbind(energy.detach()) + assert len(energy_list) == len(atomic_numbers_split) + calculations["energy"] = energy_list + if forces is not None: + forces_list = mint.split(forces.detach(), graphs.n_node.tolist()) + assert len(forces_list) == len(atomic_numbers_split) + calculations["forces"] = forces_list + if stress is not None: + stress_list = mint.unbind(stress.detach()) + assert len(stress_list) == len(atomic_numbers_split) + calculations["stress"] = stress_list + + atoms_list = [] + for index, (n, p, c, t) in enumerate( + zip(atomic_numbers_split, positions_split, graphs.cell, tags) + ): + atoms = ase.Atoms( + numbers=n.detach(), + positions=p.detach(), + cell=c.detach(), + tags=t.detach(), + pbc=mint.any(c != 0), + ) + if calculations != {}: + spc = SinglePointCalculator( + atoms=atoms, + **{ + key: ( + val[index].item() + if val[index].nelement() == 1 + else val[index].numpy() + ) + for key, val in calculations.items() + }, + ) + atoms.calc = spc + atoms_list.append(atoms) + + return atoms_list + + +def ase_atoms_to_atom_graphs( + atoms: ase.Atoms, + system_config: SystemConfig = SystemConfig( + radius=10.0, max_num_neighbors=20, use_timestep_0=True + ), + system_id: Optional[int] = None, + brute_force_knn: Optional[bool] = None, +) -> AtomGraphs: + """Generate AtomGraphs from an ase.Atoms object. + + Args: + atoms: ase.Atoms object + system_config: SystemConfig object + system_id: Optional system_id + brute_force_knn: whether to use a 'brute force' knn approach with torch.cdist for kdtree construction. + Defaults to None, in which case brute_force is used if we a GPU is available (2-6x faster), + but not on CPU (1.5x faster - 4x slower). For very large systems, brute_force may OOM on GPU, + so it is recommended to set to False in that case. + device: device to put the tensors on. + + Returns: + AtomGraphs object + """ + atomic_numbers = ms.from_numpy(atoms.numbers).long() + atom_type_embedding = mint.nn.functional.one_hot( + atomic_numbers, num_classes=118 + ).type(ms.float32) + + node_feats = { + "atomic_numbers": atomic_numbers.to(ms.int64), + "atomic_numbers_embedding": atom_type_embedding.to(ms.float32), + "positions": ms.from_numpy(atoms.positions).to(ms.float32), + } + system_feats = {"cell": Tensor(atoms.cell.array[None, ...]).to(ms.float32)} + edge_feats, senders, receivers = _get_edge_feats( + node_feats["positions"], + system_feats["cell"][0], + system_config.radius, + system_config.max_num_neighbors, + brute_force=brute_force_knn, + ) + + num_atoms = len(node_feats["positions"]) + atom_graph = AtomGraphs( + senders=senders, + receivers=receivers, + n_node=Tensor([num_atoms]), + n_edge=Tensor([len(senders)]), + node_features=node_feats, + edge_features=edge_feats, + system_features=system_feats, + system_id=Tensor([system_id]) if system_id is not None else system_id, + fix_atoms=ase_fix_atoms_to_tensor(atoms), + tags=_get_ase_tags(atoms), + radius=system_config.radius, + max_num_neighbors=system_config.max_num_neighbors, + ) + return atom_graph + + +def _get_edge_feats( + positions: Tensor, + cell: Tensor, + radius: float, + max_num_neighbours: int, + brute_force: Optional[bool] = None, +): + """Get edge features. + + Args: + positions: (n_nodes, 3) positions tensor + cell: 3x3 tensor unit cell for a system + radius: radius for edge construction + max_num_neighbours: maximum number of neighbours each node can send messages to. + n_kdtree_workers: number of workers to use for kdtree construction. + brute_force: whether to use brute force for kdtree construction. + """ + # Construct a graph from a 3x3 supercell (as opposed to an infinite supercell). + ( + edge_index, + edge_vectors, + ) = featurization_utilities.compute_pbc_radius_graph( + positions=positions, + periodic_boundaries=cell, + radius=radius, + max_number_neighbors=max_num_neighbours, + brute_force=brute_force, + ) + edge_feats = { + "vectors": edge_vectors.to(ms.float32), + "r": edge_vectors.norm(dim=-1), + } + senders, receivers = edge_index[0], edge_index[1] + return edge_feats, senders, receivers + + +def _get_ase_tags(atoms: ase.Atoms) -> Tensor: + """Get tags from ase.Atoms object.""" + tags = atoms.get_tags() + if tags is not None: + tags = Tensor(tags) + else: + tags = mint.zeros(len(atoms)) + return tags + + +def ase_fix_atoms_to_tensor(atoms: ase.Atoms) -> Optional[Tensor]: + """Get fixed atoms from ase.Atoms object.""" + fixed_atoms = None + if atoms.constraints is not None and atoms.constraints: + constraint = atoms.constraints[0] + if isinstance(constraint, constraints.FixAtoms): + fixed_atoms = mint.zeros((len(atoms)), dtype=ms.bool_) + fixed_atoms[constraint.index] = True + return fixed_atoms diff --git a/MindChem/applications/orb/src/base.py b/MindChem/applications/orb/src/base.py new file mode 100644 index 0000000000000000000000000000000000000000..046e5950f209f46aa72b2933468211a74d9ae67d --- /dev/null +++ b/MindChem/applications/orb/src/base.py @@ -0,0 +1,486 @@ +# ============================================================================ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Base Model class.""" + +from collections import defaultdict +from copy import deepcopy +from typing import Any, Dict, List, Mapping, NamedTuple, Optional, Sequence, Union +import tree + +import mindspore as ms +from mindspore import ops, Tensor, mint + +from src import featurization_utilities + +Metric = Union[Tensor, int, float] +TensorDict = Mapping[str, Optional[Tensor]] + + +class ModelOutput(NamedTuple): + """A model's output.""" + + loss: Tensor + log: Mapping[str, Metric] + + +class AtomGraphs(NamedTuple): + """A class representing the input to a model for a graph. + + Args: + senders (torch.Tensor): The integer source nodes for each edge. + receivers (torch.Tensor): The integer destination nodes for each edge. + n_node (torch.Tensor): A (batch_size, ) shaped tensor containing the number of nodes per graph. + n_edge (torch.Tensor): A (batch_size, ) shaped tensor containing the number of edges per graph. + node_features (Dict[str, torch.Tensor]): A dictionary containing node feature tensors. + It will always contain "atomic_numbers" and "positions" keys, representing the + atomic numbers of each node, and the 3d cartesian positions of them respectively. + edge_features (Dict[str, torch.Tensor]): A dictionary containing edge feature tensors. + system_features (Optional[TensorDict]): An optional dictionary containing system-level features. + node_targets (Optional[Dict[torch.Tensor]]): An optional dict of tensors containing targets + for individual nodes. This tensor is commonly expected to have shape (num_nodes, *). + edge_target (Optional[torch.Tensor]): An optional tensor containing targets for individual edges. + This tensor is commonly expected to have (num_edges, *). + system_targets (Optional[Dict[torch.Tensor]]): An optional dict of tensors containing targets for the + entire system. system_id (Optional[torch.Tensor]): An optional tensor containing the ID of the system. + fix_atoms (Optional[torch.Tensor]): An optional tensor containing information on fixed atoms in the system. + """ + + senders: Tensor + receivers: Tensor + n_node: Tensor + n_edge: Tensor + node_features: Dict[str, Tensor] + edge_features: Dict[str, Tensor] + system_features: Dict[str, Tensor] + node_targets: Optional[Dict[str, Tensor]] = None + edge_targets: Optional[Dict[str, Tensor]] = None + system_targets: Optional[Dict[str, Tensor]] = None + system_id: Optional[Tensor] = None + fix_atoms: Optional[Tensor] = None + tags: Optional[Tensor] = None + radius: Optional[float] = None + max_num_neighbors: Optional[int] = None + + @property + def positions(self): + """Get positions of atoms.""" + return self.node_features["positions"] + + @positions.setter + def positions(self, val: Tensor): + self.node_features["positions"] = val + + @property + def atomic_numbers(self): + """Get integer atomic numbers.""" + return self.node_features["atomic_numbers"] + + @atomic_numbers.setter + def atomic_numbers(self, val: Tensor): + self.node_features["atomic_numbers"] = val + + @property + def cell(self): + """Get unit cells.""" + assert self.system_features + return self.system_features.get("cell") + + @cell.setter + def cell(self, val: Tensor): + assert self.system_features + self.system_features["cell"] = val + + def clone(self) -> "AtomGraphs": + """Clone the AtomGraphs object. + + Note: this differs from deepcopy() because it preserves gradients. + """ + + def _clone(x): + if isinstance(x, Tensor): + return x.clone() + return x + + return tree.map_structure(_clone, self) + + def to(self, device: Any = None) -> "AtomGraphs": + """Move AtomGraphs child tensors to a device.""" + + print(f"Moving AtomGraphs to device: {device}") + def _to(x): + if hasattr(x, "to"): + return x + return x + + return tree.map_structure(_to, self) + + def tachdetach(self) -> "AtomGraphs": + """Detach all child tensors.""" + + def _detach(x): + if hasattr(x, "detach"): + return x.detach() + return x + + return tree.map_structure(_detach, self) + + def equals(self, graphs: "AtomGraphs") -> bool: + """Check two atomgraphs are equal.""" + + def _is_equal(x, y): + if isinstance(x, Tensor): + return mint.equal(x, y) + return x == y + + flat_results = tree.flatten(tree.map_structure(_is_equal, self, graphs)) + return all(flat_results) + + def allclose(self, graphs: "AtomGraphs", rtol=1e-5, atol=1e-8) -> bool: + """Check all tensors/scalars of two atomgraphs are close.""" + + def _is_close(x, y): + if isinstance(x, Tensor): + return mint.allclose(x, y, rtol=rtol, atol=atol) + if isinstance(x, (float, int)): + return mint.allclose( + Tensor(x), Tensor(y), rtol=rtol, atol=atol + ) + return x == y + + flat_results = tree.flatten(tree.map_structure(_is_close, self, graphs)) + return all(flat_results) + + def to_dict(self): + """Return a dictionary mapping each AtomGraph property to a corresponding tensor/scalar. + + Any nested attributes of the AtomGraphs are unpacked so the + returned dict has keys like "positions" and "atomic_numbers". + + Any None attributes are not included in the dictionary. + + Returns: + dict: A dictionary mapping attribute_name -> tensor/scalar + """ + ret = {} + for key, val in self._asdict().items(): + if val is None: + continue + if isinstance(val, dict): + for k, v in val.items(): + ret[k] = v + else: + ret[key] = val + + return ret + + def to_batch_dict(self) -> Dict[str, Any]: + """Return a single dictionary mapping each AtomGraph property to a corresponding list of tensors/scalars. + + Returns: + dict: A dict mapping attribute_name -> list of length batch_size containing tensors/scalars. + """ + batch_dict = defaultdict(list) + for graph in self.split(self): + for key, value in graph.to_dict().items(): + batch_dict[key].append(value) + return batch_dict + + def split(self, clone=True) -> List["AtomGraphs"]: + """Splits batched AtomGraphs into constituent system AtomGraphs. + + Args: + graphs (AtomGraphs): A batched AtomGraphs object. + clone (bool): Whether to clone the graphs before splitting. + Cloning removes risk of side-effects, but uses more memory. + """ + graphs = self.clone() if clone else self + + batch_nodes = graphs.n_node.tolist() + batch_edges = graphs.n_edge.tolist() + + if not batch_nodes: + raise ValueError("Cannot split empty batch") + if len(batch_nodes) == 1: + return [graphs] + + batch_systems = mint.ones(len(batch_nodes), dtype=ms.int32).tolist() + node_features = _split_features(graphs.node_features, batch_nodes) + node_targets = _split_features(graphs.node_targets, batch_nodes) + edge_features = _split_features(graphs.edge_features, batch_edges) + edge_targets = _split_features(graphs.edge_targets, batch_edges) + system_features = _split_features(graphs.system_features, batch_systems) + system_targets = _split_features(graphs.system_targets, batch_systems) + system_ids = _split_tensors(graphs.system_id, batch_systems) + fix_atoms = _split_tensors(graphs.fix_atoms, batch_nodes) + tags = _split_tensors(graphs.tags, batch_nodes) + batch_nodes = [Tensor([n]) for n in batch_nodes] + batch_edges = [Tensor([e]) for e in batch_edges] + + # calculate the new senders and receivers + senders = list(_split_tensors(graphs.senders, batch_edges)) + receivers = list(_split_tensors(graphs.receivers, batch_edges)) + n_graphs = graphs.n_node.shape[0] + offsets = mint.cumsum(graphs.n_node[:-1], 0) + offsets = mint.cat([Tensor([0]), offsets]) + unbatched_senders = [] + unbatched_recievers = [] + for graph_index in range(n_graphs): + s = senders[graph_index] - offsets[graph_index] + r = receivers[graph_index] - offsets[graph_index] + unbatched_senders.append(s) + unbatched_recievers.append(r) + + return [ + AtomGraphs(*args) + for args in zip( + unbatched_senders, + unbatched_recievers, + batch_nodes, + batch_edges, + node_features, + edge_features, + system_features, + node_targets, + edge_targets, + system_targets, + system_ids, + fix_atoms, + tags, + [graphs.radius for _ in range(len(batch_nodes))], + [graphs.max_num_neighbors for _ in range(len(batch_nodes))], + ) + ] + + +def batch_graphs(graphs: List[AtomGraphs]) -> AtomGraphs: + """Batch graphs together by concatenating their nodes, edges, and features. + + Args: + graphs (List[AtomGraphs]): A list of AtomGraphs to be batched together. + + Returns: + AtomGraphs: A new AtomGraphs object with the concatenated nodes, + edges, and features from the input graphs, along with concatenated target, + system ID, and other information. + """ + # Calculates offsets for sender and receiver arrays, caused by concatenating + # the nodes arrays. + offsets = mint.cumsum( + Tensor([0] + [mint.sum(g.n_node) for g in graphs[:-1]]), 0 + ) + radius = graphs[0].radius + assert {graph.radius for graph in graphs} == {radius} + max_num_neighbours = graphs[0].max_num_neighbors + assert {graph.max_num_neighbors for graph in graphs} == {max_num_neighbours} + + return AtomGraphs( + n_node=mint.concat([g.n_node for g in graphs], dim=0).to(ms.int64), + n_edge=mint.concat([g.n_edge for g in graphs], dim=0).to(ms.int64), + senders=mint.concat( + [g.senders + o for g, o in zip(graphs, offsets)], dim=0 + ).to(ms.int64), + receivers=mint.concat( + [g.receivers + o for g, o in zip(graphs, offsets)], dim=0 + ).to(ms.int64), + node_features=_map_concat([g.node_features for g in graphs]), + edge_features=_map_concat([g.edge_features for g in graphs]), + system_features=_map_concat([g.system_features for g in graphs]), + node_targets=_map_concat([g.node_targets for g in graphs]), + edge_targets=_map_concat([g.edge_targets for g in graphs]), + system_targets=_map_concat([g.system_targets for g in graphs]), + system_id=_concat([g.system_id for g in graphs]), + fix_atoms=_concat([g.fix_atoms for g in graphs]), + tags=_concat([g.tags for g in graphs]), + radius=radius, + max_num_neighbors=max_num_neighbours, + ) + + +def refeaturize_atomgraphs( + atoms: AtomGraphs, + positions: Tensor, + atomic_number_embeddings: Optional[Tensor] = None, + cell: Optional[Tensor] = None, + recompute_neighbors=True, + updates: Optional[Tensor] = None, + fixed_atom_pos: Optional[Tensor] = None, + fixed_atom_type_embedding: Optional[Tensor] = None, + differentiable: bool = False, +) -> AtomGraphs: + """Return a graph updated according to the new positions, and (if given) atomic numbers and unit cells. + + Note: if a unit cell is given, it will *both* be used to do the + pbc-remapping and be set on the returned AtomGraphs + + Args: + atoms (AtomGraphs): The original AtomGraphs object. + positions (torch.Tensor): The new positions of the atoms. + atomic_number_embeddings (Optional[torch.Tensor]): The new atomic number embeddings. + cell (Optional[torch.Tensor]): The new unit cell. + recompute_neighbors (bool): Whether to recompute the neighbor list. + updates (Optional[torch.Tensor]): The updates to the positions. + fixed_atom_pos (Optional[torch.Tensor]): The positions of atoms + which are fixed when diffusing on a fixed trajectory. + fixed_atom_type_embedding (Optional[torch.Tensor]): If using atom type diffusion + with a fixed trajectory, the unormalized vectors of the fixed atoms. Shape (n_atoms, 118). + differentiable (bool): Whether to make the graph inputs require_grad. This includes + the positions and atomic number embeddings, if passed. + exact_pbc_image_neighborhood: bool: If the exact pbc image neighborhood calculation (from torch nl) + which considers boundary crossing for more than cell is used. + + Returns: + AtomGraphs: A refeaturized AtomGraphs object. + """ + if cell is None: + cell = atoms.cell + + if atoms.fix_atoms is not None and fixed_atom_pos is not None: + positions[atoms.fix_atoms] = fixed_atom_pos[atoms.fix_atoms] + + if ( + atoms.fix_atoms is not None + and fixed_atom_type_embedding is not None + and atomic_number_embeddings is not None + ): + atomic_number_embeddings[atoms.fix_atoms] = fixed_atom_type_embedding[ + atoms.fix_atoms + ] + + num_atoms = atoms.n_node + positions = featurization_utilities.batch_map_to_pbc_cell( + positions, cell, num_atoms + ) + + if differentiable: + positions.requires_grad = True + if atomic_number_embeddings is not None: + atomic_number_embeddings.requires_grad = True + + if recompute_neighbors: + assert atoms.radius is not None and atoms.max_num_neighbors is not None + ( + edge_index, + edge_vectors, + batch_num_edges, + ) = featurization_utilities.batch_compute_pbc_radius_graph( + positions=positions, + periodic_boundaries=cell, + radius=atoms.radius, + image_idx=num_atoms, + max_number_neighbors=atoms.max_num_neighbors, + ) + new_senders = edge_index[0] + new_receivers = edge_index[1] + else: + assert updates is not None + new_senders = atoms.senders + new_receivers = atoms.receivers + edge_vectors = recompute_edge_vectors(atoms, updates) + batch_num_edges = atoms.n_edge + + edge_features = { + "vectors": edge_vectors.to(ms.float32), + } + + new_node_features = {} + if atoms.node_features is not None: + new_node_features = deepcopy(atoms.node_features) + new_node_features["positions"] = positions + if atomic_number_embeddings is not None: + new_node_features["atomic_numbers_embedding"] = atomic_number_embeddings + + new_system_features = {} + if atoms.system_features is not None: + new_system_features = deepcopy(atoms.system_features) + new_system_features["cell"] = cell + + new_atoms = AtomGraphs( + senders=new_senders, + receivers=new_receivers, + n_node=atoms.n_node, + n_edge=batch_num_edges, + node_features=new_node_features, + edge_features=edge_features, + system_features=new_system_features, + node_targets=atoms.node_targets, + system_targets=atoms.system_targets, + fix_atoms=atoms.fix_atoms, + tags=atoms.tags, + radius=atoms.radius, + max_num_neighbors=atoms.max_num_neighbors, + ) + + return new_atoms + + +def recompute_edge_vectors(atoms, updates): + """Recomputes edge vectors with per node updates.""" + updates = -updates + senders = atoms.senders + receivers = atoms.receivers + edge_translation = updates[senders] - updates[receivers] + return atoms.edge_features["vectors"] + edge_translation + + +def volume_atomgraphs(atoms: AtomGraphs): + """Returns the volume of the unit cell.""" + cell = atoms.cell + cross = ops.Cross(dim=1) + return (cell[:, 0] * cross(cell[:, 1], cell[:, 2])).sum(-1) + + +def _map_concat(nests): + concat = lambda *args: _concat(args) + return tree.map_structure(concat, *nests) + + +def _concat( + tensors: List[Optional[Tensor]], +) -> Optional[Tensor]: + """Splits tensors based on the intended split sizes.""" + if any([x is None for x in tensors]): + return None + return mint.concat(tensors, dim=0) + + +def _split_tensors( + features: Optional[Tensor], + split_sizes: List[int], +) -> Sequence[Optional[Tensor]]: + """Splits tensors based on the intended split sizes.""" + if features is None: + return [None] * len(split_sizes) + + return mint.split(features, split_sizes) + + +def _split_features( + features: Optional[TensorDict], + split_sizes: List[int], +) -> Sequence[Optional[TensorDict]]: + """Splits features based on the intended split sizes.""" + if features is None: + return [None] * len(split_sizes) + + split_dict = { + k: mint.split(v, split_sizes) if v is not None else [None] * len(split_sizes) + for k, v in features.items() + } + individual_tuples = zip(*[v for v in split_dict.values()]) + individual_dicts: List[Optional[TensorDict]] = list( + map(lambda k: dict(zip(split_dict.keys(), k)), individual_tuples) + ) + return individual_dicts diff --git a/MindChem/applications/orb/src/featurization_utilities.py b/MindChem/applications/orb/src/featurization_utilities.py new file mode 100644 index 0000000000000000000000000000000000000000..dd68d4ad8d4322bb334d9c70109ae7ce029b9050 --- /dev/null +++ b/MindChem/applications/orb/src/featurization_utilities.py @@ -0,0 +1,438 @@ +# ============================================================================ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Featurization utilities for molecular models.""" + +from typing import Callable, Optional, Tuple, Union +from pynanoflann import KDTree as NanoKDTree +from scipy.spatial import KDTree as SciKDTree + +import numpy as np +import mindspore as ms +from mindspore import ops, Tensor, mint + +DistanceFeaturizer = Callable[[Tensor], Tensor] + + + +def gaussian_basis_function( + scalars: Tensor, + num_bases: Union[Tensor, int], + radius: Union[Tensor, float], + scale: Union[Tensor, float] = 1.0, +) -> Tensor: + """Gaussian basis function applied to a tensor of scalars. + + Args: + scalars (Tensor): Scalars to compute the gbf on. Shape [num_scalars]. + num_bases (Tensor): The number of bases. An Int. + radius (Tensor): The largest centre of the bases. A Float. + scale (Tensor, optional): The width of the gaussians. Defaults to 1. + + Returns: + Tensor: A tensor of shape [num_scalars, num_bases]. + """ + assert len(scalars.shape) == 1 + gaussian_means = ops.arange( + 0, float(radius), float(radius) / num_bases + ) + return mint.exp( + -(scale**2) * (scalars.unsqueeze(1) - gaussian_means.unsqueeze(0)).abs() ** 2 + ) + + +def featurize_edges( + edge_vectors: Tensor, distance_featurization: DistanceFeaturizer +) -> Tensor: + """Featurizes edge features, provides concatenated unit vector along with featurized distances. + + Args: + edge_vectors (tensor): Edge vectors to featurize. Shape [num_edge, 3] + distance_featurization (DistanceFeaturization): A function that featurizes the distances of the vectors. + + Returns: + tensor: Edge features, shape [num_edge, num_edge_features]. + """ + edge_features = [] + edge_norms = mint.linalg.norm(edge_vectors, dim=1) + featurized_edge_norms = distance_featurization(edge_norms) + unit_vectors = edge_vectors / edge_norms.unsqueeze(1) + unit_vectors = mint.nan_to_num(unit_vectors, nan=0, posinf=0, neginf=0) + edge_features.append(featurized_edge_norms) + edge_features.append(unit_vectors) + return mint.cat(edge_features, dim=-1).to(ms.float32) + + +def compute_edge_vectors( + edge_index: Tensor, positions: Tensor +) -> Tensor: + """Computes edge vectors from positions. + + Args: + edge_index (tensor): The edge index. First position the senders, second + position the receivers. Shape [2, num_edge]. + positions (tensor): Positions of each node. Shape [num_nodes, 3] + + Returns: + tensor: The vectors of each edge. + """ + senders = edge_index[0] + receivers = edge_index[1] + return positions[receivers] - positions[senders] + + +# These are offsets applied to coordinates to create a 3x3x3 +# tiled periodic image of the input structure. +OFFSETS = np.array( + [ + [-1.0, 1.0, -1.0], + [0.0, 1.0, -1.0], + [1.0, 1.0, -1.0], + [-1.0, 0.0, -1.0], + [0.0, 0.0, -1.0], + [1.0, 0.0, -1.0], + [-1.0, -1.0, -1.0], + [0.0, -1.0, -1.0], + [1.0, -1.0, -1.0], + [-1.0, 1.0, 0.0], + [0.0, 1.0, 0.0], + [1.0, 1.0, 0.0], + [-1.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + [-1.0, -1.0, 0.0], + [0.0, -1.0, 0.0], + [1.0, -1.0, 0.0], + [-1.0, 1.0, 1.0], + [0.0, 1.0, 1.0], + [1.0, 1.0, 1.0], + [-1.0, 0.0, 1.0], + [0.0, 0.0, 1.0], + [1.0, 0.0, 1.0], + [-1.0, -1.0, 1.0], + [0.0, -1.0, 1.0], + [1.0, -1.0, 1.0], + ] +) + +NUM_OFFSETS = len(OFFSETS) + + +def _compute_img_positions_torch( + positions: Tensor, periodic_boundaries: Tensor +) -> Tensor: + """Computes the positions of the periodic images of the input structure. + + Consider the following 2D periodic boundary image. + + --- + --- + --- + + | | | | + + --- + --- + --- + + | | x | | + + --- + --- + --- + + | | | | + + --- + --- + --- + + + Each tile in this has an associated translation to translate + 'x'. For example, the top left would by (-1, +1). These are + the 'OFFSETS', but OFFSETS are for a 3x3x3 grid. + + This is complicated by the fact that our periodic + boundaries are not orthogonal to each other, and so we form a new + translation by taking a linear combination of the unit cell axes. + + Args: + positions (Tensor): Positions of the atoms. Shape [num_atoms, 3]. + periodic_boundaries (Tensor): Periodic boundaries of the unit cell. + This can be 2 shapes - [3, 3] or [num_atoms, 3, 3]. If the shape is + [num_atoms, 3, 3], it is assumed that the PBC has been repeat_interleaved + for each atom, i.e this function is agnostic as to whether it is computing + with respect to a batch or not. + Returns: + Tensor: The positions of the periodic images. Shape [num_atoms, 27, 3]. + """ + num_positions = len(positions) + + has_unbatched_pbc = periodic_boundaries.shape == (3, 3) + if has_unbatched_pbc: + periodic_boundaries = periodic_boundaries.unsqueeze(0) + periodic_boundaries = periodic_boundaries.expand((num_positions, 3, 3)) + + assert periodic_boundaries.shape[0] == positions.shape[0] + offsets = Tensor(OFFSETS, dtype=positions.dtype) + offsets = mint.unsqueeze(offsets, 0) + repeated_offsets = offsets.expand((num_positions, NUM_OFFSETS, 3)) + repeated_offsets = mint.unsqueeze(repeated_offsets, 3) + periodic_boundaries = mint.unsqueeze(periodic_boundaries, 1) + translations = repeated_offsets * periodic_boundaries + translations = translations.sum(2) + + # Expand the positions so we can broadcast add the translations per PBC image. + expanded_positions = positions.unsqueeze(1) + translated_positions = expanded_positions + translations + return translated_positions + + +def brute_force_knn( + img_positions: Tensor, positions: Tensor, k: int +) -> Tuple[Tensor, Tensor]: + """Brute force k-nearest neighbors. + + Args: + img_positions (Tensor): The positions of the images. Shape [num_atoms * 27, 3]. + positions (Tensor): The positions of the query atoms. Shape [num_atoms, 3]. + k (int): The number of nearest neighbors to find. + + Returns: + return_types.topk: The indices of the nearest neighbors. Shape [num_atoms, k]. + """ + dist = mint.cdist(positions, img_positions) + return mint.topk(dist, k, largest=False, sorted=True) + + +def compute_pbc_radius_graph( + *, + positions: Tensor, + periodic_boundaries: Tensor, + radius: Union[float, Tensor], + max_number_neighbors: int = 20, + brute_force: Optional[bool] = None, + library: str = "pynanoflann", + n_workers: int = 1, +) -> Tuple[Tensor, Tensor]: + """Computes periodic condition radius graph from positions. + + Args: + positions (Tensor): 3D positions of particles. Shape [num_particles, 3]. + periodic_boundaries (Tensor): A 3x3 matrix where the periodic boundary axes are rows or columns. + radius (Union[float, tensor]): The radius within which to connect atoms. + max_number_neighbors (int, optional): The maximum number of neighbors for each particle. Defaults to 20. + brute_force (bool, optional): Whether to use brute force knn. Defaults to None, in which case brute_force + is used if GPU is available (2-6x faster), but not on CPU (1.5x faster - 4x slower, depending on + system size). + library (str, optional): The KDTree library to use. Currently, either 'scipy' or 'pynanoflann'. + n_workers (int, optional): The number of workers to use for KDTree construction. Defaults to 1. + + Returns: + Tuple[Tensor, Tensor]: A 2-Tuple. First, an edge_index tensor, where the first index are the + sender indices and the second are the receiver indices. Second, the vector displacements between edges. + """ + if brute_force is None: + brute_force = ms.get_context("device_target") == "GPU" + + if mint.any(periodic_boundaries != 0.0): + supercell_positions = _compute_img_positions_torch( + positions=positions, periodic_boundaries=periodic_boundaries + ) + # CRITICALLY IMPORTANT: We need to reshape the supercell_positions to be + # flat, so we can use them for the nearest neighbors. The *way* in which + # they are flattened is important, because we need to be able to map the + # indices returned from the nearest neighbors to the original positions. + # The easiest way to do this is to transpose, so that when we flatten, we + # have: + # [ + # img_0_atom_0, + # img_0_atom_1, + # ..., + # img_0_atom_N, + # img_1_atom_0, + # img_1_atom_1, + # ..., + # img_N_atom_N, + # etc + # ] + # This way, we can take the mod of the indices returned from the nearest + # neighbors to get the original indices. + # Shape (27, num_positions, 3) + supercell_positions = supercell_positions.transpose(0, 1) + supercell_positions = supercell_positions.reshape(-1, 3) + else: + supercell_positions = positions + + num_positions = positions.shape[0] + + if brute_force: + # Brute force + distance_values, nearest_img_neighbors = brute_force_knn( + supercell_positions, + positions, + min(max_number_neighbors + 1, len(supercell_positions)), + ) + + # remove distances greater than radius, and exclude self + within_radius = distance_values[:, 1:] < (radius + 1e-6) + + num_neighbors_per_position = within_radius.sum(-1) + # remove the self node which will be closest + index_array = nearest_img_neighbors[:, 1:] + + senders = mint.repeat_interleave( + mint.arange(num_positions), num_neighbors_per_position + ) + receivers_imgs = index_array[within_radius] + + receivers = receivers_imgs % num_positions + vectors = supercell_positions[receivers_imgs] - positions[senders] + stacked = mint.stack((senders, receivers), dim=0) + return stacked, vectors + + # Build a KDTree from the supercell positions. + # Query that KDTree just for the positions in the central cell. + tree_data = supercell_positions.clone().numpy() + tree_query = positions.clone().numpy() + distance_upper_bound = np.array(radius) + 1e-8 + if library == "scipy": + tree = SciKDTree(tree_data, leafsize=100) + _, nearest_img_neighbors = tree.query( + tree_query, + max_number_neighbors + 1, + distance_upper_bound=distance_upper_bound, + workers=n_workers, + p=2, + ) + # Remove the self-edge that will be closest + index_array = np.array(nearest_img_neighbors)[:, 1:] + # Remove any entry that equals len(supercell_positions), which are negative hits + receivers_imgs = index_array[index_array != len(supercell_positions)] + num_neighbors_per_position = (index_array != len(supercell_positions)).sum( + -1 + ) + elif library == "pynanoflann": + tree = NanoKDTree( + n_neighbors=min(max_number_neighbors + 1, len(supercell_positions)), + radius=radius, + leaf_size=100, + metric="l2", + ) + tree.fit(tree_data) + distance_values, nearest_img_neighbors = tree.kneighbors( + tree_query, n_jobs=n_workers + ) + nearest_img_neighbors = nearest_img_neighbors.astype(np.int32) + + # remove the self node which will be closest + index_array = nearest_img_neighbors[:, 1:] + # remove distances greater than radius + within_radius = distance_values[:, 1:] < (radius + 1e-6) + receivers_imgs = index_array[within_radius] + num_neighbors_per_position = within_radius.sum(-1) + + # We construct our senders and receiver indexes. + senders = np.repeat(np.arange(num_positions), list(num_neighbors_per_position)) + receivers_img_torch = Tensor(receivers_imgs, ms.int32) + # Map back to indexes on the central image. + receivers = receivers_img_torch % num_positions + senders_torch = Tensor(senders, ms.int32) + + # Finally compute the vector displacements between senders and receivers. + vectors = supercell_positions[receivers_img_torch] - positions[senders_torch] + return mint.stack((senders_torch, receivers), dim=0), vectors + + +def batch_map_to_pbc_cell( + positions: Tensor, + periodic_boundary_conditions: Tensor, + num_atoms: Tensor, +) -> Tensor: + """Maps positions to within a periodic boundary cell, for a batched system. + + Args: + positions (Tensor): The positions to be mapped. Shape [num_particles, 3] + periodic_boundary_conditions (Tensor): The periodic boundary conditions. Shape [num_batches, 3, 3] + num_atoms (LongTensor): The number of atoms in each batch. Shape [num_batches] + """ + dtype = positions.dtype + positions = positions.double() + periodic_boundary_conditions = periodic_boundary_conditions.double() + + pbc_nodes = mint.repeat_interleave(periodic_boundary_conditions, num_atoms, dim=0) + + # To use the stable linalg.solve, we need to mask batch elements which don't + # have periodic boundaries. We do this by adding the identity matrix as their PBC, + # because we need the PBCs to be non-singular. + null_pbc = pbc_nodes.abs().sum(dim=[1, 2]) == 0 + identity = mint.eye(3, dtype=ms.bool_) + # Broadcast the identity to the elements of the batch that have a null pbc. + null_pbc_identity_mask = null_pbc.view(-1, 1, 1) & identity.view(1, 3, 3) + pbc_nodes_masked = pbc_nodes + null_pbc_identity_mask.double() + + lattice_coords = ops.matrix_solve(pbc_nodes_masked.transpose(1, 2), positions) + frac_coords = lattice_coords % 1.0 + + cartesian = mint.einsum("bi,bij->bj", frac_coords, pbc_nodes) + return mint.where(null_pbc.unsqueeze(1), positions, cartesian).to(dtype) + + +def batch_compute_pbc_radius_graph( + *, + positions: Tensor, + periodic_boundaries: Tensor, + radius: Union[float, Tensor], + image_idx: Tensor, + max_number_neighbors: int = 20, + brute_force: Optional[bool] = None, + library: str = "scipy", +): + """Computes batched periodic boundary condition radius graph from positions. + + This function is optimised for computation on CPU, and work work on device. GPU implementations + are likely to be significantly slower because of the irregularly sized tensor computations and the + lack of extremely fast GPU knn routines. + + Args: + positions (Tensor): 3D positions of a batch of particles. Shape [num_particles, 3]. + periodic_boundaries (Tensor): A batch where each element 3x3 matrix where the periodic boundary axes + are rows or columns. + radius (Union[float, tensor]): The radius within which to connect atoms. + image_idx (Tensor): A vector where each element indicates the number of particles in each element of + the batch. Of size len(batch). + max_number_neighbors (int, optional): The maximum number of neighbors for each particle. Defaults to 20. + brute_force (bool, optional): Whether to use brute force knn. Defaults to None, in which case brute_force + is used if we are on GPU (2-6x faster), but not on CPU (1.5x faster - 4x slower). + library (str, optional): The KDTree library to use. Currently, either 'scipy' or 'pynanoflann'. + + Returns: + Tuple[Tensor, Tensor]: A 2-Tuple. First, an edge_index tensor, where the first index are the + sender indices and the second are the receiver indices. Second, the vector displacements between edges. + """ + idx = 0 + all_edges = [] + all_vectors = [] + num_edges = [] + + for p, pbc in zip( + ops.tensor_split(positions, mint.cumsum(image_idx, 0)[:-1]), + periodic_boundaries, + ): + edges, vectors = compute_pbc_radius_graph( + positions=p, + periodic_boundaries=pbc, + radius=radius, + max_number_neighbors=max_number_neighbors, + brute_force=brute_force, + library=library, + ) + if idx == 0: + offset = 0 + else: + offset += image_idx[idx - 1] + all_edges.append(edges + offset) + all_vectors.append(vectors) + num_edges.append(len(edges[0])) + idx += 1 + + all_edges = ms.numpy.concatenate(all_edges, 1) + all_vectors = ms.numpy.concatenate(all_vectors, 0) + num_edges = Tensor(num_edges, dtype=ms.int64) + return all_edges, all_vectors, num_edges diff --git a/MindChem/applications/orb/src/pretrained.py b/MindChem/applications/orb/src/pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..cb66db19fe55e2694a492e836c78d0a175bca67d --- /dev/null +++ b/MindChem/applications/orb/src/pretrained.py @@ -0,0 +1,116 @@ +# ============================================================================ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""pretrained.""" + +import os +from typing import Optional + +from mindspore import nn, load_checkpoint, load_param_into_net + +from mindchemistry.cell import ( + EnergyHead, + GraphHead, + Orb, + NodeHead, + MoleculeGNS, +) + + +def get_gns( + latent_dim: int = 256, + mlp_hidden_dim: int = 512, + num_message_passing_steps: int = 15, + num_edge_in_features: int = 23, + distance_cutoff: bool = True, + attention_gate: str = "sigmoid", +) -> MoleculeGNS: + """Define the base pretrained model architecture.""" + return MoleculeGNS( + num_node_in_features=256, + num_node_out_features=3, + num_edge_in_features=num_edge_in_features, + latent_dim=latent_dim, + interactions="simple_attention", + interaction_params={ + "distance_cutoff": distance_cutoff, + "polynomial_order": 4, + "cutoff_rmax": 6, + "attention_gate": attention_gate, + }, + num_message_passing_steps=num_message_passing_steps, + num_mlp_layers=2, + mlp_hidden_dim=mlp_hidden_dim, + use_embedding=True, + node_feature_names=["feat"], + edge_feature_names=["feat"], + ) + + +def load_model_for_inference(model: nn.Cell, weights_path: str) -> nn.Cell: + """ + Load a pretrained model in inference mode, using GPU if available. + """ + if not os.path.exists(weights_path): + raise FileNotFoundError(f"Checkpoint file {weights_path} not found.") + param_dict = load_checkpoint(weights_path) + load_param_into_net(model, param_dict) + model.set_train(False) + + return model + +def orb_v2( + weights_path: Optional[str] = None, +): + """Load ORB v2.""" + gns = get_gns() + + model = Orb( + graph_head=EnergyHead( + latent_dim=256, + num_mlp_layers=1, + mlp_hidden_dim=256, + target_property_dim=1, + node_aggregation="mean", + reference_energy_name="vasp-shifted", + train_reference=True, + predict_atom_avg=True, + ), + node_head=NodeHead( + latent_dim=256, + num_mlp_layers=1, + mlp_hidden_dim=256, + target_property_dim=3, + remove_mean=True, + ), + stress_head=GraphHead( + latent_dim=256, + num_mlp_layers=1, + mlp_hidden_dim=256, + target_property_dim=6, + compute_stress=True, + ), + model=gns, + ) + model = load_model_for_inference(model, weights_path) + return model + + +def orb_mptraj_only_v2( + weights_path: Optional[str] = None, +): + """Load ORB MPTraj Only v2.""" + + return orb_v2(weights_path,) diff --git a/MindChem/applications/orb/src/property_definitions.py b/MindChem/applications/orb/src/property_definitions.py new file mode 100644 index 0000000000000000000000000000000000000000..3951c06c0e0ca8b540bc8ad4c4c69649aa7affd6 --- /dev/null +++ b/MindChem/applications/orb/src/property_definitions.py @@ -0,0 +1,239 @@ +# ============================================================================ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Classes that define prediction targets.""" + +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union + +import ase.data +import ase.db +import ase.db.row +import ase.db.sqlite +import numpy as np + +import mindspore as ms +from mindspore import ops, Tensor, mint + +HARTREE_TO_EV = 27.211386245988 + + +def recursive_getattr(obj: object, attr: str) -> Any: + """Recursively access an object property using dot notation.""" + for sub_attr in attr.split("."): + obj = getattr(obj, sub_attr) + + return obj + + +def get_property_from_row( + name: Union[str, List[str]], + row: ase.db.row.AtomsRow, + conversion_factor: float = 1.0, +) -> Tensor: + """Retrieve arbitrary values from ase db data dict.""" + if isinstance(name, str): + names = [name] + else: + names = name + values = [] + for name_ in names: + attribute = recursive_getattr(row, name_) + target = np.array(attribute) + values.append(target) + + property_tensor = ms.from_numpy(np.hstack(values)).to(ms.float32) + + while len(property_tensor.shape) < 2: + property_tensor = property_tensor[None, ...] + + if "stress" in name and property_tensor.shape == (3, 3): + # convert stress tensor to voigt notation + property_tensor = Tensor( + [ + property_tensor[0, 0], + property_tensor[1, 1], + property_tensor[2, 2], + property_tensor[1, 2], + property_tensor[0, 2], + property_tensor[0, 1], + ], + dtype=ms.float32, + ).unsqueeze(0) + return property_tensor * conversion_factor + + +@dataclass +class PropertyDefinition: + """Defines how to extract and transform a quantative property from an ase db. + + Such properties have two primary use-cases: + - as features for the model to use / condition on. + - as target variables for regression tasks. + + Args: + name: The name of the property. + dim: The dimensionality of the property variable. + domain: Whether the variable is real, binary or categorical. If using + this variable as a regression target, then var_type determines + the loss function used e.g. MSE, BCE or cross-entropy loss. + row_to_property_fn: A function defining how a target can be + retrieved from an ase database row. + means: The mean to transform this by in the model. + stds: The std to scale this by in the model. + """ + + name: str + dim: int + domain: Literal["real", "binary", "categorical"] + row_to_property_fn: Optional[Callable] = None + means: Optional[Tensor] = None + stds: Optional[Tensor] = None + + +def energy_row_fn(row: ase.db.row.AtomsRow, dataset: str) -> float: + """Energy data in eV. + + - Some datasets use sums of energy values e.g. PBE + D3. + - For external datasets, we should explicitly register how + to extract the energy property by adding it to `extract_info'. + - Unregistered datasets default to using the `energy` attribute + and a conversion factor of 1, which is always correct for our + internally generated datasets. + """ + extract_info: Dict[str, List[Tuple]] = { + "mp-traj": [("energy", 1)], + "mp-traj-d3": [("energy", 1), ("data.d3.energy", 1)], + "alexandria-d3": [("energy", 1), ("data.d3.energy", 1)], + } + if dataset not in extract_info: + if not hasattr(row, "energy"): + raise ValueError( + f"db row {row.id} doesn't have an energy attribute directly " + ", but also doesn't define a method to extract energy info." + ) + return get_property_from_row("energy", row, 1) + + energy_ = 0.0 + for row_attribute, conversion_factor in extract_info[dataset]: + energy_ += get_property_from_row(row_attribute, row, conversion_factor) + return energy_ + + +def forces_row_fn(row: ase.db.row.AtomsRow, dataset: str): + """Force data in eV / Angstrom. + + - Some datasets use sums of energy values e.g. PBE + D3. + - For external datasets, we should explicitly register how + to extract the energy property by adding it to `extract_info'. + - Unregistered datasets default to using the `energy` attribute + and a conversion factor of 1, which is always correct for our + internally generated datasets. + """ + extract_info: Dict[str, List[Tuple]] = { + "mp-traj": [("forces", 1)], + "mp-traj-d3": [("forces", 1), ("data.d3.forces", 1)], + "alexandria-d3": [("forces", 1), ("data.d3.forces", 1)], + } + if dataset not in extract_info: + if not hasattr(row, "forces"): + raise ValueError( + f"db row {row.id} doesn't have a forces attribute directly, " + "but also doesn't define a method to extract forces info." + ) + return get_property_from_row("forces", row, 1) + + forces_ = 0.0 + for row_attribute, conversion_factor in extract_info[dataset]: + forces_ += get_property_from_row(row_attribute, row, conversion_factor) + return forces_ + + +def stress_row_fn(row: ase.db.row.AtomsRow, dataset: str) -> float: + """Extract stress data.""" + extract_info: Dict[str, List[Tuple]] = { + "mp-traj": [("stress", 1)], + "mp-traj-d3": [("stress", 1), ("data.d3.stress", 1)], + "alexandria-d3": [("stress", 1), ("data.d3.stress", 1)], + } + if dataset not in extract_info: + if not hasattr(row, "stress"): + raise ValueError( + f"db row {row.id} doesn't have an stress attribute directly " + ", but also doesn't define a method to extract stress info." + ) + return get_property_from_row("stress", row, 1) + + stress_ = 0.0 + for row_attribute, conversion_factor in extract_info[dataset]: + stress_ += get_property_from_row(row_attribute, row, conversion_factor) + return stress_ + + +def test_fixture_node_row_fn(row: ase.db.row.AtomsRow): + """Just return random noise.""" + + pos = ms.from_numpy(row.toatoms().positions) + return ops.rand_like(pos).to(ms.float32) + + +def test_fixture_graph_row_fn(): + """Just return random noise.""" + return mint.randn((1, 1)).to(ms.float32) + + +energy = PropertyDefinition( + name="energy", + dim=1, + domain="real", + row_to_property_fn=energy_row_fn, +) + +forces = PropertyDefinition( + name="forces", + dim=3, + domain="real", + row_to_property_fn=forces_row_fn, +) + +stress = PropertyDefinition( + name="stress", + dim=6, + domain="real", + row_to_property_fn=stress_row_fn, +) + +test_fixture = PropertyDefinition( + name="test-fixture", + dim=3, + domain="real", + row_to_property_fn=test_fixture_node_row_fn, +) + +test_graph_fixture = PropertyDefinition( + name="test-graph-fixture", + dim=1, + domain="real", + row_to_property_fn=test_fixture_graph_row_fn, +) + + +PROPERTIES = { + "energy": energy, + "forces": forces, + "stress": stress, + "test-fixture": test_fixture, + "test-graph-fixture": test_graph_fixture, +} diff --git a/MindChem/applications/orb/src/segment_ops.py b/MindChem/applications/orb/src/segment_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..d8b671e4091723fbaa03d02fdfa853535d2ee3b8 --- /dev/null +++ b/MindChem/applications/orb/src/segment_ops.py @@ -0,0 +1,202 @@ +# ============================================================================ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Segment operations.""" + +from typing import Optional +import numpy as np +import mindspore as ms +from mindspore import ops, Tensor, mint + +MSINT = [ms.int64, ms.int32, ms.int16, ms.int8, ms.uint8] + + +def aggregate_nodes(tensor: Tensor, n_node: Tensor, reduction: str = "mean", deterministic: bool = False) -> Tensor: + """Aggregates over a tensor based on graph sizes.""" + count = len(n_node) + if deterministic: + ms.set_seed(1) + segments = ops.arange(count).repeat_interleave(n_node).astype(ms.int32) + if reduction == "sum": + return scatter_sum(tensor, segments, dim=0) + if reduction == "mean": + return scatter_mean(tensor, segments, dim=0) + if reduction == "max": + return scatter_max(tensor, segments, dim=0) + raise ValueError("Invalid reduction argument. Use sum, mean or max.") + + +def segment_sum(data: Tensor, segment_ids: Tensor, num_segments: int): + """Computes index based sum over segments of a tensor.""" + return scatter_sum(data, segment_ids, dim=0, dim_size=num_segments) + + +def segment_max(data: Tensor, segment_ids: Tensor, num_segments: int): + """Computes index based max over segments of a tensor.""" + assert segment_ids is not None, "segment_ids must not be None" + assert num_segments > 0, "num_segments must be greater than 0" + max_op = ops.ArgMaxWithValue(axis=0) + _, max_values = max_op(data) + return max_values + + +def segment_mean(data: Tensor, segment_ids: Tensor, num_segments: int): + """Computes index based mean over segments of a tensor.""" + sum_v = segment_sum(data, segment_ids, num_segments) + count = ops.scatter_add(ops.zeros( + (num_segments,), dtype=ms.int32), segment_ids, ops.ones_like(segment_ids)) + return sum_v / count.astype(sum_v.dtype) + + +def segment_softmax(data: Tensor, segment_ids: Tensor, num_segments: int, weights: Optional[Tensor] = None): + """Computes a softmax over segments of the tensor.""" + data_max = segment_max(data, segment_ids, num_segments) + data = data - data_max[segment_ids] + + unnormalised_probs = ops.exp(data) + if weights is not None: + unnormalised_probs = unnormalised_probs * weights + denominator = segment_sum(unnormalised_probs, segment_ids, num_segments) + + return safe_division(unnormalised_probs, denominator, segment_ids) + + +def safe_division(numerator: Tensor, denominator: Tensor, segment_ids: Tensor): + """Divides logits by denominator, setting 0 where the denominator is zero.""" + result = ops.where(denominator[segment_ids] == + 0, 0, numerator / denominator[segment_ids]) + return result + + +def _broadcast(src: Tensor, other: Tensor, dim: int): + """Broadcasts the source tensor to match the shape of the other tensor along the specified dimension.""" + if dim < 0: + dim = other.ndim + dim + if src.ndim == 1: + for _ in range(0, dim): + src = src.unsqueeze(0) + for _ in range(src.ndim, other.ndim): + src = src.unsqueeze(-1) + src = src.expand_as(other) + return src + + +def scatter_sum( + src: Tensor, index: Tensor, dim: int = -1, out: Optional[Tensor] = None, + dim_size: Optional[int] = None, reduce: str = "sum" +) -> Tensor: + """Applies a sum reduction of the orb_models tensor along the specified dimension.""" + assert reduce == "sum" + index = _broadcast(index, src, dim) + if out is None: + size = list(src.shape) + if dim_size is not None: + size[dim] = dim_size + elif index.numel() == 0: + size[dim] = 0 + else: + size[dim] = int(index.max()) + 1 + out = ops.zeros(size, dtype=src.dtype) + return mint.scatter_add(out, dim, index, src) + return mint.scatter_add(out, dim, index, src) + + +def scatter_std( + src: Tensor, index: Tensor, dim: int = -1, out: Optional[Tensor] = None, + dim_size: Optional[int] = None, unbiased: bool = True +) -> Tensor: + """Computes the standard deviation of the orb_models tensor along the specified dimension.""" + if out is not None: + dim_size = out.shape[dim] + + if dim < 0: + dim = src.ndim + dim + + count_dim = dim + if index.ndim <= dim: + count_dim = index.ndim - 1 + + ones = ops.ones(index.shape, dtype=src.dtype) + count = scatter_sum(ones, index, count_dim, dim_size=dim_size) + + index = _broadcast(index, src, dim) + tmp = scatter_sum(src, index, dim, dim_size=dim_size) + count = _broadcast(count, tmp, dim).clip(1) + mean = tmp / count + + var = src - mean.gather(dim, index) + var = var * var + out = scatter_sum(var, index, dim, out=out, dim_size=dim_size) + + if unbiased: + count = count - 1 + count = count.clip(1) + out = out / (count + 1e-6) + out = ops.sqrt(out) + return out + + +def scatter_mean( + src: Tensor, index: Tensor, dim: int = -1, out: Optional[Tensor] = None, + dim_size: Optional[int] = None +) -> Tensor: + """Computes the mean of the orb_models tensor along the specified dimension.""" + out = scatter_sum(src, index, dim, out=out, dim_size=dim_size) + dim_size = out.shape[dim] + + index_dim = dim + if index_dim < 0: + index_dim = index_dim + src.ndim + if index.ndim <= index_dim: + index_dim = index.ndim - 1 + + ones = ops.ones(index.shape, dtype=src.dtype) + count = scatter_sum(ones, index, index_dim, dim_size=dim_size) + count = count.clip(1) + count = _broadcast(count, out, dim) + out = out / count + return out + + +def scatter_max( + src: Tensor, index: Tensor, dim: int = -1, out: Optional[Tensor] = None, + dim_size: Optional[int] = None +) -> Tensor: + """Computes the maximum of the orb_models tensor for each group defined by index along the specified dimension.""" + if out is not None: + raise NotImplementedError( + "The 'out' argument is not supported for scatter_max") + + if src.dtype in MSINT: + init_value = np.iinfo(src.dtype).min + else: + init_value = np.finfo(src.dtype).min + + if dim < 0: + dim = src.ndim + dim + + if dim_size is None: + dim_size = int(index.max()) + 1 + + result = ops.ones( + (dim_size, *src.shape[:dim], *src.shape[dim + 1:]), dtype=src.dtype) + result = init_value * result + broadcasted_index = _broadcast(index, src, dim) + + scatter_result = ops.ZerosLike()(result) + index = ops.expand_dims(broadcasted_index, dim) + scatter_result = scatter_result.scatter_update(index, src) + result = ops.Maximum()(result, scatter_result) + return result diff --git a/MindChem/applications/orb/src/trainer.py b/MindChem/applications/orb/src/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..5b1404f6b59c4d696458d11150e306f183329fe7 --- /dev/null +++ b/MindChem/applications/orb/src/trainer.py @@ -0,0 +1,329 @@ +# ============================================================================ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Trainer.""" + +from typing import Dict, Optional, Tuple + +import mindspore as ms +from mindspore import ops, Tensor, mint + +from src import base, segment_ops + +class OrbLoss: + """Loss function for ORB models. + + This class is used to compute the loss for the ORB model. + It can be used to compute the loss for both node and graph predictions. + """ + + def __init__(self, model): + """Initializes the OrbLoss. + + Args: + target: either the name of a PropertyDefinition or a PropertyDefinition itself. + """ + self.model = model + + def loss_node(self, batch, out_batch=None): + """Apply mlp to compute loss and metrics.""" + batch_n_node = batch.n_node + assert batch.node_targets is not None + target = batch.node_targets['forces'].squeeze(-1) + pred = out_batch["node_pred"].squeeze(-1) + # make sure we remove fixed atoms before normalization + pred, target, batch_n_node = _remove_fixed_atoms( + pred, target, batch_n_node, batch.fix_atoms, self.model.training + ) + mae = mint.abs(pred - self.model.node_head.normalizer(target)) + raw_pred = self.model.node_head.normalizer.inverse(pred) + raw_mae = mint.abs(raw_pred - target) + + mae = mae.mean(dim=-1) + mae = segment_ops.aggregate_nodes( + mae, batch_n_node, reduction="mean" + ).mean() + raw_mae = raw_mae.mean(dim=-1) + raw_mae = segment_ops.aggregate_nodes( + raw_mae, batch_n_node, reduction="mean" + ).mean() + + metrics = { + "node_mae": mae.item(), + "node_mae_raw": raw_mae.item(), + "node_cosine_sim": ops.cosine_similarity(raw_pred, target, dim=-1).mean().item(), + "fwt_0.03": forces_within_threshold(raw_pred, target, batch_n_node), + } + return mae, base.ModelOutput(loss=mae, log=metrics) + + def loss_graph(self, batch, out_batch=None): + """Apply mlp to compute loss and metrics. + + Depending on whether the target is real/binary/categorical, we + use an MSE/cross-entropy loss. In the case of cross-entropy, the + preds are logits (not normalised) to take advantage of numerically + stable log-softmax. + """ + assert batch.system_targets is not None + target = batch.system_targets['stress'].squeeze(-1) + if self.model.stress_head.compute_stress: + pred = out_batch["stress_pred"].squeeze(-1) + else: + pred = out_batch["graph_pred"].squeeze(-1) + + normalized_target = self.model.stress_head.normalizer(target) + errors = normalized_target - pred + mae = mint.abs(errors).mean() + + raw_pred = self.model.stress_head.normalizer.inverse(pred) + raw_mae = mint.abs(raw_pred - target).mean() + metrics = {"stress_mae": mae.item(), "stress_mae_raw": raw_mae.item()} + return mae, base.ModelOutput(loss=mae, log=metrics) + + + def loss_energy(self, batch, out_batch=None): + """Apply mlp to compute loss and metrics.""" + assert batch.system_targets is not None + target = batch.system_targets['energy'].squeeze(-1) + pred = out_batch["graph_pred"].squeeze(-1) + + reference = self.model.graph_head.reference(batch.atomic_numbers, batch.n_node).squeeze(-1) + reference_target = target - reference + if self.model.graph_head.atom_avg: + reference_target = reference_target / batch.n_node + + normalized_reference = self.model.graph_head.normalizer(reference_target) + model_loss = normalized_reference - pred + + raw_pred = self.model.graph_head.normalizer.inverse(pred) + if self.model.graph_head.atom_avg: + raw_pred = raw_pred * batch.n_node + raw_mae = mint.abs((raw_pred + reference) - target).mean() + + reference_mae = mint.abs(reference_target).mean() + model_mae = mint.abs(model_loss).mean() + metrics = { + "energy_reference_mae": reference_mae.item(), + "energy_mae": model_mae.item(), + "energy_mae_raw": raw_mae.item(), + } + return model_mae, base.ModelOutput(loss=model_mae, log=metrics) + + def loss(self, batch, label=None): + """Loss function of Orb GraphRegressor.""" + assert label is None, "Orb GraphRegressor does not support labels." + + out = self.model( + batch.edge_features, + batch.node_features, + batch.senders, + batch.receivers, + batch.n_node, + ) + loss = Tensor(0.0, ms.float32) + metrics: Dict = {} + + loss1, graph_out = self.loss_energy(batch, out) + metrics.update(graph_out.log) + loss = loss.type_as(loss1) + loss1 + + loss2, stress_out = self.loss_graph(batch, out) + metrics.update(stress_out.log) + loss = loss.type_as(loss2) + loss2 + + loss3, node_out = self.loss_node(batch, out) + metrics.update(node_out.log) + loss = loss.type_as(loss3) + loss3 + + metrics["loss"] = loss.item() + return loss, metrics + + +def binary_accuracy( + pred: Tensor, target: Tensor, threshold: float = 0.5 +) -> float: + """Calculate binary accuracy between 2 tensors. + + Args: + pred: the prediction tensor. + target: the tensor of target values. + threshold: Binary classification threshold. Default 0.5. + + Returns: + mean accuracy. + """ + return ((pred > threshold) == target).to(ms.float32).mean().item() + + +def categorical_accuracy(pred: Tensor, target: Tensor) -> float: + """Calculate accuracy for K class classification. + + Args: + pred: the tensor of logits for K classes of shape (..., K) + target: tensor of integer target values of shape (...) + + Returns: + mean accuracy. + """ + pred_labels = mint.argmax(pred, dim=-1) + return (pred_labels == target).to(ms.float32).mean().item() + + +def error_within_threshold( + pred: Tensor, target: Tensor, threshold: float = 0.02 +) -> float: + """Calculate MAE between 2 tensors within a threshold. + + Args: + pred: the prediction tensor. + target: the tensor of target values. + threshold: margin threshold. Default 0.02 (derived from OCP metrics). + + Returns: + Mean predictions within threshold. + """ + error = mint.abs(pred - target) + within_threshold = error < threshold + return within_threshold.to(ms.float32).mean().item() + + +def forces_within_threshold( + pred: Tensor, + target: Tensor, + batch_num_nodes: Tensor, + threshold: float = 0.03, +) -> float: + """Calculate MAE between batched graph tensors within a threshold. + + The predictions for a graph are counted as being within the threshold + only if all nodes in the graph have predictions within the threshold. + + Args: + pred: the prediction tensor. + target: the tensor of target values. + batch_num_nodes: A tensor containing the number of nodes per + graph. + threshold: margin threshold. Default 0.03 (derived from OCP metrics). + + Returns: + Mean predictions within threshold. + """ + error = mint.abs(pred - target) + largest_dim_fwt = error.max(-1)[0] < threshold + + count_within_threshold = segment_ops.aggregate_nodes( + largest_dim_fwt.float(), batch_num_nodes, reduction="sum" + ) + # count equals batch_num_nodes if all nodes within threshold + return (count_within_threshold == batch_num_nodes).to(ms.float32).mean().item() + + +def energy_and_forces_within_threshold( + pred_energy: Tensor, + pred_forces: Tensor, + target_energy: Tensor, + target_forces: Tensor, + batch_num_nodes: Tensor, + fixed_atoms: Optional[Tensor] = None, + threshold: Tuple[float, float] = (0.02, 0.03), +) -> float: + """Calculate MAE between batched graph energies and forces within a threshold. + + The predictions for a graph are counted as being within the threshold + only if all nodes in the graph have predictions within the threshold AND + the energies are also within a threshold. A combo of the two above functions. + + Args: + pred_*: the prediction tensors. + target_*: the tensor of target values. + batch_num_nodes: A tensor containing the number of nodes per + graph. + fixed_atoms: A tensor of bools indicating which atoms are fixed. + threshold: margin threshold. Default (0.02, 0.03) (derived from OCP metrics). + Returns: + Mean predictions within threshold. + """ + energy_err = mint.abs(pred_energy - target_energy) + ewt = energy_err < threshold[0] + + forces_err = mint.abs(pred_forces - target_forces) + largest_dim_fwt = forces_err.max(-1).values < threshold[1] + + working_largest_dim_fwt = largest_dim_fwt + + if fixed_atoms is not None: + fixed_per_graph = segment_ops.aggregate_nodes( + fixed_atoms.int(), batch_num_nodes, reduction="sum" + ) + # remove the fixed atoms from the counts + batch_num_nodes = batch_num_nodes - fixed_per_graph + # remove the fixed atoms from the forces + working_largest_dim_fwt = largest_dim_fwt[not fixed_atoms] + + force_count_within_threshold = segment_ops.aggregate_nodes( + working_largest_dim_fwt.int(), batch_num_nodes, reduction="sum" + ) + fwt = force_count_within_threshold == batch_num_nodes + + # count equals batch_num_nodes if all nodes within threshold + return (fwt & ewt).to(ms.float32).mean().item() + + +def _remove_fixed_atoms( + pred_node: Tensor, + node_target: Tensor, + batch_n_node: Tensor, + fix_atoms: Optional[Tensor], + training: bool, +): + """We use inf targets on purpose to designate nodes for removal.""" + assert len(pred_node) == len(node_target) + if fix_atoms is not None and not training: + pred_node = pred_node[~fix_atoms] + node_target = node_target[~fix_atoms] + batch_n_node = segment_ops.aggregate_nodes( + (~fix_atoms).int(), batch_n_node, reduction="sum" + ) + return pred_node, node_target, batch_n_node + + +def bce_loss( + pred: Tensor, target: Tensor, metric_prefix: str = "" +) -> Tuple: + """Binary cross-entropy loss with accuracy metric.""" + loss = mint.nn.BCEWithLogitsLoss()(pred, target.float()) + accuracy = binary_accuracy(pred, target) + return ( + loss, + { + f"{metric_prefix}_accuracy": accuracy, + f"{metric_prefix}_loss": loss.item(), + }, + ) + + +def cross_entropy_loss( + pred: Tensor, target: Tensor, metric_prefix: str = "" +) -> Tuple: + """Cross-entropy loss with accuracy metric.""" + loss = mint.nn.CrossEntropyLoss()(pred, target.long()) + accuracy = categorical_accuracy(pred, target) + return ( + loss, + { + f"{metric_prefix}_accuracy": accuracy, + f"{metric_prefix}_loss": loss.item(), + }, + ) diff --git a/MindChem/applications/orb/src/utils.py b/MindChem/applications/orb/src/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..34bd4623f1c452d269085eea8aa03c45550fb06a --- /dev/null +++ b/MindChem/applications/orb/src/utils.py @@ -0,0 +1,296 @@ +# ============================================================================ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Experiment utilities.""" + +import math +import random +import re +from collections import defaultdict +from typing import Dict, List, Mapping, Optional, Tuple, TypeVar, Any + +import yaml +import numpy as np +import mindspore as ms +from mindspore import Tensor, mint + +from src import base + +T = TypeVar("T") + + +def load_cfg(filename): + """load_cfg + + Load configurations from yaml file and return a namespace object + """ + from argparse import Namespace + with open(filename, "r", encoding="utf-8") as f: + cfg = yaml.safe_load(f) + return Namespace(**cfg) + + +def ensure_detached(x: base.Metric) -> base.Metric: + """Ensure that the tensor is detached and on the CPU.""" + return x + + +def to_item(x: base.Metric) -> base.Metric: + """Convert a tensor to a python scalar.""" + if isinstance(x, Tensor): + return x.item() + return x + + +def prefix_keys( + dict_to_prefix: Dict[str, T], prefix: str, sep: str = "/" +) -> Dict[str, T]: + """Add a prefix to dictionary keys with a separator.""" + return {f"{prefix}{sep}{k}": v for k, v in dict_to_prefix.items()} + + +def seed_everything(seed: int, rank: int = 0) -> None: + """Set the seed for all pseudo random number generators.""" + random.seed(seed + rank) + np.random.seed(seed + rank) + ms.manual_seed(seed + rank) + + +class ScalarMetricTracker: + """Keep track of average scalar metric values.""" + + def __init__(self): + self.reset() + + def reset(self): + """Reset the AverageMetrics.""" + self.sums = defaultdict(float) + self.counts = defaultdict(int) + + def update(self, metrics: Mapping[str, base.Metric]) -> None: + """Update the metric counts with new values.""" + for k, v in metrics.items(): + if isinstance(v, Tensor) and v.nelement() > 1: + continue # only track scalar metrics + if isinstance(v, Tensor) and v.isnan().any(): + continue + self.sums[k] += ensure_detached(v) + self.counts[k] += 1 + + def get_metrics(self): + """Get the metric values, possibly reducing across gpu processes.""" + return {k: to_item(v) / self.counts[k] for k, v in self.sums.items()} + + +def gradient_clipping( + model: ms.nn.Cell, clip_value: float +) -> List[Any]: + """Add gradient clipping hooks to a model. + + This is the correct way to implement gradient clipping, because + gradients are clipped as gradients are computed, rather than after + all gradients are computed - this means expoding gradients are less likely, + because they are "caught" earlier. + + Args: + model: The model to add hooks to. + clip_value: The upper and lower threshold to clip the gradients to. + + Returns: + A list of handles to remove the hooks from the parameters. + """ + handles = [] + + def _clip(grad): + if grad is None: + return grad + return grad.clamp(min=-clip_value, max=clip_value) + + for parameter in model.trainable_params(): + if parameter.requires_grad: + h = parameter.register_hook(_clip) + handles.append(h) + + return handles + + +def get_optim( + lr: float, total_steps: int, model: ms.nn.Cell +) -> Tuple[ms.experimental.optim.Optimizer, Optional[ms.experimental.optim.lr_scheduler.LRScheduler]]: + """Configure optimizers, LR schedulers and EMA.""" + + # Initialize parameter groups + params = [] + + # Split parameters based on the regex + for param in model.trainable_params(): + name = param.name + if re.search(r"(.*bias|.*layer_norm.*|.*batch_norm.*)", name): + params.append({"params": param, "weight_decay": 0.0}) + else: + params.append({"params": param}) + + # Create the optimizer with the parameter groups + optimizer = ms.experimental.optim.Adam(params, lr=lr) + + # Create the learning rate scheduler + scheduler = ms.experimental.optim.lr_scheduler.CyclicLR( + optimizer, base_lr=1.0e-9, max_lr=lr, step_size_up=int(total_steps*0.04), step_size_down=total_steps + ) + + return optimizer, scheduler + + +def rand_angles(*shape, dtype=None): + r"""random rotation angles + + Parameters + ---------- + *shape : int + + Returns + ------- + alpha : `Tensor` + tensor of shape :math:`(\mathrm{shape})` + + beta : `Tensor` + tensor of shape :math:`(\mathrm{shape})` + + gamma : `Tensor` + tensor of shape :math:`(\mathrm{shape})` + """ + alpha, gamma = 2 * math.pi * mint.rand(2, *shape, dtype=dtype) + beta = mint.rand(shape, dtype=dtype).mul(2).sub(1).acos() + return alpha, beta, gamma + + +def matrix_x(angle: Tensor) -> Tensor: + r"""matrix of rotation around X axis + + Parameters + ---------- + angle : `Tensor` + tensor of any shape :math:`(...)` + + Returns + ------- + `Tensor` + matrices of shape :math:`(..., 3, 3)` + """ + c = angle.cos() + s = angle.sin() + o = mint.ones_like(angle) + z = mint.zeros_like(angle) + return mint.stack( + [ + mint.stack([o, z, z], dim=-1), + mint.stack([z, c, -s], dim=-1), + mint.stack([z, s, c], dim=-1), + ], + dim=-2, + ) + + +def matrix_y(angle: Tensor) -> Tensor: + r"""matrix of rotation around Y axis + + Parameters + ---------- + angle : `Tensor` + tensor of any shape :math:`(...)` + + Returns + ------- + `Tensor` + matrices of shape :math:`(..., 3, 3)` + """ + c = angle.cos() + s = angle.sin() + o = mint.ones_like(angle) + z = mint.zeros_like(angle) + return mint.stack( + [ + mint.stack([c, z, s], dim=-1), + mint.stack([z, o, z], dim=-1), + mint.stack([-s, z, c], dim=-1), + ], + dim=-2, + ) + + +def matrix_z(angle: Tensor) -> Tensor: + r"""matrix of rotation around Z axis + + Parameters + ---------- + angle : `Tensor` + tensor of any shape :math:`(...)` + + Returns + ------- + `Tensor` + matrices of shape :math:`(..., 3, 3)` + """ + c = angle.cos() + s = angle.sin() + o = mint.ones_like(angle) + z = mint.zeros_like(angle) + return mint.stack( + [ + mint.stack([c, -s, z], dim=-1), + mint.stack([s, c, z], dim=-1), + mint.stack([z, z, o], dim=-1), + ], + dim=-2, + ) + + +def angles_to_matrix(alpha, beta, gamma): + r"""conversion from angles to matrix + + Parameters + ---------- + alpha : `Tensor` + tensor of shape :math:`(...)` + + beta : `Tensor` + tensor of shape :math:`(...)` + + gamma : `Tensor` + tensor of shape :math:`(...)` + + Returns + ------- + `Tensor` + matrices of shape :math:`(..., 3, 3)` + """ + alpha, beta, gamma = ms.numpy.broadcast_arrays(alpha, beta, gamma) + return matrix_y(alpha) @ matrix_x(beta) @ matrix_y(gamma) + + +def rand_matrix(*shape, dtype=None): + r"""random rotation matrix + + Parameters + ---------- + *shape : int + + Returns + ------- + `Tensor` + tensor of shape :math:`(\mathrm{shape}, 3, 3)` + """ + rotation_matrix = angles_to_matrix(*rand_angles(*shape, dtype=dtype)) + return rotation_matrix diff --git a/MindChem/mindchemistry/cell/orb/__init__.py b/MindChem/mindchemistry/cell/orb/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..709978030161be6191c7d5bc96b466e777c6ae3a --- /dev/null +++ b/MindChem/mindchemistry/cell/orb/__init__.py @@ -0,0 +1,36 @@ +# ============================================================================ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""init""" + +from .orb import ( + NodeHead, + GraphHead, + EnergyHead, + Orb, +) +from .gns import ( + AttentionInteractionNetwork, + MoleculeGNS, +) + +__all__ = [ + "AttentionInteractionNetwork", + "EnergyHead", + "GraphHead", + "MoleculeGNS", + "NodeHead", + "Orb", +] diff --git a/MindChem/mindchemistry/cell/orb/gns.py b/MindChem/mindchemistry/cell/orb/gns.py new file mode 100644 index 0000000000000000000000000000000000000000..ab53083f8605dbe0463a5d9a0554d5e5990428f9 --- /dev/null +++ b/MindChem/mindchemistry/cell/orb/gns.py @@ -0,0 +1,690 @@ +# ============================================================================ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""GNS Molecule.""" + + +from typing import List, Literal, Optional, Dict, Any, Union +from functools import partial + +import numpy as np +from mindspore import nn, ops, Tensor, mint +from mindspore.common.initializer import Uniform +import mindspore.ops.operations as P + +from mindchemistry.cell.orb.utils import build_mlp + +_KEY = "feat" + + +def mlp_and_layer_norm(in_dim: int, out_dim: int, hidden_dim: int, n_layers: int) -> nn.SequentialCell: + """Create an MLP followed by layer norm. + + Args: + in_dim (int): Input dimension. + out_dim (int): Output dimension. + hidden_dim (int): Hidden dimension. + n_layers (int): Number of hidden layers. + + Returns: + nn.SequentialCell: A sequential cell containing the MLP and layer norm. + """ + layers = build_mlp( + in_dim, + [hidden_dim for _ in range(n_layers)], + out_dim, + ) + layers.append(nn.LayerNorm((out_dim,))) + return layers + + +def get_cutoff(p: int, r: Tensor, r_max: float) -> Tensor: + """Get the cutoff function for attention. + + Args: + p (int): Polynomial order. + r (Tensor): Distance tensor. + r_max (float): Maximum distance for the cutoff. + + Returns: + Tensor: Cutoff tensor. + """ + envelope = 1.0 - ((p + 1.0) * (p + 2.0) / 2.0) * ops.pow(r / r_max, p) + \ + p * (p + 2.0) * ops.pow(r / r_max, p + 1) - \ + (p * (p + 1.0) / 2) * ops.pow(r / r_max, p + 2) + cutoff = ops.expand_dims( + ops.where(r < r_max, envelope, ops.zeros_like(envelope)), -1) + return cutoff + + +def gaussian_basis_function( + scalars: Tensor, + num_bases: Union[Tensor, int], + radius: Union[Tensor, float], + scale: Union[Tensor, float] = 1.0, +) -> Tensor: + """Gaussian basis function applied to a tensor of scalars. + + Args: + scalars (Tensor): Scalars to compute the gbf on. Shape [num_scalars]. + num_bases (Tensor): The number of bases. An Int. + radius (Tensor): The largest centre of the bases. A Float. + scale (Tensor, optional): The width of the gaussians. Defaults to 1. + + Returns: + Tensor: A tensor of shape [num_scalars, num_bases]. + """ + assert len(scalars.shape) == 1 + gaussian_means = ops.arange( + 0, float(radius), float(radius) / num_bases + ) + return mint.exp( + -(scale**2) * (scalars.unsqueeze(1) - gaussian_means.unsqueeze(0)).abs() ** 2 + ) + + +class AtomEmbedding(nn.Cell): + r""" + AtomEmbedding Layer. + + This layer initializes atom embeddings based on the atomic number of elements in the periodic table. + It uses an embedding table initialized with a uniform distribution over the range [-sqrt(3), sqrt(3)]. + + Args: + emb_size (int): Size of the embedding vector for each atom. + num_elements (int): Number of elements in the periodic table (typically 118 for known elements). + + Inputs: + - **x** (Tensor) - Input tensor of shape [..., num_atoms], where + each value represents the atomic number of an atom in the periodic table. + + Outputs: + - **h** (Tensor) - Output tensor of shape [..., num_atoms, emb_size], + where each atom's embedding is represented as a vector of size `emb_size`. + + Supported Platforms: + ``Ascend`` + """ + def __init__(self, emb_size, num_elements): + """init + """ + super().__init__() + self.emb_size = emb_size + self.embeddings = nn.Embedding( + num_elements + 1, emb_size, embedding_table=Uniform(np.sqrt(3))) + + def construct(self, x): + """construct + """ + h = self.embeddings(x) + return h + + +class Encoder(nn.Cell): + r""" + Encoder for Graph Network States (GNS). + + This encoder processes node and edge features using MLPs and layer normalization. + It concatenates the features of nodes and edges, applies MLPs to update their states, + and returns the updated features. + + Args: + num_node_in_features (int): Number of input features for nodes. + num_node_out_features (int): Number of output features for nodes. + num_edge_in_features (int): Number of input features for edges. + num_edge_out_features (int): Number of output features for edges. + num_mlp_layers (int): Number of MLP layers. + mlp_hidden_dim (int): Hidden dimension for the MLP. + node_feature_names (List[str]): List of node feature names. + edge_feature_names (List[str]): List of edge feature names. + + Inputs: + - **nodes** (Dict[str, Tensor]) - Dictionary of node features, where keys are feature names + and values are tensors of shape (num_nodes, num_node_in_features). + - **edges** (Dict[str, Tensor]) - Dictionary of edge features, where keys are feature names + and values are tensors of shape (num_edges, num_edge_in_features). + + Outputs: + - **edges** (Dict[str, Tensor]) - Updated edge features dictionary, where key "feat" contains + the updated edge features of shape (num_edges, num_edge_out_features). + - **nodes** (Dict[str, Tensor]) - Updated node features dictionary, where key "feat" contains + the updated node features of shape (num_nodes, num_node_out_features). + + Supported Platforms: + ``Ascend`` + """ + + def __init__(self, + num_node_in_features: int, + num_node_out_features: int, + num_edge_in_features: int, + num_edge_out_features: int, + num_mlp_layers: int, + mlp_hidden_dim: int, + node_feature_names: List[str], + edge_feature_names: List[str]): + """init + """ + super().__init__() + self.node_feature_names = node_feature_names + self.edge_feature_names = edge_feature_names + self._node_fn = mlp_and_layer_norm( + num_node_in_features, num_node_out_features, mlp_hidden_dim, num_mlp_layers) + self._edge_fn = mlp_and_layer_norm( + num_edge_in_features, num_edge_out_features, mlp_hidden_dim, num_mlp_layers) + + def construct(self, nodes, edges): + """construct + """ + edge_features = ops.cat([edges[k] for k in self.edge_feature_names], axis=-1) + node_features = ops.cat([nodes[k] for k in self.node_feature_names], axis=-1) + + edges.update({_KEY: self._edge_fn(edge_features)}) + nodes.update({_KEY: self._node_fn(node_features)}) + return edges, nodes + + +class InteractionNetwork(nn.Cell): + r""" + Interaction Network. + + Implements a message passing neural network layer that updates node and edge features based on interactions. + This layer combines node and edge features, applies MLPs to update their states, and returns the updated features. + + Args: + num_node_in (int): Number of input features for nodes. + num_node_out (int): Number of output features for nodes. + num_edge_in (int): Number of input features for edges. + num_edge_out (int): Number of output features for edges. + num_mlp_layers (int): Number of MLP layers. + mlp_hidden_dim (int): Hidden dimension for the MLP. + + Inputs: + - **graph_edges** (Dict[str, Tensor]) - Dictionary of edge features, where key "feat" contains + the edge features of shape (num_edges, num_edge_in). + - **graph_nodes** (Dict[str, Tensor]) - Dictionary of node features, where key "feat" contains + the node features of shape (num_nodes, num_node_in). + - **senders** (Tensor) - Indices of the sender nodes for each edge, shape (num_edges,). + - **receivers** (Tensor) - Indices of the receiver nodes for each edge, shape (num_edges,). + + Outputs: + - **edges** (Dict[str, Tensor]) - Updated edge features dictionary, where key "feat" contains + the updated edge features of shape (num_edges, num_edge_out). + - **nodes** (Dict[str, Tensor]) - Updated node features dictionary, where key "feat" contains + the updated node features of shape (num_nodes, num_node_out). + + Supported Platforms: + ``Ascend`` + """ + def __init__(self, + num_node_in: int, + num_node_out: int, + num_edge_in: int, + num_edge_out: int, + num_mlp_layers: int, + mlp_hidden_dim: int): + """init + """ + super().__init__() + self._node_mlp = mlp_and_layer_norm( + num_node_in + num_edge_out, num_node_out, mlp_hidden_dim, num_mlp_layers) + self._edge_mlp = mlp_and_layer_norm( + num_node_in + num_node_in + num_edge_in, num_edge_out, mlp_hidden_dim, num_mlp_layers) + + def construct(self, graph_edges, graph_nodes, senders, receivers): + """construct + """ + nodes = graph_nodes[_KEY] + edges = graph_edges[_KEY] + + sent_attributes = ops.gather(nodes, senders, 0) + received_attributes = ops.gather(nodes, receivers, 0) + + edge_features = ops.cat( + [edges, sent_attributes, received_attributes], axis=1) + updated_edges = self._edge_mlp(edge_features) + + received_attributes = ops.scatter_add( + ops.zeros_like(nodes), receivers, updated_edges) + + node_features = ops.cat([nodes, received_attributes], axis=1) + updated_nodes = self._node_mlp(node_features) + + nodes = graph_nodes[_KEY] + updated_nodes + edges = graph_edges[_KEY] + updated_edges + + node_features = {**graph_nodes, _KEY: nodes} + edge_features = {**graph_edges, _KEY: edges} + return edge_features, node_features + + +# pylint: disable=C0301 +class AttentionInteractionNetwork(nn.Cell): + r""" + Attention interaction network. + Implements attention-based message passing neural network layer for edge updates in molecular graphs. + + Args: + num_node_in (int): Number of input node features. + num_node_out (int): Number of output node features. + num_edge_in (int): Number of input edge features. + num_edge_out (int): Number of output edge features. + num_mlp_layers (int): Number of hidden layers in node and edge update MLPs. + mlp_hidden_dim (int): Hidden dimension size of MLPs. + attention_gate (str, optional): Attention gate type, ``"sigmoid"`` or ``"softmax"``. Default: ``"sigmoid"``. + distance_cutoff (bool, optional): Whether to use distance-based edge cutoff. Default: ``True``. + polynomial_order (int, optional): Order of polynomial cutoff function. Default: ``4``. + cutoff_rmax (float, optional): Maximum distance for cutoff. Default: ``6.0``. + + Inputs: + - **graph_edges** (dict) - Edge feature dictionary, must contain key "feat" with shape :math:`(n_{edges}, num\_edge\_in)`. + - **graph_nodes** (dict) - Node feature dictionary, must contain key "feat" with shape :math:`(n_{nodes}, num\_node\_in)`. + - **senders** (Tensor) - Sender node indices for each edge, shape :math:`(n_{edges},)`. + - **receivers** (Tensor) - Receiver node indices for each edge, shape :math:`(n_{edges},)`. + + Outputs: + - **edges** (dict) - Updated edge feature dictionary with key "feat" of shape :math:`(n_{edges}, num\_edge\_out)`. + - **nodes** (dict) - Updated node feature dictionary with key "feat" of shape :math:`(n_{nodes}, num\_node\_out)`. + + Raises: + ValueError: If `attention_gate` is not "sigmoid" or "softmax". + ValueError: If edge or node features do not contain the required "feat" key. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> import numpy as np + >>> import mindspore + >>> from mindspore import Tensor + >>> from mindchemistry.cell.orb.gns import AttentionInteractionNetwork + >>> attn_net = AttentionInteractionNetwork( + ... num_node_in=256, + ... num_node_out=256, + ... num_edge_in=256, + ... num_edge_out=256, + ... num_mlp_layers=2, + ... mlp_hidden_dim=512, + ... ) + >>> n_atoms = 4 + >>> n_edges = 10 + >>> atomic_numbers = Tensor(np.random.randint(1, 119, size=(n_atoms,), dtype=np.int32)) + >>> atomic_numbers_embedding_np = np.zeros((n_atoms, 118), dtype=np.float32) + >>> for i, num in enumerate(atomic_numbers.asnumpy()): + ... atomic_numbers_embedding_np[i, num - 1] = 1.0 + >>> node_features = { + ... "atomic_numbers": atomic_numbers, + ... "atomic_numbers_embedding": Tensor(atomic_numbers_embedding_np), + ... "positions": Tensor(np.random.randn(n_atoms, 3).astype(np.float32)), + ... "feat": Tensor(np.random.randn(n_atoms, 256).astype(np.float32)) + ... } + >>> edge_features = { + ... "vectors": Tensor(np.random.randn(n_edges, 3).astype(np.float32)), + ... "r": Tensor(np.abs(np.random.randn(n_edges).astype(np.float32) * 10)), + ... "feat": Tensor(np.random.randn(n_edges, 256).astype(np.float32)) + ... } + >>> senders = Tensor(np.random.randint(0, n_atoms, size=(n_edges,), dtype=np.int32)) + >>> receivers = Tensor(np.random.randint(0, n_atoms, size=(n_edges,), dtype=np.int32)) + >>> edges, nodes = attn_net( + ... edge_features, + ... node_features, + ... senders, + ... receivers, + ... ) + >>> print(edges["feat"].shape, nodes["feat"].shape) + (10, 256) (4, 256) + """ + + def __init__(self, + num_node_in: int, + num_node_out: int, + num_edge_in: int, + num_edge_out: int, + num_mlp_layers: int, + mlp_hidden_dim: int, + attention_gate: Literal["sigmoid", "softmax"] = "sigmoid", + distance_cutoff: bool = True, + polynomial_order: Optional[int] = 4, + cutoff_rmax: Optional[float] = 6.0): + """init + """ + super().__init__() + self._num_node_in = num_node_in + self._num_node_out = num_node_out + self._num_edge_in = num_edge_in + self._num_edge_out = num_edge_out + self._num_mlp_layers = num_mlp_layers + self._mlp_hidden_dim = mlp_hidden_dim + self._node_mlp = mlp_and_layer_norm( + num_node_in + num_edge_out + num_edge_out, num_node_out, mlp_hidden_dim, num_mlp_layers) + self._edge_mlp = mlp_and_layer_norm( + num_node_in + num_node_in + num_edge_in, num_edge_out, mlp_hidden_dim, num_mlp_layers) + self._receive_attn = nn.Dense(num_edge_in, 1) + self._send_attn = nn.Dense(num_edge_in, 1) + self._distance_cutoff = distance_cutoff + self._r_max = cutoff_rmax + self._polynomial_order = polynomial_order + self._attention_gate = attention_gate + + self.scatter_add = P.TensorScatterAdd() + + def construct(self, graph_edges, graph_nodes, senders, receivers): + """construct + """ + nodes = graph_nodes[_KEY] + edges = graph_edges[_KEY] + + p = self._polynomial_order + r_max = self._r_max + r = graph_edges['r'] + cutoff = get_cutoff(p, r, r_max) + + sent_attributes = ops.gather(nodes, senders, 0) + received_attributes = ops.gather(nodes, receivers, 0) + + if self._attention_gate == "softmax": + receive_attn = ops.softmax(self._receive_attn(edges), axis=0) + send_attn = ops.softmax(self._send_attn(edges), axis=0) + else: + receive_attn = ops.sigmoid(self._receive_attn(edges)) + send_attn = ops.sigmoid(self._send_attn(edges)) + + if self._distance_cutoff: + receive_attn = receive_attn * cutoff + send_attn = send_attn * cutoff + + edge_features = ops.cat( + [edges, sent_attributes, received_attributes], axis=1) + updated_edges = self._edge_mlp(edge_features) + + if senders.ndim < 2: + senders = senders.unsqueeze(-1) + sent_attributes = self.scatter_add( + ops.zeros_like(nodes), senders, updated_edges * send_attn) + if receivers.ndim < 2: + receivers = receivers.unsqueeze(-1) + received_attributes = self.scatter_add( + ops.zeros_like(nodes), receivers, updated_edges * receive_attn) + + node_features = ops.cat( + [nodes, received_attributes, sent_attributes], axis=1) + updated_nodes = self._node_mlp(node_features) + + nodes = graph_nodes[_KEY] + updated_nodes + edges = graph_edges[_KEY] + updated_edges + + node_features = {**graph_nodes, _KEY: nodes} + edge_features = {**graph_edges, _KEY: edges} + return edge_features, node_features + +class Decoder(nn.Cell): + r""" + Decoder for Graph Network States (GNS). + + This decoder processes node features using an MLP to produce predictions. + It takes the node features as input and outputs updated node features with predictions. + + Args: + num_node_in (int): Number of input features for nodes. + num_node_out (int): Number of output features for nodes. + num_mlp_layers (int): Number of MLP layers. + mlp_hidden_dim (int): Hidden dimension for the MLP. + batch_norm (bool, optional): Whether to apply batch normalization. Defaults to False. + + Inputs: + - **graph_nodes** (Dict[str, Tensor]) - Dictionary of node features, where key "feat" contains + the node features of shape (num_nodes, num_node_in). + + Outputs: + - **graph_nodes** (Dict[str, Tensor]) - Updated node features dictionary, where key "pred" contains + the predicted node features of shape (num_nodes, num_node_out). + + Supported Platforms: + ``Ascend`` + """ + def __init__(self, + num_node_in: int, + num_node_out: int, + num_mlp_layers: int, + mlp_hidden_dim: int, + batch_norm: bool = False): + """Initialization. + Args: + num_node_in (int): Number of input features for nodes. + num_node_out (int): Number of output features for nodes. + num_mlp_layers (int): Number of MLP layers. + mlp_hidden_dim (int): Hidden dimension for the MLP. + batch_norm (bool, optional): Whether to apply batch normalization. Defaults to False. + """ + super().__init__() + seq = build_mlp( + num_node_in, + [mlp_hidden_dim for _ in range(num_mlp_layers)], + num_node_out, + ) + if batch_norm: + seq.append(nn.BatchNorm1d(num_node_out)) + self.node_fn = nn.SequentialCell(seq) + + def construct(self, graph_nodes): + """Forward pass of the decoder. + Args: + graph_nodes (Dict[str, Tensor]): Dictionary of node features. + Returns: + Dict[str, Tensor]: Updated node features with predictions. + """ + nodes = graph_nodes[_KEY] + updated = self.node_fn(nodes) + return {**graph_nodes, "pred": updated} + + +# pylint: disable=C0301 +class MoleculeGNS(nn.Cell): + r""" + Molecular graph neural network. + Implements flexible modular graph neural network for molecular property prediction based on message passing + with attention or other interaction mechanisms. Supports node and edge embeddings, multiple message passing + steps, and customizable interaction layers for complex molecular graphs. + + Args: + num_node_in_features (int): Number of input features per node. + num_node_out_features (int): Number of output features per node. + num_edge_in_features (int): Number of input features per edge. + latent_dim (int): Latent dimension for node and edge representations. + num_message_passing_steps (int): Number of message passing layers. + num_mlp_layers (int): Number of hidden layers in node and edge update MLPs. + mlp_hidden_dim (int): Hidden dimension size of MLPs. + node_feature_names (List[str]): List of node feature keys to use from input dictionary. + edge_feature_names (List[str]): List of edge feature keys to use from input dictionary. + use_embedding (bool, optional): Whether to use atomic number embedding for nodes. Default: ``True``. + interactions (str, optional): Type of interaction layer to use (e.g., ``"simple_attention"``). Default: ``"simple_attention"``. + interaction_params (Optional[Dict[str, Any]], optional): Parameters for interaction layer, e.g., cutoff, + polynomial order, gate type. Default: ``None``. + + Inputs: + - **edge_features** (dict) - Edge feature dictionary, must contain keys specified in `edge_feature_names`. + - **node_features** (dict) - Node feature dictionary, must contain keys specified in `node_feature_names`. + - **senders** (Tensor) - Sender node indices for each edge, shape :math:`(n_{edges},)`. + - **receivers** (Tensor) - Receiver node indices for each edge, shape :math:`(n_{edges},)`. + + Outputs: + - **edges** (dict) - Updated edge feature dictionary with key "feat" of shape :math:`(n_{edges}, latent\_dim)`. + - **nodes** (dict) - Updated node feature dictionary with key "feat" of shape :math:`(n_{nodes}, latent\_dim)`. + + Raises: + ValueError: If required feature keys are missing in `edge_features` or `node_features`. + ValueError: If `interactions` is not a supported type. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> import numpy as np + >>> import mindspore + >>> from mindspore import Tensor + >>> from mindchemistry.cell.orb.gns import MoleculeGNS + >>> gns_model = MoleculeGNS( + ... num_node_in_features=256, + ... num_node_out_features=3, + ... num_edge_in_features=23, + ... latent_dim=256, + ... interactions="simple_attention", + ... interaction_params={ + ... "distance_cutoff": True, + ... "polynomial_order": 4, + ... "cutoff_rmax": 6, + ... "attention_gate": "sigmoid", + ... }, + ... num_message_passing_steps=15, + ... num_mlp_layers=2, + ... mlp_hidden_dim=512, + ... use_embedding=True, + ... node_feature_names=["feat"], + ... edge_feature_names=["feat"], + ... ) + >>> n_atoms = 4 + >>> n_edges = 10 + >>> atomic_numbers = Tensor(np.random.randint(1, 119, size=(n_atoms,), dtype=np.int32)) + >>> atomic_numbers_embedding_np = np.zeros((n_atoms, 118), dtype=np.float32) + >>> for i, num in enumerate(atomic_numbers.asnumpy()): + ... atomic_numbers_embedding_np[i, num - 1] = 1.0 + >>> node_features = { + ... "atomic_numbers": atomic_numbers, + ... "atomic_numbers_embedding": Tensor(atomic_numbers_embedding_np), + ... "positions": Tensor(np.random.randn(n_atoms, 3).astype(np.float32)), + ... "feat": Tensor(np.random.randn(n_atoms, 256).astype(np.float32)) + ... } + >>> edge_features = { + ... "vectors": Tensor(np.random.randn(n_edges, 3).astype(np.float32)), + ... "r": Tensor(np.abs(np.random.randn(n_edges).astype(np.float32) * 10)), + ... "feat": Tensor(np.random.randn(n_edges, 256).astype(np.float32)) + ... } + >>> senders = Tensor(np.random.randint(0, n_atoms, size=(n_edges,), dtype=np.int32)) + >>> receivers = Tensor(np.random.randint(0, n_atoms, size=(n_edges,), dtype=np.int32)) + >>> edges, nodes = gns_model( + ... edge_features, + ... node_features, + ... senders, + ... receivers, + ... ) + >>> print(edges["feat"].shape, nodes["feat"].shape) + (10, 256) (4, 256) + """ + + def __init__(self, + num_node_in_features: int, + num_node_out_features: int, + num_edge_in_features: int, + latent_dim: int, + num_message_passing_steps: int, + num_mlp_layers: int, + mlp_hidden_dim: int, + node_feature_names: List[str], + edge_feature_names: List[str], + use_embedding: bool = True, + interactions: Literal["default", + "simple_attention"] = "simple_attention", + interaction_params: Optional[Dict[str, Any]] = None): + """init + """ + super().__init__() + self._encoder = Encoder( + num_node_in_features=num_node_in_features, + num_node_out_features=latent_dim, + num_edge_in_features=num_edge_in_features, + num_edge_out_features=latent_dim, + num_mlp_layers=num_mlp_layers, + mlp_hidden_dim=mlp_hidden_dim, + node_feature_names=node_feature_names, + edge_feature_names=edge_feature_names + ) + if interactions == "default": + InteractionNetworkClass = InteractionNetwork + elif interactions == "simple_attention": + InteractionNetworkClass = AttentionInteractionNetwork + self.num_message_passing_steps = num_message_passing_steps + if interaction_params is None: + interaction_params = {} + self.gnn_stacks = nn.CellList([ + InteractionNetworkClass( + num_node_in=latent_dim, + num_node_out=latent_dim, + num_edge_in=latent_dim, + num_edge_out=latent_dim, + num_mlp_layers=num_mlp_layers, + mlp_hidden_dim=mlp_hidden_dim, + **interaction_params + ) for _ in range(self.num_message_passing_steps) + ]) + self._decoder = Decoder( + num_node_in=latent_dim, + num_node_out=num_node_out_features, + num_mlp_layers=num_mlp_layers, + mlp_hidden_dim=mlp_hidden_dim + ) + self.rbf = partial(gaussian_basis_function, num_bases=20, radius=10.0) + self.use_embedding = use_embedding + if self.use_embedding: + self.atom_emb = AtomEmbedding(latent_dim, 118) + + def construct(self, edge_features, node_features, senders, receivers): + """construct + """ + edge_features = self.featurize_edges(edge_features) + node_features = self.featurize_nodes(node_features) + edges, nodes = self._encoder(node_features, edge_features) + for gnn in self.gnn_stacks: + edges, nodes = gnn(edges, nodes, senders, receivers) + nodes = self._decoder(nodes) + return edges, nodes + + def featurize_nodes(self, node_features): + """Featurize the nodes of a graph. + + Args: + node_features (Dict[str, Tensor]): Dictionary of node features. + + Returns: + Dict[str, Tensor]: Updated node features with atomic embeddings. + """ + one_hot_atomic = ops.OneHot()( + node_features["atomic_numbers"], 118, Tensor(1.0), Tensor(0.0) + ) + if self.use_embedding: + atomic_embedding = self.atom_emb(node_features["atomic_numbers"]) + else: + atomic_embedding = one_hot_atomic + + node_features = {**node_features, **{_KEY: atomic_embedding}} + return node_features + + def featurize_edges(self, edge_features): + """Featurize the edges of a graph. + + Args: + edge_features (Dict[str, Tensor]): Dictionary of edge features. + + Returns: + Dict[str, Tensor]: Updated edge features with radial basis functions and unit vectors. + """ + lengths = ops.norm(edge_features['vectors'], dim=1) + non_zero_divisor = ops.where( + lengths == 0, ops.ones_like(lengths), lengths) + unit_vectors = edge_features['vectors'] / ops.expand_dims(non_zero_divisor, 1) + rbfs = self.rbf(lengths) + edges = ops.cat([rbfs, unit_vectors], axis=1) + + edge_features = {**edge_features, **{_KEY: edges}} + return edge_features diff --git a/MindChem/mindchemistry/cell/orb/orb.py b/MindChem/mindchemistry/cell/orb/orb.py new file mode 100644 index 0000000000000000000000000000000000000000..8afc55dbf5e931505ad6747b53ff9fcd4e1fbe23 --- /dev/null +++ b/MindChem/mindchemistry/cell/orb/orb.py @@ -0,0 +1,698 @@ +# ============================================================================ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Orb GraphRegressor.""" + +from typing import Literal, Optional, Union +import numpy + +import mindspore as ms +from mindspore import Parameter, ops, Tensor, mint + +from mindchemistry.cell.orb.gns import _KEY, MoleculeGNS +from mindchemistry.cell.orb.utils import ( + aggregate_nodes, + build_mlp, + REFERENCE_ENERGIES, +) + + +class LinearReferenceEnergy(ms.nn.Cell): + r""" + Linear reference energy (no bias term). + + This class implements a linear reference energy model that can be used + to compute the reference energy for a given set of atomic numbers. + + Args: + weight_init (numpy.ndarray, optional): Initial weights for the linear layer. + If not provided, the weights will be initialized randomly. + trainable (bool, optional): Whether the weights are trainable or not. + If not provided, the weights will be trainable by default. + + Inputs: + - **atom_types** (Tensor) - A tensor of atomic numbers of shape (n_atoms,). + - **n_node** (Tensor) - A tensor of shape (n_graphs,) containing the number of nodes in each graph. + + Outputs: + - **Tensor** - A tensor of shape (n_graphs, 1) containing the reference energy. + + Raises: + ValueError: If the input tensor shapes are not compatible with the expected shapes. + TypeError: If the input types are not compatible with the expected types. + + Supported Platforms: + ``Ascend`` + """ + def __init__( + self, + weight_init: Optional[numpy.ndarray] = None, + trainable: Optional[bool] = None, + ): + """init + """ + super().__init__() + + if trainable is None: + trainable = weight_init is None + + self.linear = ms.nn.Dense(118, 1, has_bias=False) + if weight_init is not None: + self.linear.weight.set_data(Tensor(weight_init, dtype=ms.float32).reshape(1, 118)) + if not trainable: + self.linear.weight.requires_grad = False + + def construct(self, atom_types: Tensor, n_node: Tensor): + """construct + """ + one_hot_atomic = ops.OneHot()(atom_types, 118, Tensor(1.0, ms.float32), Tensor(0.0, ms.float32)) + + reduced = aggregate_nodes(one_hot_atomic, n_node, reduction="sum") + return self.linear(reduced) + + +class ScalarNormalizer(ms.nn.Cell): + r""" + Scalar normalizer that learns mean and std from data. + + NOTE: Multi-dimensional tensors are flattened before updating + the running mean/std. This is desired behaviour for force targets. + + Args: + init_mean (Tensor or float, optional): Initial mean value for normalization. + If not provided, defaults to 0.0. + init_std (Tensor or float, optional): Initial standard deviation value for normalization. + If not provided, defaults to 1.0. + init_num_batches (int, optional): Initial number of batches for normalization. + If not provided, defaults to 1000. + + Inputs: + - **x** (Tensor) - A tensor of shape (n_samples, n_features) to normalize. + + Outputs: + - **Tensor** - A tensor of the same shape as x, normalized by the running mean and std. + + Raises: + ValueError: If the input tensor is not of the expected shape. + TypeError: If the input types are not compatible with the expected types. + + Supported Platforms: + ``Ascend`` + """ + def __init__( + self, + init_mean: Optional[Union[Tensor, float]] = None, + init_std: Optional[Union[Tensor, float]] = None, + init_num_batches: Optional[int] = 1000, + ): + """init + """ + super().__init__() + self.bn = mint.nn.BatchNorm1d(1, affine=False, momentum=None) + self.bn.running_mean = Parameter(Tensor([0], ms.float32)) + self.bn.running_var = Parameter(Tensor([1], ms.float32)) + self.bn.num_batches_tracked = Parameter(Tensor([1000], ms.float32)) + self.stastics = { + "running_mean": init_mean if init_mean is not None else 0.0, + "running_var": init_std**2 if init_std is not None else 1.0, + "num_batches_tracked": init_num_batches if init_num_batches is not None else 1000, + } + + def construct(self, x: Tensor): + """construct + """ + if self.training: + self.bn(x.view(-1, 1)) + if hasattr(self, "running_mean"): + return (x - self.running_mean) / mint.sqrt(self.running_var) + return (x - self.bn.running_mean) / mint.sqrt(self.bn.running_var) + + def inverse(self, x: Tensor): + """Reverse the construct normalization. + + Args: + x: A tensor of shape (n_samples, n_features) to inverse normalize. + + Returns: + A tensor of the same shape as x, inverse normalized by the running mean and std. + """ + if hasattr(self, "running_mean"): + return x * mint.sqrt(self.running_var) + self.running_mean + return x * mint.sqrt(self.bn.running_var) + self.bn.running_mean + + +# pylint: disable=C0301 +class NodeHead(ms.nn.Cell): + r""" + Node-level prediction head. + + Implements neural network head for predicting node-level properties from node features. This head can be + added to base models to enable auxiliary tasks during pretraining or added in fine-tuning steps. + + Args: + latent_dim (int): Input feature dimension for each node. + num_mlp_layers (int): Number of hidden layers in MLP. + mlp_hidden_dim (int): Hidden dimension size of MLP. + target_property_dim (int): Output dimension of node-level target property. + dropout (Optional[float], optional): Dropout rate for MLP. Default: ``None``. + remove_mean (bool, optional): If True, remove mean from output, typically used for force prediction. + Default: ``True``. + + Inputs: + - **node_features** (dict) - Node feature dictionary, must contain key "feat" with shape :math:`(n_{nodes}, latent\_dim)`. + - **n_node** (Tensor) - Number of nodes in graph, shape :math:`(1,)`. + + Outputs: + - **output** (dict) - Dictionary containing key "node_pred" with value of shape :math:`(n_{nodes}, target\_property\_dim)`. + + Raises: + ValueError: If required feature keys are missing in `node_features`. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> import numpy as np + >>> import mindspore + >>> from mindspore import Tensor + >>> from mindchemistry.cell.orb.gns import NodeHead + >>> node_head = NodeHead( + ... latent_dim=256, + ... num_mlp_layers=1, + ... mlp_hidden_dim=256, + ... target_property_dim=3, + ... remove_mean=True, + ... ) + >>> n_atoms = 4 + >>> n_node = Tensor([n_atoms], mindspore.int32) + >>> atomic_numbers = Tensor(np.random.randint(1, 119, size=(n_atoms,), dtype=np.int32)) + >>> atomic_numbers_embedding_np = np.zeros((n_atoms, 118), dtype=np.float32) + >>> for i, num in enumerate(atomic_numbers.asnumpy()): + ... atomic_numbers_embedding_np[i, num - 1] = 1.0 + >>> node_features = { + ... "atomic_numbers": atomic_numbers, + ... "atomic_numbers_embedding": Tensor(atomic_numbers_embedding_np), + ... "positions": Tensor(np.random.randn(n_atoms, 3).astype(np.float32)), + ... "feat": Tensor(np.random.randn(n_atoms, 256).astype(np.float32)) + ... } + >>> output = node_head(node_features, n_node) + >>> print(output['node_pred'].shape) + (4, 3) + """ + def __init__( + self, + latent_dim: int, + num_mlp_layers: int, + mlp_hidden_dim: int, + target_property_dim: int, + dropout: Optional[float] = None, + remove_mean: bool = True, + ): + """init + """ + super().__init__() + self.target_property_dim = target_property_dim + self.normalizer = ScalarNormalizer() + + self.mlp = build_mlp( + input_size=latent_dim, + hidden_layer_sizes=[mlp_hidden_dim] * num_mlp_layers, + output_size=self.target_property_dim, + dropout=dropout, + ) + + self.remove_mean = remove_mean + + def construct(self, node_features, n_node): + """construct + """ + feat = node_features[_KEY] + pred = self.mlp(feat) + if self.remove_mean: + system_means = aggregate_nodes( + pred, n_node, reduction="mean" + ) + node_broadcasted_means = mint.repeat_interleave( + system_means, n_node, dim=0 + ) + pred = pred - node_broadcasted_means + res = {"node_pred": pred} + return res + + def predict(self, node_features, n_node): + """Predict node-level attributes. + + Args: + node_features: Node features tensor of shape (n_nodes, latent_dim). + n_node: Number of nodes in the graph. + + Returns: + node_pred: Node-level predictions of shape (n_nodes, target_property_dim). + """ + out = self(node_features, n_node) + pred = out["node_pred"] + return self.normalizer.inverse(pred) + + +# pylint: disable=C0301 +class GraphHead(ms.nn.Cell): + r""" + Graph-level prediction head. Implements graph-level prediction head that can be attached to base models + for predicting graph-level properties (e.g., stress tensor) from node features using aggregation and MLP. + + Args: + latent_dim (int): Input feature dimension for each node. + num_mlp_layers (int): Number of hidden layers in MLP. + mlp_hidden_dim (int): Hidden dimension size of MLP. + target_property_dim (int): Output dimension of graph-level property. + node_aggregation (str, optional): Aggregation method for node predictions, e.g., ``"mean"`` or ``"sum"``. Default: ``"mean"``. + dropout (Optional[float], optional): Dropout rate for MLP. Default: ``None``. + compute_stress (bool, optional): Whether to compute and output stress tensor. Default: ``False``. + + Inputs: + - **node_features** (dict) - Node feature dictionary, must contain key "feat" with shape :math:`(n_{nodes}, latent\_dim)`. + - **n_node** (Tensor) - Number of nodes in graph, shape :math:`(1,)`. + + Outputs: + - **output** (dict) - Dictionary containing key "stress_pred" with value of shape :math:`(1, target\_property\_dim)`. + + Raises: + ValueError: If required feature keys are missing in `node_features`. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> import numpy as np + >>> import mindspore + >>> from mindspore import Tensor + >>> from mindchemistry.cell.orb.gns import GraphHead + >>> graph_head = GraphHead( + ... latent_dim=256, + ... num_mlp_layers=1, + ... mlp_hidden_dim=256, + ... target_property_dim=6, + ... compute_stress=True, + ... ) + >>> n_atoms = 4 + >>> n_node = Tensor([n_atoms], mindspore.int32) + >>> atomic_numbers = Tensor(np.random.randint(1, 119, size=(n_atoms,), dtype=np.int32)) + >>> atomic_numbers_embedding_np = np.zeros((n_atoms, 118), dtype=np.float32) + >>> for i, num in enumerate(atomic_numbers.asnumpy()): + ... atomic_numbers_embedding_np[i, num - 1] = 1.0 + >>> node_features = { + ... "atomic_numbers": atomic_numbers, + ... "atomic_numbers_embedding": Tensor(atomic_numbers_embedding_np), + ... "positions": Tensor(np.random.randn(n_atoms, 3).astype(np.float32)), + ... "feat": Tensor(np.random.randn(n_atoms, 256).astype(np.float32)) + ... } + >>> output = graph_head(node_features, n_node) + >>> print(output['stress_pred'].shape) + (1, 6) + """ + + def __init__( + self, + latent_dim: int, + num_mlp_layers: int, + mlp_hidden_dim: int, + target_property_dim: int, + node_aggregation: Literal["sum", "mean"] = "mean", + dropout: Optional[float] = None, + compute_stress: Optional[bool] = False, + ): + """init + """ + super().__init__() + self.target_property_dim = target_property_dim + self.normalizer = ScalarNormalizer() + + self.node_aggregation = node_aggregation + self.mlp = build_mlp( + input_size=latent_dim, + hidden_layer_sizes=[mlp_hidden_dim] * num_mlp_layers, + output_size=self.target_property_dim, + dropout=dropout, + ) + self.output_activation = ops.Identity() + self.compute_stress = compute_stress + + def construct(self, node_features, n_node): + """construct + """ + feat = node_features[_KEY] + + # aggregate to get a tensor of shape (num_graphs, latent_dim) + mlp_input = aggregate_nodes( + feat, + n_node, + reduction=self.node_aggregation, + ) + + pred = self.mlp(mlp_input) + if self.compute_stress: + # name the stress prediction differently + res = {"stress_pred": pred} + else: + res = {"graph_pred": pred} + return res + + def predict(self, node_features, n_node, atomic_numbers=None): + """Predict graph-level attributes. + + Args: + node_features: Node features tensor + n_node: Number of nodes + atomic_numbers: Optional atomic numbers for reference energy calculation + + Returns: + probs: Graph-level predictions of shape (n_graphs, target_property_dim). + If compute_stress is True, this will be the stress tensor. + If compute_stress is False, this will be the graph-level property (e.g., energy). + """ + pred = self(node_features, n_node) + if self.compute_stress: + logits = pred["stress_pred"].squeeze(-1) + else: + assert atomic_numbers is not None, "atomic_numbers must be provided for graph prediction" + logits = pred["graph_pred"].squeeze(-1) + probs = self.output_activation(logits) + probs = self.normalizer.inverse(probs) + return probs + + +# pylint: disable=C0301 +class EnergyHead(GraphHead): + r""" + Graph-level energy prediction head. + Implements neural network head for predicting total energy or per-atom average energy of molecular graphs. + Supports node-level aggregation, reference energy offset, and flexible output modes. + + Args: + latent_dim (int): Input feature dimension for each node. + num_mlp_layers (int): Number of hidden layers in MLP. + mlp_hidden_dim (int): Hidden dimension size of MLP. + target_property_dim (int): Output dimension of energy property (typically 1). + predict_atom_avg (bool, optional): Whether to predict per-atom average energy instead of total energy. Default: ``True``. + reference_energy_name (str, optional): Reference energy name for offset, e.g., ``"vasp-shifted"``. Default: ``"mp-traj-d3"``. + train_reference (bool, optional): Whether to train reference energy as learnable parameter. Default: ``False``. + dropout (Optional[float], optional): Dropout rate for MLP. Default: ``None``. + node_aggregation (str, optional): Aggregation method for node predictions, e.g., ``"mean"`` or ``"sum"``. Default: ``None``. + + Inputs: + - **node_features** (dict) - Node feature dictionary, must contain key "feat" with shape :math:`(n_{nodes}, latent\_dim)`. + - **n_node** (Tensor) - Number of nodes in graph, shape :math:`(1,)`. + + Outputs: + - **output** (dict) - Dictionary containing key "graph_pred" with value of shape :math:`(1, target\_property\_dim)`. + + Raises: + ValueError: If required feature keys are missing in `node_features`. + ValueError: If `node_aggregation` is not a supported type. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> import numpy as np + >>> import mindspore + >>> from mindspore import Tensor + >>> from mindchemistry.cell.orb.gns import EnergyHead + >>> energy_head = EnergyHead( + ... latent_dim=256, + ... num_mlp_layers=1, + ... mlp_hidden_dim=256, + ... target_property_dim=1, + ... node_aggregation="mean", + ... reference_energy_name="vasp-shifted", + ... train_reference=True, + ... predict_atom_avg=True, + ... ) + >>> n_atoms = 4 + >>> n_node = Tensor([n_atoms], mindspore.int32) + >>> atomic_numbers = Tensor(np.random.randint(1, 119, size=(n_atoms,), dtype=np.int32)) + >>> atomic_numbers_embedding_np = np.zeros((n_atoms, 118), dtype=np.float32) + >>> for i, num in enumerate(atomic_numbers.asnumpy()): + ... atomic_numbers_embedding_np[i, num - 1] = 1.0 + >>> node_features = { + ... "atomic_numbers": atomic_numbers, + ... "atomic_numbers_embedding": Tensor(atomic_numbers_embedding_np), + ... "positions": Tensor(np.random.randn(n_atoms, 3).astype(np.float32)), + ... "feat": Tensor(np.random.randn(n_atoms, 256).astype(np.float32)) + ... } + >>> output = energy_head(node_features, n_node) + >>> print(output['graph_pred'].shape) + (1, 1) + """ + + def __init__( + self, + latent_dim: int, + num_mlp_layers: int, + mlp_hidden_dim: int, + target_property_dim: int, + predict_atom_avg: bool = True, + reference_energy_name: str = "mp-traj-d3", + train_reference: bool = False, + dropout: Optional[float] = None, + node_aggregation: Optional[str] = "mean", + ): + """init + """ + ref = REFERENCE_ENERGIES[reference_energy_name] + + super().__init__( + latent_dim=latent_dim, + num_mlp_layers=num_mlp_layers, + mlp_hidden_dim=mlp_hidden_dim, + target_property_dim=target_property_dim, + node_aggregation=node_aggregation, + dropout=dropout, + ) + self.reference = LinearReferenceEnergy( + weight_init=ref.coefficients, trainable=train_reference + ) + self.atom_avg = predict_atom_avg + + def predict(self, node_features, n_node, atomic_numbers=None): + """Predict energy. + + Args: + node_features: Node features tensor + n_node: Number of nodes + atomic_numbers: Optional atomic numbers for reference energy calculation + + Returns: + graph_pred: Energy prediction + """ + if atomic_numbers is None: + raise ValueError("atomic_numbers is required for energy prediction") + + pred = self(node_features, n_node)["graph_pred"] + pred = self.normalizer.inverse(pred).squeeze(-1) + if self.atom_avg: + pred = pred * n_node + pred = pred + self.reference(atomic_numbers, n_node) + return pred + + +# pylint: disable=C0301 +class Orb(ms.nn.Cell): + r""" + Orb graph regressor. + Combines a pretrained base model (e.g., MoleculeGNS) with optional node, graph, and stress regression heads, supporting + fine-tuning or feature extraction workflows. + + Args: + model (MoleculeGNS): Pretrained or randomly initialized base model for message passing and feature extraction. + node_head (NodeHead, optional): Regression head for node-level property prediction. Default: ``None``. + graph_head (GraphHead, optional): Regression head for graph-level property prediction (e.g., energy). Default: ``None``. + stress_head (GraphHead, optional): Regression head for stress prediction. Default: ``None``. + model_requires_grad (bool, optional): Whether to fine-tune the base model (True) or freeze its parameters (False). Default: ``True``. + cutoff_layers (int, optional): If provided, only use the first ``cutoff_layers`` message passing layers of the base model. + Default: ``None``. + + Inputs: + - **edge_features** (dict) - Edge feature dictionary (e.g., `{"vectors": Tensor, "r": Tensor}`). + - **node_features** (dict) - Node feature dictionary (e.g., `{"atomic_numbers": Tensor, ...}`). + - **senders** (Tensor) - Sender node indices for each edge. Shape: :math:`(n_{edges},)`. + - **receivers** (Tensor) - Receiver node indices for each edge. Shape: :math:`(n_{edges},)`. + - **n_node** (Tensor) - Number of nodes for each graph in the batch. Shape: :math:`(n_{graphs},)`. + + Outputs: + - **output** (dict) - Dictionary containing: + - **edges** (dict) - Edge features after message passing, e.g., `{..., "feat": Tensor}`. + - **nodes** (dict) - Node features after message passing, e.g., `{..., "feat": Tensor}`. + - **graph_pred** (Tensor) - Graph-level predictions, e.g., energy. Shape: :math:`(n_{graphs}, target\_property\_dim)`. + - **node_pred** (Tensor) - Node-level predictions. Shape: :math:`(n_{nodes}, target\_property\_dim)`. + - **stress_pred** (Tensor) - Stress predictions (if stress_head is provided). Shape: :math:`(n_{graphs}, 6)`. + + Raises: + ValueError: If neither node_head nor graph_head is provided. + ValueError: If cutoff_layers exceeds the number of message passing steps in the base model. + ValueError: If atomic_numbers is not provided when graph_head is required. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> import numpy as np + >>> import mindspore + >>> from mindspore import Tensor + >>> from mindchemistry.cell.orb import Orb, MoleculeGNS, EnergyHead, NodeHead, GraphHead + >>> Orb = Orb( + ... model=MoleculeGNS( + ... num_node_in_features=256, + ... num_node_out_features=3, + ... num_edge_in_features=23, + ... latent_dim=256, + ... interactions="simple_attention", + ... interaction_params={ + ... "distance_cutoff": True, + ... "polynomial_order": 4, + ... "cutoff_rmax": 6, + ... "attention_gate": "sigmoid", + ... }, + ... num_message_passing_steps=15, + ... num_mlp_layers=2, + ... mlp_hidden_dim=512, + ... use_embedding=True, + ... node_feature_names=["feat"], + ... edge_feature_names=["feat"], + ... ), + ... graph_head=EnergyHead( + ... latent_dim=256, + ... num_mlp_layers=1, + ... mlp_hidden_dim=256, + ... target_property_dim=1, + ... node_aggregation="mean", + ... reference_energy_name="vasp-shifted", + ... train_reference=True, + ... predict_atom_avg=True, + ... ), + ... node_head=NodeHead( + ... latent_dim=256, + ... num_mlp_layers=1, + ... mlp_hidden_dim=256, + ... target_property_dim=3, + ... remove_mean=True, + ... ), + ... stress_head=GraphHead( + ... latent_dim=256, + ... num_mlp_layers=1, + ... mlp_hidden_dim=256, + ... target_property_dim=6, + ... compute_stress=True, + ... ), + ... ) + >>> n_atoms = 4 + >>> n_edges = 10 + >>> n_node = Tensor([n_atoms], mindspore.int32) + >>> atomic_numbers = Tensor(np.random.randint(1, 119, size=(n_atoms,), dtype=np.int32)) + >>> atomic_numbers_embedding_np = np.zeros((n_atoms, 118), dtype=np.float32) + >>> for i, num in enumerate(atomic_numbers.asnumpy()): + ... atomic_numbers_embedding_np[i, num - 1] = 1.0 + >>> node_features = { + ... "atomic_numbers": atomic_numbers, + ... "atomic_numbers_embedding": Tensor(atomic_numbers_embedding_np), + ... "positions": Tensor(np.random.randn(n_atoms, 3).astype(np.float32)) + ... } + >>> edge_features = { + ... "vectors": Tensor(np.random.randn(n_edges, 3).astype(np.float32)), + ... "r": Tensor(np.abs(np.random.randn(n_edges).astype(np.float32) * 10)) + ... } + >>> senders = Tensor(np.random.randint(0, n_atoms, size=(n_edges,), dtype=np.int32)) + >>> receivers = Tensor(np.random.randint(0, n_atoms, size=(n_edges,), dtype=np.int32)) + >>> output = Orb(edge_features, node_features, senders, receivers, n_node) + >>> print(output['graph_pred'].shape, output['node_pred'].shape, output['stress_pred'].shape) + (1, 1) (4, 3) (1, 6) + """ + + def __init__( + self, + model: MoleculeGNS, + node_head: Optional[NodeHead] = None, + graph_head: Optional[GraphHead] = None, + stress_head: Optional[GraphHead] = None, + model_requires_grad: bool = True, + cutoff_layers: Optional[int] = None, + ): + """init + """ + super().__init__() + + if (node_head is None) and (graph_head is None): + raise ValueError("Must provide at least one node/graph head.") + self.node_head = node_head + self.graph_head = graph_head + self.stress_head = stress_head + self.cutoff_layers = cutoff_layers + + self.model = model + + if self.cutoff_layers is not None: + if self.cutoff_layers > self.model.num_message_passing_steps: + raise ValueError( + f"cutoff_layers ({self.cutoff_layers}) must be less than or equal to" + f" the number of message passing steps ({self.model.num_message_passing_steps})" + ) + self.model.gnn_stacks = self.model.gnn_stacks[: self.cutoff_layers] + self.model.num_message_passing_steps = self.cutoff_layers + + self.model_requires_grad = model_requires_grad + + if not model_requires_grad: + for param in self.model.parameters(): + param.requires_grad = False + + + def predict(self, edge_features, node_features, senders, receivers, n_node, atomic_numbers): + """Predict node and/or graph level attributes. + + Args: + edge_features: A dictionary, e.g., `{"vectors": Tensor, "r": Tensor}`. + node_features: A dictionary, e.g., `{"atomic_numbers": Tensor, "positions": Tensor, + "atomic_numbers_embedding": Tensor}`. + senders: A tensor of shape (n_edges,) containing the sender node indices. + receivers: A tensor of shape (n_edges,) containing the receiver node indices. + n_node: A tensor of shape (1,) containing the number of nodes. + atomic_numbers: A tensor of atomic numbers for reference energy calculation. + + Returns: + ouput_dict: A dictionary containing the predictions: + - `graph_pred`: Graph-level predictions (e.g., energy) of shape (n_graphs, graph_property_dim). + - `stress_pred`: Stress predictions (if stress_head is provided) of shape (n_graphs, stress_dim). + - `node_pred`: Node-level predictions of shape (n_nodes, node_property_dim). + """ + _, nodes = self.model(edge_features, node_features, senders, receivers) + + output = {} + output["graph_pred"] = self.graph_head.predict(nodes, n_node, atomic_numbers) + output["stress_pred"] = self.stress_head.predict(nodes, n_node) + output["node_pred"] = self.node_head.predict(nodes, n_node) + + return output + + def construct(self, edge_features, node_features, senders, receivers, n_node): + """construct + """ + edges, nodes = self.model(edge_features, node_features, senders, receivers) + + res = {"edges": edges, "nodes": nodes} + res.update(self.graph_head(nodes, n_node)) + res.update(self.stress_head(nodes, n_node)) + res.update(self.node_head(nodes, n_node)) + + return res diff --git a/MindChem/mindchemistry/cell/orb/utils.py b/MindChem/mindchemistry/cell/orb/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..26797d9df20cc84e83ce9a20ebf63653fc2be33b --- /dev/null +++ b/MindChem/mindchemistry/cell/orb/utils.py @@ -0,0 +1,737 @@ +# ============================================================================ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Utils.""" + +from typing import NamedTuple, List, Optional, Type + +import numpy as np +import mindspore as ms +from mindspore import nn, ops, Tensor, mint, context + +MSINT = [ms.int64, ms.int32, ms.int16, ms.int8, ms.uint8] + + +def aggregate_nodes(tensor: Tensor, n_node: Tensor, reduction: str = "mean", deterministic: bool = False) -> Tensor: + """Aggregates over a tensor based on graph sizes.""" + count = len(n_node) + if deterministic: + ms.set_seed(1) + segments = ops.arange(count).repeat_interleave(n_node).astype(ms.int32) + if reduction == "sum": + return scatter_sum(tensor, segments, dim=0) + if reduction == "mean": + return scatter_mean(tensor, segments, dim=0) + if reduction == "max": + return scatter_max(tensor, segments, dim=0) + raise ValueError("Invalid reduction argument. Use sum, mean or max.") + + +def segment_sum(data: Tensor, segment_ids: Tensor, num_segments: int): + """Computes index based sum over segments of a tensor.""" + return scatter_sum(data, segment_ids, dim=0, dim_size=num_segments) + + +def segment_max(data: Tensor, segment_ids: Tensor, num_segments: int): + """Computes index based max over segments of a tensor.""" + assert segment_ids is not None, "segment_ids must not be None" + assert num_segments > 0, "num_segments must be greater than 0" + max_op = ops.ArgMaxWithValue(axis=0) + _, max_values = max_op(data) + return max_values + + +def segment_mean(data: Tensor, segment_ids: Tensor, num_segments: int): + """Computes index based mean over segments of a tensor.""" + sum_v = segment_sum(data, segment_ids, num_segments) + count = ops.scatter_add(ops.zeros( + (num_segments,), dtype=ms.int32), segment_ids, ops.ones_like(segment_ids)) + return sum_v / count.astype(sum_v.dtype) + + +def segment_softmax(data: Tensor, segment_ids: Tensor, num_segments: int, weights: Optional[Tensor] = None): + """Computes a softmax over segments of the tensor.""" + data_max = segment_max(data, segment_ids, num_segments) + data = data - data_max[segment_ids] + + unnormalised_probs = ops.exp(data) + if weights is not None: + unnormalised_probs = unnormalised_probs * weights + denominator = segment_sum(unnormalised_probs, segment_ids, num_segments) + + return safe_division(unnormalised_probs, denominator, segment_ids) + + +def safe_division(numerator: Tensor, denominator: Tensor, segment_ids: Tensor): + """Divides logits by denominator, setting 0 where the denominator is zero.""" + result = ops.where(denominator[segment_ids] == + 0, 0, numerator / denominator[segment_ids]) + return result + + +def _broadcast(src: Tensor, other: Tensor, dim: int): + """Broadcasts the source tensor to match the shape of the other tensor along the specified dimension.""" + if dim < 0: + dim = other.ndim + dim + if src.ndim == 1: + for _ in range(0, dim): + src = src.unsqueeze(0) + for _ in range(src.ndim, other.ndim): + src = src.unsqueeze(-1) + src = src.expand_as(other) + return src + + +def scatter_sum( + src: Tensor, index: Tensor, dim: int = -1, out: Optional[Tensor] = None, + dim_size: Optional[int] = None, reduce: str = "sum" +) -> Tensor: + """Applies a sum reduction of the orb_models tensor along the specified dimension.""" + assert reduce == "sum" + index = _broadcast(index, src, dim) + if out is None: + size = list(src.shape) + if dim_size is not None: + size[dim] = dim_size + elif index.numel() == 0: + size[dim] = 0 + else: + size[dim] = int(index.max()) + 1 + out = ops.zeros(size, dtype=src.dtype) + return mint.scatter_add(out, dim, index, src) + return mint.scatter_add(out, dim, index, src) + + +def scatter_std( + src: Tensor, index: Tensor, dim: int = -1, out: Optional[Tensor] = None, + dim_size: Optional[int] = None, unbiased: bool = True +) -> Tensor: + """Computes the standard deviation of the orb_models tensor along the specified dimension.""" + if out is not None: + dim_size = out.shape[dim] + + if dim < 0: + dim = src.ndim + dim + + count_dim = dim + if index.ndim <= dim: + count_dim = index.ndim - 1 + + ones = ops.ones(index.shape, dtype=src.dtype) + count = scatter_sum(ones, index, count_dim, dim_size=dim_size) + + index = _broadcast(index, src, dim) + tmp = scatter_sum(src, index, dim, dim_size=dim_size) + count = _broadcast(count, tmp, dim).clip(1) + mean = tmp / count + + var = src - mean.gather(dim, index) + var = var * var + out = scatter_sum(var, index, dim, out=out, dim_size=dim_size) + + if unbiased: + count = count - 1 + count = count.clip(1) + out = out / (count + 1e-6) + out = ops.sqrt(out) + return out + + +def scatter_mean( + src: Tensor, index: Tensor, dim: int = -1, out: Optional[Tensor] = None, + dim_size: Optional[int] = None +) -> Tensor: + """Computes the mean of the orb_models tensor along the specified dimension.""" + out = scatter_sum(src, index, dim, out=out, dim_size=dim_size) + dim_size = out.shape[dim] + + index_dim = dim + if index_dim < 0: + index_dim = index_dim + src.ndim + if index.ndim <= index_dim: + index_dim = index.ndim - 1 + + ones = ops.ones(index.shape, dtype=src.dtype) + count = scatter_sum(ones, index, index_dim, dim_size=dim_size) + count = count.clip(1) + count = _broadcast(count, out, dim) + out = out / count + return out + + +def scatter_max( + src: Tensor, index: Tensor, dim: int = -1, out: Optional[Tensor] = None, + dim_size: Optional[int] = None +) -> Tensor: + """Computes the maximum of the orb_models tensor for each group defined by index along the specified dimension.""" + if out is not None: + raise NotImplementedError( + "The 'out' argument is not supported for scatter_max") + + if src.dtype in MSINT: + init_value = np.iinfo(src.dtype).min + else: + init_value = np.finfo(src.dtype).min + + if dim < 0: + dim = src.ndim + dim + + if dim_size is None: + dim_size = int(index.max()) + 1 + + result = ops.ones( + (dim_size, *src.shape[:dim], *src.shape[dim + 1:]), dtype=src.dtype) + result = init_value * result + broadcasted_index = _broadcast(index, src, dim) + + scatter_result = ops.ZerosLike()(result) + index = ops.expand_dims(broadcasted_index, dim) + scatter_result = scatter_result.scatter_update(index, src) + result = ops.Maximum()(result, scatter_result) + return result + + +class SSP(nn.Cell): + """Shifted Softplus activation function. + + This activation is twice differentiable so can be used when regressing + gradients for conservative force fields. + """ + + def __init__(self, beta: int = 1, threshold: int = 20): + super().__init__() + self.beta = beta + self.threshold = threshold + + def construct(self, input_x: Tensor) -> Tensor: + sp0 = ops.softplus(ops.zeros(1), self.beta, self.threshold) + return ops.softplus(input_x, self.beta, self.threshold) - sp0 + + +def build_mlp( + input_size: int, + hidden_layer_sizes: List[int], + output_size: Optional[int] = None, + output_activation: Type[nn.Cell] = nn.Identity, + activation: Type[nn.Cell] = SSP, + dropout: Optional[float] = None, +) -> nn.Cell: + """Build a MultiLayer Perceptron. + + Args: + input_size: Size of input layer. + hidden_layer_sizes: An array of input size for each hidden layer. + output_size: Size of the output layer. + output_activation: Activation function for the output layer. + activation: Activation function for the hidden layers. + dropout: Dropout rate for hidden layers. + checkpoint: Whether to use checkpointing. + + Returns: + mlp: An MLP sequential container. + """ + # Size of each layer + layer_sizes = [input_size] + hidden_layer_sizes + if output_size: + layer_sizes.append(output_size) + + # Number of layers + nlayers = len(layer_sizes) - 1 + + # Create a list of activation functions and + # set the last element to output activation function + act = [activation for _ in range(nlayers)] + act[-1] = output_activation + + # Create a list to hold layers + layers = [] + for i in range(nlayers): + if dropout is not None: + layers.append(nn.Dropout(keep_prob=1 - dropout)) + layers.append(nn.Dense(layer_sizes[i], layer_sizes[i + 1])) + layers.append(act[i]()) + + # Create a sequential container + mlp = nn.SequentialCell(layers) + return mlp + + +class CheckpointedSequential(nn.Cell): + """Sequential container with checkpointing.""" + + def __init__(self, *args, n_layers: int = 1): + super().__init__() + self.n_layers = n_layers + self.layers = nn.CellList(list(args)) + + def construct(self, input_x: Tensor) -> Tensor: + """Forward pass with checkpointing enabled in training mode.""" + if context.get_context("mode") == context.GRAPH_MODE: + # In graph mode, checkpointing is handled by MindSpore's graph optimization + for layer in self.layers: + input_x = layer(input_x) + else: + # In PyNative mode, we can manually checkpoint each layer + for i in range(self.n_layers): + input_x = self.layers[i](input_x) + return input_x + + +class ReferenceEnergies(NamedTuple): + """ + Reference energies for an atomic system. + + Our vasp reference energies are computed by running vasp + optimisations on a single atom of each atom-type. + + Other reference energies are fitted using least-squares. + + Doing so with mp-traj-d3 gives the following: + + ---------- LSTQ ---------- + Reference MAE: 13.35608855004781 + (energy - ref) mean: 1.3931169304958624 + (energy - ref) std: 22.45615276341948 + (energy - ref)/natoms mean: 0.16737045963056316 + (energy - ref)/natoms std: 0.8189314920219992 + CO2: Predicted vs DFT: -23.154158610392408 vs -22.97 + H2O: Predicted vs DFT: -11.020918107591324 vs - 14.23 + ---------- VASP ---------- + Reference MAE: 152.4722089438871 + (energy - ref) mean: -152.47090833346033 + (energy - ref) std: 153.89049784836962 + (energy - ref)/natoms mean: -4.734136414817941 + (energy - ref)/natoms std: 1.3603868419157275 + CO2: Predicted vs DFT: -4.35888857 vs -22.97 + H2O: Predicted vs DFT: -2.66521147 vs - 14.23 + ---------- Shifted VASP ---------- + Reference MAE: 28.95948216608197 + (energy - ref) mean: 0.7083632520428979 + (energy - ref) std: 48.61861182844561 + (energy - ref)/natoms mean: 0.17320099403091083 + (energy - ref)/natoms std: 1.3603868419157275 + CO2: Predicted vs DFT: -19.080900796546562 vs -22.97 + H2O: Predicted vs DFT: -12.479886287697706 vs - 14.23 + + Args: + coefficients: Coefficients for each atom in the periodic table. + Must be of length 118 with first entry equal to 0. + residual_mean: Mean of (pred - target) + residual_std: Standard deviation of (pred - target) + residual_mean_per_atom: Mean of (pred - target)/n_atoms. + residual_std_per_atom: Standard deviation of (pred - target)/n_atoms. + """ + + coefficients: np.ndarray + residual_mean: float + residual_std: float + residual_mean_per_atom: float + residual_std_per_atom: float + + +# We have only computed these for the first +# 88 elements, and padded the remainder with 0. +vasp_reference_energies = ReferenceEnergies( + coefficients=np.array( + [ + 0.0, # padding + -1.11725225e00, + 7.69290000e-04, + -3.22788480e-01, + -4.47021900e-02, + -2.90627280e-01, + -1.26297013e00, + -3.12415058e00, + -1.54795922e00, + -4.39757050e-01, + -1.25673900e-02, + -2.63927430e-01, + -1.92670300e-02, + -2.11267040e-01, + -8.24799500e-01, + -1.88734631e00, + -8.91048980e-01, + -2.58371430e-01, + -2.50008000e-02, + -2.71936150e-01, + -7.11147600e-02, + -2.06076796e00, + -2.42753196e00, + -3.57144559e00, + -5.45540047e00, + -5.15708214e00, + -3.31393675e00, + -1.84639284e00, + -6.32812480e-01, + -2.38017450e-01, + -1.41047600e-02, + -2.06349980e-01, + -7.77477960e-01, + -1.70160351e00, + -7.84231510e-01, + -2.27541260e-01, + -2.26104900e-02, + -2.79760570e-01, + -9.92851900e-02, + -2.18560872e00, + -2.26603086e00, + -3.14842282e00, + -4.61199158e00, + -3.34329762e00, + -2.48233722e00, + -1.27872811e00, + -1.47784242e00, + -2.04068960e-01, + -1.89639300e-02, + -1.88520140e-01, + -6.76700640e-01, + -1.42966694e00, + -6.57608340e-01, + -1.89308030e-01, + -1.20491300e-02, + -3.07991050e-01, + -1.58601400e-01, + -4.89728600e-01, + -1.35031403e00, + -3.31509450e-01, + -3.23660410e-01, + -3.15316610e-01, + -3.11184530e-01, + -8.44684689e00, + -1.04408371e01, + -2.30922790e-01, + -2.26295040e-01, + -2.92747580e-01, + -2.92191740e-01, + -2.91465170e-01, + -3.80611000e-02, + -2.87691040e-01, + -3.51528971e00, + -3.51343142e00, + -4.64232388e00, + -2.88816624e00, + -1.46089612e00, + -5.36042350e-01, + -1.87182020e-01, + -1.33549100e-02, + -1.68142250e-01, + -6.25378750e-01, + -1.32291753e00, + -3.26246040e-01, + -1.10239294e00, + -2.30839543e00, + -4.61968511e00, + -7.30638139e00, + -1.04613411e01, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + ] + ), + residual_mean=-152.47090833346033, + residual_std=153.89049784836962, + residual_mean_per_atom=-4.734136414817941, + residual_std_per_atom=1.3603868419157275, +) + +vasp_shifted_reference_energies = ReferenceEnergies( + coefficients=np.array( + [ + 0.0, # padding + -6.0245896588488534, + -4.9065681188488535, + -5.230125888848853, + -4.952039598848853, + -5.197964688848853, + -6.170307538848854, + -8.031487988848854, + -6.455296628848854, + -5.347094458848853, + -4.919904798848854, + -5.171264838848853, + -4.9266044388488535, + -5.118604448848854, + -5.732136908848854, + -6.794683718848853, + -5.798386388848853, + -5.165708838848854, + -4.932338208848853, + -5.179273558848854, + -4.978452168848854, + -6.968105368848853, + -7.334869368848853, + -8.478782998848853, + -10.362737878848854, + -10.064419548848853, + -8.221274158848853, + -6.7537302488488535, + -5.540149888848854, + -5.145354858848854, + -4.921442168848854, + -5.113687388848853, + -5.684815368848853, + -6.6089409188488535, + -5.691568918848853, + -5.134878668848853, + -4.929947898848853, + -5.187097978848853, + -5.006622598848853, + -7.092946128848853, + -7.173368268848853, + -8.055760228848854, + -9.519328988848853, + -8.250635028848853, + -7.389674628848853, + -6.186065518848854, + -6.3851798288488535, + -5.111406368848853, + -4.9263013388488535, + -5.095857548848853, + -5.5840380488488535, + -6.337004348848853, + -5.564945748848854, + -5.096645438848854, + -4.919386538848854, + -5.2153284588488535, + -5.065938808848854, + -5.397066008848854, + -6.257651438848853, + -5.238846858848854, + -5.230997818848854, + -5.2226540188488535, + -5.218521938848854, + -13.354184298848853, + -15.348174508848853, + -5.138260198848854, + -5.133632448848854, + -5.200084988848854, + -5.199529148848853, + -5.198802578848854, + -4.945398508848854, + -5.195028448848854, + -8.422627118848853, + -8.420768828848853, + -9.549661288848853, + -7.795503648848854, + -6.368233528848854, + -5.443379758848853, + -5.094519428848853, + -4.920692318848854, + -5.075479658848853, + -5.532716158848854, + -6.230254938848853, + -5.2335834488488535, + -6.009730348848853, + -7.2157328388488535, + -9.527022518848852, + -12.213718798848854, + -15.368678508848854, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + -4.9073374088488535, + ] + ), + residual_mean=0.7083632520428979, + residual_std=48.61861182844561, + residual_mean_per_atom=0.17320099403091083, + residual_std_per_atom=1.3603868419157275, +) + +mp_traj_d3_reference_energies = ReferenceEnergies( + coefficients=np.array( + [ + 0.0, # padding + -3.6818229500085327, + -1.3199148098871394, + -3.688797198716366, + -4.938608191337134, + -7.901604711660046, + -8.475968295226822, + -7.42601366967988, + -7.339095157582792, + -4.9239197309790725, + -0.061236726924086424, + -3.0526401941340806, + -3.0836199809602105, + -5.055909838526647, + -7.875649504560413, + -7.175538036602013, + -4.814514763424572, + -2.9198, + -0.13127266880110078, + -2.8792125576832865, + -5.635016298424046, + -8.164720105254204, + -10.712143655281858, + -9.00292017736733, + -9.619640942931085, + -8.610981088341331, + -7.3506162257219385, + -5.943664565392655, + -5.592846831852426, + -3.6868017794232077, + -1.579885044321145, + -3.744040760877656, + -4.945137332817033, + -4.2021571924020655, + -4.045303645442562, + -2.652667661940346, + 6.497305115069106, + -2.806819346028444, + -5.164089337915934, + -10.493037547114369, + -12.256967896681578, + -12.642602087796805, + -9.20874164629371, + -9.292405362859506, + -8.304141715175632, + -7.49355696426791, + -5.44150554776011, + -2.5621691409635474, + -0.9687174918829102, + -3.055905969721681, + -4.02975498585447, + -3.847125028451477, + -3.1016305514702203, + -1.8001556831915142, + 9.742275211909387, + -3.045410331644577, + -5.204088972093178, + -9.267561428901118, + -9.031458669303145, + -8.345252241333469, + -8.584977779192705, + -7.955970517402418, + -8.519743221802353, + -13.927799873369949, + -19.12242499580686, + -8.156787154342183, + -8.505944162624234, + -8.015433843487497, + -7.129355408977684, + -8.166165621829014, + -3.9995952334750644, + -7.884852034766514, + -13.281575162667238, + -14.598283494757041, + -9.729591400065184, + -11.798570715867179, + -9.878207068760076, + -7.891075131963705, + -5.964524120587406, + -2.9665634245721275, + -0.10530075207060852, + -2.649420791761001, + -4.00193074336809, + -3.7403644338639785, + -1.5543122344752192e-15, + -8.881784197001252e-16, + -8.881784197001252e-16, + 0.0, + 0.0, + -5.480602125607218, + -11.9439263006771, + -12.974770001312883, + -14.376719109855834, + -15.49262474740642, + -16.02533150334938, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ] + ), + residual_mean=1.3931169304958624, + residual_std=22.45615276341948, + residual_mean_per_atom=0.16737045963056316, + residual_std_per_atom=0.8189314920219992, +) + +REFERENCE_ENERGIES = { + "vasp": vasp_reference_energies, + "vasp-shifted": vasp_shifted_reference_energies, + "mp-traj-d3": mp_traj_d3_reference_energies, +} diff --git a/MindChemistry/applications/crystalflow/README.md b/MindChemistry/applications/crystalflow/README.md new file mode 100644 index 0000000000000000000000000000000000000000..afdea75c1fca29bd71d682992d6e88e8296d270f --- /dev/null +++ b/MindChemistry/applications/crystalflow/README.md @@ -0,0 +1,131 @@ + +# 模型名称 + +> CrystalFlow + +## 介绍 + +> 理论晶体结构预测是通过计算的手段寻找物质在给定的外界条件下最稳定结构的重要手段。传统结构预测方法依赖在势能面上广泛的随机采样来寻找最稳定结构,然而,这种方法需要对大量随机生成的结构进行局域优化,而局域优化通常需要消耗巨大的第一性原理计算成本,尤其在模拟多元素复杂体系时,这种计算开销会显著增加,从而带来巨大的挑战。近年来,基于深度学习生成模型的晶体结构生成方法因其能够在势能面上更高效地采样合理结构而逐渐受到关注。这种方法通过从已有的稳定或局域稳定结构数据中学习,进而生成合理的晶体结构,与随机采样相比,不仅能够减少局域优化的计算成本,还能通过较少的采样找到体系的最稳定结构。采用神经常微分方程和连续变化建模概率密度的归一化流流模型,相比采用扩散模型方法的生成模型具有更加简洁、灵活、高效的优点。本方法基于流模型架构,发展了以CrystalFlow命名的晶体结构生成模型,在MP20等基准数据集上达到优秀的水平。 + +## 环境要求 + +> 1. 安装`mindspore(2.5.0)` +> 2. 安装依赖包:`pip install -r requirement.txt` + +## 快速入门 + +> 1. 将Mindchemistry/mindchemistry文件包下载到当前目录 +> 2. 在[数据集链接](https://download-mindspore.osinfra.cn/mindscience/mindchemistry/diffcsp/dataset/)下载相应的数据集 +> 3. 安装依赖包:`pip install -r requirement.txt` +> 4. 训练命令: `python train.py` +> 5. 预测命令: `python evaluate.py` +> 6. 评估命令: `python compute_metric.py` +> 7. 评估结果放在`config.yaml`中指定的`metric_dir`路径的json文件中 + +### 代码目录结构 + +```text +代码主要模块在models文件夹下,其中cspnet.py是网络层,flow.py是流模型模块.data文件夹下是数据集处理模块。 + +applications + └── crystalflow # 模型名 + ├── readme.md # readme文件 + ├── config.yaml # 配置文件 + ├── train.py # 训练启动脚本 + ├── evaluate.py # 推理启动脚本 + ├── compute_metric.py # 评估启动脚本 + ├── requirement.txt # 环境依赖 + ├── data # 数据处理模块 + | ├── data_utils.py # 工具模块 + | ├── dataset.py # 构造数据集 + | └── crysloader.py # 构造数据加载器 + └── models + ├── conditioning.py # 条件生成工具模块 + ├── cspnet.py # 基于图神经网络的去噪器模块 + ├── cspnet_condition.py # 条件生成的网络层 + ├── diff_utils.py # 工具模块 + ├── flow.py # 流模型模块 + ├── flow_condition.py # 条件生成的流模型 + ├── infer_utils.py # 推理工具模块 + ├── lattice.py # 晶格矩阵处理工具 + └── train_utils.py # 训练工具模块 + +``` + +## 下载数据集 + +在[数据集链接](https://download-mindspore.osinfra.cn/mindscience/mindchemistry/diffcsp/dataset/)中下载相应的数据集文件夹和dataset_prop.txt数据集属性文件放置于当前路径的dataset文件夹下(如果没有需要自己手动创建),文件路径参考: + +```txt +crystalflow + ... + └─dataset + perov_5 钙钛矿数据集 + carbon_24 碳晶体数据集 + mp_20 晶胞内原子数最多为20的MP数据集 + mpts_52 晶胞内原子数最多为52的MP数据集 + dataset_prop.txt 数据集属性文件 + ... +``` + +## 训练过程 + +### 训练 + +将Mindchemistry/mindchemistry文件包下载到当前目录; + +更改config文件,设置训练参数: +> 1. 设置训练的dataset,见dataset字段 +> 2. 设置去噪器模型的配置,见model字段 +> 3. 设置训练保存的权重文件,更改train.ckpt_dir文件夹名称和checkpoint.last_path权重文件名称 +> 4. 其它训练设置见train字段 + +```bash +pip install -r requirement.txt +python train.py +``` + +### 推理 + +更改config文件中的test字段来更改推理参数,特别是test.num_eval,它**决定了对于每个组分生成多少个样本**,对于后续的评估阶段很重要。 + +```bash +python evaluate.py +``` + +推理得到的晶体将保存在test.eval_save_path指定的文件中 + +文件中存储的内容为python字典,格式为: + +```python +{ + 'pred': [ + [晶体A sample 1, 晶体A sample 2, 晶体A sample 3, ... 晶体A sample num_eval], + [晶体B sample 1, 晶体B sample 2, 晶体B sample 3, ... 晶体B sample num_eval] + ... + ] + 'gt': [ + 晶体A ground truth, + 晶体B ground truth, + ... + ] +} +``` + +### 评估 + +将推理得到的晶体文件的path写入config文件的test.eval_save_path中; + +确保num_evals与进行推理时设置的对于每个组分生成样本的数量一致或更小。比如进行推理时,num_evals设置为1,那么评估时,num_evals只能设置为1;推理时,num_evals设置为20,那么评估时,num_evals可以设置为1-20的数字来进行评估。 + +更改config文件中的test.metric_dir字段来设置评估结果的保存路径 + +```bash +python compute_metric.py +``` + +得到的评估结果文件示例: + +```json +{"match_rate": 0.6107671899181959, "rms_dist": 0.07492558322002925} +``` diff --git a/MindChemistry/applications/crystalflow/test_crystalflow.py b/MindChemistry/applications/crystalflow/test_crystalflow.py new file mode 100644 index 0000000000000000000000000000000000000000..0f9531b135384ba2e7486eb0caaf5fa412fd7ac5 --- /dev/null +++ b/MindChemistry/applications/crystalflow/test_crystalflow.py @@ -0,0 +1,197 @@ +"""model test""" +import math +import os +import urllib.request + +import mindspore as ms +import mindspore.numpy as mnp +from mindspore import nn, ops, Tensor, mint, load_checkpoint, load_param_into_net +from mindchemistry.graph.loss import L2LossMask +import numpy as np + + +from models.cspnet import CSPNet +from models.flow import CSPFlow +from data.dataset import fullconnect_dataset +from data.crysloader import Crysloader as DataLoader + + +ms.set_seed(1234) +np.random.seed(1234) + +class SinusoidalTimeEmbeddings(nn.Cell): + """time embedding""" + def __init__(self, dim): + super(SinusoidalTimeEmbeddings, self).__init__() + self.dim = dim + + def construct(self, time): + half_dim = self.dim // 2 + embeddings = math.log(10000) / (half_dim - 1) + embeddings = ops.Exp()(mnp.arange(half_dim) * -embeddings) + embeddings = time[:, None] * embeddings[None, :] + embeddings = ops.Concat(axis=-1)( + (ops.Sin()(embeddings), ops.Cos()(embeddings))) + return embeddings + +def download_file(url, filename): + urllib.request.urlretrieve(url, filename) + print(f"File downloaded successfully: {filename}") + +def test_cspnet(): + """test cspnet.py""" + ms.set_seed(1234) + time_embedding = SinusoidalTimeEmbeddings(256) + cspnet = CSPNet(num_layers=6, hidden_dim=512, num_freqs=128) + atom_types = Tensor([61, 12, 52, 52, 46, 46], dtype=ms.int32) + frac_coords = Tensor( + [[5.00000000e-01, 5.00000000e-01, 5.00000000e-01], + [0.00000000e+00, 0.00000000e+00, 0.00000000e+00], + [6.66666687e-01, 3.33333343e-01, 7.50000000e-01], + [3.33333343e-01, 6.66666687e-01, 2.50000000e-01], + [0.00000000e+00, 0.00000000e+00, 0.00000000e+00], + [0.00000000e+00, 0.00000000e+00, 5.00000000e-01]], dtype=ms.float32) + lengths = Tensor( + [[3.86215806e+00, 3.86215806e+00, 3.86215806e+00], + [4.21191406e+00, 4.21191454e+00, 5.75016499e+00]], dtype=ms.float32) + lattice_polar = Tensor( + [[0.00000000e+00, 0.00000000e+00, 3.97458431e-1, 5.55111512e-16, 0.00000000e+00, 1.35122609e+00], + [-2.74653047e-01, 1.58676151e-16, 6.82046943e-17, -5.38849108e-08, -1.27743945e-01, 1.49374068e+00]], + dtype=ms.float32) + edge_index = Tensor( + [[0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5], + [0, 1, 0, 1, 2, 3, 4, 5, 2, 3, 4, 5, 2, 3, 4, 5, 2, 3, 4, 5]], dtype=ms.int32) + node2graph = Tensor([0, 0, 1, 1, 1, 1], dtype=ms.int32) + node_mask = Tensor([1, 1, 1, 1, 1, 1], dtype=ms.int32) + edge_mask = Tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=ms.int32,) + tar_lat_polar = Tensor( + [[-0.5366, 0.5920, 0.2546, 0.4013, -0.0032, 0.6611], + [-0.5696, 0.6870, 0.2512, 0.4647, 0.0228, 0.5979]] + ) + tar_coord = Tensor([[-0.7573, 0.2272, -0.4823], + [-0.7647, 0.2261, -0.4763], + [-0.7841, 0.2948, -0.3861], + [-0.7872, 0.2915, -0.3810], + [-0.7789, 0.2759, -0.4070], + [-0.7785, 0.2757, -0.4070]]) + + np.random.seed(1234) + times = np.random.rand(lengths.shape[0]) + times = ms.tensor(times, dtype=ms.float32) + t = time_embedding(times) + lattices_out, coords_out = cspnet(t, atom_types, frac_coords, lattice_polar, node2graph,\ + edge_index, node_mask, edge_mask) + assert mint.isclose(lattices_out, tar_lat_polar, rtol=1e-4, atol=1e-4).all(), \ + f"For `cspnet`, the output should be {tar_lat_polar}, but got {lattices_out}." + assert mint.isclose(coords_out, tar_coord, rtol=1e-4, atol=1e-4).all(), \ + f"For `cspnet`, the output should be {tar_coord}, but got {coords_out}." + +def test_flow(): + """test flow.py""" + ms.set_seed(1234) + cspnet = CSPNet(num_layers=6, hidden_dim=512, num_freqs=128) + cspflow = CSPFlow(cspnet) + atom_types = Tensor([61, 12, 52, 52, 46, 46], dtype=ms.int32) + frac_coords = Tensor( + [[5.00000000e-01, 5.00000000e-01, 5.00000000e-01], + [0.00000000e+00, 0.00000000e+00, 0.00000000e+00], + [6.66666687e-01, 3.33333343e-01, 7.50000000e-01], + [3.33333343e-01, 6.66666687e-01, 2.50000000e-01], + [0.00000000e+00, 0.00000000e+00, 0.00000000e+00], + [0.00000000e+00, 0.00000000e+00, 5.00000000e-01]], dtype=ms.float32) + lengths = Tensor( + [[3.86215806e+00, 3.86215806e+00, 3.86215806e+00], + [4.21191406e+00, 4.21191454e+00, 5.75016499e+00]], dtype=ms.float32) + angles = Tensor( + [[9.00000000e+01, 9.00000000e+01, 9.00000000e+01], + [9.00000000e+01, 9.00000000e+01, 1.20000000e+02]], dtype=ms.float32) + lattice_polar = Tensor( + [[0.00000000e+00, 0.00000000e+00, 3.97458431e-1, 5.55111512e-16, 0.00000000e+00, 1.35122609e+00], + [-2.74653047e-01, 1.58676151e-16, 6.82046943e-17, -5.38849108e-08, -1.27743945e-01, 1.49374068e+00]], \ + dtype=ms.float32) + num_atoms = Tensor([2, 4], dtype=ms.int32) + edge_index = Tensor( + [[0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5], + [0, 1, 0, 1, 2, 3, 4, 5, 2, 3, 4, 5, 2, 3, 4, 5, 2, 3, 4, 5]], dtype=ms.int32) + node2graph = Tensor([0, 0, 1, 1, 1, 1], dtype=ms.int32) + node_mask = Tensor([1, 1, 1, 1, 1, 1], dtype=ms.int32) + edge_mask = Tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=ms.int32,) + batch_size_mask = Tensor([1, 1], dtype=ms.int32) + + pred_l, tar_l, pred_f, tar_f = cspflow(atom_types, atom_types, lengths, + angles, lattice_polar, num_atoms, frac_coords, node2graph, + edge_index, node_mask, edge_mask, batch_size_mask) + out_pred_l = Tensor([[-0.54417396, 0.6183988, 0.25345746, 0.41497535, -0.00219233, 0.6622897], + [-0.5647707, 0.68243337, 0.25912297, 0.45234668, 0.01847154, 0.6095263]]) + out_tar_l = Tensor([[-0.02254689, 0.04679973, 0.3856261, -0.08269336, -0.08592724, 0.45530552], + [-0.53301036, -0.21067567, -0.05119152, 0.04148455, -0.0907657, 0.3682214]]) + out_pred_f = Tensor([[-0.7662705, 0.24618103, -0.4741043], + [-0.77218896, 0.2367004, -0.4617761], + [-0.7825796, 0.28697833, -0.38660413], + [-0.7888657, 0.2943602, -0.39356205], + [-0.7792929, 0.26879176, -0.42642403], + [-0.77509487, 0.2633396, -0.41789246]]) + out_tar_f = Tensor([[0.20181239, -0.07186192, -0.40746307], + [-0.4028666, 0.18524933, 0.14020872], + [-0.31370556, 0.08878523, -0.18229586], + [-0.15636778, -0.44619012, 0.13355094], + [-0.03352255, -0.15093482, -0.13720155], + [-0.2018686, 0.07621789, -0.4946221]]) + assert mint.isclose(pred_l, out_pred_l, rtol=1e-4, atol=1e-4).all(), \ + f"For `cspnet`, the output should be {pred_l}, but got {out_pred_l}." + assert mint.isclose(pred_f, out_pred_f, rtol=1e-4, atol=1e-4).all(), \ + f"For `cspnet`, the output should be {pred_f}, but got {out_pred_f}." + assert mint.isclose(tar_l, out_tar_l, rtol=1e-4, atol=1e-4).all(), \ + f"For `cspnet`, the output should be {tar_l}, but got {out_tar_l}." + assert mint.isclose(tar_f, out_tar_f, rtol=1e-4, atol=1e-4).all(), \ + f"For `cspnet`, the output should be {tar_f}, but got {out_tar_f}." + +def test_loss(): + """test loss""" + ms.set_context(device_target="CPU") + ckpt_dir = "./ckpt/mp_20" + if not os.path.exists(ckpt_dir): + os.makedirs(ckpt_dir) + + ms.set_seed(1234) + batch_size_max = 256 + + cspnet = CSPNet(num_layers=6, hidden_dim=512, num_freqs=256) + cspflow = CSPFlow(cspnet) + download_file('https://download-mindspore.osinfra.cn/mindscience/mindchemistry/crystalflow/ms_flow.ckpt', 'ms_flow.ckpt') + mindspore_ckpt = load_checkpoint("ms_flow.ckpt") + load_param_into_net(cspflow, mindspore_ckpt) + + loss_func_mse = L2LossMask(reduction='mean') + def forward(atom_types_step, frac_coords_step, _, lengths_step, angles_step, lattice_polar_step, \ + num_atoms_step, edge_index_step, batch_node2graph, \ + node_mask_step, edge_mask_step, batch_mask, node_num_valid, batch_size_valid): + pred_l, tar_l, pred_x, tar_x = cspflow(batch_size_valid, atom_types_step, lengths_step, + angles_step, lattice_polar_step, num_atoms_step, + frac_coords_step, batch_node2graph, edge_index_step, + node_mask_step, edge_mask_step, batch_mask) + mseloss_l = loss_func_mse(pred_l, tar_l, mask=batch_mask, num=batch_size_valid) + mseloss_x = loss_func_mse(pred_x, tar_x, mask=node_mask_step, num=node_num_valid) + mseloss = mseloss_l + 10 * mseloss_x + + return mseloss, mseloss_l, mseloss_x + + train_datatset = fullconnect_dataset(name="mp_20", path='./dataset/mp_20/train.csv', + save_path='./dataset/mp_20/train.npy') + train_loader = DataLoader(batch_size_max, *train_datatset, shuffle_dataset=False) + + for atom_types_batch, frac_coords_batch, property_batch, lengths_batch, \ + angles_batch, lattice_polar_batch, num_atoms_batch,\ + edge_index_batch, batch_node2graph_, node_mask_batch, edge_mask_batch, batch_mask_batch,\ + node_num_valid_, batch_size_valid_ in train_loader: + + result = forward(atom_types_batch, frac_coords_batch, property_batch, + lengths_batch, angles_batch, lattice_polar_batch, + num_atoms_batch, edge_index_batch, batch_node2graph_, + node_mask_batch, edge_mask_batch, batch_mask_batch, node_num_valid_, + batch_size_valid_) + + _, mseloss_l, mseloss_x = result + break + assert mseloss_l <= 0.7, "The denoising of lattice accuracy is not successful." + assert mseloss_x <= 0.7, "The denoising of fractional coordinates accuracy is not successful." diff --git a/MindChemistry/mindchemistry/cell/__init__.py b/MindChemistry/mindchemistry/cell/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5660308bc1d2931107aaf328c36f1e927f1d8fc2 --- /dev/null +++ b/MindChemistry/mindchemistry/cell/__init__.py @@ -0,0 +1,34 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""initialization for cells""" +from .allegro import * +from .nequip import Nequip +from .cspnet import CSPNet +from .basic_block import AutoEncoder, FCNet, MLPNet +from .deephe3nn import * +from .matformer import * +from .dimenet import * +from .gemnet import * +from .orb import * + +__all__ = [ + "Nequip", 'AutoEncoder', 'FCNet', 'MLPNet', 'CSPNet' +] +__all__.extend(deephe3nn.__all__) +__all__.extend(matformer.__all__) +__all__.extend(allegro.__all__) +__all__.extend(dimenet.__all__) +__all__.extend(gemnet.__all__) +__all__.extend(orb.__all__) diff --git a/MindEarth/applications/climate-prediction/ensoforecast/src/utils.py b/MindEarth/applications/climate-prediction/ensoforecast/src/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e1c669d0738185323e5c50d623e5959045a4befe --- /dev/null +++ b/MindEarth/applications/climate-prediction/ensoforecast/src/utils.py @@ -0,0 +1,123 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""ensoforcast utils""" +import os + +import numpy as np +import matplotlib.pyplot as plt + +import mindspore.dataset as ds +from mindspore.train.serialization import load_checkpoint +from mindearth.utils import create_logger + +from .ctefnet import CTEFNet +from .dataset import CMIP5Data, ReanalysisData + + +def get_logger(config): + """Get logger for saving log""" + summary_params = config.get('summary') + logger = create_logger(path=os.path.join(summary_params.get("summary_dir"), "results.log")) + for key in config: + logger.info(config[key]) + return logger + + +def init_model(config, run_mode='train'): + """Init model""" + data_params = config.get("data") + model_params = config.get("model") + train_params = config.get("train") + train_params['load_ckpt'] = run_mode == "test" + model = CTEFNet( + cov_hidden_channels=model_params.get('cov_hidden_channels'), + cov_out_channels=model_params.get('cov_out_channels'), + heads=model_params.get('heads'), + num_layer=model_params.get('num_layer'), + feedforward_dims=model_params.get('feedforward_dims'), + dropout=model_params.get('dropout'), + obs_time=data_params.get('obs_time'), + pred_time=data_params.get('pred_time') + ) + return model + + +def init_dataloader(config): + """Init dataloader""" + data_params = config.get('data') + train_type = data_params.get('train_dataset') + valid_type = data_params.get('valid_dataset') + if train_type not in ['CMIP5', 'Reanalysis']: + raise ValueError(f"Unexpected Data Type {train_type}.") + if valid_type not in ['CMIP5', 'Reanalysis']: + raise ValueError(f"Unexpected Data Type {valid_type}.") + if train_type == 'CMIP5': + train_dataset = CMIP5Data(data_params.get('root_dir'), data_params.get('train_period'), + data_params.get('obs_time'), data_params.get('pred_time')) + else: + train_dataset = ReanalysisData(data_params.get('root_dir'), data_params.get('train_period'), + data_params.get('obs_time'), data_params.get('pred_time')) + if valid_type == 'CMIP5': + valid_dataset = CMIP5Data(data_params.get('root_dir'), data_params.get('valid_period'), + data_params.get('obs_time'), data_params.get('pred_time')) + else: + valid_dataset = ReanalysisData(data_params.get('root_dir'), data_params.get('valid_period'), + data_params.get('obs_time'), data_params.get('pred_time')) + train_dataloader = ds.GeneratorDataset(train_dataset, ["data", "index"], shuffle=True).batch( + data_params.get('train_batch_size'), False) + valid_dataloader = ds.GeneratorDataset(valid_dataset, ["data", "index"], shuffle=False).batch( + data_params.get('valid_batch_size'), False) + return train_dataloader, valid_dataloader + + +def get_param_dict(config, current_step): + """Get param dict when load checkpoint""" + summary_params = config.get("summary") + + ckpt_path = os.path.join(summary_params.get('summary_dir'), 'ckpt', f'step_{current_step}') + ckpt_list = os.listdir(ckpt_path) + ckpt_list.sort() + ckpt_name = ckpt_list[-1] + params_dict = load_checkpoint(os.path.join(ckpt_path, ckpt_name)) + return params_dict, ckpt_path + + +def plot_correlation(config, corr_list): + """Plot model eval result""" + n_line = len(corr_list) + summary_params = config.get('summary') + + n = len(corr_list[0]) + x = np.arange(1, n+1, 1) + plt.rc('font', size=16) + plt.figure(figsize=(15, 6), dpi=150) + plt.plot(x, corr_list[0], color='orangered', linestyle='-', marker='o', markerfacecolor='orangered', linewidth=5, + label='CTEFNet-pretrain', markersize='8') + if n_line > 1: + plt.plot(x, corr_list[1], color='blue', linestyle='-', marker='o', markerfacecolor='blue', linewidth=5, + label='CTEFNet-finetune', markersize='8') + plt.xlabel('Forecast Lead (months)') + plt.ylabel('Correlation Skill') + plt.tick_params(labelsize=18) + my_x_ticks = np.arange(1, n+1, 1) + my_y_ticks = np.arange(0.1, 1.1, 0.1) + plt.xticks(my_x_ticks) + plt.yticks(my_y_ticks) + plt.grid(linewidth=0.1) + plt.legend(ncol=4) + plt.axhline(0.5, color='black') + plt.savefig(os.path.join(summary_params.get('summary_dir'), + 'Forecast_Correlation_Skill' + '.png')) + plt.show() diff --git a/MindEarth/applications/earthquake/G-TEAM/README.md b/MindEarth/applications/earthquake/G-TEAM/README.md new file mode 100644 index 0000000000000000000000000000000000000000..fc37db4f345362b658769cb6fc9c58b11ca03070 --- /dev/null +++ b/MindEarth/applications/earthquake/G-TEAM/README.md @@ -0,0 +1,92 @@ +ENGLISH | [简体中文](README.md) + +# G-TEAM Earthquake Early Warning Model + +## Overview + +The earthquake early warning system aims to issue alerts before destructive seismic waves arrive, thereby reducing casualties and economic losses. The G-TEAM model is a data-driven national earthquake early warning system that integrates Graph Neural Networks (GNN) and Transformer architectures. It can rapidly estimate epicenter location, magnitude, and seismic intensity distribution within 3 seconds after earthquake occurrence. By directly processing raw seismic waveform data, the model eliminates limitations from manual feature selection and enhances prediction accuracy and real-time performance through multi-station data utilization. + +This model is an efficient earthquake early warning system combining Graph Neural Networks (GNN) and Transformer architectures, taking seismic waveform data from any number of seismic stations as input. It enables real-time processing of seismic signals to deliver fast and precise estimations of hypocenter location, magnitude, and seismic intensity distribution range (characterized by Peak Ground Acceleration, PGA). Leveraging deep learning methods, the model fully exploits spatial correlations and temporal features within seismic networks to improve warning accuracy and response speed, providing robust support for earthquake emergency response and disaster mitigation strategies. + +![](./images/image.png) + +The PGA prediction architecture using multi-source seismic station data operates as follows: + +1. The system receives position data and waveform recordings from multiple seismic stations, along with target coordinates for PGA estimation. +2. For each station's waveform data: + - Perform standardization + - Extract features via Convolutional Neural Networks (CNN) + - Fuse features through fully connected layers + - Combine with station coordinates to form feature vectors +3. Target PGA coordinates are processed through positional encoding to generate feature vectors. +4. All feature vectors are sequentially fed into a Transformer encoder that captures global dependencies via self-attention mechanisms. +5. Encoder outputs pass through three independent fully connected layers to perform regression tasks: magnitude estimation, epicenter localization, and PGA prediction. + +## Training Data + +The model is trained using the [Diting Dataset 2.0 - Multifunctional Large AI Training Dataset for China Seismic Network](http://www.esdc.ac.cn/article/137), which contains: + +- Waveform records from 1,177 fixed stations in China (15°-50°N, 65°-140°E) +- Data coverage: March 2020 to February 2023 +- 264,298 local seismic events (M > 0) +- Only retains initial P-wave and S-wave phases +- Includes events recorded by ≥3 stations for reliability + +This model has fully open-sourced both the inference and training modules. For the inference part, the provided [ckpt](https://download-mindspore.osinfra.cn/mindscience/mindearth/dataset/G-TEAM/) is used for inference, while the training part utilizes the provided [hdf5](https://download-mindspore.osinfra.cn/mindscience/mindearth/dataset/G-TEAM/) and [pkl](https://download-mindspore.osinfra.cn/mindscience/mindearth/dataset/G-TEAM/) files for training. + +## Quick Start + +You can download the required data and ckpt files for training and inference at [dataset](https://download-mindspore.osinfra.cn/mindscience/mindearth/dataset/G-TEAM/) + +### Execution + +Run via command line using the `main` script: +It is necessary to configure the istraining parameter in the config.yaml file in advance to set up inference or training: +istraining: false -- Inference +istraining: true -- Training + +```python +python main.py --cfg_path ./config/config.yaml --device_id 0 --device_target Ascend + +``` + +Parameters: +--cfg_path: Configuration file path (default: "./config/config.yaml") +--device_target: Hardware type (default: Ascend) +--device_id: Device ID (default: 0) + +### Inference + +### Visualization + +![](./images/pga.png) + +Scatter plot compares predicted vs actual PGA values (x-axis vs y-axis). Closer alignment to y=x line indicates higher accuracy. + +### Results Presentation + +| Parameter | NPU | +|:----------------------:|:--------------------------:| +| Hardware | Ascend, memory 64G | +| MindSpore Version | mindspore2.5.0 | +| Dataset | diting2_2020-2022_sc | +| Test Parameters | batch_size=1
steps=9 | +| Magnitude Error (RMSE, MSE) | [ 0.262, 0.257 ] | +| Epicenter Distance Error (RMSE, MAE) | [ 4.318 , 4.123 ] | +| Hypocenter Depth Error (RMSE, MAE) | [ 5.559 , 5.171 ] | +| PGA Error (RMSE, MSE) |[ 0.466, 0.287 ] | +| Inference Resource | 1NPU | +| Inference Speed(ms/step) | 556 | + +### Training + +### Results Presentation + +![](./images/train_loss.png) +Under normal circumstances, the Average Training Loss should continue to converge. + +## Contributors + +gitee id: xujiabao, longjundong, dinghongyang, chengjie + +email: funniless@163.com \ No newline at end of file diff --git a/MindEarth/applications/earthquake/G-TEAM/README_CN.md b/MindEarth/applications/earthquake/G-TEAM/README_CN.md new file mode 100644 index 0000000000000000000000000000000000000000..3475aa849629194b4375bac120ce2ec3225ecdc4 --- /dev/null +++ b/MindEarth/applications/earthquake/G-TEAM/README_CN.md @@ -0,0 +1,72 @@ +[ENGLISH](README.md) | 简体中文 + +# G-TEAM地震预警模型 + +## 概述 + +地震预警系统旨在在破坏性震动到达前尽早发出警报,以减少人员伤亡和经济损失。G-TEAM 模型是一种数据驱动的全国地震预警系统,结合了图神经网络(GNN)和 Transformer 架构,能够在地震发生后 3 秒内迅速提供震中位置、震级及地震强度分布。该模型通过直接处理原始地震波形数据,避免了手动特征选择的限制,并充分利用多台站数据,提高了预测的准确性和实时性。 + +本模型是一款高效的地震预警系统,结合了图神经网络(Graph Neural Network, GNN)与 Transformer 架构,以任意数量的地震台站记录的地震波形数据作为输入。该模型能够实时接收地震信号,并对震源位置、震级以及地震烈度分布范围进行快速且精准的估计,其中烈度分布范围以地面峰值加速度(Peak Ground Acceleration, PGA)表征。通过深度学习方法,本模型可以充分利用地震台网的空间关联性与时序特征,提高预警精度和响应速度,为地震应急响应和减灾决策提供可靠支持。 + +![](./images/image.png) + +该模型采用多源地震台站数据进行PGA预测,具体架构如下:首先,系统接收多个地震台站的位置信息及其记录的地震波形数据,同时获取待估计PGA的目标位置坐标。对于每个地震台站的波形数据,首先进行标准化处理,随后通过卷积神经网络(CNN)进行特征提取。提取的特征经全连接层进行特征融合,并与对应台站的位置信息共同构成特征向量。 +目标PGA位置坐标经过位置编码模块处理后,形成特征向量。所有特征向量按序列形式输入到Transformer编码器中,编码器通过自注意力机制捕捉全局依赖关系。编码器输出依次通过三个独立的全连接层,分别完成地震事件震级、震中位置以及PGA的回归预测任务。 + +本模型的训练数据来源于[谛听数据集2.0 -中国地震台网多功能大型人工智能训练数据集](http://www.esdc.ac.cn/article/137),该数据集汇集了中国大陆及其邻近地区(15°-50°N,65°-140°E)1177 个中国地震台网固定台站的波形记录,覆盖时间范围为 2020 年 3 月至 2023 年 2 月。数据集包含研究区域内所有震级大于 0 的地方震事件,共计 264,298 个。我们在训练过程中仅选取了初至 P 波和 S 波震相,并且只保留至少被三个台站记录到的地震事件,以确保数据的可靠性和稳定性。 + +本模型已全部开源推理和训练模块,其中推理部分使用提供的[ckpt](https://download-mindspore.osinfra.cn/mindscience/mindearth/dataset/G-TEAM/)进行推理,训练部分使用提供的[hdf5](https://download-mindspore.osinfra.cn/mindscience/mindearth/dataset/G-TEAM/)和[pkl](https://download-mindspore.osinfra.cn/mindscience/mindearth/dataset/G-TEAM/)进行训练。 + +## 快速开始 + +可在[dataset](https://download-mindspore.osinfra.cn/mindscience/mindearth/dataset/G-TEAM/)下载训练所需要的数据集。 + +### 运行方式: 在命令行调用`main`脚本 + +需提前在config.yaml中配置istraining参数设定推理/训练 +istraining: false -- 推理 +istraining: true -- 训练 + +```python + +python main.py --cfg_path ./config/config.yaml --device_id 0 --device_target Ascend + +``` + +其中, --cfg_path表示配置文件路径,默认值"./config/config.yaml" --device_target 表示设备类型,默认Ascend。 --device_id 表示运行设备的编号,默认值0。 + +### 推理 + +### 可视化结果 + +![](./images/pga.png) + +图示为pga的点坐标,横轴表示预测值,纵轴表示实际值,点数据越靠近y=x这条直线代表数据越准确。 + +### 结果展示 + +| 参数 | NPU | +|:----------------------:|:--------------------------:| +| 硬件 | Ascend, memory 64G | +| mindspore版本 | mindspore2.5.0 | +| 数据集 | diting2_2020-2022_sc | +| 测试参数 | batch_size=1
steps=9 | +| Mag震级误差(RMSE, MSE) | [ 0.262, 0.257 ] | +| Loc震中距离误差(RMSE, MAE) | [ 4.318 , 4.123 ] | +| Loc震源距离误差(RMSE, MAE) | [ 5.559 , 5.171 ] | +| Pga峰值地面加速度误差(RMSE, MSE) |[ 0.466, 0.287 ] | +| 推理资源 | 1NPU | +| 推理速度(ms/step) | 556 | + +### 训练 + +### 结果展示 + +![](./images/train_loss.png) +正常情况Average Training Loss会持续收敛。 + +## 贡献者 + +gitee id: xujiabao, longjundong, dinghongyang, chengjie + +email: funniless@163.com \ No newline at end of file diff --git a/MindEarth/applications/earthquake/G-TEAM/config/GTEAM.yaml b/MindEarth/applications/earthquake/G-TEAM/config/GTEAM.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0faf89c09619b16e1e69f0b67c1905c0d2696249 --- /dev/null +++ b/MindEarth/applications/earthquake/G-TEAM/config/GTEAM.yaml @@ -0,0 +1,76 @@ +model: + istraining: false + use_mlp: False + hidden_dim: 1000 + hidden_dropout: 0.0 + n_heads: 10 + n_pga_targets: 15 + output_location_dims: [150,100,50,30,3] + output_mlp_dims: [150,100,50,30,1] + transformer_layers: 6 + waveform_model_dims: [500,500,500] + wavelength: [[0.01,15],[0.01,15],[0.01,10]] + times: [5] + run_with_less_data: false + pga: true + mode: test + no_event_token : False + max_stations: 5 +data: + root_dir: "./dataset" + batch_size: 64 + max_stations: 5 + disable_station_foreshadowing: true + key: Mag + magnitude_resampling: 1 + min_mag: None + min_upsample_magnitude: 4 + aug_large: True + pga_from_inactive: true + pga_key: pga + pga_selection_skew: 1000 + pos_offset: [30,102] + scale_metadata: false + selection_skew: 1000 + shuffle_train_dev: true + transform_target_only: false + trigger_based: true + waveform_shape: [3000, 3] + overwrite_sampling_rate: None + noise_seconds: 5 +training_params: + seed: 42 + clipnorm: 1.0 + data_path: ./diting2_2020-2022_sc_abridged.hdf5 + ensemble_rotation: true + epochs_full_model: 100 + epochs_single_station: 5 + filter_single_station_by_pick: true + generator_params: + - batch_size: 1 + cutout_end: 25 + cutout_start: -1 + disable_station_foreshadowing: true + key: Mag + magnitude_resampling: 1.5 + min_upsample_magnitude: 4 + pga_from_inactive: true + pga_key: pga + pga_selection_skew: 1000 + pos_offset: [30,102] + scale_metadata: false + selection_skew: 1000 + shuffle_train_dev: true + transform_target_only: false + translate: false + trigger_based: true + upsample_high_station_events: 10 + loss_weights: + location: 1 + magnitude: 0.3 + pga: 1 + lr: 1e-5 + workers: 1 +summary: + summary_dir: "./summary" + ckpt_path: "./dataset/ckpt/g_team.ckpt" \ No newline at end of file diff --git a/MindEarth/applications/earthquake/G-TEAM/images/image.png b/MindEarth/applications/earthquake/G-TEAM/images/image.png new file mode 100644 index 0000000000000000000000000000000000000000..455f89f7b4e7d98a145f983df99d13bdad6e8c31 Binary files /dev/null and b/MindEarth/applications/earthquake/G-TEAM/images/image.png differ diff --git a/MindEarth/applications/earthquake/G-TEAM/images/train_loss.png b/MindEarth/applications/earthquake/G-TEAM/images/train_loss.png new file mode 100644 index 0000000000000000000000000000000000000000..77ee77fa11f4db7eec0504a01a2e0acbac987697 Binary files /dev/null and b/MindEarth/applications/earthquake/G-TEAM/images/train_loss.png differ diff --git a/MindEarth/applications/earthquake/G-TEAM/main.py b/MindEarth/applications/earthquake/G-TEAM/main.py new file mode 100644 index 0000000000000000000000000000000000000000..2e12e479f90e8dbc6732172c1e66bbc4b59a56bf --- /dev/null +++ b/MindEarth/applications/earthquake/G-TEAM/main.py @@ -0,0 +1,63 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""main function""" +import argparse + +import mindspore as ms +from mindspore import context +from mindearth import load_yaml_config, make_dir + +from src.utils import init_model, get_logger +from src.forcast import GTeamInference, GTeamTrain + + +def get_args(): + """get args""" + parser = argparse.ArgumentParser() + parser.add_argument("--cfg_path", default="./config/GTEAM.yaml", type=str) + parser.add_argument("--device_id", default=0, type=int) + parser.add_argument("--device_target", default="Ascend", type=str) + parse_args = parser.parse_args() + return parse_args + +def test(cfg): + """main test""" + save_dir = cfg["summary"].get("summary_dir", "./summary") + make_dir(save_dir) + model = init_model(cfg) + logger_obj = get_logger(cfg) + processor = GTeamInference(model, cfg, save_dir, logger_obj) + processor.test() + + +def train(cfg): + """main train""" + save_dir = cfg["summary"].get("summary_dir", "./summary") + make_dir(save_dir) + model = init_model(cfg) + logger_obj = get_logger(cfg) + processor = GTeamTrain(model, cfg, save_dir, logger_obj) + processor.train() + + +if __name__ == '__main__': + args = get_args() + config = load_yaml_config(args.cfg_path) + context.set_context(mode=ms.PYNATIVE_MODE) + ms.set_device(device_target=args.device_target, device_id=args.device_id) + if config['model']['istraining']: + train(config) + else: + test(config) diff --git a/MindEarth/applications/earthquake/G-TEAM/src/data.py b/MindEarth/applications/earthquake/G-TEAM/src/data.py new file mode 100644 index 0000000000000000000000000000000000000000..52abb64ec4ca01c4bf1f9ea6410cccceb85932ee --- /dev/null +++ b/MindEarth/applications/earthquake/G-TEAM/src/data.py @@ -0,0 +1,757 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"load diting data" +import os +import pickle +import glob +import h5py +import numpy as np + +import mindspore as ms +from mindspore.dataset import Dataset + +# degrees to kilometers +D2KM = 111.19492664455874 + + +def load_pickle_data(filename): + """Load and deserialize data from a pickle file.""" + with open(filename, "rb") as file: + data = pickle.load(file) + print(f"Data loaded from {filename}") + return data + +def load_data(cfg): + """Load preprocessed seismic data from a configured pickle file.""" + data_path = glob.glob(os.path.join(cfg["data"].get("root_dir"), "*.hdf5"))[0] + file_basename = os.path.basename(data_path).split(".")[0] + filename = os.path.join( + cfg["data"].get("root_dir"), f"{file_basename}_test_filter_pga.pkl" + ) + loaded_pickle_data = load_pickle_data(filename) + _, evt_metadata, meta_data, data_data, evt_key, _ = loaded_pickle_data + return data_data, evt_key, evt_metadata, meta_data, data_path + + +def detect_location_keys(columns): + """Identify standardized location keys from column headers.""" + candidates = [ + ["LAT", "Latitude(°)", "Latitude"], + ["LON", "Longitude(°)", "Longitude"], + ["DEPTH", "JMA_Depth(km)", "Depth(km)", "Depth/Km"], + ] + + coord_keys = [] + for keyset in candidates: + for key in keyset: + if key in columns: + coord_keys += [key] + break + + if len(coord_keys) != len(candidates): + raise ValueError("Unknown location key format") + + return coord_keys + + +class DataGenerator(Dataset): + """ + A PyTorch Dataset subclass for generating earthquake detection training data. + Handles loading, preprocessing, and batching of seismic waveform data. + """ + def __init__(self, data_path, event_metadata_index, event_key, mag_key='M_J', batch_size=32, cutout=None, + sliding_window=False, windowlen=3000, shuffle=True, label_smoothing=False, decimate=1): + """ + Initialize the data generator. + + Args: + data_path (str): Path to the HDF5 file containing seismic data. + event_metadata_index (pd.DataFrame): DataFrame containing event metadata indices. + event_key (str): Column name in metadata used to identify events. + mag_key (str, optional): Column name containing magnitude values. Defaults to 'M_J'. + batch_size (int, optional): Number of samples per batch. Defaults to 32. + cutout (tuple, optional): Time window for data augmentation. Defaults to None. + sliding_window (bool, optional): Use sliding window for cutouts. Defaults to False. + windowlen (int, optional): Length of time window for analysis. Defaults to 3000. + shuffle (bool, optional): Shuffle data during epoch. Defaults to True. + label_smoothing (bool, optional): Apply label smoothing. Defaults to False. + decimate (int, optional): Decimation factor for waveform data. Defaults to 1. + """ + super().__init__() + self.data_path = data_path + self.event_metadata_index = event_metadata_index + self.event_key = event_key + self.batch_size = batch_size + self.mag_key = mag_key + self.cutout = cutout + self.sliding_window = sliding_window + self.windowlen = windowlen + self.shuffle = shuffle + self.label_smoothing = label_smoothing + self.decimate = decimate + self.indexes = np.arange(len(self.event_metadata_index)) + self.on_epoch_end() + + def __len__(self): + """ + Return the number of batches in the dataset. + + Returns: + int: Number of batches = total samples / batch size (floor division) + """ + return int(np.floor(len(self.event_metadata_index) / self.batch_size)) + + def __getitem__(self, index): + """ + Get a batch of data by index. + + Args: + index (int): Batch index + + Returns: + tuple: (x, y) where X is input tensor, y is target tensor + """ + indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size] + batch_metadata = self.event_metadata_index.iloc[indexes] + x, y = self.__data_generation(batch_metadata) + return x, y + + def on_epoch_end(self): + """ + Called when an epoch ends. Resets indexes and shuffles if required. + """ + self.indexes = np.arange(len(self.event_metadata_index)) + if self.shuffle: + np.random.shuffle(self.indexes) + + def __data_generation(self, batch_metadata): + """ + Generate data for a batch of events. + + Args: + batch_metadata (pd.DataFrame): Metadata for batch of events + + Returns: + tuple: (x, y) where x is input tensor, y is target tensor + """ + x = [] + y = [] + with h5py.File(self.data_path, 'r') as f: + for _, event in batch_metadata.iterrows(): + event_name = str(int(event[self.event_key])) + if event_name not in f['data']: + continue + g_event = f['data'][event_name] + waveform = g_event['waveforms'][event['index'], ::self.decimate, :] + if self.cutout: + if self.sliding_window: + windowlen = self.windowlen + window_end = np.random.randint(max(windowlen, self.cutout[0]), + min(waveform.shape[1], self.cutout[1]) + 1) + waveform = waveform[:, window_end - windowlen:window_end] + else: + waveform[:, np.random.randint(*self.cutout):] = 0 + x.append(waveform) + y.append(event[self.mag_key]) + x = np.array(x) + y = np.array(y) + if self.label_smoothing: + y += (y > 4) * np.random.randn(y.shape[0]).reshape(y.shape) * (y - 4) * 0.05 + + return (ms.tensor(x, dtype=ms.float32), + ms.tensor(np.expand_dims(np.expand_dims(y, axis=1), axis=2), dtype=ms.float32)) + +class EarthquakeDataset(Dataset): + """ + Dataset class for loading and processing seismic event data. + Handles waveform loading, magnitude-based resampling, PGA target processing, + and batch preparation for earthquake analysis models. + Key Features: + Batch processing of seismic waveforms and metadata + Magnitude-based data resampling for class balance + PGA (Peak Ground Acceleration) target handling + HDF5 waveform data loading + Flexible data shuffling and oversampling + """ + + def __init__( + self, + data_path, + event_key, + data, + event_metadata, + batch_size=32, + shuffle=True, + oversample=1, + magnitude_resampling=3, + min_upsample_magnitude=2, + pga_targets=None, + pga_mode=False, + pga_key="pga", + coord_keys=None, + **kwargs, + ): + + super().__init__() + + self.data_path = data_path + self.event_key = event_key + self.batch_size = batch_size + self.shuffle = shuffle + self.metadata = data["coords"] + self.event_metadata = event_metadata + self.pga = data[pga_key] + self.triggers = data["p_picks"] + self.oversample = oversample + + self.pga_mode = pga_mode + self.pga_targets = pga_targets + + self.base_indexes = np.arange(self.event_metadata.shape[0]) + self.reverse_index = None + + if magnitude_resampling > 1: + magnitude = self.event_metadata[kwargs["key"]].values + for i in np.arange(min_upsample_magnitude, 9): + ind = np.where(np.logical_and(i < magnitude, magnitude <= i + 1))[0] + self.base_indexes = np.concatenate( + ( + self.base_indexes, + np.repeat(ind, int(magnitude_resampling ** (i - 1) - 1)), + ) + ) + + if pga_mode and pga_targets is not None: + new_base_indexes = [] + self.reverse_index = [] + c = 0 + for idx in self.base_indexes: + num_samples = (len(self.pga[idx]) - 1) // pga_targets + 1 + new_base_indexes += [(idx, i) for i in range(num_samples)] + self.reverse_index += [c] + c += num_samples + self.reverse_index += [c] + self.base_indexes = new_base_indexes + if coord_keys is None: + self.coord_keys = detect_location_keys(event_metadata.columns) + else: + self.coord_keys = coord_keys + self.use_shuffle() + + def __len__(self): + """get length""" + return int(np.ceil(len(self.indexes) / self.batch_size)) + + def __getitem__(self, index): + """Load data.""" + batch_indexes = self.indexes[ + index * self.batch_size : (index + 1) * self.batch_size + ] + batch_data = { + "indexes": batch_indexes, + "waveforms": [], + "metadata": [], + "pga": [], + "p_picks": [], + "event_info": [], + } + if self.pga_mode: + batch_data["pga_indexes"] = [x[1] for x in batch_indexes] + batch_data["indexes"] = [x[0] for x in batch_indexes] + for idx in batch_data["indexes"]: + event = self.event_metadata.iloc[idx] + event_name = str(event[self.event_key]) + waveform_data = self._load_waveform_data(event_name) + batch_data["waveforms"].append(waveform_data) + batch_data["metadata"].append(self.metadata[idx]) + batch_data["pga"].append(self.pga[idx]) + batch_data["p_picks"].append(self.triggers[idx]) + batch_data["event_info"].append(event) + + return batch_data + + def _load_waveform_data(self, event_name): + """load waveform data""" + with h5py.File(self.data_path, "r") as f: + if "data" not in f or event_name not in f["data"]: + return None + g_event = f["data"][event_name] + if "waveforms" not in g_event: + return None + return g_event["waveforms"][:, :, :] + + def use_shuffle(self): + """shuffle index""" + self.indexes = np.repeat(self.base_indexes.copy(), self.oversample, axis=0) + if self.shuffle: + np.random.shuffle(self.indexes) + +class PreloadedEventGenerator(Dataset): + """ + A custom dataset generator for preloading seismic events. + """ + def __init__(self, data_path, event_key, data, event_metadata, waveform_shape=(3000, 6), key='MA', batch_size=32, + cutout=None, sliding_window=False, windowlen=3000, shuffle=True, coords_target=True, oversample=1, + pos_offset=(-21, -69), label_smoothing=False, station_blinding=False, magnitude_resampling=3, + pga_targets=None, adjust_mean=True, transform_target_only=False, max_stations=None, trigger_based=None, + min_upsample_magnitude=2, disable_station_foreshadowing=False, selection_skew=None, + pga_from_inactive=False, integrate=False, differentiate=False, sampling_rate=100., select_first=False, + fake_borehole=False, scale_metadata=True, pga_key='pga', pga_mode=False, p_pick_limit=5000, + coord_keys=None, upsample_high_station_events=None, no_event_token=False, pga_selection_skew=None, + **kwargs): + ''' + Initializes the PreloadedEventGenerator. + + Args: + data_path: Path to the HDF5 file containing waveform data. + event_key: The key in the event metadata DataFrame identifying each event. + data: Dictionary containing 'coords' and 'pga' keys for metadata and PGA values. + event_metadata: Pandas DataFrame with event metadata. + waveform_shape: Shape of each waveform (number of samples, number of channels). + key: The key in event metadata to use for magnitude. + batch_size: Number of events per batch. + cutout: Tuple specifying the range for random cutout in the waveform. + sliding_window: Whether to use a sliding window for cutout. + windowlen: Length of the sliding window. + shuffle: Whether to shuffle the events at the end of each epoch. + coords_target: Whether to include event coordinates as targets. + oversample: Factor by which to oversample the events. + pos_offset: Offset to apply to event coordinates. + label_smoothing: Whether to apply label smoothing to magnitudes. + station_blinding: Whether to randomly blind stations in the waveforms. + magnitude_resampling: Factor by which to resample events based on their magnitude. + pga_targets: Number of PGA targets to sample per event. + adjust_mean: Whether to adjust the mean of the waveforms. + transform_target_only: Whether to apply transformations only to the target coordinates. + max_stations: Maximum number of stations to include per event. + trigger_based: Whether to zero out waveforms before the P-wave trigger. + min_upsample_magnitude: Minimum magnitude for upsampling. + disable_station_foreshadowing: Whether to disable station foreshadowing. + selection_skew: Skew parameter for selecting stations when max_stations is reached. + pga_from_inactive: Whether to sample PGA from inactive stations. + integrate: Whether to integrate the waveforms. + differentiate: Whether to differentiate the waveforms. + sampling_rate: Sampling rate of the waveforms. + select_first: Whether to select the first stations when max_stations is reached. + fake_borehole: Whether to add fake borehole channels to the waveforms. + scale_metadata: Whether to scale the metadata coordinates. + pga_key: Key in the data dictionary for PGA values. + pga_mode: Whether to operate in PGA mode. + p_pick_limit: Limit for P-wave picks. + coord_keys: Keys in the event metadata for coordinates. + upsample_high_station_events: Whether to upsample events with high station counts. + no_event_token: Whether to include an event token in the outputs. + pga_selection_skew: Skew parameter for selecting PGA targets. + **kwargs: + ''' + super().__init__() + if kwargs: + print(f'Unused parameters: {", ".join(kwargs.keys())}') + self.data_path = data_path + self.event_key = event_key + self.batch_size = batch_size + self.shuffle = shuffle + self.waveform_shape = waveform_shape + self.metadata = data['coords'] + self.event_metadata = event_metadata + self.pga = data[pga_key] + self.key = key + self.cutout = cutout + self.sliding_window = sliding_window + self.windowlen = windowlen + self.coords_target = coords_target + self.oversample = oversample + self.pos_offset = pos_offset + self.label_smoothing = label_smoothing + self.station_blinding = station_blinding + self.magnitude_resampling = magnitude_resampling + self.pga_targets = pga_targets + self.adjust_mean = adjust_mean + self.transform_target_only = transform_target_only + self.max_stations = max_stations + self.trigger_based = trigger_based + self.disable_station_foreshadowing = disable_station_foreshadowing + self.selection_skew = selection_skew + self.pga_from_inactive = pga_from_inactive + self.pga_selection_skew = pga_selection_skew + self.integrate = integrate + self.differentiate = differentiate + self.sampling_rate = sampling_rate + self.select_first = select_first + self.fake_borehole = fake_borehole + self.scale_metadata = scale_metadata + self.upsample_high_station_events = upsample_high_station_events + self.no_event_token = no_event_token + self.triggers = data['p_picks'] + self.pga_mode = pga_mode + self.p_pick_limit = p_pick_limit + self.base_indexes = np.arange(self.event_metadata.shape[0]) + self.reverse_index = None + if magnitude_resampling > 1: + magnitude = self.event_metadata[key].values + for i in np.arange(min_upsample_magnitude, 9): + ind = np.where(np.logical_and(i < magnitude, magnitude <= i + 1))[0] + self.base_indexes = np.concatenate( + (self.base_indexes, np.repeat(ind, int(magnitude_resampling ** (i - 1) - 1)))) + if pga_mode: + new_base_indexes = [] + self.reverse_index = [] + c = 0 + for idx in self.base_indexes: + num_samples = (len(self.pga[idx]) - 1) // pga_targets + 1 + new_base_indexes += [(idx, i) for i in range(num_samples)] + self.reverse_index += [c] + c += num_samples + self.reverse_index += [c] + self.base_indexes = new_base_indexes + self.indexes = np.arange(len(self.event_metadata)) + if coord_keys is None: + self.coord_keys = detect_location_keys(event_metadata.columns) + else: + self.coord_keys = coord_keys + self.on_epoch_end() + + def __len__(self): + """ + Returns the number of batches in the dataset. + """ + return int(np.ceil(len(self.indexes) / self.batch_size)) + + def __getitem__(self, index): + """ + Retrieves a batch of events from the dataset. + """ + indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size] + true_batch_size = len(indexes) + if self.pga_mode: + self.pga_indexes = [x[1] for x in indexes] + indexes = [x[0] for x in indexes] + + waveforms = np.zeros([true_batch_size, self.max_stations] + self.waveform_shape) + true_max_stations_in_batch = max(max([self.metadata[idx].shape[0] for idx in indexes]), self.max_stations) + metadata = np.zeros((true_batch_size, true_max_stations_in_batch) + self.metadata[0].shape[1:]) + pga = np.zeros((true_batch_size, true_max_stations_in_batch)) + full_p_picks = np.zeros((true_batch_size, true_max_stations_in_batch)) + p_picks = np.zeros((true_batch_size, self.max_stations)) + reverse_selections = [] + + waveforms, metadata, pga, p_picks, reverse_selections, full_p_picks = ( + self.htpyfile_process(indexes, waveforms, metadata, pga, + p_picks, reverse_selections, full_p_picks)) + + magnitude = self.event_metadata.iloc[indexes][self.key].values.copy() + magnitude = magnitude.astype(np.float32) + + target, waveforms, magnitude, metadata, pga_values, pga_targets = ( + self.data_preprocessing(indexes, waveforms, p_picks, magnitude, metadata, + pga, true_batch_size, reverse_selections, full_p_picks)) + + waveforms, metadata = self.data_processing(waveforms, metadata) + + return self.get_result(waveforms, metadata, magnitude, target, pga_targets, pga_values) + + def htpyfile_process(self, indexes, waveforms, metadata, pga, + p_picks, reverse_selections, full_p_picks): + """ + Processes the HDF5 file to retrieve waveform data for a batch of events. + """ + with h5py.File(self.data_path, 'r') as f: + for i, idx in enumerate(indexes): + event = self.event_metadata.iloc[idx] + event_name = str(event[self.event_key]) + if event_name not in f['data']: + continue + g_event = f['data'][event_name] + waveform_data = g_event['waveforms'][:, :, :] + + num_stations = waveform_data.shape[0] + + if num_stations <= self.max_stations: + waveforms[i, :num_stations] = waveform_data + metadata[i, :len(self.metadata[idx])] = self.metadata[idx] + pga[i, :len(self.pga[idx])] = self.pga[idx] + p_picks[i, :len(self.triggers[idx])] = self.triggers[idx] + reverse_selections += [[]] + else: + if self.selection_skew is None: + selection = np.arange(0, num_stations) + np.random.shuffle(selection) + else: + tmp_p_picks = self.triggers[idx].copy() + mask = np.logical_and(tmp_p_picks <= 0, tmp_p_picks > self.p_pick_limit) + tmp_p_picks[mask] = min(np.max(tmp_p_picks), self.p_pick_limit) + coeffs = np.exp(-tmp_p_picks / self.selection_skew) + coeffs *= np.random.random(coeffs.shape) + coeffs[self.triggers[idx] == 0] = 0 + coeffs[self.triggers[idx] > self.waveform_shape[0]] = 0 + selection = np.argsort(-coeffs) + + if self.select_first: + selection = np.argsort(self.triggers[idx]) + + metadata[i, :len(selection)] = self.metadata[idx][selection] + pga[i, :len(selection)] = self.pga[idx][selection] + full_p_picks[i, :len(selection)] = self.triggers[idx][selection] + + tmp_reverse_selection = [0 for _ in selection] + for j, s in enumerate(selection): + tmp_reverse_selection[s] = j + reverse_selections += [tmp_reverse_selection] + + selection = selection[:self.max_stations] + waveforms[i] = waveform_data[selection] + p_picks[i] = self.triggers[idx][selection] + return waveforms, metadata, pga, p_picks, reverse_selections, full_p_picks + + def pga_mode_process(self, waveforms, reverse_selections, metadata, + pga_values, pga_targets, pga, indexes, full_p_picks): + """ + Processes the data in PGA mode. + """ + if self.pga_mode: + for i in range(waveforms.shape[0]): + pga_index = self.pga_indexes[i] + if reverse_selections[i]: + sorted_pga = pga[i, reverse_selections[i]] + sorted_metadata = metadata[i, reverse_selections[i]] + else: + sorted_pga = pga[i] + sorted_metadata = metadata[i] + pga_values_pre = sorted_pga[pga_index * self.pga_targets:(pga_index + 1) * self.pga_targets] + pga_values[i, :len(pga_values_pre)] = pga_values_pre + pga_targets_pre = sorted_metadata[pga_index * self.pga_targets:(pga_index + 1) * self.pga_targets, :] + if pga_targets_pre.shape[-1] == 4: + pga_targets_pre = pga_targets_pre[:, (0, 1, 3)] + pga_targets[i, :len(pga_targets_pre), :] = pga_targets_pre + else: + pga[np.logical_or(np.isnan(pga), np.isinf(pga))] = 0 + for i in range(waveforms.shape[0]): + active = np.where(pga[i] != 0)[0] + l = len(active) + if l == 0: + raise ValueError(f'Found event without PGA idx={indexes[i]}') + while len(active) < self.pga_targets: + active = np.repeat(active, 2) + if self.pga_selection_skew is not None: + active_p_picks = full_p_picks[i, active] + mask = np.logical_and(active_p_picks <= 0, active_p_picks > self.p_pick_limit) + active_p_picks[mask] = min(np.max(active_p_picks), self.p_pick_limit) + coeffs = np.exp(-active_p_picks / self.pga_selection_skew) + coeffs *= np.random.random(coeffs.shape) + active = active[np.argsort(-coeffs)] + else: + np.random.shuffle(active) + + samples = active[:self.pga_targets] + if metadata.shape[-1] == 3: + pga_targets[i] = metadata[i, samples, :] + else: + full_targets = metadata[i, samples] + pga_targets[i] = full_targets[:, (0, 1, 3)] + pga_values[i] = pga[i, samples] + return pga_values, pga_targets + + def data_preprocessing(self, indexes, waveforms, p_picks, magnitude, metadata, + pga, true_batch_size, reverse_selections, full_p_picks): + """ + Data preprocessing. + """ + target = None + if self.coords_target: + target = self.event_metadata.iloc[indexes][self.coord_keys].values + target = target.astype(np.float32) + org_waveform_length = waveforms.shape[2] + if self.cutout: + if self.sliding_window: + windowlen = self.windowlen + window_end = np.random.randint(max(windowlen, self.cutout[0]), + min(waveforms.shape[2], self.cutout[1]) + 1) + waveforms = waveforms[:, :, window_end - windowlen: window_end] + cutout = window_end + if self.adjust_mean: + waveforms -= np.mean(waveforms, axis=2, keepdims=True) + else: + cutout = np.random.randint(*self.cutout) + if self.adjust_mean: + waveforms -= np.mean(waveforms[:, :, :cutout + 1], axis=2, keepdims=True) + waveforms[:, :, cutout:] = 0 + else: + cutout = waveforms.shape[2] + + if self.trigger_based: + p_picks[p_picks <= 0] = org_waveform_length + waveforms[cutout < p_picks, :, :] = 0 + if self.integrate: + waveforms = np.cumsum(waveforms, axis=2) / self.sampling_rate + if self.differentiate: + waveforms = np.diff(waveforms, axis=2) + + magnitude = np.expand_dims(np.expand_dims(magnitude, axis=-1), axis=-1) + if self.coords_target: + metadata, target = self.location_transformation(metadata, target) + else: + metadata = self.location_transformation(metadata) + + if self.label_smoothing: + magnitude += (magnitude > 4) * np.random.randn(magnitude.shape[0]).reshape(magnitude.shape) * ( + magnitude - 4) * 0.05 + if not self.pga_from_inactive and not self.pga_mode: + metadata = metadata[:, :self.max_stations] + pga = pga[:, :self.max_stations] + pga_values = () + pga_targets = () + if self.pga_targets: + pga_values = np.zeros((true_batch_size, self.pga_targets)) + pga_targets = np.zeros((true_batch_size, self.pga_targets, 3)) + + pga_values, pga_targets = self.pga_mode_process(waveforms, reverse_selections, metadata, + pga_values, pga_targets, pga, indexes, full_p_picks) + + pga_values = pga_values.reshape((true_batch_size, self.pga_targets, 1, 1)) + + return target, waveforms, magnitude, metadata, pga_values, pga_targets + + def data_processing(self, waveforms, metadata): + """ + Data process. + """ + metadata = metadata[:, :self.max_stations] + if self.station_blinding: + mask = np.zeros(waveforms.shape[:2], dtype=bool) + + for i in range(waveforms.shape[0]): + active = np.where((waveforms[i] != 0).any(axis=(1, 2)))[0] + l = len(active) + if l == 0: + active = np.zeros(1, dtype=int) + blind_length = np.random.randint(0, len(active)) + np.random.shuffle(active) + blind = active[:blind_length] + mask[i, blind] = True + + waveforms[mask] = 0 + metadata[mask] = 0 + + stations_without_trigger = (metadata != 0).any(axis=2) & (waveforms == 0).all(axis=(2, 3)) + if self.disable_station_foreshadowing: + metadata[stations_without_trigger] = 0 + else: + waveforms[stations_without_trigger, 0, 0] += 1e-9 + + mask = np.logical_and((metadata == 0).all(axis=(1, 2)), (waveforms == 0).all(axis=(1, 2, 3))) + waveforms[mask, 0, 0, 0] = 1e-9 + metadata[mask, 0, 0] = 1e-9 + + return waveforms, metadata + + def get_result(self, waveforms, metadata, magnitude, target, pga_targets, pga_values): + """ + get result. + """ + inputs = [ms.tensor(waveforms, dtype=ms.float32), ms.tensor(metadata, dtype=ms.float32)] + outputs = [] + if not self.no_event_token: + outputs += [ms.tensor(magnitude, dtype=ms.float32)] + + if self.coords_target: + target = np.expand_dims(target, axis=-1) + outputs += [ms.tensor(target, dtype=ms.float32)] + + if self.pga_targets: + inputs += [ms.tensor(pga_targets, dtype=ms.float32)] + outputs += [ms.tensor(pga_values, dtype=ms.float32)] + + return inputs, outputs + + def on_epoch_end(self): + """ + Resets the indexes for a new epoch, optionally with oversampling and shuffling. + """ + self.indexes = np.repeat(self.base_indexes.copy(), self.oversample, axis=0) + if self.shuffle: + np.random.shuffle(self.indexes) + + def location_transformation(self, metadata, target=None): + """ + Transforms the event coordinates and optionally the target coordinates. + """ + transform_target_only = self.transform_target_only + metadata = metadata.copy() + + metadata_old = metadata + metadata = metadata.copy() + mask = (metadata == 0).all(axis=2) + if target is not None: + target[:, 0] -= self.pos_offset[0] + target[:, 1] -= self.pos_offset[1] + metadata[:, :, 0] -= self.pos_offset[0] + metadata[:, :, 1] -= self.pos_offset[1] + + # Coordinates to kilometers (assuming a flat earth, which is okay close to equator) + if self.scale_metadata: + metadata[:, :, :2] *= D2KM + if target is not None: + target[:, :2] *= D2KM + + metadata[mask] = 0 + + if self.scale_metadata: + metadata /= 100 + if target is not None: + target /= 100 + + if transform_target_only: + metadata = metadata_old + + if target is None: + return metadata + + return metadata, target + +def generator_from_config( + config, + data, + data_path, + event_key, + event_metadata, + time, + sampling_rate=100, + pga=False, +): + """init generator""" + generator_params = config["data"] + cutout = int(sampling_rate * (generator_params["noise_seconds"] + time)) + cutout = (cutout, cutout + 1) + + n_pga_targets = config["model"].get("n_pga_targets", 0) + if "data_path" in generator_params: + del generator_params["data_path"] + + generator = PreloadedEventGenerator( + data_path=data_path, + event_key=event_key, + data=data, + event_metadata=event_metadata, + coords_target=True, + cutout=cutout, + pga_targets=n_pga_targets, + sampling_rate=sampling_rate, + select_first=True, + shuffle=False, + pga_mode=pga, + **generator_params, + ) + + return generator diff --git a/MindEarth/applications/earthquake/G-TEAM/src/forcast.py b/MindEarth/applications/earthquake/G-TEAM/src/forcast.py new file mode 100644 index 0000000000000000000000000000000000000000..b1752b8e4ac55119e310e2698bf312674f7d6ec5 --- /dev/null +++ b/MindEarth/applications/earthquake/G-TEAM/src/forcast.py @@ -0,0 +1,521 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"GTeam forcast" +import os +from tqdm import tqdm +import numpy as np + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops + +from src.utils import ( + predict_at_time, + calc_mag_stats, + calc_loc_stats, + calc_pga_stats, +) +from src.data import DataGenerator, PreloadedEventGenerator, load_pickle_data, load_data +from src.models import SingleStationModel +from src.utils import evaluation, seed_np_tf +from src.visual import generate_true_pred_plot + + +class CustomWithLossCell(nn.Cell): + """ + A neural network cell that wraps a main network and loss function together, + allowing the entire forward pass including loss computation to be treated as a single cell. + + This class combines a neural network model and a loss function into a single computation unit, + which is useful for training loops and model encapsulation in deep learning frameworks. + + Attributes: + net (nn.Cell): The main neural network model whose output will be used in loss computation. + loss_fn (nn.Cell): The loss function cell that computes the difference between predictions + and true labels. + """ + + def __init__(self, net, loss_fn): + """ + Initializes the CustomWithLossCell with a network model and loss function. + + Args: + net (nn.Cell): The neural network model whose output will be used for loss calculation. + loss_fn (nn.Cell): The loss computation function that takes (true_labels, predictions) + and returns a scalar loss value. + """ + super().__init__() + self.net = net + self.loss_fn = loss_fn + + def construct(self, x, y): + ''' + Computes the loss by first passing input data through the network and then applying the loss function. + + Args: + X (Tensor): Input data tensor containing features. + y (Tensor): Ground truth labels tensor. + + Returns: + Tensor: Computed loss value. + + Note: + The input labels 'y' are squeezed along dimension 2 to match the output shape from the network. + This ensures the loss function receives inputs of the expected shape. + ''' + outputs = self.net(x) + return self.loss_fn(y.squeeze(2), outputs) + + +class GTeamInference: + """ + Initialize the GTeamInference class. + """ + + def __init__(self, model_ins, cfg, output_dir, logger): + """ + Args: + model_ins: The model instance used for inference. + cfg: Configuration dictionary containing model and data parameters. + output_dir: Directory to save the output results. + Attributes: + model: The model instance for inference. + cfg: Configuration dictionary. + output_dir: Directory to save outputs. + pga: Flag indicating if PGA (Peak Ground Acceleration) is enabled. + generator_params: Parameters for data generation. + model_params: Parameters specific to the model. + mag_key: Key for magnitude-related data. + pos_offset: Position offset for location predictions. + mag_stats: List to store magnitude prediction statistics. + loc_stats: List to store location prediction statistics. + pga_stats: List to store PGA prediction statistics. + """ + self.model = model_ins + self.cfg = cfg + self.output_dir = output_dir + self.logger = logger + self.pga = cfg["model"].get("pga", "true") + self.generator_params = cfg["data"] + self.model_params = cfg["model"] + self.output_dir = output_dir + self.mag_key = self.generator_params["key"] + self.pos_offset = self.generator_params["pos_offset"] + self.mag_stats = [] + self.loc_stats = [] + self.pga_stats = [] + + def _parse_predictions(self, pred): + """ + Parse the raw predictions into magnitude, location, and PGA components. + """ + mag_pred = pred[0] + loc_pred = pred[1] + pga_pred = pred[2] if self.pga else [] + return mag_pred, loc_pred, pga_pred + + def _process_predictions( + self, mag_pred, loc_pred, pga_pred, time, evt_metadata, pga_true + ): + """ + Process the parsed predictions to compute statistics and generate plots. + """ + mag_pred_np = [t[0].asnumpy() for t in mag_pred] + mag_pred_reshaped = np.concatenate(mag_pred_np, axis=0) + + loc_pred_np = [t[0].asnumpy() for t in loc_pred] + loc_pred_reshaped = np.array(loc_pred_np) + + pga_pred_np = [t.asnumpy() for t in pga_pred] + pga_pred_reshaped = np.concatenate(pga_pred_np, axis=0) + pga_true_reshaped = np.log( + np.abs(np.concatenate(pga_true, axis=0).reshape(-1, 1)) + ) + + if not self.model_params["no_event_token"]: + self.mag_stats += calc_mag_stats( + mag_pred_reshaped, evt_metadata, self.mag_key + ) + + self.loc_stats += calc_loc_stats( + loc_pred_reshaped, evt_metadata, self.pos_offset + ) + + generate_true_pred_plot( + mag_pred_reshaped, + evt_metadata[self.mag_key].values, + time, + self.output_dir, + ) + self.pga_stats = calc_pga_stats(pga_pred_reshaped, pga_true_reshaped) + + def _save_results(self): + """ + Save the final results (magnitude, location, and PGA statistics) to a JSON file. + """ + times = self.cfg["model"].get("times") + self.logger.info("times: {}".format(times)) + self.logger.info("mag_stats: {}".format(self.mag_stats)) + self.logger.info("loc_stats: {}".format(self.loc_stats)) + self.logger.info("pga_stats: {}".format(self.pga_stats)) + + def test(self): + """ + Perform inference for all specified times, process predictions, and save results. + This method iterates over the specified times, performs predictions, processes + the results, and saves the final statistics. + """ + data_data, evt_key, evt_metadata, meta_data, data_path = load_data(self.cfg) + pga_true = data_data["pga"] + for time in self.cfg["model"].get("times"): + pred = predict_at_time( + self.model, + time, + data_data, + data_path, + evt_key, + evt_metadata, + config=self.cfg, + pga=self.pga, + sampling_rate=meta_data["sampling_rate"], + ) + mag_pred, loc_pred, pga_pred = self._parse_predictions(pred) + self._process_predictions( + mag_pred, loc_pred, pga_pred, time, evt_metadata, pga_true + ) + self._save_results() + print("Inference completed and results saved") + +class GTeamTrain: + """ + A class to handle the training of a full model for earthquake detection and localization. + It manages data loading, training of single-station models, and full-model training. + """ + def __init__(self, model_ins, cfg, output_dir, logger): + """ + Initialize the GTeamTrain class with model, configuration, output directory, and logger. + Args: + model_ins (nn.Cell): The full model instance to be trained. + cfg (dict): Configuration dictionary containing training parameters and paths. + output_dir (str): Directory to save checkpoints and outputs. + logger (logging.Logger): Logger instance for logging messages. + """ + self.full_model = model_ins + self.cfg = cfg + self.output_dir = output_dir + self.logger = logger + self.waveform_shape = [3000, 3] + self.training_params = self.cfg['training_params'] + self.generator_params = self.training_params.get('generator_params', [self.training_params.copy()]) + self.file_basename = os.path.basename(self.training_params['data_path']).split('.')[0] + + def load_train_data(self): + """ + Load training data from a pickle file. + Returns: + Data structure: The loaded training data. + """ + data_path = self.cfg['data']["root_dir"] + filename_train = os.path.join(data_path, f"{self.file_basename}_train.pkl") + return load_pickle_data(filename_train) + + def load_val_data(self): + """ + Load validation data from a pickle file. + Returns: + Data structure: The loaded validation data. + """ + data_path = self.cfg['data']["root_dir"] + filename_val = os.path.join(data_path, f"{self.file_basename}_val.pkl") + return load_pickle_data(filename_val) + + def init_single_generator(self, sampling_rate, event_metadata_index_train, event_key_train, + event_metadata_index_val, event_key_val, decimate_train): + """ + Initialize the single-station model and its data generators for training and validation. + Args: + sampling_rate (float): Sampling rate of the seismic data. + event_metadata_index_train (list): Indices for training events in the metadata. + event_key_train (str): Key for selecting the training event data. + event_metadata_index_val (list): Indices for validation events in the metadata. + event_key_val (str): Key for selecting the validation event data. + decimate_train (bool): Whether to decimate the training data. + """ + self.single_station_model = SingleStationModel(output_mlp_dims=self.cfg['model']['output_mlp_dims'], + use_mlp=self.cfg['model']['use_mlp']) + noise_seconds = self.generator_params[0].get('noise_seconds', 5) + cutout = (sampling_rate * (noise_seconds + self.generator_params[0]['cutout_start']), + sampling_rate * (noise_seconds + self.generator_params[0]['cutout_end'])) + self.single_train_generator = DataGenerator(self.training_params['data_path'], + event_metadata_index_train, event_key_train, + mag_key=self.generator_params[0]['key'], + batch_size=self.generator_params[0]['batch_size'], + cutout=cutout, + label_smoothing=True, + sliding_window=self.generator_params[0].get('sliding_window', + False), + decimate=decimate_train) + self.single_validation_generator = DataGenerator(self.training_params['data_path'], + event_metadata_index_val, event_key_val, + mag_key=self.generator_params[0]['key'], + batch_size=self.generator_params[0]['batch_size'], + cutout=cutout, + label_smoothing=True, + sliding_window=self.generator_params[0].get('sliding_window', + False), + decimate=decimate_train) + optimizer_single = nn.Adam(self.single_station_model.trainable_params(), learning_rate=1e-4) + self.criterion_single_mse = nn.MSELoss() + + loss_net = CustomWithLossCell(self.single_station_model, self.criterion_single_mse) + self.single_train_network = nn.TrainOneStepCell(loss_net, optimizer_single) + + self.single_station_model.set_train(True) + + def single_station_train(self, sampling_rate, event_metadata_index_train, event_key_train, + event_metadata_index_val, event_key_val, decimate_train): + """ + Train the single-station model. Loads a pre-trained model if specified, otherwise + initializes the generator and trains from scratch. + Args: + sampling_rate (float): Sampling rate of the seismic data. + event_metadata_index_train (list): Indices for training events in the metadata. + event_key_train (str): Key for selecting the training event data. + event_metadata_index_val (list): Indices for validation events in the metadata. + event_key_val (str): Key for selecting the validation event data. + decimate_train (bool): Whether to decimate the training data. + """ + if 'single_station_model_path' in self.training_params: + print('Loading single station model') + param_dict = ms.load_checkpoint(self.training_params['single_station_model_path']) + ms.load_param_into_net(self.single_station_model, param_dict) + elif 'transfer_model_path' not in self.training_params: + self.init_single_generator(sampling_rate, event_metadata_index_train, event_key_train, + event_metadata_index_val, event_key_val, decimate_train) + + for epoch in tqdm(range(self.training_params['epochs_single_station']), + desc='training single station model'): + train_loss = 0.0 + + for i in range(len(self.single_train_generator)): + x, y = self.single_train_generator[i] + loss = self.single_train_network(x, y) + train_loss += loss.asnumpy() + + train_loss /= len(self.single_train_generator) + + val_loss = 0.0 + for i in range(len(self.single_validation_generator)): + x, y = self.single_validation_generator[i] + outputs = self.single_station_model(x) + loss = self.criterion_single_mse(y.squeeze(2), outputs) + val_loss += loss.item() + + val_loss /= len(self.single_validation_generator) + + print(f'Epoch {epoch + 1}/{self.training_params["epochs_single_station"]}, ' + f'Training Loss: {train_loss}, Validation Loss: {val_loss}') + + ms.save_checkpoint(self.single_station_model, + os.path.join(self.output_dir, f'single-station-{epoch + 1}')) + + def init_full_generator(self, sampling_rate, event_key_train, data_train, event_metadata_train, + max_stations, event_key_val, data_val, event_metadata_val): + """ + Initialize the full model's data generators and optimizer. + Args: + sampling_rate (float): Sampling rate of the seismic data. + event_key_train (str): Key for selecting the training event data. + data_train: Training data. + event_metadata_train: Metadata for training events. + max_stations (int): Maximum number of stations to consider. + event_key_val (str): Key for selecting the validation event data. + data_val: Validation data. + event_metadata_val: Metadata for validation events. + """ + if 'load_model_path' in self.training_params: + print('Loading full model') + param_dict = ms.load_checkpoint(self.training_params['load_model_path']) + ms.load_param_into_net(self.full_model, param_dict) + + n_pga_targets = self.cfg['model'].get('n_pga_targets', 0) + no_event_token = self.cfg['model'].get('no_event_token', False) + + self.optimizer_full = nn.Adam(self.full_model.trainable_params(), learning_rate=1e-4) + self.losses_full_mse = {'magnitude': nn.MSELoss(), 'location': nn.MSELoss(), 'pga': nn.MSELoss()} + + generator_param_set = self.generator_params[0] + noise_seconds = generator_param_set.get('noise_seconds', 5) + cutout = (sampling_rate * (noise_seconds + generator_param_set['cutout_start']), + sampling_rate * (noise_seconds + generator_param_set['cutout_end'])) + + generator_param_set['transform_target_only'] = generator_param_set.get('transform_target_only', True) + + if 'data_path' in generator_param_set: + del generator_param_set['data_path'] + + self.full_train_generator = PreloadedEventGenerator(self.training_params['data_path'], + event_key_train, + data_train, + event_metadata_train, + waveform_shape=self.waveform_shape, + coords_target=True, + label_smoothing=True, + station_blinding=True, + cutout=cutout, + pga_targets=n_pga_targets, + max_stations=max_stations, + sampling_rate=sampling_rate, + no_event_token=no_event_token, + **generator_param_set) + + old_oversample = generator_param_set.get('oversample', 1) + generator_param_set['oversample'] = 4 + + self.full_validation_generator = PreloadedEventGenerator(self.training_params['data_path'], + event_key_val, + data_val, + event_metadata_val, + waveform_shape=self.waveform_shape, + coords_target=True, + station_blinding=True, + cutout=cutout, + pga_targets=n_pga_targets, + max_stations=max_stations, + sampling_rate=sampling_rate, + no_event_token=no_event_token, + **generator_param_set) + + generator_param_set['oversample'] = old_oversample + print('len(full_train_generator)', len(self.full_train_generator)) + + self.loss_weights = self.training_params['loss_weights'] + print(f'The total number of parameters: {sum(p.numel() for p in self.full_model.trainable_params())}') + + def full_station_train(self, sampling_rate, event_key_train, data_train, event_metadata_train, + max_stations, event_key_val, data_val, event_metadata_val): + """ + Train the full station model using the initialized generators and optimizer. + + Args: + sampling_rate (float): Sampling rate of the seismic data + event_key_train (str): Key for selecting training event data + data_train: Training data + event_metadata_train: Training event metadata + max_stations (int): Maximum number of stations to consider + event_key_val (str): Key for selecting validation event data + data_val: Validation data + event_metadata_val: Validation event metadata + """ + self.init_full_generator(sampling_rate, event_key_train, data_train, event_metadata_train, + max_stations, event_key_val, data_val, event_metadata_val) + def calculate_total_loss(network, x, y): + train_mag_loss = 0 + train_loc_loss = 0 + train_pga_loss = 0 + outputs = network(x[0], x[1], x[2]) + total_loss = 0 + for k, loss_fn in self.losses_full_mse.items(): + if k == 'magnitude': + mag_pre = outputs[0] + mag_target = y[0] + mag_loss = loss_fn(mag_target.squeeze(2), mag_pre) * self.loss_weights[k] + train_mag_loss += mag_loss + total_loss += mag_loss + elif k == 'location': + loc_pre = outputs[1] + loc_target = y[1] + loc_loss = loss_fn(loc_target.squeeze(2), loc_pre) * self.loss_weights[k] + train_loc_loss += loc_loss + total_loss += loc_loss + elif k == 'pga': + pga_pre = outputs[2] + if 'italy' in self.file_basename: + pga_target = y[2] + else: + pga_target = ops.log(ops.abs(y[2])) + pga_loss = loss_fn(pga_target.squeeze(3), pga_pre) * self.loss_weights[k] + train_pga_loss += pga_loss + total_loss += pga_loss + return total_loss + + self.full_model.set_train() + grad_fn = ms.value_and_grad( + fn=calculate_total_loss, + grad_position=None, + weights=self.full_model.trainable_params(), + has_aux=False + ) + for epoch in tqdm(range(self.training_params['epochs_full_model']), desc='training full model'): + train_loss = 0 + + for i in range(len(self.full_train_generator)): + x, y = self.full_train_generator[i] + + total_loss, grads = grad_fn(self.full_model, x, y) + self.optimizer_full(grads) + + train_loss += total_loss.item() + avg_train_loss = train_loss / len(self.full_train_generator) + + avg_val_loss = evaluation(self.full_model, self.full_validation_generator, + self.losses_full_mse, self.loss_weights) + + print(f'Epoch {epoch + 1}/{self.training_params["epochs_full_model"]}', + f'Average Training Loss: {avg_train_loss}', f'Average val Loss: {avg_val_loss}') + + ms.save_checkpoint(self.full_model, os.path.join(self.output_dir, f'event-{epoch + 1}')) + + print('Training complete, and loss history saved.') + + def train(self): + """ + Train the full model for earthquake detection and localization. + + This method orchestrates the training process by: + 1. Setting the random seed for reproducibility. + 2. Loading training and validation datasets. + 3. Extracting key parameters like sampling rate and event metadata. + 4. Training single-station models for each station in the dataset. + 5. Training the full multi-station model using the pre-trained single-station models. + + Steps: + - Initialize random seed from configuration (default: 42) + - Load training data and extract metadata + - Load validation data + - Extract sampling rate and remove 'max_stations' from model config + - Train single-station models using training and validation data + - Train full model using combined data from all stations + + Note: This method assumes that the `single_station_train` and `full_station_train` methods are implemented. + """ + seed_np_tf(self.cfg['training_params'].get('seed', 42)) + + print('Loading data') + (event_metadata_index_train, event_metadata_train, metadata_train, + data_train, event_key_train, decimate_train) = self.load_train_data() + (event_metadata_index_val, event_metadata_val, _, + data_val, event_key_val, _) = self.load_val_data() + + sampling_rate = metadata_train['sampling_rate'] + max_stations = self.cfg['model']['max_stations'] + del self.cfg['model']['max_stations'] + + print('training') + self.single_station_train(sampling_rate, event_metadata_index_train, event_key_train, + event_metadata_index_val, event_key_val, decimate_train) + + self.full_station_train(sampling_rate, event_key_train, data_train, event_metadata_train, + max_stations, event_key_val, data_val, event_metadata_val) diff --git a/MindEarth/applications/earthquake/G-TEAM/src/models.py b/MindEarth/applications/earthquake/G-TEAM/src/models.py new file mode 100644 index 0000000000000000000000000000000000000000..57a23019794223ba2af95741a36521a98d19d706 --- /dev/null +++ b/MindEarth/applications/earthquake/G-TEAM/src/models.py @@ -0,0 +1,470 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"GTeam model" +import numpy as np + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops + + +class MLP(nn.Cell): + """ + A Multi-Layer Perceptron (MLP) class using MindSpore's nn.Cell. + Parameters: + input_shape: Tuple representing the shape of the input data. + dims: Tuple containing the dimensions of each layer. Default is (100, 50). + final_activation: The activation function for the final layer. Default is nn.ReLU. + """ + + def __init__(self, input_shape, dims=(100, 50), final_activation=nn.ReLU(), is_mlp=False): + super().__init__() + layers = [] + in_dim = input_shape[0] + if is_mlp: + for dim in dims[:-1]: + layers.append(nn.Dense(in_dim, dim)) + layers.append(nn.LayerNorm((dim,))) + layers.append(nn.ReLU()) + in_dim = dim + layers.append(nn.Dense(in_dim, dims[-1])) + + if final_activation: + layers.append(final_activation) + self.model = nn.SequentialCell(*layers) + else: + for dim in dims[:-1]: + layers.append(nn.Dense(in_dim, dim)) + layers.append(nn.ReLU()) + in_dim = dim + layers.append(nn.Dense(in_dim, dims[-1])) + + if final_activation: + layers.append(final_activation) + self.model = nn.SequentialCell(*layers) + + def construct(self, x): + """ + Forward pass through the network. + Parameters: + x: Input data to the MLP. + Returns: + Output after passing through the MLP. + """ + return self.model(x) + + +class NormalizedScaleEmbedding(nn.Cell): + """ + A neural network module that normalizes input data, extracts features using a series of + convolutional and pooling layers, and processes the features through a multi-layer perceptron (MLP). + """ + + def __init__(self, downsample=5, mlp_dims=(500, 300, 200, 150), eps=1e-8, use_mlp=False): + """ + Initialize the module with given parameters. + Parameters: + :downsample: Downsampling factor for the first convolutional layer. + :mlp_dims: Dimensions for the MLP layers. + :eps: A small value for numerical stability. + """ + super().__init__() + self.downsample = downsample + self.mlp_dims = mlp_dims + self.eps = eps + + self.conv2d_1 = nn.Conv2d( + 1, + 8, + kernel_size=(downsample, 1), + stride=(downsample, 1), + has_bias=True, + pad_mode="pad", + ) + self.conv2d_2 = nn.Conv2d( + 8, 32, kernel_size=(16, 3), stride=(1, 1), has_bias=True, pad_mode="pad" + ) + + self.conv1d_1 = nn.Conv1d(32, 64, kernel_size=16, has_bias=True, pad_mode="pad") + self.maxpool_1 = nn.MaxPool1d(kernel_size=2, stride=2) + self.conv1d_2 = nn.Conv1d( + 64, 128, kernel_size=16, has_bias=True, pad_mode="pad" + ) + self.maxpool_2 = nn.MaxPool1d(kernel_size=2, stride=2) + self.conv1d_3 = nn.Conv1d(128, 32, kernel_size=8, has_bias=True, pad_mode="pad") + self.maxpool_3 = nn.MaxPool1d(kernel_size=2, stride=2) + self.conv1d_4 = nn.Conv1d(32, 32, kernel_size=8, has_bias=True, pad_mode="pad") + self.conv1d_5 = nn.Conv1d(32, 16, kernel_size=4, has_bias=True, pad_mode="pad") + + self.flatten = nn.Flatten() + self.mlp = MLP((865,), dims=self.mlp_dims, is_mlp=use_mlp) + self.leaky_relu = nn.LeakyReLU(alpha=0.01) + self._initialize_weights() + + def _initialize_weights(self): + self.conv2d_1.bias.set_data(ms.numpy.zeros_like(self.conv2d_1.bias)) + self.conv2d_2.bias.set_data(ms.numpy.zeros_like(self.conv2d_2.bias)) + + # For Conv1d layers + self.conv1d_1.bias.set_data(ms.numpy.zeros_like(self.conv1d_1.bias)) + self.conv1d_2.bias.set_data(ms.numpy.zeros_like(self.conv1d_2.bias)) + self.conv1d_3.bias.set_data(ms.numpy.zeros_like(self.conv1d_3.bias)) + self.conv1d_4.bias.set_data(ms.numpy.zeros_like(self.conv1d_4.bias)) + self.conv1d_5.bias.set_data(ms.numpy.zeros_like(self.conv1d_5.bias)) + + def construct(self, x): + """ + Forward pass through the network. + :param x: Input tensor. + :return: Processed output tensor. + """ + original_input = x + x = ( + x + / ( + ops.max( + ops.max(ops.abs(x), axis=1, keepdims=True)[0], axis=2, keepdims=True + )[0] + + self.eps + ) + + self.eps + ) + x = ops.unsqueeze(x, dim=1) + + scale = ( + ops.log( + ops.max(ops.max(ops.abs(original_input), axis=1)[0], axis=1)[0] + + self.eps + ) + / 100 + + self.eps + ) + scale = ops.unsqueeze(scale, dim=1) + + x = self.leaky_relu(self.conv2d_1(x)) + x = self.leaky_relu(self.conv2d_2(x)) + + tmp_x = ops.Squeeze(axis=-1) + x = tmp_x(x) + x = self.leaky_relu(self.conv1d_1(x)) + x = self.maxpool_1(x) + x = self.leaky_relu(self.conv1d_2(x)) + x = self.maxpool_2(x) + x = self.leaky_relu(self.conv1d_3(x)) + x = self.maxpool_3(x) + x = self.leaky_relu(self.conv1d_4(x)) + x = self.leaky_relu(self.conv1d_5(x)) + + x = self.flatten(x) + x = ops.cat((x, scale), axis=1) + x = self.mlp(x) + return x + + +class TransformerEncoder(nn.Cell): + """ + TransformerEncoder class, used to implement the Transformer encoder. + Parameters: + d_model: Dimension of the input data. + nhead: Number of heads in multi-head attention. + num_layers: Number of layers in the encoder. + batch_first: Whether to consider the first dimension of the input data as the batch dimension. + activation: Type of activation function. + dim_feedforward: Dimension of the hidden layer in the feedforward network. + dropout: Proportion of dropout. + Methods: + __init__: Initialize the TransformerEncoder object. + construct: Construct the TransformerEncoder network. + """ + + def __init__( + self, + d_model=500, + nhead=10, + num_layers=6, + batch_first=True, + activation="gelu", + dim_feedforward=1000, + dropout=0.0, + ): + super().__init__() + self.encoder_layer = nn.TransformerEncoderLayer( + d_model=d_model, + nhead=nhead, + batch_first=batch_first, + dim_feedforward=dim_feedforward, + dropout=dropout, + activation=activation, + ) + self.transformer_encoder = nn.TransformerEncoder( + self.encoder_layer, num_layers=num_layers + ) + + def construct(self, x, src_key_padding_mask=None): + """Construct the TransformerEncoder network""" + return self.transformer_encoder(x, src_key_padding_mask=src_key_padding_mask) + + +class PositionEmbedding(nn.Cell): + """ + PositionEmbedding class, used to implement position embeddings. + Parameters: + wavelengths: Range of wavelengths. + emb_dim: Dimension of the embeddings. + Methods: + __init__: Initialize the PositionEmbedding object. + construct: Construct the PositionEmbedding network. + """ + + def __init__(self, wavelengths, emb_dim): + super().__init__() + self.wavelengths = wavelengths + self.emb_dim = emb_dim + + min_lat, max_lat = wavelengths[0] + min_lon, max_lon = wavelengths[1] + min_depth, max_depth = wavelengths[2] + if emb_dim % 10 != 0: + raise ValueError(f"emb_dim must be divisible by 10, but got {emb_dim}") + lat_dim = emb_dim // 5 + lon_dim = emb_dim // 5 + depth_dim = emb_dim // 10 + self.lat_coeff = ( + 2 + * np.pi + * 1.0 + / min_lat + * ((min_lat / max_lat) ** (np.arange(lat_dim) / lat_dim)) + ) + self.lon_coeff = ( + 2 + * np.pi + * 1.0 + / min_lon + * ((min_lon / max_lon) ** (np.arange(lon_dim) / lon_dim)) + ) + self.depth_coeff = ( + 2 + * np.pi + * 1.0 + / min_depth + * ((min_depth / max_depth) ** (np.arange(depth_dim) / depth_dim)) + ) + lat_sin_mask = np.arange(emb_dim) % 5 == 0 + lat_cos_mask = np.arange(emb_dim) % 5 == 1 + lon_sin_mask = np.arange(emb_dim) % 5 == 2 + lon_cos_mask = np.arange(emb_dim) % 5 == 3 + + depth_sin_mask = np.arange(emb_dim) % 10 == 4 + depth_cos_mask = np.arange(emb_dim) % 10 == 9 + + self.mask = np.zeros(emb_dim) + self.mask[lat_sin_mask] = np.arange(lat_dim) + self.mask[lat_cos_mask] = lat_dim + np.arange(lat_dim) + self.mask[lon_sin_mask] = 2 * lat_dim + np.arange(lon_dim) + self.mask[lon_cos_mask] = 2 * lat_dim + lon_dim + np.arange(lon_dim) + self.mask[depth_sin_mask] = 2 * lat_dim + 2 * lon_dim + np.arange(depth_dim) + self.mask[depth_cos_mask] = ( + 2 * lat_dim + 2 * lon_dim + depth_dim + np.arange(depth_dim) + ) + self.mask = ms.tensor(self.mask.astype("int32")) + + def construct(self, x): + """position embedding""" + lat_base = x[:, :, 0:1] * ms.tensor(self.lat_coeff, dtype=ms.float32) + lon_base = x[:, :, 1:2] * ms.tensor(self.lon_coeff, dtype=ms.float32) + depth_base = x[:, :, 2:3] * ms.tensor(self.depth_coeff, dtype=ms.float32) + + output = ops.cat( + [ + ops.sin(lat_base), + ops.cos(lat_base), + ops.sin(lon_base), + ops.cos(lon_base), + ops.sin(depth_base), + ops.cos(depth_base), + ], + axis=-1, + ) + output = ops.index_select(output, axis=-1, index=self.mask) + + return output + + +class AddEventToken(nn.Cell): + """ + AddEventToken class, used to implement adding event tokens. + + Parameters: + emb_dim: Dimension of the embeddings. + init_range: Initialization range. + + Methods: + __init__: Initialize the AddEventToken object. + construct: Construct the AddEventToken network. + """ + + def __init__(self, emb_dim, init_range): + super().__init__() + self.emb_dim = emb_dim + init_value = np.random.uniform(-init_range, init_range, (1, 1, emb_dim)).astype( + np.float32 + ) + self.event_token = ms.Parameter(ms.Tensor(init_value), name="event_token") + + def construct(self, x): + """add eventtoken""" + event_token = self.event_token + pad = ops.ones_like(x[:, :1, :]) * event_token + x = ops.cat([pad, x], axis=1) + + return x + +class SingleStationModel(nn.Cell): + """ + A neural network model for processing seismic waveforms from a single station. + This class implements a two-stage processing pipeline: waveform embedding followed by feature extraction. + """ + def __init__(self, waveform_model_dims=(500, 500, 500), + output_mlp_dims=(150, 100, 50, 30, 10), downsample=5, use_mlp=False): + """ + Initialize the SingleStationModel. + + Args: + waveform_model_dims (tuple): Dimensions of the MLP in the waveform embedding module. + Format: (input_dim, hidden_dim1, hidden_dim2, ...) + output_mlp_dims (tuple): Dimensions of the final MLP for feature extraction. + Format: (input_dim, hidden_dim1, hidden_dim2, ...) + downsample (int): Factor by which to downsample the input waveform data. + """ + super().__init__() + + self.waveform_model = NormalizedScaleEmbedding(downsample=downsample, mlp_dims=waveform_model_dims, + use_mlp=use_mlp) + self.mlp_mag_single_station = MLP((self.waveform_model.mlp_dims[-1],), output_mlp_dims) + + def construct(self, x): + """ + Forward pass of the SingleStationModel. + + Args: + x (Tensor): Input waveform data with shape (batch_size, time_steps, features) + + Returns: + Tensor: Extracted features with shape (batch_size, output_features) + """ + emb = self.waveform_model(x) + emb_mlp = self.mlp_mag_single_station(emb) + + return emb_mlp +def _init_pad_mask(waveforms, pga_targets): + """ + _init_pad_mask function, used to initialize the padding mask. + """ + station_pad_mask = abs(waveforms) < 1e-8 + station_pad_mask = ops.all(station_pad_mask, axis=2) + station_pad_mask = ops.all(station_pad_mask, axis=2) + + event_token_mask = ops.zeros((station_pad_mask.shape[0], 1), dtype=ms.dtype.bool_) + pad_mask = ops.cat([event_token_mask, station_pad_mask], axis=1) + + target_pad_mask = ms.numpy.ones_like(pga_targets, dtype=ms.dtype.bool_) + target_pad_mask = ops.all(target_pad_mask, 2) + + pad_mask = ops.cat((pad_mask, target_pad_mask), axis=1) + + return pad_mask + + +class WaveformFullmodel(nn.Cell): + """ + Waveform full model class, used for processing and predicting waveform data." + """ + + def __init__( + self, + waveform_model_dims=(500, 500, 500), + output_mlp_dims=(150, 100, 50, 30, 10), + output_location_dims=(150, 100, 50, 50, 50), + wavelength=((0.01, 10), (0.01, 10), (0.01, 10)), + n_heads=10, + hidden_dim=1000, + transformer_layers=6, + hidden_dropout=0.0, + n_pga_targets=0, + downsample=5, + use_mlp=False + ): + super().__init__() + self.waveform_model = NormalizedScaleEmbedding( + downsample=downsample, mlp_dims=waveform_model_dims, use_mlp=use_mlp + ) + self.transformer = TransformerEncoder( + d_model=waveform_model_dims[-1], + nhead=n_heads, + num_layers=transformer_layers, + dim_feedforward=hidden_dim, + dropout=hidden_dropout, + ) + + self.mlp_mag = MLP((waveform_model_dims[-1],), output_mlp_dims, is_mlp=use_mlp) + self.mlp_loc = MLP( + (waveform_model_dims[-1],), output_location_dims, final_activation=None, is_mlp=use_mlp + ) + self.mlp_pga = MLP( + (waveform_model_dims[-1],), output_mlp_dims, final_activation=None, is_mlp=use_mlp + ) + + self.position_embedding = PositionEmbedding( + wavelengths=wavelength, emb_dim=waveform_model_dims[-1] + ) + self.addeventtoken = AddEventToken(emb_dim=500, init_range=0.02) + self.n_pga_targets = n_pga_targets + + def cal_waveforms_emb_normalized(self, waveforms_emb): + """Normalize the waveform embeddings""" + mean_vals = waveforms_emb.mean(axis=2, keep_dims=True) + std_vals = waveforms_emb.std(axis=2, keepdims=True) + waveforms_emb_normalized = (waveforms_emb - mean_vals) / (std_vals + 1e-8) + return waveforms_emb_normalized + + def construct(self, waveforms, metadata, pga_targets): + """ + Construct method to process the input waveforms, metadata, and PGA targets. + """ + batch_size, num_stations, seq_length, num_channels = waveforms.shape + waveforms_reshape = waveforms.reshape(-1, seq_length, num_channels) + + waveforms_emb = self.waveform_model(waveforms_reshape) + waveforms_emb = waveforms_emb.reshape(batch_size, num_stations, -1) + waveforms_emb_normalized = self.cal_waveforms_emb_normalized(waveforms_emb) + coords_emb = self.position_embedding(metadata) + pga_target_emb = self.position_embedding(pga_targets) + pad_mask = _init_pad_mask(waveforms, pga_targets) + + emb_pos = waveforms_emb_normalized + coords_emb + emb_pos = self.addeventtoken(emb_pos) + emb_pos_pga = ops.cat((emb_pos, pga_target_emb), axis=1) + emb_pos_pga_trans = self.transformer(emb_pos_pga, pad_mask) + emb_pga = emb_pos_pga_trans[:, -self.n_pga_targets :, :] + emb_mag_loc = emb_pos_pga_trans[:, 0, :] + + mag = self.mlp_mag(emb_mag_loc) + loc = self.mlp_loc(emb_mag_loc) + + pga_all = self.mlp_pga(emb_pga) + outputs = [mag, loc, pga_all] + + return outputs diff --git a/MindEarth/applications/earthquake/G-TEAM/src/utils.py b/MindEarth/applications/earthquake/G-TEAM/src/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f8ae9cabf85584f8adcb6d33eb1eaf4acec0cffc --- /dev/null +++ b/MindEarth/applications/earthquake/G-TEAM/src/utils.py @@ -0,0 +1,219 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""GTeam util""" +import os +import copy +import numpy as np +import sklearn.metrics as metrics +from geopy.distance import geodesic + +import mindspore as ms +import mindspore.ops as ops +from mindearth import create_logger + +from src import data +from src.data import generator_from_config, D2KM +from src.models import WaveformFullmodel + +def predict_at_time( + model, + time, + data_data, + data_path, + event_key, + event_metadata, + config, + sampling_rate=100, + pga=False, +): + """Predict at a specific time point""" + generator = generator_from_config( + config, + data_data, + data_path, + event_key, + event_metadata, + time, + sampling_rate=sampling_rate, + pga=pga, + ) + + pred_list_mag = [] + pred_list_loc = [] + pred_list_pga = [] + for i in range(len(generator)): + x, _ = generator[i] + + pred = model(x[0], x[1], x[2]) + pred_list_mag.append(pred[0]) + pred_list_loc.append(pred[1]) + pred_list_pga.append(pred[2]) + + pre_mag = ops.cat(pred_list_mag, axis=0) + pre_loc = ops.cat(pred_list_loc, axis=0) + pre_pga = ops.cat(pred_list_pga, axis=0) + predictions = [pre_mag, pre_loc, pre_pga] + + mag_pred_filter = [] + loc_pred_filter = [] + pga_pred_filter = [] + + for i, (start, end) in enumerate(zip(generator.reverse_index[:-1], generator.reverse_index[1:])): + sample_mag_pred = predictions[0][start:end].reshape((-1,) + predictions[0].shape[-1:]) + sample_mag_pred = sample_mag_pred[:len(generator.pga[i])] + mag_pred_filter += [sample_mag_pred] + + sample_loc_pred = predictions[1][start:end].reshape((-1,) + predictions[1].shape[-1:]) + sample_loc_pred = sample_loc_pred[:len(generator.pga[i])] + loc_pred_filter += [sample_loc_pred] + + sample_pga_pred = predictions[2][start:end].reshape((-1,) + predictions[2].shape[-1:]) + sample_pga_pred = sample_pga_pred[:len(generator.pga[i])] + pga_pred_filter += [sample_pga_pred] + + preds = [mag_pred_filter, loc_pred_filter, pga_pred_filter] + + return preds + +def calc_mag_stats(mag_pred, event_metadata, key): + """Calculate statistical information for magnitude predictions""" + mean_mag = mag_pred + true_mag = event_metadata[key].values + # R^2 + r2 = metrics.r2_score(true_mag, mean_mag) + # RMSE + rmse = np.sqrt(metrics.mean_squared_error(true_mag, mean_mag)) + # MAE + mae = metrics.mean_absolute_error(true_mag, mean_mag) + return r2, rmse, mae + +def calc_pga_stats(pga_pred, pga_true, suffix=""): + """Calculate statistical information for PGA predictions""" + if suffix: + suffix += "_" + valid_mask = np.isfinite(pga_true) & np.isfinite(pga_pred) + pga_true_clean = pga_true[valid_mask] + pga_pred_clean = pga_pred[valid_mask] + r2 = metrics.r2_score(pga_true_clean, pga_pred_clean) + rmse = np.sqrt(metrics.mean_squared_error(pga_true_clean, pga_pred_clean)) + mae = metrics.mean_absolute_error(pga_true_clean, pga_pred_clean) + + return [r2, rmse, mae] + +def calc_loc_stats(loc_pred, event_metadata, pos_offset): + """Calculate statistical information for location predictions""" + coord_keys = data.detect_location_keys(event_metadata.columns) + true_coords = event_metadata[coord_keys].values + mean_coords = loc_pred + mean_coords *= 100 + mean_coords[:, :2] /= D2KM + mean_coords[:, 0] += pos_offset[0] + mean_coords[:, 1] += pos_offset[1] + + dist_epi = np.zeros(len(mean_coords)) + dist_hypo = np.zeros(len(mean_coords)) + real_dep = np.zeros(len(mean_coords)) + pred_dep = np.zeros(len(mean_coords)) + for i, (pred_coord, true_coord) in enumerate(zip(mean_coords, true_coords)): + dist_epi[i] = geodesic(pred_coord[:2], true_coord[:2]).km + dist_hypo[i] = np.sqrt(dist_epi[i] ** 2 + (pred_coord[2] - true_coord[2]) ** 2) + real_dep[i] = true_coord[2] + pred_dep[i] = pred_coord[2] + + rmse_epi = np.sqrt(np.mean(dist_epi**2)) + mae_epi = np.mean(np.abs(dist_epi)) + + rmse_hypo = np.sqrt(np.mean(dist_hypo**2)) + mae_hypo = np.mean(dist_hypo) + + return rmse_hypo, mae_hypo, rmse_epi, mae_epi + + +def seed_np_tf(seed): + '''Set the random seed for numpy and manual seed for mindspore.''' + np.random.seed(seed) + ms.manual_seed(seed) + + +def evaluation(full_model, val_generator, losses, loss_weights): + """ + Evaluates the performance of the full_model on the validation data provided by val_generator. + Calculates the average validation loss by accumulating losses from different components (magnitude, location, pga) + using the specified loss functions and weights. + Args: + full_model (nn.Cell): The complete model to be evaluated in inference mode. + val_generator (generator): A generator that yields batches of validation data (x, y). + Each x is expected to be a tuple of three input tensors, and y is a tuple of three target tensors. + losses (dict): A dictionary mapping loss names to their respective loss functions. + Supported keys: 'magnitude', 'location', 'pga'. + loss_weights (dict): A dictionary mapping loss names to their corresponding weights. + Returns: + float: The average validation loss over the entire validation dataset. + """ + full_model.set_train(False) + epoch_val_loss = 0 + for i in range(len(val_generator)): + x, y = val_generator[i] + outputs = full_model(x[0], x[1], x[2]) + total_val_loss = ms.Tensor(0) + + for k, loss_fn in losses.items(): + if k == 'magnitude': + mag_pre = outputs[0] + mag_target = y[0] + mag_loss = loss_fn(mag_target.squeeze(2), mag_pre) * loss_weights[k] + total_val_loss += mag_loss + elif k == 'location': + loc_pre = outputs[1] + loc_target = y[1] + loc_loss = loss_fn(loc_target.squeeze(2), loc_pre) * loss_weights[k] + total_val_loss += loc_loss + elif k == 'pga': + pga_pre = outputs[2] + pga_target = ops.log(ops.abs(y[2])) + pga_loss = loss_fn(pga_target.squeeze(3), pga_pre) * loss_weights[k] + total_val_loss += pga_loss + epoch_val_loss += total_val_loss.item() + avg_val_loss = epoch_val_loss / len(val_generator) + return avg_val_loss +def init_model(arg): + """set model""" + tmpcfg = copy.deepcopy(arg["model"]) + tmpcfg.pop("istraining") + tmpcfg.pop("no_event_token") + tmpcfg.pop("run_with_less_data") + tmpcfg.pop("pga") + tmpcfg.pop("mode") + tmpcfg.pop("times") + tmpcfg.pop("max_stations") + model = WaveformFullmodel(**tmpcfg) + if arg['model']['istraining']: + model.set_train(True) + else: + param_dict = ms.load_checkpoint(arg["summary"].get("ckpt_path")) + ms.load_param_into_net(model, param_dict) + model.set_train(False) + return model + + +def get_logger(config): + """Get logger for saving log""" + summary_params = config.get("summary") + logger = create_logger( + path=os.path.join(summary_params.get("summary_dir"), "results.log") + ) + for key in config: + logger.info(config[key]) + return logger diff --git a/MindEarth/applications/earthquake/G-TEAM/src/visual.py b/MindEarth/applications/earthquake/G-TEAM/src/visual.py new file mode 100644 index 0000000000000000000000000000000000000000..78354a6f8fb4eb2a0b40cc4b8672a004eafda847 --- /dev/null +++ b/MindEarth/applications/earthquake/G-TEAM/src/visual.py @@ -0,0 +1,78 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"visualization" +import os +import matplotlib.pyplot as plt +import numpy as np +import sklearn.metrics as metrics + +def generate_true_pred_plot(pred_values, true_values, time, path, suffix=""): + """ + Generate a plot comparing true values and predicted values, and calculate + evaluation metrics including MAE, RMSE, R^2, and the standard deviation of residuals. + Parameters: + pred_values: List of predicted values + true_values: List of true values + time: Time, used for naming the image + path: Path to save the image + suffix: Suffix for image naming, default is an empty string + """ + if suffix: + suffix += "_" + fig = plt.figure(figsize=(9, 9)) + plt.plot(true_values, pred_values, "ok", alpha=0.2) + pred_value = pred_values + pred_value = np.array([x for x in pred_value]) + r2 = metrics.r2_score(true_values, pred_value) + rmse = np.sqrt(metrics.mean_squared_error(true_values, pred_value)) + mae = metrics.mean_absolute_error(true_values, pred_value) + + plt.text( + 0.6, + 6, + f"MAE: {mae:.2f}\nRMSE: {rmse:.2f}\n$R^{2}$: {r2:.2f}", + fontsize=30, + verticalalignment="top", + horizontalalignment="left", + ) + plt.plot(np.arange(0, 8), np.arange(0, 8), "-r") + plt.xlim(0, 7) + plt.ylim(0, 7) + ax = plt.gca() + ax.set_xlabel("True values", fontsize=20) + ax.set_ylabel("Pred values", fontsize=20) + ax.set_title(str(time) + " s", fontsize=20) + fig.savefig(os.path.join(path, f"truepred_{suffix}{time}.png"), bbox_inches="tight") + plt.close() + + residual = true_values - pred_value + fig = plt.figure(figsize=(9, 9)) + axs = fig.subplots(1, 1) + axs.hist(residual) + axs.set_xlabel("residual", fontsize=25) + axs.set_ylabel("Event Number", fontsize=25) + x_lim = axs.get_xlim() + y_lim = axs.get_ylim() + plt.text( + x_lim[1] * 0.95, + y_lim[1] * 0.95, + f"MAE: {mae:.2f}\nRMSE: {rmse:.2f}\n$R^{{2}}$: {r2:.2f}\nSTD: {np.std(residual):.2f}", + fontsize=30, + verticalalignment="top", + horizontalalignment="right", + ) + + fig.savefig(os.path.join(path, f"Residual_{suffix}{time}.png"), bbox_inches="tight") + plt.close() diff --git a/MindEarth/applications/medium-range/koopman_vit/src/callback.py b/MindEarth/applications/medium-range/koopman_vit/src/callback.py new file mode 100644 index 0000000000000000000000000000000000000000..66b4f932030b01012a799567348236142998f661 --- /dev/null +++ b/MindEarth/applications/medium-range/koopman_vit/src/callback.py @@ -0,0 +1,212 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The callback of koopman_vit""" + +import mindspore.numpy as msnp +from mindspore import nn, ops, Tensor +from mindspore.train.callback import Callback + +from mindearth.module import WeatherForecast + +from .utils import plt_key_info + + + +class CustomWithLossCell(nn.Cell): + r""" + CustomWithLossCell is used to Connect the feedforward network and multi-label loss function. + + Args: + backbone: a feedforward neural network + loss_fn: a multi-label loss function + + Inputs: + - **data** (Tensor) - The input data of feedforward neural network. Tensor of any dimension. + - **label** (Tensor) - The input label. Tensor of any dimension. + + Outputs: + Tensor + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + """ + + def __init__(self, backbone, loss_fn): + super(CustomWithLossCell, self).__init__() + self._backbone = backbone + self._loss_fn = loss_fn + + def construct(self, data, label): + output, recons = self._backbone(data) + loss = self._loss_fn(output, recons, label, data) + return loss + + +class MultiMSELoss(nn.LossBase): + r""" + MultiMSELoss is used to calculate multiple MSELoss, then weighted summation. the MSEloss is used to calculate + the mean squared error between the predicted value and the label value. + + Inputs: + - **prediction1** (Tensor) - The predicted value of the input. Tensor of any dimension. + - **prediction2** (Tensor) - The predicted value of the input. Tensor of any dimension. + - **label1** (Tensor) - The input label. Tensor of any dimension. + - **label2** (Tensor) - The input label. Tensor of any dimension. + - **weight1** (Tensor) - The coefficient of l1. + - **weight2** (Tensor) - The coefficient of l2. + + Outputs: + Tensor + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + # >>> import numpy as np + # >>> import mindspore + # >>> from mindspore import Tensor + # >>> from mindearth.loss import MultiMSELoss + # >>> # Case: prediction.shape = labels.shape = (3, 3) + # >>> prediction1 = Tensor(np.array([[1, 2, 3],[1, 2, 3],[1, 2, 3]]), mindspore.float32) + # >>> prediction2 = Tensor(np.array([[1, 2, 3],[1, 2, 3],[1, 2, 3]]), mindspore.float32) + # >>> label1 = Tensor(np.array([[1, 2, 2],[1, 2, 3],[1, 2, 3]]), mindspore.float32) + # >>> label2 = Tensor(np.array([[1, 2, 2],[1, 2, 3],[1, 2, 3]]), mindspore.float32) + # >>> loss_fn = MultiMSELoss() + # >>> loss = loss_fn(prediction1, prediction2, label1, label2) + # >>> print(loss) + 0.111111 + """ + + def __init__(self, ai, wj, sj_std, feature_dims, use_weight=False): + super(MultiMSELoss, self).__init__() + self.loss = nn.MSELoss() + self.wj = wj + self.ai = ai + self.sj_std = sj_std + self.feature_dims = feature_dims + self.use_weight = use_weight + + def construct(self, prediction1, prediction2, label1, label2, weight1=0.9, weight2=0.1): + """MultiMSELoss""" + prediction1 = prediction1.reshape(-1, self.feature_dims) + prediction2 = prediction2.reshape(-1, self.feature_dims) + label1 = label1.reshape(-1, self.feature_dims) + label2 = label2.reshape(-1, self.feature_dims) + + err1 = msnp.square(prediction1 - label1) + weighted_err1 = err1 * self.wj * self.ai / self.sj_std + l1 = msnp.average(weighted_err1) + + err2 = msnp.square(prediction2 - label2) + weighted_err2 = err2 * self.wj * self.ai / self.sj_std + l2 = msnp.average(weighted_err2) + return weight1 * l1 + weight2 * l2 + + +class InferenceModule(WeatherForecast): + """ + Perform multiple rounds of model inference. + + Args: + """ + + def __init__(self, model, config, logger): + super().__init__(model, config, logger) + self.model = model + self.config = config + self.logger = logger + + def forecast(self, inputs): + pred_lst = [] + for _ in range(self.t_out): + pred, _ = self.model(inputs) + pred_lst.append(pred.transpose(0, 2, 3, 1).reshape(self.batch_size, -1, self.feature_dims).asnumpy()) + inputs = pred + return pred_lst + + +class EvaluateCallBack(Callback): + """ + Monitor the prediction accuracy in training. + + Args: + """ + + def __init__(self, + model, + valid_dataset, + config, + logger + ): + super(EvaluateCallBack, self).__init__() + self.config = config + self.eval_time = 0 + self.model = model + self.valid_dataset = valid_dataset + self.predict_interval = config.get('summary').get("valid_frequency") + self.logger = logger + self.eval_net = InferenceModule(model, + config, + logger) + + def epoch_end(self, run_context): + """ + Evaluate the model at the end of epoch. + + Args: + run_context (RunContext): Context of the train running. + """ + cb_params = run_context.original_args() + if cb_params.cur_epoch_num % self.predict_interval == 0: + self.eval_time += 1 + lat_weight_rmse, lat_weight_acc = self.eval_net.eval(self.valid_dataset, generator_flag=True) + if self.config.get('summary').get('plt_key_info'): + plt_key_info(lat_weight_rmse, self.config, self.eval_time * self.predict_interval, metrics_type="RMSE", + loc="upper left") + plt_key_info(lat_weight_acc, self.config, self.eval_time * self.predict_interval, metrics_type="ACC", + loc="lower left") + + +class Lploss(nn.LossBase): + """Lploss""" + def __init__(self, p=2, size_average=True, reduction=True): + super(Lploss, self).__init__() + # Dimension and Lp-norm type are positive + if p <= 0: + raise ValueError(f"p must be positive, but got {p}") + self.p = p + self.reduction = reduction + self.size_average = size_average + + def loss(self, x, y): + """Get loss""" + num_examples = x.shape[0] + + diff_norms = ops.norm(x.reshape(num_examples, -1) - y.reshape(num_examples, -1), dim=1, ord=self.p) + y_norms = ops.norm(y.reshape(num_examples, -1), dim=1, ord=self.p) + + if self.reduction: + if self.size_average: + loss = ops.mean(diff_norms / y_norms) + else: + loss = Tensor.sum(diff_norms / y_norms) + else: + loss = diff_norms / y_norms + return loss + + def construct(self, prediction1, prediction2, label1, label2, weight1=0.8, weight2=0.2): + l1 = self.loss(prediction1, label1) + l2 = self.loss(prediction2, label2) + return weight1 * l1 + weight2 * l2 diff --git a/MindEarth/applications/nowcasting/PreDiff/src/diffusion/cuboid_transformer.py b/MindEarth/applications/nowcasting/PreDiff/src/diffusion/cuboid_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..97e3e952bf51c43b3b7df15e7420c91f109707fc --- /dev/null +++ b/MindEarth/applications/nowcasting/PreDiff/src/diffusion/cuboid_transformer.py @@ -0,0 +1,1078 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The Substructure of cuboid_transformer_unet""" +from functools import lru_cache +from collections import OrderedDict + +import mindspore as ms +from mindspore import nn, ops, Parameter +from mindspore.common.initializer import initializer, TruncatedNormal + +from src.utils import ( + get_activation, + get_norm_layer, + generalize_padding, + generalize_unpadding, + apply_initialization, +) + + +class PosEmbed(nn.Cell): + """ + Spatiotemporal positional embedding layer combining temporal, height, and width embeddings. + """ + def __init__(self, embed_dim, max_t, max_h, max_w): + """ + Initialize positional embedding with separate temporal/spatial components. + Args: + embed_dim (int): Dimensionality of the embedding vectors. + maxT (int): Maximum temporal length (number of time steps). + maxH (int): Maximum height dimension size. + maxW (int): Maximum width dimension size. + """ + super().__init__() + self.embed_dim = embed_dim + # spatiotemporal learned positional embedding + self.t_embed = nn.Embedding(vocab_size=max_t, embedding_size=embed_dim) + self.h_embed = nn.Embedding(vocab_size=max_h, embedding_size=embed_dim) + self.w_embed = nn.Embedding(vocab_size=max_w, embedding_size=embed_dim) + self.reset_parameters() + + def reset_parameters(self): + for cell in self.cells(): + apply_initialization(cell, embed_mode="0") + + def construct(self, x): + """Forward pass of positional embedding. + + Args: + x (Tensor): Input tensor of shape (B, T, H, W, C) + + Returns: + Tensor: Output tensor with added positional embeddings + """ + + _, t, h, w, _ = x.shape + + t_idx = ops.arange(t) + h_idx = ops.arange(h) + w_idx = ops.arange(w) + return ( + x + + self.t_embed(t_idx).reshape(t, 1, 1, self.embed_dim) + + self.h_embed(h_idx).reshape(1, h, 1, self.embed_dim) + + self.w_embed(w_idx).reshape(1, 1, w, self.embed_dim) + ) + + +class PositionwiseFFN(nn.Cell): + """The Position-wise Feed-Forward Network layer used in Transformer architectures. + + This implements a two-layer MLP with optional gating mechanism and normalization. + The processing order depends on the pre_norm parameter: + + If pre_norm is True: + norm(data) -> fc1 -> act -> act_dropout -> fc2 -> dropout -> residual_add(+data) + Else: + data -> fc1 -> act -> act_dropout -> fc2 -> dropout -> norm(residual_add(+data)) + + When gated projection is enabled, uses: + fc1_1 * act(fc1_2(data)) for the first projection + """ + + def __init__( + self, + units: int = 512, + hidden_size: int = 2048, + activation_dropout: float = 0.0, + dropout: float = 0.1, + gated_proj: bool = False, + activation="relu", + normalization: str = "layer_norm", + layer_norm_eps: float = 1e-5, + pre_norm: bool = False, + linear_init_mode="0", + ffn2_linear_init_mode="2", + norm_init_mode="0", + ): + super().__init__() + self.linear_init_mode = linear_init_mode + self.ffn2_linear_init_mode = ffn2_linear_init_mode + self.norm_init_mode = norm_init_mode + + self._pre_norm = pre_norm + self._gated_proj = gated_proj + self._kwargs = OrderedDict( + [ + ("units", units), + ("hidden_size", hidden_size), + ("activation_dropout", activation_dropout), + ("activation", activation), + ("dropout", dropout), + ("normalization", normalization), + ("layer_norm_eps", layer_norm_eps), + ("gated_proj", gated_proj), + ("pre_norm", pre_norm), + ] + ) + self.dropout_layer = nn.Dropout(p=dropout) + self.activation_dropout_layer = nn.Dropout(p=activation_dropout) + self.ffn_1 = nn.Dense( + in_channels=units, out_channels=hidden_size, has_bias=True + ) + if self._gated_proj: + self.ffn_1_gate = nn.Dense( + in_channels=units, out_channels=hidden_size, has_bias=True + ) + self.activation = get_activation(activation) + self.ffn_2 = nn.Dense( + in_channels=hidden_size, out_channels=units, has_bias=True + ) + self.layer_norm = get_norm_layer( + norm_type=normalization, in_channels=units, epsilon=layer_norm_eps + ) + self.reset_parameters() + + def reset_parameters(self): + """Initialize all sublayers with specified initialization modes.""" + apply_initialization(self.ffn_1, linear_mode=self.linear_init_mode) + if self._gated_proj: + apply_initialization(self.ffn_1_gate, linear_mode=self.linear_init_mode) + apply_initialization(self.ffn_2, linear_mode=self.ffn2_linear_init_mode) + apply_initialization(self.layer_norm, norm_mode=self.norm_init_mode) + + def construct(self, data): + """ + Forward pass of the Position-wise FFN. + + Args: + data: Input tensor of shape (batch_size, sequence_length, units) + + Returns: + Output tensor of same shape as input with transformed features + """ + residual = data + if self._pre_norm: + data = self.layer_norm(data) + if self._gated_proj: + out = self.activation(self.ffn_1_gate(data)) * self.ffn_1(data) + else: + out = self.activation(self.ffn_1(data)) + out = self.activation_dropout_layer(out) + out = self.ffn_2(out) + out = self.dropout_layer(out) + out = out + residual + if not self._pre_norm: + out = self.layer_norm(out) + return out + + +class PatchMerging3D(nn.Cell): + """3D Patch Merging Layer for spatial-temporal feature downsampling. + This layer merges patches in 3D (temporal, height, width) and applies a linear transformation + to reduce the feature dimension while increasing the channel dimension. + """ + + def __init__( + self, + dim, + out_dim=None, + downsample=(1, 2, 2), + norm_layer="layer_norm", + padding_type="nearest", + linear_init_mode="0", + norm_init_mode="0", + ): + super().__init__() + self.linear_init_mode = linear_init_mode + self.norm_init_mode = norm_init_mode + self.dim = dim + if out_dim is None: + out_dim = max(downsample) * dim + self.out_dim = out_dim + self.downsample = downsample + self.padding_type = padding_type + self.reduction = nn.Dense( + downsample[0] * downsample[1] * downsample[2] * dim, out_dim, has_bias=False + ) + self.norm = get_norm_layer( + norm_layer, in_channels=downsample[0] * downsample[1] * downsample[2] * dim + ) + self.reset_parameters() + + def reset_parameters(self): + """Initialize all sublayers with specified initialization modes.""" + for cell in self.cells(): + apply_initialization( + cell, linear_mode=self.linear_init_mode, norm_mode=self.norm_init_mode + ) + + def get_out_shape(self, data_shape): + """ + Calculate the output shape given input dimensions. + + Args: + data_shape: Input shape tuple (T, H, W, C_in) + + Returns: + Tuple of output shape (T_out, H_out, W_out, C_out) + """ + t, h, w, _ = data_shape + pad_t = (self.downsample[0] - t % self.downsample[0]) % self.downsample[0] + pad_h = (self.downsample[1] - h % self.downsample[1]) % self.downsample[1] + pad_w = (self.downsample[2] - w % self.downsample[2]) % self.downsample[2] + return ( + (t + pad_t) // self.downsample[0], + (h + pad_h) // self.downsample[1], + (w + pad_w) // self.downsample[2], + self.out_dim, + ) + + def construct(self, x): + """ + Forward pass of the 3D Patch Merging layer. + + Args: + x: Input tensor of shape (B, T, H, W, C) + + Returns: + Output tensor of shape: + (B, T//downsample[0], H//downsample[1], W//downsample[2], out_dim) + """ + b, t, h, w, c = x.shape + + # padding + pad_t = (self.downsample[0] - t % self.downsample[0]) % self.downsample[0] + pad_h = (self.downsample[1] - h % self.downsample[1]) % self.downsample[1] + pad_w = (self.downsample[2] - w % self.downsample[2]) % self.downsample[2] + if pad_h or pad_h or pad_w: + t += pad_t + h += pad_h + w += pad_w + x = generalize_padding( + x, pad_t, pad_h, pad_w, padding_type=self.padding_type + ) + + x = ( + x.reshape( + ( + b, + t // self.downsample[0], + self.downsample[0], + h // self.downsample[1], + self.downsample[1], + w // self.downsample[2], + self.downsample[2], + c, + ) + ) + .permute(0, 1, 3, 5, 2, 4, 6, 7) + .reshape( + b, + t // self.downsample[0], + h // self.downsample[1], + w // self.downsample[2], + self.downsample[0] * self.downsample[1] * self.downsample[2] * c, + ) + ) + x = self.norm(x) + x = self.reduction(x) + + return x + + +class Upsample3DLayer(nn.Cell): + """3D Upsampling Layer combining interpolation and convolution. + + Performs spatial upsampling (with optional temporal upsampling) followed by convolution. + The operation consists of: + 1. Spatial upsampling using nearest-neighbor interpolation + 2. 2D or 3D convolution to refine features and adjust channel dimensions + + Note: Currently only implements 2D upsampling (spatial only) + """ + + def __init__( + self, + dim, + out_dim, + target_size, + kernel_size=3, + conv_init_mode="0", + ): + super().__init__() + self.conv_init_mode = conv_init_mode + self.target_size = target_size + self.out_dim = out_dim + self.up = nn.Upsample(size=(target_size[1], target_size[2]), mode="nearest") + self.conv = nn.Conv2d( + in_channels=dim, + out_channels=out_dim, + kernel_size=(kernel_size, kernel_size), + padding=kernel_size // 2, + has_bias=True, + pad_mode="pad", + ) + self.reset_parameters() + + def reset_parameters(self): + """Initialize all sublayers with specified initialization modes.""" + for cell in self.cells(): + apply_initialization(cell, conv_mode=self.conv_init_mode) + + def construct(self, x): + """Forward pass of the 3D Upsampling layer.""" + b, t, h, w, c = x.shape + if self.target_size[0] != t: + raise ValueError( + f"Target size mismatch: expected first dimension to be {self.target_size[0]}, " + f"but got {t}. Please ensure consistent dimensions." + ) + x = x.reshape(b * t, h, w, c).permute(0, 3, 1, 2) + x = self.up(x) + return ( + self.conv(x) + .permute(0, 2, 3, 1) + .reshape((b,) + self.target_size + (self.out_dim,)) + ) + + +def cuboid_reorder(data, cuboid_size, strategy): + """Reorder the tensor into (B, num_cuboids, bT * bH * bW, C) + + We assume that the tensor shapes are divisible to the cuboid sizes. + + Parameters + ---------- + data + The input data + cuboid_size + The size of the cuboid + strategy + The cuboid strategy + + Returns + ------- + reordered_data + Shape will be (B, num_cuboids, bT * bH * bW, C) + num_cuboids = T / bT * H / bH * W / bW + """ + b, t, h, w, c = data.shape + num_cuboids = t // cuboid_size[0] * h // cuboid_size[1] * w // cuboid_size[2] + cuboid_volume = cuboid_size[0] * cuboid_size[1] * cuboid_size[2] + intermediate_shape = [] + + nblock_axis = [] + block_axis = [] + for i, (block_size, total_size, ele_strategy) in enumerate( + zip(cuboid_size, (t, h, w), strategy) + ): + if ele_strategy == "l": + intermediate_shape.extend([total_size // block_size, block_size]) + nblock_axis.append(2 * i + 1) + block_axis.append(2 * i + 2) + elif ele_strategy == "d": + intermediate_shape.extend([block_size, total_size // block_size]) + nblock_axis.append(2 * i + 2) + block_axis.append(2 * i + 1) + else: + raise NotImplementedError + + a = (b,) + tuple(intermediate_shape) + (c,) + data = data.reshape(a) + reordered_data = data.permute((0,) + tuple(nblock_axis) + tuple(block_axis) + (7,)) + reordered_data = reordered_data.reshape((b, num_cuboids, cuboid_volume, c)) + return reordered_data + + +def cuboid_reorder_reverse(data, cuboid_size, strategy, orig_data_shape): + """Reverse the reordered cuboid back to the original space + + Parameters + ---------- + data + cuboid_size + strategy + orig_data_shape + + Returns + ------- + data + The recovered data + """ + b, _, _, c = data.shape + t, h, w = orig_data_shape + + permutation_axis = [0] + for i, (_, _, ele_strategy) in enumerate( + zip(cuboid_size, (t, h, w), strategy) + ): + if ele_strategy == "l": + permutation_axis.append(i + 1) + permutation_axis.append(i + 4) + elif ele_strategy == "d": + permutation_axis.append(i + 4) + permutation_axis.append(i + 1) + else: + raise NotImplementedError + permutation_axis.append(7) + data = data.reshape( + b, + t // cuboid_size[0], + h // cuboid_size[1], + w // cuboid_size[2], + cuboid_size[0], + cuboid_size[1], + cuboid_size[2], + c, + ) + data = data.permute(permutation_axis) + data = data.reshape((b, t, h, w, c)) + return data + + +@lru_cache() +def compute_cuboid_self_attention_mask( + data_shape, cuboid_size, shift_size, strategy, padding_type +): + """compute_cuboid_self_attention_mask""" + t, h, w = data_shape + pad_t = (cuboid_size[0] - t % cuboid_size[0]) % cuboid_size[0] + pad_h = (cuboid_size[1] - h % cuboid_size[1]) % cuboid_size[1] + pad_w = (cuboid_size[2] - w % cuboid_size[2]) % cuboid_size[2] + + data_mask = None + if pad_t > 0 or pad_h > 0 or pad_w > 0: + if padding_type == "ignore": + data_mask = ops.ones((1, t, h, w, 1), dtype=ms.bool_) + data_mask = ops.pad( + data_mask, ((0, 0), (0, pad_t), (0, pad_h), (0, pad_w), (0, 0)) + ) + else: + data_mask = ops.ones((1, t + pad_t, h + pad_h, w + pad_w, 1), dtype=ms.bool_) + + if any(i > 0 for i in shift_size): + if padding_type == "ignore": + data_mask = ops.roll( + data_mask, + shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), + dims=(1, 2, 3), + ) + t_padded, h_padded, w_padded = t + pad_t, h + pad_h, w + pad_w + if t_padded <= 0 or h_padded <= 0 or w_padded <= 0: + raise ValueError( + f"invalid padded dimensions: t={t_padded}, h={h_padded}, w={w_padded}" + ) + + shift_mask = ops.zeros((1, t_padded, h_padded, w_padded, 1)) + cnt = 0 + t_slices = ( + [ + slice(0, cuboid_size[0]), + slice(cuboid_size[0] - shift_size[0], t_padded - shift_size[0]), + slice(t_padded - cuboid_size[0], t_padded), + ] + if shift_size[0] > 0 + else [slice(0, t_padded)] + ) + + h_slices = ( + [ + slice(0, cuboid_size[1]), + slice(cuboid_size[1] - shift_size[1], h_padded - shift_size[1]), + slice(h_padded - cuboid_size[1], h_padded), + ] + if shift_size[1] > 0 + else [slice(0, h_padded)] + ) + + w_slices = ( + [ + slice(0, cuboid_size[2]), + slice(cuboid_size[2] - shift_size[2], w_padded - shift_size[2]), + slice(w_padded - cuboid_size[2], w_padded), + ] + if shift_size[2] > 0 + else [slice(0, w_padded)] + ) + + for t in t_slices: + for h in h_slices: + for w in w_slices: + shift_mask[:, t, h, w, :] = cnt + cnt += 1 + + shift_mask = cuboid_reorder(shift_mask, cuboid_size, strategy=strategy) + shift_mask = shift_mask.squeeze(-1).squeeze(0) # num_cuboids, cuboid_volume + attn_mask = (shift_mask.unsqueeze(1) - shift_mask.unsqueeze(2)) == 0 + + if padding_type == "ignore": + if padding_type == "ignore": + data_mask = cuboid_reorder(data_mask, cuboid_size, strategy=strategy) + data_mask = data_mask.squeeze(-1).squeeze(0) + attn_mask = data_mask.unsqueeze(1) * data_mask.unsqueeze(2) * attn_mask + + return attn_mask + + +def masked_softmax(att_score, mask, axis: int = -1): + """Computes softmax while ignoring masked elements with broadcastable masks. + + Parameters + ---------- + att_score : Tensor + mask : Tensor or None + Binary mask tensor of shape (..., length, ...) where: + - 1 indicates unmasked (valid) elements + - 0 indicates masked elements + Must be broadcastable with att_score + axis : int, optional + + Returns + ------- + Tensor + Softmax output of same shape as input att_score, with: + - Proper attention weights for unmasked elements + - Zero weights for masked elements + """ + if mask is not None: + # Fill in the masked scores with a very small value + if att_score.dtype == ms.float16: + att_score = att_score.masked_fill(ops.logical_not(mask), -1e4) + else: + att_score = att_score.masked_fill(ops.logical_not(mask), -1e18) + att_weights = ops.softmax(att_score, axis=axis) * mask + else: + att_weights = ops.softmax(att_score, axis=axis) + return att_weights + + +def update_cuboid_size_shift_size(data_shape, cuboid_size, shift_size, strategy): + """Update the + + Parameters + ---------- + data_shape + The shape of the data + cuboid_size + Size of the cuboid + shift_size + Size of the shift + strategy + The strategy of attention + + Returns + ------- + new_cuboid_size + Size of the cuboid + new_shift_size + Size of the shift + """ + new_cuboid_size = list(cuboid_size) + new_shift_size = list(shift_size) + for i in range(len(data_shape)): + if strategy[i] == "d": + new_shift_size[i] = 0 + if data_shape[i] <= cuboid_size[i]: + new_cuboid_size[i] = data_shape[i] + new_shift_size[i] = 0 + return tuple(new_cuboid_size), tuple(new_shift_size) + + +class CuboidSelfAttentionLayer(nn.Cell): + """ + A self-attention layer designed for 3D data (e.g., video or 3D images), + implementing cuboid-based attention with optional global vectors and relative position encoding. + """ + def __init__( + self, + dim, + num_heads, + cuboid_size=(2, 7, 7), + shift_size=(0, 0, 0), + strategy=("l", "l", "l"), + padding_type="ignore", + qkv_bias=False, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + use_final_proj=True, + norm_layer="layer_norm", + use_global_vector=False, + use_global_self_attn=False, + separate_global_qkv=False, + global_dim_ratio=1, + use_relative_pos=True, + attn_linear_init_mode="0", + ffn_linear_init_mode="2", + norm_init_mode="0", + ): + """Initialize the CuboidSelfAttentionLayer. + + Args: + dim (int): Input feature dimension. + num_heads (int): Number of attention heads. + cuboid_size (tuple): 3D dimensions (T, H, W) of the cuboid blocks. + shift_size (tuple): Shift sizes for each dimension to avoid attention blindness. + strategy (tuple): Strategy for each dimension ('l' for local, 'g' for global). + padding_type (str): Padding method for attention computation ("ignore", "zeros", "nearest"). + qkv_bias (bool): Whether to include bias in QKV projections. + qk_scale (float, optional): Scaling factor for QK dot product. Defaults to head_dim**-0.5. + attn_drop (float): Dropout rate after attention softmax. + proj_drop (float): Dropout rate after output projection. + use_final_proj (bool): Whether to apply the final linear projection. + norm_layer (str): Type of normalization layer ("layer_norm", etc.). + use_global_vector (bool): Whether to include a global vector in attention. + use_global_self_attn (bool): Whether to apply self-attention to global vectors. + separate_global_qkv (bool): Whether to use separate QKV for global vectors. + global_dim_ratio (int): Dimension ratio for global vector (requires separate_global_qkv=True if !=1). + use_relative_pos (bool): Whether to use relative position embeddings. + attn_linear_init_mode (str): Initialization mode for attention linear layers. + ffn_linear_init_mode (str): Initialization mode for FFN linear layers. + norm_init_mode (str): Initialization mode for normalization layers. + """ + super().__init__() + # initialization + self.attn_linear_init_mode = attn_linear_init_mode + self.ffn_linear_init_mode = ffn_linear_init_mode + self.norm_init_mode = norm_init_mode + + if dim % num_heads != 0: + raise ValueError( + f"Dimension {dim} must be divisible by number of heads {num_heads}. " + f"Got dim={dim}, num_heads={num_heads}" + ) + self.num_heads = num_heads + self.dim = dim + self.cuboid_size = cuboid_size + self.shift_size = shift_size + self.strategy = strategy + self.padding_type = padding_type + self.use_final_proj = use_final_proj + self.use_relative_pos = use_relative_pos + # global vectors + self.use_global_vector = use_global_vector + self.use_global_self_attn = use_global_self_attn + self.separate_global_qkv = separate_global_qkv + self.global_dim_ratio = global_dim_ratio + if self.padding_type not in ["ignore", "zeros", "nearest"]: + raise ValueError( + f"Invalid padding_type: '{self.padding_type}'. " + f"Expected one of: ['ignore', 'zeros', 'nearest']" + ) + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + if self.use_relative_pos: + self.relative_position_bias_table = Parameter( + initializer( + TruncatedNormal(sigma=0.02), + [ + (2 * cuboid_size[0] - 1) + * (2 * cuboid_size[1] - 1) + * (2 * cuboid_size[2] - 1), + num_heads, + ], + ms.float32, + ) + ) + self.relative_position_bias_table.name = "relative_position_bias_table" + coords_t = ops.arange(self.cuboid_size[0]) + coords_h = ops.arange(self.cuboid_size[1]) + coords_w = ops.arange(self.cuboid_size[2]) + coords = ops.stack(ops.meshgrid(coords_t, coords_h, coords_w)) + + coords_flatten = ops.flatten(coords, start_dim=1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0) + relative_coords[:, :, 0] += self.cuboid_size[0] - 1 + relative_coords[:, :, 1] += self.cuboid_size[1] - 1 + relative_coords[:, :, 2] += self.cuboid_size[2] - 1 + + relative_coords[:, :, 0] *= (2 * self.cuboid_size[1] - 1) * ( + 2 * self.cuboid_size[2] - 1 + ) + relative_coords[:, :, 1] *= 2 * self.cuboid_size[2] - 1 + relative_position_index = relative_coords.sum(-1) + self.relative_position_index = Parameter( + relative_position_index, + name="relative_position_index", + requires_grad=False, + ) + self.qkv = nn.Dense(dim, dim * 3, has_bias=qkv_bias) + self.attn_drop = nn.Dropout(p=attn_drop) + + if use_final_proj: + self.proj = nn.Dense(dim, dim) + self.proj_drop = nn.Dropout(p=proj_drop) + + if self.use_global_vector: + self.global_proj = nn.Dense( + in_channels=global_dim_ratio * dim, + out_channels=global_dim_ratio * dim, + ) + + self.norm = get_norm_layer(norm_layer, in_channels=dim) + if self.use_global_vector: + self.global_vec_norm = get_norm_layer( + norm_layer, in_channels=global_dim_ratio * dim + ) + + self.reset_parameters() + + def reset_parameters(self): + '''set_parameters''' + apply_initialization(self.qkv, linear_mode=self.attn_linear_init_mode) + if self.use_final_proj: + apply_initialization(self.proj, linear_mode=self.ffn_linear_init_mode) + apply_initialization(self.norm, norm_mode=self.norm_init_mode) + if self.use_global_vector: + if self.separate_global_qkv: + apply_initialization( + self.l2g_q_net, linear_mode=self.attn_linear_init_mode + ) + apply_initialization( + self.l2g_global_kv_net, linear_mode=self.attn_linear_init_mode + ) + apply_initialization( + self.g2l_global_q_net, linear_mode=self.attn_linear_init_mode + ) + apply_initialization( + self.g2l_k_net, linear_mode=self.attn_linear_init_mode + ) + apply_initialization( + self.g2l_v_net, linear_mode=self.attn_linear_init_mode + ) + if self.use_global_self_attn: + apply_initialization( + self.g2g_global_qkv_net, linear_mode=self.attn_linear_init_mode + ) + else: + apply_initialization( + self.global_qkv, linear_mode=self.attn_linear_init_mode + ) + apply_initialization(self.global_vec_norm, norm_mode=self.norm_init_mode) + + def construct(self, x): + """ + Constructs the output by applying normalization, padding, shifting, and attention mechanisms. + + Parameters: + - x (Tensor): Input tensor with shape (batch, time, height, width, channels). + - global_vectors (Tensor, optional): Global vectors used in global-local interactions. Defaults to None. + + Returns: + - Tensor: Processed tensor after applying all transformations. + - Tensor: Updated global vectors if global vectors are used; otherwise, returns only the processed tensor. + """ + x = self.norm(x) + batch, time, height, width, channels = x.shape + if channels != self.dim: + raise ValueError( + f"Channel dimension mismatch: expected {self.dim}, got {channels}. " + f"Please ensure input channels match the layer's expected dimension." + ) + cuboid_size, shift_size = update_cuboid_size_shift_size( + (time, height, width), self.cuboid_size, self.shift_size, self.strategy + ) + pad_t = (cuboid_size[0] - time % cuboid_size[0]) % cuboid_size[0] + pad_h = (cuboid_size[1] - height % cuboid_size[1]) % cuboid_size[1] + pad_w = (cuboid_size[2] - width % cuboid_size[2]) % cuboid_size[2] + x = generalize_padding(x, pad_t, pad_h, pad_w, self.padding_type) + if any(i > 0 for i in shift_size): + shifted_x = ops.roll( + x, + shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), + dims=(1, 2, 3), + ) + else: + shifted_x = x + reordered_x = cuboid_reorder( + shifted_x, cuboid_size=cuboid_size, strategy=self.strategy + ) + _, num_cuboids, cuboid_volume, _ = reordered_x.shape + attn_mask = compute_cuboid_self_attention_mask( + (time, height, width), + cuboid_size, + shift_size=shift_size, + strategy=self.strategy, + padding_type=self.padding_type, + ) + head_c = channels // self.num_heads + qkv = ( + self.qkv(reordered_x) + .reshape(batch, num_cuboids, cuboid_volume, 3, self.num_heads, head_c) + .permute(3, 0, 4, 1, 2, 5) + ) + q, k, v = ( + qkv[0], + qkv[1], + qkv[2], + ) + q = q * self.scale + attn_score = q @ k.swapaxes(-2, -1) + if self.use_relative_pos: + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index[:cuboid_volume, :cuboid_volume].reshape(-1) + ].reshape(cuboid_volume, cuboid_volume, -1) + relative_position_bias = relative_position_bias.permute(2, 0, 1).unsqueeze( + 1 + ) + attn_score = attn_score + relative_position_bias + attn_score = masked_softmax(attn_score, mask=attn_mask) + attn_score = self.attn_drop(attn_score) + reordered_x = ( + (attn_score @ v) + .permute(0, 2, 3, 1, 4) + .reshape(batch, num_cuboids, cuboid_volume, self.dim) + ) + + if self.use_final_proj: + reordered_x = self.proj_drop(self.proj(reordered_x)) + if self.use_global_vector: + new_global_vector = self.proj_drop(self.global_proj(new_global_vector)) + shifted_x = cuboid_reorder_reverse( + reordered_x, + cuboid_size=cuboid_size, + strategy=self.strategy, + orig_data_shape=(time + pad_t, height + pad_h, width + pad_w), + ) + if any(i > 0 for i in shift_size): + x = ops.roll( + shifted_x, + shifts=(shift_size[0], shift_size[1], shift_size[2]), + dims=(1, 2, 3), + ) + else: + x = shifted_x + x = generalize_unpadding( + x, pad_t=pad_t, pad_h=pad_h, pad_w=pad_w, padding_type=self.padding_type + ) + if self.use_global_vector: + return x, new_global_vector + return x + + +class StackCuboidSelfAttentionBlock(nn.Cell): + """ + + - "use_inter_ffn" is True + x --> attn1 --> ffn1 --> attn2 --> ... --> ffn_k --> out + - "use_inter_ffn" is False + x --> attn1 --> attn2 --> ... attnk --> ffnk --> out + If we have enabled global memory vectors, each attention will be a + + """ + + def __init__( + self, + dim=None, + num_heads=None, + block_cuboid_size=None, + block_shift_size=None, + block_strategy=None, + padding_type="ignore", + qkv_bias=False, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + ffn_drop=0.0, + activation="leaky", + gated_ffn=False, + norm_layer="layer_norm", + use_inter_ffn=False, + use_global_vector=False, + use_global_vector_ffn=True, + use_global_self_attn=False, + separate_global_qkv=False, + global_dim_ratio=1, + use_relative_pos=True, + use_final_proj=True, + # initialization + attn_linear_init_mode="0", + ffn_linear_init_mode="0", + ffn2_linear_init_mode="2", + attn_proj_linear_init_mode="2", + norm_init_mode="0", + ): + super().__init__() + # initialization + self.attn_linear_init_mode = attn_linear_init_mode + self.ffn_linear_init_mode = ffn_linear_init_mode + self.attn_proj_linear_init_mode = attn_proj_linear_init_mode + self.norm_init_mode = norm_init_mode + self.num_attn = len(block_cuboid_size) + self.use_inter_ffn = use_inter_ffn + # global vectors + self.use_global_vector = use_global_vector + self.use_global_vector_ffn = use_global_vector_ffn + self.use_global_self_attn = use_global_self_attn + self.global_dim_ratio = global_dim_ratio + + if self.use_inter_ffn: + self.ffn_l = nn.CellList( + [ + PositionwiseFFN( + units=dim, + hidden_size=4 * dim, + activation_dropout=ffn_drop, + dropout=ffn_drop, + gated_proj=gated_ffn, + activation=activation, + normalization=norm_layer, + pre_norm=True, + linear_init_mode=ffn_linear_init_mode, + ffn2_linear_init_mode=ffn2_linear_init_mode, + norm_init_mode=norm_init_mode, + ) + for _ in range(self.num_attn) + ] + ) + if self.use_global_vector_ffn and self.use_global_vector: + self.global_ffn_l = nn.CellList( + [ + PositionwiseFFN( + units=global_dim_ratio * dim, + hidden_size=global_dim_ratio * 4 * dim, + activation_dropout=ffn_drop, + dropout=ffn_drop, + gated_proj=gated_ffn, + activation=activation, + normalization=norm_layer, + pre_norm=True, + linear_init_mode=ffn_linear_init_mode, + ffn2_linear_init_mode=ffn2_linear_init_mode, + norm_init_mode=norm_init_mode, + ) + for _ in range(self.num_attn) + ] + ) + else: + self.ffn_l = nn.CellList( + [ + PositionwiseFFN( + units=dim, + hidden_size=4 * dim, + activation_dropout=ffn_drop, + dropout=ffn_drop, + gated_proj=gated_ffn, + activation=activation, + normalization=norm_layer, + pre_norm=True, + linear_init_mode=ffn_linear_init_mode, + ffn2_linear_init_mode=ffn2_linear_init_mode, + norm_init_mode=norm_init_mode, + ) + ] + ) + if self.use_global_vector_ffn and self.use_global_vector: + self.global_ffn_l = nn.CellList( + [ + PositionwiseFFN( + units=global_dim_ratio * dim, + hidden_size=global_dim_ratio * 4 * dim, + activation_dropout=ffn_drop, + dropout=ffn_drop, + gated_proj=gated_ffn, + activation=activation, + normalization=norm_layer, + pre_norm=True, + linear_init_mode=ffn_linear_init_mode, + ffn2_linear_init_mode=ffn2_linear_init_mode, + norm_init_mode=norm_init_mode, + ) + ] + ) + self.attn_l = nn.CellList( + [ + CuboidSelfAttentionLayer( + dim=dim, + num_heads=num_heads, + cuboid_size=ele_cuboid_size, + shift_size=ele_shift_size, + strategy=ele_strategy, + padding_type=padding_type, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=proj_drop, + norm_layer=norm_layer, + use_global_vector=use_global_vector, + use_global_self_attn=use_global_self_attn, + separate_global_qkv=separate_global_qkv, + global_dim_ratio=global_dim_ratio, + use_relative_pos=use_relative_pos, + use_final_proj=use_final_proj, + attn_linear_init_mode=attn_linear_init_mode, + ffn_linear_init_mode=attn_proj_linear_init_mode, + norm_init_mode=norm_init_mode, + ) + for ele_cuboid_size, ele_shift_size, ele_strategy in zip( + block_cuboid_size, block_shift_size, block_strategy + ) + ] + ) + + def reset_parameters(self): + for m in self.ffn_l: + m.reset_parameters() + if self.use_global_vector_ffn and self.use_global_vector: + for m in self.global_ffn_l: + m.reset_parameters() + for m in self.attn_l: + m.reset_parameters() + + def construct(self, x, global_vectors=None): + """ + Constructs the network output by processing input data with attention and feed-forward layers. + + Args: + x (Tensor): Input data tensor. + global_vectors (Tensor, optional): Global vectors for contextual processing. Defaults to None. + + Returns: + Union[Tensor, Tuple[Tensor, Tensor]]: + - If `global_vectors` is used, returns a tuple (processed_x, updated_global_vectors). + - Otherwise, returns the processed input tensor x. + """ + if self.use_inter_ffn: + if self.use_global_vector: + for idx, (attn, ffn) in enumerate(zip(self.attn_l, self.ffn_l)): + x_out, global_vectors_out = attn(x, global_vectors) + x = x + x_out + global_vectors = global_vectors + global_vectors_out + x = ffn(x) + if self.use_global_vector_ffn: + global_vectors = self.global_ffn_l[idx](global_vectors) + return x, global_vectors + for idx, (attn, ffn) in enumerate(zip(self.attn_l, self.ffn_l)): + x_ = attn(x) + x = x + x_ + x = ffn(x) + return x + if self.use_global_vector: + for idx, attn in enumerate(self.attn_l): + x_out, global_vectors_out = attn(x, global_vectors) + x = x + x_out + global_vectors = global_vectors + global_vectors_out + x = self.ffn_l[0](x) + if self.use_global_vector_ffn: + global_vectors = self.global_ffn_l[0](global_vectors) + return x, global_vectors + for idx, attn in enumerate(self.attn_l): + out = attn(x) + x = x + out + x = self.ffn_l[0](x) + return x diff --git a/MindEarth/applications/nowcasting/PreDiff/src/diffusion/cuboid_transformer_unet.py b/MindEarth/applications/nowcasting/PreDiff/src/diffusion/cuboid_transformer_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..67929ae0c9bf7a1f848b3a07026e8a87eb339567 --- /dev/null +++ b/MindEarth/applications/nowcasting/PreDiff/src/diffusion/cuboid_transformer_unet.py @@ -0,0 +1,575 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"CuboidTransformerUNet base class" +from mindspore import ops, nn, Parameter +import mindspore.common.initializer as initializer + +from src.utils import timestep_embedding, apply_initialization, round_to, self_axial +from .time_embed import TimeEmbedLayer, TimeEmbedResBlock +from .cuboid_transformer import ( + PosEmbed, + Upsample3DLayer, + PatchMerging3D, + StackCuboidSelfAttentionBlock, +) + + +class CuboidTransformerUNet(nn.Cell): + r""" + U-Net style CuboidTransformer that parametrizes `p(x_{t-1}|x_t)`. + It takes `x_t`, `t` as input. + The conditioning can be concatenated to the input like the U-Net in FVD paper. + + For each block, we apply the StackCuboidSelfAttention in U-Net style + + x --> attn --> downscale --> ... --> z --> attn --> upscale --> ... --> out + + Besides, we insert the embeddings of the timesteps `t` before each cuboid attention blocks. + """ + + def __init__( + self, + input_shape=None, + target_shape=None, + base_units=256, + block_units=None, + scale_alpha=1.0, + depth=None, + downsample=2, + downsample_type="patch_merge", + upsample_type="upsample", + upsample_kernel_size=3, + use_attn_pattern=True, + block_cuboid_size=None, + block_cuboid_strategy=None, + block_cuboid_shift_size=None, + num_heads=4, + attn_drop=0.0, + proj_drop=0.0, + ffn_drop=0.0, + ffn_activation="leaky", + gated_ffn=False, + norm_layer="layer_norm", + use_inter_ffn=True, + hierarchical_pos_embed=False, + padding_type="ignore", + use_relative_pos=True, + self_attn_use_final_proj=True, + # global vectors + num_global_vectors=False, + use_global_vector_ffn=True, + use_global_self_attn=False, + separate_global_qkv=False, + global_dim_ratio=1, + # initialization + attn_linear_init_mode="0", + ffn_linear_init_mode="0", + ffn2_linear_init_mode="2", + attn_proj_linear_init_mode="2", + conv_init_mode="0", + down_linear_init_mode="0", + global_proj_linear_init_mode="2", + norm_init_mode="0", + # timestep embedding for diffusion + time_embed_channels_mult=4, + time_embed_use_scale_shift_norm=False, + time_embed_dropout=0.0, + unet_res_connect=True, + ): + super().__init__() + # initialization mode + self.attn_linear_init_mode = attn_linear_init_mode + self.ffn_linear_init_mode = ffn_linear_init_mode + self.ffn2_linear_init_mode = ffn2_linear_init_mode + self.attn_proj_linear_init_mode = attn_proj_linear_init_mode + self.conv_init_mode = conv_init_mode + self.down_linear_init_mode = down_linear_init_mode + self.global_proj_linear_init_mode = global_proj_linear_init_mode + self.norm_init_mode = norm_init_mode + + self.input_shape = input_shape + self.target_shape = target_shape + self.num_blocks = len(depth) + self.depth = depth + self.base_units = base_units + self.scale_alpha = scale_alpha + self.downsample = downsample + self.downsample_type = downsample_type + self.upsample_type = upsample_type + self.upsample_kernel_size = upsample_kernel_size + if not isinstance(downsample, (tuple, list)): + downsample = (1, downsample, downsample) + if block_units is None: + block_units = [ + round_to(base_units * int((max(downsample) ** scale_alpha) ** i), 4) + for i in range(self.num_blocks) + ] + else: + if len(block_units) != self.num_blocks: + raise ValueError( + f"Length of block_units ({len(block_units)}) does not match " + f"num_blocks ({self.num_blocks}). They must be equal." + ) + if block_units[0] != base_units: + raise ValueError( + f"First block_units value ({block_units[0]}) does not match " + f"base_units ({base_units}). The first unit must equal base_units." + ) + self.block_units = block_units + self.hierarchical_pos_embed = hierarchical_pos_embed + self.num_global_vectors = num_global_vectors + use_global_vector = num_global_vectors > 0 + self.use_global_vector = use_global_vector + if global_dim_ratio != 1: + if not separate_global_qkv: + raise ValueError( + "Configuration conflict: When global_dim_ratio != 1, " + "separate_global_qkv must be set to True. " + f"Current values: global_dim_ratio={global_dim_ratio}, " + f"separate_global_qkv={separate_global_qkv}" + ) + self.global_dim_ratio = global_dim_ratio + self.use_global_vector_ffn = use_global_vector_ffn + + self.time_embed_channels_mult = time_embed_channels_mult + self.time_embed_channels = self.block_units[0] * time_embed_channels_mult + self.time_embed_use_scale_shift_norm = time_embed_use_scale_shift_norm + self.time_embed_dropout = time_embed_dropout + self.unet_res_connect = unet_res_connect + + if self.use_global_vector: + self.init_global_vectors = Parameter( + ops.zeros((self.num_global_vectors, global_dim_ratio * base_units)) + ) + + t_in, h_in, w_in, c_in = input_shape + t_out, h_out, w_out, c_out = target_shape + if h_in != h_out or w_in != w_out or c_in != c_out: + mismatched_dims = [] + if h_in != h_out: + mismatched_dims.append(f"height ({h_in} vs {h_out})") + if w_in != w_out: + mismatched_dims.append(f"width ({w_in} vs {w_out})") + if c_in != c_out: + mismatched_dims.append(f"channels ({c_in} vs {c_out})") + raise ValueError( + f"Input and output dimensions mismatch. " + f"Mismatched dimensions: {', '.join(mismatched_dims)}. " + f"All dimensions must match for this operation." + ) + self.t_in = t_in + self.t_out = t_out + self.first_proj = TimeEmbedResBlock( + channels=self.data_shape[-1], + emb_channels=None, + dropout=proj_drop, + out_channels=self.base_units, + use_conv=False, + use_embed=False, + use_scale_shift_norm=False, + dims=3, + up=False, + down=False, + ) + self.pos_embed = PosEmbed( + embed_dim=base_units, + max_t=self.data_shape[0], + max_h=h_in, + max_w=w_in, + ) + + # diffusion time embed + self.time_embed = TimeEmbedLayer( + base_channels=self.block_units[0], + time_embed_channels=self.time_embed_channels, + ) + # # inner U-Net + if self.num_blocks > 1: + # Construct downsampling layers + if downsample_type == "patch_merge": + self.downsample_layers = nn.CellList( + [ + PatchMerging3D( + dim=self.block_units[i], + downsample=downsample, + padding_type=padding_type, + out_dim=self.block_units[i + 1], + linear_init_mode=down_linear_init_mode, + norm_init_mode=norm_init_mode, + ) + for i in range(self.num_blocks - 1) + ] + ) + else: + raise NotImplementedError + if self.use_global_vector: + self.down_layer_global_proj = nn.CellList( + [ + nn.Dense( + in_channels=global_dim_ratio * self.block_units[i], + out_channels=global_dim_ratio * self.block_units[i + 1], + ) + for i in range(self.num_blocks - 1) + ] + ) + # Construct upsampling layers + if self.upsample_type == "upsample": + self.upsample_layers = nn.CellList( + [ + Upsample3DLayer( + dim=self.mem_shapes[i + 1][-1], + out_dim=self.mem_shapes[i][-1], + target_size=self.mem_shapes[i][:3], + kernel_size=upsample_kernel_size, + conv_init_mode=conv_init_mode, + ) + for i in range(self.num_blocks - 1) + ] + ) + else: + raise NotImplementedError + if self.use_global_vector: + self.up_layer_global_proj = nn.CellList( + [ + nn.Dense( + in_channels=global_dim_ratio * self.block_units[i + 1], + out_channels=global_dim_ratio * self.block_units[i], + ) + for i in range(self.num_blocks - 1) + ] + ) + if self.hierarchical_pos_embed: + self.down_hierarchical_pos_embed_l = nn.CellList( + [ + PosEmbed( + embed_dim=self.block_units[i], + max_t=self.mem_shapes[i][0], + max_h=self.mem_shapes[i][1], + max_w=self.mem_shapes[i][2], + ) + for i in range(self.num_blocks - 1) + ] + ) + self.up_hierarchical_pos_embed_l = nn.CellList( + [ + PosEmbed( + embed_dim=self.block_units[i], + max_t=self.mem_shapes[i][0], + max_h=self.mem_shapes[i][1], + max_w=self.mem_shapes[i][2], + ) + for i in range(self.num_blocks - 1) + ] + ) + + if use_attn_pattern: + block_attn_patterns = self.depth + block_cuboid_size = [] + block_cuboid_strategy = [] + block_cuboid_shift_size = [] + for idx, _ in enumerate(block_attn_patterns): + cuboid_size, strategy, shift_size = self_axial(self.mem_shapes[idx]) + block_cuboid_size.append(cuboid_size) + block_cuboid_strategy.append(strategy) + block_cuboid_shift_size.append(shift_size) + else: + if not isinstance(block_cuboid_size[0][0], (list, tuple)): + block_cuboid_size = [block_cuboid_size for _ in range(self.num_blocks)] + else: + if len(block_cuboid_size) != self.num_blocks: + raise ValueError( + f"Block cuboid size dimension mismatch. Expected {self.num_blocks} blocks, " + f"but got {len(block_cuboid_size)}. Received block_cuboid_size={block_cuboid_size}. " + f"Please ensure the input matches the expected number of blocks." + ) + if not isinstance(block_cuboid_strategy[0][0], (list, tuple)): + block_cuboid_strategy = [ + block_cuboid_strategy for _ in range(self.num_blocks) + ] + else: + if len(block_cuboid_strategy) != self.num_blocks: + raise ValueError( + f"Configuration error: Expected {self.num_blocks} block strategies, " + f"but got {len(block_cuboid_strategy)}. " + f"Received block_cuboid_strategy={block_cuboid_strategy}. " + f"Please ensure the strategy list matches the number of blocks." + ) + + if not isinstance(block_cuboid_shift_size[0][0], (list, tuple)): + block_cuboid_shift_size = [ + block_cuboid_shift_size for _ in range(self.num_blocks) + ] + else: + if len(block_cuboid_shift_size) != self.num_blocks: + raise ValueError( + f"Block shift size configuration error: Expected {self.num_blocks} shift sizes, " + f"but received {len(block_cuboid_shift_size)}. " + f"Invalid configuration: block_cuboid_shift_size={block_cuboid_shift_size}. " + f"Please provide exactly {self.num_blocks} shift sizes in the list." + ) + self.block_cuboid_size = block_cuboid_size + self.block_cuboid_strategy = block_cuboid_strategy + self.block_cuboid_shift_size = block_cuboid_shift_size + + # cuboid self attention blocks + down_self_blocks = [] + up_self_blocks = [] + # ResBlocks that incorporate `time_embed` + down_time_embed_blocks = [] + up_time_embed_blocks = [] + for i in range(self.num_blocks): + down_time_embed_blocks.append( + TimeEmbedResBlock( + channels=self.mem_shapes[i][-1], + emb_channels=self.time_embed_channels, + dropout=self.time_embed_dropout, + out_channels=self.mem_shapes[i][-1], + use_conv=False, + use_embed=True, + use_scale_shift_norm=self.time_embed_use_scale_shift_norm, + dims=3, + up=False, + down=False, + ) + ) + + ele_depth = depth[i] + stack_cuboid_blocks = [ + StackCuboidSelfAttentionBlock( + dim=self.mem_shapes[i][-1], + num_heads=num_heads, + block_cuboid_size=block_cuboid_size[i], + block_strategy=block_cuboid_strategy[i], + block_shift_size=block_cuboid_shift_size[i], + attn_drop=attn_drop, + proj_drop=proj_drop, + ffn_drop=ffn_drop, + activation=ffn_activation, + gated_ffn=gated_ffn, + norm_layer=norm_layer, + use_inter_ffn=use_inter_ffn, + padding_type=padding_type, + use_global_vector=use_global_vector, + use_global_vector_ffn=use_global_vector_ffn, + use_global_self_attn=use_global_self_attn, + separate_global_qkv=separate_global_qkv, + global_dim_ratio=global_dim_ratio, + use_relative_pos=use_relative_pos, + use_final_proj=self_attn_use_final_proj, + # initialization + attn_linear_init_mode=attn_linear_init_mode, + ffn_linear_init_mode=ffn_linear_init_mode, + ffn2_linear_init_mode=ffn2_linear_init_mode, + attn_proj_linear_init_mode=attn_proj_linear_init_mode, + norm_init_mode=norm_init_mode, + ) + for _ in range(ele_depth) + ] + down_self_blocks.append(nn.CellList(stack_cuboid_blocks)) + + up_time_embed_blocks.append( + TimeEmbedResBlock( + channels=self.mem_shapes[i][-1], + emb_channels=self.time_embed_channels, + dropout=self.time_embed_dropout, + out_channels=self.mem_shapes[i][-1], + use_conv=False, + use_embed=True, + use_scale_shift_norm=self.time_embed_use_scale_shift_norm, + dims=3, + up=False, + down=False, + ) + ) + + stack_cuboid_blocks = [ + StackCuboidSelfAttentionBlock( + dim=self.mem_shapes[i][-1], + num_heads=num_heads, + block_cuboid_size=block_cuboid_size[i], + block_strategy=block_cuboid_strategy[i], + block_shift_size=block_cuboid_shift_size[i], + attn_drop=attn_drop, + proj_drop=proj_drop, + ffn_drop=ffn_drop, + activation=ffn_activation, + gated_ffn=gated_ffn, + norm_layer=norm_layer, + use_inter_ffn=use_inter_ffn, + padding_type=padding_type, + use_global_vector=use_global_vector, + use_global_vector_ffn=use_global_vector_ffn, + use_global_self_attn=use_global_self_attn, + separate_global_qkv=separate_global_qkv, + global_dim_ratio=global_dim_ratio, + use_relative_pos=use_relative_pos, + use_final_proj=self_attn_use_final_proj, + # initialization + attn_linear_init_mode=attn_linear_init_mode, + ffn_linear_init_mode=ffn_linear_init_mode, + ffn2_linear_init_mode=ffn2_linear_init_mode, + attn_proj_linear_init_mode=attn_proj_linear_init_mode, + norm_init_mode=norm_init_mode, + ) + for _ in range(ele_depth) + ] + up_self_blocks.append(nn.CellList(stack_cuboid_blocks)) + self.down_self_blocks = nn.CellList(down_self_blocks) + self.up_self_blocks = nn.CellList(up_self_blocks) + self.down_time_embed_blocks = nn.CellList(down_time_embed_blocks) + self.up_time_embed_blocks = nn.CellList(up_time_embed_blocks) + self.final_proj = nn.Dense(self.base_units, c_out) + + self.reset_parameters() + + def reset_parameters(self): + '''init parameters''' + if self.num_global_vectors > 0: + initializer.TruncatedNormal(self.init_global_vectors, sigma=0.02) + self.first_proj.reset_parameters() + apply_initialization(self.final_proj, linear_mode="2") + self.pos_embed.reset_parameters() + for block in self.down_self_blocks: + for m in block: + m.reset_parameters() + for m in self.down_time_embed_blocks: + m.reset_parameters() + for block in self.up_self_blocks: + for m in block: + m.reset_parameters() + for m in self.up_time_embed_blocks: + m.reset_parameters() + if self.num_blocks > 1: + for m in self.downsample_layers: + m.reset_parameters() + for m in self.upsample_layers: + m.reset_parameters() + if self.use_global_vector: + apply_initialization( + self.down_layer_global_proj, + linear_mode=self.global_proj_linear_init_mode, + ) + apply_initialization( + self.up_layer_global_proj, + linear_mode=self.global_proj_linear_init_mode, + ) + if self.hierarchical_pos_embed: + for m in self.down_hierarchical_pos_embed_l: + m.reset_parameters() + for m in self.up_hierarchical_pos_embed_l: + m.reset_parameters() + + @property + def data_shape(self): + '''set datashape''' + if not hasattr(self, "_data_shape"): + t_in, h_in, w_in, c_in = self.input_shape + t_out, h_out, w_out, c_out = self.target_shape + if not (h_in == h_out and w_in == w_out and c_in == c_out): + mismatches = [] + if h_in != h_out: + mismatches.append(f"height ({h_in} vs {h_out})") + if w_in != w_out: + mismatches.append(f"width ({w_in} vs {w_out})") + if c_in != c_out: + mismatches.append(f"channels ({c_in} vs {c_out})") + raise ValueError( + f"Input-output dimension mismatch. Mismatched dimensions: {', '.join(mismatches)}. " + f"All dimensions must match for this operation. " + f"Input shape: (h={h_in}, w={w_in}, c={c_in}), " + f"Output shape: (h={h_out}, w={w_out}, c={c_out})" + ) + self._data_shape = ( + t_in + t_out, + h_in, + w_in, + c_in + 1, + ) + return self._data_shape + + @property + def mem_shapes(self): + """Get the shape of the output memory based on the input shape. This can be used for constructing the decoder. + + Returns + ------- + mem_shapes + A list of shapes of the output memory + """ + inner_data_shape = tuple(self.data_shape)[:3] + (self.base_units,) + if self.num_blocks == 1: + return [inner_data_shape] + mem_shapes = [inner_data_shape] + curr_shape = inner_data_shape + for down_layer in self.downsample_layers: + curr_shape = down_layer.get_out_shape(curr_shape) + mem_shapes.append(curr_shape) + return mem_shapes + + def construct(self, x, t, cond): + """ + + Parameters + ---------- + x: mindspore.Tensor + Shape (B, t_out, H, W, C) + t: mindspore.Tensor + Shape (B, ) + cond: mindspore.Tensor + Shape (B, t_in, H, W, C) + verbose: bool + + Returns + ------- + out: mindspore.Tensor + Shape (B, T, H, W, C) + """ + + x = ops.cat([cond, x], axis=1) + obs_indicator = ops.ones_like(x[..., :1]) + obs_indicator[:, self.t_in :, ...] = 0.0 + x = ops.cat([x, obs_indicator], axis=-1) + x = x.transpose((0, 4, 1, 2, 3)) + x = self.first_proj(x) + x = x.transpose((0, 2, 3, 4, 1)) + x = self.pos_embed(x) + # inner U-Net + t_emb = self.time_embed(timestep_embedding(t, self.block_units[0])) + if self.unet_res_connect: + res_connect_l = [] + for i in range(self.num_blocks): + # Downample + if i > 0: + x = self.downsample_layers[i - 1](x) + for idx in range(self.depth[i]): + x = x.transpose((0, 4, 1, 2, 3)) + x = self.down_time_embed_blocks[i](x, t_emb) + x = x.transpose((0, 2, 3, 4, 1)) + x = self.down_self_blocks[i][idx](x) + if self.unet_res_connect and i < self.num_blocks - 1: + res_connect_l.append(x) + + for i in range(self.num_blocks - 1, -1, -1): + if self.unet_res_connect and i < self.num_blocks - 1: + x = x + res_connect_l[i] + for idx in range(self.depth[i]): + x = x.transpose((0, 4, 1, 2, 3)) + x = self.up_time_embed_blocks[i](x, t_emb) + x = x.transpose((0, 2, 3, 4, 1)) + x = self.up_self_blocks[i][idx](x) + if i > 0: + x = self.upsample_layers[i - 1](x) + x = self.final_proj(x[:, self.t_in :, ...]) + return x diff --git a/MindEarth/applications/nowcasting/PreDiff/src/diffusion/latent_diffusion.py b/MindEarth/applications/nowcasting/PreDiff/src/diffusion/latent_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..040afb9cb3f2310d48f654c93cd14f18fd52a035 --- /dev/null +++ b/MindEarth/applications/nowcasting/PreDiff/src/diffusion/latent_diffusion.py @@ -0,0 +1,1114 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"Latent Diffusion Model" +import warnings +from typing import Sequence, Dict, Any, Callable +from copy import deepcopy +from functools import partial +import numpy as np +from tqdm import tqdm +from einops import rearrange +from omegaconf import OmegaConf + +import mindspore as ms +from mindspore import nn, ops, Tensor, Parameter, mint + +from src.utils import ( + DiagonalGaussianDistribution, + make_beta_schedule, + extract_into_tensor, + noise_like, + default, + parse_layout_shape, + disabled_train, + layout_to_in_out_slice, + calculate_ssim, + SEVIRSkillScore, +) +from src.sevir_dataset import SEVIRDataModule +from src.vae import AutoencoderKL +from src.knowledge_alignment.alignment_net import AvgIntensityAlignment +from .cuboid_transformer_unet import CuboidTransformerUNet + + +class LatentDiffusion(nn.Cell): + """ + Base class for latent space diffusion models. Implements core diffusion processes including + noise scheduling, model application, loss calculation, and latent space operations. Integrates + main UNet model, VAE, and conditioning modules with support for temporal alignment. + """ + + def __init__( + self, + main_model: nn.Cell, + layout: str = "NTHWC", + data_shape: Sequence[int] = (10, 128, 128, 4), + timesteps=1000, + beta_schedule="linear", + loss_type="l2", + monitor="val/loss", + log_every_t=100, + clip_denoised=False, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + given_betas=None, + original_elbo_weight=0.0, + v_posterior=0.0, + l_simple_weight=1.0, + learn_logvar=False, + logvar_init=0.0, + latent_shape: Sequence[int] = (10, 16, 16, 4), + first_stage_model: nn.Cell = None, + cond_stage_forward=None, + scale_by_std=False, + scale_factor=1.0, + ): + super().__init__() + + self.clip_denoised = clip_denoised + self.log_every_t = log_every_t + self.main_model = main_model + self.layout = layout + self.data_shape = data_shape + self.parse_layout_shape(layout=layout) + self.v_posterior = v_posterior + self.original_elbo_weight = original_elbo_weight + self.l_simple_weight = l_simple_weight + + if monitor is not None: + self.monitor = monitor + + self.register_schedule( + given_betas=given_betas, + beta_schedule=beta_schedule, + timesteps=timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s, + ) + + self.loss_type = loss_type + + self.learn_logvar = learn_logvar + logvar = ops.full(fill_value=logvar_init, size=(self.num_timesteps,)).astype( + ms.float32 + ) + if self.learn_logvar: + self.logvar = Parameter(logvar, requires_grad=True) + else: + self.logvar = Parameter(logvar, name="logvar", requires_grad=False) + + self.latent_shape = latent_shape + self.scale_by_std = scale_by_std + if not scale_by_std: + self.scale_factor = scale_factor + else: + self.logvar = Parameter( + scale_factor, name="scale_factor", requires_grad=False + ) + + self.instantiate_first_stage(first_stage_model) + self.instantiate_cond_stage(cond_stage_forward) + + def set_alignment(self, alignment_fn: Callable = None): + """ + Sets alignment function for denoising process after initialization. + Args: + alignment_fn (Callable): Alignment function with signature + `alignment_fn(zt, t, zc=None, y=None, **kwargs)` + """ + self.alignment_fn = alignment_fn + + def parse_layout_shape(self, layout): + """ + Parses data layout string to determine axis indices. + Args: + layout (str): Data layout specification (e.g., 'NTHWC') + """ + parsed_dict = parse_layout_shape(layout=layout) + self.batch_axis = parsed_dict["batch_axis"] + self.t_axis = parsed_dict["t_axis"] + self.h_axis = parsed_dict["h_axis"] + self.w_axis = parsed_dict["w_axis"] + self.c_axis = parsed_dict["c_axis"] + self.all_slice = [ + slice(None, None), + ] * len(layout) + + def extract_into_tensor(self, a, t, x_shape): + """Extracts schedule parameters into tensor format for current batch.""" + return extract_into_tensor( + a=a, t=t, x_shape=x_shape, batch_axis=self.batch_axis + ) + + @property + def loss_mean_dim(self): + """Computes mean dimensions for loss calculation excluding batch axis.""" + if not hasattr(self, "loss_m_dim"): + loss_m_dim = list(range(len(self.layout))) + loss_m_dim.pop(self.batch_axis) + self.loss_m_dim = tuple(loss_m_dim) + return self.loss_m_dim + + def get_batch_latent_shape(self, batch_size=1): + """ + Generates latent shape with specified batch size. + Args: + batch_size (int): Desired batch size + """ + batch_latent_shape = deepcopy(list(self.latent_shape)) + batch_latent_shape.insert(self.batch_axis, batch_size) + self.batch_latent_shape = tuple(batch_latent_shape) + return self.batch_latent_shape + + def register_schedule( + self, + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + ): + """ + Registers diffusion schedule parameters and precomputes necessary tensors. + Args: + given_betas (Tensor): Custom beta values + beta_schedule (str): Schedule type ('linear', 'cosine') + timesteps (int): Number of diffusion steps + linear_start (float): Linear schedule start value + linear_end (float): Linear schedule end value + cosine_s (float): Cosine schedule parameter + """ + if given_betas is not None: + betas = given_betas + else: + betas = make_beta_schedule( + beta_schedule, + timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s, + ) + alphas = 1.0 - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) + + (timesteps,) = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + if alphas_cumprod.shape[0] != self.num_timesteps: + raise ValueError( + f"Timestep dimension mismatch: alphas_cumprod has {alphas_cumprod.shape[0]} timesteps, " + f"but expected {self.num_timesteps}. " + "The alpha values must be defined for each diffusion timestep." + ) + + to_mindspore = partial(Tensor, dtype=ms.float32) + self.betas = Parameter(to_mindspore(betas), name="betas", requires_grad=False) + self.alphas_cumprod = Parameter( + to_mindspore(alphas_cumprod), name="alphas_cumprod", requires_grad=False + ) + self.alphas_cumprod_prev = Parameter( + to_mindspore(alphas_cumprod_prev), + name="alphas_cumprod_prev", + requires_grad=False, + ) + self.sqrt_alphas_cumprod = Parameter( + to_mindspore(np.sqrt(alphas_cumprod)), + name="sqrt_alphas_cumprod", + requires_grad=False, + ) + self.sqrt_one_minus_alphas_cumprod = Parameter( + to_mindspore(np.sqrt(1.0 - alphas_cumprod)), + name="sqrt_one_minus_alphas_cumprod", + requires_grad=False, + ) + self.log_one_minus_alphas_cumprod = Parameter( + to_mindspore(np.log(1.0 - alphas_cumprod)), + name="log_one_minus_alphas_cumprod", + requires_grad=False, + ) + self.sqrt_recip_alphas_cumprod = Parameter( + to_mindspore(np.sqrt(1.0 / alphas_cumprod)), + name="sqrt_recip_alphas_cumprod", + requires_grad=False, + ) + self.sqrt_recipm1_alphas_cumprod = Parameter( + to_mindspore(np.sqrt(1.0 / alphas_cumprod - 1)), + name="sqrt_recipm1_alphas_cumprod", + requires_grad=False, + ) + + posterior_variance = (1 - self.v_posterior) * betas * ( + 1.0 - alphas_cumprod_prev + ) / (1.0 - alphas_cumprod) + self.v_posterior * betas + self.posterior_variance = Parameter( + to_mindspore(posterior_variance), + name="posterior_variance", + requires_grad=False, + ) + self.posterior_log_variance_clipped = Parameter( + to_mindspore(np.log(np.maximum(posterior_variance, 1e-20))), + name="posterior_log_variance_clipped", + requires_grad=False, + ) + self.posterior_mean_coef1 = Parameter( + to_mindspore(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)), + name="posterior_mean_coef1", + requires_grad=False, + ) + self.posterior_mean_coef2 = Parameter( + to_mindspore( + (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod) + ), + name="posterior_mean_coef2", + requires_grad=False, + ) + + lvlb_weights = self.betas**2 / ( + 2 + * self.posterior_variance + * to_mindspore(alphas) + * (1 - self.alphas_cumprod) + ) + lvlb_weights[0] = lvlb_weights[1] + self.lvlb_weights = Parameter( + lvlb_weights, name="lvlb_weights", requires_grad=False + ) + if ops.isnan(self.lvlb_weights).all(): + raise ValueError( + "All lvlb_weights are NaN (Not a Number). " + "This indicates a numerical instability or uninitialized weights. " + "Please check the weight initialization or training process." + ) + + def instantiate_first_stage(self, first_stage_model): + """ + Initializes and freezes the first stage autoencoder model. + Args: + first_stage_model (nn.Cell): Autoencoder model instance + """ + if isinstance(first_stage_model, nn.Cell): + model = first_stage_model + else: + if first_stage_model is not None: + raise ValueError( + "Custom first_stage_model is not currently supported. " + f"Received: {type(first_stage_model).__name__}. " + "This functionality is planned for future implementation." + ) + raise NotImplementedError( + "Automatic first_stage_model initialization is not yet implemented. " + "Please check for framework updates or consider contributing." + ) + self.first_stage_model = model.set_train(False) + self.first_stage_model.train = disabled_train + for param in self.first_stage_model.trainable_params(): + param.requires_grad = False + + def instantiate_cond_stage(self, cond_stage_forward): + """Configures conditioning stage encoder with spatial rearrangement.""" + self.cond_stage_model = self.first_stage_model + for param in self.cond_stage_model.trainable_params(): + param.requires_grad = False + cond_stage_forward = self.cond_stage_model.encode + + def wrapper(cond_stage_forward: Callable): + def func(c: Dict[str, Any]): + c = c.get("y") + batch_size = c.shape[self.batch_axis] + c = c.transpose(0, 1, 4, 2, 3) + n_new, t_new, c_new, h_new, w_new = c.shape + c = c.reshape(n_new * t_new, c_new, h_new, w_new) + c = cond_stage_forward(c) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + n_new, c_new, h_new, w_new = c.shape + c = c.reshape(batch_size, -1, c_new, h_new, w_new) + c = c.transpose(0, 1, 3, 4, 2) + return c + + return func + + self.cond_stage_forward = wrapper(cond_stage_forward) + + def get_first_stage_encoding(self, encoder_posterior): + """ + Extracts latent representation from encoder output. + Args: + encoder_posterior (Tensor/DiagonalGaussianDistribution): Encoder output + """ + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, Tensor): + z = encoder_posterior + else: + raise NotImplementedError( + f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" + ) + return self.scale_factor * z + + @property + def einops_layout(self): + """Returns Einops layout string for data rearrangement.""" + return " ".join(self.layout) + + @property + def einops_spatial_layout(self): + """Generates spatial Einops pattern for 2D/3D data handling.""" + if not hasattr(self, "_einops_spatial_layout"): + if len(self.layout) not in (4, 5): + raise ValueError( + f"Invalid layout dimension: expected 4 or 5 dimensions, but got {len(self.layout)}. " + f"Current layout: {self.layout}\n" + "Possible solutions:\n" + "1. For 2D data: use [batch, channel, height, width]\n" + "2. For 3D data: use [batch, channel, depth, height, width]" + ) + self._einops_spatial_layout = ( + "(N T) C H W" if self.layout.find("T") else "N C H W" + ) + return self._einops_spatial_layout + + def decode_first_stage(self, z): + """ + Decodes latent representation to data space with spatial rearrangement. + Args: + z (Tensor): Latent tensor + """ + z = 1.0 / self.scale_factor * z + batch_size = z.shape[self.batch_axis] + z = rearrange( + z.asnumpy(), f"{self.einops_layout} -> {self.einops_spatial_layout}" + ) + z = Tensor.from_numpy(z) + output = self.first_stage_model.decode(z) + output = rearrange( + output.asnumpy(), + f"{self.einops_spatial_layout} -> {self.einops_layout}", + N=batch_size, + ) + output = Tensor.from_numpy(output) + return output + + def encode_first_stage(self, x): + """ + Encodes input data into latent space. + Args: + x (Tensor): Input data tensor + """ + encoder_posterior = self.first_stage_model.encode(x) + output = self.get_first_stage_encoding(encoder_posterior) + return output + + def apply_model(self, x_noisy, t, cond): + """ + Applies main UNet model to denoise inputs. + Args: + x_noisy (Tensor): Noisy input tensor + t (Tensor): Time step tensor + cond (Dict): Conditioning information + Returns: + Tensor: Denoising model output + """ + x_recon = self.main_model(x_noisy, t, cond) + if isinstance(x_recon, tuple): + return x_recon[0] + return x_recon + + def q_sample(self, x_start, t, noise=None): + """ + Adds noise to clean data according to diffusion schedule. + Args: + x_start (Tensor): Clean data tensor + t (Tensor): Time step tensor + noise (Tensor): Optional noise tensor + Returns: + Tensor: Noisy data tensor + """ + noise = default(noise, lambda: ops.randn_like(x_start)) + return ( + self.extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) + * x_start + + self.extract_into_tensor( + self.sqrt_one_minus_alphas_cumprod, t, x_start.shape + ) + * noise + ) + + def get_loss(self, pred, target, mean=True): + """ + Calculates loss between prediction and target. + Args: + pred (Tensor): Model predictions + target (Tensor): Target values + mean (bool): Whether to return mean loss + Returns: + Tensor: Loss value(s) + """ + if self.loss_type == "l1": + loss = (target - pred).abs() + if mean: + loss = loss.mean() + elif self.loss_type == "l2": + if mean: + loss = mint.nn.functional.mse_loss(target, pred) + else: + loss = mint.nn.functional.mse_loss(target, pred, reduction="none") + else: + raise NotImplementedError("unknown loss type '{loss_type}'") + + return loss + + def p_losses(self, x_start, cond, t, noise=None): + """ + Computes diffusion training loss for given time steps. + Args: + x_start (Tensor): Clean data tensor + cond (Dict): Conditioning information + t (Tensor): Time step tensor + noise (Tensor): Optional noise tensor + Returns: + Tensor: Total training loss + """ + noise = default(noise, lambda: ops.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_output = self.apply_model(x_noisy, t, cond) + loss_simple = self.get_loss(model_output, noise, mean=False).mean( + axis=self.loss_mean_dim + ) + + logvar_t = self.logvar[t] + + loss = loss_simple / ops.exp(logvar_t) + logvar_t + + loss = self.l_simple_weight * loss.mean() + return loss + + def predict_start_from_noise(self, x_t, t, noise): + """ + Reconstructs clean data from noisy input and predicted noise. + Args: + x_t (Tensor): Noisy data tensor + t (Tensor): Time step tensor + noise (Tensor): Predicted noise tensor + Returns: + Tensor: Reconstructed clean data + """ + sqrt_recip_alphas_cumprod_t = self.extract_into_tensor( + self.sqrt_recip_alphas_cumprod, t, x_t.shape + ) + sqrt_recipm1_alphas_cumprod_t = self.extract_into_tensor( + self.sqrt_recipm1_alphas_cumprod, t, x_t.shape + ) + term1 = sqrt_recip_alphas_cumprod_t * x_t + term2 = sqrt_recipm1_alphas_cumprod_t * noise + pred = term1 - term2 + return pred + + def q_posterior(self, x_start, x_t, t): + """ + Calculates posterior distribution parameters for given time steps. + Args: + x_start (Tensor): Clean data tensor + x_t (Tensor): Noisy data tensor + t (Tensor): Time step tensor + Returns: + Tuple[Tensor]: (posterior mean, variance, log variance) + """ + posterior_mean = ( + self.extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + self.extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = self.extract_into_tensor( + self.posterior_variance, t, x_t.shape + ) + posterior_log_variance_clipped = self.extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance( + self, + zt, + zc, + t, + clip_denoised: bool, + return_x0=False, + score_corrector=None, + corrector_kwargs=None, + ): + """ + Computes predicted mean and variance during denoising. + Args: + zt (Tensor): Current latent sample + zc (Tensor): Conditioning tensor + t (Tensor): Time step tensor + clip_denoised (bool): Whether to clip denoised outputs + return_x0 (bool): Whether to return reconstructed x0 + score_corrector (Callable): Optional score correction function + corrector_kwargs (Dict): Correction function parameters + Returns: + Tuple[Tensor]: (mean, variance, log variance, [reconstructed x0]) + """ + t_in = t + model_out = self.apply_model(zt, t_in, zc) + if score_corrector is not None: + model_out = score_corrector.modify_score( + self, model_out, zt, t, zc, **corrector_kwargs + ) + z_recon = self.predict_start_from_noise(zt, t=t, noise=model_out) + if clip_denoised: + z_recon = z_recon.clamp(-1.0, 1.0) + model_mean, posterior_variance, posterior_log_variance = self.q_posterior( + x_start=z_recon, x_t=zt, t=t + ) + if return_x0: + return model_mean, posterior_variance, posterior_log_variance, z_recon + return model_mean, posterior_variance, posterior_log_variance + + def aligned_mean(self, zt, t, zc, y, orig_mean, orig_log_var, **kwargs): + """ + Calculates aligned mean using gradient-based alignment function. + Args: + zt (Tensor): Current latent sample + t (Tensor): Time step tensor + zc (Tensor): Conditioning tensor + y (Tensor): Ground truth tensor + orig_mean (Tensor): Original mean + orig_log_var (Tensor): Original log variance + **kwargs: Additional alignment parameters + Returns: + Tensor: Aligned mean tensor + """ + align_gradient = self.alignment_fn(zt, t, zc=zc, y=y, **kwargs) + new_mean = orig_mean - (0.5 * orig_log_var).exp() * align_gradient + return new_mean + + def p_sample( + self, + zt, + zc, + t, + y=None, + use_alignment=False, + alignment_kwargs=None, + clip_denoised=False, + return_x0=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + ): + """ + Single step diffusion sampling. + Args: + zt (Tensor): Current noisy sample at time step t + zc (Tensor/Dict): Condition input (latent or processed) + t (Tensor): Time step tensor + y (Tensor, optional): Additional conditioning information + use_alignment (bool): Whether to apply alignment correction + alignment_kwargs (dict, optional): Parameters for alignment correction + clip_denoised (bool): Clip model output to [-1,1] range + return_x0 (bool): Return estimated x0 along with sample + temperature (float): Noise scaling factor + noise_dropout (float): Dropout rate for noise component + score_corrector (object, optional): Model output corrector instance + corrector_kwargs (dict, optional): Parameters for score correction + + Returns: + Tensor: Next denoised sample + Tensor (optional): Estimated x0 if return_x0 is True + """ + batch_size = zt.shape[self.batch_axis] + outputs = self.p_mean_variance( + zt=zt, + zc=zc, + t=t, + clip_denoised=clip_denoised, + return_x0=return_x0, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + ) + if use_alignment: + if alignment_kwargs is None: + alignment_kwargs = {} + model_mean, posterior_variance, model_log_variance, *_ = outputs + model_mean = self.aligned_mean( + zt=zt, + t=t, + zc=zc, + y=y, + orig_mean=model_mean, + orig_log_var=model_log_variance, + **alignment_kwargs, + ) + outputs = (model_mean, posterior_variance, model_log_variance, *outputs[3:]) + if return_x0: + model_mean, _, model_log_variance, x0 = outputs + else: + model_mean, _, model_log_variance = outputs + + noise = noise_like(zt.shape) * temperature + if noise_dropout > 0.0: + noise = ops.dropout(noise, p=noise_dropout) + # no noise when t == 0 + nonzero_mask_shape = [ + 1, + ] * len(zt.shape) + nonzero_mask_shape[self.batch_axis] = batch_size + nonzero_mask = (1 - (t == 0).float()).reshape(*nonzero_mask_shape) + + if return_x0: + return ( + model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, + x0, + ) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + def p_sample_loop( + self, + cond, + shape, + y=None, + use_alignment=False, + alignment_kwargs=None, + return_intermediates=False, + x_t=None, + verbose=False, + timesteps=None, + mask=None, + x0=None, + start_t=None, + log_every_t=None, + ): + """ + Full diffusion sampling loop. + Args: + cond (Tensor/Dict): Conditioning input (processed) + shape (tuple): Output tensor shape (B, C, H, W) + y (Tensor, optional): Additional conditioning info + use_alignment (bool): Enable alignment correction during sampling + alignment_kwargs (dict, optional): Alignment parameters + return_intermediates (bool): Return intermediate steps + x_t (Tensor, optional): Initial noise sample (default: random) + verbose (bool): Show progress bar + timesteps (int): Number of sampling steps + mask (Tensor, optional): Mask for conditional generation (requires x0) + x0 (Tensor, optional): Original image for inpainting/conditional generation + start_t (int): Override maximum time step + log_every_t (int): Frequency of intermediate saves + + Returns: + Tensor: Final generated sample + list[Tensor] (optional): Intermediate samples if requested + """ + + if not log_every_t: + log_every_t = self.log_every_t + batch_size = shape[self.batch_axis] + if x_t is None: + img = ops.randn(shape) + + else: + img = x_t + + intermediates = [img] + if timesteps is None: + timesteps = self.num_timesteps + + if start_t is not None: + timesteps = min(timesteps, start_t) + iterator = ( + tqdm(reversed(range(0, timesteps)), desc="Sampling t", total=timesteps) + if verbose + else reversed(range(0, timesteps)) + ) + + if mask is not None: + if x0 is None: + raise ValueError( + "Missing required input: x0 cannot be None. " + "Please provide valid input data." + ) + + if x0.shape[2:3] != mask.shape[2:3]: + raise ValueError( + f"Spatial dimension mismatch between input and mask. " + f"Input spatial size: {x0.shape[2:3]}, " + f"Mask spatial size: {mask.shape[2:3]}. " + "The height and width dimensions must match exactly." + ) + for i in iterator: + ts = ops.full((batch_size,), i, dtype=ms.int64) + img = self.p_sample( + zt=img, + zc=cond, + t=ts, + y=y, + use_alignment=use_alignment, + alignment_kwargs=alignment_kwargs, + clip_denoised=self.clip_denoised, + ) + if mask is not None: + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1.0 - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(img) + + if return_intermediates: + return img, intermediates + return img + + def sample( + self, + cond, + batch_size=16, + use_alignment=False, + alignment_kwargs=None, + return_intermediates=False, + x_t=None, + verbose=False, + timesteps=None, + mask=None, + x0=None, + shape=None, + return_decoded=True, + ): + """ + High-level sampling interface with conditioning handling. + + Args: + cond (Tensor/Dict): Raw conditioning input (e.g., text/image) + batch_size (int): Number of samples to generate + use_alignment (bool): Enable alignment correction + alignment_kwargs (dict, optional): Alignment parameters + return_intermediates (bool): Return intermediate steps + x_t (Tensor, optional): Initial noise sample + verbose (bool): Show progress + timesteps (int): Sampling steps + mask (Tensor, optional): Inpainting mask (requires x0) + x0 (Tensor, optional): Original image for conditioning + shape (tuple, optional): Output shape override + return_decoded (bool): Return decoded image instead of latent + + Returns: + Tensor: Generated image (decoded if return_decoded) + list[Tensor] (optional): Decoded intermediate steps if requested + """ + if shape is None: + shape = self.get_batch_latent_shape(batch_size=batch_size) + if self.cond_stage_model is not None: + if cond is None: + raise ValueError( + "Required condition is None. " + "This parameter must be provided with a valid value." + ) + cond_tensor_slice = [ + slice(None, None), + ] * len(self.data_shape) + cond_tensor_slice[self.batch_axis] = slice(0, batch_size) + if isinstance(cond, dict): + zc = { + key: ( + cond[key][cond_tensor_slice] + if not isinstance(cond[key], list) + else list(map(lambda x: x[cond_tensor_slice], cond[key])) + ) + for key in cond + } + else: + zc = ( + [c[cond_tensor_slice] for c in cond] + if isinstance(cond, list) + else cond[cond_tensor_slice] + ) + zc = self.cond_stage_forward(zc) + else: + zc = cond if isinstance(cond, Tensor) else cond.get("y", None) + y = cond if isinstance(cond, Tensor) else cond.get("y", None) + output = self.p_sample_loop( + cond=zc, + shape=shape, + y=y, + use_alignment=use_alignment, + alignment_kwargs=alignment_kwargs, + return_intermediates=return_intermediates, + x_t=x_t, + verbose=verbose, + timesteps=timesteps, + mask=mask, + x0=x0, + ) + if return_decoded: + if return_intermediates: + samples, intermediates = output + decoded_samples = self.decode_first_stage(samples) + decoded_intermediates = [ + self.decode_first_stage(ele) for ele in intermediates + ] + output = [decoded_samples, decoded_intermediates] + else: + output = self.decode_first_stage(output) + return output + + + +class PreDiffModule(LatentDiffusion): + """ + Main module for pre-training diffusion models with latent representations. + Integrates configuration loading, model creation, alignment setup, metric initialization, + and visualization parameters. Extends LatentDiffusion to handle cuboid-based UNet architectures + and knowledge alignment for sequential data generation tasks. + """ + + def __init__(self, oc_file: str = None): + self.oc = self._load_configs(oc_file) + latent_model = self._create_latent_model() + first_stage_model = self._create_vae_model() + super().__init__( + **self._prepare_parent_init_params(latent_model, first_stage_model) + ) + self._setup_alignment() + self._initialize_metrics() + self._setup_visualization() + + def _load_configs(self, oc_file): + """Loads all configuration files through a unified entry point.""" + oc_from_file = OmegaConf.load(open(oc_file, "r")) if oc_file else None + return self.get_base_config(oc_from_file=oc_from_file) + + def _create_latent_model(self): + """Builds the CuboidTransformerUNet model based on configurations.""" + latent_model_cfg = OmegaConf.to_object(self.oc.model.latent_model) + return CuboidTransformerUNet( + **{ + k: latent_model_cfg[k] + for k in latent_model_cfg + }, + ) + + def _process_attention_patterns(self, cfg, num_blocks): + """Processes attention patterns from configuration settings.""" + if isinstance(cfg["self_pattern"], str): + return [cfg["self_pattern"]] * num_blocks + return OmegaConf.to_container(cfg["self_pattern"]) + + def _create_vae_model(self): + """Creates and loads pretrained weights for the VAE model.""" + vae_cfg = OmegaConf.to_object(self.oc.model.vae) + model = AutoencoderKL( + **{ + k: vae_cfg[k] + for k in vae_cfg + if k not in ["pretrained_ckpt_path", "data_channels"] + } + ) + self._load_pretrained_weights(model, vae_cfg["pretrained_ckpt_path"]) + return model + + def _load_pretrained_weights(self, model, ckpt_path): + """Loads pretrained weights into the given model if a checkpoint path is provided.""" + if ckpt_path: + param_dict = ms.load_checkpoint(ckpt_path) + param_not_load, _ = ms.load_param_into_net(model, param_dict) + if param_not_load: + print(f"Unloaded AutoencoderKLparameters: {param_not_load}") + else: + warnings.warn( + "Pretrained weights for AutoencoderKL not set. Running sanity check only." + ) + + def _prepare_parent_init_params(self, latent_model, first_stage_model): + """Prepares initialization parameters for the parent class.""" + diffusion_cfg = OmegaConf.to_object(self.oc.model.diffusion) + return { + "main_model": latent_model, + "layout": self.oc.layout.layout, + "loss_type": self.oc.optim.loss_type, + "monitor": self.oc.optim.monitor, + "first_stage_model": first_stage_model, + **{ + k: diffusion_cfg[k] + for k in diffusion_cfg + if k not in ["latent_cond_shape"] + }, + } + + def _setup_alignment(self): + """Sets up alignment using AvgIntensityAlignment if specified in configurations.""" + # from src.knowledge_alignment.alignment_net import AvgIntensityAlignment + + knowledge_cfg = OmegaConf.to_object(self.oc.model.align) + self.alignment_type = knowledge_cfg["alignment_type"] + self.use_alignment = self.alignment_type is not None + + if self.use_alignment: + self.alignment_obj = AvgIntensityAlignment( + guide_scale=knowledge_cfg["guide_scale"], + model_args=knowledge_cfg["model_args"], + model_ckpt_path=knowledge_cfg["model_ckpt_path"], + ) + self.alignment_obj.model.set_train(False) + self.set_alignment(self.alignment_obj.get_mean_shift) + else: + self.set_alignment(None) + + def _initialize_metrics(self): + """Initializes metrics for evaluation based on configurations.""" + if self.oc.eval.eval_unaligned: + self._init_unaligned_metrics() + if self.oc.eval.eval_aligned: + self._init_aligned_metrics() + + def _init_unaligned_metrics(self): + """Initializes unaligned metrics for evaluation.""" + common_args = { + "mode": self.oc.data.metrics_mode, + "seq_in": self.oc.layout.t_out, + "layout": self.layout, + "threshold_list": self.oc.data.threshold_list, + "metrics_list": self.oc.data.metrics_list, + "eps": 1e-4, + } + + self.valid_score = SEVIRSkillScore(**common_args) + + self.test_ssim = calculate_ssim + self.test_aligned_ssim = calculate_ssim + self.test_score = SEVIRSkillScore(**common_args) + + def _init_aligned_metrics(self): + """Initializes aligned metrics for evaluation.""" + common_args = { + "mode": self.oc.data.metrics_mode, + "seq_in": self.oc.layout.t_out, + "layout": self.layout, + "threshold_list": self.oc.data.threshold_list, + "metrics_list": self.oc.data.metrics_list, + "eps": 1e-4, + } + + self.valid_aligned_score = SEVIRSkillScore(**common_args) + + self.test_aligned_ssim = nn.SSIM() + self.test_aligned_score = SEVIRSkillScore(**common_args) + + def _setup_visualization(self): + """Sets up visualization parameters based on configurations.""" + self.logging_prefix = self.oc.logging.logging_prefix + self.train_example_data_idx_list = list( + self.oc.eval.train_example_data_idx_list + ) + self.val_example_data_idx_list = list(self.oc.eval.val_example_data_idx_list) + self.test_example_data_idx_list = list(self.oc.eval.test_example_data_idx_list) + + def get_base_config(self, oc_from_file=None): + """Merges base configuration with configuration loaded from file.""" + if oc_from_file is None: + raise ValueError("oc_from_file is required but not provided") + oc = OmegaConf.create() + oc = OmegaConf.merge(oc, oc_from_file) + return oc + + @classmethod + def get_total_num_steps( + cls, num_samples: int, total_batch_size: int, epoch: int = None + ): + """ + Parameters + ---------- + num_samples: int + The number of samples of the datasets. `num_samples / micro_batch_size` is the number of steps per epoch. + total_batch_size: int + `total_batch_size == micro_batch_size * world_size * grad_accum` + epoch: int + """ + if epoch is None: + epoch = cls.get_optim_config().max_epochs + return int(epoch * num_samples / total_batch_size) + + @staticmethod + def get_sevir_datamodule( + dataset_cfg, micro_batch_size: int = 1, num_workers: int = 8 + ): + """Creates and returns a SEVIRDataModule instance based on dataset configurations.""" + dm = SEVIRDataModule( + sevir_dir=dataset_cfg["root_dir"], + seq_in=dataset_cfg["seq_in"], + sample_mode=dataset_cfg["sample_mode"], + stride=dataset_cfg["stride"], + batch_size=micro_batch_size, + layout=dataset_cfg["layout"], + output_type=np.float32, + preprocess=True, + rescale_method="01", + verbose=False, + aug_mode=dataset_cfg["aug_mode"], + dataset_name=dataset_cfg["dataset_name"], + start_date=dataset_cfg["start_date"], + train_val_split_date=dataset_cfg["train_val_split_date"], + train_test_split_date=dataset_cfg["train_test_split_date"], + end_date=dataset_cfg["end_date"], + val_ratio=dataset_cfg["val_ratio"], + num_workers=num_workers, + raw_seq_len=dataset_cfg["raw_seq_len"] + ) + return dm + + @property + def in_slice(self): + """Returns the input slice based on layout and sequence length configurations.""" + if not hasattr(self, "_in_slice"): + in_slice, out_slice = layout_to_in_out_slice( + layout=self.oc.layout.layout, + t_in=self.oc.layout.t_in, + t_out=self.oc.layout.t_out, + ) + self._in_slice = in_slice + self._out_slice = out_slice + return self._in_slice + + @property + def out_slice(self): + """Returns the output slice based on layout and sequence length configurations.""" + if not hasattr(self, "_out_slice"): + in_slice, out_slice = layout_to_in_out_slice( + layout=self.oc.layout.layout, + t_in=self.oc.layout.t_in, + t_out=self.oc.layout.t_out, + ) + self._in_slice = in_slice + self._out_slice = out_slice + return self._out_slice + + def get_input(self, batch, **kwargs): + """Extracts input data and conditioning information from a raw data batch.""" + return self._get_input_sevirlr( + batch=batch, return_verbose=kwargs.get("return_verbose", False) + ) + + def _get_input_sevirlr(self, batch, return_verbose=False): + """Specific implementation of input extraction for SEVIRLR dataset.""" + seq = batch + in_seq = seq[self.in_slice] + out_seq = seq[self.out_slice] + if return_verbose: + return out_seq, {"y": in_seq}, in_seq + return out_seq, {"y": in_seq} diff --git a/MindEarth/applications/nowcasting/PreDiff/src/diffusion/solver.py b/MindEarth/applications/nowcasting/PreDiff/src/diffusion/solver.py new file mode 100644 index 0000000000000000000000000000000000000000..bdbb5d4f85002b5943f4b5c290cf7c9d650dedaf --- /dev/null +++ b/MindEarth/applications/nowcasting/PreDiff/src/diffusion/solver.py @@ -0,0 +1,144 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"diffusion model training" +import time +import os + + +import mindspore as ms +from mindspore import ops, nn +from mindspore.train.serialization import save_checkpoint + +from src.sevir_dataset import SEVIRDataset + + +class DiffusionTrainer(nn.Cell): + """ + Class managing the training pipeline for diffusion models. Handles dataset processing, + optimizer configuration, gradient clipping, checkpoint saving, and logging. + """ + def __init__(self, main_module, dm, logger, config): + """ + Initialize trainer with model, data module, logger, and configuration. + Args: + main_module: Main diffusion model to be trained + dm: Data module providing training dataset + logger: Logging utility for training progress + config: Configuration dictionary containing hyperparameters + """ + super().__init__() + self.main_module = main_module + self.traindataset = dm.sevir_train + self.logger = logger + self.datasetprocessing = SEVIRDataset( + data_types=["vil"], + layout="NHWT", + rescale_method=config.get("rescale_method", "01"), + ) + self.example_save_dir = config["summary"].get("summary_dir", "./summary") + self.fs = config["eval"].get("fs", 20) + self.label_offset = config["eval"].get("label_offset", [-0.5, 0.5]) + self.label_avg_int = config["eval"].get("label_avg_int", False) + self.current_epoch = 0 + self.learn_logvar = ( + config.get("model", {}).get("diffusion", {}).get("learn_logvar", False) + ) + self.logvar = main_module.logvar + self.maeloss = nn.MAELoss() + self.optim_config = config["optim"] + self.clip_norm = config.get("clip_norm", 2) + self.ckpt_dir = os.path.join(self.example_save_dir, "ckpt") + self.keep_ckpt_max = config["summary"].get("keep_ckpt_max", 100) + self.ckpt_history = [] + self.grad_clip_fn = ops.clip_by_global_norm + self.optimizer = nn.Adam(params=self.main_module.main_model.trainable_params(), + learning_rate=config["optim"].get("lr", 1e-5)) + os.makedirs(self.ckpt_dir, exist_ok=True) + + def train(self, total_steps: int): + """Execute complete training pipeline.""" + self.main_module.main_model.set_train(True) + self.logger.info(f"total_steps: {total_steps}") + self.logger.info("Initializing training process...") + loss_processor = Trainonestepforward(self.main_module) + grad_func = ms.ops.value_and_grad(loss_processor, None, self.main_module.main_model.trainable_params()) + for epoch in range(self.optim_config["max_epochs"]): + epoch_loss = 0.0 + epoch_start = time.time() + + iterator = self.traindataset.create_dict_iterator() + if not iterator: + raise ValueError( + "Empty dataset error: The provided dataset iterator contains no data. " + "Please verify your data loading pipeline and ensure the dataset is properly populated." + ) + batch_idx = 0 + for batch_idx, batch in enumerate(iterator): + processed_data = self.datasetprocessing.process_data(batch["vil"]) + loss_value, gradients = grad_func(processed_data) + clipped_grads = self.grad_clip_fn(gradients, self.clip_norm) + self.optimizer(clipped_grads) + epoch_loss += loss_value.asnumpy() + self.logger.info( + f"epoch: {epoch} step: {batch_idx}, loss: {loss_value}" + ) + self._save_ckpt(epoch) + epoch_time = time.time() - epoch_start + self.logger.info( + f"Epoch {epoch} completed in {epoch_time:.2f}s | " + f"Avg Loss: {epoch_loss/(batch_idx+1):.4f}" + ) + + def _save_ckpt(self, epoch: int): + """Save model ckpt with rotation policy""" + ckpt_file = f"diffusion_epoch{epoch}.ckpt" + ckpt_path = os.path.join(self.ckpt_dir, ckpt_file) + + save_checkpoint(self.main_module.main_model, ckpt_path) + self.ckpt_history.append(ckpt_path) + + if len(self.ckpt_history) > self.keep_ckpt_max: + removed_ckpt = self.ckpt_history.pop(0) + os.remove(removed_ckpt) + self.logger.info(f"Removed outdated ckpt: {removed_ckpt}") + + +class Trainonestepforward(nn.Cell): + """A neural network cell that performs one training step forward pass for a diffusion model. + This class encapsulates the forward pass computation for training a diffusion model, + handling the input processing, latent space encoding, conditioning, and loss calculation. + Args: + model (nn.Cell): The main diffusion model containing the necessary submodules + for encoding, conditioning, and loss computation. + """ + + def __init__(self, model): + super().__init__() + self.main_module = model + + def construct(self, inputs): + """Perform one forward training step and compute the loss.""" + x, condition = self.main_module.get_input(inputs) + x = x.transpose(0, 1, 4, 2, 3) + n, t_, c_, h, w = x.shape + x = x.reshape(n * t_, c_, h, w) + z = self.main_module.encode_first_stage(x) + _, c_z, h_z, w_z = z.shape + z = z.reshape(n, -1, c_z, h_z, w_z) + z = z.transpose(0, 1, 3, 4, 2) + t = ops.randint(0, self.main_module.num_timesteps, (n,)).long() + zc = self.main_module.cond_stage_forward(condition) + loss = self.main_module.p_losses(z, zc, t, noise=None) + return loss diff --git a/MindEarth/applications/nowcasting/PreDiff/src/diffusion/time_embed.py b/MindEarth/applications/nowcasting/PreDiff/src/diffusion/time_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..f052dc189b0f0c214c45d7b1017cc7e9c5ac560e --- /dev/null +++ b/MindEarth/applications/nowcasting/PreDiff/src/diffusion/time_embed.py @@ -0,0 +1,292 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"TimeEmbedLayer and TimeEmbedResBlock" +from mindspore import nn, ops + +from src.utils import conv_nd, apply_initialization, avg_pool_nd + + +class Upsample(nn.Cell): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd( + dims, self.channels, self.out_channels, 3, padding=padding + ) + + def construct(self, x): + '''upsample forward''' + if x.shape[1] != self.channels: + raise ValueError( + f"Channel dimension mismatch: input has {x.shape[1]} channels, " + f"but layer expects {self.channels} channels. " + f"Input shape: {x.shape}, expected channels dimension: {self.channels}. " + "Please adjust your input data or layer configuration." + ) + if self.dims == 3: + x = ops.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = ops.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Cell): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, + self.channels, + self.out_channels, + 3, + stride=stride, + padding=padding, + ) + else: + if self.channels != self.out_channels: + raise ValueError( + f"Channel configuration mismatch: input channels ({self.channels}) " + f"must match output channels ({self.out_channels}) for this operation. " + "Please adjust either the input channels or the layer configuration." + ) + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def construct(self, x): + if x.shape[1] != self.channels: + raise ValueError( + f"Input channel mismatch: Expected {self.channels} channels, " + f"but received input with {x.shape[1]} channels. " + f"Full input shape: {x.shape}. " + "Please ensure your input data matches the layer's channel requirements." + ) + return self.op(x) + + +class TimeEmbedLayer(nn.Cell): + """ + A neural network layer that embeds time information into a higher-dimensional space. + + The layer consists of two linear layers separated by a SiLU activation function. + It takes an input tensor with a specified number of base channels and transforms it + into a tensor with a specified number of time embedding channels. + Parameters: + - base_channels (int): Number of channels in the input tensor. + - time_embed_channels (int): Number of channels in the output embedded tensor. + - linear_init_mode (str, optional): Initialization mode for the linear layers. Defaults to "0". + """ + + def __init__(self, base_channels, time_embed_channels, linear_init_mode="0"): + super().__init__() + self.layer = nn.SequentialCell( + nn.Dense(base_channels, time_embed_channels), + nn.SiLU(), + nn.Dense(time_embed_channels, time_embed_channels), + ) + self.linear_init_mode = linear_init_mode + + def construct(self, x): + """Forward pass through the TimeEmbedLayer.""" + return self.layer(x) + + def reset_parameters(self): + """Reset the parameters of the linear layers in the TimeEmbedLayer.""" + apply_initialization(self.layer[0], linear_mode=self.linear_init_mode) + apply_initialization(self.layer[2], linear_mode=self.linear_init_mode) + + +class TimeEmbedResBlock(nn.Cell): + r""" + Modifications: + 1. Change GroupNorm32 to use arbitrary `num_groups`. + 2. Add method `self.reset_parameters()`. + 3. Use gradient ckpt from mindspore instead of the stable diffusion implementation + 4. If no input time embed, it degrades to res block. + """ + + def __init__( + self, + channels, + dropout, + emb_channels=None, + out_channels=None, + use_conv=False, + use_embed=True, + use_scale_shift_norm=False, + dims=2, + up=False, + down=False, + norm_groups=32, + ): + r""" + Parameters + ---------- + channels + dropout + emb_channels + out_channels + use_conv + use_embed: bool + include `emb` as input in `self.forward()` + use_scale_shift_norm: bool + take effect only when `use_embed == True` + dims + up + down + norm_groups + """ + super().__init__() + self.channels = channels + self.dropout = dropout + self.use_embed = use_embed + if use_embed: + if not isinstance(emb_channels, int): + raise TypeError( + f"Invalid type for emb_channels: expected integer, got {type(emb_channels).__name__}. " + f"Received value: {emb_channels}. " + "Please provide an integer value for the embedding channels." + ) + self.emb_channels = emb_channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.SequentialCell( + nn.GroupNorm( + num_groups=norm_groups if channels % norm_groups == 0 else channels, + num_channels=channels, + ), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + if use_embed: + self.emb_layers = nn.SequentialCell( + nn.SiLU(), + nn.Dense( + in_channels=emb_channels, + out_channels=( + 2 * self.out_channels + if use_scale_shift_norm + else self.out_channels + ), + ), + ) + self.out_layers = nn.SequentialCell( + nn.GroupNorm( + num_groups=( + norm_groups + if self.out_channels % norm_groups == 0 + else self.out_channels + ), + num_channels=self.out_channels, + ), + nn.SiLU(), + nn.Dropout(p=dropout), + # nn.Dropout(p=0), + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1), + ) + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + self.reset_parameters() + + def construct(self, x, emb=None): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + + Parameters + ---------- + x: an [N x C x ...] Tensor of features. + emb: an [N x emb_channels] Tensor of timestep embeddings. + + Returns + ------- + out: an [N x C x ...] Tensor of outputs. + """ + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + if self.use_embed: + emb_out = self.emb_layers(emb).astype(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = ops.chunk(emb_out, 2, axis=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + else: + h = self.out_layers(h) + n = self.skip_connection(x) + h + return n + + def reset_parameters(self): + for _, cell in self.cells_and_names(): + apply_initialization(cell) + for p in self.out_layers[-1].get_parameters(): + p.set_data(ops.zeros(p.shape, dtype=p.dtype)) diff --git a/MindEarth/applications/nowcasting/PreDiff/src/knowledge_alignment/alignment.py b/MindEarth/applications/nowcasting/PreDiff/src/knowledge_alignment/alignment.py new file mode 100644 index 0000000000000000000000000000000000000000..ec4b7f1608bcfbee9891e0879d2d6b7e237c97d5 --- /dev/null +++ b/MindEarth/applications/nowcasting/PreDiff/src/knowledge_alignment/alignment.py @@ -0,0 +1,594 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"NoisyCuboidTransformerEncoder" +import math +import numpy as np + +import mindspore as ms +from mindspore import nn, ops, mint +from mindspore.common.initializer import TruncatedNormal + +from src.utils import ( + conv_nd, + zero_module, + timestep_embedding, + apply_initialization, + round_to, + self_axial +) +from src.diffusion import ( + PatchMerging3D, + PosEmbed, + StackCuboidSelfAttentionBlock, + TimeEmbedLayer, + TimeEmbedResBlock, +) + + +class QKVAttention(nn.Cell): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def construct(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + if width % (3 * self.n_heads) != 0: + raise ValueError( + f"Dimension mismatch: width ({width}) must be divisible by {3 * self.n_heads} " + f"(3 * n_heads), but got remainder {width % (3 * self.n_heads)}. " + f"Current configuration: n_heads={self.n_heads}. " + "Please adjust either the input width or the number of attention heads." + ) + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + q_transposed = ops.transpose( + (q * scale).view(bs * self.n_heads, ch, length), (0, 2, 1) + ) + k_reshaped = (k * scale).view(bs * self.n_heads, ch, length) + weight = ops.BatchMatMul()(q_transposed, k_reshaped) + weight = nn.Softmax(axis=-1)(weight.float()).type(weight.dtype) + weight_transposed = ops.transpose(weight, (0, 2, 1)) + v_reshaped = v.reshape(bs * self.n_heads, ch, length) + a = ops.BatchMatMul()(v_reshaped, weight_transposed) + return a.reshape(bs, -1, length) + + +class AttentionPool3d(nn.Cell): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + data_dim: int, + embed_dim: int, + num_heads: int, + output_dim: int = None, + init_mode: str = "0", + ): + r""" + Parameters + ---------- + data_dim: int + e.g. T*H*W if data is 3D + embed_dim: int + input data channels + num_heads: int + output_dim: int + """ + super().__init__() + self.positional_embedding = ms.Parameter( + ops.randn(embed_dim, data_dim + 1) / embed_dim**0.5 + ) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = num_heads + self.attention = QKVAttention(self.num_heads) + self.init_mode = init_mode + + def construct(self, x): + r""" + + Parameters + ---------- + x: ms.Tensor + layout = "NCTHW" + + Returns + ------- + ret: ms.Tensor + layout = "NC" + """ + b, c, _ = x.shape + x = x.reshape(b, c, -1) + x = mint.cat([x.mean(axis=-1, keep_dims=True), x], dim=-1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + def reset_parameters(self): + '''set parameters''' + apply_initialization(self.qkv_proj, conv_mode="0") + apply_initialization(self.c_proj, conv_mode=self.init_mode) + + +class NoisyCuboidTransformerEncoder(nn.Cell): + r""" + Half U-Net style CuboidTransformerEncoder that parametrizes `U(z_t, t, ...)`. + It takes `x_t`, `t` as input. + The conditioning can be concatenated to the input like the U-Net in FVD paper. + + For each block, we apply the StackCuboidSelfAttention. The final block state is read out by a pooling layer. + + x --> attn --> downscale --> ... --> poll --> out + + Besides, we insert the embeddings of the timesteps `t` before each cuboid attention blocks. + """ + + def __init__( + self, + input_shape=None, + out_channels=1, + base_units=128, + block_units=None, + scale_alpha=1.0, + depth=None, + downsample=2, + downsample_type="patch_merge", + use_attn_pattern=None, + block_cuboid_size=None, + block_cuboid_strategy=None, + block_cuboid_shift_size=None, + num_heads=4, + attn_drop=0.0, + proj_drop=0.0, + ffn_drop=0.0, + ffn_activation="gelu", + gated_ffn=False, + norm_layer="layer_norm", + use_inter_ffn=True, + hierarchical_pos_embed=False, + padding_type="zeros", + use_relative_pos=True, + self_attn_use_final_proj=True, + # global vectors + num_global_vectors=0, + use_global_vector_ffn=True, + use_global_self_attn=False, + separate_global_qkv=False, + global_dim_ratio=1, + # initialization + attn_linear_init_mode="0", + ffn_linear_init_mode="0", + ffn2_linear_init_mode="2", + attn_proj_linear_init_mode="2", + conv_init_mode="0", + down_linear_init_mode="0", + global_proj_linear_init_mode="2", + norm_init_mode="0", + # timestep embedding for diffusion + time_embed_channels_mult=4, + time_embed_use_scale_shift_norm=False, + time_embed_dropout=0.0, + # readout + pool: str = "attention", + readout_seq: bool = True, + t_out: int = None, + ): + super().__init__() + # initialization mode + self.attn_linear_init_mode = attn_linear_init_mode + self.ffn_linear_init_mode = ffn_linear_init_mode + self.ffn2_linear_init_mode = ffn2_linear_init_mode + self.attn_proj_linear_init_mode = attn_proj_linear_init_mode + self.conv_init_mode = conv_init_mode + self.down_linear_init_mode = down_linear_init_mode + self.global_proj_linear_init_mode = global_proj_linear_init_mode + self.norm_init_mode = norm_init_mode + + self.input_shape = input_shape + self.out_channels = out_channels + self.num_blocks = len(depth) + self.depth = depth + self.base_units = base_units + self.scale_alpha = scale_alpha + self.downsample = downsample + self.downsample_type = downsample_type + if not isinstance(downsample, (tuple, list)): + downsample = (1, downsample, downsample) + if block_units is None: + block_units = [ + round_to(base_units * int((max(downsample) ** scale_alpha) ** i), 4) + for i in range(self.num_blocks) + ] + else: + if len(block_units) != self.num_blocks: + raise ValueError( + f"Block configuration mismatch: Expected {self.num_blocks} blocks, " + f"but got {len(block_units)}. Received block_units: {block_units}" + ) + + if block_units[0] != base_units: + raise ValueError( + f"First block units mismatch: Expected {base_units}, " + f"but got {block_units[0]}. The first block must match base_units." + ) + self.block_units = block_units + self.hierarchical_pos_embed = hierarchical_pos_embed + self.num_global_vectors = num_global_vectors + use_global_vector = num_global_vectors > 0 + self.use_global_vector = use_global_vector + self.global_dim_ratio = global_dim_ratio + self.use_global_vector_ffn = use_global_vector_ffn + + self.time_embed_channels_mult = time_embed_channels_mult + self.time_embed_channels = self.block_units[0] * time_embed_channels_mult + self.time_embed_use_scale_shift_norm = time_embed_use_scale_shift_norm + self.time_embed_dropout = time_embed_dropout + self.pool = pool + self.readout_seq = readout_seq + self.t_out = t_out + + if self.use_global_vector: + self.init_global_vectors = ms.Parameter( + mint.zeros((self.num_global_vectors, global_dim_ratio * base_units)) + ) + + _, h_in, w_in, _ = input_shape + self.first_proj = TimeEmbedResBlock( + channels=input_shape[-1], + emb_channels=None, + dropout=proj_drop, + out_channels=self.base_units, + use_conv=False, + use_embed=False, + use_scale_shift_norm=False, + dims=3, + up=False, + down=False, + ) + self.pos_embed = PosEmbed( + embed_dim=base_units, + max_t=input_shape[0], + max_h=h_in, + max_w=w_in, + ) + + # diffusion time embed + self.time_embed = TimeEmbedLayer( + base_channels=self.block_units[0], + time_embed_channels=self.time_embed_channels, + ) + if self.num_blocks > 1: + if downsample_type == "patch_merge": + self.downsample_layers = nn.CellList( + [ + PatchMerging3D( + dim=self.block_units[i], + downsample=downsample, + padding_type=padding_type, + out_dim=self.block_units[i + 1], + linear_init_mode=down_linear_init_mode, + norm_init_mode=norm_init_mode, + ) + for i in range(self.num_blocks - 1) + ] + ) + else: + raise NotImplementedError + if self.use_global_vector: + self.down_layer_global_proj = nn.CellList( + [ + mint.nn.Linear( + in_features=global_dim_ratio * self.block_units[i], + out_features=global_dim_ratio * self.block_units[i + 1], + ) + for i in range(self.num_blocks - 1) + ] + ) + if self.hierarchical_pos_embed: + self.down_hierarchical_pos_embed_l = nn.CellList( + [ + PosEmbed( + embed_dim=self.block_units[i], + max_t=self.mem_shapes[i][0], + max_h=self.mem_shapes[i][1], + max_w=self.mem_shapes[i][2], + ) + for i in range(self.num_blocks - 1) + ] + ) + + if use_attn_pattern: + block_attn_patterns = self.depth + block_cuboid_size = [] + block_cuboid_strategy = [] + block_cuboid_shift_size = [] + for idx, _ in enumerate(block_attn_patterns): + cuboid_size, strategy, shift_size = self_axial(self.mem_shapes[idx]) + block_cuboid_size.append(cuboid_size) + block_cuboid_strategy.append(strategy) + block_cuboid_shift_size.append(shift_size) + else: + if not isinstance(block_cuboid_size[0][0], (list, tuple)): + block_cuboid_size = [block_cuboid_size for _ in range(self.num_blocks)] + else: + if len(block_cuboid_size) != self.num_blocks: + raise ValueError( + f"Block cuboid configuration error: Expected {self.num_blocks} blocks, " + f"but received {len(block_cuboid_size)} block configurations. " + f"Received block_cuboid_size: {block_cuboid_size}\n" + "Please ensure the number of block configurations matches num_blocks." + ) + + if not isinstance(block_cuboid_strategy[0][0], (list, tuple)): + block_cuboid_strategy = [ + block_cuboid_strategy for _ in range(self.num_blocks) + ] + else: + if len(block_cuboid_strategy) != self.num_blocks: + raise ValueError( + f"Block strategy configuration error: Expected {self.num_blocks} strategies (one per block), " + f"but received {len(block_cuboid_strategy)}. " + f"Received strategies: {block_cuboid_strategy}\n" + "Each cuboid block must have a corresponding processing strategy." + ) + + if not isinstance(block_cuboid_shift_size[0][0], (list, tuple)): + block_cuboid_shift_size = [ + block_cuboid_shift_size for _ in range(self.num_blocks) + ] + else: + if len(block_cuboid_shift_size) != self.num_blocks: + raise ValueError( + f"Block shift configuration error: Expected {self.num_blocks} shift sizes (one per block), " + f"but received {len(block_cuboid_shift_size)}. " + f"Received shift sizes: {block_cuboid_shift_size}\n" + "Each cuboid block must have a corresponding shift size configuration." + ) + self.block_cuboid_size = block_cuboid_size + self.block_cuboid_strategy = block_cuboid_strategy + self.block_cuboid_shift_size = block_cuboid_shift_size + + # cuboid self attention blocks + down_self_blocks = [] + # ResBlocks that incorporate `time_embed` + down_time_embed_blocks = [] + for i in range(self.num_blocks): + down_time_embed_blocks.append( + TimeEmbedResBlock( + channels=self.mem_shapes[i][-1], + emb_channels=self.time_embed_channels, + dropout=self.time_embed_dropout, + out_channels=self.mem_shapes[i][-1], + use_conv=False, + use_embed=True, + use_scale_shift_norm=self.time_embed_use_scale_shift_norm, + dims=3, + up=False, + down=False, + ) + ) + + ele_depth = depth[i] + + stack_cuboid_blocks = [ + StackCuboidSelfAttentionBlock( + dim=self.mem_shapes[i][-1], + num_heads=num_heads, + block_cuboid_size=block_cuboid_size[i], + block_strategy=block_cuboid_strategy[i], + block_shift_size=block_cuboid_shift_size[i], + attn_drop=attn_drop, + proj_drop=proj_drop, + ffn_drop=ffn_drop, + activation=ffn_activation, + gated_ffn=gated_ffn, + norm_layer=norm_layer, + use_inter_ffn=use_inter_ffn, + padding_type=padding_type, + use_global_vector=use_global_vector, + use_global_vector_ffn=use_global_vector_ffn, + use_global_self_attn=use_global_self_attn, + separate_global_qkv=separate_global_qkv, + global_dim_ratio=global_dim_ratio, + use_relative_pos=use_relative_pos, + use_final_proj=self_attn_use_final_proj, + # initialization + attn_linear_init_mode=attn_linear_init_mode, + ffn_linear_init_mode=ffn_linear_init_mode, + ffn2_linear_init_mode=ffn2_linear_init_mode, + attn_proj_linear_init_mode=attn_proj_linear_init_mode, + norm_init_mode=norm_init_mode, + ) + for _ in range(ele_depth) + ] + down_self_blocks.append(nn.CellList(stack_cuboid_blocks)) + + self.down_self_blocks = nn.CellList(down_self_blocks) + self.down_time_embed_blocks = nn.CellList(down_time_embed_blocks) + + out_shape = self.mem_shapes[-1] + cuboid_out_channels = out_shape[-1] + if pool == "adaptive": + self.out = nn.SequentialCell( + nn.GroupNorm(min(cuboid_out_channels, 32), cuboid_out_channels), + nn.SiLU(), + nn.AdaptiveAvgPool2d((1, 1)), + zero_module(conv_nd(2, cuboid_out_channels, out_channels, 1)), + nn.Flatten(), + ) + elif pool == "attention": + if readout_seq: + data_dim = np.prod(out_shape[1:-1]).item() + num_global_vectors + else: + data_dim = np.prod(out_shape[:-1]).item() + num_global_vectors + self.out = nn.SequentialCell( + nn.GroupNorm(min(cuboid_out_channels, 32), cuboid_out_channels), + nn.SiLU(), + AttentionPool3d( + data_dim, + cuboid_out_channels, + num_heads, + out_channels, + init_mode="0", + ), + ) + elif pool == "spatial": + self.out = nn.SequentialCell( + mint.nn.Linear(self._feature_size, 2048), + mint.nn.ReLU(), + mint.nn.Linear(2048, out_channels), + ) + elif pool == "spatial_v2": + self.out = nn.SequentialCell( + mint.nn.Linear(self._feature_size, 2048), + mint.nn.GroupNorm(2048, 2048), + nn.SiLU(), + mint.nn.Linear(2048, out_channels), + ) + else: + raise NotImplementedError(f"Unexpected {pool} pooling") + + self.reset_parameters() + + def reset_parameters(self): + """set parameters""" + if self.num_global_vectors > 0: + TruncatedNormal(self.init_global_vectors, sigma=0.02) + self.first_proj.reset_parameters() + self.pos_embed.reset_parameters() + # inner U-Net + for block in self.down_self_blocks: + for m in block: + m.reset_parameters() + for m in self.down_time_embed_blocks: + m.reset_parameters() + if self.num_blocks > 1: + for m in self.downsample_layers: + m.reset_parameters() + if self.use_global_vector: + apply_initialization( + self.down_layer_global_proj, + linear_mode=self.global_proj_linear_init_mode, + ) + if self.hierarchical_pos_embed: + for m in self.down_hierarchical_pos_embed_l: + m.reset_parameters() + if self.pool == "attention": + apply_initialization(self.out[0], norm_mode=self.norm_init_mode) + self.out[2].reset_parameters() + else: + raise NotImplementedError + + def transpose_and_first_proj(self, x, batch_size): + """transpose and first_proj""" + x = x.transpose(0, 4, 1, 2, 3) + x = self.first_proj(x) + x = x.transpose(0, 2, 3, 4, 1) + if self.use_global_vector: + global_vectors = self.init_global_vectors.broadcast_to( + batch_size, + self.num_global_vectors, + self.global_dim_ratio * self.base_units, + ) + return x, global_vectors + return x, None + + @property + def mem_shapes(self): + """Get the shape of the output memory based on the input shape. This can be used for constructing the decoder. + + Returns + ------- + mem_shapes + A list of shapes of the output memory + """ + inner_data_shape = tuple(self.input_shape)[:3] + (self.base_units,) + if self.num_blocks == 1: + return [inner_data_shape] + mem_shapes = [inner_data_shape] + curr_shape = inner_data_shape + for down_layer in self.downsample_layers: + curr_shape = down_layer.get_out_shape(curr_shape) + mem_shapes.append(curr_shape) + return mem_shapes + + def construct(self, x, t): + """ + Forward pass through the NoisyCuboidTransformerEncoder. + + Parameters: + - x (Tensor): Input tensor of shape (batch_size, seq_in, H, W, C). + - t (Tensor): Timestep tensor. + + Returns: + - Tensor: Output tensor after processing through the encoder. + """ + batch_size, seq_in, _, _, _ = x.shape + x, global_vectors = self.transpose_and_first_proj(x, batch_size) + x = self.pos_embed(x) + t_emb = self.time_embed(timestep_embedding(t, self.block_units[0])) + for i in range(self.num_blocks): + if i > 0: + x = self.downsample_layers[i - 1](x) + if self.hierarchical_pos_embed: + x = self.down_hierarchical_pos_embed_l[i - 1](x) + if self.use_global_vector: + global_vectors = self.down_layer_global_proj[i - 1](global_vectors) + for idx in range(self.depth[i]): + x = x.transpose(0, 4, 1, 2, 3) + x = self.down_time_embed_blocks[i](x, t_emb) + x = x.transpose(0, 2, 3, 4, 1) + if self.use_global_vector: + x, global_vectors = self.down_self_blocks[i][idx](x, global_vectors) + else: + x = self.down_self_blocks[i][idx](x) + + if self.readout_seq: + if self.t_out is not None: + seq_in = self.t_out + start_idx = x.shape[1] - self.t_out + x = x[:, start_idx:, ...] + out = x.transpose(0, 1, 4, 2, 3) + b_out, t_out, c_out, h_out, w_out = out.shape + out = out.reshape(b_out * t_out, c_out, h_out * w_out) + if self.num_global_vectors > 0: + out_global = global_vectors.tile((seq_in, 1, 1)) + out_global = out_global.transpose(0, 2, 1) + out = mint.cat([out, out_global], dim=2) + out = self.out(out) + out = out.reshape(batch_size, seq_in, -1) + else: + out = x.transpose(0, 4, 1, 2, 3) + b_out, c_out, t_out, h_out, w_out = out.shape + out = out.reshape(b_out, c_out, t_out * h_out * w_out) + if self.num_global_vectors > 0: + out_global = global_vectors.transpose(0, 2, 1) + out = mint.cat([out, out_global], dim=2) + out = self.out(out) + return out diff --git a/MindEarth/applications/nowcasting/PreDiff/src/sevir_dataset.py b/MindEarth/applications/nowcasting/PreDiff/src/sevir_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..5b9c48128820cf64ad0b23b4fe3f7258061e373d --- /dev/null +++ b/MindEarth/applications/nowcasting/PreDiff/src/sevir_dataset.py @@ -0,0 +1,1039 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"generate dataset" +import os +import datetime +from typing import Union, Sequence, Tuple +import h5py +import pandas as pd +from einops import rearrange +import numpy as np + +import mindspore as ms +import mindspore.dataset as ds +import mindspore.dataset.vision.transforms as vision +from mindspore import nn, ops, Tensor +from mindspore.dataset.vision import RandomRotation, Rotate +from mindspore.dataset.transforms import Compose + + +SEVIR_DATA_TYPES = ["vis", "ir069", "ir107", "vil", "lght"] +LIGHTING_FRAME_TIMES = np.arange(-120.0, 125.0, 5) * 60 +SEVIR_DATA_SHAPE = { + "lght": (48, 48), +} +PREPROCESS_SCALE_01 = { + "vis": 1, + "ir069": 1, + "ir107": 1, + "vil": 1 / 255, + "lght": 1, +} +PREPROCESS_OFFSET_01 = { + "vis": 0, + "ir069": 0, + "ir107": 0, + "vil": 0, + "lght": 0, +} + + +def path_splitall(path): + """ + Split a file path into all its components. + + Recursively splits the path into directory components and the final file name, + handling both absolute and relative paths across different OS conventions. + + Args: + path (str): Input file path to split + + Returns: + List[str]: List of path components from root to leaf + """ + allparts = [] + while 1: + parts = os.path.split(path) + if parts[0] == path: + allparts.insert(0, parts[0]) + break + elif parts[1] == path: + allparts.insert(0, parts[1]) + break + else: + path = parts[0] + allparts.insert(0, parts[1]) + return allparts + + +def change_layout(data, in_layout="NHWT", out_layout="NHWT"): + """ + Convert data layout between different dimension orderings. + + Handles layout transformations using einops.rearrange, with special handling + for 'C' (channel) dimensions which are treated as singleton dimensions. + + Args: + data (Tensor/ndarray): Input data to transform + in_layout (str): Current dimension order (e.g., "NHWT") + out_layout (str): Target dimension order (e.g., "THWC") + + Returns: + ndarray: Data in new layout with applied transformations + """ + if isinstance(data, ms.Tensor): + data = data.asnumpy() + in_layout = " ".join(in_layout.replace("C", "1")) + out_layout = " ".join(out_layout.replace("C", "1")) + data = rearrange(data, f"{in_layout} -> {out_layout}") + return data + + +class DatasetSEVIR: + """ + SEVIR Dataset class for weather event sequence data. + + Provides data loading and augmentation capabilities for SEVIR (Severe Weather Events Dataset) + with support for different temporal layouts and data preprocessing. + + Attributes: + layout (str): Output data layout configuration + sevir_dataloader (SEVIRDataLoader): Core data loading component + aug_pipeline (AugmentationPipeline): Data augmentation operations + """ + def __init__( + self, + seq_in: int = 25, + raw_seq_in: int = 49, + sample_mode: str = "sequent", + stride: int = 12, + layout: str = "THWC", + ori_layout: str = "NHWT", + split_mode: str = "uneven", + sevir_catalog: Union[str, pd.DataFrame] = None, + sevir_data_dir: str = None, + start_date: datetime.datetime = None, + end_date: datetime.datetime = None, + datetime_filter=None, + catalog_filter="default", + shuffle: bool = False, + shuffle_seed: int = 1, + output_type=np.float32, + preprocess: bool = True, + rescale_method: str = "01", + verbose: bool = False, + aug_mode: str = "0", + ): + super().__init__() + self.layout = layout.replace("C", "1") + self.sevir_dataloader = SEVIRDataLoader( + data_types=[ + "vil", + ], + seq_in=seq_in, + raw_seq_in=raw_seq_in, + sample_mode=sample_mode, + stride=stride, + batch_size=1, + layout=ori_layout, + num_shard=1, + rank=0, + split_mode=split_mode, + sevir_catalog=sevir_catalog, + sevir_data_dir=sevir_data_dir, + start_date=start_date, + end_date=end_date, + datetime_filter=datetime_filter, + catalog_filter=catalog_filter, + shuffle=shuffle, + shuffle_seed=shuffle_seed, + output_type=output_type, + preprocess=preprocess, + rescale_method=rescale_method, + verbose=verbose, + ) + self.aug_mode = aug_mode + self.aug_pipeline = AugmentationPipeline( + self.aug_mode, + self.layout, + ) + + def __getitem__(self, index): + """ + Get processed data sample by index. + + Performs data extraction, augmentation, and layout conversion. + + Args: + index (int): Sample index + + Returns: + ndarray: Processed data in specified layout + """ + data_dict = self.sevir_dataloader.extract_data(index=index) + data = data_dict["vil"] + if self.aug_pipeline is not None: + data = self.aug_pipeline(data_dict) + return data + + def __len__(self): + """len""" + return self.sevir_dataloader.__len__() + + +class SEVIRDataModule(nn.Cell): + """ + DataModule for SEVIR dataset. + + Manages dataset splits (train/val/test), data loading, and augmentation + for training diffusion models on weather event sequences. + + Attributes: + sevir_dir (str): Root directory of SEVIR dataset + batch_size (int): Data loader batch size + num_workers (int): Number of data loader workers + aug_mode (str): Data augmentation configuration + layout (str): Data layout configuration + """ + + def __init__( + self, + seq_in: int = 25, + sample_mode: str = "sequent", + stride: int = 12, + layout: str = "NTHWC", + output_type=np.float32, + preprocess: bool = True, + rescale_method: str = "01", + verbose: bool = False, + aug_mode: str = "0", + dataset_name: str = "sevir", + sevir_dir: str = None, + start_date: Tuple[int] = None, + train_val_split_date: Tuple[int] = (2019, 3, 20), + train_test_split_date: Tuple[int] = (2019, 6, 1), + end_date: Tuple[int] = None, + val_ratio: float = 0.1, + batch_size: int = 1, + num_workers: int = 1, + raw_seq_len: int = 25, + seed: int = 0, + ): + super().__init__() + self.sevir_dir = sevir_dir + self.aug_mode = aug_mode + self.seq_in = seq_in + self.sample_mode = sample_mode + self.stride = stride + self.layout = layout.replace("N", "") + self.output_type = output_type + self.preprocess = preprocess + self.rescale_method = rescale_method + self.verbose = verbose + self.aug_mode = aug_mode + self.batch_size = batch_size + self.num_workers = num_workers + self.seed = seed + self.dataset_name = dataset_name + self.sevir_dir = sevir_dir + self.catalog_path = os.path.join(sevir_dir, "CATALOG.csv") + self.raw_data_dir = os.path.join(sevir_dir, "data") + self.raw_seq_in = raw_seq_len + self.start_date = ( + datetime.datetime(*start_date) if start_date is not None else None + ) + self.train_test_split_date = ( + datetime.datetime(*train_test_split_date) + if train_test_split_date is not None + else None + ) + self.train_val_split_date = ( + datetime.datetime(*train_val_split_date) + if train_val_split_date is not None + else None + ) + self.end_date = datetime.datetime(*end_date) if end_date is not None else None + self.val_ratio = val_ratio + + def setup(self, stage=None) -> None: + """ + Prepare dataset splits for different stages. + + Creates train/val/test splits based on date ranges and configuration. + + Args: + stage (str): Current stage ("fit", "test", etc.) + """ + if stage in (None, "fit"): + print("train") + self.sevir_train_ori = DatasetSEVIR( + sevir_catalog=self.catalog_path, + sevir_data_dir=self.raw_data_dir, + raw_seq_in=self.raw_seq_in, + split_mode="uneven", + shuffle=False, + seq_in=self.seq_in, + stride=self.stride, + sample_mode=self.sample_mode, + layout=self.layout, + start_date=self.start_date, + end_date=self.train_val_split_date, + output_type=self.output_type, + preprocess=self.preprocess, + rescale_method=self.rescale_method, + verbose=self.verbose, + aug_mode=self.aug_mode, + ) + self.sevir_train = ds.GeneratorDataset( + source=self.sevir_train_ori, + column_names="vil", + shuffle=False, + num_parallel_workers=self.num_workers, + ) + self.sevir_train = self.sevir_train.batch(batch_size=self.batch_size) + + if stage in (None, "fit"): + print("val") + self.sevir_val = DatasetSEVIR( + sevir_catalog=self.catalog_path, + sevir_data_dir=self.raw_data_dir, + raw_seq_in=self.raw_seq_in, + split_mode="uneven", + shuffle=False, + seq_in=self.seq_in, + stride=self.stride, + sample_mode=self.sample_mode, + layout=self.layout, + start_date=self.train_val_split_date, + end_date=self.train_test_split_date, + output_type=self.output_type, + preprocess=self.preprocess, + rescale_method=self.rescale_method, + verbose=self.verbose, + aug_mode=self.aug_mode, + ) + self.sevir_val = ds.GeneratorDataset( + source=self.sevir_val, + column_names="vil", + shuffle=False, + num_parallel_workers=self.num_workers, + ) + self.sevir_val = self.sevir_val.batch(batch_size=self.batch_size) + + if stage in (None, "test"): + print("test") + self.sevir_test = DatasetSEVIR( + sevir_catalog=self.catalog_path, + sevir_data_dir=self.raw_data_dir, + raw_seq_in=self.raw_seq_in, + split_mode="uneven", + shuffle=False, + seq_in=self.seq_in, + stride=self.stride, + sample_mode=self.sample_mode, + layout=self.layout, + start_date=self.train_test_split_date, + end_date=self.end_date, + output_type=self.output_type, + preprocess=self.preprocess, + rescale_method=self.rescale_method, + verbose=self.verbose, + aug_mode=self.aug_mode, + ) + self.sevir_test = ds.GeneratorDataset( + source=self.sevir_test, + column_names="vil", + shuffle=False, + num_parallel_workers=self.num_workers, + ) + self.sevir_test = self.sevir_test.batch(batch_size=self.batch_size) + + @property + def num_train_samples(self): + """Get number of training samples""" + return len(self.sevir_train_ori) + + @property + def num_val_samples(self): + """Get number of validation samples""" + return len(self.sevir_val) + + @property + def num_test_samples(self): + """Get number of test samples""" + return len(self.sevir_test) + + +class SEVIRDataLoader: + r""" + DataLoader that loads SEVIR sequences, and spilts each event + into segments according to specified sequence length. + """ + + def __init__( + self, + data_types: Sequence[str] = None, + seq_in: int = 49, + raw_seq_in: int = 49, + sample_mode: str = "sequent", + stride: int = 12, + batch_size: int = 1, + layout: str = "NHWT", + num_shard: int = 1, + rank: int = 0, + split_mode: str = "uneven", + sevir_catalog: Union[str, pd.DataFrame] = None, + sevir_data_dir: str = None, + start_date: datetime.datetime = None, + end_date: datetime.datetime = None, + datetime_filter=None, + catalog_filter="default", + shuffle: bool = False, + shuffle_seed: int = 1, + output_type=np.float32, + preprocess: bool = True, + rescale_method: str = "01", + verbose: bool = False, + ): + super().__init__() + + # configs which should not be modified + self.lght_frame_times = LIGHTING_FRAME_TIMES + self.data_shape = SEVIR_DATA_SHAPE + + self.raw_seq_in = raw_seq_in + if seq_in > self.raw_seq_in: + raise ValueError( + f"Sequence length violation: Input sequence length ({seq_in}) " + f"exceeds maximum allowed length ({self.raw_seq_in}).\n" + f"Technical constraints: Processed sequence length must be ≤ original length.\n" + "Please check your sequence trimming/padding operations." + ) + self.seq_in = seq_in + if sample_mode not in ["random", "sequent"]: + raise ValueError( + f"Invalid sampling mode: '{sample_mode}'. " + f"Allowed options are: {['random', 'sequent']}\n" + "Please specify either:\n" + "- 'random' for stochastic sampling\n" + "- 'sequent' for sequential sampling" + ) + self.sample_mode = sample_mode + self.stride = stride + self.batch_size = batch_size + valid_layout = ("NHWT", "NTHW", "NTCHW", "NTHWC", "TNHW", "TNCHW") + if layout not in valid_layout: + raise ValueError( + f"Invalid layout = {layout}! Must be one of {valid_layout}." + ) + self.layout = layout + self.num_shard = num_shard + self.rank = rank + valid_split_mode = ("ceil", "floor", "uneven") + if split_mode not in valid_split_mode: + raise ValueError( + f"Invalid split_mode: {split_mode}! Must be one of {valid_split_mode}." + ) + self.split_mode = split_mode + self._samples = None + self._hdf_files = {} + self.data_types = data_types + if isinstance(sevir_catalog, str): + self.catalog = pd.read_csv( + sevir_catalog, parse_dates=["time_utc"], low_memory=False + ) + else: + self.catalog = sevir_catalog + self.sevir_data_dir = sevir_data_dir + self.datetime_filter = datetime_filter + self.catalog_filter = catalog_filter + self.start_date = start_date + self.end_date = end_date + self.shuffle = shuffle + self.shuffle_seed = int(shuffle_seed) + self.output_type = output_type + self.preprocess = preprocess + self.rescale_method = rescale_method + self.verbose = verbose + + if self.start_date is not None: + self.catalog = self.catalog[self.catalog.time_utc > self.start_date] + if self.end_date is not None: + self.catalog = self.catalog[self.catalog.time_utc <= self.end_date] + if self.datetime_filter: + self.catalog = self.catalog[self.datetime_filter(self.catalog.time_utc)] + + if self.catalog_filter is not None: + if self.catalog_filter == "default": + self.catalog_filter = lambda c: c.pct_missing == 0 + self.catalog = self.catalog[self.catalog_filter(self.catalog)] + + self._compute_samples() + print(self._samples.head(n=10)) + print("len", len(self._samples)) + self._open_files(verbose=self.verbose) + self.reset() + + def _compute_samples(self): + """ + Computes the list of samples in catalog to be used. This sets self._samples + """ + imgt = self.data_types + imgts = set(imgt) + filtcat = self.catalog[ + np.logical_or.reduce([self.catalog.img_type == i for i in imgt]) + ] + filtcat = filtcat.groupby("id").filter( + lambda x: imgts.issubset(set(x["img_type"])) + ) + filtcat = filtcat.groupby("id").filter(lambda x: x.shape[0] == len(imgt)) + self._samples = filtcat.groupby("id").apply( + lambda df: self._df_to_series(df, imgt) + ) + if self.shuffle: + self.shuffle_samples() + + def shuffle_samples(self): + """Shuffle the dataset samples using a fixed random seed for reproducibility.""" + self._samples = self._samples.sample(frac=1, random_state=self.shuffle_seed) + + def _df_to_series(self, df, imgt): + """Convert catalog DataFrame entries to structured format for multi-image types.""" + d = {} + df = df.set_index("img_type") + for i in imgt: + s = df.loc[i] + idx = s.file_index if i != "lght" else s.id + d.update({f"{i}_filename": [s.file_name], f"{i}_index": [idx]}) + + return pd.DataFrame(d) + + def _open_files(self, verbose=True): + """ + Opens HDF files + """ + imgt = self.data_types + hdf_filenames = [] + for t in imgt: + hdf_filenames += list(np.unique(self._samples[f"{t}_filename"].values)) + + print("hdf_filenames", hdf_filenames) + self._hdf_files = {} + for f in hdf_filenames: + print("Opening HDF5 file for reading", f) + if verbose: + print("Opening HDF5 file for reading", f) + self._hdf_files[f] = h5py.File(self.sevir_data_dir + "/" + f, "r") + print("f:", f) + print("self._hdf_files[f]:", self._hdf_files[f]) + + def close(self): + """ + Closes all open file handles + """ + for f in self._hdf_files: + self._hdf_files[f].close() + print("close: ", f) + self._hdf_files = {} + + @property + def num_seq_per_event(self): + """num seq per event""" + return 1 + (self.raw_seq_in - self.seq_in) // self.stride + + @property + def total_num_seq(self): + """ + The total number of sequences within each shard. + Notice that it is not the product of `self.num_seq_per_event` and `self.total_num_event`. + """ + return int(self.num_seq_per_event * self.num_event) + + @property + def total_num_event(self): + """ + The total number of events in the whole dataset, before split into different shards. + """ + return int(self._samples.shape[0]) + + @property + def start_event_idx(self): + """ + The event idx used in certain rank should satisfy event_idx >= start_event_idx + """ + return self.total_num_event // self.num_shard * self.rank + + @property + def end_event_idx(self): + """ + The event idx used in certain rank should satisfy event_idx < end_event_idx + + """ + if self.split_mode == "ceil": + last_start_event_idx = ( + self.total_num_event // self.num_shard * (self.num_shard - 1) + ) + num_event = self.total_num_event - last_start_event_idx + return self.start_event_idx + num_event + if self.split_mode == "floor": + return self.total_num_event // self.num_shard * (self.rank + 1) + if self.rank == self.num_shard - 1: + return self.total_num_event + return self.total_num_event // self.num_shard * (self.rank + 1) + + @property + def num_event(self): + """ + The number of events split into each rank + """ + return self.end_event_idx - self.start_event_idx + + def _read_data(self, row, data): + """ + Iteratively read data into data dict. Finally data[imgt] gets shape (batch_size, height, width, raw_seq_in). + + Parameters + ---------- + row + A series with fields IMGTYPE_filename, IMGTYPE_index, IMGTYPE_time_index. + data + Dict, data[imgt] is a data tensor with shape = (tmp_batch_size, height, width, raw_seq_in). + + Returns + ------- + data + Updated data. Updated shape = (tmp_batch_size + 1, height, width, raw_seq_in). + """ + imgtyps = np.unique([x.split("_")[0] for x in list(row.keys())]) + for t in imgtyps: + fname = row[f"{t}_filename"] + idx = row[f"{t}_index"] + t_slice = slice(0, None) + if t == "lght": + lght_data = self._hdf_files[fname][idx][:] + data_i = self._lght_to_grid(lght_data, t_slice) + else: + data_i = self._hdf_files[fname][t][idx : idx + 1, :, :, t_slice] + data[t] = ( + np.concatenate((data[t], data_i), axis=0) if (t in data) else data_i + ) + + return data + + def _lght_to_grid(self, data, t_slice=slice(0, None)): + """ + Converts Nx5 lightning data matrix into a 2D grid of pixel counts + """ + + out_size = ( + (*self.data_shape["lght"], len(self.lght_frame_times)) + if t_slice.stop is None + else (*self.data_shape["lght"], 1) + ) + if data.shape[0] == 0: + return np.zeros((1,) + out_size, dtype=np.float32) + + x, y = data[:, 3], data[:, 4] + m = np.logical_and.reduce([x >= 0, x < out_size[0], y >= 0, y < out_size[1]]) + data = data[m, :] + if data.shape[0] == 0: + return np.zeros((1,) + out_size, dtype=np.float32) + t = data[:, 0] + if t_slice.stop is not None: + if t_slice.stop > 0: + if t_slice.stop < len(self.lght_frame_times): + tm = np.logical_and( + t >= self.lght_frame_times[t_slice.stop - 1], + t < self.lght_frame_times[t_slice.stop], + ) + else: + tm = t >= self.lght_frame_times[-1] + else: + tm = np.logical_and( + t >= self.lght_frame_times[0], t < self.lght_frame_times[1] + ) + + data = data[tm, :] + z = np.zeros(data.shape[0], dtype=np.int64) + else: + z = np.digitize(t, self.lght_frame_times) - 1 + z[z == -1] = 0 + + x = data[:, 3].astype(np.int64) + y = data[:, 4].astype(np.int64) + + k = np.ravel_multi_index(np.array([y, x, z]), out_size) + n = np.bincount(k, minlength=np.prod(out_size)) + return np.reshape(n, out_size).astype(np.int16)[np.newaxis, :] + + @property + def sample_count(self): + """ + Record how many times self.__next__() is called. + """ + return self._sample_count + + @property + def _curr_event_idx(self): + return self.__curr_event_idx + + @property + def _curr_seq_idx(self): + """ + Used only when self.sample_mode == 'sequent' + """ + return self.__curr_seq_idx + + def _set__curr_event_idx(self, val): + self.__curr_event_idx = val + + def _set__curr_seq_idx(self, val): + """ + Used only when self.sample_mode == 'sequent' + """ + self.__curr_seq_idx = val + + def reset(self, shuffle: bool = None): + """reset""" + self._set__curr_event_idx(val=self.start_event_idx) + self._set__curr_seq_idx(0) + self._sample_count = 0 + if shuffle is None: + shuffle = self.shuffle + if shuffle: + self.shuffle_samples() + + def __len__(self): + """ + Used only when self.sample_mode == 'sequent' + """ + return self.total_num_seq // self.batch_size + + def _load_event_batch(self, event_idx, event_batch_size): + """ + Loads a selected batch of events (not batch of sequences) into memory. + + Parameters + ---------- + idx + event_batch_size + event_batch[i] = all_type_i_available_events[idx:idx + event_batch_size] + Returns + ------- + event_batch + list of event batches. + event_batch[i] is the event batch of the i-th data type. + Each event_batch[i] is a np.ndarray with shape = (event_batch_size, height, width, raw_seq_in) + """ + event_idx_slice_end = event_idx + event_batch_size + pad_size = 0 + if event_idx_slice_end > self.end_event_idx: + pad_size = event_idx_slice_end - self.end_event_idx + event_idx_slice_end = self.end_event_idx + pd_batch = self._samples.iloc[event_idx:event_idx_slice_end] + data = {} + for _, row in pd_batch.iterrows(): + data = self._read_data(row, data) + if pad_size > 0: + event_batch = [] + for t in self.data_types: + pad_shape = [ + pad_size, + ] + list(data[t].shape[1:]) + data_pad = np.concatenate( + ( + data[t].astype(self.output_type), + np.zeros(pad_shape, dtype=self.output_type), + ), + axis=0, + ) + event_batch.append(data_pad) + else: + event_batch = [data[t].astype(self.output_type) for t in self.data_types] + return event_batch + + + def extract_data(self, index): + """ + Extracts a batch of data without any processing. + + Parameters + ---------- + index + The index of the batch to sample. + + Returns + ------- + event_batch + The extracted data from the event batch without any processing. + """ + event_idx = (index * self.batch_size) // self.num_seq_per_event + seq_idx = (index * self.batch_size) % self.num_seq_per_event + num_sampled = 0 + sampled_idx_list = [] + while num_sampled < self.batch_size: + sampled_idx_list.append({"event_idx": event_idx, "seq_idx": seq_idx}) + seq_idx += 1 + if seq_idx >= self.num_seq_per_event: + event_idx += 1 + seq_idx = 0 + num_sampled += 1 + + start_event_idx = sampled_idx_list[0]["event_idx"] + event_batch_size = sampled_idx_list[-1]["event_idx"] - start_event_idx + 1 + + event_batch = self._load_event_batch( + event_idx=start_event_idx, event_batch_size=event_batch_size + ) + ret_dict = {} + for sampled_idx in sampled_idx_list: + batch_slice = [ + sampled_idx["event_idx"] - start_event_idx, + ] + seq_slice = slice( + sampled_idx["seq_idx"] * self.stride, + sampled_idx["seq_idx"] * self.stride + self.seq_in, + ) + for imgt_idx, imgt in enumerate(self.data_types): + sampled_seq = event_batch[imgt_idx][batch_slice, :, :, seq_slice] + if imgt in ret_dict: + ret_dict[imgt] = np.concatenate( + (ret_dict[imgt], sampled_seq), axis=0 + ) + else: + ret_dict.update({imgt: sampled_seq}) + + return ret_dict + + +class AugmentationPipeline: + """Data augmentation pipeline for multi-frame image processing. + """ + def __init__( + self, + aug_mode="0", + layout=None, + ): + self.layout = layout + self.aug_mode = aug_mode + + if aug_mode == "0": + self.aug = lambda x: x + elif self.aug_mode == "1": + self.aug = Compose( + [ + vision.RandomHorizontalFlip(), + vision.RandomVerticalFlip(), + RandomRotation(degrees=180), + ] + ) + elif aug_mode == "2": + self.aug = Compose( + [ + vision.RandomHorizontalFlip(), + vision.RandomVerticalFlip(), + FixedAngleRotation(angles=[0, 90, 180, 270]), + ] + ) + else: + raise NotImplementedError + + def rearrange_tensor(self, tensor, from_layout, to_layout): + """Permute and reshape tensor dimensions according to layout specifications.""" + return tensor.permute(*tuple(range(len(from_layout)))).reshape(to_layout) + + def __call__(self, data_dict): + """Apply augmentation pipeline to input data dictionary. + + Args: + data_dict (dict): Input data containing "vil" key with tensor data + + Returns: + ms.Tensor: Processed tensor with applied augmentations and layout conversion + """ + data = data_dict["vil"].squeeze(0) + if self.aug_mode != "0": + data = rearrange( + data, + "H W T -> T H W", + ) + data = self.aug(data) + data = rearrange(data, f"{' '.join('THW')} -> {' '.join(self.layout)}") + else: + data = rearrange( + data, + f"{' '.join('HWT')} -> {' '.join(self.layout)}", + ) + + return data + + +class FixedAngleRotation: + """Image augmentation for rotating images by fixed predefined angles. + + Args: + angles (List[int]): List of allowed rotation angles (degrees) + """ + def __init__(self, angles=None): + self.angles = angles + + def __call__(self, img): + """Apply random rotation from predefined angles. + + Args: + img (PIL.Image or mindspore.Tensor): Input image to transform + + Returns: + PIL.Image or mindspore.Tensor: Rotated image with same format as input + """ + angle = np.random.choice(self.angles) + return Rotate(angle)(img) + + +class SEVIRDataset: + """Base dataset class for processing SEVIR data with configurable preprocessing. + + Args: + data_types (Sequence[str], optional): + List of data types to process (e.g., ["vil", "lght"]). Defaults to SEVIR_DATA_TYPES. + layout (str, optional): + Tensor layout specification containing dimensions: + N - batch size + H - height + W - width + T - time/sequence length + C - channel + Defaults to "NHWT". + rescale_method (str, optional): + Data rescaling strategy identifier (e.g., "01" for 0-1 normalization). Defaults to "01". + """ + def __init__( + self, + data_types: Sequence[str] = None, + layout: str = "NHWT", + rescale_method: str = "01", + ): + super().__init__() + if data_types is None: + data_types = SEVIR_DATA_TYPES + else: + if not set(data_types).issubset(SEVIR_DATA_TYPES): + invalid_types = set(data_types) - set(SEVIR_DATA_TYPES) + raise ValueError( + f"Invalid data type(s) detected: {sorted(invalid_types)}\n" + f"Allowed SEVIR data types are: {sorted(SEVIR_DATA_TYPES)}\n" + "Please remove or replace the unsupported data types." + ) + + self.layout = layout + self.data_types = data_types + self.rescale_method = rescale_method + + @staticmethod + def preprocess_data_dict(data_dict, data_types=None, layout="NHWT"): + """ + Parameterss + ---------- + data_dict: Dict[str, Union[np.ndarray, ms.Tensor]] + data_types: Sequence[str] + The data types that we want to rescale. This mainly excludes "mask" from preprocessing. + layout: str + consists of batch_size 'N', seq_in 'T', channel 'C', height 'H', width 'W' + Returns + ------- + data_dict: Dict[str, Union[np.ndarray, ms.Tensor]] + preprocessed data + """ + scale_dict = PREPROCESS_SCALE_01 + offset_dict = PREPROCESS_OFFSET_01 + if data_types is None: + data_types = data_dict.keys() + for key, data in data_dict.items(): + if key in data_types: + if isinstance(data, np.ndarray): + data = data.astype(np.float32) + elif isinstance(data, ms.Tensor): + data = data.float() + else: + raise TypeError + data = change_layout( + data=scale_dict[key] * (data + offset_dict[key]), + in_layout="NHWT", + out_layout=layout, + ) + data_dict[key] = data + return data_dict + + @staticmethod + def data_dict_to_tensor(data_dict, data_types=None): + """ + Convert each element in data_dict to ms.Tensor (copy without grad). + """ + ret_dict = {} + if data_types is None: + data_types = data_dict.keys() + for key, data in data_dict.items(): + if key in data_types: + if isinstance(data, ms.Tensor): + ret_dict[key] = data + elif isinstance(data, np.ndarray): + ret_dict[key] = Tensor.from_numpy(data) + else: + raise ValueError( + f"Invalid data type: {type(data)}. Should be ms.Tensor or np.ndarray" + ) + else: + ret_dict[key] = data + return ret_dict + + def process_data(self, data_dict): + """ + Processes the extracted data. + + Parameters + ---------- + data_dict + The dictionary containing the extracted data. + + Returns + ------- + processed_dict + The dictionary containing the processed data. + """ + split_tensors = data_dict.split(1, axis=0) + processed_tensors = [ + self.process_singledata(tensor) for tensor in split_tensors + ] + tensor_list = [] + for item in processed_tensors: + numpy_array = item["vil"] + tensor = Tensor(numpy_array) + tensor_list.append(tensor) + output_tensor = ops.Stack(axis=0)(tensor_list) + return output_tensor + + def process_singledata(self, singledata): + """process singledata""" + squeezed_tensor = ops.squeeze(singledata, 0) + singledata = {"vil": squeezed_tensor} + processed_dict = self.data_dict_to_tensor( + data_dict=singledata, data_types=self.data_types + ) + processed_dict = self.preprocess_data_dict( + data_dict=processed_dict, + data_types=self.data_types, + layout=self.layout, + ) + return processed_dict diff --git a/MindEarth/applications/nowcasting/PreDiff/src/utils.py b/MindEarth/applications/nowcasting/PreDiff/src/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1908b6a13b9e8b4b2da3935003a4af8cee1d40d2 --- /dev/null +++ b/MindEarth/applications/nowcasting/PreDiff/src/utils.py @@ -0,0 +1,1091 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"all util" +import math +import os +import shutil +import re +from copy import deepcopy +from inspect import isfunction +from typing import Dict, Any, Callable, Optional, Sequence +import numpy as np +import cv2 +from einops import repeat, rearrange + +import mindspore as ms +from mindspore import ops, mint, nn, Parameter, Tensor +from mindspore.train.metrics.metric import Metric +from mindspore.common.initializer import ( + initializer, + One, + Zero, + HeNormal, + Uniform, + TruncatedNormal, +) +from mindearth.utils import create_logger + + +PREPROCESS_SCALE_01 = { + "vis": 1, + "ir069": 1, + "ir107": 1, + "vil": 1 / 255, + "lght": 1, +} +PREPROCESS_OFFSET_01 = { + "vis": 0, + "ir069": 0, + "ir107": 0, + "vil": 0, + "lght": 0, +} + + +class DiagonalGaussianDistribution(nn.Cell): + """Diagonal Gaussian distribution layer for variational autoencoders. + + This class represents a diagonal Gaussian distribution parameterized by mean and log-variance, + supporting sampling, KL divergence computation, and negative log-likelihood evaluation. + + Attributes: + mean (Tensor): Mean values of the distribution + logvar (Tensor): Clamped log-variance values + std (Tensor): Standard deviation derived from logvar + var (Tensor): Variance derived from logvar + deterministic (bool): Flag indicating deterministic sampling mode + """ + def __init__(self, parameters, deterministic=False): + super().__init__() + self.parameters = parameters + self.mean, self.logvar = ops.chunk(parameters, 2, axis=1) + self.logvar = ops.clamp(self.logvar, -30.0, 20.0) + + self.deterministic = deterministic + self.std = ops.exp(0.5 * self.logvar) + + self.var = ops.exp(self.logvar) + + if self.deterministic: + self.var = self.std = ops.zeros_like(self.mean) + + def sample(self): + """Generate a sample from the distribution. + + Returns: + Tensor: Sampled tensor with same shape as mean + + Notes: + - If deterministic=True, returns mean directly without noise + - Uses reparameterization trick for differentiable sampling + """ + sample = mint.randn(self.mean.shape) + + x = self.mean + self.std * sample + return x + + def kl(self, other=None): + """Compute KL divergence between this distribution and another or standard normal.""" + if self.deterministic: + return ms.Tensor([0.0]) + if other is None: + return 0.5 * ops.sum( + ops.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3] + ) + return 0.5 * ops.sum( + ops.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) + + def mode(self): + """Return the mode of the distribution (mean value).""" + return self.mean + + +def _threshold(target, pred, t_input): + """Apply thresholding to target and prediction tensors.""" + t = (target >= t_input).float() + p = (pred >= t_input).float() + is_nan = ops.logical_or(ops.isnan(target), ops.isnan(pred)) + t[is_nan] = 0 + p[is_nan] = 0 + return t, p + + +@staticmethod +def process_data_dict_back(data_dict, data_types=None): + """Rescale and offset data in dictionary using predefined parameters. + + Applies normalization using scale and offset values from global dictionaries. + + Args: + data_dict (dict): Dictionary containing data tensors + data_types (list, optional): Keys to process. Defaults to all keys in data_dict. + rescale (str, optional): Rescaling mode identifier. Defaults to "01". + + Returns: + dict: Processed data dictionary with normalized values + """ + scale_dict = PREPROCESS_SCALE_01 + offset_dict = PREPROCESS_OFFSET_01 + if data_types is None: + data_types = data_dict.keys() + for key in data_types: + data = data_dict[key] + data = data.float() / scale_dict[key] - offset_dict[key] + data_dict[key] = data + return data_dict + + +class SEVIRSkillScore(Metric): + """ + Class for calculating meteorological skill scores using threshold-based metrics. + This metric class computes performance metrics like CSI, POD, etc., + across multiple thresholds for weather prediction evaluation. + Args: + layout (str): Data dimension layout specification (default "NHWT") + mode (str): Operation mode affecting dimension handling ("0", "1", or "2") + seq_in (Optional[int]): Input sequence length (required for modes 1/2) + preprocess_type (str): Data preprocessing method ("sevir" or "sevir_pool*") + threshold_list (Sequence[int]): List of thresholds for binary classification + metrics_list (Sequence[str]): List of metrics to compute (csi, bias, sucr, pod) + eps (float): Small value to prevent division by zero + """ + def __init__( + self, + layout: str = "NHWT", + mode: str = "0", + seq_in: Optional[int] = None, + preprocess_type: str = "sevir", + threshold_list: Sequence[int] = (16, 74, 133, 160, 181, 219), + metrics_list: Sequence[str] = ("csi", "bias", "sucr", "pod"), + eps: float = 1e-4, + ): + super().__init__() + self.layout = layout + if not (preprocess_type == "sevir" or preprocess_type.startswith("sevir_pool")): + raise ValueError( + f"Invalid preprocessing type: '{preprocess_type}'\n" + "Allowed options:\n" + "- 'sevir' for standard processing\n" + "- 'sevir_pool*' for pooled variants (e.g. 'sevir_pool4')" + ) + self.preprocess_type = preprocess_type + self.threshold_list = threshold_list + self.metrics_list = metrics_list + self.eps = eps + self.mode = mode + self.seq_in = seq_in + if mode in ("0",): + self.keep_seq_in_dim = False + state_shape = (len(self.threshold_list),) + elif mode in ("1", "2"): + self.keep_seq_in_dim = True + if not isinstance(self.seq_in, int): + raise TypeError( + f"Invalid type for seq_in: expected integer, got {type(self.seq_in).__name__}. " + "This parameter is required when preserving the sequence dimension." + ) + state_shape = (len(self.threshold_list), self.seq_in) + + else: + raise NotImplementedError(f"mode {mode} not supported!") + + self.hits = Parameter(ops.zeros(state_shape), name="hits") + self.misses = Parameter(ops.zeros(state_shape), name="misses") + self.fas = Parameter(ops.zeros(state_shape), name="fas") + + @property + def hits_misses_fas_reduce_dims(self): + """Dimensions to reduce when calculating metric statistics. + + Returns: + list[int]: List of dimensions to collapse during metric computation + """ + if not hasattr(self, "_hits_misses_fas_reduce_dims"): + seq_dim = self.layout.find("T") + self._hits_misses_fas_reduce_dims = list(range(len(self.layout))) + if self.keep_seq_in_dim: + self._hits_misses_fas_reduce_dims.pop(seq_dim) + return self._hits_misses_fas_reduce_dims + + def clear(self): + """Clear the internal states.""" + self.hits.set_data(mint.zeros_like(self.hits)) + self.misses.set_data(mint.zeros_like(self.misses)) + self.fas.set_data(mint.zeros_like(self.fas)) + + @staticmethod + def pod(hits, misses, _, eps): + """Probability of Detection""" + return hits / (hits + misses + eps) + + @staticmethod + def sucr(hits, _, fas, eps): + """Probability of hits""" + return hits / (hits + fas + eps) + + @staticmethod + def csi(hits, misses, fas, eps): + """critical success index""" + return hits / (hits + misses + fas + eps) + + @staticmethod + def bias(hits, misses, fas, eps): + """Bias score""" + bias = (hits + fas) / (hits + misses + eps) + logbias = ops.pow(bias / ops.log(Tensor(2.0)), 2.0) + return logbias + + def calc_seq_hits_misses_fas(self, pred, target, threshold): + """Calculate contingency table statistics for given threshold. + + Args: + pred (Tensor): Model prediction tensor + target (Tensor): Ground truth tensor + threshold (int): Threshold value for binarization + + Returns: + tuple[Tensor, Tensor, Tensor]: Hits, misses, false alarms + """ + t, p = _threshold(target, pred, threshold) + hits = ops.sum(t * p, dim=self.hits_misses_fas_reduce_dims).int() + misses = ops.sum(t * (1 - p), dim=self.hits_misses_fas_reduce_dims).int() + fas = ops.sum((1 - t) * p, dim=self.hits_misses_fas_reduce_dims).int() + return hits, misses, fas + + def preprocess(self, pred, target): + """Apply data preprocessing based on configuration. + + Handles SEVIR-specific normalization and optional spatial pooling. + + Args: + pred (Tensor): Raw model predictions + target (Tensor): Raw ground truth data + + Returns: + tuple[Tensor, Tensor]: Processed prediction and target tensors + """ + if self.preprocess_type == "sevir": + pred = process_data_dict_back(data_dict={"vil": pred.float()})["vil"] + target = process_data_dict_back(data_dict={"vil": target.float()})["vil"] + elif self.preprocess_type.startswith("sevir_pool"): + pred = process_data_dict_back(data_dict={"vil": pred.float()})["vil"] + target = process_data_dict_back(data_dict={"vil": target.float()})["vil"] + self.pool_scale = int(re.search(r"\d+", self.preprocess_type).group()) + batch_size = target.shape[0] + pred = rearrange( + pred, f"{self.einops_layout} -> {self.einops_spatial_layout}" + ) + target = rearrange( + target, f"{self.einops_layout} -> {self.einops_spatial_layout}" + ) + max_pool = nn.MaxPool2d( + kernel_size=self.pool_scale, stride=self.pool_scale, pad_mode="pad" + ) + pred = max_pool(pred) + target = max_pool(target) + pred = rearrange( + pred, + f"{self.einops_spatial_layout} -> {self.einops_layout}", + N=batch_size, + ) + target = rearrange( + target, + f"{self.einops_spatial_layout} -> {self.einops_layout}", + N=batch_size, + ) + else: + raise NotImplementedError + return pred, target + + def update(self, pred: Tensor, target: Tensor): + """Update metric statistics with new batch of predictions.""" + pred, target = self.preprocess(pred, target) + for i, threshold in enumerate(self.threshold_list): + hits, misses, fas = self.calc_seq_hits_misses_fas(pred, target, threshold) + self.hits[i] += hits + self.misses[i] += misses + self.fas[i] += fas + + def eval(self): + """Compute final metric scores across all thresholds.""" + metrics_dict = { + "pod": self.pod, + "csi": self.csi, + "sucr": self.sucr, + "bias": self.bias, + } + ret = {} + for threshold in self.threshold_list: + ret[threshold] = {} + ret["avg"] = {} + for metrics in self.metrics_list: + if self.keep_seq_in_dim: + score_avg = np.zeros((self.seq_in,)) + else: + score_avg = 0 + scores = metrics_dict[metrics](self.hits, self.misses, self.fas, self.eps) + scores = scores.asnumpy() + for i, threshold in enumerate(self.threshold_list): + if self.keep_seq_in_dim: + score = scores[i] + else: + score = scores[i].item() + if self.mode in ("0", "1"): + ret[threshold][metrics] = score + elif self.mode in ("2",): + ret[threshold][metrics] = np.mean(score).item() + else: + raise NotImplementedError + score_avg += score + score_avg /= len(self.threshold_list) + if self.mode in ("0", "1"): + ret["avg"][metrics] = score_avg + elif self.mode in ("2",): + ret["avg"][metrics] = np.mean(score_avg).item() + else: + raise NotImplementedError + return ret + + +def make_beta_schedule( + schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3 +): + """Generate beta schedule for diffusion models. + + Supports linear, cosine, sqrt_linear and sqrt schedules. + + Args: + schedule (str): Schedule type ("linear", "cosine", etc.) + n_timestep (int): Number of time steps + linear_start (float): Linear schedule start value + linear_end (float): Linear schedule end value + cosine_s (float): Cosine schedule shift parameter + + Returns: + Tensor: Beta values for each time step + """ + if schedule == "linear": + betas = ( + mint.linspace( + linear_start**0.5, linear_end**0.5, n_timestep, dtype=ms.float64 + ) + ** 2 + ) + + elif schedule == "cosine": + timesteps = ops.arange(n_timestep + 1, dtype=ms.float64) / n_timestep + cosine_s + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = ops.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = ops.linspace(linear_start, linear_end, n_timestep) + elif schedule == "sqrt": + betas = ops.linspace(linear_start, linear_end, n_timestep) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.asnumpy() + + +def extract_into_tensor(a, t, x_shape, batch_axis=0): + """Extract tensor elements and reshape to match target dimensions.""" + batch_size = t.shape[0] + out = a.gather_elements(-1, t) + out_shape = [ + 1, + ] * len(x_shape) + out_shape[batch_axis] = batch_size + return out.reshape(out_shape) + + +def noise_like(shape): + """Generate random noise tensor matching given shape.""" + return ops.randn(shape) + + +def default(val, d): + """Return val if present, otherwise resolve default value.""" + if val is not None: + return val + return d() if isfunction(d) else d + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = ops.exp( + -math.log(max_period) + * ops.arange(start=0, end=half, dtype=ms.float32) + / half + ) + args = timesteps[:, None].float() * freqs[None] + embedding = ops.cat([ops.cos(args), ops.sin(args)], axis=-1) + if dim % 2: + embedding = ops.cat([embedding, ops.zeros_like(embedding[:, :1])], axis=-1) + else: + embedding = repeat(timesteps, "b -> b d", d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for param in module.trainable_params(): + param.set_data(Zero()(shape=param.shape, dtype=param.dtype)) + return module + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + num_groups = min(32, channels) + return nn.GroupNorm(num_groups, channels) + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs, pad_mode="pad", has_bias=True) + if dims == 2: + return nn.Conv2d(*args, **kwargs, pad_mode="pad", has_bias=True) + return mint.nn.Conv3d(*args, **kwargs) + + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Dense(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + if dims == 2: + return mint.nn.AvgPool2d(*args, **kwargs) + if dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + + +def round_to(dat, c): + """round to""" + return dat + (dat - dat % c) % c + + +def get_activation(act, inplace=False, **kwargs): + """ + + Parameters + ---------- + act + Name of the activation + inplace + Whether to perform inplace activation + + Returns + ------- + activation_layer + The activation + """ + if act is None: + return lambda x: x + if isinstance(act, str): + if act == "leaky": + negative_slope = kwargs.get("negative_slope", 0.1) + return nn.LeakyReLU(negative_slope, inplace=inplace) + if act == "identity": + return nn.Identity() + if act == "elu": + return nn.ELU(inplace=inplace) + if act == "gelu": + return nn.GELU(approximate=False) + if act == "relu": + return nn.ReLU() + if act == "sigmoid": + return nn.Sigmoid() + if act == "tanh": + return nn.Tanh() + if act in ('softrelu', 'softplus'): + return ops.Softplus() + if act == "softsign": + return nn.Softsign() + raise NotImplementedError('act="{}" is not supported. ') + return act + + +def get_norm_layer( + norm_type: str = "layer_norm", + axis: int = -1, + epsilon: float = 1e-5, + in_channels: int = 0, + **kwargs, +): + """Get the normalization layer based on the provided type + + Parameters + ---------- + norm_type + The type of the layer normalization from ['layer_norm'] + axis + The axis to normalize the + epsilon + The epsilon of the normalization layer + in_channels + Input channel + + Returns + ------- + norm_layer + The layer normalization layer + """ + if isinstance(norm_type, str): + if norm_type == "layer_norm": + if in_channels <= 0: + raise ValueError( + f"Invalid number of input channels: {in_channels}. " + "in_channels must be a positive integer." + ) + + if axis != -1: + raise ValueError( + f"Invalid axis specification: {axis}. " + "This operation only supports axis=-1 (last dimension)." + ) + norm_layer = nn.LayerNorm( + normalized_shape=[in_channels], epsilon=epsilon, **kwargs + ) + else: + raise NotImplementedError("norm_type={} is not supported".format(norm_type)) + return norm_layer + if norm_type is None: + return nn.Identity() + raise NotImplementedError("The type of normalization must be str") + + +def generalize_padding(x, pad_t, pad_h, pad_w, padding_type, t_pad_left=False): + """ + + Parameters + ---------- + x + Shape (B, T, H, W, C) + pad_t + pad_h + pad_w + padding_type + t_pad_left + + Returns + ------- + out + The result after padding the x. Shape will be (B, T + pad_t, H + pad_h, W + pad_w, C) + """ + if pad_t == 0 and pad_h == 0 and pad_w == 0: + return x + + if padding_type not in ["zeros", "ignore", "nearest"]: + raise ValueError( + f"Invalid padding type: '{padding_type}'. " + f"Allowed options are: {['zeros', 'ignore', 'nearest']}\n" + "Please specify one of:\n" + "- 'zeros': pad with zero values\n" + "- 'ignore': maintain original values\n" + "- 'nearest': replicate edge values" + ) + _, t, h, w, _ = x.shape + + if padding_type == "nearest": + return ops.interpolate( + x.permute(0, 4, 1, 2, 3), size=(t + pad_t, h + pad_h, w + pad_w) + ).permute(0, 2, 3, 4, 1) + if t_pad_left: + return ops.pad(x, (0, 0, 0, pad_w, 0, pad_h, pad_t, 0)) + return ops.pad(x, (0, 0, 0, pad_w, 0, pad_h, 0, pad_t)) + + +def generalize_unpadding(x, pad_t, pad_h, pad_w, padding_type): + """Removes padding from a 5D tensor based on specified padding type and dimensions. + + Args: + x (Tensor): Input tensor with shape (batch, time, height, width, channels). + pad_t (int): Number of time steps to remove from the end. + pad_h (int): Number of height units to remove from the end. + pad_w (int): Number of width units to remove from the end. + padding_type (str): Type of padding removal method ("zeros", "ignore", "nearest"). + + Returns: + Tensor: Processed tensor with padding removed according to specified method. + + Raises: + AssertionError: If invalid padding_type is provided. + """ + if padding_type not in ["zeros", "ignore", "nearest"]: + raise ValueError( + f"Invalid padding_type: '{padding_type}'. " + f"Supported padding types are: 'zeros', 'ignore', 'nearest'" + ) + _, t, h, w, _ = x.shape + if pad_t == 0 and pad_h == 0 and pad_w == 0: + return x + + if padding_type == "nearest": + return ops.interpolate( + x.permute(0, 4, 1, 2, 3), size=(t - pad_t, h - pad_h, w - pad_w) + ).permute(0, 2, 3, 4, 1) + return x[:, : (t - pad_t), : (h - pad_h), : (w - pad_w), :] + + +def _calculate_fan_in_and_fan_out(parameter): + """Calculates fan_in and fan_out values for neural network weight initialization.""" + dimensions = parameter.ndim + if dimensions < 2: + raise ValueError( + "Fan in and fan out can not be computed for parameter with fewer than 2 dimensions" + ) + num_input_fmaps = parameter.shape[1] + num_output_fmaps = parameter.shape[0] + receptive_field_size = 1 + if dimensions > 2: + for s in parameter.shape[2:]: + receptive_field_size *= s + fan_in = num_input_fmaps * receptive_field_size + fan_out = num_output_fmaps * receptive_field_size + return fan_in, fan_out + + +def apply_initialization( + cell, linear_mode="0", conv_mode="0", norm_mode="0", embed_mode="0" +): + """Applies parameter initialization strategies to neural network layers. + + Args: + cell (nn.Cell): Neural network layer to initialize. + linear_mode (str): Initialization mode for dense layers ("0", "1", "2"). + conv_mode (str): Initialization mode for convolutional layers ("0", "1", "2"). + norm_mode (str): Initialization mode for normalization layers ("0"). + embed_mode (str): Initialization mode for embedding layers ("0"). + + Raises: + NotImplementedError: If unsupported initialization mode is requested. + """ + if isinstance(cell, nn.Dense): + if linear_mode in ("0",): + cell.weight.set_data( + initializer( + HeNormal(mode="fan_in", nonlinearity="linear"), + cell.weight.shape, + cell.weight.dtype, + ) + ) + elif linear_mode in ("1",): + cell.weight.set_data( + initializer.initializer( + HeNormal(mode="fan_out", nonlinearity="leaky_relu"), + cell.weight.shape, + cell.weight.dtype, + ) + ) + elif linear_mode in ("2",): + zeros_tensor = ops.zeros(cell.weight.shape, cell.weight.dtype) + cell.weight.set_data(zeros_tensor) + else: + raise NotImplementedError + if hasattr(cell, "bias") and cell.bias is not None: + zeros_tensor = ops.zeros(cell.bias.shape, cell.bias.dtype) + cell.bias.set_data(zeros_tensor) + + elif isinstance( + cell, (nn.Conv2d, nn.Conv3d, nn.Conv2dTranspose, nn.Conv3dTranspose) + ): + if conv_mode in ("0",): + cell.weight.set_data( + initializer( + HeNormal( + negative_slope=math.sqrt(5), mode="fan_out", nonlinearity="relu" + ), + cell.weight.shape, + cell.weight.dtype, + ) + ) + if cell.has_bias: + fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight) + if fan_in != 0: + bound = 1 / math.sqrt(fan_in) + cell.bias.set_data( + initializer(Uniform(bound), cell.bias.shape, cell.bias.dtype) + ) + elif conv_mode in ("1",): + cell.weight.set_data( + initializer( + HeNormal( + mode="fan_out", nonlinearity="leaky_relu", negative_slope=0.1 + ), + cell.weight.shape, + cell.weight.dtype, + ) + ) + if hasattr(mcell, "bias") and mcell.bias is not None: + cell.bias.set_data( + initializer(Zero(), cell.bias.shape, cell.bias.dtype) + ) + elif conv_mode in ("2",): + cell.weight.set_data( + initializer(Zero(), cell.weight.shape, cell.weight.dtype) + ) + if hasattr(m, "bias") and m.bias is not None: + cell.bias.set_data( + initializer(Zero(), cell.bias.shape, cell.bias.dtype) + ) + else: + raise NotImplementedError + + elif isinstance(cell, nn.GroupNorm): + if norm_mode in ("0",): + if cell.gamma is not None: + cell.gamma.set_data( + initializer(One(), cell.gamma.shape, cell.gamma.dtype) + ) + if cell.beta is not None: + cell.beta.set_data( + initializer(Zero(), cell.beta.shape, cell.beta.dtype) + ) + else: + raise NotImplementedError("Normalization mode not supported") + elif isinstance(cell, nn.Embedding): + if embed_mode == "0": + cell.embedding_table.set_data( + initializer( + TruncatedNormal(sigma=0.02), + cell.embedding_table.shape, + cell.embedding_table.dtype, + ) + ) + else: + raise NotImplementedError + else: + pass + + +def prepare_output_directory(base_config, device_id): + """Creates/updates output directory for experiment results. + + Args: + base_config (dict): Configuration dictionary containing directory paths. + device_id (int): Device identifier for directory naming. + + Returns: + str: Path to the created/updated output directory. + + Raises: + OSError: If directory operations fail unexpectedly. + """ + output_path = os.path.join( + base_config["summary"]["summary_dir"], f"single_device{device_id}" + ) + + try: + if os.path.exists(output_path): + shutil.rmtree(output_path) + print(f"Cleared previous output directory: {output_path}") + os.makedirs(output_path, exist_ok=True) + except OSError as e: + print(f"Directory operation failed: {e}", exc_info=True) + raise + base_config["summary"]["summary_dir"] = output_path + return output_path + + +def configure_logging_system(output_dir, config): + """Sets up logging system for the application. + + Args: + output_dir (str): Directory where logs should be stored. + config (dict): Configuration dictionary containing experiment parameters. + + Returns: + Logger: Configured logger instance. + """ + logger = create_logger(path=os.path.join(output_dir, "results.log")) + logger.info(f"Process ID: {os.getpid()}") + logger.info(config["summary"]) + return logger + + +def prepare_dataset(config, module): + """Initializes and prepares the dataset for training/evaluation. + + Args: + config (dict): Configuration dictionary with dataset parameters. + SEVIRPLModule (Module): Data module class for dataset handling. + + Returns: + tuple: (DataModule, total_num_steps) containing initialized data module and total training steps. + + Raises: + ValueError: If configuration is not provided. + """ + if config is not None: + dataset_cfg = config["data"] + total_batch_size = config["optim"]["total_batch_size"] + micro_batch_size = config["optim"]["micro_batch_size"] + max_epochs = config["optim"]["max_epochs"] + else: + raise ValueError("config is required but not provided") + dm = module.get_sevir_datamodule( + dataset_cfg=dataset_cfg, + micro_batch_size=micro_batch_size, + num_workers=8, + ) + dm.setup() + total_num_steps = module.get_total_num_steps( + epoch=max_epochs, + num_samples=dm.num_train_samples, + total_batch_size=total_batch_size, + ) + return dm, total_num_steps + + +def warmup_lambda(warmup_steps, min_lr_ratio=0.1): + """Creates a learning rate warmup schedule as a lambda function. + + Args: + warmup_steps (int): Number of steps for the warmup phase. + min_lr_ratio (float): Minimum learning rate ratio at the start of training. + + Returns: + function: Lambda function that calculates the warmup multiplier based on current step. + """ + def ret_lambda(epoch): + if epoch <= warmup_steps: + return min_lr_ratio + (1.0 - min_lr_ratio) * epoch / warmup_steps + return 1.0 + + return ret_lambda + + +def get_loss_fn(loss: str = "l2") -> Callable: + """ + Returns a loss function based on the provided loss type. + + Args: + loss (str): Type of loss function. Default is "l2". + + Returns: + Callable: A loss function corresponding to the provided loss type. + """ + if loss in ("l2", "mse"): + return nn.MSELoss() + return nn.L1Loss() + + +def disabled_train(self): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def disable_train(model: nn.Cell): + """ + Disable training to avoid error when used in pl.LightningModule + """ + model.set_train(False) + model.train = disabled_train + return model + + +def layout_to_in_out_slice(layout, t_in, t_out=None): + """layout_to_in_out_slice""" + t_axis = layout.find("T") + num_axes = len(layout) + in_slice = [ + slice(None, None), + ] * num_axes + out_slice = deepcopy(in_slice) + in_slice[t_axis] = slice(None, t_in) + if t_out is None: + out_slice[t_axis] = slice(t_in, None) + else: + out_slice[t_axis] = slice(t_in, t_in + t_out) + return in_slice, out_slice + + +def parse_layout_shape(layout: str) -> Dict[str, Any]: + r""" + + Parameters + ---------- + layout: str + e.g., "NTHWC", "NHWC". + + Returns + ------- + ret: Dict + """ + batch_axis = layout.find("N") + t_axis = layout.find("T") + h_axis = layout.find("H") + w_axis = layout.find("W") + c_axis = layout.find("C") + return { + "batch_axis": batch_axis, + "t_axis": t_axis, + "h_axis": h_axis, + "w_axis": w_axis, + "c_axis": c_axis, + } + + +def ssim(img1, img2): + """Compute Structural Similarity Index (SSIM) between two images. + + Args: + img1 (np.ndarray): First input image (grayscale or single-channel), shape (H, W) + img2 (np.ndarray): Second input image with identical shape to img1 + + Returns: + float: SSIM value between 0 (completely dissimilar) and 1 (perfect similarity) + + Notes: + - Uses 11x11 Gaussian window with σ=1.5 for weighted filtering + - Follows the standard SSIM formulation with constants c1=0.0001, c2=0.0009 + - Computes valid convolution regions (edges truncated by kernel size) + """ + c1 = 0.01**2 + c2 = 0.03**2 + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + ssim_map = ((2 * mu1_mu2 + c1) * (2 * sigma12 + c2)) / ( + (mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2) + ) + return ssim_map.mean() + + +def calculate_ssim_function(img1, img2): + """calculate ssim function""" + if not img1.shape == img2.shape: + raise ValueError("Input images must have the same dimensions.") + if img1.ndim == 2: + return ssim(img1, img2) + if img1.ndim == 3: + if img1.shape[0] == 3: + ssims = [] + for i in range(3): + ssims.append(ssim(img1[i], img2[i])) + return np.array(ssims).mean() + if img1.shape[0] == 1: + return ssim(np.squeeze(img1), np.squeeze(img2)) + raise ValueError("Wrong input image dimensions.") + + + + +def calculate_ssim(videos1, videos2): + """Calculate Structural Similarity Index (SSIM) between two video sequences across all timestamps. + + Args: + videos1 (Tensor or np.ndarray): First video sequence with shape (batch_size, time_steps, + height, width, channels) + videos2 (Tensor or np.ndarray): Second video sequence with identical shape to videos1 + + Returns: + dict[int, float]: Dictionary where keys are timestamp indices and values are the mean SSIM values + across all batches for that timestamp + + Raises: + AssertionError: If input video tensors have different shapes + """ + ssim_results = [] + for video_num in range(videos1.shape[0]): + video1 = videos1[video_num] + video2 = videos2[video_num] + ssim_results_of_a_video = [] + for clip_timestamp in range(len(video1)): + img1 = video1[clip_timestamp] + img2 = video2[clip_timestamp] + ssim_results_of_a_video.append(calculate_ssim_function(img1, img2)) + ssim_results.append(ssim_results_of_a_video) + ssim_results = np.array(ssim_results) + ssim_score = {} + for clip_timestamp in range(len(video1)): + ssim_score[clip_timestamp] = np.mean(ssim_results[:, clip_timestamp]) + + return ssim_score + + +def init_model(module, config, mode): + """Initialize model with ckpt""" + summary_params = config.get("summary") + module.main_model.set_train(True) + if mode != "train": + summary_params["load_ckpt"] = "True" + module.main_model.set_train(False) + if summary_params["load_ckpt"]: + params = ms.load_checkpoint(summary_params.get("ckpt_path")) + ms.load_param_into_net( + module.main_model, params + ) + return module + +def self_axial(input_shape): + """Axial attention implementation from "Axial-Deeplab: + Efficient Convolutional Neural Networks for Semantic Segmentation" + Args: + input_shape (tuple): Input tensor shape (T, H, W, C). + Returns: + tuple: Axial attention parameters with separate temporal/spatial cuboids. + """ + t, h, w, _ = input_shape + cuboid_size = [(t, 1, 1), (1, h, 1), (1, 1, w)] + strategy = [("l", "l", "l"), ("l", "l", "l"), ("l", "l", "l")] + shift_size = [(0, 0, 0), (0, 0, 0), (0, 0, 0)] + return cuboid_size, strategy, shift_size diff --git a/MindEarth/applications/nowcasting/PreDiff/src/vae/resnet.py b/MindEarth/applications/nowcasting/PreDiff/src/vae/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..b954f9d4595df4e71880b317557987f254d77112 --- /dev/null +++ b/MindEarth/applications/nowcasting/PreDiff/src/vae/resnet.py @@ -0,0 +1,996 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"resnet model" +from functools import partial +import numpy as np + +import mindspore as ms +from mindspore import Tensor, mint, nn, ops + + +class AvgPool1d(nn.Cell): + """ + 1D average pooling layer implementation with customizable kernel size, stride, and padding. + Performs spatial downsampling by computing average values over sliding windows. + """ + def __init__(self, kernel_size, stride=1, padding=0): + """ + Initialize 1D average pooling parameters with validation checks. + + Args: + kernel_size (int): Length of the pooling window + stride (int): Stride size for window movement (default=1) + padding (int): Zero-padding added to both sides of input (default=0) + + Raises: + ValueError: If kernel_size ≤ 0, stride ≤ 0, or padding < 0 + """ + super().__init__() + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.mean = ops.ReduceMean(keep_dims=False) + if stride <= 0: + raise ValueError("stride must be positive") + if kernel_size <= 0: + raise ValueError("kernel_size must be positive") + if padding < 0: + raise ValueError("padding must be non-negative") + + def construct(self, x): + """ + Apply 1D average pooling to input tensor. + """ + input_shape = x.shape + n, c, l_in = input_shape[0], input_shape[1], input_shape[2] + pad_left = self.padding + pad_right = self.padding + x = ops.Pad(((0, 0), (0, 0), (pad_left, pad_right)))(x) + l_in += pad_left + pad_right + l_out = (l_in - self.kernel_size) // self.stride + 1 + output = Tensor(np.zeros((n, c, l_out)), dtype=ms.float32) + for i in range(l_out): + start = i * self.stride + end = start + self.kernel_size + if end <= l_in: + window = x[:, :, start:end] + output[:, :, i] = self.mean(window, -1) + + return output + + +class Upsample1D(nn.Cell): + """ + An upsampling layer with an optional convolution. + + Parameters: + channels: channels in the inputs and outputs. + use_conv: a bool determining if a convolution is applied. + use_conv_transpose: + out_channels: + """ + + def __init__( + self, + channels, + use_conv=False, + use_conv_transpose=False, + out_channels=None, + name="conv", + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + + self.conv = None + if use_conv_transpose: + self.conv = nn.Conv1dTranspose( + channels, + self.out_channels, + kernel_size=4, + stride=2, + pad_mode="pad", + padding=1, + has_bias=True, + ) + elif use_conv: + self.conv = nn.Conv1d( + self.channels, + self.out_channels, + 3, + padding=1, + pad_mode="pad", + has_bias=True, + ) + + def construct(self, x): + """forward""" + if x.shape[1] != self.channels: + raise ValueError( + f"Input channels mismatch. Expected {self.channels} channels, " + f"but got {x.shape[1]} channels in dimension 1 of input tensor." + ) + if self.use_conv_transpose: + return self.conv(x) + + x = ops.interpolate(x, scale_factor=2.0, mode="nearest") + + if self.use_conv: + x = self.conv(x) + + return x + + +class Downsample1D(nn.Cell): + """ + A downsampling layer with an optional convolution. + + Parameters: + channels: channels in the inputs and outputs. + use_conv: a bool determining if a convolution is applied. + out_channels: + padding: + """ + + def __init__( + self, channels, use_conv=False, out_channels=None, padding=1, name="conv" + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + + if use_conv: + self.conv = nn.Conv1d( + self.channels, + self.out_channels, + 3, + stride=stride, + padding=padding, + pad_mode="pad", + has_bias=True, + ) + else: + if self.channels != self.out_channels: + raise RuntimeError( + f"Channels mismatch. Expected channels and out_channels to be equal, " + f"but got channels={self.channels}, out_channels={self.out_channels}." + ) + self.conv = AvgPool1d(kernel_size=stride, stride=stride) + + def construct(self, x): + return self.conv(x) + + +class Upsample2D(nn.Cell): + """ + An upsampling layer with an optional convolution. + + Parameters: + channels: channels in the inputs and outputs. + use_conv: a bool determining if a convolution is applied. + use_conv_transpose: + out_channels: + """ + + def __init__( + self, + channels, + use_conv=False, + use_conv_transpose=False, + out_channels=None, + name="conv", + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + + conv = None + if use_conv_transpose: + conv = nn.Conv2dTranspose( + channels, + self.out_channels, + kernel_size=4, + stride=2, + padding=1, + pad_mode="pad", + has_bias=True, + ) + elif use_conv: + conv = nn.Conv2d( + self.channels, + self.out_channels, + kernel_size=3, + padding=1, + pad_mode="pad", + has_bias=True, + ) + if name == "conv": + self.conv = conv + else: + self.conv2d_0 = conv + + def construct(self, hidden_states, output_size=None): + """forward""" + if hidden_states.shape[1] != self.channels: + raise ValueError( + f"Channel dimension mismatch in hidden states. " + f"Expected {self.channels} channels at dimension 1, " + f"but received {hidden_states.shape[1]} channels. " + f"Full shape: {tuple(hidden_states.shape)}" + ) + + + if self.use_conv_transpose: + return self.conv(hidden_states) + + dtype = hidden_states.dtype + if dtype == ms.bfloat16: + hidden_states = hidden_states.to(ms.float32) + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + if output_size is None: + hidden_states = ops.interpolate( + hidden_states, + scale_factor=2.0, + recompute_scale_factor=True, + mode="nearest", + ) + else: + hidden_states = ops.interpolate( + hidden_states, size=output_size, mode="nearest" + ) + + if dtype == ms.bfloat16: + hidden_states = hidden_states.to(dtype) + if self.use_conv: + if self.name == "conv": + hidden_states = self.conv(hidden_states) + else: + hidden_states = self.conv2d_0(hidden_states) + + return hidden_states + + +class Downsample2D(nn.Cell): + """ + A downsampling layer with an optional convolution. + + Parameters: + channels: channels in the inputs and outputs. + use_conv: a bool determining if a convolution is applied. + out_channels: + padding: + """ + + def __init__( + self, channels, use_conv=False, out_channels=None, padding=1, name="conv" + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + + if use_conv: + conv = nn.Conv2d( + self.channels, + self.out_channels, + kernel_size=3, + stride=stride, + padding=padding, + pad_mode="pad", + has_bias=True, + ) + else: + if self.channels != self.out_channels: + raise RuntimeError( + f"Layer configuration conflict. channels ({self.channels}) " + f"must equal out_channels ({self.out_channels}). " + f"Check layer initialization parameters." + ) + + conv = mint.nn.AvgPool2d(kernel_size=stride, stride=stride) + if name == "conv": + self.conv2d_0 = conv + self.conv = conv + elif name == "Conv2d_0": + self.conv = conv + else: + self.conv = conv + + def construct(self, hidden_states): + """forward""" + if hidden_states.shape[1] != self.channels: + raise ValueError( + f"Channel dimension mismatch in hidden states. Expected {self.channels} channels at dimension 1, " + f"but received tensor with {hidden_states.shape[1]} channels. " + f"Full shape: {tuple(hidden_states.shape)}" + ) + + if self.use_conv and self.padding == 0: + pad = (0, 1, 0, 1) + hidden_states = ops.pad(hidden_states, pad, mode="constant", value=None) + + if hidden_states.shape[1] != self.channels: + raise ValueError( + f"Channel dimension mismatch in hidden states. " + f"Layer expects {self.channels} channels at dimension 1, " + f"but received {hidden_states.shape[1]} channels. " + f"Full tensor shape: {tuple(hidden_states.shape)}\n" + f"Possible solutions:\n" + f"1. Check input data pipeline\n" + f"2. Verify layer configuration (current channels: {self.channels})\n" + f"3. Inspect preceding layer's output channels" + ) + hidden_states = self.conv(hidden_states) + + return hidden_states + + +class FirUpsample2D(nn.Cell): + """ + 2D upsampling layer with optional FIR filtering and convolutional projection. + Implements pixel-shuffle based upsampling with optional convolutional transformation. + """ + def __init__( + self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1) + ): + """ + Initialize upsample layer parameters. + + Args: + channels (int): Number of input channels + out_channels (int): Number of output channels (defaults to input channels if not specified) + use_conv (bool): Whether to apply 3x3 convolution after upsampling + fir_kernel (tuple): FIR filter kernel coefficients for antialiasing + + Raises: + ValueError: If invalid kernel parameters are provided + """ + super().__init__() + out_channels = out_channels if out_channels else channels + if use_conv: + self.conv2d_0 = nn.Conv2d( + channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + pad_mode="pad", + has_bias=True, + ) + self.use_conv = use_conv + self.fir_kernel = fir_kernel + self.out_channels = out_channels + + def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1): + """ + Core upsampling operation with optional convolution and FIR filtering. + """ + + if not isinstance(factor, int): + raise TypeError( + f"Invalid type for 'factor'. Expected integer, " + f"but got {type(factor).__name__}." + ) + + if factor < 1: + raise ValueError( + f"Invalid value for 'factor'. Must be >= 1, " + f"but got {factor}." + ) + + # Setup filter kernel. + if kernel is None: + kernel = [1] * factor + + # setup kernel + kernel = Tensor(kernel, dtype=ms.float32) + if kernel.ndim == 1: + kernel = ops.outer(kernel, kernel) + kernel /= ops.sum(kernel) + + kernel = kernel * (gain * (factor**2)) + + if self.use_conv: + convh = weight.shape[2] + convw = weight.shape[3] + in_c = weight.shape[1] + + pad_value = (kernel.shape[0] - factor) - (convw - 1) + + stride = (factor, factor) + # Determine data dimensions. + output_shape = ( + (hidden_states.shape[2] - 1) * factor + convh, + (hidden_states.shape[3] - 1) * factor + convw, + ) + output_padding = ( + output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convh, + output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convw, + ) + if len(output_padding) < 2: + raise IndexError( + f"output_padding must have at least 2 elements, " + f"but got {len(output_padding)} elements" + ) + + errors = [] + if output_padding[0] < 0: + errors.append(f"output_padding[0] = {output_padding[0]} < 0") + if output_padding[1] < 0: + errors.append(f"output_padding[1] = {output_padding[1]} < 0") + + if errors: + raise ValueError( + f"Invalid output padding values:\n" + + "\n".join(errors) + + f"\nAll output padding values must be >= 0. " + f"Full output_padding: {output_padding}" + ) + num_groups = hidden_states.shape[1] // in_c + + # Transpose weights. + weight = ops.reshape(weight, (num_groups, -1, in_c, convh, convw)) + weight = ops.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4) + weight = ops.reshape(weight, (num_groups * in_c, -1, convh, convw)) + conv_transpose2d = nn.Conv2dTranspose( + weight[0], + weight[1], + (weight[2], weight[3]), + stride=stride, + output_padding=output_padding, + padding=0, + pad_mode="pad", + ) + inverse_conv = conv_transpose2d(hidden_states) + + output = upfirdn2d_native( + inverse_conv, + ms.tensor(kernel), + pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1), + ) + else: + pad_value = kernel.shape[0] - factor + output = upfirdn2d_native( + hidden_states, + ms.tensor( + kernel, + ), + up=factor, + pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), + ) + + return output + + def construct(self, hidden_states): + """ + Apply upsampling transformation with optional convolutional projection. + """ + if self.use_conv: + height = self._upsample_2d( + hidden_states, self.conv2d_0.weight, kernel=self.fir_kernel + ) + height = height + self.conv2d_0.bias.reshape(1, -1, 1, 1) + else: + height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2) + + return height + + +class FirDownsample2D(nn.Cell): + """ + 2D downsampling layer with optional FIR filtering and convolutional projection. + Implements anti-aliased downsampling with optional 3x3 convolution. + """ + def __init__( + self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1) + ): + """ + Initialize downsampling layer parameters. + Args: + channels (int): Number of input channels + out_channels (int): Number of output channels (defaults to input channels if not specified) + use_conv (bool): Whether to apply 3x3 convolution before downsampling + fir_kernel (tuple): FIR filter kernel coefficients for antialiasing + + Raises: + ValueError: If invalid kernel parameters are provided + """ + super().__init__() + out_channels = out_channels if out_channels else channels + if use_conv: + self.conv2d_0 = nn.Conv2d( + channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + pad_mode="pad", + has_bias=True, + ) + self.fir_kernel = fir_kernel + self.use_conv = use_conv + self.out_channels = out_channels + + def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1): + """ + Core downsampling operation with optional convolution and FIR filtering. + """ + if not isinstance(factor, int): + raise TypeError( + f"Invalid type for 'factor'. Expected integer, " + f"but got {type(factor).__name__} with value {repr(factor)}." + ) + + if factor < 1: + raise ValueError( + f"Invalid value for 'factor'. Must be a positive integer >= 1, " + f"but got {factor}." + ) + if kernel is None: + kernel = [1] * factor + + # setup kernel + kernel = ms.tensor(kernel, dtype=ms.float32) + if kernel.ndim == 1: + kernel = ms.outer(kernel, kernel) + kernel /= ms.sum(kernel) + + kernel = kernel * gain + + if self.use_conv: + _, _, _, convw = weight.shape + pad_value = (kernel.shape[0] - factor) + (convw - 1) + stride_value = [factor, factor] + upfirdn_input = upfirdn2d_native( + hidden_states, + ms.tensor(kernel), + pad=((pad_value + 1) // 2, pad_value // 2), + ) + output = ops.conv2d(upfirdn_input, weight, stride=stride_value, padding=0) + else: + pad_value = kernel.shape[0] - factor + output = upfirdn2d_native( + hidden_states, + ms.tensor(kernel), + down=factor, + pad=((pad_value + 1) // 2, pad_value // 2), + ) + + return output + + def construct(self, hidden_states): + """ + Apply downsampling transformation with optional convolutional projection. + """ + if self.use_conv: + downsample_input = self._downsample_2d( + hidden_states, weight=self.conv2d_0.weight, kernel=self.fir_kernel + ) + hidden_states = downsample_input + self.conv2d_0.bias.reshape(1, -1, 1, 1) + else: + hidden_states = self._downsample_2d( + hidden_states, kernel=self.fir_kernel, factor=2 + ) + + return hidden_states + + +class ResnetBlock2D(nn.Cell): + """ + 2D ResNet block with optional time embeddings and spatial transformations. + Implements pre-activation residual connections with optional upsampling/downsampling. + """ + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout=0.0, + temb_channels=512, + groups=32, + groups_out=None, + pre_norm=True, + eps=1e-6, + non_linearity="swish", + time_embedding_norm="default", + kernel=None, + output_scale_factor=1.0, + use_in_shortcut=None, + up=False, + down=False, + ): + """ + Initialize ResNet block with configurable normalization and spatial transformations. + + Args: + in_channels (int): Number of input channels + out_channels (int): Number of output channels (defaults to in_channels) + conv_shortcut (bool): Use 1x1 convolution for shortcut connection + dropout (float): Dropout probability (default=0) + temb_channels (int): Time embedding dimension (default=512) + groups (int): Number of groups for group normalization + groups_out (int): Groups for second normalization layer (defaults to groups) + pre_norm (bool): Apply normalization before non-linearity + eps (float): Epsilon for numerical stability in normalization + non_linearity (str): Activation function type ("swish", "mish", "silu") + time_embedding_norm (str): Time embedding normalization mode ("default" or "scale_shift") + kernel (str): Upsample/downsample kernel type ("fir", "sde_vp") + output_scale_factor (float): Output scaling factor (default=1.0) + use_in_shortcut (bool): Force shortcut connection usage + up (bool): Enable upsampling transformation + down (bool): Enable downsampling transformation + + Raises: + ValueError: If invalid non_linearity or time_embedding_norm values are provided + """ + super().__init__() + self.pre_norm = pre_norm + self.pre_norm = True + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.time_embedding_norm = time_embedding_norm + self.up = up + self.down = down + self.output_scale_factor = output_scale_factor + self.groups = groups + self.in_channels = in_channels + self.eps = eps + if groups_out is None: + groups_out = groups + + self.norm1 = nn.GroupNorm( + num_groups=groups, num_channels=in_channels, eps=eps, affine=True + ) + + self.conv1 = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + pad_mode="pad", + has_bias=True, + ) + + if temb_channels is not None: + if self.time_embedding_norm == "default": + time_emb_proj_out_channels = out_channels + elif self.time_embedding_norm == "scale_shift": + time_emb_proj_out_channels = out_channels * 2 + else: + raise ValueError( + f"unknown time_embedding_norm : {self.time_embedding_norm} " + ) + + self.time_emb_proj = nn.Dense(temb_channels, time_emb_proj_out_channels) + else: + self.time_emb_proj = None + + self.norm2 = nn.GroupNorm( + num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True + ) + self.dropout = nn.Dropout(p=dropout) + self.conv2 = mint.nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if non_linearity == "swish": + self.nonlinearity = ops.silu() + elif non_linearity == "mish": + self.nonlinearity = Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + + self.upsample = self.downsample = None + if self.up: + if kernel == "fir": + fir_kernel = (1, 3, 3, 1) + self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel) + elif kernel == "sde_vp": + self.upsample = partial( + ops.interpolate, scale_factor=2.0, mode="nearest" + ) + else: + self.upsample = Upsample2D(in_channels, use_conv=False) + elif self.down: + if kernel == "fir": + fir_kernel = (1, 3, 3, 1) + self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel) + elif kernel == "sde_vp": + self.downsample = partial(mint.nn.AvgPool2d, kernel_size=2, stride=2) + else: + self.downsample = Downsample2D( + in_channels, use_conv=False, padding=1, name="op" + ) + + self.use_in_shortcut = ( + self.in_channels != self.out_channels + if use_in_shortcut is None + else use_in_shortcut + ) + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + pad_mode="pad", + has_bias=True, + ) + + def construct(self, input_tensor, temb): + """ + Forward pass of the ResNet block. + + Args: + input_tensor (Tensor): Input tensor of shape (batch, channels, height, width). + temb (Tensor): Optional time embedding tensor. + + Returns: + Tensor: Output tensor after applying residual block operations. + """ + hidden_states = input_tensor + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + if self.upsample is not None: + input_tensor = self.upsample(input_tensor) + hidden_states = self.upsample(hidden_states) + + elif self.downsample is not None: + input_tensor = self.downsample(input_tensor) + hidden_states = self.downsample(hidden_states) + hidden_states = self.conv1(hidden_states) + if temb is not None: + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] + if temb is not None and self.time_embedding_norm == "default": + hidden_states = hidden_states + temb + hidden_states = self.norm2(hidden_states) + if temb is not None and self.time_embedding_norm == "scale_shift": + scale, shift = ops.chunk(temb, 2, axis=1) + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + return output_tensor + + +class Mish(nn.Cell): + """Implements the Mish activation function: x * tanh(softplus(x)).""" + def __init__(self): + super().__init__() + self.tanh = ops.Tanh() + self.softplus = ops.Softplus() + + def construct(self, hidden_states): + """Compute Mish activation on input tensor.""" + return hidden_states * self.tanh(self.softplus(hidden_states)) + + +def rearrange_dims(tensor): + """ + Adjust tensor dimensions based on input shape: + - 2D → add two singleton dimensions + - 3D → add one singleton dimension + - 4D → squeeze spatial dimensions + + Args: + tensor (Tensor): Input tensor. + + Returns: + Tensor: Dimension-adjusted tensor. + + Raises: + ValueError: If input tensor has invalid dimensions. + """ + if len(tensor.shape) == 2: + return tensor[:, :, None] + if len(tensor.shape) == 3: + return tensor[:, :, None, :] + if len(tensor.shape) == 4: + return tensor[:, :, 0, :] + raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.") + + +class Conv1dBlock(nn.Cell): + """ + 1D Convolution block with GroupNorm and Mish activation. + + Args: + inp_channels (int): Number of input channels. + out_channels (int): Number of output channels. + kernel_size (int): Convolution kernel size. + n_groups (int): Number of groups for GroupNorm. Defaults to 8. + """ + def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): + super().__init__() + + self.conv1d = nn.Conv1d( + inp_channels, + out_channels, + kernel_size, + padding=kernel_size // 2, + has_bias=True, + pad_mode="valid", + ) + self.group_norm = nn.GroupNorm(n_groups, out_channels) + self.mish = ops.mish() + + def construct(self, x): + """Apply convolution, normalization, dimension rearrangement and activation.""" + x = self.conv1d(x) + x = rearrange_dims(x) + x = self.group_norm(x) + x = rearrange_dims(x) + x = self.mish(x) + return x + + +class ResidualTemporalBlock1D(nn.Cell): + """ResidualTemporalBlock1D""" + def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5): + super().__init__() + self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size) + self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size) + + self.time_emb_act = nn.Mish() + self.time_emb = nn.Linear(embed_dim, out_channels) + + self.residual_conv = ( + nn.Conv1d(inp_channels, out_channels, 1, has_bias=True, pad_mode="valid") + if inp_channels != out_channels + else nn.Identity() + ) + + def construct(self, x, t): + """ + Args: + x : [ batch_size x inp_channels x horizon ] + t : [ batch_size x embed_dim ] + + returns: + out : [ batch_size x out_channels x horizon ] + """ + t = self.time_emb_act(t) + t = self.time_emb(t) + out = self.conv_in(x) + rearrange_dims(t) + out = self.conv_out(out) + return out + self.residual_conv(x) + + +def upsample_2d(hidden_states, kernel=None, factor=2, gain=1): + """Upsample2D a batch of 2D images with the given filter.""" + if not isinstance(factor, int): + raise TypeError( + f"Invalid type for 'factor'. Expected integer, " + f"but got {type(factor).__name__} with value {repr(factor)}." + ) + + if factor < 1: + raise ValueError( + f"Invalid value for 'factor'. Must be a positive integer >= 1, " + f"but got {factor}." + ) + if kernel is None: + kernel = [1] * factor + + kernel = ms.tensor(kernel, dtype=ms.float32) + if kernel.ndim == 1: + kernel = ms.outer(kernel, kernel) + kernel /= ms.sum(kernel) + + kernel = kernel * (gain * (factor**2)) + pad_value = kernel.shape[0] - factor + output = upfirdn2d_native( + hidden_states, + up=factor, + pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), + ) + return output + + +def downsample_2d(hidden_states, kernel=None, factor=2, gain=1): + """Downsample2D a batch of 2D images with the given filter.""" + if not isinstance(factor, int): + raise TypeError(f"factor must be an integer, got {type(factor).__name__}") + + if factor < 1: + raise ValueError(f"factor must be >= 1, got {factor}") + + if kernel is None: + kernel = [1] * factor + + kernel = ms.tensor(kernel, dtype=ms.float32) + if kernel.ndim == 1: + kernel = ms.outer(kernel, kernel) + kernel /= ms.sum(kernel) + + kernel = kernel * gain + pad_value = kernel.shape[0] - factor + output = upfirdn2d_native( + hidden_states, + down=factor, + pad=((pad_value + 1) // 2, pad_value // 2), + ) + return output + + +def upfirdn2d_native(tensor, kernel=None, up=1, down=1, pad=(0, 0)): + """upfirdn2d native""" + up_x = up_y = up + down_x = down_y = down + pad_x0 = pad_y0 = pad[0] + pad_x1 = pad_y1 = pad[1] + + _, channel, in_h, in_w = tensor.shape + tensor = tensor.reshape(-1, in_h, in_w, 1) + + _, in_h, in_w, minor = tensor.shape + kernel_h, kernel_w = kernel.shape + + out = tensor.view(-1, in_h, 1, in_w, 1, minor) + out = ops.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = ops.pad( + out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] + ) + out = out[ + :, + max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), + max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), + :, + ] + + out = out.permute(0, 3, 1, 2) + out = out.reshape( + [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] + ) + w = ms.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = ops.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + out = out[:, ::down_y, ::down_x, :] + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + + return out.view(-1, channel, out_h, out_w) diff --git a/MindEarth/applications/nowcasting/PreDiff/src/visual.py b/MindEarth/applications/nowcasting/PreDiff/src/visual.py new file mode 100644 index 0000000000000000000000000000000000000000..65d27e6265b14b4f8573daa566707bd1005b411c --- /dev/null +++ b/MindEarth/applications/nowcasting/PreDiff/src/visual.py @@ -0,0 +1,220 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"visualization" +import math +from copy import deepcopy +from typing import Optional, Sequence, Union, Dict +from matplotlib import pyplot as plt +from matplotlib.colors import ListedColormap, BoundaryNorm +from matplotlib.font_manager import FontProperties +from matplotlib.patches import Patch +import numpy as np + + +VIL_COLORS = [ + [0, 0, 0], + [0.30196078431372547, 0.30196078431372547, 0.30196078431372547], + [0.1568627450980392, 0.7450980392156863, 0.1568627450980392], + [0.09803921568627451, 0.5882352941176471, 0.09803921568627451], + [0.0392156862745098, 0.4117647058823529, 0.0392156862745098], + [0.0392156862745098, 0.29411764705882354, 0.0392156862745098], + [0.9607843137254902, 0.9607843137254902, 0.0], + [0.9294117647058824, 0.6745098039215687, 0.0], + [0.9411764705882353, 0.43137254901960786, 0.0], + [0.6274509803921569, 0.0, 0.0], + [0.9058823529411765, 0.0, 1.0], +] + +VIL_LEVELS = [0.0, 16.0, 31.0, 59.0, 74.0, 100.0, 133.0, 160.0, 181.0, 219.0, 255.0] + + +def vil_cmap(): + """ + Generate a ListedColormap and normalization for VIL (Vertically Integrated Liquid) visualization. + + This function creates a colormap with specific color levels for VIL data visualization. It sets under/over colors + for values outside the defined levels and handles invalid (NaN) values. + + Returns: + tuple: A tuple containing: + - cmap (ListedColormap): Colormap object with defined colors. + - norm (BoundaryNorm): Normalization object based on VIL levels. + - vmin (None): Minimum value for colormap (set to None). + - vmax (None): Maximum value for colormap (set to None). + """ + cols = deepcopy(VIL_COLORS) + lev = deepcopy(VIL_LEVELS) + nil = cols.pop(0) + under = cols[0] + over = cols[-1] + cmap = ListedColormap(cols) + cmap.set_bad(nil) + cmap.set_under(under) + cmap.set_over(over) + norm = BoundaryNorm(lev, cmap.N) + vmin, vmax = None, None + return cmap, norm, vmin, vmax + + +def vis_sevir_seq( + save_path, + seq: Union[np.ndarray, Sequence[np.ndarray]], + label: Union[str, Sequence[str]] = "pred", + norm: Optional[Dict[str, float]] = None, + interval_real_time: float = 10.0, + plot_stride=2, + label_rotation=0, + label_offset=(-0.06, 0.4), + label_avg_int=False, + fs=10, + max_cols=10, +): + """Visualize SEVIR sequence data as a grid of images with colormap and annotations. + Args: + save_path (str): Path to save the output visualization figure. + seq (Union[np.ndarray, Sequence[np.ndarray]]): Input data sequence(s) to visualize. + Can be a single array or list of arrays. + label (Union[str, Sequence[str]], optional): Labels for each sequence. Defaults to "pred". + norm (Optional[Dict[str, float]], optional): Normalization parameters (scale/shift). + Defaults to {"scale": 255, "shift": 0}. + interval_real_time (float, optional): Time interval between frames in real time. Defaults to 10.0. + plot_stride (int, optional): Stride for subsampling frames. Defaults to 2. + label_rotation (int, optional): Rotation angle for y-axis labels. Defaults to 0. + label_offset (tuple, optional): Offset for y-axis label position. Defaults to (-0.06, 0.4). + label_avg_int (bool, optional): Append average intensity to labels. Defaults to False. + fs (int, optional): Font size for text elements. Defaults to 10. + max_cols (int, optional): Maximum number of columns per row. Defaults to 10. + + Raises: + NotImplementedError: If input sequence type is not supported. + + Returns: + None: Saves visualization to disk and closes the figure. + """ + def cmap_dict(): + return { + "cmap": vil_cmap()[0], + "norm": vil_cmap()[1], + "vmin": vil_cmap()[2], + "vmax": vil_cmap()[3], + } + + fontproperties = FontProperties() + fontproperties.set_family("serif") + fontproperties.set_size(fs) + + if isinstance(seq, Sequence): + seq_list = [ele.astype(np.float32) for ele in seq] + if not isinstance(label, Sequence): + raise TypeError( + f"label must be a Sequence (list, tuple, etc.), " + f"got {type(label).__name__}" + ) + + if len(label) != len(seq): + raise ValueError( + f"Length mismatch: label and seq must have same length\n" + f"• len(label) = {len(label)}\n" + f"• len(seq) = {len(seq)}" + ) + label_list = label + elif isinstance(seq, np.ndarray): + seq_list = [ + seq.astype(np.float32), + ] + if not isinstance(label, str): + raise TypeError( + f"Invalid label type. Expected string, " + f"but got {type(label).__name__}. " + f"Value: {repr(label)}" + ) + + label_list = [ + label, + ] + else: + raise NotImplementedError + if label_avg_int: + label_list = [ + f"{ele1}\nAvgInt = {np.mean(ele2): .3f}" + for ele1, ele2 in zip(label_list, seq_list) + ] + seq_list = [ele[::plot_stride, ...] for ele in seq_list] + seq_in_list = [len(ele) for ele in seq_list] + max_len = max(seq_in_list) + max_len = min(max_len, max_cols) + seq_list_wrap = [] + label_list_wrap = [] + seq_in_list_wrap = [] + for i, (processed_seq, processed_label, seq_in) in enumerate(zip(seq_list, label_list, seq_in_list)): + num_row = math.ceil(seq_in / max_len) + for j in range(num_row): + slice_end = min(seq_in, (j + 1) * max_len) + seq_list_wrap.append(processed_seq[j * max_len : slice_end]) + if j == 0: + label_list_wrap.append(processed_label) + else: + label_list_wrap.append("") + seq_in_list_wrap.append(min(seq_in - j * max_len, max_len)) + + if norm is None: + norm = {"scale": 255, "shift": 0} + nrows = len(seq_list_wrap) + fig, ax = plt.subplots(nrows=nrows, ncols=max_len, figsize=(3 * max_len, 3 * nrows)) + + for i, (processed_seq, processed_label, seq_in) in enumerate( + zip(seq_list_wrap, label_list_wrap, seq_in_list_wrap) + ): + ax[i][0].set_ylabel( + ylabel=processed_label, fontproperties=fontproperties, rotation=label_rotation + ) + ax[i][0].yaxis.set_label_coords(label_offset[0], label_offset[1]) + for j in range(0, max_len): + if j < seq_in: + x = processed_seq[j] * norm["scale"] + norm["shift"] + ax[i][j].imshow(x, **cmap_dict()) + if i == len(seq_list) - 1 and i > 0: + ax[-1][j].set_title( + f"Min {int(interval_real_time * (j + 1) * plot_stride)}", + y=-0.25, + fontproperties=fontproperties, + ) + else: + ax[i][j].axis("off") + + for i in range(len(ax)): + for j in range(len(ax[i])): + ax[i][j].xaxis.set_ticks([]) + ax[i][j].yaxis.set_ticks([]) + + num_thresh_legend = len(VIL_LEVELS) - 1 + legend_elements = [ + Patch( + facecolor=VIL_COLORS[i], + label=f"{int(VIL_LEVELS[i - 1])}-{int(VIL_LEVELS[i])}", + ) + for i in range(1, num_thresh_legend + 1) + ] + ax[0][0].legend( + handles=legend_elements, + loc="center left", + bbox_to_anchor=(-1.2, -0.0), + borderaxespad=0, + frameon=False, + fontsize="10", + ) + plt.subplots_adjust(hspace=0.05, wspace=0.05) + plt.savefig(save_path) + plt.close(fig) diff --git a/MindFlow/applications/cfd/acoustic/cbs/cbs.py b/MindFlow/applications/cfd/acoustic/cbs/cbs.py new file mode 100644 index 0000000000000000000000000000000000000000..7496a8997482ed1e923352db325485ef4f0f3b5b --- /dev/null +++ b/MindFlow/applications/cfd/acoustic/cbs/cbs.py @@ -0,0 +1,272 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The CBS (convergen Born series) API""" +from math import factorial +from time import time as toc +import numpy as np +import mindspore as ms +from mindspore import Tensor, nn, ops, numpy as mnp, lazy_inline + +from mindflow import DFTn, IDFTn + + +class CBSBlock(nn.Cell): + ''' The computation procedures for each iteration in CBS ''' + @lazy_inline + def __init__(self, shape): + ''' + No trainable parameters, but the dft cells needs initialization + Args: + shape: tuple of int, only the spatial shape, not including the batch and channel dimensions + ''' + super().__init__() + self.dft_cell = DFTn(shape) + self.idft_cell = IDFTn(shape) + + # Scattering potential calculation for real and imaginary parts + def op_v(self, ur, ui, vr, vi): + wr = ur * vr - ui * vi + wi = ur * vi + ui * vr + return wr, wi + + # Vectorized Helmholtz Green function for real and imaginary parts + def op_g(self, ur, ui, gr, gi): + fur, fui = self.dft_cell(ur, ui) + gur = gr * fur - gi * fui + gui = gi * fur + gr * fui + wr, wi = self.idft_cell(gur, gui) + return wr, wi + + # Vectorized Born iteration for real and imaginary parts + def construct(self, ur, ui, vr, vi, gr, gi, rhs, eps): + ''' run one iteration and return the incremental ''' + vur, vui = self.op_v(ur, ui, vr, vi) + gvr, gvi = self.op_g(vur + rhs, vui, gr, gi) + vgr, vgi = self.op_v(gvr - ur, gvi - ui, vr, vi) + + # eps > 0: Convergent Born series; eps == 0: Original Born Series + cond = ops.broadcast_to(eps, ur.shape) > 0 + dur = ops.select(cond, -vgi / (eps + 1e-8), gvr - ur) # '* (-1.)' comes from imag part multiplying i/eps + dui = ops.select(cond, vgr / (eps + 1e-8), gvi - ui) + + return ops.stack([dur, dui]) # return a single Tensor for compatibility with nn.SequentialCell + +class CBS(nn.Cell): + ''' The CBS cell for solving 2D acoustic equation ''' + def __init__(self, + shape, + n_iter=20, + pml_size=60, + alpha=1.0, + rampup=12, + remove_pml=True, + epsilon=None, + ): + """Configurations of the CBS solver + + Args: + shape (tuple[int]): only the spatial shape, not including the batch and channel dimensions + n_iter (int, optional): number of iterations in a single call. Defaults to 20. + pml_size (int, optional): number of grid layers to pad on each boundary for the wave to attenuate. + Defaults to 60. + alpha (float, optional): the strength of wave attenuation in PML layers. Defaults to 1.0. + rampup (int, optional): the smoothness of transition from interior domain to PML layers. Defaults to 12. + remove_pml (bool, optional): whether to remove the PML layers for the output. Defaults to True. + epsilon (float, optional): the small value to stabilize the iteration. + Defaults to None, calculating epsilon automatically. + """ + super().__init__() + + self.n_iter = n_iter + self.pml_size = pml_size + self.alpha = alpha + self.rampup = rampup + self.remove_pml = remove_pml + self.epsilon = epsilon + + shape_padded = tuple(n + 2 * pml_size for n in shape) + + dxs = (1.0, 1.0) + p_sq = sum(np.meshgrid( + *[np.fft.fftfreq(n, d)**2 for n, d in zip(shape_padded, dxs)], + indexing="ij")) * (2 * np.pi)**2 + self.p_sq = Tensor(p_sq, dtype=ms.float32, const_arg=True) + + pml_mask = 1 - np.pad(np.ones(shape), pml_size) + self.pml_mask = Tensor(pml_mask, dtype=ms.float32, const_arg=True) + + self.cbs_block = CBSBlock(shape_padded) + + def cbs_params(self, c_star, f_star): + ''' compute constant variables for CBS iteration ''' + pml_size = self.pml_size + nz, nx = c_star.shape[-2:] + dxs = (1.0, 1.0) + omg = 1.0 + + # source field + rhs = ops.pad(f_star / c_star**2, [pml_size] * 4) # (batch, 1, nz_padded, nx_padded) + + # homogeneous k field + k_max = omg / ops.amin(c_star, axis=(-2, -1), keepdims=True) + k_min = omg / ops.amax(c_star, axis=(-2, -1), keepdims=True) + k0 = ops.sqrt(0.5 * (k_max**2 + k_min**2)) # (batch, 1, 1, 1) + + # heterogeneous k field + ksq_r, ksq_i = self.cbs_pml( + (nz, nx), dxs, k_max, pml_size, self.alpha, self.rampup) # (batch, 1, nz_padded, nx_padded) + + ksq_r = ksq_r * self.pml_mask + ops.pad((omg / c_star)**2, [pml_size] * 4) * (1 - self.pml_mask) + ksq_i = ksq_i * self.pml_mask + + eps = ops.amax((ksq_r - k0**2)**2 + ksq_i**2, axis=(-2, -1), keepdims=True)**.5 # (batch, 1, 1, 1) + + # if epsilon given by user, use original BS instead of CBS + if isinstance(self.epsilon, (float, int)): + eps = self.epsilon * ops.ones_like(eps) + + # field variables needed by operator V & G + vr = ksq_r - k0**2 # (batch, 1, nz_padded, nx_padded) + vi = ksq_i - eps # (batch, 1, nz_padded, nx_padded) + gr = 1. / ((self.p_sq - k0**2)**2 + eps**2) * (self.p_sq - k0**2) # (batch, 1, nz_padded, nx_padded) + gi = 1. / ((self.p_sq - k0**2)**2 + eps**2) * eps # (batch, 1, nz_padded, nx_padded) + + return vr, vi, gr, gi, rhs, eps * (self.epsilon is None) + + @staticmethod + def cbs_pml(shape, dxs, k0, pml_size, alpha, rampup): + ''' construct the heterogeneous k field with PML BC embedded ''' + shape_padded = tuple(n + 2 * pml_size for n in shape) + + def num(x): + num_real = (alpha ** 2) * (rampup - alpha * x) * ((alpha * x) ** (rampup - 1)) + num_imag = (alpha ** 2) * (2 * k0 * x) * ((alpha * x) ** (rampup - 1)) + return num_real, num_imag + + def den(x): + return sum([(alpha * x) ** i / float(factorial(i)) for i in range(rampup + 1)]) * factorial(rampup) + + def transform_fun(x): + num_real, num_imag = num(x) + den_x = den(x) + transform_real, transform_imag = num_real / den_x, num_imag / den_x + return transform_real, transform_imag + + diff = ops.stack(mnp.meshgrid( + *[((ops.abs(mnp.linspace(1 - n, n - 1, n)) - n) / 2 + pml_size) * d for n, d in zip(shape_padded, dxs)], + indexing="ij"), axis=0) + + diff *= (diff > 0).astype(ms.float32) / 4. + + dist = ops.norm(diff, dim=0) + k_k0_real, k_k0_imag = transform_fun(dist) + ksq_r = k_k0_real + k0 ** 2 + ksq_i = k_k0_imag + + return ksq_r, ksq_i + + def construct(self, c_star, f_star, ur_init=None, ui_init=None): + ''' + Run the solver to solve non-dimensionalized 2D acoustic equation for given c* and f* + Args: + c_star: float (batch_size, 1, nz, nx), the non-dimensionalized velocity field + f_star: float (batch_size, 1, nz, nx), the mask marking out the source locations + ur_init, ui_init: float (batch_size, 1, NZ, NX), initial wave field for iteration, real & imag parts. + If remove_pml is True, NZ = nz, NX = nx, otherwise NZ = nz + 2 * pml_size, NX = nx + 2 * pml_size. + Default is None, which means initialize from 0. + ''' + vr, vi, gr, gi, rhs, eps = self.cbs_params(c_star, f_star) + + n0 = self.remove_pml * self.pml_size + n1 = (ur_init is None or self.remove_pml) * self.pml_size + n2 = (ui_init is None or self.remove_pml) * self.pml_size + + # construct initial field + if ur_init is None: + ur_init = ops.zeros_like(c_star, dtype=ms.float32) # (batch, 1, nz, nx) + if ui_init is None: + ui_init = ops.zeros_like(c_star, dtype=ms.float32) # (batch, 1, nz, nx) + + # pad initial field + # note: here u_init is conjugated, because the output is also conjugated + ur = ops.pad(ur_init, padding=[n1] * 4, value=0) # note: better padding (with gradual damping) can be applied + ui = ops.pad(-1. * ui_init, padding=[n2] * 4, value=0) # (batch, 1, nz_padded, nx_padded) + + # start iteration + errs_list = [] + + for _ in range(self.n_iter): + dur, dui = self.cbs_block(ur, ui, vr, vi, gr, gi, rhs, eps) + ur += dur + ui += dui + + # calculate iteration residual + errs = (ops.sum(dur**2 + dui**2, dim=(-2, -1)) / ops.sum(ur**2 + ui**2, dim=(-2, -1)))**.5 + errs_list.append(errs) + + # remove pml layer + nz, nx = ur.shape[-2:] + ur = ur[..., n0:nz - n0, n0:nx - n0] + ui = ui[..., n0:nz - n0, n0:nx - n0] + ui *= -1. + # note: the conjugate here is because we define Fourier modes differently to JAX in that the frequencies + # are opposite, leading to opposite attenuation in PML, and finally the conjugation in results + + return ur, ui, errs_list + + def solve(self, + c_star, + f_star, + ur_init=None, + ui_init=None, + tol=1e-3, + max_iter=10000, + remove_pml=True, + print_info=True, + ): + """A convenient method for solving the equation to a given tolerance + + Args: + tol (float, optional): the tolerance for the relative error. Defaults to 1e-3. + """ + msg = 'PML layers cannot be removed during iteration, but can be removed for the final result' + assert not self.remove_pml, msg + + tic = toc() + + ur, ui, errs_list = self(c_star, f_star, ur_init, ui_init) + + for ep in range(max_iter // self.n_iter): + err_max = float(errs_list[-1].max()) + err_min = float(errs_list[-1].min()) + err_ave = float(errs_list[-1].mean()) + + if print_info: + print(f'step {(ep + 1) * self.n_iter}, max error {err_max:.6f}', end=', ') + print(f'min error {err_min:.6f}, mean error {err_ave:.6f}', end=', ') + print(f'mean step time {(toc() - tic) / self.n_iter:.4f}s') + tic = toc() + + if err_max < tol: + break + + ur, ui, errs = self(c_star, f_star, ur, ui) + errs_list += errs + + if remove_pml and self.pml_size: + ur = ur[..., self.pml_size:-self.pml_size, self.pml_size:-self.pml_size] + ui = ui[..., self.pml_size:-self.pml_size, self.pml_size:-self.pml_size] + + return ur, ui, errs_list diff --git a/MindFlow/applications/data_driven/airfoil/2D_unsteady/src/fno2d.py b/MindFlow/applications/data_driven/airfoil/2D_unsteady/src/fno2d.py new file mode 100644 index 0000000000000000000000000000000000000000..8cd1d90be33c5977d52996c374a574f2fb7f9408 --- /dev/null +++ b/MindFlow/applications/data_driven/airfoil/2D_unsteady/src/fno2d.py @@ -0,0 +1,220 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +fno2d +""" +import numpy as np +import mindspore.common.dtype as mstype +from mindspore import ops, nn, Tensor, Parameter +from mindspore.ops import operations as P +from mindspore.common.initializer import Zero + +from mindflow.utils.check_func import check_param_type +from mindflow.core.math import get_grid_2d +from mindflow import RDFTn, IRDFTn + + +class FNO2D(nn.Cell): + r""" + The 2-dimensional Fourier Neural Operator (FNO2D) contains a lifting layer, + multiple Fourier layers and a decoder layer. + The details can be found in `Fourier neural operator for parametric + partial differential equations `_. + + Args: + in_channels (int): The number of channels in the input space. + out_channels (int): The number of channels in the output space. + resolution (int): The spatial resolution of the input. + modes (int): The number of low-frequency components to keep. + channels (int): The number of channels after dimension lifting of the input. Default: 20. + depths (int): The number of FNO layers. Default: 4. + mlp_ratio (int): The number of channels lifting ratio of the decoder layer. Default: 4. + compute_dtype (dtype.Number): The computation type of dense layer. + Default mstype.float16. + Should be mstype.float16 or mstype.float32. + mstype.float32 is recommended for the GPU backend, mstype.float16 is recommended for the Ascend backend. + + Inputs: + - **x** (Tensor) - Tensor of shape :math:`(batch_size, resolution, resolution, in_channels)`. + + Outputs: + Tensor, the output of this FNO network. + + - **output** (Tensor) - Tensor of shape :math:`(batch_size, resolution, resolution, out_channels)`. + - grid (Tensor) - Tensor of shape :`(1, resolution, resolution, 2)` + + Raises: + TypeError: If `in_channels` is not an int. + TypeError: If `out_channels` is not an int. + TypeError: If `resolution` is not an int. + TypeError: If `modes` is not an int. + ValueError: If `modes` is less than 1. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> from mindspore.common.initializer import initializer, Normal + >>> from mindflow.cell.neural_operators import FNO2D + >>> B, H, W, C = 32, 64, 64, 1 + >>> input = initializer(Normal(), [B, H, W, C]) + >>> net = FNO2D(in_channels=1, out_channels=1, resolution=64, modes=12) + >>> output = net(input) + >>> print(output.shape) + (32, 64, 64, 1) + + """ + + def __init__(self, + in_channels, + out_channels, + resolution, + modes, + channels=20, + depths=4, + mlp_ratio=4, + compute_dtype=mstype.float32): + super().__init__() + check_param_type(in_channels, "in_channels", + data_type=int, exclude_type=bool) + check_param_type(out_channels, "out_channels", + data_type=int, exclude_type=bool) + check_param_type(resolution, "resolution", + data_type=int, exclude_type=bool) + check_param_type(modes, "modes", data_type=int, exclude_type=bool) + if modes < 1: + raise ValueError("modes must at least 1, but got mode: {}".format(modes)) + self.compute_dtype = compute_dtype + + self.modes1 = modes + self.channels = channels + self.fc_channel = mlp_ratio * channels + self.fc0 = nn.Dense(in_channels + 2, self.channels, has_bias=True, + weight_init='Uniform', bias_init='Uniform').to_float(self.compute_dtype) + self.layers = depths + + self.fno_seq = nn.SequentialCell() + for _ in range(self.layers): + self.fno_seq.append(FNOBlock(self.channels, self.channels, modes1=self.modes1, + resolution=resolution, compute_dtype=self.compute_dtype)) + + self.fc1 = nn.Dense(self.channels, 128, has_bias=True, weight_init='Uniform', + bias_init='Uniform').to_float(self.compute_dtype) + self.fc2 = nn.Dense(128, out_channels, has_bias=True, weight_init='Uniform', + bias_init='Uniform').to_float(self.compute_dtype) + + self.grid = Tensor(get_grid_2d(resolution), self.compute_dtype) + self.concat = ops.Concat(axis=-1) + self.act = ops.ReLU() + + def construct(self, x: Tensor): + """forward""" + batch_size = x.shape[0] + grid = self.grid.repeat(batch_size, axis=0) + x = P.Concat(-1)((x, grid)) + x = self.fc0(x) + x = P.Transpose()(x, (0, 3, 1, 2)) + x = self.fno_seq(x) + x = P.Transpose()(x, (0, 2, 3, 1)) + x = self.fc1(x) + x = self.act(x) + output = self.fc2(x) + return output + + +class FNOBlock(nn.Cell): + """FNOBlock""" + def __init__(self, in_channels, out_channels, modes1, resolution=128, compute_dtype=mstype.float32): + super().__init__() + self.compute_dtype = compute_dtype + self.conv = SpectralConv2dDft(in_channels, out_channels, modes1, modes1, resolution, + resolution, compute_dtype=mstype.float32) + self.w = nn.Conv2d(in_channels, out_channels, 1, has_bias=True, + weight_init='HeUniform').to_float(self.compute_dtype) + self.act = ops.ReLU() + + def construct(self, x): + return self.act(self.conv(x) + self.w(x)) + + +class SpectralConv2dDft(nn.Cell): + """SpectralConv2dDft""" + def __init__(self, in_channels, out_channels, modes1, modes2, column_resolution, raw_resolution, + compute_dtype=mstype.float32): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.modes1 = modes1 + self.modes2 = modes2 + self.column_resolution = column_resolution + self.raw_resolution = raw_resolution + self.compute_dtype = compute_dtype + self.scale = (1. / (in_channels * out_channels)) + + w_re1 = Tensor(self.scale * np.random.rand(in_channels, out_channels, modes1, modes2), + dtype=mstype.float32) + w_im1 = Tensor(self.scale * np.random.rand(in_channels, out_channels, modes1, modes2), + dtype=mstype.float32) + w_re2 = Tensor(self.scale * np.random.rand(in_channels, out_channels, modes1, modes2), + dtype=mstype.float32) + w_im2 = Tensor(self.scale * np.random.rand(in_channels, out_channels, modes1, modes2), + dtype=mstype.float32) + + self.w_re1 = Parameter(w_re1, requires_grad=True) + self.w_im1 = Parameter(w_im1, requires_grad=True) + self.w_re2 = Parameter(w_re2, requires_grad=True) + self.w_im2 = Parameter(w_im2, requires_grad=True) + self.dft2_cell = RDFTn(shape=(column_resolution, raw_resolution), norm='ortho', + modes=(modes1, modes2), compute_dtype=self.compute_dtype) + self.idft2_cell = IRDFTn(shape=(column_resolution, raw_resolution), norm='ortho', + modes=(modes1, modes2), compute_dtype=self.compute_dtype) + self.mat = Tensor(shape=(1, out_channels, column_resolution - 2 * modes1, modes2), + dtype=self.compute_dtype, init=Zero()) + self.concat = ops.Concat(-2) + + @staticmethod + def mul2d(inputs, weights): + weight = weights.expand_dims(0) + data = inputs.expand_dims(2) + out = weight * data + return out.sum(1) + + def construct(self, x: Tensor): + """forward""" + x_re = x + x_ft_re, x_ft_im = self.dft2_cell(x_re) + + out_ft_re1 = \ + self.mul2d(x_ft_re[:, :, :self.modes1, :self.modes2], self.w_re1) \ + - self.mul2d(x_ft_im[:, :, :self.modes1, :self.modes2], self.w_im1) + out_ft_im1 = \ + self.mul2d(x_ft_re[:, :, :self.modes1, :self.modes2], self.w_im1) \ + + self.mul2d(x_ft_im[:, :, :self.modes1, :self.modes2], self.w_re1) + + out_ft_re2 = \ + self.mul2d(x_ft_re[:, :, -self.modes1:, :self.modes2], self.w_re2) \ + - self.mul2d(x_ft_im[:, :, -self.modes1:, :self.modes2], self.w_im2) + out_ft_im2 = \ + self.mul2d(x_ft_re[:, :, -self.modes1:, :self.modes2], self.w_im2) \ + + self.mul2d(x_ft_im[:, :, -self.modes1:, :self.modes2], self.w_re2) + + batch_size = x.shape[0] + mat = ops.cast(self.mat.repeat(batch_size, 0), self.compute_dtype) + out_re = self.concat((out_ft_re1, mat, out_ft_re2)) + out_im = self.concat((out_ft_im1, mat, out_ft_im2)) + + x = self.idft2_cell(out_re, out_im) + return x diff --git a/MindFlow/mindflow/cell/__init__.py b/MindFlow/mindflow/cell/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fb6dc6d8da6ec79601dc247d7023a7824e33bf33 --- /dev/null +++ b/MindFlow/mindflow/cell/__init__.py @@ -0,0 +1,33 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""init""" +from .activation import get_activation +from .basic_block import LinearBlock, ResBlock, InputScale, FCSequential, MultiScaleFCSequential, DropPath +from .neural_operators import (FNO1D, FNO2D, FNO3D, KNO1D, KNO2D, PDENet, PeRCNN, SNO, SNO1D, SNO2D, SNO3D, FFNO, + FFNO1D, FFNO2D, FFNO3D) +from .attention import Attention, MultiHeadAttention, TransformerBlock +from .vit import ViT +from .unet2d import UNet2D +from .sno_utils import poly_data, get_poly_transform, interpolate_1d_dataset, interpolate_2d_dataset +from .diffusion import DiffusionScheduler, DiffusionTrainer, DDPMScheduler, DDIMScheduler, DDPMPipeline, DDIMPipeline +from .diffusion_transformer import DiffusionTransformer, ConditionDiffusionTransformer + +__all__ = ["get_activation", "FNO1D", "FNO2D", "FNO3D", "KNO1D", "KNO2D", "PDENet", "UNet2D", "PeRCNN", + "SNO", "SNO1D", "SNO2D", "SNO3D", "Attention", "MultiHeadAttention", "TransformerBlock", + "ViT", "DDPMPipeline", "DDIMPipeline", "DiffusionTrainer", "DiffusionScheduler", "DDPMScheduler", + "DDIMScheduler", "DiffusionTransformer", "ConditionDiffusionTransformer", + "FFNO", "FFNO1D", "FFNO2D", "FFNO3D"] +__all__.extend(basic_block.__all__) +__all__.extend(sno_utils.__all__) diff --git a/MindFlow/mindflow/cell/attention.py b/MindFlow/mindflow/cell/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..ce4460809412ccc9301faf7d919b34a997675f1e --- /dev/null +++ b/MindFlow/mindflow/cell/attention.py @@ -0,0 +1,352 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Attention module""" +from typing import Optional +from mindspore import ops, nn, Tensor +import mindspore.common.dtype as mstype + +from .basic_block import DropPath + + +class Attention(nn.Cell): + r"""Attention implementation base class + + Args: + in_channels (int): The dimension of input vector. + num_heads (int): The number of attention heads. + compute_dtype (mindspore.dtype): Compute dtype. Default: ``mstype.float32``, indicates ``mindspore.float32``. + + Inputs: + - **x** (Tensor) - Tensor with shape :math:`(batch\_size, sequence\_len, in\_channels)`. + - **attn_mask** (Tensor, optional) - Tensor with shape :math:`(sequence\_len, sequence\_len)` or + or :math:`(batch\_size, 1, sequence\_len, sequence\_len)`. Default: ``None``. + - **key_padding_mask** (Tensor, optional) - Tensor with shape :math:`(batch\_size, sequence\_len)`. + Default: ``None``. + + Outputs: + - **output** (Tensor) - Tensor with shape :math:`(batch\_size, sequence\_len, in\_channels)`. + + Supported Platforms: + ``Ascend`` ``CPU`` + + Examples: + >>> from mindspore import ops + >>> from mindflow.cell import Attention + >>> model = Attention(in_channels=512, num_heads=4) + >>> x = ops.rand((2, 32, 512)) + >>> q, k, v = model.get_qkv(x) + >>> print(q.shape) + (2, 4, 32, 128) + """ + + def __init__(self, in_channels: int, num_heads: int, compute_dtype: mstype = mstype.float32): + super().__init__() + self.num_heads = num_heads + self.compute_dtype = compute_dtype + self.qkv = nn.Dense( + in_channels, in_channels * 3, weight_init="XavierUniform" + ).to_float(compute_dtype) + + @staticmethod + def merge_mask(attn_mask: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None) -> Tensor: + """merge mask""" + if attn_mask is None and key_padding_mask is None: + return None + mask = Tensor(0, dtype=mstype.uint8) + if attn_mask is not None: + node = attn_mask.shape[-1] + if len(attn_mask.shape) == 2: + attn_mask = attn_mask.reshape(1, 1, node, node) + elif len(attn_mask.shape) == 4: + pass + else: + raise Exception(f'attn_mask shape {attn_mask.shape} not support') + mask = mask + attn_mask.astype(mstype.uint8) + if key_padding_mask is not None: + batch, node = key_padding_mask.shape[0], key_padding_mask.shape[-1] + if len(key_padding_mask.shape) == 2: + key_padding_mask = ops.broadcast_to(key_padding_mask.unsqueeze(1), (batch, node, node)).unsqueeze(1) + else: + raise Exception(f'key_padding_mask shape {attn_mask.shape} not support') + mask = mask + key_padding_mask.astype(mstype.uint8) + return mask + + @staticmethod + def mask_scores(scores: Tensor, mask: Optional[Tensor] = None) -> Tensor: + """mask attention scores""" + if mask is None: + return scores + scores += mask * Tensor(-1e10, scores.dtype) + return scores + + def get_qkv(self, x: Tensor) -> tuple[Tensor]: + """get qkv value""" + b, n, _ = x.shape + qkv = ( + self.qkv(x).reshape(b, n, 3, self.num_heads, - + 1).transpose((2, 0, 3, 1, 4)) + ) + return qkv[0], qkv[1], qkv[2] + + def _reshape_output(self, x: Tensor) -> Tensor: + b, _, n, _ = x.shape + return x.transpose(0, 2, 1, 3).reshape(b, n, -1) + + def construct(self, x: Tensor, attn_mask: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None): + """Attention network construction.""" + raise NotImplementedError + + +class ScaledDot(nn.Cell): + """Scaled dot attention""" + + def __init__(self, scale): + super().__init__() + self.scale = scale + + def construct(self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None): + scores = ops.matmul(query, key.swapaxes(-1, -2)) * self.scale + scores = Attention.mask_scores(scores, mask) + scores = scores.astype(mstype.float32) + attn = ops.softmax(scores, axis=-1) + attn = attn.astype(value.dtype) + output = ops.matmul(attn, value) + return output + + +class FlashAttn(nn.Cell): + r"""FlashAttention proposed in `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness `_. + + Args: + num_heads (int): The number of attention heads. + scale (float): The attention scale. + fa_dtype (mindspore.dtype, optional): FlashAttention compute dtype. Choose from `mstype.bfloat16`, + `mstype.float16`. Default: ``mstype.bfloat16``, indicates ``mindspore.bfloat16``. + + Inputs: + - **query** (Tensor) - Tensor with shape :math:`(batch\_size, num\_heads, sequence\_len, in\_channels)`. + - **key** (Tensor) - Tensor with shape :math:`(batch\_size, num\_heads, sequence\_len, in\_channels)`. + - **value** (Tensor) - Tensor with shape :math:`(batch\_size, num\_heads, sequence\_len, in\_channels)`. + - **mask** (Tensor, optional) - Tensor with shape :math:`(sequence\_len, sequence\_len)` or + or :math:`(batch\_size, 1, sequence\_len, sequence\_len)`. Default: ``None``. + + Outputs: + - **output** (Tensor) - Tensor with shape :math:`(batch\_size, sequence\_len, in\_channels)`. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindspore import ops + >>> from mindflow.cell import FlashAttn + >>> model = FlashAttn(num_heads=4, scale=0.25) + >>> in_shape = (2, 16, 32, 16) + >>> q, k, v = ops.rand(in_shape), ops.rand(in_shape), ops.rand(in_shape) + >>> mask_shape = (32, 32) + >>> mask = ops.randint(0, 2, mask_shape) + >>> output = model(q, k, v, mask) + >>> print(output.shape) + (2, 16, 32, 16) + """ + + def __init__(self, num_heads: int, scale: float, fa_dtype=mstype.bfloat16): + super().__init__() + self.fa_dtype = fa_dtype + self.num_heads = num_heads + self.scale = scale + + def construct(self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None): + query, key, value = query.astype(self.fa_dtype), key.astype(self.fa_dtype), value.astype(self.fa_dtype) + if mask is not None: + mask = mask.astype(mstype.uint8) + scores = ops.flash_attention_score(query, key, value, input_layout='BNSD', head_num=self.num_heads, + attn_mask=mask, scalar_value=self.scale) + return scores + + +class MultiHeadAttention(Attention): + r"""Multi Head Attention proposed in `Attention Is All You Need `_. + + Args: + in_channels (int): The input channels. + num_heads (int): The number of attention heads. + enable_flash_attn (bool): Whether use flash attention. FlashAttention only supports Ascend backend. + FlashAttention proposed in `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness `_. + Default: ``False``. + fa_dtype (mindspore.dtype): FlashAttention compute dtype. Choose from `mstype.bfloat16`, `mstype.float16`. + Default: ``mstype.bfloat16``, indicates ``mindspore.bfloat16``. + drop_mode (str): Dropout method, ``dropout`` or ``droppath``. Default: ``dropout``. + dropout_rate (float): The drop rate of dropout layer, greater than 0 and less equal than 1. Default: ``0.0``. + compute_dtype (mindspore.dtype): Compute dtype. Default: ``mstype.float32``, indicates ``mindspore.float32``. + + Inputs: + - **x** (Tensor) - Tensor with shape :math:`(batch\_size, sequence\_len, in\_channels)`. + - **attn_mask** (Tensor, optional) - Tensor with shape :math:`(sequence\_len, sequence\_len)` or + or :math:`(batch\_size, 1, sequence\_len, sequence\_len)`. Default: ``None``. + - **key_padding_mask** (Tensor, optional) - Tensor with shape :math:`(batch\_size, sequence\_len)`. + Default: ``None``. + + Outputs: + - **output** (Tensor) - Tensor with shape :math:`(batch\_size, sequence\_len, in\_channels)`. + + Supported Platforms: + ``Ascend`` ``CPU`` + + Examples: + >>> from mindspore import ops + >>> from mindflow.cell import MultiHeadAttention + >>> model = MultiHeadAttention(in_channels=512, num_heads=4) + >>> x = ops.rand((2, 32, 512)) + >>> mask_shape = (32, 32) + >>> mask = ops.ones(mask_shape) + >>> output = model(x, mask) + >>> print(output.shape) + (2, 32, 512) + """ + + def __init__(self, in_channels: int, + num_heads: int, + enable_flash_attn: bool = False, + fa_dtype: mstype = mstype.bfloat16, + drop_mode: str = "dropout", + dropout_rate: float = 0.0, + compute_dtype: mstype = mstype.float32, + ): + super().__init__(in_channels, num_heads, compute_dtype) + assert ( + in_channels % num_heads == 0 + ), "hidden channels must be divisible by number of heads" + scale = (in_channels // num_heads) ** -0.5 + self.proj = nn.Dense(in_channels, in_channels).to_float(compute_dtype) + if enable_flash_attn: + print('use flash attention') + self.attn = FlashAttn(num_heads=num_heads, scale=scale, fa_dtype=fa_dtype) + else: + self.attn = ScaledDot(scale=scale) + if drop_mode == "dropout": + self.drop = nn.Dropout(p=dropout_rate) + else: + self.drop = DropPath(dropout_rate=dropout_rate) + + def construct(self, x: Tensor, attn_mask: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None): + """construct""" + query, key, value = self.get_qkv(x) + mask = self.merge_mask(attn_mask, key_padding_mask) + output = self.attn(query, key, value, mask) + output = output.astype(mstype.float32) + output = self._reshape_output(output) + output = self.proj(output) + output = self.drop(output) + return output + + +class FeedForward(nn.Cell): + """FeedForward""" + def __init__(self, in_channels, dropout_rate=0.0, compute_dtype=mstype.float16): + super().__init__() + self.fc1 = nn.Dense(in_channels, in_channels * 4).to_float(compute_dtype) + self.fc2 = nn.Dense(in_channels * 4, in_channels).to_float(compute_dtype) + self.act_fn = nn.GELU() + self.dropout = nn.Dropout(p=dropout_rate) + + def construct(self, x: Tensor): + """construct""" + x = self.fc1(x) + x = self.act_fn(x) + x = self.dropout(x) + x = self.fc2(x) + x = self.dropout(x) + return x + + +class TransformerBlock(nn.Cell): + r""" `TransformerBlock` comprises an `MultiHeadAttention` and an `FeedForward` layer. + + Args: + in_channels (int): The input channels. + num_heads (int): The number of attention heads. + enable_flash_attn (bool): Whether use flash attention. FlashAttention only supports Ascend backend. + FlashAttention proposed in `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness `_. + Default: ``False``. + fa_dtype (mindspore.dtype): FlashAttention compute dtype. Choose from `mstype.bfloat16`, `mstype.float16`. + Default: ``mstype.bfloat16``, indicates ``mindspore.bfloat16``. + drop_mode (str): Dropout method. Default: ``dropout``. Support ``dropout`` or ``droppath``. + dropout_rate (float): The drop rate of dropout layer, greater than 0 and less equal than 1. Default: ``0.0``. + compute_dtype (mindspore.dtype): Compute dtype. Default: ``mstype.float32``, indicates ``mindspore.float32``. + + Inputs: + - **x** (Tensor) - Tensor with shape :math:`(batch\_size, sequence\_len, in\_channels)`. + - **mask** (Tensor, optional) - Tensor with shape :math:`(sequence\_len, sequence\_len)` or + :math:`(batch\_size, 1, sequence\_len, sequence\_len)`. Default: ``None``. + + Outputs: + - **output** (Tensor) - Tensor with shape :math:`(batch\_size, sequence\_len, in\_channels)`. + + Supported Platforms: + ``Ascend`` ``CPU`` + + Examples: + >>> from mindspore import ops + >>> from mindflow.cell import TransformerBlock + >>> model = TransformerBlock(in_channels=256, num_heads=4) + >>> x = ops.rand((4, 100, 256)) + >>> output = model(x) + >>> print(output.shape) + (4, 100, 256) + """ + + def __init__(self, + in_channels: int, + num_heads: int, + enable_flash_attn: bool = False, + fa_dtype: mstype = mstype.bfloat16, + drop_mode: str = "dropout", + dropout_rate: float = 0.0, + compute_dtype: mstype = mstype.float32, + ): + super().__init__() + self.compute_dtype = compute_dtype + self.attention_norm = nn.LayerNorm([in_channels], epsilon=1e-6).to_float( + mstype.float32 + ) + self.ffn_norm = nn.LayerNorm([in_channels], epsilon=1e-6).to_float( + mstype.float32 + ) + self.ffn = FeedForward( + in_channels=in_channels, + dropout_rate=dropout_rate, + compute_dtype=compute_dtype, + ) + self.attention = MultiHeadAttention( + in_channels=in_channels, + num_heads=num_heads, + enable_flash_attn=enable_flash_attn, + fa_dtype=fa_dtype, + drop_mode=drop_mode, + dropout_rate=dropout_rate, + compute_dtype=compute_dtype, + ) + + def construct(self, x: Tensor, mask: Optional[Tensor] = None): + """construct""" + h = x + x = self.attention_norm(x) + x = self.attention(x, mask) + x = x + h + + x = self.ffn_norm(x) + x = self.ffn(x) + x = x + h + return x diff --git a/MindFlow/mindflow/cell/diffusion_transformer.py b/MindFlow/mindflow/cell/diffusion_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..cdb08ffe429dbac5b97df54f4ae1afe4124ec253 --- /dev/null +++ b/MindFlow/mindflow/cell/diffusion_transformer.py @@ -0,0 +1,269 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Diffusion transformer api""" + +import math + +import numpy as np +from mindspore import nn, ops, Tensor +from mindspore import dtype as mstype +from mindflow.cell import TransformerBlock + + +class Mlp(nn.Cell): + """MLP""" + + def __init__(self, in_channels, out_channels, dropout=0., compute_dtype=mstype.float32): + super().__init__() + self.fc1 = nn.Dense( + in_channels, 4*in_channels).to_float(compute_dtype) + self.act = nn.GELU() + self.fc2 = nn.Dense( + 4*in_channels, out_channels).to_float(compute_dtype) + self.drop = nn.Dropout(p=dropout) + + def construct(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class SinusoidalPosEmb(nn.Cell): + """sinusoidal embedding model""" + + def __init__(self, dim, max_period=10000, compute_dtype=mstype.float32): + super().__init__() + half_dim = dim // 2 + self.concat_zero = (dim % 2 == 1) + freqs = np.exp(-math.log(max_period) * + np.arange(start=0, stop=half_dim) / half_dim) + self.freqs = Tensor(freqs, compute_dtype) + + def construct(self, x): + emb = x[:, None] * self.freqs[None, :] + emb = ops.concat((ops.cos(emb), ops.sin(emb)), axis=-1) + if self.concat_zero: + emb = ops.concat([emb, ops.zeros_like(emb[:, :1])], axis=-1) + return emb + + +class Transformer(nn.Cell): + """Transformer backbone model""" + + def __init__(self, hidden_channels, layers, heads, compute_dtype=mstype.float32): + super().__init__() + self.hidden_channels = hidden_channels + self.layers = layers + self.blocks = nn.CellList([ + TransformerBlock( + in_channels=hidden_channels, + num_heads=heads, + drop_mode="dropout", + dropout_rate=0.0, + compute_dtype=compute_dtype, + ) + for _ in range(layers)]) + + def construct(self, x): + for block in self.blocks: + x = block(x) + return x + + +class DiffusionTransformer(nn.Cell): + r""" + Diffusion model with Transformer backbone implementation. + + Args: + in_channels (int): The number of input channel. + out_channels (int): The number of output channel. + hidden_channels (int): The number of hidden channel. + layers (int): The number of transformer block layers. + heads (int): The number of transformer heads. + time_token_cond (bool): Whether to use timestep as condition token. Default: ``True``. + compute_dtype (mindspore.dtype): The dtype of compute, it can be ``mstype.float32`` or ``mstype.float16``. + Default: ``mstype.float32``, indicates ``mindspore.float32``. + + Inputs: + - **x** (Tensor) - The input has a shape of :math:`(batch\_size, sequence\_len, in\_channels)`. + - **timestep** (Tensor) - The timestep input has a shape of :math:`(batch\_size,)`. + + Outputs: + - **output** (Tensor) - The output has a shape of :math:`(batch\_size, sequence\_len, out\_channels)`. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindspore import ops + >>> from mindflow.cell import DiffusionTransformer + >>> in_channels, out_channels, hidden_channels, layers, heads, batch_size, seq_len = 16, 16, 256, 3, 4, 8, 256 + >>> model = DiffusionTransformer(in_channels=in_channels, + ... out_channels=out_channels, + ... hidden_channels=hidden_channels, + ... layers=layers, + ... heads=heads) + >>> x = ops.rand((batch_size, seq_len, in_channels)) + >>> timestep = ops.randint(0, 1000, (batch_size,)) + >>> output = model(x, timestep) + >>> print(output.shape) + (8, 256, 16) + """ + + def __init__(self, + in_channels, + out_channels, + hidden_channels, + layers, + heads, + time_token_cond=True, + compute_dtype=mstype.float32): + super().__init__() + self.time_token_cond = time_token_cond + self.in_channels = in_channels + self.timestep_emb = SinusoidalPosEmb( + hidden_channels, compute_dtype=compute_dtype) + self.time_embed = Mlp( + in_channels=hidden_channels, + out_channels=hidden_channels, + dropout=0., + compute_dtype=compute_dtype + ) + + self.ln_pre = nn.LayerNorm( + (hidden_channels,), epsilon=1e-5).to_float(mstype.float32) + self.backbone = Transformer( + hidden_channels=hidden_channels, + layers=layers, + heads=heads, + compute_dtype=compute_dtype + ) + self.ln_post = nn.LayerNorm( + (hidden_channels,), epsilon=1e-5).to_float(mstype.float32) + self.input_proj = nn.Dense( + in_channels, hidden_channels).to_float(compute_dtype) + self.output_proj = nn.Dense( + hidden_channels, out_channels, weight_init='zeros', bias_init='zeros').to_float(compute_dtype) + + def construct(self, x, timestep): + """construct""" + t_embed = self.time_embed(self.timestep_emb(timestep)) + return self._forward_with_cond(x, [(t_embed, self.time_token_cond)]) + + def _forward_with_cond(self, x, cond_token_list): + """forward network with condition input""" + h = self.input_proj(x) + extra_tokens = [] + for tokens, as_token in cond_token_list: + if as_token: + if len(tokens.shape) == 2: + extra_tokens.append(tokens.unsqueeze(1)) + else: + extra_tokens.append(tokens) + else: + h = h + tokens.unsqueeze(1) + + if extra_tokens: + h = ops.concat(extra_tokens + [h], axis=1) + + h = self.ln_pre(h) + h = self.backbone(h) + h = self.ln_post(h) + if extra_tokens: + # keep sequence length unchanged + h = h[:, sum(token.shape[1] for token in extra_tokens):] + h = self.output_proj(h) + return h + + +class ConditionDiffusionTransformer(DiffusionTransformer): + r""" + Conditioned Diffusion Transformer implementation. + + Args: + in_channels (int): The number of input channel. + out_channels (int): The number of output channel. + hidden_channels (int): The number of hidden channel. + cond_channels (int): The number of condition channel. + layers (int): The number of transformer block layers. + heads (int): The number of transformer heads. + time_token_cond (bool): Whether to use timestep as condition token. Default: ``True``. + cond_as_token (bool): Whether to use condition as token. Default: ``True``. + compute_dtype (mindspore.dtype): the dtype of compute, it can be ``mstype.float32`` or ``mstype.float16``. + Default: ``mstype.float32``, indicates ``mindspore.float32``. + + Inputs: + - **x** (Tensor) - The input has a shape of :math:`(batch\_size, sequence\_len, in\_channels)`. + - **timestep** (Tensor) - The timestep input has a shape of :math:`(batch\_size,)`. + - **condition** (Tensor) - The condition input has a shape of :math:`(batch\_size, cond\_size)`. + Default: ``None``. + + Outputs: + - **output** (Tensor) - The output has a shape of :math:`(batch\_size, sequence\_len, out\_channels)`. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindspore import ops + >>> from mindflow.cell import ConditionDiffusionTransformer + >>> in_channels, out_channels, cond_channels, hidden_channels = 16, 16, 10, 256 + >>> layers, heads, batch_size, seq_len = 3, 4, 8, 256 + >>> model = ConditionDiffusionTransformer(in_channels=in_channels, + ... out_channels=out_channels, + ... cond_channels=cond_channels, + ... hidden_channels=hidden_channels, + ... layers=layers, + ... heads=heads) + >>> x = ops.rand((batch_size, seq_len, in_channels)) + >>> cond = ops.rand((batch_size, cond_channels)) + >>> timestep = ops.randint(0, 1000, (batch_size,)) + >>> output = model(x, timestep, cond) + >>> print(output.shape) + (8, 256, 16) + """ + + def __init__(self, in_channels, + out_channels, + cond_channels, + hidden_channels, + layers, + heads, + time_token_cond=True, + cond_as_token=True, + compute_dtype=mstype.float32): + super().__init__(in_channels, + out_channels, + hidden_channels, + layers, + heads, + time_token_cond, + compute_dtype) + self.cond_as_token = cond_as_token + self.cond_embed = nn.Dense( + cond_channels, hidden_channels).to_float(compute_dtype) + + # pylint: disable=W0221 + def construct(self, x, timestep, condition=None): + """forward network with timestep and condition input """ + t_embed = self.time_embed(self.timestep_emb(timestep)) + full_cond = [(t_embed, self.time_token_cond)] + if condition is not None: + cond_emb = self.cond_embed(condition) + full_cond.append((cond_emb, self.cond_as_token)) + return self._forward_with_cond(x, full_cond) diff --git a/MindFlow/mindflow/cell/neural_operators/__init__.py b/MindFlow/mindflow/cell/neural_operators/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..682152fd8c5248115624050974051ab44c395bb6 --- /dev/null +++ b/MindFlow/mindflow/cell/neural_operators/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""init""" +from .fno import FNOBlocks, FNO1D, FNO2D, FNO3D +from .kno1d import KNO1D +from .kno2d import KNO2D +from .pdenet import PDENet +from .percnn import PeRCNN +from .sno import SNO, SNO1D, SNO2D, SNO3D +from .ffno import FFNOBlocks, FFNO, FFNO1D, FFNO2D, FFNO3D + +__all__ = ["FNOBlocks", "FNO1D", "FNO2D", "FNO3D", "KNO1D", "KNO2D", "PDENet", "PeRCNN", + "SNO", "SNO1D", "SNO2D", "SNO3D", "FFNOBlocks", "FFNO", "FFNO1D", "FFNO2D", "FFNO3D"] + +__all__.sort() diff --git a/MindFlow/mindflow/cell/neural_operators/dft.py b/MindFlow/mindflow/cell/neural_operators/dft.py new file mode 100644 index 0000000000000000000000000000000000000000..e41ba2b49cd27d07bf4d7b0f3ae4c7d11ba687e0 --- /dev/null +++ b/MindFlow/mindflow/cell/neural_operators/dft.py @@ -0,0 +1,723 @@ +'''' +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +''' +import numpy as np +from scipy.linalg import dft + +import mindspore +import mindspore.common.dtype as mstype +from mindspore import nn, ops, Tensor, Parameter, mint +from mindspore.common.initializer import Zero +from mindspore.ops import operations as P + +from ...utils.check_func import check_param_no_greater, check_param_value, check_param_type, check_param_even + + +class DFT1d(nn.Cell): + '''One dimensional Discrete Fourier Transformation''' + + def __init__(self, n, modes, last_index, idx=0, inv=False, compute_dtype=mindspore.float32): + super().__init__() + + self.n = n + self.dft_mat = dft(n, scale="sqrtn") + self.modes = modes + self.last_index = last_index + self.inv = inv + self.idx = idx + self.compute_dtype = compute_dtype + + self.dft_mode_mat_upper = self.dft_mat[:, :modes] + self.a_re_upper = Tensor( + self.dft_mode_mat_upper.real, dtype=compute_dtype) + self.a_im_upper = Tensor( + self.dft_mode_mat_upper.imag, dtype=compute_dtype) + + self.dft_mode_mat_lower = self.dft_mat[:, -modes:] + self.a_re_lower = Tensor( + self.dft_mode_mat_lower.real, dtype=compute_dtype) + self.a_im_lower = Tensor( + self.dft_mode_mat_lower.imag, dtype=compute_dtype) + self.concat = ops.Concat(axis=-1) + + if self.inv: + self.a_re_upper = self.a_re_upper.T + self.a_im_upper = -self.a_im_upper.T + if last_index: + if modes == n // 2 + 1: + self.dft_mat_res = self.dft_mat[:, -modes + 2:] + else: + self.dft_mat_res = self.dft_mat[:, -modes + 1:] + + mat = Tensor(np.zeros(n, ), dtype=compute_dtype).reshape(n, 1) + self.a_re_res = mindspore.numpy.flip( + Tensor(self.dft_mat_res.real, dtype=compute_dtype), axis=-1) + self.a_im_res = mindspore.numpy.flip( + Tensor(self.dft_mat_res.imag, dtype=compute_dtype), axis=-1) + if modes == n // 2 + 1: + self.a_re_res = self.concat((mat, self.a_re_res, mat)) + self.a_im_res = self.concat((mat, self.a_im_res, mat)) + else: + self.a_re_res = self.concat((mat, self.a_re_res)) + self.a_im_res = self.concat((mat, self.a_im_res)) + + self.a_re_res = self.a_re_res.T + self.a_im_res = -self.a_im_res.T + else: + self.a_re_res = self.a_re_lower.T + self.a_im_res = -self.a_im_lower.T + + if (self.n - 2 * self.modes) > 0: + self.mat = Tensor(shape=(self.n - 2 * self.modes), + dtype=compute_dtype, init=Zero()) + + def swap_axes(self, x_re, x_im): + return x_re.swapaxes(-1, self.idx), x_im.swapaxes(-1, self.idx) + + def complex_matmul(self, x_re, x_im, a_re, a_im): + y_re = ops.matmul(x_re, a_re) - ops.matmul(x_im, a_im) + y_im = ops.matmul(x_im, a_re) + ops.matmul(x_re, a_im) + return y_re, y_im + + def construct(self, x): + x_re, x_im = x + x_re, x_im = P.Cast()(x_re, self.compute_dtype), P.Cast()(x_im, self.compute_dtype) + if not self.inv: + x_re, x_im = self.swap_axes(x_re, x_im) + y_re, y_im = self.complex_matmul( + x_re=x_re, x_im=x_im, a_re=self.a_re_upper, a_im=self.a_im_upper) + + if not self.last_index: + y_re2, y_im2 = self.complex_matmul( + x_re=x_re, x_im=x_im, a_re=self.a_re_lower, a_im=self.a_im_lower) + + if self.n == self.modes * 2: + y_re = self.concat((y_re, y_re2)) + y_im = self.concat((y_im, y_im2)) + else: + dims = x_re.shape[:-1] + length = len(dims) + mat = self.mat + for i in range(length - 1, -1, -1): + mat = mint.repeat_interleave(mat.expand_dims(0), dims[i], 0) + y_re = self.concat((y_re, mat, y_re2)) + y_im = self.concat((y_im, mat, y_im2)) + + y_re, y_im = self.swap_axes(y_re, y_im) + return y_re, y_im + + x_re, x_im = self.swap_axes(x_re, x_im) + y_re, y_im = self.complex_matmul(x_re=x_re[..., :self.modes], x_im=x_im[..., :self.modes], + a_re=self.a_re_upper, + a_im=self.a_im_upper) + y_re, y_im = self.swap_axes(y_re, y_im) + + if self.last_index: + y_re_res, y_im_res = self.complex_matmul( + x_re=x_re, x_im=x_im, a_re=self.a_re_res, a_im=-self.a_im_res) + else: + y_re_res, y_im_res = self.complex_matmul(x_re=x_re[..., -self.modes:], x_im=x_im[..., -self.modes:], + a_re=self.a_re_res, a_im=self.a_im_res) + + y_re_res, y_im_res = self.swap_axes(y_re_res, y_im_res) + return y_re + y_re_res, y_im + y_im_res + + +class DFTn(nn.Cell): + '''N dimensional Discrete Fourier Transformation''' + + def __init__(self, shape, modes, dim=None, inv=False, compute_dtype=mindspore.float32): + super().__init__() + + if dim is None: + dim = range(len(shape)) + self.dft1_seq = nn.SequentialCell() + last_index = [False for _ in range(len(shape))] + last_index[-1] = True + for dim_id, idx in enumerate(dim): + self.dft1_seq.append( + DFT1d(n=shape[dim_id], modes=modes[dim_id], last_index=last_index[dim_id], idx=idx, inv=inv, + compute_dtype=compute_dtype)) + + def construct(self, x): + return self.dft1_seq(x) + + +def _dftn(shape, modes, dim=None, compute_dtype=mindspore.float32): + dftn_ = DFTn(shape=shape, modes=modes, dim=dim, + inv=False, compute_dtype=compute_dtype) + return dftn_ + + +def _idftn(shape, modes, dim=None, compute_dtype=mindspore.float32): + idftn_ = DFTn(shape=shape, modes=modes, dim=dim, + inv=True, compute_dtype=compute_dtype) + return idftn_ + + +def dft3(shape, modes, dim=(-3, -2, -1), compute_dtype=mindspore.float32): + r""" + Calculate three-dimensional discrete Fourier transform. Corresponding to the rfftn operator in torch. + + Args: + shape (tuple): Dimension of the input 'x'. + modes (tuple): The length of the output transform axis. The `modes` must be no greater than half of the + dimension of input 'x'. + dim (tuple): Dimensions to be transformed. + compute_dtype (mindspore.dtype): The type of input tensor. Default: mindspore.float32. + + Inputs: + - **x** (Tensor, Tensor): The input data. It's 3-D tuple of Tensor. It's a complex, + including x real and imaginary. Tensor of shape :math:`(*, *)`. + + Returns: + Complex tensor with the same shape of input x. + + Raises: + TypeError: If `shape` is not a tuple. + ValueError: If the length of `shape` is no equal to 3. + + Examples: + >>> import numpy as np + >>> from mindspore import Tensor, ops + >>> import mindspore.common.dtype as mstype + >>> from mindflow.cell.neural_operators.dft import dft3 + >>> array = np.ones((6, 6, 6)) * np.arange(1, 7) + >>> x_re = Tensor(array, dtype=mstype.float32) + >>> x_im = x_re + >>> dft3_cell = dft3(shape=array.shape, modes=(2, 2, 2), compute_dtype=mstype.float32) + >>> ret, _ = dft3_cell((x_re, x_im)) + >>> print(ret) + [[[ 5.1439293e+01 -2.0076393e+01] + [ 7.9796671e-08 -1.9494735e-08] + [ 0.0000000e+00 0.0000000e+00] + [ 0.0000000e+00 0.0000000e+00] + [ 9.0537789e-08 1.0553553e-07] + [ 3.3567730e-07 1.0368046e-07]] + + [[ 4.7683722e-07 -3.1770034e-07] + [ 6.5267522e-15 -2.7775875e-15] + [ 0.0000000e+00 0.0000000e+00] + [ 0.0000000e+00 0.0000000e+00] + [-2.1755840e-15 -1.5215135e-15] + [ 3.6259736e-15 -4.0336615e-15]] + + [[ 0.0000000e+00 0.0000000e+00] + [ 0.0000000e+00 0.0000000e+00] + [ 0.0000000e+00 0.0000000e+00] + [ 0.0000000e+00 0.0000000e+00] + [ 0.0000000e+00 0.0000000e+00] + [ 0.0000000e+00 0.0000000e+00]] + + [[ 0.0000000e+00 0.0000000e+00] + [ 0.0000000e+00 0.0000000e+00] + [ 0.0000000e+00 0.0000000e+00] + [ 0.0000000e+00 0.0000000e+00] + [ 0.0000000e+00 0.0000000e+00] + [ 0.0000000e+00 0.0000000e+00]] + + [[ 1.1920930e-07 -5.1619136e-08] + [-3.6259733e-16 -1.0747753e-15] + [ 0.0000000e+00 0.0000000e+00] + [ 0.0000000e+00 0.0000000e+00] + [ 3.6259733e-16 -1.8129867e-16] + [ 3.6259733e-16 -1.4373726e-15]] + + [[ 5.9604650e-07 -2.5809570e-07] + [ 8.7023360e-15 -1.9812689e-15] + [ 0.0000000e+00 0.0000000e+00] + [ 0.0000000e+00 0.0000000e+00] + [ 2.9007787e-15 7.2519467e-16] + [ 8.7023360e-15 -1.7869532e-15]]] + + """ + check_param_type(shape, "shape", data_type=tuple) + check_param_type(modes, "modes", data_type=tuple) + check_param_value(len(shape), "shape length", 3) + check_param_value(len(modes), "modes length", 3) + check_param_even(shape, "shape") + check_param_no_greater(modes[0], "mode1", shape[0] // 2) + check_param_no_greater(modes[1], "mode2", shape[1] // 2) + check_param_no_greater(modes[2], "mode3", shape[2] // 2 + 1) + return _dftn(shape, modes, dim=dim, compute_dtype=compute_dtype) + + +def idft3(shape, modes, dim=(-3, -2, -1), compute_dtype=mindspore.float32): + r""" + Calculate three-dimensional discrete Fourier transform. Corresponding to the irfftn operator in torch. + + Args: + shape (tuple): Dimension of the input 'x'. + modes (tuple): The length of the output transform axis. The `modes` must be no greater than half of the + dimension of input 'x'. + dim (tuple): Dimensions to be transformed. + compute_dtype (mindspore.dtype): The type of input tensor. Default: mindspore.float32. + + Inputs: + - **x** (Tensor, Tensor): The input data. It's 3-D tuple of Tensor. It's a complex, including x real and + imaginary. Tensor of shape :math:`(*, *)`. + + Returns: + Complex tensor with the same shape of input x. + + Raises: + TypeError: If `shape` is not a tuple. + ValueError: If the length of `shape` is no equal to 3. + + Examples: + >>> import numpy as np + >>> from mindspore import Tensor, ops + >>> import mindspore.common.dtype as mstype + >>> from mindflow.cell.neural_operators.dft import idft3 + >>> array = np.ones((2, 2, 2)) * np.arange(1, 3) + >>> x_re = Tensor(array, dtype=mstype.float32) + >>> x_im = ops.zeros_like(x_re) + >>> idft3_cell = idft3(shape=(6, 6, 6), modes=(2, 2, 2), compute_dtype=mstype.float32) + >>> ret, _ = idft3_cell((x_re, x_im)) + >>> print(ret) + [[[ 5.44331074e+00 3.26598644e+00 -1.08866215e+00 -3.26598644e+00 -1.08866215e+00 3.26598644e+00] + [ 2.04124165e+00 2.04124165e+00 4.08248246e-01 -1.22474492e+00 -1.22474492e+00 4.08248365e-01] + [-6.80413842e-01 -1.22474492e+00 -6.80413783e-01 4.08248305e-01 9.52579379e-01 4.08248246e-01] + [ 0.00000000e+00 -2.30921616e-16 -2.30921616e-16 6.53092730e-32 2.30921616e-16 2.30921616e-16] + [-6.80413842e-01 4.08248246e-01 9.52579379e-01 4.08248305e-01 -6.80413783e-01 -1.22474492e+00] + [ 2.04124165e+00 4.08248365e-01 -1.22474492e+00 -1.22474492e+00 4.08248246e-01 2.04124165e+00]] + ...... + [[ 2.04124165e+00 4.08248544e-01 -1.22474492e+00 -1.22474504e+00 4.08248186e-01 2.04124165e+00] + [ 1.02062082e+00 6.12372518e-01 -2.04124182e-01 -6.12372518e-01 -2.04124182e-01 6.12372518e-01] + [-5.10310411e-01 -5.10310411e-01 -1.02062061e-01 3.06186229e-01 3.06186229e-01 -1.02062091e-01] + [-7.21630050e-17 -1.29893429e-16 -7.21630183e-17 4.32978030e-17 1.01028220e-16 4.32978163e-17] + [-6.08337416e-08 4.08248246e-01 4.08248305e-01 3.65002428e-08 -4.08248246e-01 -4.08248305e-01] + [ 5.10310471e-01 -3.06186140e-01 -7.14434564e-01 -3.06186318e-01 5.10310352e-01 9.18558717e-01]]] + + """ + check_param_type(shape, "shape", data_type=tuple) + check_param_type(modes, "modes", data_type=tuple) + check_param_value(len(shape), "shape length", 3) + check_param_value(len(modes), "modes length", 3) + check_param_even(shape, "shape") + check_param_no_greater(modes[0], "mode1", shape[0] // 2) + check_param_no_greater(modes[1], "mode2", shape[1] // 2) + check_param_no_greater(modes[2], "mode3", shape[2] // 2 + 1) + return _idftn(shape, modes, dim=dim, compute_dtype=compute_dtype) + + +def dft2(shape, modes, dim=(-2, -1), compute_dtype=mindspore.float32): + """ + Calculate two-dimensional discrete Fourier transform. Corresponding to the rfft2 operator in torch. + + Args: + shape (tuple): Dimension of the input 'x'. + modes (tuple): The length of the output transform axis. The `modes` must be no greater than half of the + dimension of input 'x'. + dim (tuple): Dimensions to be transformed. + compute_dtype (:class:`mindspore.dtype`): The type of input tensor. Default: mindspore.float32. + + Inputs: + - **x** (Tensor, Tensor): The input data. It's 2-D tuple of Tensor. It's a complex, + including x real and imaginary. Tensor of shape :math:`(*, *)`. + + Returns: + Complex tensor with the same shape of input x. + + Raises: + TypeError: If `shape` is not a tuple. + ValueError: If the length of `shape` is no equal to 2. + + Examples: + >>> import numpy as np + >>> from mindspore import Tensor, ops + >>> import mindspore.common.dtype as mstype + >>> from mindflow.cell.neural_operators.dft import dft2 + >>> array = np.ones((5, 5)) * np.arange(1, 6) + >>> x_re = Tensor(array, dtype=mstype.float32) + >>> x_im = x_re + >>> dft2_cell = dft2(shape=array.shape, modes=(2, 2), compute_dtype=mstype.float32) + >>> ret, _ = dft2_cell((x_re, x_im)) + >>> print(ret) + [[ 1.5000000e+01 -5.9409552e+00] + [-2.4656805e-07 7.6130398e-08] + [ 0.0000000e+00 0.0000000e+00] + [-1.9992007e-07 7.3572544e-08] + [-2.4656805e-07 7.6130398e-08]] + + """ + check_param_type(shape, "shape", data_type=tuple) + check_param_type(modes, "modes", data_type=tuple) + check_param_value(len(shape), "shape length", 2) + check_param_value(len(modes), "modes length", 2) + check_param_even(shape, "shape") + check_param_no_greater(modes[0], "mode1", shape[0] // 2) + check_param_no_greater(modes[1], "mode2", shape[1] // 2 + 1) + return _dftn(shape, modes, dim=dim, compute_dtype=compute_dtype) + + +def idft2(shape, modes, dim=(-2, -1), compute_dtype=mindspore.float32): + """ + Calculate two-dimensional discrete Fourier transform. Corresponding to the irfft2 operator in torch. + + Args: + shape (tuple): Dimension of the input 'x'. + modes (tuple): The length of the output transform axis. The `modes` must be no greater than half of the + dimension of input 'x'. + dim (tuple): Dimensions to be transformed. + compute_dtype (:class:`mindspore.dtype`): The type of input tensor. Default: mindspore.float32. + + Inputs: + - **x** (Tensor, Tensor): The input data. It's 2-D tuple of Tensor. It's a complex, + including x real and imaginary. Tensor of shape :math:`(*, *)`. + + Returns: + Complex tensor with the same shape of input x. + + Raises: + TypeError: If `shape` is not a tuple. + ValueError: If the length of `shape` is no equal to 2. + + Examples: + >>> import numpy as np + >>> from mindspore import Tensor, ops + >>> import mindspore.common.dtype as mstype + >>> from mindflow.cell.neural_operators.dft import idft2 + >>> array = np.ones((2, 2)) * np.arange(1, 3) + >>> x_re = Tensor(array, dtype=mstype.float32) + >>> x_im = ops.zeros_like(x_re) + >>> idft2_cell = idft2(shape=(5, 5), modes=(2, 2), compute_dtype=mstype.float32) + >>> ret, _ = idft2_cell((x_re, x_im)) + >>> print(ret) + [[ 3.9999998 1.7888544 -1.7888546 -1.7888546 1.7888544 ] + [ 0.80901694 0.80901694 -0.08541022 -0.6381966 -0.08541021] + [-0.30901706 -0.8618034 -0.30901694 0.5854102 0.5854101 ] + [-0.30901706 0.5854101 0.5854102 -0.30901694 -0.8618034 ] + [ 0.80901694 -0.08541021 -0.6381966 -0.08541022 0.80901694]] + + """ + check_param_type(shape, "shape", data_type=tuple) + check_param_type(modes, "modes", data_type=tuple) + check_param_value(len(shape), "shape length", 2) + check_param_value(len(modes), "modes length", 2) + check_param_even(shape, "shape") + check_param_no_greater(modes[0], "mode1", shape[0] // 2) + check_param_no_greater(modes[1], "mode2", shape[1] // 2 + 1) + return _idftn(shape, modes, dim=dim, compute_dtype=compute_dtype) + + +def dft1(shape, modes, dim=(-1,), compute_dtype=mindspore.float32): + """ + Calculate one-dimensional discrete Fourier transform. Corresponding to the rfft operator in torch. + + Args: + shape (tuple): Dimension of the input 'x'. + modes (int): The length of the output transform axis. The `modes` must be no greater than half of the + dimension of input 'x'. + dim (tuple): Dimensions to be transformed. + compute_dtype (:class:`mindspore.dtype`): The type of input tensor. + Default: mindspore.float32. + + Inputs: + - **x** (Tensor, Tensor): The input data. It's 2-D tuple of Tensor. It's a complex, + including x real and imaginary. Tensor of shape :math:`(*, *)`. + + Returns: + Complex tensor with the same shape of input x. + + Raises: + TypeError: If `shape` is not a tuple. + ValueError: If the length of `shape` is no equal to 1. + + Examples: + >>> from mindspore import Tensor, ops + >>> import mindspore.common.dtype as mstype + >>> from mindflow.cell.neural_operators.dft import dft1 + >>> array = [i for i in range(5)] + >>> x_re = Tensor(array, dtype=mstype.float32) + >>> x_im = ops.zeros_like(x_re) + >>> dft1_cell = dft1(shape=(len(x_re),), modes=2, compute_dtype=mstype.float32) + >>> ret, _ = dft1_cell((x_re, x_im)) + >>> print(ret) + [ 4.4721355 -1.1180341] + + """ + check_param_type(shape, "shape", data_type=tuple) + check_param_type(modes, "modes", data_type=int) + check_param_value(len(shape), "shape length", 1) + check_param_even(shape, "shape") + check_param_no_greater(modes, "mode1", shape[0] // 2 + 1) + modes = (modes,) + return _dftn(shape, modes, dim=dim, compute_dtype=compute_dtype) + + +def idft1(shape, modes, dim=(-1,), compute_dtype=mindspore.float32): + """ + Calculate one-dimensional discrete Fourier transform. Corresponding to the irfft operator in torch. + + Args: + shape (tuple): Dimension of the input 'x'. + modes (int): The length of the output transform axis. The `modes` must be no greater than half of the + dimension of input 'x'. + dim (tuple): Dimensions to be transformed. + compute_dtype (:class:`mindspore.dtype`): The type of input tensor. Default: mindspore.float32. + + Inputs: + - **x** (Tensor, Tensor): The input data. It's 2-D tuple of Tensor. It's a complex, + including x real and imaginary. Tensor of shape :math:`(*, *)`. + + Returns: + Complex tensor with the same shape of input x. + + Raises: + TypeError: If `shape` is not a tuple. + ValueError: If the length of `shape` is no equal to 1. + + Examples: + >>> from mindspore import Tensor, ops + >>> import mindspore.common.dtype as mstype + >>> from mindflow.cell.neural_operators.dft import idft1 + >>> array = [i for i in range(2)] + >>> x_re = Tensor(array, dtype=mstype.float32) + >>> x_im = x_re + >>> idft1_cell = idft1(shape=(len(x_re),), modes=2, compute_dtype=mstype.float32) + >>> ret, _ = idft1_cell((x_re, x_im)) + >>> print(ret) + [ 0.8944272 -0.5742576 -1.2493379 -0.19787574 1.127044 ] + + """ + check_param_type(shape, "shape", data_type=tuple) + check_param_type(modes, "modes", data_type=int) + check_param_value(len(shape), "shape length", 1) + check_param_even(shape, "shape") + check_param_no_greater(modes, "mode1", shape[0] // 2 + 1) + modes = (modes,) + return _idftn(shape, modes, dim=dim, compute_dtype=compute_dtype) + + +class SpectralConvDft(nn.Cell): + """Base Class for Fourier Layer, including DFT, linear transform, and Inverse DFT""" + + def __init__(self, in_channels, out_channels, n_modes, resolutions, compute_dtype=mstype.float32): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + if isinstance(n_modes, int): + n_modes = [n_modes] + self.n_modes = n_modes + if isinstance(resolutions, int): + resolutions = [resolutions] + self.resolutions = resolutions + if len(self.n_modes) != len(self.resolutions): + raise ValueError( + "The dimension of n_modes should be equal to that of resolutions, \ + but got dimension of n_modes {} and dimension of resolutions {}".format(len(self.n_modes), + len(self.resolutions))) + self.compute_dtype = compute_dtype + + def construct(self, x: Tensor): + raise NotImplementedError() + + def _einsum(self, inputs, weights): + weights = weights.expand_dims(0) + inputs = inputs.expand_dims(2) + out = inputs * weights + return out.sum(1) + + +class SpectralConv1dDft(SpectralConvDft): + """1D Fourier Layer. It does DFT, linear transform, and Inverse DFT.""" + + def __init__(self, in_channels, out_channels, n_modes, resolutions, compute_dtype=mstype.float32): + super().__init__(in_channels, out_channels, n_modes, resolutions) + self._scale = (1. / (self.in_channels * self.out_channels)) + w_re = Tensor(self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0]), + dtype=mstype.float32) + w_im = Tensor(self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0]), + dtype=mstype.float32) + self._w_re = Parameter(w_re, requires_grad=True) + self._w_im = Parameter(w_im, requires_grad=True) + self._dft1_cell = dft1(shape=(self.resolutions[0],), modes=self.n_modes[0], compute_dtype=self.compute_dtype) + self._idft1_cell = idft1(shape=(self.resolutions[0],), modes=self.n_modes[0], compute_dtype=self.compute_dtype) + + def construct(self, x: Tensor): + x_re = x + x_im = ops.zeros_like(x_re) + x_ft_re, x_ft_im = self._dft1_cell((x_re, x_im)) + w_re = P.Cast()(self._w_re, self.compute_dtype) + w_im = P.Cast()(self._w_im, self.compute_dtype) + out_ft_re = self._einsum(x_ft_re[:, :, :self.n_modes[0]], w_re) - self._einsum(x_ft_im[:, :, :self.n_modes[0]], + w_im) + out_ft_im = self._einsum(x_ft_re[:, :, :self.n_modes[0]], w_im) + self._einsum(x_ft_im[:, :, :self.n_modes[0]], + w_re) + + x, _ = self._idft1_cell((out_ft_re, out_ft_im)) + + return x + + +class SpectralConv2dDft(SpectralConvDft): + """2D Fourier Layer. It does DFT, linear transform, and Inverse DFT.""" + + def __init__(self, in_channels, out_channels, n_modes, resolutions, compute_dtype=mstype.float32): + super().__init__(in_channels, out_channels, n_modes, resolutions) + self._scale = (1. / (self.in_channels * self.out_channels)) + w_re1 = Tensor( + self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1]), + dtype=self.compute_dtype) + w_im1 = Tensor( + self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1]), + dtype=self.compute_dtype) + w_re2 = Tensor( + self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1]), + dtype=self.compute_dtype) + w_im2 = Tensor( + self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1]), + dtype=self.compute_dtype) + + self._w_re1 = Parameter(w_re1, requires_grad=True) + self._w_im1 = Parameter(w_im1, requires_grad=True) + self._w_re2 = Parameter(w_re2, requires_grad=True) + self._w_im2 = Parameter(w_im2, requires_grad=True) + + self._dft2_cell = dft2(shape=(self.resolutions[0], self.resolutions[1]), + modes=(self.n_modes[0], self.n_modes[1]), compute_dtype=self.compute_dtype) + self._idft2_cell = idft2(shape=(self.resolutions[0], self.resolutions[1]), + modes=(self.n_modes[0], self.n_modes[1]), compute_dtype=self.compute_dtype) + self._mat = Tensor(shape=(1, self.out_channels, self.resolutions[1] - 2 * self.n_modes[0], self.n_modes[1]), + dtype=self.compute_dtype, init=Zero()) + self._concat = ops.Concat(-2) + + def construct(self, x: Tensor): + x_re = x + x_im = ops.zeros_like(x_re) + x_ft_re, x_ft_im = self._dft2_cell((x_re, x_im)) + + out_ft_re1 = self._einsum(x_ft_re[:, :, :self.n_modes[0], :self.n_modes[1]], self._w_re1) - self._einsum( + x_ft_im[:, :, :self.n_modes[0], :self.n_modes[1]], self._w_im1) + out_ft_im1 = self._einsum(x_ft_re[:, :, :self.n_modes[0], :self.n_modes[1]], self._w_im1) + self._einsum( + x_ft_im[:, :, :self.n_modes[0], :self.n_modes[1]], self._w_re1) + + out_ft_re2 = self._einsum(x_ft_re[:, :, -self.n_modes[0]:, :self.n_modes[1]], self._w_re2) - self._einsum( + x_ft_im[:, :, -self.n_modes[0]:, :self.n_modes[1]], self._w_im2) + out_ft_im2 = self._einsum(x_ft_re[:, :, -self.n_modes[0]:, :self.n_modes[1]], self._w_im2) + self._einsum( + x_ft_im[:, :, -self.n_modes[0]:, :self.n_modes[1]], self._w_re2) + + batch_size = x.shape[0] + mat = mint.repeat_interleave(self._mat, batch_size, 0) + out_re = self._concat((out_ft_re1, mat, out_ft_re2)) + out_im = self._concat((out_ft_im1, mat, out_ft_im2)) + + x, _ = self._idft2_cell((out_re, out_im)) + + return x + + +class SpectralConv3dDft(SpectralConvDft): + """3D Fourier layer. It does DFT, linear transform, and Inverse DFT.""" + + def __init__(self, in_channels, out_channels, n_modes, resolutions, compute_dtype=mstype.float32): + super().__init__(in_channels, out_channels, n_modes, resolutions) + self._scale = (1 / (self.in_channels * self.out_channels)) + + w_re1 = Tensor( + self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1], + self.n_modes[2]), dtype=self.compute_dtype) + w_im1 = Tensor( + self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1], + self.n_modes[2]), dtype=self.compute_dtype) + w_re2 = Tensor( + self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1], + self.n_modes[2]), dtype=self.compute_dtype) + w_im2 = Tensor( + self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1], + self.n_modes[2]), dtype=self.compute_dtype) + w_re3 = Tensor( + self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1], + self.n_modes[2]), dtype=self.compute_dtype) + w_im3 = Tensor( + self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1], + self.n_modes[2]), dtype=self.compute_dtype) + w_re4 = Tensor( + self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1], + self.n_modes[2]), dtype=self.compute_dtype) + w_im4 = Tensor( + self._scale * np.random.rand(self.in_channels, self.out_channels, self.n_modes[0], self.n_modes[1], + self.n_modes[2]), dtype=self.compute_dtype) + + self._w_re1 = Parameter(w_re1, requires_grad=True) + self._w_im1 = Parameter(w_im1, requires_grad=True) + self._w_re2 = Parameter(w_re2, requires_grad=True) + self._w_im2 = Parameter(w_im2, requires_grad=True) + self._w_re3 = Parameter(w_re3, requires_grad=True) + self._w_im3 = Parameter(w_im3, requires_grad=True) + self._w_re4 = Parameter(w_re4, requires_grad=True) + self._w_im4 = Parameter(w_im4, requires_grad=True) + + self._dft3_cell = dft3(shape=(self.resolutions[0], self.resolutions[1], self.resolutions[2]), + modes=(self.n_modes[0], self.n_modes[1], self.n_modes[2]), + compute_dtype=self.compute_dtype) + self._idft3_cell = idft3(shape=(self.resolutions[0], self.resolutions[1], self.resolutions[2]), + modes=(self.n_modes[0], self.n_modes[1], self.n_modes[2]), + compute_dtype=self.compute_dtype) + self._mat_x = Tensor( + shape=(1, self.out_channels, self.resolutions[0] - 2 * self.n_modes[0], self.n_modes[1], self.n_modes[2]), + dtype=self.compute_dtype, init=Zero()) + self._mat_y = Tensor( + shape=(1, self.out_channels, self.resolutions[0], self.resolutions[1] - 2 * self.n_modes[1], + self.n_modes[2]), + dtype=self.compute_dtype, init=Zero()) + self._concat = ops.Concat(-2) + + def construct(self, x: Tensor): + x_re = x + x_im = ops.zeros_like(x_re) + x_ft_re, x_ft_im = self._dft3_cell((x_re, x_im)) + + out_ft_re1 = self._einsum(x_ft_re[:, :, :self.n_modes[0], :self.n_modes[1], :self.n_modes[2]], + self._w_re1) - self._einsum(x_ft_im[:, :, :self.n_modes[0], :self.n_modes[1], + :self.n_modes[2]], self._w_im1) + out_ft_im1 = self._einsum(x_ft_re[:, :, :self.n_modes[0], :self.n_modes[1], :self.n_modes[2]], + self._w_im1) + self._einsum(x_ft_im[:, :, :self.n_modes[0], :self.n_modes[1], + :self.n_modes[2]], self._w_re1) + out_ft_re2 = self._einsum(x_ft_re[:, :, -self.n_modes[0]:, :self.n_modes[1], :self.n_modes[2]], + self._w_re2) - self._einsum(x_ft_im[:, :, -self.n_modes[0]:, :self.n_modes[1], + :self.n_modes[2]], self._w_im2) + out_ft_im2 = self._einsum(x_ft_re[:, :, -self.n_modes[0]:, :self.n_modes[1], :self.n_modes[2]], + self._w_im2) + self._einsum(x_ft_im[:, :, -self.n_modes[0]:, :self.n_modes[1], + :self.n_modes[2]], self._w_re2) + out_ft_re3 = self._einsum(x_ft_re[:, :, :self.n_modes[0], -self.n_modes[1]:, :self.n_modes[2]], + self._w_re3) - self._einsum(x_ft_im[:, :, :self.n_modes[0], -self.n_modes[1]:, + :self.n_modes[2]], self._w_im3) + out_ft_im3 = self._einsum(x_ft_re[:, :, :self.n_modes[0], -self.n_modes[1]:, :self.n_modes[2]], + self._w_im3) + self._einsum(x_ft_im[:, :, :self.n_modes[0], -self.n_modes[1]:, + :self.n_modes[2]], self._w_re3) + out_ft_re4 = self._einsum(x_ft_re[:, :, -self.n_modes[0]:, -self.n_modes[1]:, :self.n_modes[2]], + self._w_re4) - self._einsum(x_ft_im[:, :, -self.n_modes[0]:, -self.n_modes[1]:, + :self.n_modes[2]], self._w_im4) + out_ft_im4 = self._einsum(x_ft_re[:, :, -self.n_modes[0]:, -self.n_modes[1]:, :self.n_modes[2]], + self._w_im4) + self._einsum(x_ft_im[:, :, -self.n_modes[0]:, -self.n_modes[1]:, + :self.n_modes[2]], self._w_re4) + + batch_size = x.shape[0] + mat_x = mint.repeat_interleave(self._mat_x, batch_size, 0) + mat_y = mint.repeat_interleave(self._mat_y, batch_size, 0) + + out_re1 = ops.concat((out_ft_re1, mat_x, out_ft_re2), -3) + out_im1 = ops.concat((out_ft_im1, mat_x, out_ft_im2), -3) + + out_re2 = ops.concat((out_ft_re3, mat_x, out_ft_re4), -3) + out_im2 = ops.concat((out_ft_im3, mat_x, out_ft_im4), -3) + out_re = ops.concat((out_re1, mat_y, out_re2), -2) + out_im = ops.concat((out_im1, mat_y, out_im2), -2) + x, _ = self._idft3_cell((out_re, out_im)) + + return x diff --git a/MindFlow/mindflow/cell/neural_operators/fno.py b/MindFlow/mindflow/cell/neural_operators/fno.py new file mode 100644 index 0000000000000000000000000000000000000000..4c4c644ac94e06d71fa68b3593bea8854d8c545f --- /dev/null +++ b/MindFlow/mindflow/cell/neural_operators/fno.py @@ -0,0 +1,672 @@ +'''' +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +''' +# pylint: disable=W0235 + +from mindspore import nn, ops, Tensor, mint +import mindspore.common.dtype as mstype + +from .fno_sp import SpectralConv1dDft, SpectralConv2dDft, SpectralConv3dDft +from ..activation import get_activation +from ...core.math import get_grid_1d, get_grid_2d, get_grid_3d +from ...utils.check_func import check_param_type + + +class FNOBlocks(nn.Cell): + r""" + The FNOBlock, which usually accompanied by a Lifting Layer ahead and a Projection Layer behind, + is a part of Fourier Neural Operator. It contains a Fourier Layer and a FNO Skip Layer. + The details can be found in `Zongyi Li, et. al: FOURIER NEURAL OPERATOR FOR PARAMETRIC PARTIAL + DIFFERENTIAL EQUATIONS `_. + + Args: + in_channels (int): The number of channels in the input space. + out_channels (int): The number of channels in the output space. + n_modes (Union[int, list(int)]): The number of modes reserved after linear transformation in Fourier Layer. + resolutions (Union[int, list(int)]): The resolutions of the input tensor. + act (Union[str, class]): The activation function, could be either str or class. Default: ``gelu``. + add_residual (bool): Whether to add residual in FNOBlock or not. Default: ``False``. + dft_compute_dtype (dtype.Number): The computation type of DFT in SpectralConvDft. Default: ``mstype.float32``. + fno_compute_dtype (dtype.Number): The computation type of MLP in fno skip. Default: ``mstype.float16``. + Should be ``mstype.float32`` or ``mstype.float16``. mstype.float32 is recommended for + the GPU backend, mstype.float16 is recommended for the Ascend backend. + + Inputs: + - **x** (Tensor) - Tensor of shape :math:`(batch\_size, in\_channels, resolution)`. + + Outputs: + Tensor, the output of this FNOBlocks. + + - **output** (Tensor) -Tensor of shape :math:`(batch\_size, out\_channels, resolution)`. + + Raises: + TypeError: If `in_channels` is not an int. + TypeError: If `out_channels` is not an int. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> from mindspore import Tensor + >>> import mindspore.common.dtype as mstype + >>> from mindflow.cell.neural_operators import FNOBlocks + >>> data = Tensor(np.ones([2, 3, 128, 128]), mstype.float32) + >>> net = FNOBlocks(in_channels=3, out_channels=3, n_modes=[20, 20], resolutions=[128, 128]) + >>> out = net(data) + >>> print(data.shape, out.shape) + (2, 3, 128, 128) (2, 3, 128, 128) + """ + + def __init__(self, + in_channels, + out_channels, + n_modes, + resolutions, + act="gelu", + add_residual=False, + dft_compute_dtype=mstype.float32, + fno_compute_dtype=mstype.float16 + ): + super().__init__() + check_param_type(in_channels, "in_channels", data_type=int) + check_param_type(out_channels, "out_channels", data_type=int) + self.in_channels = in_channels + self.out_channels = out_channels + if isinstance(n_modes, int): + n_modes = [n_modes] + self.n_modes = n_modes + if isinstance(resolutions, int): + resolutions = [resolutions] + self.resolutions = resolutions + if len(self.n_modes) != len(self.resolutions): + raise ValueError( + "The dimension of n_modes should be equal to that of resolutions\ + but got dimension of n_modes {} and dimension of resolutions {}".format(len(self.n_modes), + len(self.resolutions))) + self.act = get_activation(act) if isinstance(act, str) else act + self.add_residual = add_residual + self.dft_compute_dtype = dft_compute_dtype + self.fno_compute_dtype = fno_compute_dtype + + if len(self.resolutions) == 1: + self._convs = SpectralConv1dDft( + self.in_channels, + self.out_channels, + self.n_modes, + self.resolutions, + compute_dtype=self.dft_compute_dtype + ) + self._fno_skips = nn.Conv1d( + self.in_channels, + self.out_channels, + kernel_size=1, + has_bias=False, + weight_init="HeUniform" + ).to_float(self.fno_compute_dtype) + elif len(self.resolutions) == 2: + self._convs = SpectralConv2dDft( + self.in_channels, + self.out_channels, + self.n_modes, + self.resolutions, + compute_dtype=self.dft_compute_dtype + ) + self._fno_skips = nn.Conv2d( + self.in_channels, + self.out_channels, + kernel_size=1, + has_bias=False, + weight_init="HeUniform" + ).to_float(self.fno_compute_dtype) + elif len(self.resolutions) == 3: + self._convs = SpectralConv3dDft( + self.in_channels, + self.out_channels, + self.n_modes, + self.resolutions, + compute_dtype=self.dft_compute_dtype + ) + self._fno_skips = nn.Conv3d( + self.in_channels, + self.out_channels, + kernel_size=1, + has_bias=False, + weight_init="HeUniform" + ).to_float(self.fno_compute_dtype) + else: + raise ValueError("The length of input resolutions dimensions should be in [1, 2, 3], but got: {}".format( + len(self.resolutions))) + + def construct(self, x: Tensor): + if self.add_residual: + x = self.act(self._convs(x) + self._fno_skips(x)) + x + else: + x = self.act(self._convs(x) + self._fno_skips(x)) + return x + + +class FNO(nn.Cell): + r""" + The FNO base class, which usually contains a Lifting Layer, a Fourier Block Layer and a Projection Layer. + The details can be found in `Zongyi Li, et. al: FOURIER NEURAL OPERATOR FOR PARAMETRIC PARTIAL DIFFERENTIAL + EQUATIONS `_. + + Args: + in_channels (int): The number of channels in the input space. + out_channels (int): The number of channels in the output space. + n_modes (Union[int, list(int)]): The number of modes reserved after linear transformation in Fourier Layer. + resolutions (Union[int, list(int)]): The resolutions of the input tensor. + hidden_channels (int): The number of channels of the FNOBlock input and output. Default: ``20``. + lifting_channels (int): The number of channels of the lifting layer mid channels. Default: None. + projection_channels (int): The number of channels of the projection layer mid channels. Default: ``128``. + n_layers (int): The number that Fourier Layer nests. Default: ``4``. + data_format (str): The input data channel sequence. Default: ``channels_last``. + fnoblock_act (Union[str, class]): The activation function for FNOBlock, could be either str or class. + Default: ``identity``. + mlp_act (Union[str, class]): The activation function for MLP layers, could be either str or class. + Default: ``gelu``. + add_residual (bool): Whether to add residual in FNOBlock or not. Default: ``False``. + positional_embedding (bool): Whether to embed positional information or not. Default: ``True``. + dft_compute_dtype (dtype.Number): The computation type of DFT in SpectralConvDft. Default: ``mstype.float32``. + fno_compute_dtype (dtype.Number): The computation type of MLP in fno skip. Default: ``mstype.float16``. + Should be ``mstype.float32`` or ``mstype.float16``. mstype.float32 is recommended for + the GPU backend, mstype.float16 is recommended for the Ascend backend. + + Inputs: + - **x** (Tensor) - Tensor of shape :math:`(batch\_size, resolution, in\_channels)`. + + Outputs: + Tensor, the output of this FNOBlocks. + + - **output** (Tensor) -Tensor of shape :math:`(batch\_size, resolution, out\_channels)`. + + Raises: + TypeError: If `in_channels` is not an int. + TypeError: If `out_channels` is not an int. + TypeError: If `hidden_channels` is not an int. + TypeError: If `lifting_channels` is not an int. + TypeError: If `projection_channels` is not an int. + TypeError: If `n_layers` is not an int. + TypeError: If `data_format` is not a str. + TypeError: If `add_residual` is not an bool. + TypeError: If `positional_embedding` is not an bool. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> from mindspore import Tensor + >>> import mindspore.common.dtype as mstype + >>> from mindflow.cell.neural_operators.fno import FNO + >>> data = Tensor(np.ones([2, 3, 128, 128]), mstype.float32) + >>> net = FNO(in_channels=3, out_channels=3, n_modes=[20, 20], resolutions=[128, 128]) + >>> out = net(data) + >>> print(data.shape, out.shape) + (2, 3, 128, 128) (2, 3, 128, 128) + """ + + def __init__( + self, + in_channels, + out_channels, + n_modes, + resolutions, + hidden_channels=20, + lifting_channels=None, + projection_channels=128, + n_layers=4, + data_format="channels_last", + fnoblock_act="gelu", + mlp_act="gelu", + add_residual=False, + positional_embedding=True, + dft_compute_dtype=mstype.float32, + fno_compute_dtype=mstype.float16 + ): + super().__init__() + check_param_type(in_channels, "in_channels", data_type=int, exclude_type=bool) + check_param_type(out_channels, "out_channels", data_type=int, exclude_type=bool) + check_param_type(hidden_channels, "hidden_channels", data_type=int, exclude_type=bool) + check_param_type(n_layers, "n_layers", data_type=int, exclude_type=bool) + check_param_type(data_format, "data_format", data_type=str, exclude_type=bool) + check_param_type(positional_embedding, "positional_embedding", data_type=bool, exclude_type=str) + check_param_type(add_residual, "add_residual", data_type=bool, exclude_type=str) + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.lifting_channels = lifting_channels + self.projection_channels = projection_channels + if isinstance(n_modes, int): + n_modes = [n_modes] + self.n_modes = n_modes + if isinstance(resolutions, int): + resolutions = [resolutions] + self.resolutions = resolutions + if len(self.n_modes) != len(self.resolutions): + raise ValueError( + "The dimension of n_modes should be equal to that of resolutions\ + but got dimension of n_modes {} and dimension of resolutions {}".format(len(self.n_modes), + len(self.resolutions))) + self.n_layers = n_layers + self.data_format = data_format + if fnoblock_act == "identity": + self.fnoblock_act = ops.Identity() + else: + self.fnoblock_act = get_activation(fnoblock_act) if isinstance(fnoblock_act, str) else fnoblock_act + self.mlp_act = get_activation(mlp_act) if isinstance(mlp_act, str) else mlp_act + self.add_residual = add_residual + self.positional_embedding = positional_embedding + if self.positional_embedding: + self.in_channels += len(self.resolutions) + self.dft_compute_dtype = dft_compute_dtype + self.fno_compute_dtype = fno_compute_dtype + self._concat = ops.Concat(axis=-1) + self._positional_embedding, self._input_perm, self._output_perm = self._transpose(len(self.resolutions)) + if self.lifting_channels: + self._lifting = nn.SequentialCell([ + nn.Dense(self.in_channels, self.lifting_channels, has_bias=False).to_float(self.fno_compute_dtype), + nn.Dense(self.lifting_channels, self.hidden_channels, has_bias=False).to_float(self.fno_compute_dtype) + ]) + else: + self._lifting = nn.SequentialCell( + nn.Dense(self.in_channels, self.hidden_channels, has_bias=False).to_float(self.fno_compute_dtype) + ) + self._fno_blocks = nn.SequentialCell() + for _ in range(self.n_layers): + self._fno_blocks.append(FNOBlocks(self.hidden_channels, self.hidden_channels, n_modes=self.n_modes, + resolutions=self.resolutions, act=self.fnoblock_act, + add_residual=self.add_residual, dft_compute_dtype=self.dft_compute_dtype, + fno_compute_dtype=self.fno_compute_dtype)) + if self.projection_channels: + self._projection = nn.SequentialCell([ + nn.Dense(self.hidden_channels, self.projection_channels, has_bias=False).to_float( + self.fno_compute_dtype), + self.mlp_act, + nn.Dense(self.projection_channels, self.out_channels, has_bias=False).to_float(self.fno_compute_dtype) + ]) + else: + self._projection = nn.SequentialCell( + nn.Dense(self.hidden_channels, self.out_channels, has_bias=False).to_float(self.fno_compute_dtype)) + + def construct(self, x: Tensor): + """construct""" + batch_size = x.shape[0] + grid = mint.repeat_interleave(self._positional_embedding.astype(x.dtype), batch_size, dim=0) + if self.data_format != "channels_last": + x = ops.transpose(x, input_perm=self._output_perm) + if self.positional_embedding: + x = self._concat((x, grid)) + x = self._lifting(x) + x = ops.transpose(x, input_perm=self._input_perm) + x = self._fno_blocks(x) + x = ops.transpose(x, input_perm=self._output_perm) + x = self._projection(x) + if self.data_format != "channels_last": + x = ops.transpose(x, input_perm=self._input_perm) + return x + + def _transpose(self, n_dim): + """transpose tensor""" + if n_dim == 1: + positional_embedding = Tensor(get_grid_1d(resolution=self.resolutions)) + input_perm = (0, 2, 1) + output_perm = (0, 2, 1) + elif n_dim == 2: + positional_embedding = Tensor(get_grid_2d(resolution=self.resolutions)) + input_perm = (0, 3, 1, 2) + output_perm = (0, 2, 3, 1) + elif n_dim == 3: + positional_embedding = Tensor(get_grid_3d(resolution=self.resolutions)) + input_perm = (0, 4, 1, 2, 3) + output_perm = (0, 2, 3, 4, 1) + else: + raise ValueError( + "The length of input resolutions dimensions should be in [1, 2, 3], but got: {}".format(n_dim)) + return positional_embedding, input_perm, output_perm + + +class FNO1D(FNO): + r""" + The 1D Fourier Neural Operator, which usually contains a Lifting Layer, + a Fourier Block Layer and a Projection Layer. The details can be found in + `Zongyi Li, et. al: FOURIER NEURAL OPERATOR FOR PARAMETRIC PARTIAL DIFFERENTIAL EQUATIONS + `_. + + Args: + in_channels (int): The number of channels in the input space. + out_channels (int): The number of channels in the output space. + n_modes (Union[int, list(int)]): The number of modes reserved after linear transformation in Fourier Layer. + resolutions (Union[int, list(int)]): The resolutions of the input tensor. + hidden_channels (int): The number of channels of the FNOBlock input and output. Default: ``20``. + lifting_channels (int): The number of channels of the lifting layer mid channels. Default: None. + projection_channels (int): The number of channels of the projection layer mid channels. Default: ``128``. + n_layers (int): The number that Fourier Layer nests. Default: ``4``. + data_format (str): The input data channel sequence. Default: ``"channels_last"``. + Support value: ``"channels_last"``, ``"channels_first"``. + fnoblock_act (Union[str, class]): The activation function for FNOBlock, could be either str or class. + Default: ``"gelu"``. + mlp_act (Union[str, class]): The activation function for MLP layers, could be either str or class. + Default: ``gelu``. + add_residual (bool): Whether to add residual in FNOBlock or not. Default: ``False``. + positional_embedding (bool): Whether to embed positional information or not. Default: ``True``. + dft_compute_dtype (dtype.Number): The computation type of DFT in SpectralConvDft. Default: ``mstype.float32``. + fno_compute_dtype (dtype.Number): The computation type of MLP in fno skip. Default: ``mstype.float16``. + Should be ``mstype.float32`` or ``mstype.float16``. mstype.float32 is recommended for + the GPU backend, mstype.float16 is recommended for the Ascend backend. + + Inputs: + - **x** (Tensor) - Tensor of shape :math:`(batch\_size, resolution, in\_channels)`. + + Outputs: + Tensor, the output of this FNOBlocks. + + - **output** (Tensor) -Tensor of shape :math:`(batch\_size, resolution, out\_channels)`. + + Raises: + TypeError: If `in_channels` is not an int. + TypeError: If `out_channels` is not an int. + TypeError: If `hidden_channels` is not an int. + TypeError: If `lifting_channels` is not an int. + TypeError: If `projection_channels` is not an int. + TypeError: If `n_layers` is not an int. + TypeError: If `data_format` is not a str. + TypeError: If `add_residual` is not an bool. + TypeError: If `positional_embedding` is not an bool. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> import mindspore + >>> import mindflow + >>> from mindspore import Tensor + >>> import mindspore.common.dtype as mstype + >>> from mindflow.cell import FNO1D + >>> data = Tensor(np.ones([2, 128, 3]), mstype.float32) + >>> net = FNO1D(in_channels=3, out_channels=3, n_modes=[20], resolutions=[128]) + >>> out = net(data) + >>> print(data.shape, out.shape) + (2, 128, 3) (2, 128, 3) + """ + + def __init__( + self, + in_channels, + out_channels, + n_modes, + resolutions, + hidden_channels=20, + lifting_channels=None, + projection_channels=128, + n_layers=4, + data_format="channels_last", + fnoblock_act="gelu", + mlp_act="gelu", + add_residual=False, + positional_embedding=True, + dft_compute_dtype=mstype.float32, + fno_compute_dtype=mstype.float16 + ): + super().__init__( + in_channels, + out_channels, + n_modes, + resolutions, + hidden_channels, + lifting_channels, + projection_channels, + n_layers, + data_format, + fnoblock_act, + mlp_act, + add_residual, + positional_embedding, + dft_compute_dtype, + fno_compute_dtype + ) + + +class FNO2D(FNO): + r""" + The 2D Fourier Neural Operator, which usually contains a Lifting Layer, + a Fourier Block Layer and a Projection Layer. The details can be found in + `Zongyi Li, et. al: FOURIER NEURAL OPERATOR FOR PARAMETRIC PARTIAL DIFFERENTIAL EQUATIONS + `_. + + Args: + in_channels (int): The number of channels in the input space. + out_channels (int): The number of channels in the output space. + n_modes (Union[int, list(int)]): The number of modes reserved after linear transformation in Fourier Layer. + resolutions (Union[int, list(int)]): The resolutions of the input tensor. + hidden_channels (int): The number of channels of the FNOBlock input and output. Default: ``20``. + lifting_channels (int): The number of channels of the lifting layer mid channels. Default: None. + projection_channels (int): The number of channels of the projection layer mid channels. Default: ``128``. + n_layers (int): The number that Fourier Layer nests. Default: ``4``. + data_format (str): The input data channel sequence. Default: ``channels_last``. + Support value: ``"channels_last"``, ``"channels_first"``. + fnoblock_act (Union[str, class]): The activation function for FNOBlock, could be either str or class. + Default: ``"gelu"``. + mlp_act (Union[str, class]): The activation function for MLP layers, could be either str or class. + Default: ``gelu``. + add_residual (bool): Whether to add residual in FNOBlock or not. Default: ``False``. + positional_embedding (bool): Whether to embed positional information or not. Default: ``True``. + dft_compute_dtype (dtype.Number): The computation type of DFT in SpectralConvDft. Default: ``mstype.float32``. + fno_compute_dtype (dtype.Number): The computation type of MLP in fno skip. Default: ``mstype.float16``. + Should be ``mstype.float32`` or ``mstype.float16``. mstype.float32 is recommended for + the GPU backend, mstype.float16 is recommended for the Ascend backend. + + Inputs: + - **x** (Tensor) - Tensor of shape :math:`(batch\_size, resolution[0], resolution[1], in\_channels)`. + + Outputs: + Tensor, the output of this FNOBlocks. + + - **output** (Tensor) -Tensor of shape :math:`(batch\_size, resolution[0], resolution[1], out\_channels)`. + + Raises: + TypeError: If `in_channels` is not an int. + TypeError: If `out_channels` is not an int. + TypeError: If `hidden_channels` is not an int. + TypeError: If `lifting_channels` is not an int. + TypeError: If `projection_channels` is not an int. + TypeError: If `n_layers` is not an int. + TypeError: If `data_format` is not a str. + TypeError: If `add_residual` is not an bool. + TypeError: If `positional_embedding` is not an bool. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> import mindspore + >>> import mindflow + >>> from mindspore import Tensor + >>> import mindspore.common.dtype as mstype + >>> from mindflow.cell import FNO2D + >>> data = Tensor(np.ones([2, 128, 128, 3]), mstype.float32) + >>> net = FNO2D(in_channels=3, out_channels=3, n_modes=[20, 20], resolutions=[128, 128]) + >>> out = net(data) + >>> print(data.shape, out.shape) + (2, 128, 128, 3) (2, 128, 128, 3) + """ + + def __init__( + self, + in_channels, + out_channels, + n_modes, + resolutions, + hidden_channels=20, + lifting_channels=None, + projection_channels=128, + n_layers=4, + data_format="channels_last", + fnoblock_act="gelu", + mlp_act="gelu", + add_residual=False, + positional_embedding=True, + dft_compute_dtype=mstype.float32, + fno_compute_dtype=mstype.float16 + ): + if isinstance(n_modes, int): + n_modes = [n_modes, n_modes] + if isinstance(resolutions, int): + resolutions = [resolutions, resolutions] + if len(n_modes) != 2: + raise ValueError( + "The dimension of n_modes should be equal to 2 when using FNO2D\ + but got dimension of n_modes {}".format(len(n_modes))) + if len(resolutions) != 2: + raise ValueError( + "The dimension of resolutions should be equal to 2 when using FNO2D\ + but got dimension of resolutions {}".format(len(resolutions))) + super().__init__( + in_channels, + out_channels, + n_modes, + resolutions, + hidden_channels, + lifting_channels, + projection_channels, + n_layers, + data_format, + fnoblock_act, + mlp_act, + add_residual, + positional_embedding, + dft_compute_dtype, + fno_compute_dtype + ) + + +class FNO3D(FNO): + r""" + The 3D Fourier Neural Operator, which usually contains a Lifting Layer, + a Fourier Block Layer and a Projection Layer. The details can be found in + `Zongyi Li, et. al: FOURIER NEURAL OPERATOR FOR PARAMETRIC PARTIAL DIFFERENTIAL EQUATIONS + `_. + + Args: + in_channels (int): The number of channels in the input space. + out_channels (int): The number of channels in the output space. + n_modes (Union[int, list(int)]): The number of modes reserved after linear transformation in Fourier Layer. + resolutions (Union[int, list(int)]): The resolutions of the input tensor. + hidden_channels (int): The number of channels of the FNOBlock input and output. Default: ``20``. + lifting_channels (int): The number of channels of the lifting layer mid channels. Default: None. + projection_channels (int): The number of channels of the projection layer mid channels. Default: ``128``. + n_layers (int): The number that Fourier Layer nests. Default: ``4``. + data_format (str): The input data channel sequence. Default: ``channels_last``. + Support value: ``"channels_last"``, ``"channels_first"``. + fnoblock_act (Union[str, class]): The activation function for FNOBlock, could be either str or class. + Default: ``"gelu"``. + mlp_act (Union[str, class]): The activation function for MLP layers, could be either str or class. + Default: ``gelu``. + add_residual (bool): Whether to add residual in FNOBlock or not. Default: ``False``. + positional_embedding (bool): Whether to embed positional information or not. Default: ``True``. + dft_compute_dtype (dtype.Number): The computation type of DFT in SpectralConvDft. Default: ``mstype.float32``. + fno_compute_dtype (dtype.Number): The computation type of MLP in fno skip. Default: ``mstype.float16``. + Should be ``mstype.float32`` or ``mstype.float16``. mstype.float32 is recommended for + the GPU backend, mstype.float16 is recommended for the Ascend backend. + + Inputs: + - **x** (Tensor) - Tensor of shape :math:`(batch\_size, resolution[0], resolution[1], resolution[2], \ + in\_channels)`. + + Outputs: + Tensor, the output of this FNOBlocks. + + - **output** (Tensor) -Tensor of shape :math:`(batch\_size, resolution[0], resolution[1], + resolution[2], out\_channels)`. + + Raises: + TypeError: If `in_channels` is not an int. + TypeError: If `out_channels` is not an int. + TypeError: If `hidden_channels` is not an int. + TypeError: If `lifting_channels` is not an int. + TypeError: If `projection_channels` is not an int. + TypeError: If `n_layers` is not an int. + TypeError: If `data_format` is not a str. + TypeError: If `add_residual` is not an bool. + TypeError: If `positional_embedding` is not an bool. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> import mindspore + >>> import mindflow + >>> from mindspore import Tensor + >>> import mindspore.common.dtype as mstype + >>> from mindflow.cell import FNO3D + >>> data = Tensor(np.ones([2, 128, 128, 128, 3]), mstype.float32) + >>> net = FNO3D(in_channels=3, out_channels=3, n_modes=[20, 20, 20], resolutions=[128, 128, 128]) + >>> out = net(data) + >>> print(data.shape, out.shape) + (2, 128, 128, 128, 3) (2, 128, 128, 128, 3) + """ + + def __init__( + self, + in_channels, + out_channels, + n_modes, + resolutions, + hidden_channels=20, + lifting_channels=None, + projection_channels=128, + n_layers=4, + data_format="channels_last", + fnoblock_act="gelu", + mlp_act="gelu", + add_residual=False, + positional_embedding=True, + dft_compute_dtype=mstype.float32, + fno_compute_dtype=mstype.float16 + ): + if isinstance(n_modes, int): + n_modes = [n_modes, n_modes, n_modes] + if isinstance(resolutions, int): + resolutions = [resolutions, resolutions, resolutions] + if len(n_modes) != 3: + raise ValueError( + "The dimension of n_modes should be equal to 3 when using FNO3D\ + but got dimension of n_modes {}".format(len(n_modes))) + if len(resolutions) != 3: + raise ValueError( + "The dimension of resolutions should be equal to 3 when using FNO3D\ + but got dimension of resolutions {}".format(len(resolutions))) + super().__init__( + in_channels, + out_channels, + n_modes, + resolutions, + hidden_channels, + lifting_channels, + projection_channels, + n_layers, + data_format, + fnoblock_act, + mlp_act, + add_residual, + positional_embedding, + dft_compute_dtype, + fno_compute_dtype + ) diff --git a/MindFlow/mindflow/cell/neural_operators/kno1d.py b/MindFlow/mindflow/cell/neural_operators/kno1d.py new file mode 100644 index 0000000000000000000000000000000000000000..81820de728a5a6d5563e04c6819e28eb6db5bf1b --- /dev/null +++ b/MindFlow/mindflow/cell/neural_operators/kno1d.py @@ -0,0 +1,117 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""KNO1D""" +import mindspore.common.dtype as mstype +from mindspore import ops, nn, Tensor + +from .fno_sp import SpectralConv1dDft +from ...utils.check_func import check_param_type + + +class KNO1D(nn.Cell): + r""" + The 1-dimensional Koopman Neural Operator (KNO1D) contains a encoder layer and a decoder layer, + multiple Koopman layers. + The details can be found in `KoopmanLab: machine learning for solving complex physics equations + `_. + + Args: + in_channels (int): The number of channels in the input space. Default: ``1``. + channels (int): The number of channels after dimension lifting of the input. Default: ``32``. + modes (int): The number of low-frequency components to keep. Default: ``16``. + resolution (int): The spatial resolution of the input. Default: ``1024``. + depths (int): The number of KNO layers. Default: ``4``. + compute_dtype (dtype.Number): The computation type of dense. Default: ``mstype.float16``. + Should be ``mstype.float32`` or ``mstype.float16``. mstype.float32 is recommended for + the GPU backend, mstype.float16 is recommended for the Ascend backend. + + Inputs: + - **x** (Tensor) - Tensor of shape :math:`(batch\_size, resolution, in\_channels)`. + + Outputs: + Tensor, the output of this KNO network. + + - **output** (Tensor) -Tensor of shape :math:`(batch\_size, resolution, in\_channels)`. + + Raises: + TypeError: If `in_channels` is not an int. + TypeError: If `channels` is not an int. + TypeError: If `modes` is not an int. + TypeError: If `depths` is not an int. + TypeError: If `resolution` is not an int. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> from mindflow.cell.neural_operators import KNO1D + >>> input_ = Tensor(np.ones([32, 1024, 1]), mstype.float32) + >>> net = KNO1D() + >>> x, x_reconstruct = net(input_) + >>> print(x.shape, x_reconstruct.shape) + (32, 1024, 1) (32, 1024, 1) + """ + def __init__(self, + in_channels=1, + channels=32, + modes=16, + depths=4, + resolution=1024, + compute_dtype=mstype.float32): + super().__init__() + check_param_type(in_channels, "in_channels", + data_type=int, exclude_type=bool) + check_param_type(channels, "channels", + data_type=int, exclude_type=bool) + check_param_type(modes, "modes", + data_type=int, exclude_type=bool) + check_param_type(depths, "depths", + data_type=int, exclude_type=bool) + check_param_type(resolution, "resolution", + data_type=int, exclude_type=bool) + self.in_channels = in_channels + self.channels = channels + self.modes = modes + self.depths = depths + self.resolution = resolution + self.enc = nn.Dense(in_channels, channels, has_bias=True) + self.dec = nn.Dense(channels, in_channels, has_bias=True) + self.koopman_layer = SpectralConv1dDft(channels, channels, modes, resolution, compute_dtype=compute_dtype) + self.w0 = nn.Conv1d(channels, channels, 1, has_bias=True) + + def construct(self, x: Tensor): + """KNO1D forward function. + + Args: + x (Tensor): Input Tensor. + """ + # reconstruct + x_reconstruct = self.enc(x) + x_reconstruct = ops.tanh(x_reconstruct) + x_reconstruct = self.dec(x_reconstruct) + + # predict + x = self.enc(x) + x = ops.tanh(x) + x = x.transpose(0, 2, 1) + x_w = x + for _ in range(self.depths): + x1 = self.koopman_layer(x) + x = ops.tanh(x + x1) + x = ops.tanh(self.w0(x_w) + x) + x = x.transpose(0, 2, 1) + x = self.dec(x) + return x, x_reconstruct diff --git a/MindFlow/mindflow/cell/neural_operators/kno2d.py b/MindFlow/mindflow/cell/neural_operators/kno2d.py new file mode 100644 index 0000000000000000000000000000000000000000..79f9ae98a2a824cb2339287ced8f7c66e3655af5 --- /dev/null +++ b/MindFlow/mindflow/cell/neural_operators/kno2d.py @@ -0,0 +1,119 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""KNO2D""" +import mindspore.common.dtype as mstype +from mindspore import ops, nn, Tensor + +from .fno_sp import SpectralConv2dDft +from ...utils.check_func import check_param_type + + +class KNO2D(nn.Cell): + r""" + The 2-dimensional Koopman Neural Operator (KNO2D) contains a encoder layer and a decoder layer, + multiple Koopman layers. + The details can be found in `KoopmanLab: machine learning for solving complex physics equations + `_. + + Args: + in_channels (int): The number of channels in the input space. Default: ``1``. + channels (int): The number of channels after dimension lifting of the input. Default: ``32``. + modes (int): The number of low-frequency components to keep. Default: ``16``. + resolution (int): The spatial resolution of the input. Default: ``1024``. + depths (int): The number of KNO layers. Default: ``4``. + compute_dtype (dtype.Number): The computation type of dense. Default: ``mstype.float16``. + Should be ``mstype.float32`` or ``mstype.float16``. mstype.float32 is recommended for + the GPU backend, mstype.float16 is recommended for the Ascend backend. + + Inputs: + - **x** (Tensor) - Tensor of shape :math:`(batch\_size, resolution, in\_channels)`. + + Outputs: + Tensor, the output of this KNO network. + + - **output** (Tensor) -Tensor of shape :math:`(batch\_size, resolution, in\_channels)`. + + Raises: + TypeError: If `in_channels` is not an int. + TypeError: If `channels` is not an int. + TypeError: If `modes` is not an int. + TypeError: If `depths` is not an int. + TypeError: If `resolution` is not an int. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> from mindflow.cell.neural_operators import KNO2D + >>> input_ = Tensor(np.ones([32, 64, 64, 10]), mstype.float32) + >>> net = KNO2D() + >>> x, x_reconstruct = net(input_) + >>> print(x.shape, x_reconstruct.shape) + (32, 64, 64, 10) (32, 64, 64, 10) + """ + + def __init__(self, + in_channels=10, + channels=32, + modes=16, + depths=4, + resolution=64, + compute_dtype=mstype.float32): + super().__init__() + check_param_type(in_channels, "in_channels", + data_type=int, exclude_type=bool) + check_param_type(channels, "channels", + data_type=int, exclude_type=bool) + check_param_type(modes, "modes", + data_type=int, exclude_type=bool) + check_param_type(depths, "depths", + data_type=int, exclude_type=bool) + check_param_type(resolution, "resolution", + data_type=int, exclude_type=bool) + self.in_channels = in_channels + self.channels = channels + self.modes = modes + self.depths = depths + self.resolution = resolution + self.enc = nn.Dense(in_channels, channels, has_bias=True) + self.dec = nn.Dense(channels, in_channels, has_bias=True) + self.koopman_layer = SpectralConv2dDft(channels, channels, [modes, modes], [resolution, resolution], + compute_dtype=compute_dtype) + self.w0 = nn.Conv2d(channels, channels, 1, has_bias=True) + + def construct(self, x: Tensor): + """KNO2D forward function. + + Args: + x (Tensor): Input Tensor. + """ + # reconstruct + x_reconstruct = self.enc(x) + x_reconstruct = ops.tanh(x_reconstruct) + x_reconstruct = self.dec(x_reconstruct) + + # predict + x = self.enc(x) + x = ops.tanh(x) + x = x.transpose(0, 3, 1, 2) + x_w = x + for _ in range(self.depths): + x1 = self.koopman_layer(x) + x = ops.tanh(x + x1) + x = ops.tanh(self.w0(x_w) + x) + x = x.transpose(0, 2, 3, 1) + x = self.dec(x) + return x, x_reconstruct diff --git a/MindFlow/mindflow/cell/vit.py b/MindFlow/mindflow/cell/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..f7c8e11818e35a785ef251310792c14b8ea3a8d6 --- /dev/null +++ b/MindFlow/mindflow/cell/vit.py @@ -0,0 +1,352 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +The ViT model +""" + +from mindspore import ops, Parameter, Tensor, nn +import mindspore.ops.operations as P +from mindspore.common.initializer import initializer, XavierUniform +import mindspore.common.dtype as mstype + +from .utils import to_2tuple, get_2d_sin_cos_pos_embed +from .attention import TransformerBlock + + +class PatchEmbedding(nn.Cell): + """Construct patch embeddings with positional embeddings""" + + def __init__(self, in_channels, hidden_channels, patch_size=16, compute_dtype=mstype.float16 + ): + super().__init__() + self.compute_dtype = compute_dtype + self.patch_embedding = nn.Conv2d( + in_channels=in_channels, + out_channels=hidden_channels, + kernel_size=patch_size, + stride=patch_size, + has_bias=True, + ).to_float(compute_dtype) + self.init_weights() + + def init_weights(self): + weight_shape = self.patch_embedding.weight.shape + xavier_init = initializer( + XavierUniform(), + [weight_shape[0], weight_shape[1] * weight_shape[2] * weight_shape[3]], + mstype.float32, + ) + + self.patch_embedding.weight = P.Reshape()( + xavier_init, + (weight_shape[0], weight_shape[1], + weight_shape[2], weight_shape[3]), + ) + + def construct(self, x): + x = self.patch_embedding(x) + x = P.Reshape()(x, (x.shape[0], x.shape[1], x.shape[2] * x.shape[3])) + x = P.Transpose()(x, (0, 2, 1)) + return x + + +class VitEncoder(nn.Cell): + r""" + ViT Encoder module with multi-layer stacked of `MultiHeadAttention`, + including multihead self attention and feedforward layer. + + Args: + grid_size (tuple[int]): The grid_size size of input. + in_channels (int): The input feature size of input. Default: ``3``. + patch_size (int): The patch size of image. Default: ``16``. + depths (int): The encoder depth of encoder layer. + hidden_channels (int): The encoder embedding dimension of encoder layer. Default: ``768``. + num_heads (int): The encoder heads' number of encoder layer. Default: ``16``. + dropout_rate (float): The rate of dropout layer. Default: ``0.0``. + compute_dtype (dtype): The data type for encoder, encoding_embedding, encoder and dense layer. + Default: ``mstype.float16``. + + Inputs: + - **input** (Tensor) - Tensor of shape :math:`(batch\_size, feature\_size, image\_height, image\_width)`. + + Outputs: + - **output** (Tensor) - Tensor of shape :math:`(batch\_size, patchify\_size, embed\_dim)`. + where patchify_size = (image_height * image_width) / (patch_size * patch_size). + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> from mindspore import ops + >>> from mindflow.cell.vit import VitEncoder + >>> input_tensor = ops.rand(32, 3, 192, 384) + >>> print(input_tensor.shape) + (32, 3, 192, 384) + >>>encoder = VitEncoder(grid_size=(192 // 16, 384 // 16), + >>> in_channels=3, + >>> patch_size=16, + >>> depths=6, + >>> hidden_channels=768, + >>> num_heads=12, + >>> dropout_rate=0.0, + >>> compute_dtype=mstype.float16) + >>>output_tensor = encoder(input_tensor) + >>> print("output_tensor:",output_tensor.shape) + (32, 288, 768) + """ + + def __init__(self, + in_channels, + hidden_channels, + grid_size, + patch_size, + depths, + num_heads, + dropout_rate=0.0, + compute_dtype=mstype.float16, + ): + super().__init__() + self.patch_embedding = PatchEmbedding( + in_channels, hidden_channels, patch_size, compute_dtype=compute_dtype + ) + pos_embed = get_2d_sin_cos_pos_embed(hidden_channels, grid_size) + self.position_embedding = Parameter( + Tensor(pos_embed, mstype.float32), + name="encoder_pos_embed", + requires_grad=False, + ) + self.layer = nn.CellList([]) + self.encoder_norm = nn.LayerNorm([hidden_channels], epsilon=1e-6).to_float( + mstype.float32 + ) + for _ in range(depths): + layer = TransformerBlock( + in_channels=hidden_channels, + num_heads=num_heads, + dropout_rate=dropout_rate, + compute_dtype=compute_dtype, + ) + self.layer.append(layer) + + def construct(self, x): + """construct""" + x = self.patch_embedding(x) + x = x + self.position_embedding + for layer_block in self.layer: + x = layer_block(x) + x = self.encoder_norm(x) + return x + + +class VitDecoder(nn.Cell): + r""" + ViT Decoder module with multi-layer stacked of `MultiHeadAttention`, + including multihead self attention and feedforward layer. + + Args: + grid_size (tuple[int]): The grid_size size of input. + depths (int): The decoder depth of decoder layer. + hidden_channels (int): The decoder embedding dimension of decoder layer. + num_heads (int): The decoder heads' number of decoder layer. + dropout_rate (float): The rate of dropout layer. Default: ``0.0``. + compute_dtype (dtype): The data type for encoder, decoding_embedding, decoder and dense layer. + Default: ``mstype.float16``. + + Inputs: + - **input** (Tensor) - Tensor of shape :math:`(batch\_size, patchify\_size, embed\_dim)`. + + Outputs: + - **output** (Tensor) - Tensor of shape :math:`(batch\_size, patchify\_size, embed\_dim)`. + where patchify_size = (image_height * image_width) / (patch_size * patch_size). + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> from mindspore import ops + >>> from mindflow.cell.vit import VitDecoder + >>> input_tensor = ops.rand(32, 288, 512) + >>> print(input_tensor.shape) + (32, 288, 768) + >>> decoder = VitDecoder(grid_size=grid_size, + >>> depths=6, + >>> hidden_channels=512, + >>> num_heads=16, + >>> dropout_rate=0.0, + >>> compute_dtype=mstype.float16) + >>> output_tensor = VitDecoder(input_tensor) + >>> print("output_tensor:",output_tensor.shape) + (32, 288, 512) + """ + + def __init__(self, + grid_size, + depths, + hidden_channels, + num_heads, + dropout_rate=0.0, + compute_dtype=mstype.float16, + ): + super().__init__() + self.grid_size = grid_size + self.layer = nn.CellList([]) + pos_embed = get_2d_sin_cos_pos_embed(hidden_channels, grid_size) + self.position_embedding = Parameter( + Tensor(pos_embed, mstype.float32), + name="decoder_pos_embed", + requires_grad=False, + ) + self.decoder_norm = nn.LayerNorm([hidden_channels], epsilon=1e-6).to_float( + mstype.float32 + ) + for _ in range(depths): + layer = TransformerBlock( + in_channels=hidden_channels, + num_heads=num_heads, + dropout_rate=dropout_rate, + compute_dtype=compute_dtype, + ) + self.layer.append(layer) + + def construct(self, x): + """construct""" + x = x + self.position_embedding + for layer_block in self.layer: + x = layer_block(x) + x = self.decoder_norm(x) + return x + + +class ViT(nn.Cell): + r""" + This module based on ViT backbone which including encoder, decoding_embedding, decoder and dense layer. + + Args: + image_size (tuple[int]): The image size of input. Default: ``(192, 384)``. + in_channels (int): The input feature size of input. Default: ``7``. + out_channels (int): The output feature size of output. Default: ``3``. + patch_size (int): The patch size of image. Default: ``16``. + encoder_depths (int): The encoder depth of encoder layer. Default: ``12``. + encoder_embed_dim (int): The encoder embedding dimension of encoder layer. Default: ``768``. + encoder_num_heads (int): The encoder heads' number of encoder layer. Default: ``12``. + decoder_depths (int): The decoder depth of decoder layer. Default: ``8``. + decoder_embed_dim (int): The decoder embedding dimension of decoder layer. Default: ``512``. + decoder_num_heads (int): The decoder heads' number of decoder layer. Default: ``16``. + dropout_rate (float): The rate of dropout layer. Default: ``0.0``. + compute_dtype (dtype): The data type for encoder, decoding_embedding, decoder and dense layer. + Default: ``mstype.float16``. + + Inputs: + - **input** (Tensor) - Tensor of shape :math:`(batch\_size, feature\_size, image\_height, image\_width)`. + + Outputs: + - **output** (Tensor) - Tensor of shape :math:`(batch\_size, patchify\_size, embed\_dim)`. + where patchify_size = (image_height * image_width) / (patch_size * patch_size) + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> from mindspore import ops + >>> from mindflow.cell import ViT + >>> input_tensor = ops.rand(32, 3, 192, 384) + >>> print(input_tensor.shape) + (32, 3, 192, 384) + >>> model = ViT(in_channels=3, + >>> out_channels=3, + >>> encoder_depths=6, + >>> encoder_embed_dim=768, + >>> encoder_num_heads=12, + >>> decoder_depths=6, + >>> decoder_embed_dim=512, + >>> decoder_num_heads=16, + >>> ) + >>> output_tensor = model(input_tensor) + >>> print(output_tensor.shape) + (32, 288, 768) + """ + + def __init__(self, + image_size=(192, 384), + in_channels=7, + out_channels=3, + patch_size=16, + encoder_depths=12, + encoder_embed_dim=768, + encoder_num_heads=12, + decoder_depths=8, + decoder_embed_dim=512, + decoder_num_heads=16, + dropout_rate=0.0, + compute_dtype=mstype.float16, + ): + super().__init__() + image_size = to_2tuple(image_size) + grid_size = (image_size[0] // patch_size, image_size[1] // patch_size) + + self.patch_size = patch_size + self.out_channels = out_channels + self.in_channels = in_channels + + self.encoder_depths = encoder_depths + self.encoder_embed_dim = encoder_embed_dim + self.encoder_num_heads = encoder_num_heads + + self.decoder_depths = decoder_depths + self.decoder_embed_dim = decoder_embed_dim + self.decoder_num_heads = decoder_num_heads + + self.transpose = ops.Transpose() + + self.encoder = VitEncoder( + in_channels=in_channels, + hidden_channels=encoder_embed_dim, + patch_size=patch_size, + grid_size=grid_size, + depths=encoder_depths, + num_heads=encoder_num_heads, + dropout_rate=dropout_rate, + compute_dtype=compute_dtype, + ) + + self.decoder_embedding = nn.Dense( + encoder_embed_dim, + decoder_embed_dim, + has_bias=True, + weight_init="XavierUniform", + ).to_float(compute_dtype) + + self.decoder = VitDecoder( + hidden_channels=decoder_embed_dim, + grid_size=grid_size, + depths=decoder_depths, + num_heads=decoder_num_heads, + dropout_rate=dropout_rate, + compute_dtype=compute_dtype, + ) + + self.decoder_pred = nn.Dense( + decoder_embed_dim, + patch_size**2 * out_channels, + has_bias=True, + weight_init="XavierUniform", + ).to_float(compute_dtype) + + def construct(self, x): + x = self.encoder(x) + x = self.decoder_embedding(x) + x = self.decoder(x) + images = self.decoder_pred(x) + return images.astype(mstype.float32) diff --git a/MindFlow/mindflow/core/__init__.py b/MindFlow/mindflow/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7595f98becafc3483d3100f02612b450abcc78aa --- /dev/null +++ b/MindFlow/mindflow/core/__init__.py @@ -0,0 +1,42 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""init""" +from .lr_scheduler import get_poly_lr, get_multi_step_lr, get_warmup_cosine_annealing_lr +from .losses import get_loss_metric, WaveletTransformLoss, MTLWeightedLoss, RelativeRMSELoss +from .derivatives import batched_hessian, batched_jacobian +from .optimizers import AdaHessian +from .fourier import DFTn, IDFTn, RDFTn, IRDFTn, DCT, IDCT, DST, IDST + +__all__ = ["get_poly_lr", + "get_multi_step_lr", + "get_warmup_cosine_annealing_lr", + "get_loss_metric", + "WaveletTransformLoss", + "MTLWeightedLoss", + "RelativeRMSELoss", + "batched_hessian", + "batched_jacobian", + "AdaHessian", + "DFTn", + "IDFTn", + "RDFTn", + "IRDFTn", + "DCT", + "IDCT", + "DST", + "IDST", + ] + +__all__.sort() diff --git a/MindSPONGE/applications/model_cards/MEGAProtein.md b/MindSPONGE/applications/model_cards/MEGAProtein.md new file mode 100644 index 0000000000000000000000000000000000000000..1d9e74e7ca62714f3ffa374548f5ff315eb01bab --- /dev/null +++ b/MindSPONGE/applications/model_cards/MEGAProtein.md @@ -0,0 +1,317 @@ +# MEGAProtein + +## 模型介绍 + +使用计算机高效计算获取蛋白质空间结构的过程被称为蛋白质结构预测,传统的结构预测工具一直存在精度不足的问题,直至2020年谷歌DeepMind团队提出[AlphaFold2](https://www.nature.com/articles/s41586-021-03819-2)[1,2],该模型相较于传统工具预测精度大幅提升,所得结构与真实结构误差接近实验方法,但是仍存在数据前处理耗时过长、缺少MSA时预测精度不准、缺乏通用评估结构质量工具的问题。针对这些问题,高毅勤老师团队与MindSpore科学计算团队合作进行了一系列创新研究,开发出更准确和更高效的蛋白质结构预测工具**MEGA-Protein**。 + +MEGA-Protein主要由三部分组成: + +- **蛋白质结构预测工具MEGA-Fold**,网络模型部分与AlphaFold2相同,在数据预处理的多序列对比环节采用了[MMseqs2](https://www.biorxiv.org/content/10.1101/2021.08.15.456425v1.full.pdf)[3]进行序列检索,相比于原版端到端速度提升2-3倍;同时借助内存复用大幅提升内存利用效率,同硬件条件下支持更长序列的推理,基于32GB内存的Ascend910运行时最长支持2048长度序列推理(以[Pipeline](https://gitee.com/mindspore/mindscience/tree/master/MindSPONGE/src/mindsponge/pipeline/models/megafold)模式运行,可以支持3072长度序列推理);我们还提供了结构预测模型训练能力,我们自己训练的权重获得了CAMEO-3D蛋白质结构预测赛道22年4月月榜第一。 + +
+MEGA-Fold获得CAMEO-3D蛋白质结构预测赛道月榜第一 +
+ +- **MSA生成工具MEGA-EvoGen**,能显著提升单序列的预测速度,并且能够在MSA较少(few shot)甚至没有MSA(zero-shot,即单序列)的情况下,帮助MEGA-Fold/AlphaFold2等模型维持甚至提高推理精度,突破了在「孤儿序列」、高异变序列和人造蛋白等MSA匮乏场景下无法做出准确预测的限制。基于32GB内存的Ascend910运行时最长支持768长度序列推理。该方法获得了CAMEO-3D蛋白质结构预测赛道22年7月月榜第一。 + +
+MEGA-EvoGen方法获得CAMEO-3D蛋白质结构预测赛道月榜第一 +
+ +- **蛋白质结构评分工具MEGA-Assessment**,该工具可以评价蛋白质结构每个残基的准确性以及残基-残基之间的距离误差,同时可以基于评价结果对蛋白结构作出进一步的优化。基于32GB内存的Ascend910运行时最长支持2048长度序列推理。该方法获得了CAMEO-QE结构质量评估赛道22年7月月榜第一。 + +
+MEGA-Assessment方法获得CAMEO-QE结构质量评估赛道月榜第一 +
+ +## 数据集 + +MEGA-Fold训练数据集为[PSP蛋白质结构数据集](http://ftp.cbi.pku.edu.cn/psp/),数据集大小为1.6TB,解压后为25TB。 +MEGA-Assessment训练数据集为PSP数据集中的[PSP lite](http://ftp.cbi.pku.edu.cn/psp/psp_lite/)。 + +```shell +. +└─PSP + ├─true_structure_dataset + | ├─pkl + | | └─256 pkl packages + | ├─pdb + | | └─256 pdb packages + | └─true_structure_data_statistics_729.json + ├─distillation_dataset + | ├─pkl + | | └─256 pkl packages + | ├─pdb + | | └─256 pdb packages + | └─distill_data_statistics_729.json + ├─new_validation_dataset + | ├─pkl.tar.gz + | ├─pdb.tar.gz + | └─nv_data_statistics.json + └─psp_lite + ├─true_structure_mini + | ├─pkl + | | └─32 pkl packages + | └─true_structure_mini.pdb.tar.gz + └─distillation_mini + ├─pkl + | └─32 pkl packages + └─distillation_mini.pdb.tar.gz +``` + +## 如何使用 + +mindsponge.PipeLine中分别提供了三个模型的推理流程,在使用时, + +1. 可将氨基酸序列输入MEGA-EvoGen中获取该蛋白的共进化信息,也可以将*传统数据库检索*生成的共进化信息输入MEGA-EvoGen进行强化 +2. 将共进化输入MEGA-Fold中进行蛋白质的结构预测 +3. 最后将蛋白质共进化与结构信息共同输入MEGA-Assessment中进行打分评估 + +以CASP14蛋白质T1082-D1为例,整体推理流程如下所示。 + +*传统数据库检索请参考`application/common_utils/database_query/README.md`配置。* + +```python +import numpy as np +import mindspore as ms +from mindsponge import PipeLine + +ms.set_context(mode=ms.GRAPH_MODE) + +# MEGA-EvoGen推理获取蛋白质生成MSA后的特征 +fasta = "GYDKDLCEWSMTADQTEVETQIEADIMNIVKRDRPEMKAEVQKQLKSGGVMQYNYVLYCDKNFNNKNIIAEVVGE" +msa_generator = PipeLine(name="MEGAEvoGen") +msa_generator.set_device_id(0) +msa_generator.initialize(key="evogen_predict_256") +msa_generator.model.from_pretrained() +msa_feature = msa_generator.predict(fasta) + +# MEGA-Fold推理获取蛋白质结构信息 +fold_prediction = PipeLine(name="MEGAFold") +fold_prediction.set_device_id(0) +fold_prediction.initialize(key="predict_256") +fold_prediction.model.from_pretrained() +final_atom_positions, final_atom_mask, aatype, _, _ = fold_prediction.model.predict(msa_feature) + +# MEGA-Assessment对蛋白质结构进行评价 +protein_assessment = PipeLine(name = "MEGAAssessment") +protein_assessment.set_device_id(0) +protein_assessment.initialize("predict_256") +protein_assessment.model.from_pretrained() +msa_feature['decoy_aatype'] = np.pad(aatype, (0, 256 - aatype.shape[0])) +msa_feature['decoy_atom_positions'] = np.pad(final_atom_positions, ((0, 256 - final_atom_positions.shape[0]), (0, 0), (0, 0))) +msa_feature['decoy_atom_mask'] = np.pad(final_atom_mask, ((0, 256 - final_atom_mask.shape[0]), (0, 0))) + +res = protein_assessment.model.predict(msa_feature) +print("score is:", np.mean(res[:msa_feature['num_residues']])) +``` + +### 使用场景 + +MEGAEvoGen,MEGAFold,MEGAAssessment均支持多种不同场景下的不同输入格式进行推理,详情如下: + +为方便说明使用场景,默认下载好config文件,通过修改内置参数的方式选择不同场景,用户使用时也可按照如下方式执行,若未提前下载config文件,可通过替换样例内代码的方式下载的同时进行config的修改与加载。 + +- MEGAEvoGen + + - 序列作为输入,样例如下: + + ```python + from mindsponge import PipeLine + from mindsponge.common.config_load import load_config + + fasta = "GYDKDLCEWSMTADQTEVETQIEADIMNIVKRDRPEMKAEVQKQLKSGGVMQYNYVLYCDKNFNNKNIIAEVVGE" + msa_generator = PipeLine(name="MEGAEvoGen") + + # 未获取config文件时,执行如下两行命令即可自动下载config文件,之后所有案例同理替换,仅提供代码样例,不做相同说明 + # from mindsponge.pipeline.pipeline import download_config + # download_config(msa_generator.config["evogen_predict_256"], msa_generator.config_path + "evogen_predict_256.yaml") + + conf = load_config(msa_generator.config_path + "evogen_predict_256.yaml") + conf.use_pkl = False + msa_generator.initialize(conf=conf) + msa_generator.model.from_pretrained() + features = msa_generator.predict(fasta) + + with open("./examples/MEGA-Protein/pkl/T1082-D1.pkl", "rb") as f: + data = pickle.load(f) + for k, v in features: + print(k, v.shape, v.dtype) + ``` + + - 序列搜索MSA后所获得的pickle文件作为输入,样例如下: + + ```python + import pickle + from mindsponge import PipeLine + + with open("./test.pkl", "rb") as f: + data = pickle.load(f) + msa_generator = PipeLine(name="MEGAEvoGen") + + # from mindsponge.pipeline.pipeline import download_config + # download_config(msa_generator.config["evogen_predict_256"], msa_generator.config_path + "evogen_predict_256.yaml") + + conf = load_config(msa_generator.config_path + "evogen_predict_256.yaml") + conf.use_pkl = True + msa_generator.initialize(conf=conf) + msa_generator.model.from_pretrained() + feature, mask = msa_generator.predict(data) + with open("./test.pkl", "rb") as f: + data = pickle.load(f) + for k, v in features: + print(k, v.shape, v.dtype) + ``` + +- MEGAFold + + - 使用搜索后所得pickle文件作为输入,样例如下: + + ```python + import pickle + import mindspore as ms + from mindsponge import PipeLine + ms.set_context(mode=ms.GRAPH_MODE) + + with open("./test.pkl", "rb") as f: + feature = pickle.load(f) + fold_prediction = PipeLine(name="MEGAFold") + fold_prediction.set_device_id(0) + fold_prediction.initialize(key="predict_256") + fold_prediction.model.from_pretrained() + res = fold_prediction.predict(feature) + pdb_file = res[-1] + os.makedirs(f'res.pdb', exist_ok=True) + os_flags = os.O_RDWR | os.O_CREAT + os_modes = stat.S_IRWXU + pdb_path = './res.pdb' + with os.fdopen(os.open(pdb_path, os_flags, os_modes), 'w') as fout: + fout.write(pdb_file) + + print(protein_structure) + ``` + + - 单序列进行MSA检索并进行推理(完整流程),其中MSA检索配置请参考`application/common_utils/database_query/README.md`。检索完成后使用pickle进行推理场景与上述另一场景完全相同,不重复提供代码。 + + - 后续MEGAFold会支持将蛋白质序列与template作为输入,不提供MSA进行推理的场景。 + +- MEGAAssessment + + - MEGAAssessment仅支持序列搜索所得pickle文件和MEGAFold推理所得pdb作为输入单场景,样例如下: + + ```python + import pickle + import numpy as np + from mindspore import context + from mindsponge import PipeLine + from mindsponge.common.config_load import load_config + from mindsponge.common.protein import from_pdb_string + + protein_assessment = PipeLine(name="MEGAAssessment") + protein_assessment.set_device_id(0) + + # from mindsponge.pipeline.pipeline import download_config + # download_config(protein_assessment.config["predict_256"], protein_assessment.config_path + "predict_256.yaml") + + conf = load_config(protein_assessment.config_path + "predict_256.yaml") + protein_assessment.initialize(key="predict_256") + protein_assessment.model.from_pretrained() + + # load raw feature + with open("./test.pkl", "rb") as f: + raw_feature = pickle.load(f) + # load decoy pdb + with open('./res.pdb', 'r') as f: + decoy_prot_pdb = from_pdb_string(f.read()) + raw_feature['decoy_aatype'] = decoy_prot_pdb.aatype + raw_feature['decoy_atom_positions'] = decoy_prot_pdb.atom_positions + raw_feature['decoy_atom_mask'] = decoy_prot_pdb.atom_mask + + res = protein_assessment.predict(raw_feature) + print("score is:", np.mean(res)) + ``` + +- 后处理 + + AI结构预测方法如MEGA-Fold/AlphaFold2结果只包含碳/氮等重原子的位置信息,缺少氢原子;同时AI方法预测的蛋白质结构可能违反物理化学原理,比如键长键角超出理论值范围等。MindSPONGE提供基于Amber力场的结构弛豫工具,补全氢原子位置信息的同时使结构更符合物理规律,请参考`application/common_utils/openmm_relaxation/README.md`配置 + +## 训练过程 + +Pipeline中提供了MEGAFold和MEGAAssessment两个模型的训练代码。MEGAFold的训练集为PSP数据集,MEGAAssessment的训练集为PSP lite数据集。 + +MEGAFold的训练样例代码如下所示: + +```bash +import mindspore as ms +from mindsponge import PipeLine + +ms.set_context(mode=ms.GRAPH_MODE) + +pipe = PipeLine(name="MEGAFold") +pipe.set_device_id(0) +pipe.initialize(key="initial_training") +pipe.train({YOUR_DATA_PATH}, num_epochs=1) +``` + +MEGAAssessment的训练样例代码如下所示: + +```bash +from mindsponge import PipeLine + +pipe = PipeLine(name="MEGAAssessment") +pipe.set_device_id(0) +pipe.initialize(key="initial_training") +pipe.train({YOUR_DATA_PATH}, num_epochs=1) +``` + +由于训练和推理代码网络结构存在差异,因此利用训练得到的权重进行推理、利用推理权重继续训练时,需要进行权重转换。 +当训练完成进行推理时,先进行权重转换,再通过`model.from_pretrained()`接口传入推理权重,即可进行推理。 +示例代码如下: + +```bash +from mindsponge.common.utils import get_predict_checkpoint, get_train_checkpoint + +# 将训练得到的权重转换为推理权重 +# training.ckpt: 训练得到的权重; +# 48:msa堆叠层数; +# predict.ckpt:需要被转换成的预测权重 +get_predict_checkpoint("training.ckpt", 48, "predict.ckpt") + +# 将推理时的权重转换为训练权重 +# training.ckpt: 需要进行训练使用的权重; +# 48:msa堆叠层数; +# predict.ckpt:预测时使用的权重 +get_train_checkpoint("training.ckpt", 48, "predict.ckpt") +``` + +## 引用 + +### 结构预测工具MEGA-Fold与训练数据集PSP + +```bash +@misc{https://doi.org/10.48550/arxiv.2206.12240, +doi = {10.48550/ARXIV.2206.12240}, +url = {https://arxiv.org/abs/2206.12240}, +author = {Liu, Sirui and Zhang, Jun and Chu, Haotian and Wang, Min and Xue, Boxin and Ni, Ningxi and Yu, Jialiang and Xie, Yuhao and Chen, Zhenyu and Chen, Mengyun and Liu, Yuan and Patra, Piya and Xu, Fan and Chen, Jie and Wang, Zidong and Yang, Lijiang and Yu, Fan and Chen, Lei and Gao, Yi Qin}, +title = {PSP: Million-level Protein Sequence Dataset for Protein Structure Prediction}, +publisher = {arXiv}, +year = {2022}, +copyright = {Creative Commons Attribution 4.0 International} +} +``` + +### MSA生成修正工具MEGA-EvoGen + +```bash +@article{doi:10.1021/acs.jctc.3c00528, +author = {Zhang, Jun and Liu, Sirui and Chen, Mengyun and Chu, Haotian and Wang, Min and Wang, Zidong and Yu, Jialiang and Ni, Ningxi and Yu, Fan and Chen, Dechin and Yang, Yi Isaac and Xue, Boxin and Yang, Lijiang and Liu, Yuan and Gao, Yi Qin}, +title = {Unsupervisedly Prompting AlphaFold2 for Accurate Few-Shot Protein Structure Prediction}, +journal = {Journal of Chemical Theory and Computation}, +volume = {19}, +number = {22}, +pages = {8460-8471}, +year = {2023}, +doi = {10.1021/acs.jctc.3c00528}, +note ={PMID: 37947474}, +} +``` diff --git a/MindSPONGE/applications/research/AlphaFold3/CMakeLists.txt b/MindSPONGE/applications/research/AlphaFold3/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..81162722e04a3d9a245c1d4bd5e558b65ed5110b --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/CMakeLists.txt @@ -0,0 +1,95 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +cmake_minimum_required(VERSION 3.28) +project( + "${SKBUILD_PROJECT_NAME}" + LANGUAGES CXX + VERSION "${SKBUILD_PROJECT_VERSION}") + +include(FetchContent) +set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_POSITION_INDEPENDENT_CODE TRUE) +set(ABSL_PROPAGATE_CXX_STD ON) + +# Remove support for scan deps, which is only useful when using C++ modules. +unset(CMAKE_CXX_SCANDEP_SOURCE) + +FetchContent_Declare( + abseil-cpp + GIT_REPOSITORY https://github.com/abseil/abseil-cpp + GIT_TAG d7aaad83b488fd62bd51c81ecf16cd938532cc0a # 20240116.2 + EXCLUDE_FROM_ALL) + +FetchContent_Declare( + pybind11 + GIT_REPOSITORY https://github.com/pybind/pybind11 + GIT_TAG 2e0815278cb899b20870a67ca8205996ef47e70f # v2.12.0 + EXCLUDE_FROM_ALL) + +FetchContent_Declare( + pybind11_abseil + GIT_REPOSITORY https://github.com/pybind/pybind11_abseil + GIT_TAG bddf30141f9fec8e577f515313caec45f559d319 # HEAD @ 2024-08-07 + EXCLUDE_FROM_ALL) + +FetchContent_Declare( + cifpp + GIT_REPOSITORY https://github.com/pdb-redo/libcifpp + GIT_TAG ac98531a2fc8daf21131faa0c3d73766efa46180 # v7.0.3 + # Don't `EXCLUDE_FROM_ALL` as necessary for build_data. +) + +FetchContent_Declare( + dssp + GIT_REPOSITORY https://github.com/PDB-REDO/dssp + GIT_TAG 57560472b4260dc41f457706bc45fc6ef0bc0f10 # v4.4.7 + EXCLUDE_FROM_ALL) + +FetchContent_MakeAvailable(pybind11 abseil-cpp pybind11_abseil cifpp dssp) + +find_package( + Python3 + COMPONENTS Interpreter Development NumPy + REQUIRED) + +include_directories(${PYTHON_INCLUDE_DIRS}) +include_directories(src/) + +file(GLOB_RECURSE cpp_srcs src/alphafold3/*.cc) +list(FILTER cpp_srcs EXCLUDE REGEX ".*\(_test\|_main\|_benchmark\).cc$") + +add_compile_definitions(NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION) + +pybind11_add_module(cpp ${cpp_srcs}) + +target_link_libraries( + cpp + PRIVATE absl::check + absl::flat_hash_map + absl::node_hash_map + absl::strings + absl::status + absl::statusor + absl::log + pybind11_abseil::absl_casters + Python3::NumPy + dssp::dssp + cifpp::cifpp) + +target_compile_definitions(cpp PRIVATE VERSION_INFO=${PROJECT_VERSION}) +install(TARGETS cpp LIBRARY DESTINATION alphafold3) +install( + FILES LICENSE + OUTPUT_TERMS_OF_USE.md + WEIGHTS_PROHIBITED_USE_POLICY.md + WEIGHTS_TERMS_OF_USE.md + DESTINATION alphafold3) diff --git a/MindSPONGE/applications/research/AlphaFold3/LICENSE b/MindSPONGE/applications/research/AlphaFold3/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..bfef380bf7d9cb74ec9ba533b37c3fbeef3bdc09 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/LICENSE @@ -0,0 +1,437 @@ +Attribution-NonCommercial-ShareAlike 4.0 International + +======================================================================= + +Creative Commons Corporation ("Creative Commons") is not a law firm and +does not provide legal services or legal advice. Distribution of +Creative Commons public licenses does not create a lawyer-client or +other relationship. Creative Commons makes its licenses and related +information available on an "as-is" basis. Creative Commons gives no +warranties regarding its licenses, any material licensed under their +terms and conditions, or any related information. Creative Commons +disclaims all liability for damages resulting from their use to the +fullest extent possible. + +Using Creative Commons Public Licenses + +Creative Commons public licenses provide a standard set of terms and +conditions that creators and other rights holders may use to share +original works of authorship and other material subject to copyright +and certain other rights specified in the public license below. The +following considerations are for informational purposes only, are not +exhaustive, and do not form part of our licenses. + + Considerations for licensors: Our public licenses are + intended for use by those authorized to give the public + permission to use material in ways otherwise restricted by + copyright and certain other rights. Our licenses are + irrevocable. Licensors should read and understand the terms + and conditions of the license they choose before applying it. + Licensors should also secure all rights necessary before + applying our licenses so that the public can reuse the + material as expected. Licensors should clearly mark any + material not subject to the license. This includes other CC- + licensed material, or material used under an exception or + limitation to copyright. More considerations for licensors: + wiki.creativecommons.org/Considerations_for_licensors + + Considerations for the public: By using one of our public + licenses, a licensor grants the public permission to use the + licensed material under specified terms and conditions. If + the licensor's permission is not necessary for any reason--for + example, because of any applicable exception or limitation to + copyright--then that use is not regulated by the license. Our + licenses grant only permissions under copyright and certain + other rights that a licensor has authority to grant. Use of + the licensed material may still be restricted for other + reasons, including because others have copyright or other + rights in the material. A licensor may make special requests, + such as asking that all changes be marked or described. + Although not required by our licenses, you are encouraged to + respect those requests where reasonable. More considerations + for the public: + wiki.creativecommons.org/Considerations_for_licensees + +======================================================================= + +Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International +Public License + +By exercising the Licensed Rights (defined below), You accept and agree +to be bound by the terms and conditions of this Creative Commons +Attribution-NonCommercial-ShareAlike 4.0 International Public License +("Public License"). To the extent this Public License may be +interpreted as a contract, You are granted the Licensed Rights in +consideration of Your acceptance of these terms and conditions, and the +Licensor grants You such rights in consideration of benefits the +Licensor receives from making the Licensed Material available under +these terms and conditions. + + +Section 1 -- Definitions. + + a. Adapted Material means material subject to Copyright and Similar + Rights that is derived from or based upon the Licensed Material + and in which the Licensed Material is translated, altered, + arranged, transformed, or otherwise modified in a manner requiring + permission under the Copyright and Similar Rights held by the + Licensor. For purposes of this Public License, where the Licensed + Material is a musical work, performance, or sound recording, + Adapted Material is always produced where the Licensed Material is + synched in timed relation with a moving image. + + b. Adapter's License means the license You apply to Your Copyright + and Similar Rights in Your contributions to Adapted Material in + accordance with the terms and conditions of this Public License. + + c. BY-NC-SA Compatible License means a license listed at + creativecommons.org/compatiblelicenses, approved by Creative + Commons as essentially the equivalent of this Public License. + + d. Copyright and Similar Rights means copyright and/or similar rights + closely related to copyright including, without limitation, + performance, broadcast, sound recording, and Sui Generis Database + Rights, without regard to how the rights are labeled or + categorized. For purposes of this Public License, the rights + specified in Section 2(b)(1)-(2) are not Copyright and Similar + Rights. + + e. Effective Technological Measures means those measures that, in the + absence of proper authority, may not be circumvented under laws + fulfilling obligations under Article 11 of the WIPO Copyright + Treaty adopted on December 20, 1996, and/or similar international + agreements. + + f. Exceptions and Limitations means fair use, fair dealing, and/or + any other exception or limitation to Copyright and Similar Rights + that applies to Your use of the Licensed Material. + + g. License Elements means the license attributes listed in the name + of a Creative Commons Public License. The License Elements of this + Public License are Attribution, NonCommercial, and ShareAlike. + + h. Licensed Material means the artistic or literary work, database, + or other material to which the Licensor applied this Public + License. + + i. Licensed Rights means the rights granted to You subject to the + terms and conditions of this Public License, which are limited to + all Copyright and Similar Rights that apply to Your use of the + Licensed Material and that the Licensor has authority to license. + + j. Licensor means the individual(s) or entity(ies) granting rights + under this Public License. + + k. NonCommercial means not primarily intended for or directed towards + commercial advantage or monetary compensation. For purposes of + this Public License, the exchange of the Licensed Material for + other material subject to Copyright and Similar Rights by digital + file-sharing or similar means is NonCommercial provided there is + no payment of monetary compensation in connection with the + exchange. + + l. Share means to provide material to the public by any means or + process that requires permission under the Licensed Rights, such + as reproduction, public display, public performance, distribution, + dissemination, communication, or importation, and to make material + available to the public including in ways that members of the + public may access the material from a place and at a time + individually chosen by them. + + m. Sui Generis Database Rights means rights other than copyright + resulting from Directive 96/9/EC of the European Parliament and of + the Council of 11 March 1996 on the legal protection of databases, + as amended and/or succeeded, as well as other essentially + equivalent rights anywhere in the world. + + n. You means the individual or entity exercising the Licensed Rights + under this Public License. Your has a corresponding meaning. + + +Section 2 -- Scope. + + a. License grant. + + 1. Subject to the terms and conditions of this Public License, + the Licensor hereby grants You a worldwide, royalty-free, + non-sublicensable, non-exclusive, irrevocable license to + exercise the Licensed Rights in the Licensed Material to: + + a. reproduce and Share the Licensed Material, in whole or + in part, for NonCommercial purposes only; and + + b. produce, reproduce, and Share Adapted Material for + NonCommercial purposes only. + + 2. Exceptions and Limitations. For the avoidance of doubt, where + Exceptions and Limitations apply to Your use, this Public + License does not apply, and You do not need to comply with + its terms and conditions. + + 3. Term. The term of this Public License is specified in Section + 6(a). + + 4. Media and formats; technical modifications allowed. The + Licensor authorizes You to exercise the Licensed Rights in + all media and formats whether now known or hereafter created, + and to make technical modifications necessary to do so. The + Licensor waives and/or agrees not to assert any right or + authority to forbid You from making technical modifications + necessary to exercise the Licensed Rights, including + technical modifications necessary to circumvent Effective + Technological Measures. For purposes of this Public License, + simply making modifications authorized by this Section 2(a) + (4) never produces Adapted Material. + + 5. Downstream recipients. + + a. Offer from the Licensor -- Licensed Material. Every + recipient of the Licensed Material automatically + receives an offer from the Licensor to exercise the + Licensed Rights under the terms and conditions of this + Public License. + + b. Additional offer from the Licensor -- Adapted Material. + Every recipient of Adapted Material from You + automatically receives an offer from the Licensor to + exercise the Licensed Rights in the Adapted Material + under the conditions of the Adapter's License You apply. + + c. No downstream restrictions. You may not offer or impose + any additional or different terms or conditions on, or + apply any Effective Technological Measures to, the + Licensed Material if doing so restricts exercise of the + Licensed Rights by any recipient of the Licensed + Material. + + 6. No endorsement. Nothing in this Public License constitutes or + may be construed as permission to assert or imply that You + are, or that Your use of the Licensed Material is, connected + with, or sponsored, endorsed, or granted official status by, + the Licensor or others designated to receive attribution as + provided in Section 3(a)(1)(A)(i). + + b. Other rights. + + 1. Moral rights, such as the right of integrity, are not + licensed under this Public License, nor are publicity, + privacy, and/or other similar personality rights; however, to + the extent possible, the Licensor waives and/or agrees not to + assert any such rights held by the Licensor to the limited + extent necessary to allow You to exercise the Licensed + Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this + Public License. + + 3. To the extent possible, the Licensor waives any right to + collect royalties from You for the exercise of the Licensed + Rights, whether directly or through a collecting society + under any voluntary or waivable statutory or compulsory + licensing scheme. In all other cases the Licensor expressly + reserves any right to collect such royalties, including when + the Licensed Material is used other than for NonCommercial + purposes. + + +Section 3 -- License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the +following conditions. + + a. Attribution. + + 1. If You Share the Licensed Material (including in modified + form), You must: + + a. retain the following if it is supplied by the Licensor + with the Licensed Material: + + i. identification of the creator(s) of the Licensed + Material and any others designated to receive + attribution, in any reasonable manner requested by + the Licensor (including by pseudonym if + designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of + warranties; + + v. a URI or hyperlink to the Licensed Material to the + extent reasonably practicable; + + b. indicate if You modified the Licensed Material and + retain an indication of any previous modifications; and + + c. indicate the Licensed Material is licensed under this + Public License, and include the text of, or the URI or + hyperlink to, this Public License. + + 2. You may satisfy the conditions in Section 3(a)(1) in any + reasonable manner based on the medium, means, and context in + which You Share the Licensed Material. For example, it may be + reasonable to satisfy the conditions by providing a URI or + hyperlink to a resource that includes the required + information. + 3. If requested by the Licensor, You must remove any of the + information required by Section 3(a)(1)(A) to the extent + reasonably practicable. + + b. ShareAlike. + + In addition to the conditions in Section 3(a), if You Share + Adapted Material You produce, the following conditions also apply. + + 1. The Adapter's License You apply must be a Creative Commons + license with the same License Elements, this version or + later, or a BY-NC-SA Compatible License. + + 2. You must include the text of, or the URI or hyperlink to, the + Adapter's License You apply. You may satisfy this condition + in any reasonable manner based on the medium, means, and + context in which You Share Adapted Material. + + 3. You may not offer or impose any additional or different terms + or conditions on, or apply any Effective Technological + Measures to, Adapted Material that restrict exercise of the + rights granted under the Adapter's License You apply. + + +Section 4 -- Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that +apply to Your use of the Licensed Material: + + a. for the avoidance of doubt, Section 2(a)(1) grants You the right + to extract, reuse, reproduce, and Share all or a substantial + portion of the contents of the database for NonCommercial purposes + only; + + b. if You include all or a substantial portion of the database + contents in a database in which You have Sui Generis Database + Rights, then the database in which You have Sui Generis Database + Rights (but not its individual contents) is Adapted Material, + including for purposes of Section 3(b); and + + c. You must comply with the conditions in Section 3(a) if You Share + all or a substantial portion of the contents of the database. + +For the avoidance of doubt, this Section 4 supplements and does not +replace Your obligations under this Public License where the Licensed +Rights include other Copyright and Similar Rights. + + +Section 5 -- Disclaimer of Warranties and Limitation of Liability. + + a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE + EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS + AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF + ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, + IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, + WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, + ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT + KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT + ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. + + b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE + TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, + NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, + INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, + COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR + USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN + ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR + DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR + IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. + + c. The disclaimer of warranties and limitation of liability provided + above shall be interpreted in a manner that, to the extent + possible, most closely approximates an absolute disclaimer and + waiver of all liability. + + +Section 6 -- Term and Termination. + + a. This Public License applies for the term of the Copyright and + Similar Rights licensed here. However, if You fail to comply with + this Public License, then Your rights under this Public License + terminate automatically. + + b. Where Your right to use the Licensed Material has terminated under + Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided + it is cured within 30 days of Your discovery of the + violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any + right the Licensor may have to seek remedies for Your violations + of this Public License. + + c. For the avoidance of doubt, the Licensor may also offer the + Licensed Material under separate terms or conditions or stop + distributing the Licensed Material at any time; however, doing so + will not terminate this Public License. + + d. Sections 1, 5, 6, 7, and 8 survive termination of this Public + License. + + +Section 7 -- Other Terms and Conditions. + + a. The Licensor shall not be bound by any additional or different + terms or conditions communicated by You unless expressly agreed. + + b. Any arrangements, understandings, or agreements regarding the + Licensed Material not stated herein are separate from and + independent of the terms and conditions of this Public License. + + +Section 8 -- Interpretation. + + a. For the avoidance of doubt, this Public License does not, and + shall not be interpreted to, reduce, limit, restrict, or impose + conditions on any use of the Licensed Material that could lawfully + be made without permission under this Public License. + + b. To the extent possible, if any provision of this Public License is + deemed unenforceable, it shall be automatically reformed to the + minimum extent necessary to make it enforceable. If the provision + cannot be reformed, it shall be severed from this Public License + without affecting the enforceability of the remaining terms and + conditions. + + c. No term or condition of this Public License will be waived and no + failure to comply consented to unless expressly agreed to by the + Licensor. + + d. Nothing in this Public License constitutes or may be interpreted + as a limitation upon, or waiver of, any privileges and immunities + that apply to the Licensor or You, including from the legal + processes of any jurisdiction or authority. + +======================================================================= + +Creative Commons is not a party to its public +licenses. Notwithstanding, Creative Commons may elect to apply one of +its public licenses to material it publishes and in those instances +will be considered the “Licensor.” The text of the Creative Commons +public licenses is dedicated to the public domain under the CC0 Public +Domain Dedication. Except for the limited purpose of indicating that +material is shared under a Creative Commons public license or as +otherwise permitted by the Creative Commons policies published at +creativecommons.org/policies, Creative Commons does not authorize the +use of the trademark "Creative Commons" or any other trademark or logo +of Creative Commons without its prior written consent including, +without limitation, in connection with any unauthorized modifications +to any of its public licenses or any other arrangements, +understandings, or agreements concerning use of licensed material. For +the avoidance of doubt, this paragraph does not form part of the +public licenses. + +Creative Commons may be contacted at creativecommons.org. \ No newline at end of file diff --git a/MindSPONGE/applications/research/AlphaFold3/README.md b/MindSPONGE/applications/research/AlphaFold3/README.md new file mode 100644 index 0000000000000000000000000000000000000000..6e23913ab7349a1e2a693b0c9a3b58092e45e56c --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/README.md @@ -0,0 +1,258 @@ +# AlphaFold3-MindSpore + +[**MindSpore版 AlphaFold3实现**] 一个基于MindSpore深度学习框架的AlphaFold3推理网络结构实现。 + +> 📖 **语言版本**: [中文](README.md) | [English](README_EN.md) + +## 📑 目录 + +- [项目简介](#项目简介) +- [安装](#安装) +- [快速开始](#快速开始) +- [详细使用说明](#详细使用说明) +- [许可证](#许可证) +- [致谢](#致谢) +- [参考文献](#参考文献) + +## 项目简介 + +**项目背景**: +AlphaFold3是DeepMind在2024年发布的革命性生物分子结构预测模型,能够预测蛋白质、DNA、RNA等生物大分子的三维结构。本项目基于Ascend NPU和MindSpore框架,实现了AlphaFold3的推理功能。 + +AlphaFold3 的模型结构如下图所示: + +![AlphaFold3 模型结构](image/af3_structure.jpg) + +- **推理流程**:首先输入的蛋白,核酸,配体等序列信息,经过模板搜索(Template Search)、多序列比对(Multiple Sequence Alignment, MSA)等预处理步骤,然后通过embeding部分对输入信息进行编码,之后通过Pairformer模块,获取序列及结构的关系,接着进入扩散模块生成三维结构,最后通过置信度模块给出预测的置信度评分 +- **生物分子结构预测**: 基于AlphaFold3算法的生物分子结构预测模型,支持包括蛋白质,DNA,RNA,小分子在内的多种输入形式;支持多链输入,预测相互作用和相对位置 +- **MindSpore支持**: 基于MindSpore对模型推理功能进行适配 + +### 硬件要求 + +- Atlas 800T A2 + +### 软件要求 + +- Python >= 3.11 +- MindSpore >= 2.5.0 +- CANN = 8.0.0 +- cmake >= 3.28.1 + +## 安装 + +### 1. 克隆仓库 + +```bash +git clone https://gitee.com/mindspore/mindscience +cd mindsience/MindSPONGE/application/research/AlphaFold3 +``` + +### 2. 安装依赖 + +```bash +pip install -r requirements.txt +#`{PATH}` 为当前目录 +export PYTHONPATH={PATH}/mindscience/MindSPONGE/src +export PYTHONPATH={PATH}/mindscience/MindChemistry +``` + +### 3. 安装软件包 + +[hmmer](http://eddylab.org/software/hmmer/) 在链接处下载安装包,如 `hmmer-3.4.tar.gz`,并放置在当前目录下,然后执行以下命令: + +```bash +mkdir /path/to/hmmer_build /path/to/hmmer && \ +mv ./hmmer-3.4.tar.gz /path/to/hmmer_build && \ +cd /path/to/hmmer_build && tar -zxf hmmer-3.4.tar.gz && rm hmmer-3.4.tar.gz && \ +cd /path/to/hmmer_build/hmmer-3.4 && ./configure --prefix=/path/to/hmmer && \ +make -j8 && make install && \ +cd /path/to/hmmer_build/hmmer-3.4/easel && make install && \ +rm -rf /path/to/hmmer_build +export PATH=/hmmer/bin:$PATH +which jackhmmer +``` + +如果出现`/path/to/hmmer/bin/jackhmmer`则安装成功 + +### 4. 编译 + +```bash +cd {PATH}/mindscience/MindSPONGE/applications/research/AlphaFold3 +mkdir build +cd build +cmake .. +make +cp ./cpp.cpython-311-aarch64-linux-gnu.so ../src/alphafold +cd .. +``` + +生成数据文件: + +```bash +python ./src/alphafold3/build_data.py +``` + +如出现报错找不到components.cif,可以去[wwpdb](https://files.wwpdb.org/pub/pdb/data/monomers/components.cif)下载components.cif文件,放置在conda环境中`{CONDA_ENV_DIR}/lib/python3.11/site-packages/share/libcifpp`文件夹下。如不存在`share/libcifpp`文件夹,则需要手动创建。 + +### 5. 下载数据库 + +可以从DeepMind官网下载测试用小数据库[miniature_databases](https://github.com/google-deepmind/alphafold3/tree/main/src/alphafold3/test_data/miniature_databases)(影响推理结果,仅测试使用!) +下载后放置在统一文件夹中并修改文件名如下所示(如统一放置在`/mindscience/MindSPONGE/applications/research/AlphaFold3/public_databases`可省略`--db_dir=/PATH/TO/DB_DIR`): + +```txt +miniature_databases + └─ mmcif_files + │ bfd-first_non_consensus_sequences.fasta + │ mgy_clusters_2022_05.fa + │ pdb_seqres_2022_09_28.fasta + │ uniprot_all_2021_04.fa + │ uniref90_2022_05.fa + │ nt_rna_2023_02_23_clust_seq_id_90_cov_80_rep_seq.fasta + │ rfam_14_9_clust_seq_id_90_cov_80_rep_seq.fasta + │ rnacentral_active_seq_id_90_cov_80_linclust.fasta +``` + +如果想要搜索完整的数据库,请从以下链接下载数据库,放置到同一文件夹中(如统一放置在`/mindscience/MindSPONGE/applications/research/AlphaFold3/public_databases`可省略`--db_dir=/PATH/TO/DB_DIR`): + +- [mmcif](https://storage.googleapis.com/alphafold-databases/v3.0/pdb_2022_09_28_mmcif_files.tar.zst) +- [BFD](https://storage.googleapis.com/alphafold-databases/v3.0/bfd-first_non_consensus_sequences.fasta.zst) +- [MGnify](https://storage.googleapis.com/alphafold-databases/v3.0/mgy_clusters_2022_05.fa.zst) +- [PDB seqres](https://storage.googleapis.com/alphafold-databases/v3.0/pdb_seqres_2022_09_28.fasta.zst) +- [UniProt](https://storage.googleapis.com/alphafold-databases/v3.0/uniprot_all_2021_04.fa.zst) +- [uniref90](https://storage.googleapis.com/alphafold-databases/v3.0/uniref90_2022_05.fa.zst) +- [NT](https://storage.googleapis.com/alphafold-databases/v3.0/nt_rna_2023_02_23_clust_seq_id_90_cov_80_rep_seq.fasta.zst) +- [RFam](https://storage.googleapis.com/alphafold-databases/v3.0/rfam_14_9_clust_seq_id_90_cov_80_rep_seq.fasta.zst) +- [RNACentral](https://storage.googleapis.com/alphafold-databases/v3.0/rnacentral_active_seq_id_90_cov_80_linclust.fasta.zst) + +请确保磁盘中有足够空间: +| DataBase | Compressed Size | Uncompressed Size| +|--------------|---------------------|------------------| +| mmcif | 233G | 233G | +| BFD | 9.2G | 16.9G | +| MGnify | 64.5G | 119G | +| PDB seqres| 25.3M | 217M | +| UniProt | 45.3G | 101G | +| uniref90 | 30.9G | 66.8G | +| NT | 15.8G | 75.4G | +| RFam | 53.9M | 217M | +| RNACentral| 3.27G | 12.9G | +| total | 402G | 534G | + +解压下载的数据文件: + +```bash +cd /PATH/TO/YOUR/DATA_DIR +tar –use-compress-program=unzstd -xf pdb_2022_09_28_mmcif_files.tar.zst +zstd -d bfd-first_non_consensus_sequences.fasta.zst +zstd -d mgy_clusters_2022_05.fa.zst +zstd -d pdb_seqres_2022_09_28.fasta.zst +zstd -d uniprot_all_2021_04.fa.zst +zstd -d uniref90_2022_05.fa.zst +zstd -d nt_rna_2023_02_23_clust_seq_id_90_cov_80_rep_seq.fasta.zst +zstd -d rfam_14_9_clust_seq_id_90_cov_80_rep_seq.fasta.zst +zstd -d rnacentral_active_seq_id_90_cov_80_linclust.fasta.zst +``` + +如统一放置在`/mindscience/MindSPONGE/applications/research/AlphaFold3/public_databases`可在运行时省略`--db_dir=/PATH/TO/DB_DIR` + +## 快速开始 + +### 输入数据格式 + +示例输入JSON: + +```json +{ + "name": "5tgy", + "sequences": [ + { + "protein": { + "id": "A", + "sequence": "SEFEKLRQTGDELVQAFQRLREIFDKGDDDSLEQVLEEIEELIQKHRQLFDNRQEAADTEAAKQGDQWVQLFQRFREAIDKGDKDSLEQLLEELEQALQKIRELAEKKN" + } + } + ], + "modelSeeds": [1], + "dialect": "alphafold3", + "version": 1 +} +``` + +### 运行流程 + +使用以下命令运行模型(计算精度float32): + +```bash +source set_path.sh +python run_alphafold.py \ + --json_path=example_input.json \ + --output_dir=output \ + --run_data_pipeline=true \ + --run_inference=true \ + --db_dir=/PATH/TO/DB_DIR \ + --model_dir=/PATH/TO/MODEL_DIR\ + --buckets=256 +``` + +### 参数说明 + +- `--json_path`输入文件名称 +- `--output_dir`: 输出文件路径 +- `--run_data_pipeline`: 是否运行数据处理模块 +- `--run_inference`: 是否运行推理模块 +- `--db_dir`: 数据库存放路径, 默认 `{HOME}/public_databases` +- `--model_dir`: 模型文件路径, 默认 `{HOME}/ckpt` +- `--buckets`: 设定序列长度,如不设置会将序列长度padding到256的倍数,如传入则使用传入值作为序列长度 + +### 输入与输出文件说明 + +- **JSON格式数据输入**: 包含蛋白质核酸等的序列信息。当前支持输入种类与DeepMind版本相同,支持蛋白质,DNA,RNA及Ligand作为输入,当前推理版本为单卡版本支持序列长度不超过1000 + +- **输出文件**: 5个标准的蛋白质结构文件,及置信度信息 + +```txt +└─name_in_your_json + └─ seed-1_sample-0 # 第一个生成样本 + │ confidence.json # 第一个样本的详细置信度文件 + │ model.cif # 第一个样本的结构文件 + │ summary_confidence.json # 第一个样本的总体置信度文件 + └─ seed-1_sample-1 # 第二个生成样本 + └─ seed-1_sample-2 # 第三个生成样本 + └─ seed-1_sample-3 # 第四个生成样本 + └─ seed-1_sample-4 # 第五个生成样本 + │ {name}_confidences.json # 最优样本的详细置信度文件 + │ {name}_data.json # 数据处理后的数据文件 + │ {name}_model.cif # 最优样本的结构文件 + │ {name}_summary_confidence.json # 最优样本的总体置信度文件 + │ ranking_scores.csv # 五个样本的ranking score;ranking score越高,表明置信度越高 +``` + +### 推理完成 + +当看到如下日志,表明推理正常结束: + +```txt +=======write output to /PATH/TO/OUTPUT/DIR/name_of_your_input========== +Done processing fold input name_of_your_input. +Done processing 1 fold inputs. +``` + +## 许可证 + +详情请参阅 [LICENSE](LICENSE) 文件。 + +## 致谢 + +- `data`,`structure`,`common`,`constant`等模块使用了[DeepMind](https://deepmind.com/)实现 +- `model`,`utils`等模块基于[MindSpore](https://www.mindspore.cn/)实现 + +## 联系我们 + +如果您在使用过程中遇到任何问题或有任何建议,请通过以下方式与我们联系: + +- **Gitee仓库**:[AlphaFold3](https://gitee.com/mindspore/mindscience/tree/main/MindSPONGE/applications/research/AlphaFold3) +- **问题跟踪**:[问题单跟踪](https://gitee.com/mindspore/mindscience/issues) + +## 参考文献 + +- Abramson J, Adler J, Dunger J, et al. Accurate structure prediction of biomolecular interactions with AlphaFold 3[J]. Nature, 2024, 630(8016): 493-500. \ No newline at end of file diff --git a/MindSPONGE/applications/research/AlphaFold3/README_EN.md b/MindSPONGE/applications/research/AlphaFold3/README_EN.md new file mode 100644 index 0000000000000000000000000000000000000000..a2ab03ff8f0b247a8800b0c907fcf1c46f917b57 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/README_EN.md @@ -0,0 +1,258 @@ +# AlphaFold3-MindSpore + +[**MindSpore Implementation of AlphaFold3**] A MindSpore-based deep learning framework implementation of AlphaFold3 inference network architecture. + +> 📖 **Language**: [中文](README.md) | [English](README_EN.md) + +## 📑 Table of Contents + +- [Project Overview](#project-overview) +- [Installation](#installation) +- [Quick Start](#quick-start) +- [License](#license) +- [Acknowledgments](#acknowledgments) +- [Reference](#reference) + +## Project Overview + +**Project Background**: +AlphaFold3 is a revolutionary biomolecular structure prediction model released by DeepMind in 2024, capable of predicting the three-dimensional structures of proteins, DNA, RNA, and other biological macromolecules. This project implements AlphaFold3's inference functionality based on Ascend NPU and MindSpore framework. + +Model Architecture is shown below: + +![AlphaFold3 Model Structure](image/af3_structure.jpg) + +- **Inference Pipeline**:The workflow begins with the provision of sequence information for proteins, DNA, RNA, and ligands. This data undergoes preprocessing steps, including template search and multiple sequence alignment, before being fed into the model. Next, an embedding module encodes the input information. Subsequently, the Pairformer cycles analyze the relationships between the sequences and their structures. Following this, a diffusion module generates the 3D structures. Finally, a confidence module assigns a confidence score to the predictions, providing a measure of their reliability. +- **Biomolecular Structure Prediction**: A biomolecular structure prediction model based on the AlphaFold3 algorithm, supporting various input forms including proteins, DNA, RNA, and small molecules; enabling multi-chain inputs and predicting interactions and relative positions. +- **MindSpore Support**: Model Inference adaptation based on MindSpore. + +### Hardware Requirements + +- Atlas 800T A2 + +### Software Requirements + +- Python >= 3.11 +- MindSpore >= 2.5.0 +- CANN = 8.0.0 +- cmake >= 3.28.1 + +## Installation + +### 1. Clone Repository + +```bash +git clone https://gitee.com/mindspore/mindscience +cd mindsience/MindSPONGE/application/research/AlphaFold3 +``` + +### 2. Install Dependencies + +```bash +pip install -r requirements.txt +#`{PATH}` is the current path +export PYTHONPATH={PATH}/mindscience/MindSPONGE/src +export PYTHONPATH={PATH}/mindscience/MindChemistry +``` + +### 3. Installing the Software Package + +Download the installation package from the link [hmmer](http://eddylab.org/software/hmmer/) , such as hmmer-3.4.tar.gz, and place it in the current directory. + +```bash +mkdir /path/to/hmmer_build /path/to/hmmer && \ +mv ./hmmer-3.4.tar.gz /path/to/hmmer_build && \ +cd /path/to/hmmer_build && tar -zxf hmmer-3.4.tar.gz && rm hmmer-3.4.tar.gz && \ +cd /path/to/hmmer_build/hmmer-3.4 && ./configure --prefix=/path/to/hmmer && \ +make -j8 && make install && \ +cd /path/to/hmmer_build/hmmer-3.4/easel && make install && \ +rm -rf /path/to/hmmer_build +export PATH=/hmmer/bin:$PATH +which jackhmmer +``` + +If the file `/path/to/hmmer/bin/jackhmmer` appears, the installation is successful. + +### 4. Compile + +```bash +cd {PATH}/mindscience/MindSPONGE/applications/research/AlphaFold3 +mkdir build +cd build +cmake .. +make +cp ./cpp.cpython-311-aarch64-linux-gnu.so ../src/alphafold +cd .. +``` + +Then, we need to generate data file: + +```bash +python ./src/alphafold3/build_data.py +``` + +if you see the error 'counld not find components.cif', download the file from [wwpdb](https://files.wwpdb.org/pub/pdb/data/monomers/components.cif),then put this file in your conda environment, `{CONDA_ENV_DIR}/lib/python3.11/site-packages/share/libcifpp`. If there is no `share/libcifpp` direction, create the direction by yourself. + +### 5. Download DataBase + +You can download a small test database from DeepMind [miniature_databases](https://github.com/google-deepmind/alphafold3/tree/main/src/alphafold3/test_data/miniature_databases)(Only for test,have influence to inference result!) +Download and put all the files in the same direction (No need to set `--db_dir=/PATH/TO/DB_DIR` if all the database are put in `/mindscience/MindSPONGE/applications/research/AlphaFold3/public_databases`) and rename the file like the example below: + +```txt +miniature_databases + └─ mmcif_files + │ bfd-first_non_consensus_sequences.fasta + │ mgy_clusters_2022_05.fa + │ pdb_seqres_2022_09_28.fasta + │ uniprot_all_2021_04.fa + │ uniref90_2022_05.fa + │ nt_rna_2023_02_23_clust_seq_id_90_cov_80_rep_seq.fasta + │ rfam_14_9_clust_seq_id_90_cov_80_rep_seq.fasta + │ rnacentral_active_seq_id_90_cov_80_linclust.fasta +``` + +If you want to seearch the full database, download the following database, and put them in the same direction(No need to set `--db_dir=/PATH/TO/DB_DIR` if all the database are put in `/mindscience/MindSPONGE/applications/research/AlphaFold3/public_databases`): + +- [mmcif](https://storage.googleapis.com/alphafold-databases/v3.0/pdb_2022_09_28_mmcif_files.tar.zst) +- [BFD small](https://storage.googleapis.com/alphafold-databases/v3.0/bfd-first_non_consensus_sequences.fasta.zst) +- [MGnify](https://storage.googleapis.com/alphafold-databases/v3.0/mgy_clusters_2022_05.fa.zst) +- [PDB seqres](https://storage.googleapis.com/alphafold-databases/v3.0/pdb_seqres_2022_09_28.fasta.zst) +- [UniProt](https://storage.googleapis.com/alphafold-databases/v3.0/uniprot_all_2021_04.fa.zst) +- [uniref90](https://storage.googleapis.com/alphafold-databases/v3.0/uniref90_2022_05.fa.zst) +- [NT](https://storage.googleapis.com/alphafold-databases/v3.0/nt_rna_2023_02_23_clust_seq_id_90_cov_80_rep_seq.fasta.zst) +- [RFam](https://storage.googleapis.com/alphafold-databases/v3.0/rfam_14_9_clust_seq_id_90_cov_80_rep_seq.fasta.zst) +- [RNACentral](https://storage.googleapis.com/alphafold-databases/v3.0/rnacentral_active_seq_id_90_cov_80_linclust.fasta.zst) + +Make sure having enough space on disk: + +| DataBase | Compressed Size | Uncompressed Size| +|--------------|---------------------|------------------| +| mmcif | 233G | 233G | +| BFD | 9.2G | 16.9G | +| MGnify | 64.5G | 119G | +| PDB seqres| 25.3M | 217M | +| UniProt | 45.3G | 101G | +| uniref90 | 30.9G | 66.8G | +| NT | 15.8G | 75.4G | +| RFam | 53.9M | 217M | +| RNACentral| 3.27G | 12.9G | +| total | 402G | 534G | + +Uncompressing the following database file: + +```bash +cd /PATH/TO/YOUR/DATA_DIR +tar –use-compress-program=unzstd -xf pdb_2022_09_28_mmcif_files.tar.zst +zstd -d bfd-first_non_consensus_sequences.fasta.zst +zstd -d mgy_clusters_2022_05.fa.zst +zstd -d pdb_seqres_2022_09_28.fasta.zst +zstd -d uniprot_all_2021_04.fa.zst +zstd -d uniref90_2022_05.fa.zst +zstd -d nt_rna_2023_02_23_clust_seq_id_90_cov_80_rep_seq.fasta.zst +zstd -d rfam_14_9_clust_seq_id_90_cov_80_rep_seq.fasta.zst +zstd -d rnacentral_active_seq_id_90_cov_80_linclust.fasta.zst +``` + +If all the files are put under`/mindscience/MindSPONGE/applications/research/AlphaFold3/public_databases`, the setting `--db_dir=/PATH/TO/DB_DIR` can be ignored. + +## Quick Start + +### Input Structure + +Example Input JSON: + +```json +{ + "name": "5tgy", + "sequences": [ + { + "protein": { + "id": "A", + "sequence": "SEFEKLRQTGDELVQAFQRLREIFDKGDDDSLEQVLEEIEELIQKHRQLFDNRQEAADTEAAKQGDQWVQLFQRFREAIDKGDKDSLEQLLEELEQALQKIRELAEKKN" + } + } + ], + "modelSeeds": [1], + "dialect": "alphafold3", + "version": 1 +} +``` + +### Running Pipeline + +AlphaFold3 can be run with the following command(Precision: float32). + +```bash +source set_path.sh +python run_alphafold.py \ + --json_path=example_input.json \ + --output_dir=output \ + --run_data_pipeline=true \ + --run_inference=true \ + --db_dir=/PATH/TO/DB_DIR \ + --model_dir=/PATH/TO/MODEL_DIR \ + --buckets=256 +``` + +### Parameter Introduction + +- `--json_path`: Name of input json +- `--output_dir`: Output direction +- `--run_data_pipeline`: run data-pipeline or not +- `--run_inference`: run inference or not +- `--db_dir`: path to database, default `{HOME}/public_databases` +- `--model_dir`: Path to ckpt, default `{HOME}/ckpt` +- `--buckets`: setting the sequence length,Default:padding to N * 256 + +### Input & Output + +- **JSON Input**: Contains sequence information of proteins and other molecules. Support the following types of input (same as DeepMind version): Protein, DNA, RNA, Ligand, etc. Currently, only single NPU version and the max sequence length should be smaller than 1000. + +- **CIF Output**: 5 Standard protein structure files and confidence info. + +```txt +└─name_in_your_json + └─ seed-{random_seed}_sample-0 # First Sample + │ confidence.json # Confidence of the first sample + │ model.cif # Predicted structure of the first sample + │ summary_confidence.json # Summary confidence of the first sample + └─ seed-{random_seed}_sample-1 # Second Sample + └─ seed-{random_seed}_sample-2 # Third Sample + └─ seed-{random_seed}_sample-3 # Forth Sample + └─ seed-{random_seed}_sample-4 # Fifth Sample + │ {name}_confidences.json # Confidence of the best sample + │ {name}_data.json # Data json file after data-processing + │ {name}_model.cif # Predicted structure of the best sample + │ {name}_summary_confidence.json # Summary confidence of the best sample + │ ranking_scores.csv # Ranking Score of all five samples, the higher of the ranking score, the higher of the confidence of the sample +``` + +### End of Inference + +When you see the following log,the inference finished correctly: + +```text +=======write output to /PATH/TO/OUTPUT/DIR/name_of_your_input========== +Done processing fold input name_of_your_input. +Done processing 1 fold inputs. +``` + +## License + +See the [LICENSE](LICENSE) file for details. + +## Acknowledgments + +- The implementation of Modules including: data,structure,common, constant refers to [DeepMind](https://github.com/google-deepmind/alphafold3). +- The implementation of Modules including: model,utils are based on [MindScience](https://gitee.com/mindspore/mindscience/) + +## 联系我们 + +If you encounter any issues or have any suggestions during use, please contact us through the following methods: + +- **Gitee Repository**: [AlphaFold3](https://gitee.com/mindspore/mindscience/tree/main/MindSPONGE/applications/research/AlphaFold3) +- **Issue Tracking**: [Issue Tracking](https://gitee.com/mindspore/mindscience/issues) + +## Reference + +- Abramson J, Adler J, Dunger J, et al. Accurate structure prediction of biomolecular interactions with AlphaFold 3[J]. Nature, 2024, 630(8016): 493-500. diff --git a/MindSPONGE/applications/research/AlphaFold3/example_input.json b/MindSPONGE/applications/research/AlphaFold3/example_input.json new file mode 100644 index 0000000000000000000000000000000000000000..2e31369ccb0f17196d0fcd64ecbd7397311e0bcd --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/example_input.json @@ -0,0 +1,14 @@ +{ + "name": "5tgy", + "sequences": [ + { + "protein": { + "id": "A", + "sequence": "SEFEKLRQTGDELVQAFQRLREIFDKGDDDSLEQVLEEIEELIQKHRQLFDNRQEAADTEAAKQGDQWVQLFQRFREAIDKGDKDSLEQLLEELEQALQKIRELAEKKN" + } + } + ], + "modelSeeds": [1], + "dialect": "alphafold3", + "version": 1 + } \ No newline at end of file diff --git a/MindSPONGE/applications/research/AlphaFold3/image/af3_structure.jpg b/MindSPONGE/applications/research/AlphaFold3/image/af3_structure.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a8e383f913b2ed2b3cec8f11c30203e65d8b197b Binary files /dev/null and b/MindSPONGE/applications/research/AlphaFold3/image/af3_structure.jpg differ diff --git a/MindSPONGE/applications/research/AlphaFold3/requirements.txt b/MindSPONGE/applications/research/AlphaFold3/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..1c230c66552966e3ca3002e405f99b2cb586c8b5 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/requirements.txt @@ -0,0 +1,6 @@ +mindSpore==2.5.0 +absl-py==2.1.0 +numpy==1.26.0 +rdkit==2024.3.5 +scipy==1.14.1 +tqdm==4.67.0 \ No newline at end of file diff --git a/MindSPONGE/applications/research/AlphaFold3/run_alphafold.py b/MindSPONGE/applications/research/AlphaFold3/run_alphafold.py new file mode 100644 index 0000000000000000000000000000000000000000..a9f2fe9bf53361d4d464662f51839fbd0114a1dd --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/run_alphafold.py @@ -0,0 +1,687 @@ +# Copyright 2024 DeepMind Technologies Limited +# Copyright (C) 2025 Huawei Technologies Co., Ltd +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# +# Modifications by Huawei Technologies Co., Ltd: Adapt to run by MindSpore on Ascend + +from collections.abc import Callable, Iterable, Sequence +import csv +import dataclasses +import datetime +import functools +import multiprocessing +import os +import pathlib +import shutil +import string +import textwrap +import time +import typing +from typing import Protocol, Self, TypeVar, overload + +from absl import app +from absl import flags +from alphafold3.common import base_config +from alphafold3.common import folding_input +from alphafold3.common import resources +from alphafold3.constants import chemical_components +import alphafold3.cpp +from alphafold3.data import featurisation +from alphafold3.data import pipeline +from alphafold3.utils.attention import attention +from alphafold3.model import features +from alphafold3.model.diffusion.load_ckpt import load_diffuser +# from alphafold3.model import params +from alphafold3.model import post_processing +from alphafold3.model.components import base_model +from alphafold3.model.components import utils +from alphafold3.model.diffusion import model as diffusion_model +from alphafold3.model.feat_batch import Batch +import mindspore as ms +import numpy as np + + +_HOME_DIR = pathlib.Path(os.environ.get('HOME')) +_DEFAULT_MODEL_DIR = _HOME_DIR / 'ckpt' +_DEFAULT_DB_DIR = _HOME_DIR / 'public_databases' + + +# Input and output paths. +_JSON_PATH = flags.DEFINE_string( + 'json_path', + None, + 'Path to the input JSON file.', +) +_INPUT_DIR = flags.DEFINE_string( + 'input_dir', + None, + 'Path to the directory containing input JSON files.', +) +_OUTPUT_DIR = flags.DEFINE_string( + 'output_dir', + None, + 'Path to a directory where the results will be saved.', +) +MODEL_DIR = flags.DEFINE_string( + 'model_dir', + _DEFAULT_MODEL_DIR.as_posix(), + 'Path to the model to use for inference.', +) + +# Control which stages to run. +_RUN_DATA_PIPELINE = flags.DEFINE_bool( + 'run_data_pipeline', + True, + 'Whether to run the data pipeline on the fold inputs.', +) +_RUN_INFERENCE = flags.DEFINE_bool( + 'run_inference', + True, + 'Whether to run inference on the fold inputs.', +) + +# Binary paths. +_JACKHMMER_BINARY_PATH = flags.DEFINE_string( + 'jackhmmer_binary_path', + shutil.which('jackhmmer'), + 'Path to the Jackhmmer binary.', +) +_NHMMER_BINARY_PATH = flags.DEFINE_string( + 'nhmmer_binary_path', + shutil.which('nhmmer'), + 'Path to the Nhmmer binary.', +) +_HMMALIGN_BINARY_PATH = flags.DEFINE_string( + 'hmmalign_binary_path', + shutil.which('hmmalign'), + 'Path to the Hmmalign binary.', +) +_HMMSEARCH_BINARY_PATH = flags.DEFINE_string( + 'hmmsearch_binary_path', + shutil.which('hmmsearch'), + 'Path to the Hmmsearch binary.', +) +_HMMBUILD_BINARY_PATH = flags.DEFINE_string( + 'hmmbuild_binary_path', + shutil.which('hmmbuild'), + 'Path to the Hmmbuild binary.', +) + +# Database paths. +DB_DIR = flags.DEFINE_multi_string( + 'db_dir', + (_DEFAULT_DB_DIR.as_posix(),), + 'Path to the directory containing the databases. Can be specified multiple' + ' times to search multiple directories in order.', +) + +_SMALL_BFD_DATABASE_PATH = flags.DEFINE_string( + 'small_bfd_database_path', + '${DB_DIR}/bfd-first_non_consensus_sequences.fasta', + 'Small BFD database path, used for protein MSA search.', +) +_MGNIFY_DATABASE_PATH = flags.DEFINE_string( + 'mgnify_database_path', + '${DB_DIR}/mgy_clusters_2022_05.fa', + 'Mgnify database path, used for protein MSA search.', +) +_UNIPROT_CLUSTER_ANNOT_DATABASE_PATH = flags.DEFINE_string( + 'uniprot_cluster_annot_database_path', + '${DB_DIR}/uniprot_all_2021_04.fa', + 'UniProt database path, used for protein paired MSA search.', +) +_UNIREF90_DATABASE_PATH = flags.DEFINE_string( + 'uniref90_database_path', + '${DB_DIR}/uniref90_2022_05.fa', + 'UniRef90 database path, used for MSA search. The MSA obtained by ' + 'searching it is used to construct the profile for template search.', +) +_NTRNA_DATABASE_PATH = flags.DEFINE_string( + 'ntrna_database_path', + '${DB_DIR}/nt_rna_2023_02_23_clust_seq_id_90_cov_80_rep_seq.fasta', + 'NT-RNA database path, used for RNA MSA search.', +) +_RFAM_DATABASE_PATH = flags.DEFINE_string( + 'rfam_database_path', + '${DB_DIR}/rfam_14_9_clust_seq_id_90_cov_80_rep_seq.fasta', + 'Rfam database path, used for RNA MSA search.', +) +_RNA_CENTRAL_DATABASE_PATH = flags.DEFINE_string( + 'rna_central_database_path', + '${DB_DIR}/rnacentral_active_seq_id_90_cov_80_linclust.fasta', + 'RNAcentral database path, used for RNA MSA search.', +) +_PDB_DATABASE_PATH = flags.DEFINE_string( + 'pdb_database_path', + '${DB_DIR}/mmcif_files', + 'PDB database directory with mmCIF files path, used for template search.', +) +_SEQRES_DATABASE_PATH = flags.DEFINE_string( + 'seqres_database_path', + '${DB_DIR}/pdb_seqres_2022_09_28.fasta', + 'PDB sequence database path, used for template search.', +) + +# Number of CPUs to use for MSA tools. +_JACKHMMER_N_CPU = flags.DEFINE_integer( + 'jackhmmer_n_cpu', + min(multiprocessing.cpu_count(), 8), + 'Number of CPUs to use for Jackhmmer. Default to min(cpu_count, 8). Going' + ' beyond 8 CPUs provides very little additional speedup.', +) +_NHMMER_N_CPU = flags.DEFINE_integer( + 'nhmmer_n_cpu', + min(multiprocessing.cpu_count(), 8), + 'Number of CPUs to use for Nhmmer. Default to min(cpu_count, 8). Going' + ' beyond 8 CPUs provides very little additional speedup.', +) + +# Template search configuration. +_MAX_TEMPLATE_DATE = flags.DEFINE_string( + 'max_template_date', + '2021-09-30', # By default, use the date from the AlphaFold 3 paper. + 'Maximum template release date to consider. Format: YYYY-MM-DD. All ' + 'templates released after this date will be ignored.', +) + + +_BUCKETS = flags.DEFINE_list( + 'buckets', + # pyformat: disable + ['256', '512', '768', '1024', '1280', '1536', '2048', '2560', '3072', + '3584', '4096', '4608', '5120'], + # pyformat: enable + 'Strictly increasing order of token sizes for which to cache compilations.' + ' For any input with more tokens than the largest bucket size, a new bucket' + ' is created for exactly that number of tokens.', +) +_FLASH_ATTENTION_IMPLEMENTATION = flags.DEFINE_enum( + 'flash_attention_implementation', + default='ms', + enum_values=['ms'], + help=( + "Flash attention implementation to use. 'triton' and 'cudnn' uses a" + ' Triton and cuDNN flash attention implementation, respectively. The' + ' Triton kernel is fastest and has been tested more thoroughly. The' + " Triton and cuDNN kernels require Ampere GPUs or later. 'xla' uses an" + ' XLA attention implementation (no flash attention) and is portable' + ' across GPU devices.' + ), +) + + +class ConfigurableModel(Protocol): + """A model with a nested config class.""" + + class Config(base_config.BaseConfig): + ... + + def __call__(self, config: Config) -> Self: + ... + + @classmethod + def get_inference_result( + cls: Self, + batch: features.BatchDict, + result: base_model.ModelResult, + target_name: str = '', + ) -> Iterable[base_model.InferenceResult]: + ... + + +ModelT = TypeVar('ModelT', bound=ConfigurableModel) + + +def make_model_config(): + print('not implemented make_model_config') + return 'ab' + + +def make_model_config( + *, + model_class: type[ModelT] = diffusion_model.Diffuser, + flash_attention_implementation: attention.Implementation = 'ms', +): + config = model_class.Config() + if hasattr(config, '_configglobal'): + config.global_config.flash_attention_implementation = ( + flash_attention_implementation + ) + return config + + +class ModelRunner: + """Helper class to run structure prediction stages.""" + + def __init__( + self, + model_class: ConfigurableModel, + config: base_config.BaseConfig, + model_dir: pathlib.Path, + ): + self._model_class = model_class + self._model_config = config + self._model_dir = model_dir + + @functools.cached_property + def model_params(self): + """Loads model parameters from the model directory.""" + # Load parameters from checkpoint file + # param_dict = ms.load_checkpoint(self._model_dir / "test.ckpt") + # return param_dict + + @functools.cached_property + def _model( + self + ) -> Callable[[np.ndarray, features.BatchDict], base_model.ModelResult]: + """Loads model parameters and returns a model forward pass.""" + assert isinstance(self._model_config, self._model_class.Config) + + def forward_fn(batch): + num_residues = batch.token_features.residue_index.shape[0] + model = self._model_class(self._model_config, 447, (256, 447), (num_residues, 256, 128), (256, 256, 128), + (256, 384), (256, 24, 3), 128, 4, dtype=ms.float32) + load_diffuser(model, self._model_dir, dtype=ms.float32) + res = model(batch, 42) + return res + + return forward_fn + + def run_inference( + self, featurised_example: features.BatchDict + ) -> base_model.ModelResult: + """Computes a forward pass of the model on a featurised example.""" + featurised_example = Batch.from_data_dict(featurised_example) + featurised_example.convert_to_tensor(ms.float32) + + result = self._model(featurised_example) + + # Convert identifier to bytes + if '__identifier__' in result: + result['__identifier__'] = result['__identifier__'].tobytes() + return result + + def extract_structures( + self, + batch: features.BatchDict, + result: base_model.ModelResult, + target_name: str, + ) -> list[base_model.InferenceResult]: + """Generates structures from model outputs.""" + batch = Batch.from_data_dict(batch) + batch.convert_to_tensor(ms.float32) + return list( + self._model_class.get_inference_result( + batch=batch, result=result, target_name=target_name + ) + ) + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class ResultsForSeed: + """Stores the inference results (diffusion samples) for a single seed. + + Attributes: + seed: The seed used to generate the samples. + inference_results: The inference results, one per sample. + full_fold_input: The fold input that must also include the results of + running the data pipeline - MSA and templates. + """ + + seed: int + inference_results: Sequence[base_model.InferenceResult] + full_fold_input: folding_input.Input + + +def predict_structure( + fold_input: folding_input.Input, + model_runner: ModelRunner, + buckets: Sequence[int] | None = None, +) -> Sequence[ResultsForSeed]: + """Runs the full inference pipeline to predict structures for each seed.""" + + print(f'Featurising data for seeds {fold_input.rng_seeds}...') + featurisation_start_time = time.time() + ccd = chemical_components.cached_ccd(user_ccd=fold_input.user_ccd) + featurised_examples = featurisation.featurise_input( + fold_input=fold_input, buckets=buckets, ccd=ccd, verbose=True + ) + print( + f'Featurising data for seeds {fold_input.rng_seeds} took ' + f' {time.time() - featurisation_start_time:.2f} seconds.' + ) + all_inference_start_time = time.time() + all_inference_results = [] + for seed, example in zip(fold_input.rng_seeds, featurised_examples): + print(f'Running model inference for seed {seed}...') + inference_start_time = time.time() + result = model_runner.run_inference(example) + print( + f'Running model inference for seed {seed} took ' + f' {time.time() - inference_start_time:.2f} seconds.' + ) + print( + f'Extracting output structures (one per sample) for seed {seed}...') + extract_structures = time.time() + inference_results = model_runner.extract_structures( + batch=example, result=result, target_name=fold_input.name + ) + print( + f'Extracting output structures (one per sample) for seed {seed} took ' + f' {time.time() - extract_structures:.2f} seconds.' + ) + all_inference_results.append( + ResultsForSeed( + seed=seed, + inference_results=inference_results, + full_fold_input=fold_input, + ) + ) + print( + 'Running model inference and extracting output structures for seed' + f' {seed} took {time.time() - inference_start_time:.2f} seconds.' + ) + print( + 'Running model inference and extracting output structures for seeds' + f' {fold_input.rng_seeds} took ' + f' {time.time() - all_inference_start_time:.2f} seconds.' + ) + return all_inference_results + + +def write_fold_input_json( + fold_input: folding_input.Input, + output_dir: os.PathLike[str] | str, +) -> None: + """Writes the input JSON to the output directory.""" + os.makedirs(output_dir, exist_ok=True) + with open(os.path.join(output_dir, f'{fold_input.sanitised_name()}_data.json'), 'wt') as f: + f.write(fold_input.to_json()) + + +def write_outputs( + all_inference_results: Sequence[ResultsForSeed], + output_dir: os.PathLike[str] | str, + job_name: str, +) -> None: + """Writes outputs to the specified output directory.""" + ranking_scores = [] + max_ranking_score = None + max_ranking_result = None + + os.makedirs(output_dir, exist_ok=True) + for results_for_seed in all_inference_results: + seed = results_for_seed.seed + for sample_idx, result in enumerate(results_for_seed.inference_results): + sample_dir = os.path.join( + output_dir, f'seed-{seed}_sample-{sample_idx}') + os.makedirs(sample_dir, exist_ok=True) + post_processing.write_output( + inference_result=result, output_dir=sample_dir + ) + ranking_score = float(result.metadata['ranking_score']) + ranking_scores.append((seed, sample_idx, ranking_score)) + if max_ranking_score is None or ranking_score > max_ranking_score: + max_ranking_score = ranking_score + max_ranking_result = result + + if max_ranking_result is not None: # True iff ranking_scores non-empty. + post_processing.write_output( + inference_result=max_ranking_result, + output_dir=output_dir, + # The output terms of use are the same for all seeds/samples. + # terms_of_use=output_terms, + terms_of_use=None, + name=job_name, + ) + # Save csv of ranking scores with seeds and sample indices, to allow easier + # comparison of ranking scores across different runs. + with open(os.path.join(output_dir, 'ranking_scores.csv'), 'wt') as f: + writer = csv.writer(f) + writer.writerow(['seed', 'sample', 'ranking_score']) + writer.writerows(ranking_scores) + + +@overload +def process_fold_input( + fold_input: folding_input.Input, + data_pipeline_config: pipeline.DataPipelineConfig | None, + model_runner: None, + output_dir: os.PathLike[str] | str, + buckets: Sequence[int] | None = None, +) -> folding_input.Input: + ... + + +@overload +def process_fold_input( + fold_input: folding_input.Input, + data_pipeline_config: pipeline.DataPipelineConfig | None, + model_runner: ModelRunner, + output_dir: os.PathLike[str] | str, + buckets: Sequence[int] | None = None, +) -> Sequence[ResultsForSeed]: + ... + + +def replace_db_dir(path_with_db_dir: str, db_dirs: Sequence[str]) -> str: + """Replaces the DB_DIR placeholder in a path with the given DB_DIR.""" + template = string.Template(path_with_db_dir) + if 'DB_DIR' in template.get_identifiers(): + for db_dir in db_dirs: + path = template.substitute(DB_DIR=db_dir) + if os.path.exists(path): + return path + raise FileNotFoundError( + f'{path_with_db_dir} with ${{DB_DIR}} not found in any of {db_dirs}.' + ) + if not os.path.exists(path_with_db_dir): + raise FileNotFoundError(f'{path_with_db_dir} does not exist.') + return path_with_db_dir + + +def process_fold_input( + fold_input: folding_input.Input, + data_pipeline_config: pipeline.DataPipelineConfig | None, + model_runner: ModelRunner | None, + output_dir: os.PathLike[str] | str, + buckets: Sequence[int] | None = None, +) -> folding_input.Input | Sequence[ResultsForSeed]: + """Runs data pipeline and/or inference on a single fold input. + + Args: + fold_input: Fold input to process. + data_pipeline_config: Data pipeline config to use. If None, skip the data + pipeline. + model_runner: Model runner to use. If None, skip inference. + output_dir: Output directory to write to. + buckets: Bucket sizes to pad the data to, to avoid excessive re-compilation + of the model. If None, calculate the appropriate bucket size from the + number of tokens. If not None, must be a sequence of at least one integer, + in strictly increasing order. Will raise an error if the number of tokens + is more than the largest bucket size. + + Returns: + The processed fold input, or the inference results for each seed. + + Raises: + ValueError: If the fold input has no chains. + """ + print(f'Processing fold input {fold_input.name}') + + if not fold_input.chains: + raise ValueError('Fold input has no chains.') + + if os.path.exists(output_dir) and os.listdir(output_dir): + new_output_dir = ( + f'{output_dir}_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}' + ) + print( + f'Output directory {output_dir} exists and non-empty, using instead ' + f' {new_output_dir}.' + ) + output_dir = new_output_dir + + if model_runner is not None: + # If we're running inference, check we can load the model parameters before + # (possibly) launching the data pipeline. + print('Checking we can load the model parameters...') + _ = model_runner.model_params + + if data_pipeline_config is None: + print('Skipping data pipeline...') + else: + print('Running data pipeline...') + fold_input = pipeline.DataPipeline( + data_pipeline_config).process(fold_input) + + print(f'Output directory: {output_dir}') + print(f'Writing model input JSON to {output_dir}') + write_fold_input_json(fold_input, output_dir) + if model_runner is None: + print('Skipping inference...') + output = fold_input + else: + print( + f'Predicting 3D structure for {fold_input.name} for seed(s)' + f' {fold_input.rng_seeds}...' + ) + all_inference_results = predict_structure( + fold_input=fold_input, + model_runner=model_runner, + buckets=buckets, + ) + print( + f'Writing outputs for {fold_input.name} for seed(s)' + f' {fold_input.rng_seeds}...' + ) + write_outputs( + all_inference_results=all_inference_results, + output_dir=output_dir, + job_name=fold_input.sanitised_name(), + ) + output = all_inference_results + + print(f'Done processing fold input {fold_input.name}.') + return output + + +def main(_): + + if _JSON_PATH.value is None == _INPUT_DIR.value is None: + raise ValueError( + 'Exactly one of --json_path or --input_dir must be specified.' + ) + + if not _RUN_INFERENCE.value and not _RUN_DATA_PIPELINE.value: + raise ValueError( + 'At least one of --run_inference or --run_data_pipeline must be' + ' set to true.' + ) + + if _INPUT_DIR.value is not None: + fold_inputs = folding_input.load_fold_inputs_from_dir( + pathlib.Path(_INPUT_DIR.value) + ) + elif _JSON_PATH.value is not None: + fold_inputs = folding_input.load_fold_inputs_from_path( + pathlib.Path(_JSON_PATH.value) + ) + else: + raise AssertionError( + 'Exactly one of --json_path or --input_dir must be specified.' + ) + + # Make sure we can create the output directory before running anything. + try: + os.makedirs(_OUTPUT_DIR.value, exist_ok=True) + except OSError as e: + print(f'Failed to create output directory {_OUTPUT_DIR.value}: {e}') + raise + + notice = textwrap.wrap( + 'Running AlphaFold 3. Please note that standard AlphaFold 3 model' + ' parameters are only available under terms of use provided at' + ' https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md.' + ' If you do not agree to these terms and are using AlphaFold 3 derived' + ' model parameters, cancel execution of AlphaFold 3 inference with' + ' CTRL-C, and do not use the model parameters.', + break_long_words=False, + break_on_hyphens=False, + width=80, + ) + print('\n'.join(notice)) + + if _RUN_DATA_PIPELINE.value: + def expand_path(x): + return replace_db_dir(x, DB_DIR.value) + max_template_date = datetime.date.fromisoformat( + _MAX_TEMPLATE_DATE.value) + data_pipeline_config = pipeline.DataPipelineConfig( + jackhmmer_binary_path=_JACKHMMER_BINARY_PATH.value, + nhmmer_binary_path=_NHMMER_BINARY_PATH.value, + hmmalign_binary_path=_HMMALIGN_BINARY_PATH.value, + hmmsearch_binary_path=_HMMSEARCH_BINARY_PATH.value, + hmmbuild_binary_path=_HMMBUILD_BINARY_PATH.value, + small_bfd_database_path=expand_path( + _SMALL_BFD_DATABASE_PATH.value), + mgnify_database_path=expand_path(_MGNIFY_DATABASE_PATH.value), + uniprot_cluster_annot_database_path=expand_path( + _UNIPROT_CLUSTER_ANNOT_DATABASE_PATH.value + ), + uniref90_database_path=expand_path(_UNIREF90_DATABASE_PATH.value), + ntrna_database_path=expand_path(_NTRNA_DATABASE_PATH.value), + rfam_database_path=expand_path(_RFAM_DATABASE_PATH.value), + rna_central_database_path=expand_path( + _RNA_CENTRAL_DATABASE_PATH.value), + pdb_database_path=expand_path(_PDB_DATABASE_PATH.value), + seqres_database_path=expand_path(_SEQRES_DATABASE_PATH.value), + jackhmmer_n_cpu=_JACKHMMER_N_CPU.value, + nhmmer_n_cpu=_NHMMER_N_CPU.value, + max_template_date=max_template_date, + ) + else: + print('Skipping running the data pipeline.') + data_pipeline_config = None + + if _RUN_INFERENCE.value: + print('Building model from scratch...') + model_runner = ModelRunner( + model_class=diffusion_model.Diffuser, + config=make_model_config( + flash_attention_implementation=typing.cast( + attention.Implementation, _FLASH_ATTENTION_IMPLEMENTATION.value + ) + ), + model_dir=pathlib.Path(MODEL_DIR.value), + ) + else: + print('Skipping running model inference.') + model_runner = None + + print(f'Processing {len(fold_inputs)} fold inputs.') + for fold_input in fold_inputs: + process_fold_input( + fold_input=fold_input, + data_pipeline_config=data_pipeline_config, + model_runner=model_runner, + output_dir=os.path.join( + _OUTPUT_DIR.value, fold_input.sanitised_name()), + buckets=tuple(int(bucket) for bucket in _BUCKETS.value), + ) + + print(f'Done processing {len(fold_inputs)} fold inputs.') + + +if __name__ == '__main__': + flags.mark_flags_as_required([ + 'output_dir', + ]) + app.run(main) diff --git a/MindSPONGE/applications/research/AlphaFold3/set_path.sh b/MindSPONGE/applications/research/AlphaFold3/set_path.sh new file mode 100644 index 0000000000000000000000000000000000000000..f2efd467f939e773c0c64dc8a83a95a0c1d1e677 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/set_path.sh @@ -0,0 +1,37 @@ +#!/bin/bash + +# Get the script directory to make paths more reliable +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +# From AlphaFold3 directory, go up to the mindscience directory +MINDSCIENCE_PATH="$(cd "$SCRIPT_DIR/../../../.." && pwd)" + +# Check if the base directory exists +if [ ! -d "$MINDSCIENCE_PATH" ]; then + echo "Error: MindScience path not found: $MINDSCIENCE_PATH" + echo "Please run this script from the correct directory" + exit 1 +fi + +# Function to add to PYTHONPATH if directory exists +add_to_pythonpath() { + local dir_path="$1" + if [ -d "$dir_path" ]; then + export PYTHONPATH="$PYTHONPATH:$dir_path" + echo "Added to PYTHONPATH: $dir_path" + else + echo "Warning: Directory not found, skipping: $dir_path" + fi +} + +add_to_pythonpath "$MINDSCIENCE_PATH/MindSPONGE/src" +add_to_pythonpath "$MINDSCIENCE_PATH/MindChemistry" +add_to_pythonpath "$MINDSCIENCE_PATH/MindSPONGE/applications/research/AlphaFold3/src" + +# Add directories to PATH +export PATH=$PATH:/hmmer/bin + +# Display current PYTHONPATH +echo "Current PYTHONPATH:" +echo "$PYTHONPATH" | tr ':' '\n' | sed 's/^/ /' + +echo "Environment setup completed." diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/__init__.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/build_data.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/build_data.py new file mode 100644 index 0000000000000000000000000000000000000000..58ae0c88b766cbb5ad1cfa28fe520260e112d326 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/build_data.py @@ -0,0 +1,45 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Script for building intermediate data.""" + +from importlib import resources +import pathlib +import site + +import alphafold3.constants.converters +from alphafold3.constants.converters import ccd_pickle_gen +from alphafold3.constants.converters import chemical_component_sets_gen + + +def build_data(): + """Builds intermediate data.""" + for site_path in site.getsitepackages(): + path = pathlib.Path(site_path) / 'share/libcifpp/components.cif' + if path.exists(): + cif_path = path + break + else: + raise ValueError('Could not find components.cif') + + out_root = resources.files(alphafold3.constants.converters) + ccd_pickle_path = out_root.joinpath('ccd.pickle') + chemical_component_sets_pickle_path = out_root.joinpath( + 'chemical_component_sets.pickle' + ) + ccd_pickle_gen.main(['', str(cif_path), str(ccd_pickle_path)]) + chemical_component_sets_gen.main( + ['', str(chemical_component_sets_pickle_path)] + ) + + +if __name__ == '__main__': + build_data() diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/common/base_config.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/common/base_config.py new file mode 100644 index 0000000000000000000000000000000000000000..27f6eba12e3f2ae7d04689e49a0dbc3e25883cb4 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/common/base_config.py @@ -0,0 +1,151 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ +"""Config for the protein folding model and experiment.""" + +from collections.abc import Mapping +import copy +import dataclasses +import types +import typing +from typing import Any, ClassVar, TypeVar + + +_T = TypeVar('_T') +_ConfigT = TypeVar('_ConfigT', bound='BaseConfig') + + +def _strip_optional(t: type[Any]) -> type[Any]: + """Transforms type annotations of the form `T | None` to `T`.""" + if typing.get_origin(t) in (typing.Union, types.UnionType): + args = set(typing.get_args(t)) - {types.NoneType} + if len(args) == 1: + return args.pop() + return t + + +_NO_UPDATE = object() + + +class _Autocreate: + + def __init__(self, **defaults: Any): + self.defaults = defaults + + +def autocreate(**defaults: Any) -> Any: + """Marks a field as having a default factory derived from its type.""" + return _Autocreate(**defaults) + + +def _clone_field( + field: dataclasses.Field[_T], new_default: _T +) -> dataclasses.Field[_T]: + if new_default is _NO_UPDATE: + return copy.copy(field) + return dataclasses.field( + default=new_default, + init=True, + kw_only=True, + repr=field.repr, + hash=field.hash, + compare=field.compare, + metadata=field.metadata, + ) + + +@typing.dataclass_transform() +class ConfigMeta(type): + """Metaclass that synthesizes a __post_init__ that coerces dicts to Config subclass instances.""" + + def __new__(mcs, name, bases, classdict): + cls = super().__new__(mcs, name, bases, classdict) + + def _coercable_fields(self) -> Mapping[str, tuple[ConfigMeta, Any]]: + type_hints = typing.get_type_hints(self.__class__) + fields = dataclasses.fields(self.__class__) + field_to_type_and_default = { + field.name: (_strip_optional( + type_hints[field.name]), field.default) + for field in fields + } + coercable_fields = { + f: t + for f, t in field_to_type_and_default.items() + if issubclass(type(t[0]), ConfigMeta) + } + return coercable_fields + + cls._coercable_fields = property(_coercable_fields) + + old_post_init = getattr(cls, '__post_init__', None) + + def _post_init(self) -> None: + # Use get_type_hints instead of Field.type to ensure that forward + # references are resolved. + for field_name, ( + field_type, + field_default, + ) in self._coercable_fields.items(): # pylint: disable=protected-access + field_value = getattr(self, field_name) + if field_value is None: + continue + try: + match field_value: + case _Autocreate(): + # Construct from field defaults. + setattr(self, field_name, field_type( + **field_value.defaults)) + case Mapping(): + # Field value is not yet a `Config` instance; Assume we can create + # one by splatting keys and values. + args = {} + # Apply default args first, if present. + if isinstance(field_default, _Autocreate): + args.update(field_default.defaults) + args.update(field_value) + setattr(self, field_name, field_type(**args)) + case _: + pass + except TypeError as e: + raise TypeError( + f'Failure while coercing field {field_name!r} of' + f' {self.__class__.__qualname__}' + ) from e + if old_post_init: + old_post_init(self) + + cls.__post_init__ = _post_init + + return dataclasses.dataclass(kw_only=True)(cls) + + +class BaseConfig(metaclass=ConfigMeta): + """Config base class. + + Subclassing Config automatically makes the subclass a kw_only dataclass with + a `__post_init__` that coerces Config-subclass field values from mappings to + instances of the right type. + """ + # Provided by dataclasses.make_dataclass + __dataclass_fields__: ClassVar[dict[str, dataclasses.Field[Any]]] + + # Overridden by metaclass + @property + def _coercable_fields(self) -> Mapping[str, tuple[type['BaseConfig'], Any]]: + return {} + + def as_dict(self) -> Mapping[str, Any]: + result = dataclasses.asdict(self) + for field_name in self._coercable_fields: + field_value = getattr(self, field_name, None) + if isinstance(field_value, BaseConfig): + result[field_name] = field_value.as_dict() + return result diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/common/folding_input.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/common/folding_input.py new file mode 100644 index 0000000000000000000000000000000000000000..cba8d0556396f6b62c094d034d402f50e7a884ed --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/common/folding_input.py @@ -0,0 +1,1115 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Model input dataclass.""" + +from collections.abc import Collection, Mapping, Sequence +import dataclasses +import json +import logging +import pathlib +import random +import re +import string +from typing_extensions import Any, Final, Self, TypeAlias + +from alphafold3 import structure +from alphafold3.constants import chemical_components +from alphafold3.constants import mmcif_names +from alphafold3.constants import residue_names +from alphafold3.structure import mmcif as mmcif_lib +import rdkit.Chem as rd_chem + + +BondAtomId: TypeAlias = tuple[str, int, str] + +JSON_DIALECT: Final[str] = 'alphafold3' +JSON_VERSION: Final[int] = 1 + +ALPHAFOLDSERVER_JSON_DIALECT: Final[str] = 'alphafoldserver' +ALPHAFOLDSERVER_JSON_VERSION: Final[int] = 1 + + +def _validate_keys(actual: Collection[str], expected: Collection[str]): + """Validates that the JSON doesn't contain any extra unwanted keys.""" + if bad_keys := set(actual) - set(expected): + raise ValueError( + f'Unexpected JSON keys in: {", ".join(sorted(bad_keys))}') + + +class Template: + """Structural template input.""" + + __slots__ = ('_mmcif', '_query_to_template') + + def __init__(self, mmcif: str, query_to_template_map: Mapping[int, int]): + """Initializes the template. + + Args: + mmcif: The structural template in mmCIF format. The mmCIF should have only + one protein chain. + query_to_template_map: A mapping from query residue index to template + residue index. + """ + self._mmcif = mmcif + # Needed to make the Template class hashable. + self._query_to_template = tuple(query_to_template_map.items()) + + @property + def query_to_template_map(self) -> Mapping[int, int]: + return dict(self._query_to_template) + + @property + def mmcif(self) -> str: + return self._mmcif + + def __hash__(self) -> int: + return hash((self._mmcif, tuple(sorted(self._query_to_template)))) + + def __eq__(self, other: Self) -> bool: + mmcifs_equal = self._mmcif == other._mmcif + maps_equal = sorted(self._query_to_template) == sorted( + other._query_to_template + ) + return mmcifs_equal and maps_equal + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class ProteinChain: + """Protein chain input. + + Attributes: + id: Unique protein chain identifier. + sequence: The amino acid sequence of the chain. + ptms: A list of tuples containing the post-translational modification type + and the (1-based) residue index where the modification is applied. + paired_msa: Paired A3M-formatted MSA for this chain. This MSA is not + deduplicated and will be used to compute paired features. If None, this + field is unset and must be filled in by the data pipeline before + featurisation. If set to an empty string, it will be treated as a custom + MSA with no sequences. + unpaired_msa: Unpaired A3M-formatted MSA for this chain. This will be + deduplicated and used to compute unpaired features. If None, this field is + unset and must be filled in by the data pipeline before featurisation. If + set to an empty string, it will be treated as a custom MSA with no + sequences. + templates: A list of structural templates for this chain. If None, this + field is unset and must be filled in by the data pipeline before + featurisation. The list can be empty or contain up to 20 templates. + """ + + id: str + sequence: str + ptms: Sequence[tuple[str, int]] + paired_msa: str | None = None + unpaired_msa: str | None = None + templates: Sequence[Template] | None = None + + def __post_init__(self): + if not all(res.isalpha() for res in self.sequence): + raise ValueError( + f'Protein must contain only letters, got "{self.sequence}"' + ) + if any(not 0 < mod[1] <= len(self.sequence) for mod in self.ptms): + raise ValueError( + f'Invalid protein modification index: {self.ptms}') + + # Use hashable types for ptms and templates. + if self.ptms is not None: + object.__setattr__(self, 'ptms', tuple(self.ptms)) + if self.templates is not None: + object.__setattr__(self, 'templates', tuple(self.templates)) + + @classmethod + def from_alphafoldserver_dict( + cls, json_dict: Mapping[str, Any], seq_id: str + ) -> Self: + """Constructs ProteinChain from the AlphaFoldServer JSON dict.""" + _validate_keys( + json_dict.keys(), + {'sequence', 'glycans', 'modifications', 'count'}, + ) + sequence = json_dict['sequence'] + + if 'glycans' in json_dict: + raise ValueError( + f'Specifying glycans in the `{ALPHAFOLDSERVER_JSON_DIALECT}` format' + ' is not currently supported.' + ) + + ptms = [ + (mod['ptmType'].removeprefix('CCD_'), mod['ptmPosition']) + for mod in json_dict.get('modifications', []) + ] + return cls(id=seq_id, sequence=sequence, ptms=ptms) + + @classmethod + def from_dict( + cls, json_dict: Mapping[str, Any], seq_id: str | None = None + ) -> Self: + """Constructs ProteinChain from the AlphaFold JSON dict.""" + json_dict = json_dict['protein'] + _validate_keys( + json_dict.keys(), + { + 'id', + 'sequence', + 'modifications', + 'unpairedMsa', + 'pairedMsa', + 'templates', + }, + ) + + sequence = json_dict['sequence'] + ptms = [ + (mod['ptmType'], mod['ptmPosition']) + for mod in json_dict.get('modifications', []) + ] + + unpaired_msa = json_dict.get('unpairedMsa', None) + paired_msa = json_dict.get('pairedMsa', None) + + raw_templates = json_dict.get('templates', None) + + if raw_templates is None: + templates = None + else: + templates = [ + Template( + mmcif=template['mmcif'], + query_to_template_map=dict( + zip(template['queryIndices'], + template['templateIndices']) + ), + ) + for template in raw_templates + ] + + return cls( + id=seq_id or json_dict['id'], + sequence=sequence, + ptms=ptms, + paired_msa=paired_msa, + unpaired_msa=unpaired_msa, + templates=templates, + ) + + def to_dict(self) -> Mapping[str, Mapping[str, Any]]: + """Converts ProteinChain to an AlphaFold JSON dict.""" + if self.templates is None: + templates = None + else: + templates = [ + { + 'mmcif': template.mmcif, + 'queryIndices': list(template.query_to_template_map.keys()), + 'templateIndices': ( + list(template.query_to_template_map.values()) or None + ), + } + for template in self.templates + ] + contents = { + 'id': self.id, + 'sequence': self.sequence, + 'modifications': [ + {'ptmType': ptm[0], 'ptmPosition': ptm[1]} for ptm in self.ptms + ], + 'unpairedMsa': self.unpaired_msa, + 'pairedMsa': self.paired_msa, + 'templates': templates, + } + return {'protein': contents} + + def to_ccd_sequence(self) -> Sequence[str]: + """Converts to a sequence of CCD codes.""" + ccd_coded_seq = [ + residue_names.PROTEIN_COMMON_ONE_TO_THREE.get( + res, residue_names.UNK) + for res in self.sequence + ] + for ptm_code, ptm_index in self.ptms: + ccd_coded_seq[ptm_index - 1] = ptm_code + return ccd_coded_seq + + def fill_missing_fields(self) -> Self: + """Fill missing MSA and template fields with default values.""" + return dataclasses.replace( + self, + unpaired_msa=self.unpaired_msa or '', + paired_msa=self.paired_msa or '', + templates=self.templates or [], + ) + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class RnaChain: + """RNA chain input. + + Attributes: + id: Unique RNA chain identifier. + sequence: The RNA sequence of the chain. + modifications: A list of tuples containing the modification type and the + (1-based) residue index where the modification is applied. + unpaired_msa: Unpaired A3M-formatted MSA for this chain. This will be + deduplicated and used to compute unpaired features. If None, this field is + unset and must be filled in by the data pipeline before featurisation. If + set to an empty string, it will be treated as a custom MSA with no + sequences. + """ + + id: str + sequence: str + modifications: Sequence[tuple[str, int]] + unpaired_msa: str | None = None + + def __post_init__(self): + if not all(res.isalpha() for res in self.sequence): + raise ValueError( + f'RNA must contain only letters, got "{self.sequence}"') + if any(not 0 < mod[1] <= len(self.sequence) for mod in self.modifications): + raise ValueError( + f'Invalid RNA modification index: {self.modifications}') + + # Use hashable types for modifications. + object.__setattr__(self, 'modifications', tuple(self.modifications)) + + @classmethod + def from_alphafoldserver_dict( + cls, json_dict: Mapping[str, Any], seq_id: str + ) -> Self: + """Constructs RnaChain from the AlphaFoldServer JSON dict.""" + _validate_keys(json_dict.keys(), { + 'sequence', 'modifications', 'count'}) + sequence = json_dict['sequence'] + modifications = [ + (mod['modificationType'].removeprefix('CCD_'), mod['basePosition']) + for mod in json_dict.get('modifications', []) + ] + return cls(id=seq_id, sequence=sequence, modifications=modifications) + + @classmethod + def from_dict( + cls, json_dict: Mapping[str, Any], seq_id: str | None = None + ) -> Self: + """Constructs RnaChain from the AlphaFold JSON dict.""" + json_dict = json_dict['rna'] + _validate_keys( + json_dict.keys(), {'id', 'sequence', + 'unpairedMsa', 'modifications'} + ) + sequence = json_dict['sequence'] + modifications = [ + (mod['modificationType'], mod['basePosition']) + for mod in json_dict.get('modifications', []) + ] + unpaired_msa = json_dict.get('unpairedMsa', None) + return cls( + id=seq_id or json_dict['id'], + sequence=sequence, + modifications=modifications, + unpaired_msa=unpaired_msa, + ) + + def to_dict(self) -> Mapping[str, Mapping[str, Any]]: + """Converts RnaChain to an AlphaFold JSON dict.""" + contents = { + 'id': self.id, + 'sequence': self.sequence, + 'modifications': [ + {'modificationType': mod[0], 'basePosition': mod[1]} + for mod in self.modifications + ], + 'unpairedMsa': self.unpaired_msa, + } + return {'rna': contents} + + def to_ccd_sequence(self) -> Sequence[str]: + """Converts to a sequence of CCD codes.""" + mapping = { + r: r for r in residue_names.RNA_TYPES} # Same 1-letter and CCD. + ccd_coded_seq = [ + mapping.get(res, residue_names.UNK_RNA) for res in self.sequence + ] + for ccd_code, modification_index in self.modifications: + ccd_coded_seq[modification_index - 1] = ccd_code + return ccd_coded_seq + + def fill_missing_fields(self) -> Self: + """Fill missing MSA fields with default values.""" + return dataclasses.replace(self, unpaired_msa=self.unpaired_msa or '') + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class DnaChain: + """Single strand DNA chain input. + + Attributes: + id: Unique DNA chain identifier. + sequence: The DNA sequence of the chain. + modifications: A list of tuples containing the modification type and the + (1-based) residue index where the modification is applied. + """ + + id: str + sequence: str + modifications: Sequence[tuple[str, int]] + + def __post_init__(self): + if not all(res.isalpha() for res in self.sequence): + raise ValueError( + f'DNA must contain only letters, got "{self.sequence}"') + if any(not 0 < mod[1] <= len(self.sequence) for mod in self.modifications): + raise ValueError( + f'Invalid DNA modification index: {self.modifications}') + + # Use hashable types for modifications. + object.__setattr__(self, 'modifications', tuple(self.modifications)) + + @classmethod + def from_alphafoldserver_dict( + cls, json_dict: Mapping[str, Any], seq_id: str + ) -> Self: + """Constructs DnaChain from the AlphaFoldServer JSON dict.""" + _validate_keys(json_dict.keys(), { + 'sequence', 'modifications', 'count'}) + sequence = json_dict['sequence'] + modifications = [ + (mod['modificationType'].removeprefix('CCD_'), mod['basePosition']) + for mod in json_dict.get('modifications', []) + ] + return cls(id=seq_id, sequence=sequence, modifications=modifications) + + @classmethod + def from_dict( + cls, json_dict: Mapping[str, Any], seq_id: str | None = None + ) -> Self: + """Constructs DnaChain from the AlphaFold JSON dict.""" + json_dict = json_dict['dna'] + _validate_keys(json_dict.keys(), {'id', 'sequence', 'modifications'}) + sequence = json_dict['sequence'] + modifications = [ + (mod['modificationType'], mod['basePosition']) + for mod in json_dict.get('modifications', []) + ] + return cls( + id=seq_id or json_dict['id'], + sequence=sequence, + modifications=modifications, + ) + + def to_dict(self) -> Mapping[str, Mapping[str, Any]]: + """Converts DnaChain to an AlphaFold JSON dict.""" + contents = { + 'id': self.id, + 'sequence': self.sequence, + 'modifications': [ + {'modificationType': mod[0], 'basePosition': mod[1]} + for mod in self.modifications + ], + } + return {'dna': contents} + + def to_ccd_sequence(self) -> Sequence[str]: + """Converts to a sequence of CCD codes.""" + ccd_coded_seq = [ + residue_names.DNA_COMMON_ONE_TO_TWO.get(res, residue_names.UNK_DNA) + for res in self.sequence + ] + for ccd_code, modification_index in self.modifications: + ccd_coded_seq[modification_index - 1] = ccd_code + return ccd_coded_seq + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class Ligand: + """Ligand input. + + Attributes: + id: Unique ligand "chain" identifier. + ccd_ids: The Chemical Component Dictionary or user-defined CCD IDs of the + chemical components of the ligand. Typically, this is just a single ID, + but some ligands are composed of multiple components. If that is the case, + a bond linking these components should be added to the bonded_atom_pairs + Input field. + smiles: The SMILES representation of the ligand. + """ + + id: str + ccd_ids: Sequence[str] | None = None + smiles: str | None = None + + def __post_init__(self): + if (self.ccd_ids is None) == (self.smiles is None): + raise ValueError('Ligand must have one of CCD ID or SMILES set.') + + if self.smiles is not None: + mol = rd_chem.MolFromSmiles(self.smiles) + if not mol: + raise ValueError( + f'Unable to make RDKit Mol from SMILES: {self.smiles}') + + # Use hashable types for ccd_ids. + if self.ccd_ids is not None: + object.__setattr__(self, 'ccd_ids', tuple(self.ccd_ids)) + + @classmethod + def from_alphafoldserver_dict( + cls, json_dict: Mapping[str, Any], seq_id: str + ) -> Self: + """Constructs Ligand from the AlphaFoldServer JSON dict.""" + # Ligand can be specified either as a ligand, or ion (special-case). + _validate_keys(json_dict.keys(), {'ligand', 'ion', 'count'}) + if 'ligand' in json_dict: + return cls(id=seq_id, ccd_ids=[json_dict['ligand'].removeprefix('CCD_')]) + elif 'ion' in json_dict: + return cls(id=seq_id, ccd_ids=[json_dict['ion']]) + else: + raise ValueError(f'Unknown ligand type: {json_dict}') + + @classmethod + def from_dict( + cls, json_dict: Mapping[str, Any], seq_id: str | None = None + ) -> Self: + """Constructs Ligand from the AlphaFold JSON dict.""" + json_dict = json_dict['ligand'] + _validate_keys(json_dict.keys(), {'id', 'ccdCodes', 'smiles'}) + if json_dict.get('ccdCodes') and json_dict.get('smiles'): + raise ValueError( + 'Ligand cannot have both CCD code and SMILES set at the same time, ' + f'got CCD: {json_dict["ccdCodes"]} and SMILES: {json_dict["smiles"]}' + ) + + if 'ccdCodes' in json_dict: + return cls(id=seq_id or json_dict['id'], ccd_ids=json_dict['ccdCodes']) + elif 'smiles' in json_dict: + return cls(id=seq_id or json_dict['id'], smiles=json_dict['smiles']) + else: + raise ValueError(f'Unknown ligand type: {json_dict}') + + def to_dict(self) -> Mapping[str, Any]: + """Converts Ligand to an AlphaFold JSON dict.""" + contents = {'id': self.id} + if self.ccd_ids is not None: + contents['ccdCodes'] = self.ccd_ids + if self.smiles is not None: + contents['smiles'] = self.smiles + return {'ligand': contents} + + +def _sample_rng_seed() -> int: + """Sample a random seed for AlphaFoldServer job.""" + # See https://alphafoldserver.com/faq#what-are-seeds-and-how-are-they-set. + return random.randint(0, 2**32 - 1) + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class Input: + """AlphaFold input. + + Attributes: + name: The name of the target. + chains: Protein chains, RNA chains, DNA chains, or ligands. + protein_chains: Protein chains. + rna_chains: RNA chains. + dna_chains: Single strand DNA chains. + ligands: Ligand (including ion) inputs. + rng_seeds: Random number generator seeds, one for each model execution. + bonded_atom_pairs: A list of tuples of atoms that are bonded to each other. + Each atom is defined by a tuple of (chain_id, res_id, atom_name). Chain + IDs must be set if there are any bonded atoms. Residue IDs are 1-indexed. + Atoms in ligands defined by SMILES can't be bonded since SMILES doesn't + define unique atom names. + user_ccd: Optional user-defined chemical component dictionary in the CIF + format. This can be used to provide additional CCD entries that are not + present in the default CCD and thus define arbitrary new ligands. This is + more expressive than SMILES since it allows to name all atoms within the + ligand which in turn makes it possible to define bonds using those atoms. + """ + + name: str + chains: Sequence[ProteinChain | RnaChain | DnaChain | Ligand] + rng_seeds: Sequence[int] + bonded_atom_pairs: Sequence[tuple[BondAtomId, BondAtomId]] | None = None + user_ccd: str | None = None + + def __post_init__(self): + if not self.rng_seeds: + raise ValueError('Input must have at least one RNG seed.') + + if not self.name.strip() or not self.sanitised_name(): + raise ValueError( + 'Input name must be non-empty and contain at least one valid' + ' character (letters, numbers, dots, dashes, underscores).' + ) + + chain_ids = [c.id for c in self.chains] + if any(not c.id.isalpha() or c.id.islower() for c in self.chains): + raise ValueError( + f'IDs must be upper case letters, got: {chain_ids}') + if len(set(chain_ids)) != len(chain_ids): + raise ValueError( + 'Input JSON contains sequences with duplicate IDs.') + + # Use hashable types for chains, rng_seeds, and bonded_atom_pairs. + object.__setattr__(self, 'chains', tuple(self.chains)) + object.__setattr__(self, 'rng_seeds', tuple(self.rng_seeds)) + if self.bonded_atom_pairs is not None: + object.__setattr__( + self, 'bonded_atom_pairs', tuple(self.bonded_atom_pairs) + ) + + @property + def protein_chains(self) -> Sequence[ProteinChain]: + return [chain for chain in self.chains if isinstance(chain, ProteinChain)] + + @property + def rna_chains(self) -> Sequence[RnaChain]: + return [chain for chain in self.chains if isinstance(chain, RnaChain)] + + @property + def dna_chains(self) -> Sequence[DnaChain]: + return [chain for chain in self.chains if isinstance(chain, DnaChain)] + + @property + def ligands(self) -> Sequence[Ligand]: + return [chain for chain in self.chains if isinstance(chain, Ligand)] + + @classmethod + def from_alphafoldserver_fold_job(cls, fold_job: Mapping[str, Any]) -> Self: + """Constructs Input from an AlphaFoldServer fold job.""" + + # Validate the fold job has the correct format. + _validate_keys( + fold_job.keys(), + {'name', 'modelSeeds', 'sequences', 'dialect', 'version'}, + ) + if 'dialect' not in fold_job and 'version' not in fold_job: + dialect = ALPHAFOLDSERVER_JSON_DIALECT + version = ALPHAFOLDSERVER_JSON_VERSION + elif 'dialect' in fold_job and 'version' in fold_job: + dialect = fold_job['dialect'] + version = fold_job['version'] + else: + raise ValueError( + 'AlphaFold Server input JSON must either contain both `dialect` and' + ' `version` fields, or neither. If neither is specified, it is' + f' assumed that `dialect="{ALPHAFOLDSERVER_JSON_DIALECT}"` and' + f' `version="{ALPHAFOLDSERVER_JSON_VERSION}"`.' + ) + + if dialect != ALPHAFOLDSERVER_JSON_DIALECT: + raise ValueError( + f'AlphaFold Server input JSON has unsupported dialect: {dialect}, ' + f'expected {ALPHAFOLDSERVER_JSON_DIALECT}.' + ) + + # For now, there is only one AlphaFold Server JSON version. + if version != ALPHAFOLDSERVER_JSON_VERSION: + raise ValueError( + f'AlphaFold Server input JSON has unsupported version: {version}, ' + f'expected {ALPHAFOLDSERVER_JSON_VERSION}.' + ) + + # Parse the chains. + chains = [] + for sequence in fold_job['sequences']: + if 'proteinChain' in sequence: + for _ in range(sequence['proteinChain'].get('count', 1)): + chains.append( + ProteinChain.from_alphafoldserver_dict( + sequence['proteinChain'], + seq_id=mmcif_lib.int_id_to_str_id(len(chains) + 1), + ) + ) + elif 'rnaSequence' in sequence: + for _ in range(sequence['rnaSequence'].get('count', 1)): + chains.append( + RnaChain.from_alphafoldserver_dict( + sequence['rnaSequence'], + seq_id=mmcif_lib.int_id_to_str_id(len(chains) + 1), + ) + ) + elif 'dnaSequence' in sequence: + for _ in range(sequence['dnaSequence'].get('count', 1)): + chains.append( + DnaChain.from_alphafoldserver_dict( + sequence['dnaSequence'], + seq_id=mmcif_lib.int_id_to_str_id(len(chains) + 1), + ) + ) + elif 'ion' in sequence: + for _ in range(sequence['ion'].get('count', 1)): + chains.append( + Ligand.from_alphafoldserver_dict( + sequence['ion'], + seq_id=mmcif_lib.int_id_to_str_id(len(chains) + 1), + ) + ) + elif 'ligand' in sequence: + for _ in range(sequence['ligand'].get('count', 1)): + chains.append( + Ligand.from_alphafoldserver_dict( + sequence['ligand'], + seq_id=mmcif_lib.int_id_to_str_id(len(chains) + 1), + ) + ) + else: + raise ValueError(f'Unknown sequence type: {sequence}') + + if 'modelSeeds' in fold_job and fold_job['modelSeeds']: + rng_seeds = [int(seed) for seed in fold_job['modelSeeds']] + else: + rng_seeds = [_sample_rng_seed()] + + return cls(name=fold_job['name'], chains=chains, rng_seeds=rng_seeds) + + @classmethod + def from_json(cls, json_str: str) -> Self: + """Loads the input from the AlphaFold JSON string.""" + raw_json = json.loads(json_str) + + _validate_keys( + raw_json.keys(), + { + 'dialect', + 'version', + 'name', + 'modelSeeds', + 'sequences', + 'bondedAtomPairs', + 'userCCD', + }, + ) + + if 'dialect' not in raw_json or 'version' not in raw_json: + raise ValueError( + 'AlphaFold 3 input JSON must contain `dialect` and `version` fields.' + ) + + if raw_json['dialect'] != JSON_DIALECT: + raise ValueError( + 'AlphaFold 3 input JSON has unsupported dialect:' + f' {raw_json["dialect"]}, expected {JSON_DIALECT}.' + ) + + # For now, there is only one AlphaFold 3 JSON version. + if raw_json['version'] != JSON_VERSION: + raise ValueError( + 'AlphaFold 3 input JSON has unsupported version:' + f' {raw_json["version"]}, expected {JSON_VERSION}.' + ) + + if 'sequences' not in raw_json: + raise ValueError( + 'AlphaFold 3 input JSON does not contain any sequences.') + + if 'modelSeeds' not in raw_json or not raw_json['modelSeeds']: + raise ValueError( + 'AlphaFold 3 input JSON must specify at least one rng seed in' + ' `modelSeeds`.' + ) + + sequences = raw_json['sequences'] + + # Make sure sequence IDs are all set. + raw_sequence_ids = [next(iter(s.values())).get('id') + for s in sequences] + if all(raw_sequence_ids): + sequence_ids = [] + for sequence_id in raw_sequence_ids: + if isinstance(sequence_id, list): + sequence_ids.append(sequence_id) + else: + sequence_ids.append([sequence_id]) + else: + raise ValueError( + 'AlphaFold 3 input JSON contains sequences with unset IDs.' + ) + + flat_seq_ids = [] + for seq_ids in sequence_ids: + flat_seq_ids.extend(seq_ids) + + chains = [] + for seq_ids, sequence in zip(sequence_ids, sequences, strict=True): + if len(sequence) != 1: + raise ValueError(f'Chain {seq_ids} has more than 1 sequence.') + for seq_id in seq_ids: + if 'protein' in sequence: + chains.append(ProteinChain.from_dict( + sequence, seq_id=seq_id)) + elif 'rna' in sequence: + chains.append(RnaChain.from_dict(sequence, seq_id=seq_id)) + elif 'dna' in sequence: + chains.append(DnaChain.from_dict(sequence, seq_id=seq_id)) + elif 'ligand' in sequence: + chains.append(Ligand.from_dict(sequence, seq_id=seq_id)) + else: + raise ValueError(f'Unknown sequence type: {sequence}') + + ligands = [chain for chain in chains if isinstance(chain, Ligand)] + bonded_atom_pairs = None + if bonds := raw_json.get('bondedAtomPairs'): + bonded_atom_pairs = [] + for bond in bonds: + if len(bond) != 2: + raise ValueError( + f'Bond {bond} must have 2 atoms, got {len(bond)}.') + bond_beg, bond_end = bond + if ( + len(bond_beg) != 3 + or not isinstance(bond_beg[0], str) + or not isinstance(bond_beg[1], int) + or not isinstance(bond_beg[2], str) + ): + raise ValueError( + f'Atom {bond_beg} in bond {bond} must have 3 components: ' + '(chain_id: str, res_id: int, atom_name: str).' + ) + if ( + len(bond_end) != 3 + or not isinstance(bond_end[0], str) + or not isinstance(bond_end[1], int) + or not isinstance(bond_end[2], str) + ): + raise ValueError( + f'Atom {bond_end} in bond {bond} must have 3 components: ' + '(chain_id: str, res_id: int, atom_name: str).' + ) + if bond_beg[0] not in flat_seq_ids or bond_end[0] not in flat_seq_ids: + raise ValueError(f'Invalid chain ID(s) in bond {bond}') + if bond_beg[1] <= 0 or bond_end[1] <= 0: + raise ValueError(f'Invalid residue ID(s) in bond {bond}') + smiles_ligand_ids = set( + l.id for l in ligands if l.smiles is not None) + if bond_beg[0] in smiles_ligand_ids: + raise ValueError( + f'Bond {bond} involves an unsupported SMILES ligand {bond_beg[0]}' + ) + if bond_end[0] in smiles_ligand_ids: + raise ValueError( + f'Bond {bond} involves an unsupported SMILES ligand {bond_end[0]}' + ) + bonded_atom_pairs.append((tuple(bond_beg), tuple(bond_end))) + + return cls( + name=raw_json['name'], + chains=chains, + rng_seeds=[int(seed) for seed in raw_json['modelSeeds']], + bonded_atom_pairs=bonded_atom_pairs, + user_ccd=raw_json.get('userCCD'), + ) + + @classmethod + def from_mmcif(cls, mmcif_str: str, ccd: chemical_components.Ccd) -> Self: + """Loads the input from an mmCIF string. + + WARNING: Since rng seeds are not stored in mmCIFs, an rng seed is sampled + in the returned `Input`. + + Args: + mmcif_str: The mmCIF string. + ccd: The chemical components dictionary. + + Returns: + The input in an Input format. + """ + + struct = structure.from_mmcif( + mmcif_str, + include_water=False, + fix_mse_residues=True, + fix_unknown_dna=True, + include_bonds=True, + include_other=False, + ) + + # Create default bioassembly, expanding structures implied by stoichiometry. + struct = struct.generate_bioassembly(None) + + sequences = struct.chain_single_letter_sequence( + include_missing_residues=True + ) + + chains = [] + for chain_id, chain_type in zip( + struct.group_by_chain.chain_id, struct.group_by_chain.chain_type + ): + sequence = sequences[chain_id] + + if chain_type in mmcif_names.NON_POLYMER_CHAIN_TYPES: + residues = list(struct.chain_res_name_sequence()[chain_id]) + if all(ccd.get(res) is not None for res in residues): + chains.append(Ligand(id=chain_id, ccd_ids=residues)) + elif len(residues) == 1: + comp_name = residues[0] + comps = struct.chemical_components_data + if comps is None: + raise ValueError( + 'Missing mmCIF chemical components data - this is required for ' + f'a non-CCD ligand {comp_name} defined using SMILES string.' + ) + chains.append( + Ligand(id=chain_id, + smiles=comps.chem_comp[comp_name].pdbx_smiles) + ) + else: + raise ValueError( + 'Multi-component ligand must be defined using CCD IDs, defining' + ' using SMILES is supported only for single-component ligands. ' + f'Got {residues}' + ) + else: + residues = struct.chain_res_name_sequence()[chain_id] + fixed = struct.chain_res_name_sequence( + fix_non_standard_polymer_res=True + )[chain_id] + modifications = [ + (orig, i + 1) + for i, (orig, fixed) in enumerate(zip(residues, fixed, strict=True)) + if orig != fixed + ] + + if chain_type == mmcif_names.PROTEIN_CHAIN: + chains.append( + ProteinChain(id=chain_id, sequence=sequence, + ptms=modifications) + ) + elif chain_type == mmcif_names.RNA_CHAIN: + chains.append( + RnaChain( + id=chain_id, sequence=sequence, modifications=modifications + ) + ) + elif chain_type == mmcif_names.DNA_CHAIN: + chains.append( + DnaChain( + id=chain_id, sequence=sequence, modifications=modifications + ) + ) + + bonded_atom_pairs = [] + chain_ids = set(c.id for c in chains) + for atom_a, atom_b, _ in struct.iter_bonds(): + if atom_a['chain_id'] in chain_ids and atom_b['chain_id'] in chain_ids: + beg = (atom_a['chain_id'], int( + atom_a['res_id']), atom_a['atom_name']) + end = (atom_b['chain_id'], int( + atom_b['res_id']), atom_b['atom_name']) + bonded_atom_pairs.append((beg, end)) + + return cls( + name=struct.name, + chains=chains, + # mmCIFs don't store rng seeds, so we need to sample one here. + rng_seeds=[_sample_rng_seed()], + bonded_atom_pairs=bonded_atom_pairs or None, + ) + + def to_structure(self, ccd: chemical_components.Ccd) -> structure.Structure: + """Converts Input to a Structure. + + WARNING: This method does not preserve the rng seeds. + + Args: + ccd: The chemical components dictionary. + + Returns: + The input in a structure.Structure format. + """ + ids: list[str] = [] + sequences: list[str] = [] + poly_types: list[str] = [] + formats: list[structure.SequenceFormat] = [] + + for chain in self.chains: + ids.append(chain.id) + match chain: + case ProteinChain(): + sequences.append( + '(' + ')('.join(chain.to_ccd_sequence()) + ')') + poly_types.append(mmcif_names.PROTEIN_CHAIN) + formats.append(structure.SequenceFormat.CCD_CODES) + case RnaChain(): + sequences.append( + '(' + ')('.join(chain.to_ccd_sequence()) + ')') + poly_types.append(mmcif_names.RNA_CHAIN) + formats.append(structure.SequenceFormat.CCD_CODES) + case DnaChain(): + sequences.append( + '(' + ')('.join(chain.to_ccd_sequence()) + ')') + poly_types.append(mmcif_names.DNA_CHAIN) + formats.append(structure.SequenceFormat.CCD_CODES) + case Ligand(): + if chain.ccd_ids is not None: + sequences.append('(' + ')('.join(chain.ccd_ids) + ')') + if len(chain.ccd_ids) == 1: + poly_types.append(mmcif_names.NON_POLYMER_CHAIN) + else: + poly_types.append(mmcif_names.BRANCHED_CHAIN) + formats.append(structure.SequenceFormat.CCD_CODES) + elif chain.smiles is not None: + # Convert to `:` format that is expected + # by structure.from_sequences_and_bonds. + sequences.append(f'LIG_{chain.id}:{chain.smiles}') + poly_types.append(mmcif_names.NON_POLYMER_CHAIN) + formats.append(structure.SequenceFormat.LIGAND_SMILES) + else: + raise ValueError( + 'Ligand must have one of CCD ID or SMILES set.') + + # Remap bond chain IDs from chain IDs to chain indices and convert to + # 0-based residue indexing. + bonded_atom_pairs = [] + chain_indices = {cid: i for i, cid in enumerate(ids)} + if self.bonded_atom_pairs is not None: + for bond_beg, bond_end in self.bonded_atom_pairs: + bonded_atom_pairs.append(( + (chain_indices[bond_beg[0]], bond_beg[1] - 1, bond_beg[2]), + (chain_indices[bond_end[0]], bond_end[1] - 1, bond_end[2]), + )) + + struct = structure.from_sequences_and_bonds( + sequences=sequences, + chain_types=poly_types, + sequence_formats=formats, + bonded_atom_pairs=bonded_atom_pairs, + ccd=ccd, + name=self.sanitised_name(), + bond_type=mmcif_names.COVALENT_BOND, + release_date=None, + ) + # Rename chain IDs to the original ones. + return struct.rename_chain_ids(dict(zip(struct.chains, ids, strict=True))) + + def to_json(self) -> str: + """Converts Input to an AlphaFold JSON.""" + alphafold_json = json.dumps( + { + 'dialect': JSON_DIALECT, + 'version': JSON_VERSION, + 'name': self.name, + 'sequences': [chain.to_dict() for chain in self.chains], + 'modelSeeds': self.rng_seeds, + 'bondedAtomPairs': self.bonded_atom_pairs, + 'userCCD': self.user_ccd, + }, + indent=2, + ) + # Remove newlines from the query/template indices arrays. We match the + # queryIndices/templatesIndices with a non-capturing group. We then match + # the entire region between the square brackets by looking for lines + # containing only whitespace, number, or a comma. + return re.sub( + r'("(?:queryIndices|templateIndices)": \[)([\s\n\d,]+)(\],?)', + lambda mtch: mtch[1] + + re.sub(r'\n\s+', ' ', mtch[2].strip()) + mtch[3], + alphafold_json, + ) + + def fill_missing_fields(self) -> Self: + """Fill missing MSA and template fields with default values.""" + with_missing_fields = [ + c.fill_missing_fields() + if isinstance(c, (ProteinChain, RnaChain)) + else c + for c in self.chains + ] + return dataclasses.replace(self, chains=with_missing_fields) + + def sanitised_name(self) -> str: + """Returns sanitised version of the name that can be used as a filename.""" + lower_spaceless_name = self.name.lower().replace(' ', '_') + allowed_chars = set(string.ascii_lowercase + string.digits + '_-.') + return ''.join(l for l in lower_spaceless_name if l in allowed_chars) + + +def check_unique_sanitised_names(fold_inputs: Sequence[Input]) -> None: + """Checks that the names of the fold inputs are unique.""" + names = [fi.sanitised_name() for fi in fold_inputs] + if len(set(names)) != len(names): + raise ValueError( + f'Fold inputs must have unique sanitised names, got {names}.' + ) + + +def load_fold_inputs_from_path(json_path: pathlib.Path) -> Sequence[Input]: + """Loads multiple fold inputs from a JSON string.""" + with open(json_path, 'r') as f: + json_str = f.read() + + # Parse the JSON string, so we can detect its format. + raw_json = json.loads(json_str) + + fold_inputs = [] + if isinstance(raw_json, list): + # AlphaFold Server JSON. + logging.info( + 'Detected %s is an AlphaFold Server JSON since the top-level is a' + ' list.', + json_path, + ) + + logging.info('Loading %d fold jobs from %s', len(raw_json), json_path) + for fold_job_idx, fold_job in enumerate(raw_json): + try: + fold_inputs.append( + Input.from_alphafoldserver_fold_job(fold_job)) + except ValueError as e: + raise ValueError( + f'Failed to load fold job {fold_job_idx} from {json_path}. The JSON' + f' at {json_path} was detected to be an AlphaFold Server JSON since' + ' the top-level is a list.' + ) from e + else: + logging.info( + 'Detected %s is an AlphaFold 3 JSON since the top-level is not a list.', + json_path, + ) + # AlphaFold 3 JSON. + try: + fold_inputs.append(Input.from_json(json_str)) + except ValueError as e: + raise ValueError( + f'Failed to load fold input from {json_path}. The JSON at' + f' {json_path} was detected to be an AlphaFold 3 JSON since the' + ' top-level is not a list.' + ) from e + + check_unique_sanitised_names(fold_inputs) + + return fold_inputs + + +def load_fold_inputs_from_dir(input_dir: pathlib.Path) -> Sequence[Input]: + """Loads multiple fold inputs from all JSON files in a given input_dir. + + Args: + input_dir: The directory containing the JSON files. + + Returns: + The fold inputs from all JSON files in the input directory. + + Raises: + ValueError: If the fold inputs have non-unique sanitised names. + """ + fold_inputs = [] + for file_path in input_dir.glob('*.json'): + if not file_path.is_file(): + continue + + fold_inputs.extend(load_fold_inputs_from_path(file_path)) + + check_unique_sanitised_names(fold_inputs) + + return fold_inputs diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/common/resources.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/common/resources.py new file mode 100644 index 0000000000000000000000000000000000000000..74b40c148b5fc4dc2b441f8d4c4f0cdab966ec50 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/common/resources.py @@ -0,0 +1,77 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Load external resources, such as external tools or data resources.""" + +from collections.abc import Iterator +import os +import pathlib +import typing +from typing import BinaryIO, Final, Literal, TextIO + +from importlib import resources +import alphafold3.common + + +_DATA_ROOT: Final[pathlib.Path] = ( + resources.files(alphafold3.common).joinpath('..').resolve() +) +ROOT = _DATA_ROOT + + +def filename(name: str | os.PathLike[str]) -> str: + """Returns the absolute path to an external resource. + + Note that this calls resources.GetResourceFilename under the hood and hence + causes par file unpacking, which might be unfriendly on diskless machines. + + + Args: + name: the name of the resource corresponding to its path relative to the + root of the repository. + """ + return (_DATA_ROOT / name).as_posix() + + +@typing.overload +def open_resource( + name: str | os.PathLike[str], mode: Literal['r', 'rt'] = 'rt' +) -> TextIO: + ... + + +@typing.overload +def open_resource( + name: str | os.PathLike[str], mode: Literal['rb'] +) -> BinaryIO: + ... + + +def open_resource( + name: str | os.PathLike[str], mode: str = 'rb' +) -> TextIO | BinaryIO: + """Returns an open file object for the named resource. + + Args: + name: the name of the resource corresponding to its path relative to the + root of the repository. + mode: the mode to use when opening the file. + """ + return (_DATA_ROOT / name).open(mode) + + +def get_resource_dir(path: str | os.PathLike[str]) -> os.PathLike[str]: + return _DATA_ROOT / path + + +def walk(path: str) -> Iterator[tuple[str, list[str], list[str]]]: + """Walks the directory tree of resources similar to os.walk.""" + return os.walk((_DATA_ROOT / path).as_posix()) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/common/testing/data.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/common/testing/data.py new file mode 100644 index 0000000000000000000000000000000000000000..97a69d2c10f420983a04ec7b70d71a89fd095a75 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/common/testing/data.py @@ -0,0 +1,70 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Module that provides an abstraction for accessing test data.""" + +import os +import pathlib +from typing import Literal, overload + +from absl.testing import absltest + + +class Data: + """Provides an abstraction for accessing test data.""" + + def __init__(self, data_dir: os.PathLike[str] | str): + """Initiailizes data wrapper, providing users with high level data access. + + Args: + data_dir: Directory containing test data. + """ + self._data_dir = pathlib.Path(data_dir) + + def path(self, data_name: str | os.PathLike[str] | None = None) -> str: + """Returns the path to a given test data. + + Args: + data_name: the name of the test data file relative to data_dir. If not + set, this will return the absolute path to the data directory. + """ + data_dir_path = ( + pathlib.Path(absltest.get_default_test_srcdir()) / self._data_dir + ) + + if data_name: + return str(data_dir_path / data_name) + + return str(data_dir_path) + + @overload + def load( + self, data_name: str | os.PathLike[str], mode: Literal['rt'] = 'rt' + ) -> str: + ... + + @overload + def load( + self, data_name: str | os.PathLike[str], mode: Literal['rb'] = 'rb' + ) -> bytes: + ... + + def load( + self, data_name: str | os.PathLike[str], mode: str = 'rt' + ) -> str | bytes: + """Returns the contents of a given test data. + + Args: + data_name: the name of the test data file relative to data_dir. + mode: the mode in which to read the data file. Defaults to text ('rt'). + """ + with open(self.path(data_name), mode=mode) as f: + return f.read() diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/atom_types.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/atom_types.py new file mode 100644 index 0000000000000000000000000000000000000000..8630278a18f9c6ee44b57300b8698d42fb00f994 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/atom_types.py @@ -0,0 +1,262 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""List of atom types with reverse look-up.""" + +from collections.abc import Mapping, Sequence, Set +import itertools +import sys +from typing import Final +from alphafold3.constants import residue_names + +# Note: +# `sys.intern` places the values in the Python internal db for fast lookup. + +# 37 common residue atoms. +N = sys.intern('N') +CA = sys.intern('CA') +C = sys.intern('C') +CB = sys.intern('CB') +O = sys.intern('O') +CG = sys.intern('CG') +CG1 = sys.intern('CG1') +CG2 = sys.intern('CG2') +OG = sys.intern('OG') +OG1 = sys.intern('OG1') +SG = sys.intern('SG') +CD = sys.intern('CD') +CD1 = sys.intern('CD1') +CD2 = sys.intern('CD2') +ND1 = sys.intern('ND1') +ND2 = sys.intern('ND2') +OD1 = sys.intern('OD1') +OD2 = sys.intern('OD2') +SD = sys.intern('SD') +CE = sys.intern('CE') +CE1 = sys.intern('CE1') +CE2 = sys.intern('CE2') +CE3 = sys.intern('CE3') +NE = sys.intern('NE') +NE1 = sys.intern('NE1') +NE2 = sys.intern('NE2') +OE1 = sys.intern('OE1') +OE2 = sys.intern('OE2') +CH2 = sys.intern('CH2') +NH1 = sys.intern('NH1') +NH2 = sys.intern('NH2') +OH = sys.intern('OH') +CZ = sys.intern('CZ') +CZ2 = sys.intern('CZ2') +CZ3 = sys.intern('CZ3') +NZ = sys.intern('NZ') +OXT = sys.intern('OXT') + +# 29 common nucleic acid atoms. +C1PRIME = sys.intern("C1'") +C2 = sys.intern('C2') +C2PRIME = sys.intern("C2'") +C3PRIME = sys.intern("C3'") +C4 = sys.intern('C4') +C4PRIME = sys.intern("C4'") +C5 = sys.intern('C5') +C5PRIME = sys.intern("C5'") +C6 = sys.intern('C6') +C7 = sys.intern('C7') +C8 = sys.intern('C8') +N1 = sys.intern('N1') +N2 = sys.intern('N2') +N3 = sys.intern('N3') +N4 = sys.intern('N4') +N6 = sys.intern('N6') +N7 = sys.intern('N7') +N9 = sys.intern('N9') +O2 = sys.intern('O2') +O2PRIME = sys.intern("O2'") +O3PRIME = sys.intern("O3'") +O4 = sys.intern('O4') +O4PRIME = sys.intern("O4'") +O5PRIME = sys.intern("O5'") +O6 = sys.intern('O6') +OP1 = sys.intern('OP1') +OP2 = sys.intern('OP2') +OP3 = sys.intern('OP3') +P = sys.intern('P') + +# A list of atoms (excluding hydrogen) for each AA type. PDB naming convention. +RESIDUE_ATOMS: Mapping[str, tuple[str, ...]] = { + residue_names.ALA: (C, CA, CB, N, O), + residue_names.ARG: (C, CA, CB, CG, CD, CZ, N, NE, O, NH1, NH2), + residue_names.ASN: (C, CA, CB, CG, N, ND2, O, OD1), + residue_names.ASP: (C, CA, CB, CG, N, O, OD1, OD2), + residue_names.CYS: (C, CA, CB, N, O, SG), + residue_names.GLN: (C, CA, CB, CG, CD, N, NE2, O, OE1), + residue_names.GLU: (C, CA, CB, CG, CD, N, O, OE1, OE2), + residue_names.GLY: (C, CA, N, O), + residue_names.HIS: (C, CA, CB, CG, CD2, CE1, N, ND1, NE2, O), + residue_names.ILE: (C, CA, CB, CG1, CG2, CD1, N, O), + residue_names.LEU: (C, CA, CB, CG, CD1, CD2, N, O), + residue_names.LYS: (C, CA, CB, CG, CD, CE, N, NZ, O), + residue_names.MET: (C, CA, CB, CG, CE, N, O, SD), + residue_names.PHE: (C, CA, CB, CG, CD1, CD2, CE1, CE2, CZ, N, O), + residue_names.PRO: (C, CA, CB, CG, CD, N, O), + residue_names.SER: (C, CA, CB, N, O, OG), + residue_names.THR: (C, CA, CB, CG2, N, O, OG1), + residue_names.TRP: + (C, CA, CB, CG, CD1, CD2, CE2, CE3, CZ2, CZ3, CH2, N, NE1, O), + residue_names.TYR: (C, CA, CB, CG, CD1, CD2, CE1, CE2, CZ, N, O, OH), + residue_names.VAL: (C, CA, CB, CG1, CG2, N, O), +} # pyformat: disable + +# Used to identify backbone for alignment and distance calculation for sterics. +PROTEIN_BACKBONE_ATOMS: tuple[str, ...] = (N, CA, C) + +# Naming swaps for ambiguous atom names. Due to symmetries in the amino acids +# the naming of atoms is ambiguous in 4 of the 20 amino acids. (The LDDT paper +# lists 7 amino acids as ambiguous, but the naming ambiguities in LEU, VAL and +# ARG can be resolved by using the 3D constellations of the 'ambiguous' atoms +# and their neighbours) +AMBIGUOUS_ATOM_NAMES: Mapping[str, Mapping[str, str]] = { + residue_names.ASP: {OD1: OD2}, + residue_names.GLU: {OE1: OE2}, + residue_names.PHE: {CD1: CD2, CE1: CE2}, + residue_names.TYR: {CD1: CD2, CE1: CE2}, +} + +# Used when we need to store atom data in a format that requires fixed atom data +# size for every protein residue (e.g. a numpy array). +ATOM37: tuple[str, ...] = ( + N, CA, C, CB, O, CG, CG1, CG2, OG, OG1, SG, CD, CD1, CD2, ND1, ND2, OD1, + OD2, SD, CE, CE1, CE2, CE3, NE, NE1, NE2, OE1, OE2, CH2, NH1, NH2, OH, CZ, + CZ2, CZ3, NZ, OXT) # pyformat: disable +ATOM37_ORDER: Mapping[str, int] = {name: i for i, name in enumerate(ATOM37)} +ATOM37_NUM: Final[int] = len(ATOM37) # := 37. + +# Used when we need to store protein atom data in a format that requires fixed +# atom data size for any residue but takes less space than ATOM37 by having 14 +# fields, which is sufficient for storing atoms of all protein residues (e.g. a +# numpy array). +ATOM14: Mapping[str, tuple[str, ...]] = { + residue_names.ALA: (N, CA, C, O, CB), + residue_names.ARG: (N, CA, C, O, CB, CG, CD, NE, CZ, NH1, NH2), + residue_names.ASN: (N, CA, C, O, CB, CG, OD1, ND2), + residue_names.ASP: (N, CA, C, O, CB, CG, OD1, OD2), + residue_names.CYS: (N, CA, C, O, CB, SG), + residue_names.GLN: (N, CA, C, O, CB, CG, CD, OE1, NE2), + residue_names.GLU: (N, CA, C, O, CB, CG, CD, OE1, OE2), + residue_names.GLY: (N, CA, C, O), + residue_names.HIS: (N, CA, C, O, CB, CG, ND1, CD2, CE1, NE2), + residue_names.ILE: (N, CA, C, O, CB, CG1, CG2, CD1), + residue_names.LEU: (N, CA, C, O, CB, CG, CD1, CD2), + residue_names.LYS: (N, CA, C, O, CB, CG, CD, CE, NZ), + residue_names.MET: (N, CA, C, O, CB, CG, SD, CE), + residue_names.PHE: (N, CA, C, O, CB, CG, CD1, CD2, CE1, CE2, CZ), + residue_names.PRO: (N, CA, C, O, CB, CG, CD), + residue_names.SER: (N, CA, C, O, CB, OG), + residue_names.THR: (N, CA, C, O, CB, OG1, CG2), + residue_names.TRP: + (N, CA, C, O, CB, CG, CD1, CD2, NE1, CE2, CE3, CZ2, CZ3, CH2), + residue_names.TYR: (N, CA, C, O, CB, CG, CD1, CD2, CE1, CE2, CZ, OH), + residue_names.VAL: (N, CA, C, O, CB, CG1, CG2), + residue_names.UNK: (), +} # pyformat: disable + +# A compact atom encoding with 14 columns, padded with '' in empty slots. +ATOM14_PADDED: Mapping[str, Sequence[str]] = { + k: [v for _, v in itertools.zip_longest(range(14), values, fillvalue='')] + for k, values in ATOM14.items() +} + +ATOM14_ORDER: Mapping[str, Mapping[str, int]] = { + k: {name: i for i, name in enumerate(v)} for k, v in ATOM14.items() +} +ATOM14_NUM: Final[int] = max(len(v) for v in ATOM14.values()) + +# Used when we need to store protein and nucleic atom library. +DENSE_ATOM: Mapping[str, tuple[str, ...]] = { + # Protein. + residue_names.ALA: (N, CA, C, O, CB), + residue_names.ARG: (N, CA, C, O, CB, CG, CD, NE, CZ, NH1, NH2), + residue_names.ASN: (N, CA, C, O, CB, CG, OD1, ND2), + residue_names.ASP: (N, CA, C, O, CB, CG, OD1, OD2), + residue_names.CYS: (N, CA, C, O, CB, SG), + residue_names.GLN: (N, CA, C, O, CB, CG, CD, OE1, NE2), + residue_names.GLU: (N, CA, C, O, CB, CG, CD, OE1, OE2), + residue_names.GLY: (N, CA, C, O), + residue_names.HIS: (N, CA, C, O, CB, CG, ND1, CD2, CE1, NE2), + residue_names.ILE: (N, CA, C, O, CB, CG1, CG2, CD1), + residue_names.LEU: (N, CA, C, O, CB, CG, CD1, CD2), + residue_names.LYS: (N, CA, C, O, CB, CG, CD, CE, NZ), + residue_names.MET: (N, CA, C, O, CB, CG, SD, CE), + residue_names.PHE: (N, CA, C, O, CB, CG, CD1, CD2, CE1, CE2, CZ), + residue_names.PRO: (N, CA, C, O, CB, CG, CD), + residue_names.SER: (N, CA, C, O, CB, OG), + residue_names.THR: (N, CA, C, O, CB, OG1, CG2), + residue_names.TRP: + (N, CA, C, O, CB, CG, CD1, CD2, NE1, CE2, CE3, CZ2, CZ3, CH2), + residue_names.TYR: (N, CA, C, O, CB, CG, CD1, CD2, CE1, CE2, CZ, OH), + residue_names.VAL: (N, CA, C, O, CB, CG1, CG2), + residue_names.UNK: (), + # RNA. + residue_names.A: + (OP3, P, OP1, OP2, O5PRIME, C5PRIME, C4PRIME, O4PRIME, C3PRIME, O3PRIME, + C2PRIME, O2PRIME, C1PRIME, N9, C8, N7, C5, C6, N6, N1, C2, N3, C4), + residue_names.C: + (OP3, P, OP1, OP2, O5PRIME, C5PRIME, C4PRIME, O4PRIME, C3PRIME, O3PRIME, + C2PRIME, O2PRIME, C1PRIME, N1, C2, O2, N3, C4, N4, C5, C6), + residue_names.G: + (OP3, P, OP1, OP2, O5PRIME, C5PRIME, C4PRIME, O4PRIME, C3PRIME, O3PRIME, + C2PRIME, O2PRIME, C1PRIME, N9, C8, N7, C5, C6, O6, N1, C2, N2, N3, C4), + residue_names.U: + (OP3, P, OP1, OP2, O5PRIME, C5PRIME, C4PRIME, O4PRIME, C3PRIME, O3PRIME, + C2PRIME, O2PRIME, C1PRIME, N1, C2, O2, N3, C4, O4, C5, C6), + residue_names.UNK_RNA: (), + # DNA. + residue_names.DA: + (OP3, P, OP1, OP2, O5PRIME, C5PRIME, C4PRIME, O4PRIME, C3PRIME, O3PRIME, + C2PRIME, C1PRIME, N9, C8, N7, C5, C6, N6, N1, C2, N3, C4), + residue_names.DC: + (OP3, P, OP1, OP2, O5PRIME, C5PRIME, C4PRIME, O4PRIME, C3PRIME, O3PRIME, + C2PRIME, C1PRIME, N1, C2, O2, N3, C4, N4, C5, C6), + residue_names.DG: + (OP3, P, OP1, OP2, O5PRIME, C5PRIME, C4PRIME, O4PRIME, C3PRIME, O3PRIME, + C2PRIME, C1PRIME, N9, C8, N7, C5, C6, O6, N1, C2, N2, N3, C4), + residue_names.DT: + (OP3, P, OP1, OP2, O5PRIME, C5PRIME, C4PRIME, O4PRIME, C3PRIME, O3PRIME, + C2PRIME, C1PRIME, N1, C2, O2, N3, C4, O4, C5, C7, C6), + # Unknown nucleic. + residue_names.UNK_DNA: (), +} # pyformat: disable + +DENSE_ATOM_ORDER: Mapping[str, Mapping[str, int]] = { + k: {name: i for i, name in enumerate(v)} for k, v in DENSE_ATOM.items() +} +DENSE_ATOM_NUM: Final[int] = max(len(v) for v in DENSE_ATOM.values()) + +# Used when we need to store atom data in a format that requires fixed atom data +# size for every nucleic molecule (e.g. a numpy array). +ATOM29: tuple[str, ...] = ( + "C1'", 'C2', "C2'", "C3'", 'C4', "C4'", 'C5', "C5'", 'C6', 'C7', 'C8', 'N1', + 'N2', 'N3', 'N4', 'N6', 'N7', 'N9', 'OP3', 'O2', "O2'", "O3'", 'O4', "O4'", + "O5'", 'O6', 'OP1', 'OP2', 'P') # pyformat: disable +ATOM29_ORDER: Mapping[str, int] = { + atom_type: i for i, atom_type in enumerate(ATOM29) +} +ATOM29_NUM: Final[int] = len(ATOM29) # := 29 + +# Hydrogens that exist depending on the protonation state of the residue. +# Extracted from third_party/py/openmm/app/data/hydrogens.xml +PROTONATION_HYDROGENS: Mapping[str, Set[str]] = { + 'ASP': {'HD2'}, + 'CYS': {'HG'}, + 'GLU': {'HE2'}, + 'HIS': {'HD1', 'HE2'}, + 'LYS': {'HZ3'}, +} diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/chemical_component_sets.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/chemical_component_sets.py new file mode 100644 index 0000000000000000000000000000000000000000..eaf7b5db4a665cfbd6ec247f9b2faefb408fa3f2 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/chemical_component_sets.py @@ -0,0 +1,38 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Sets of chemical components.""" + +import pickle +from typing import Final + +from alphafold3.common import resources + + +_CCD_SETS_CCD_PICKLE_FILE = resources.filename( + resources.ROOT / 'constants/converters/chemical_component_sets.pickle' +) + +_CCD_SET = pickle.load(open(_CCD_SETS_CCD_PICKLE_FILE, 'rb')) + +# Glycan (or 'Saccharide') ligands. +# _chem_comp.type containing 'saccharide' and 'linking' (when lower-case). +GLYCAN_LINKING_LIGANDS: Final[frozenset[str]] = _CCD_SET['glycans_linking'] + +# _chem_comp.type containing 'saccharide' and not 'linking' (when lower-case). +GLYCAN_OTHER_LIGANDS: Final[frozenset[str]] = _CCD_SET['glycans_other'] + +# Each of these molecules appears in over 1k PDB structures, are used to +# facilitate crystallization conditions, but do not have biological relevance. +COMMON_CRYSTALLIZATION_AIDS: Final[frozenset[str]] = frozenset({ + 'SO4', 'GOL', 'EDO', 'PO4', 'ACT', 'PEG', 'DMS', 'TRS', 'PGE', 'PG4', 'FMT', + 'EPE', 'MPD', 'MES', 'CD', 'IOD', +}) # pyformat: disable diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/chemical_components.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/chemical_components.py new file mode 100644 index 0000000000000000000000000000000000000000..d1132d9955d2bb263b826e34338f90fc7fd5b2f9 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/chemical_components.py @@ -0,0 +1,192 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Chemical Components found in PDB (CCD) constants.""" + +from collections.abc import ItemsView, Iterator, KeysView, Mapping, Sequence, ValuesView +import dataclasses +import functools +import os +import pickle + +from alphafold3.common import resources +from alphafold3.cpp import cif_dict + + +_CCD_PICKLE_FILE = resources.filename( + resources.ROOT / 'constants/converters/ccd.pickle' +) + + +class Ccd(Mapping[str, Mapping[str, Sequence[str]]]): + """Chemical Components found in PDB (CCD) constants. + + See https://academic.oup.com/bioinformatics/article/31/8/1274/212200 for CCD + CIF format documentation. + + Wraps the dict to prevent accidental mutation. + """ + + __slots__ = ('_dict', '_ccd_pickle_path') + + def __init__( + self, + ccd_pickle_path: os.PathLike[str] | None = None, + user_ccd: str | None = None, + ): + """Initialises the chemical components dictionary. + + Args: + ccd_pickle_path: Path to the CCD pickle file. If None, uses the default + CCD pickle file included in the source code. + user_ccd: A string containing the user-provided CCD. This has to conform + to the same format as the CCD, see https://www.wwpdb.org/data/ccd. If + provided, takes precedence over the CCD for the the same key. This can + be used to override specific entries in the CCD if desired. + """ + self._ccd_pickle_path = ccd_pickle_path or _CCD_PICKLE_FILE + with open(self._ccd_pickle_path, 'rb') as f: + self._dict = pickle.loads(f.read()) + + if user_ccd is not None: + if not user_ccd: + raise ValueError('User CCD cannot be an empty string.') + user_ccd_cifs = { + key: {k: tuple(v) for k, v in value.items()} + for key, value in cif_dict.parse_multi_data_cif(user_ccd).items() + } + self._dict.update(user_ccd_cifs) + + def __getitem__(self, key: str) -> Mapping[str, Sequence[str]]: + return self._dict[key] + + def __contains__(self, key: str) -> bool: + return key in self._dict + + def __iter__(self) -> Iterator[str]: + return self._dict.__iter__() + + def __len__(self) -> int: + return len(self._dict) + + def __hash__(self) -> int: + return id(self) # Ok since this is immutable. + + def get( + self, key: str, default: None | Mapping[str, Sequence[str]] = None + ) -> Mapping[str, Sequence[str]] | None: + return self._dict.get(key, default) + + def items(self) -> ItemsView[str, Mapping[str, Sequence[str]]]: + return self._dict.items() + + def values(self) -> ValuesView[Mapping[str, Sequence[str]]]: + return self._dict.values() + + def keys(self) -> KeysView[str]: + return self._dict.keys() + + +@functools.cache +def cached_ccd(user_ccd: str | None = None) -> Ccd: + return Ccd(user_ccd=user_ccd) + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class ComponentInfo: + name: str + type: str + pdbx_synonyms: str + formula: str + formula_weight: str + mon_nstd_parent_comp_id: str + mon_nstd_flag: str + pdbx_smiles: str + + +def mmcif_to_info(mmcif: Mapping[str, Sequence[str]]) -> ComponentInfo: + """Converts CCD mmCIFs to component info. Missing fields are left empty.""" + names = mmcif['_chem_comp.name'] + types = mmcif['_chem_comp.type'] + mon_nstd_parent_comp_ids = mmcif['_chem_comp.mon_nstd_parent_comp_id'] + pdbx_synonyms = mmcif['_chem_comp.pdbx_synonyms'] + formulas = mmcif['_chem_comp.formula'] + formula_weights = mmcif['_chem_comp.formula_weight'] + + def front_or_empty(values: Sequence[str]) -> str: + return values[0] if values else '' + + type_ = front_or_empty(types) + mon_nstd_parent_comp_id = front_or_empty(mon_nstd_parent_comp_ids) + if type_.lower() == 'non-polymer': + # Unset for non-polymers, e.g. water or ions. + mon_nstd_flag = '.' + elif mon_nstd_parent_comp_id == '?': + # A standard component - it doesn't have a standard parent, e.g. MET. + mon_nstd_flag = 'y' + else: + # A non-standard component, e.g. MSE. + mon_nstd_flag = 'n' + + canonical_pdbx_smiles = '' + fallback_pdbx_smiles = '' + descriptor_types = mmcif.get('_pdbx_chem_comp_descriptor.type', []) + descriptors = mmcif.get('_pdbx_chem_comp_descriptor.descriptor', []) + programs = mmcif.get('_pdbx_chem_comp_descriptor.program', []) + + for descriptor_type, descriptor in zip(descriptor_types, descriptors): + if descriptor_type == 'SMILES_CANONICAL': + if (not canonical_pdbx_smiles) or programs == 'OpenEye OEToolkits': + canonical_pdbx_smiles = descriptor + if not fallback_pdbx_smiles and descriptor_type == 'SMILES': + fallback_pdbx_smiles = descriptor + pdbx_smiles = canonical_pdbx_smiles or fallback_pdbx_smiles + + return ComponentInfo( + name=front_or_empty(names), + type=type_, + pdbx_synonyms=front_or_empty(pdbx_synonyms), + formula=front_or_empty(formulas), + formula_weight=front_or_empty(formula_weights), + mon_nstd_parent_comp_id=mon_nstd_parent_comp_id, + mon_nstd_flag=mon_nstd_flag, + pdbx_smiles=pdbx_smiles, + ) + + +@functools.lru_cache(maxsize=128) +def component_name_to_info(ccd: Ccd, res_name: str) -> ComponentInfo | None: + component = ccd.get(res_name) + if component is None: + return None + return mmcif_to_info(component) + + +def type_symbol(ccd: Ccd, res_name: str, atom_name: str) -> str: + """Returns the element type for the given component name and atom name. + + Args: + ccd: The chemical components dictionary. + res_name: The component name, e.g. ARG. + atom_name: The atom name, e.g. CB, OXT, or NH1. + + Returns: + Element type, e.g. C for (ARG, CB), O for (ARG, OXT), N for (ARG, NH1). + """ + res = ccd.get(res_name) + if res is None: + return '?' + try: + return res['_chem_comp_atom.type_symbol'][ + res['_chem_comp_atom.atom_id'].index(atom_name) + ] + except (ValueError, IndexError, KeyError): + return '?' diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/converters/ccd_pickle_gen.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/converters/ccd_pickle_gen.py new file mode 100644 index 0000000000000000000000000000000000000000..e793f216b38433ceb5ec55ba4037657cee4ae418 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/converters/ccd_pickle_gen.py @@ -0,0 +1,53 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Reads Chemical Components gz file and generates a CCD pickle file.""" + +from collections.abc import Sequence +import gzip +import pickle +import sys + +from alphafold3.cpp import cif_dict +import tqdm + + +def main(argv: Sequence[str]) -> None: + if len(argv) != 3: + raise ValueError( + 'Must specify input_file components.cif and output_file') + + _, input_file, output_file = argv + + print(f'Parsing {input_file}', flush=True) + if input_file.endswith('.gz'): + opener = gzip.open + else: + opener = open + + with opener(input_file, 'rb') as f: + whole_file = f.read() + result = { + key: {k: tuple(v) for k, v in value.items()} + for key, value in tqdm.tqdm( + cif_dict.parse_multi_data_cif(whole_file).items() + ) + } + assert len(result) == whole_file.count(b'data_') + + print(f'Writing {output_file}', flush=True) + with open(output_file, 'wb') as f: + pickle.dump(result, f, protocol=pickle.HIGHEST_PROTOCOL) + print('Done', flush=True) + + +if __name__ == '__main__': + main(sys.argv) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/converters/chemical_component_sets_gen.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/converters/chemical_component_sets_gen.py new file mode 100644 index 0000000000000000000000000000000000000000..31d05f7d2ac4765bd05d60c746a8d8e24d259bbd --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/converters/chemical_component_sets_gen.py @@ -0,0 +1,81 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Script for updating chemical_component_sets.py.""" + +from collections.abc import Mapping, Sequence +import pathlib +import pickle +import re +import sys + +from alphafold3.common import resources +import tqdm + + +_CCD_PICKLE_FILE = resources.filename( + 'constants/converters/ccd.pickle' +) + + +def find_ions_and_glycans_in_ccd( + ccd: Mapping[str, Mapping[str, Sequence[str]]], +) -> dict[str, frozenset[str]]: + """Finds glycans and ions in all version of CCD.""" + glycans_linking = [] + glycans_other = [] + ions = [] + for name, comp in tqdm.tqdm(ccd.items()): + if name == 'UNX': + continue # Skip "unknown atom or ion". + comp_type = comp['_chem_comp.type'][0].lower() + # Glycans have the type 'saccharide'. + if re.findall(r'\bsaccharide\b', comp_type): + # Separate out linking glycans from others. + if 'linking' in comp_type: + glycans_linking.append(name) + else: + glycans_other.append(name) + + # Ions have the word 'ion' in their name. + comp_name = comp['_chem_comp.name'][0].lower() + if re.findall(r'\bion\b', comp_name): + ions.append(name) + result = dict( + glycans_linking=frozenset(glycans_linking), + glycans_other=frozenset(glycans_other), + ions=frozenset(ions), + ) + + return result + + +def main(argv: Sequence[str]) -> None: + if len(argv) != 2: + raise ValueError( + 'Directory to write to must be specified as a command-line arguments.' + ) + + print(f'Loading {_CCD_PICKLE_FILE}', flush=True) + with open(_CCD_PICKLE_FILE, 'rb') as f: + ccd: Mapping[str, Mapping[str, Sequence[str]]] = pickle.load(f) + output_path = pathlib.Path(argv[1]) + output_path.parent.mkdir(exist_ok=True) + print('Finding ions and glycans', flush=True) + result = find_ions_and_glycans_in_ccd(ccd) + print(f'writing to {output_path}', flush=True) + with output_path.open('wb') as f: + pickle.dump(result, f) + print('Done', flush=True) + + +if __name__ == '__main__': + main(sys.argv) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/mmcif_names.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/mmcif_names.py new file mode 100644 index 0000000000000000000000000000000000000000..15eabf2f98156ee9d9bb01427f11996e423aca6d --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/mmcif_names.py @@ -0,0 +1,218 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Names of things in mmCIF format. + +See https://www.iucr.org/__data/iucr/cifdic_html/2/cif_mm.dic/index.html +""" + +from collections.abc import Mapping, Sequence, Set +from typing import Final + +from alphafold3.constants import atom_types +from alphafold3.constants import residue_names + + +# The following are all possible values for the "_entity.type". +# https://mmcif.wwpdb.org/dictionaries/mmcif_pdbx_v50.dic/Items/_entity.type.html +BRANCHED_CHAIN: Final[str] = 'branched' +MACROLIDE_CHAIN: Final[str] = 'macrolide' +NON_POLYMER_CHAIN: Final[str] = 'non-polymer' +POLYMER_CHAIN: Final[str] = 'polymer' +WATER: Final[str] = 'water' + +CYCLIC_PSEUDO_PEPTIDE_CHAIN: Final[str] = 'cyclic-pseudo-peptide' +DNA_CHAIN: Final[str] = 'polydeoxyribonucleotide' +DNA_RNA_HYBRID_CHAIN: Final[str] = ( + 'polydeoxyribonucleotide/polyribonucleotide hybrid' +) +OTHER_CHAIN: Final[str] = 'other' +PEPTIDE_NUCLEIC_ACID_CHAIN: Final[str] = 'peptide nucleic acid' +POLYPEPTIDE_D_CHAIN: Final[str] = 'polypeptide(D)' +PROTEIN_CHAIN: Final[str] = 'polypeptide(L)' +RNA_CHAIN: Final[str] = 'polyribonucleotide' + +# Most common _entity_poly.types. +STANDARD_POLYMER_CHAIN_TYPES: Final[Set[str]] = { + PROTEIN_CHAIN, + DNA_CHAIN, + RNA_CHAIN, +} + +# Possible values for _entity.type other than polymer and water. +LIGAND_CHAIN_TYPES: Final[Set[str]] = { + BRANCHED_CHAIN, + MACROLIDE_CHAIN, + NON_POLYMER_CHAIN, +} + +# Possible values for _entity.type other than polymer. +NON_POLYMER_CHAIN_TYPES: Final[Set[str]] = { + *LIGAND_CHAIN_TYPES, + WATER, +} + +# Peptide possible values for _entity_poly.type. +PEPTIDE_CHAIN_TYPES: Final[Set[str]] = { + CYCLIC_PSEUDO_PEPTIDE_CHAIN, + POLYPEPTIDE_D_CHAIN, + PROTEIN_CHAIN, + PEPTIDE_NUCLEIC_ACID_CHAIN, +} + + +# Nucleic-acid possible values for _entity_poly.type. +NUCLEIC_ACID_CHAIN_TYPES: Final[Set[str]] = { + RNA_CHAIN, + DNA_CHAIN, + DNA_RNA_HYBRID_CHAIN, +} + +# All possible values for _entity_poly.type. +POLYMER_CHAIN_TYPES: Final[Set[str]] = { + *NUCLEIC_ACID_CHAIN_TYPES, + *PEPTIDE_CHAIN_TYPES, + OTHER_CHAIN, +} + + +TERMINAL_OXYGENS: Final[Mapping[str, str]] = { + PROTEIN_CHAIN: 'OXT', + DNA_CHAIN: 'OP3', + RNA_CHAIN: 'OP3', +} + + +# For each chain type, which atom should be used to represent each residue. +RESIDUE_REPRESENTATIVE_ATOMS: Final[Mapping[str, str]] = { + PROTEIN_CHAIN: atom_types.CA, + DNA_CHAIN: atom_types.C1PRIME, + RNA_CHAIN: atom_types.C1PRIME, +} + +# Methods involving crystallization. See the documentation at +# mmcif.wwpdb.org/dictionaries/mmcif_pdbx_v50.dic/Items/_exptl.method.html +# for the full list of experimental methods. +CRYSTALLIZATION_METHODS: Final[Set[str]] = { + 'X-RAY DIFFRACTION', + 'NEUTRON DIFFRACTION', + 'ELECTRON CRYSTALLOGRAPHY', + 'POWDER CRYSTALLOGRAPHY', + 'FIBER DIFFRACTION', +} + +# Possible bond types. +COVALENT_BOND: Final[str] = 'covale' +HYDROGEN_BOND: Final[str] = 'hydrog' +METAL_COORDINATION: Final[str] = 'metalc' +DISULFIDE_BRIDGE: Final[str] = 'disulf' + + +def is_standard_polymer_type(chain_type: str) -> bool: + """Returns if chain type is a protein, DNA or RNA chain type. + + Args: + chain_type: The type of the chain. + + Returns: + A bool for if the chain_type matches protein, DNA, or RNA. + """ + return chain_type in STANDARD_POLYMER_CHAIN_TYPES + + +def guess_polymer_type(chain_residues: Sequence[str]) -> str: + """Guess the polymer type (protein/rna/dna/other) based on the residues. + + The polymer type is guessed by first checking for any of the standard + protein residues. If one is present then the chain is considered to be a + polypeptide. Otherwise we decide by counting residue types and deciding by + majority voting (e.g. mostly DNA residues -> DNA). If there is a tie between + the counts, the ordering is rna > dna > other. + + Note that we count MSE and UNK as protein residues. + + Args: + chain_residues: A sequence of full residue name (1-letter for DNA, 2-letters + for RNA, 3 for protein). The _atom_site.label_comp_id column in mmCIF. + + Returns: + The most probable chain type as set in the _entity_poly mmCIF table: + protein - polypeptide(L), rna - polyribonucleotide, + dna - polydeoxyribonucleotide or other. + """ + residue_types = { + **{r: RNA_CHAIN for r in residue_names.RNA_TYPES}, + **{r: DNA_CHAIN for r in residue_names.DNA_TYPES}, + **{r: PROTEIN_CHAIN for r in residue_names.PROTEIN_TYPES_WITH_UNKNOWN}, + residue_names.MSE: PROTEIN_CHAIN, + } + + counts = {PROTEIN_CHAIN: 0, RNA_CHAIN: 0, DNA_CHAIN: 0, OTHER_CHAIN: 0} + for residue in chain_residues: + residue_type = residue_types.get(residue, OTHER_CHAIN) + # If we ever see a protein residue we'll consider this a polypeptide(L). + if residue_type == PROTEIN_CHAIN: + return residue_type + counts[residue_type] += 1 + + # Make sure protein > rna > dna > other if there is a tie. + tie_braker = {PROTEIN_CHAIN: 3, RNA_CHAIN: 2, DNA_CHAIN: 1, OTHER_CHAIN: 0} + + def order_fn(item): + name, count = item + return count, tie_braker[name] + + most_probable_type = max(counts.items(), key=order_fn)[0] + return most_probable_type + + +def fix_non_standard_polymer_res(*, res_name: str, chain_type: str) -> str: + """Returns the res_name of the closest standard protein/RNA/DNA residue. + + Optimized for the case where a single residue needs to be converted. + + If res_name is already a standard type, it is returned unaltered. + If a match cannot be found, returns 'UNK' for protein chains and 'N' for + RNA/DNA chains. + + Args: + res_name: A residue_name (monomer code from the CCD). + chain_type: The type of the chain, must be PROTEIN_CHAIN, RNA_CHAIN or + DNA_CHAIN. + + Returns: + An element from PROTEIN_TYPES_WITH_UNKNOWN | RNA_TYPES | DNA_TYPES | {'N'}. + + Raises: + ValueError: If chain_type not in PEPTIDE_CHAIN_TYPES or + {OTHER_CHAIN, RNA_CHAIN, DNA_CHAIN, DNA_RNA_HYBRID_CHAIN}. + """ + # Map to one letter code, then back to common res_names. + one_letter_code = residue_names.letters_three_to_one(res_name, default='X') + + if chain_type in PEPTIDE_CHAIN_TYPES or chain_type == OTHER_CHAIN: + return residue_names.PROTEIN_COMMON_ONE_TO_THREE.get(one_letter_code, 'UNK') + elif chain_type == RNA_CHAIN: + # RNA's CCD monomer code is single-letter. + return ( + one_letter_code if one_letter_code in residue_names.RNA_TYPES else 'N' + ) + elif chain_type == DNA_CHAIN: + return residue_names.DNA_COMMON_ONE_TO_TWO.get(one_letter_code, 'N') + elif chain_type == DNA_RNA_HYBRID_CHAIN: + return ( + res_name + if res_name in residue_names.NUCLEIC_TYPES_WITH_UNKNOWN + else 'N' + ) + else: + raise ValueError( + f'Expected a protein/DNA/RNA chain but got {chain_type}') diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/periodic_table.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/periodic_table.py new file mode 100644 index 0000000000000000000000000000000000000000..7385245ff8b69e548bf7b253355adc01fa056fb2 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/periodic_table.py @@ -0,0 +1,399 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Periodic table of elements.""" + +from collections.abc import Mapping, Sequence +import dataclasses +from typing import Final + +import numpy as np + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class Element: + name: str + number: int + symbol: str + weight: float + + +# Weights taken from rdkit/Code/GraphMol/atomic_data.cpp for compatibility. +# pylint: disable=invalid-name + +# X is an unknown element that can be present in the CCD, +# https://www.rcsb.org/ligand/UNX. +X: Final[Element] = Element(name='Unknown', number=0, symbol='X', weight=0.0) +H: Final[Element] = Element( + name='Hydrogen', number=1, symbol='H', weight=1.008) +He: Final[Element] = Element( + name='Helium', number=2, symbol='He', weight=4.003) +Li: Final[Element] = Element( + name='Lithium', number=3, symbol='Li', weight=6.941 +) +Be: Final[Element] = Element( + name='Beryllium', number=4, symbol='Be', weight=9.012 +) +B: Final[Element] = Element(name='Boron', number=5, symbol='B', weight=10.812) +C: Final[Element] = Element(name='Carbon', number=6, symbol='C', weight=12.011) +N: Final[Element] = Element( + name='Nitrogen', number=7, symbol='N', weight=14.007 +) +O: Final[Element] = Element(name='Oxygen', number=8, symbol='O', weight=15.999) +F: Final[Element] = Element( + name='Fluorine', number=9, symbol='F', weight=18.998 +) +Ne: Final[Element] = Element(name='Neon', number=10, symbol='Ne', weight=20.18) +Na: Final[Element] = Element( + name='Sodium', number=11, symbol='Na', weight=22.99 +) +Mg: Final[Element] = Element( + name='Magnesium', number=12, symbol='Mg', weight=24.305 +) +Al: Final[Element] = Element( + name='Aluminium', number=13, symbol='Al', weight=26.982 +) +Si: Final[Element] = Element( + name='Silicon', number=14, symbol='Si', weight=28.086 +) +P: Final[Element] = Element( + name='Phosphorus', number=15, symbol='P', weight=30.974 +) +S: Final[Element] = Element( + name='Sulfur', number=16, symbol='S', weight=32.067) +Cl: Final[Element] = Element( + name='Chlorine', number=17, symbol='Cl', weight=35.453 +) +Ar: Final[Element] = Element( + name='Argon', number=18, symbol='Ar', weight=39.948 +) +K: Final[Element] = Element( + name='Potassium', number=19, symbol='K', weight=39.098 +) +Ca: Final[Element] = Element( + name='Calcium', number=20, symbol='Ca', weight=40.078 +) +Sc: Final[Element] = Element( + name='Scandium', number=21, symbol='Sc', weight=44.956 +) +Ti: Final[Element] = Element( + name='Titanium', number=22, symbol='Ti', weight=47.867 +) +V: Final[Element] = Element( + name='Vanadium', number=23, symbol='V', weight=50.942 +) +Cr: Final[Element] = Element( + name='Chromium', number=24, symbol='Cr', weight=51.996 +) +Mn: Final[Element] = Element( + name='Manganese', number=25, symbol='Mn', weight=54.938 +) +Fe: Final[Element] = Element( + name='Iron', number=26, symbol='Fe', weight=55.845) +Co: Final[Element] = Element( + name='Cobalt', number=27, symbol='Co', weight=58.933 +) +Ni: Final[Element] = Element( + name='Nickel', number=28, symbol='Ni', weight=58.693 +) +Cu: Final[Element] = Element( + name='Copper', number=29, symbol='Cu', weight=63.546 +) +Zn: Final[Element] = Element(name='Zinc', number=30, symbol='Zn', weight=65.39) +Ga: Final[Element] = Element( + name='Gallium', number=31, symbol='Ga', weight=69.723 +) +Ge: Final[Element] = Element( + name='Germanium', number=32, symbol='Ge', weight=72.61 +) +As: Final[Element] = Element( + name='Arsenic', number=33, symbol='As', weight=74.922 +) +Se: Final[Element] = Element( + name='Selenium', number=34, symbol='Se', weight=78.96 +) +Br: Final[Element] = Element( + name='Bromine', number=35, symbol='Br', weight=79.904 +) +Kr: Final[Element] = Element( + name='Krypton', number=36, symbol='Kr', weight=83.8 +) +Rb: Final[Element] = Element( + name='Rubidium', number=37, symbol='Rb', weight=85.468 +) +Sr: Final[Element] = Element( + name='Strontium', number=38, symbol='Sr', weight=87.62 +) +Y: Final[Element] = Element( + name='Yttrium', number=39, symbol='Y', weight=88.906 +) +Zr: Final[Element] = Element( + name='Zirconium', number=40, symbol='Zr', weight=91.224 +) +Nb: Final[Element] = Element( + name='Niobium', number=41, symbol='Nb', weight=92.906 +) +Mo: Final[Element] = Element( + name='Molybdenum', number=42, symbol='Mo', weight=95.94 +) +Tc: Final[Element] = Element( + name='Technetium', number=43, symbol='Tc', weight=98 +) +Ru: Final[Element] = Element( + name='Ruthenium', number=44, symbol='Ru', weight=101.07 +) +Rh: Final[Element] = Element( + name='Rhodium', number=45, symbol='Rh', weight=102.906 +) +Pd: Final[Element] = Element( + name='Palladium', number=46, symbol='Pd', weight=106.42 +) +Ag: Final[Element] = Element( + name='Silver', number=47, symbol='Ag', weight=107.868 +) +Cd: Final[Element] = Element( + name='Cadmium', number=48, symbol='Cd', weight=112.412 +) +In: Final[Element] = Element( + name='Indium', number=49, symbol='In', weight=114.818 +) +Sn: Final[Element] = Element( + name='Tin', number=50, symbol='Sn', weight=118.711) +Sb: Final[Element] = Element( + name='Antimony', number=51, symbol='Sb', weight=121.76 +) +Te: Final[Element] = Element( + name='Tellurium', number=52, symbol='Te', weight=127.6 +) +I: Final[Element] = Element( + name='Iodine', number=53, symbol='I', weight=126.904 +) +Xe: Final[Element] = Element( + name='Xenon', number=54, symbol='Xe', weight=131.29 +) +Cs: Final[Element] = Element( + name='Caesium', number=55, symbol='Cs', weight=132.905 +) +Ba: Final[Element] = Element( + name='Barium', number=56, symbol='Ba', weight=137.328 +) +La: Final[Element] = Element( + name='Lanthanum', number=57, symbol='La', weight=138.906 +) +Ce: Final[Element] = Element( + name='Cerium', number=58, symbol='Ce', weight=140.116 +) +Pr: Final[Element] = Element( + name='Praseodymium', number=59, symbol='Pr', weight=140.908 +) +Nd: Final[Element] = Element( + name='Neodymium', number=60, symbol='Nd', weight=144.24 +) +Pm: Final[Element] = Element( + name='Promethium', number=61, symbol='Pm', weight=145 +) +Sm: Final[Element] = Element( + name='Samarium', number=62, symbol='Sm', weight=150.36 +) +Eu: Final[Element] = Element( + name='Europium', number=63, symbol='Eu', weight=151.964 +) +Gd: Final[Element] = Element( + name='Gadolinium', number=64, symbol='Gd', weight=157.25 +) +Tb: Final[Element] = Element( + name='Terbium', number=65, symbol='Tb', weight=158.925 +) +Dy: Final[Element] = Element( + name='Dysprosium', number=66, symbol='Dy', weight=162.5 +) +Ho: Final[Element] = Element( + name='Holmium', number=67, symbol='Ho', weight=164.93 +) +Er: Final[Element] = Element( + name='Erbium', number=68, symbol='Er', weight=167.26 +) +Tm: Final[Element] = Element( + name='Thulium', number=69, symbol='Tm', weight=168.934 +) +Yb: Final[Element] = Element( + name='Ytterbium', number=70, symbol='Yb', weight=173.04 +) +Lu: Final[Element] = Element( + name='Lutetium', number=71, symbol='Lu', weight=174.967 +) +Hf: Final[Element] = Element( + name='Hafnium', number=72, symbol='Hf', weight=178.49 +) +Ta: Final[Element] = Element( + name='Tantalum', number=73, symbol='Ta', weight=180.948 +) +W: Final[Element] = Element( + name='Tungsten', number=74, symbol='W', weight=183.84 +) +Re: Final[Element] = Element( + name='Rhenium', number=75, symbol='Re', weight=186.207 +) +Os: Final[Element] = Element( + name='Osmium', number=76, symbol='Os', weight=190.23 +) +Ir: Final[Element] = Element( + name='Iridium', number=77, symbol='Ir', weight=192.217 +) +Pt: Final[Element] = Element( + name='Platinum', number=78, symbol='Pt', weight=195.078 +) +Au: Final[Element] = Element( + name='Gold', number=79, symbol='Au', weight=196.967 +) +Hg: Final[Element] = Element( + name='Mercury', number=80, symbol='Hg', weight=200.59 +) +Tl: Final[Element] = Element( + name='Thallium', number=81, symbol='Tl', weight=204.383 +) +Pb: Final[Element] = Element(name='Lead', number=82, symbol='Pb', weight=207.2) +Bi: Final[Element] = Element( + name='Bismuth', number=83, symbol='Bi', weight=208.98 +) +Po: Final[Element] = Element( + name='Polonium', number=84, symbol='Po', weight=209 +) +At: Final[Element] = Element( + name='Astatine', number=85, symbol='At', weight=210 +) +Rn: Final[Element] = Element(name='Radon', number=86, symbol='Rn', weight=222) +Fr: Final[Element] = Element( + name='Francium', number=87, symbol='Fr', weight=223 +) +Ra: Final[Element] = Element(name='Radium', number=88, symbol='Ra', weight=226) +Ac: Final[Element] = Element( + name='Actinium', number=89, symbol='Ac', weight=227 +) +Th: Final[Element] = Element( + name='Thorium', number=90, symbol='Th', weight=232.038 +) +Pa: Final[Element] = Element( + name='Protactinium', number=91, symbol='Pa', weight=231.036 +) +U: Final[Element] = Element( + name='Uranium', number=92, symbol='U', weight=238.029 +) +Np: Final[Element] = Element( + name='Neptunium', number=93, symbol='Np', weight=237 +) +Pu: Final[Element] = Element( + name='Plutonium', number=94, symbol='Pu', weight=244 +) +Am: Final[Element] = Element( + name='Americium', number=95, symbol='Am', weight=243 +) +Cm: Final[Element] = Element(name='Curium', number=96, symbol='Cm', weight=247) +Bk: Final[Element] = Element( + name='Berkelium', number=97, symbol='Bk', weight=247 +) +Cf: Final[Element] = Element( + name='Californium', number=98, symbol='Cf', weight=251 +) +Es: Final[Element] = Element( + name='Einsteinium', number=99, symbol='Es', weight=252 +) +Fm: Final[Element] = Element( + name='Fermium', number=100, symbol='Fm', weight=257 +) +Md: Final[Element] = Element( + name='Mendelevium', number=101, symbol='Md', weight=258 +) +No: Final[Element] = Element( + name='Nobelium', number=102, symbol='No', weight=259 +) +Lr: Final[Element] = Element( + name='Lawrencium', number=103, symbol='Lr', weight=262 +) +Rf: Final[Element] = Element( + name='Rutherfordium', number=104, symbol='Rf', weight=267 +) +Db: Final[Element] = Element( + name='Dubnium', number=105, symbol='Db', weight=268 +) +Sg: Final[Element] = Element( + name='Seaborgium', number=106, symbol='Sg', weight=269 +) +Bh: Final[Element] = Element( + name='Bohrium', number=107, symbol='Bh', weight=270 +) +Hs: Final[Element] = Element( + name='Hassium', number=108, symbol='Hs', weight=269 +) +Mt: Final[Element] = Element( + name='Meitnerium', number=109, symbol='Mt', weight=278 +) +Ds: Final[Element] = Element( + name='Darmstadtium', number=110, symbol='Ds', weight=281 +) +Rg: Final[Element] = Element( + name='Roentgenium', number=111, symbol='Rg', weight=281 +) +Cn: Final[Element] = Element( + name='Copernicium', number=112, symbol='Cn', weight=285 +) +Nh: Final[Element] = Element( + name='Nihonium', number=113, symbol='Nh', weight=284 +) +Fl: Final[Element] = Element( + name='Flerovium', number=114, symbol='Fl', weight=289 +) +Mc: Final[Element] = Element( + name='Moscovium', number=115, symbol='Mc', weight=288 +) +Lv: Final[Element] = Element( + name='Livermorium', number=116, symbol='Lv', weight=293 +) +Ts: Final[Element] = Element( + name='Tennessine', number=117, symbol='Ts', weight=292 +) +Og: Final[Element] = Element( + name='Oganesson', number=118, symbol='Og', weight=294 +) +# pylint: enable=invalid-name + +# fmt: off +# Lanthanides +_L: Final[Sequence[Element]] = ( + La, Ce, Pr, Nd, Pm, Sm, Eu, Gd, Tb, Dy, Ho, Er, Tm, Yb, Lu) +# Actinides +_A: Final[Sequence[Element]] = ( + Ac, Th, Pa, U, Np, Pu, Am, Cm, Bk, Cf, Es, Fm, Md, No, Lr) + +# pylint: disable=bad-whitespace +PERIODIC_TABLE: Final[Sequence[Element]] = ( + X, # Unknown + H, He, + Li, Be, B, C, N, O, F, Ne, + Na, Mg, Al, Si, P, S, Cl, Ar, + K, Ca, Sc, Ti, V, Cr, Mn, Fe, Co, Ni, Cu, Zn, Ga, Ge, As, Se, Br, Kr, + Rb, Sr, Y, Zr, Nb, Mo, Tc, Ru, Rh, Pd, Ag, Cd, In, Sn, Sb, Te, I, Xe, + Cs, Ba, *_L, Hf, Ta, W, Re, Os, Ir, Pt, Au, Hg, Tl, Pb, Bi, Po, At, Rn, + Fr, Ra, *_A, Rf, Db, Sg, Bh, Hs, Mt, Ds, Rg, Cn, Nh, Fl, Mc, Lv, Ts, Og +) +# pylint: enable=bad-whitespace +# fmt: on +ATOMIC_SYMBOL: Mapping[int, str] = {e.number: e.symbol for e in PERIODIC_TABLE} +ATOMIC_NUMBER = {e.symbol: e.number for e in PERIODIC_TABLE} +# Add Deuterium as previous table contained it. +ATOMIC_NUMBER['D'] = 1 + +ATOMIC_NUMBER: Mapping[str, int] = ATOMIC_NUMBER +ATOMIC_WEIGHT: np.ndarray = np.zeros(len(PERIODIC_TABLE), dtype=np.float64) + +for e in PERIODIC_TABLE: + ATOMIC_WEIGHT[e.number] = e.weight +ATOMIC_WEIGHT.setflags(write=False) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/residue_names.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/residue_names.py new file mode 100644 index 0000000000000000000000000000000000000000..40d42587c7f333a50b4aa38188a849ea823d2227 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/residue_names.py @@ -0,0 +1,421 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Constants associated with residue names.""" + +from collections.abc import Mapping +import functools +import sys + +# pyformat: disable +# common_typos_disable +CCD_NAME_TO_ONE_LETTER: Mapping[str, str] = { + '00C': 'C', '01W': 'X', '02K': 'A', '03Y': 'C', '07O': 'C', '08P': 'C', + '0A0': 'D', '0A1': 'Y', '0A2': 'K', '0A8': 'C', '0AA': 'V', '0AB': 'V', + '0AC': 'G', '0AD': 'G', '0AF': 'W', '0AG': 'L', '0AH': 'S', '0AK': 'D', + '0AM': 'A', '0AP': 'C', '0AU': 'U', '0AV': 'A', '0AZ': 'P', '0BN': 'F', + '0C': 'C', '0CS': 'A', '0DC': 'C', '0DG': 'G', '0DT': 'T', '0FL': 'A', + '0G': 'G', '0NC': 'A', '0SP': 'A', '0U': 'U', '10C': 'C', '125': 'U', + '126': 'U', '127': 'U', '128': 'N', '12A': 'A', '143': 'C', '193': 'X', + '1AP': 'A', '1MA': 'A', '1MG': 'G', '1PA': 'F', '1PI': 'A', '1PR': 'N', + '1SC': 'C', '1TQ': 'W', '1TY': 'Y', '1X6': 'S', '200': 'F', '23F': 'F', + '23S': 'X', '26B': 'T', '2AD': 'X', '2AG': 'A', '2AO': 'X', '2AR': 'A', + '2AS': 'X', '2AT': 'T', '2AU': 'U', '2BD': 'I', '2BT': 'T', '2BU': 'A', + '2CO': 'C', '2DA': 'A', '2DF': 'N', '2DM': 'N', '2DO': 'X', '2DT': 'T', + '2EG': 'G', '2FE': 'N', '2FI': 'N', '2FM': 'M', '2GT': 'T', '2HF': 'H', + '2LU': 'L', '2MA': 'A', '2MG': 'G', '2ML': 'L', '2MR': 'R', '2MT': 'P', + '2MU': 'U', '2NT': 'T', '2OM': 'U', '2OT': 'T', '2PI': 'X', '2PR': 'G', + '2SA': 'N', '2SI': 'X', '2ST': 'T', '2TL': 'T', '2TY': 'Y', '2VA': 'V', + '2XA': 'C', '32S': 'X', '32T': 'X', '3AH': 'H', '3AR': 'X', '3CF': 'F', + '3DA': 'A', '3DR': 'N', '3GA': 'A', '3MD': 'D', '3ME': 'U', '3NF': 'Y', + '3QN': 'K', '3TY': 'X', '3XH': 'G', '4AC': 'N', '4BF': 'Y', '4CF': 'F', + '4CY': 'M', '4DP': 'W', '4FB': 'P', '4FW': 'W', '4HT': 'W', '4IN': 'W', + '4MF': 'N', '4MM': 'X', '4OC': 'C', '4PC': 'C', '4PD': 'C', '4PE': 'C', + '4PH': 'F', '4SC': 'C', '4SU': 'U', '4TA': 'N', '4U7': 'A', '56A': 'H', + '5AA': 'A', '5AB': 'A', '5AT': 'T', '5BU': 'U', '5CG': 'G', '5CM': 'C', + '5CS': 'C', '5FA': 'A', '5FC': 'C', '5FU': 'U', '5HP': 'E', '5HT': 'T', + '5HU': 'U', '5IC': 'C', '5IT': 'T', '5IU': 'U', '5MC': 'C', '5MD': 'N', + '5MU': 'U', '5NC': 'C', '5PC': 'C', '5PY': 'T', '5SE': 'U', '64T': 'T', + '6CL': 'K', '6CT': 'T', '6CW': 'W', '6HA': 'A', '6HC': 'C', '6HG': 'G', + '6HN': 'K', '6HT': 'T', '6IA': 'A', '6MA': 'A', '6MC': 'A', '6MI': 'N', + '6MT': 'A', '6MZ': 'N', '6OG': 'G', '70U': 'U', '7DA': 'A', '7GU': 'G', + '7JA': 'I', '7MG': 'G', '8AN': 'A', '8FG': 'G', '8MG': 'G', '8OG': 'G', + '9NE': 'E', '9NF': 'F', '9NR': 'R', '9NV': 'V', 'A': 'A', 'A1P': 'N', + 'A23': 'A', 'A2L': 'A', 'A2M': 'A', 'A34': 'A', 'A35': 'A', 'A38': 'A', + 'A39': 'A', 'A3A': 'A', 'A3P': 'A', 'A40': 'A', 'A43': 'A', 'A44': 'A', + 'A47': 'A', 'A5L': 'A', 'A5M': 'C', 'A5N': 'N', 'A5O': 'A', 'A66': 'X', + 'AA3': 'A', 'AA4': 'A', 'AAR': 'R', 'AB7': 'X', 'ABA': 'A', 'ABR': 'A', + 'ABS': 'A', 'ABT': 'N', 'ACB': 'D', 'ACL': 'R', 'AD2': 'A', 'ADD': 'X', + 'ADX': 'N', 'AEA': 'X', 'AEI': 'D', 'AET': 'A', 'AFA': 'N', 'AFF': 'N', + 'AFG': 'G', 'AGM': 'R', 'AGT': 'C', 'AHB': 'N', 'AHH': 'X', 'AHO': 'A', + 'AHP': 'A', 'AHS': 'X', 'AHT': 'X', 'AIB': 'A', 'AKL': 'D', 'AKZ': 'D', + 'ALA': 'A', 'ALC': 'A', 'ALM': 'A', 'ALN': 'A', 'ALO': 'T', 'ALQ': 'X', + 'ALS': 'A', 'ALT': 'A', 'ALV': 'A', 'ALY': 'K', 'AN8': 'A', 'AP7': 'A', + 'APE': 'X', 'APH': 'A', 'API': 'K', 'APK': 'K', 'APM': 'X', 'APP': 'X', + 'AR2': 'R', 'AR4': 'E', 'AR7': 'R', 'ARG': 'R', 'ARM': 'R', 'ARO': 'R', + 'ARV': 'X', 'AS': 'A', 'AS2': 'D', 'AS9': 'X', 'ASA': 'D', 'ASB': 'D', + 'ASI': 'D', 'ASK': 'D', 'ASL': 'D', 'ASM': 'X', 'ASN': 'N', 'ASP': 'D', + 'ASQ': 'D', 'ASU': 'N', 'ASX': 'B', 'ATD': 'T', 'ATL': 'T', 'ATM': 'T', + 'AVC': 'A', 'AVN': 'X', 'AYA': 'A', 'AZK': 'K', 'AZS': 'S', 'AZY': 'Y', + 'B1F': 'F', 'B1P': 'N', 'B2A': 'A', 'B2F': 'F', 'B2I': 'I', 'B2V': 'V', + 'B3A': 'A', 'B3D': 'D', 'B3E': 'E', 'B3K': 'K', 'B3L': 'X', 'B3M': 'X', + 'B3Q': 'X', 'B3S': 'S', 'B3T': 'X', 'B3U': 'H', 'B3X': 'N', 'B3Y': 'Y', + 'BB6': 'C', 'BB7': 'C', 'BB8': 'F', 'BB9': 'C', 'BBC': 'C', 'BCS': 'C', + 'BE2': 'X', 'BFD': 'D', 'BG1': 'S', 'BGM': 'G', 'BH2': 'D', 'BHD': 'D', + 'BIF': 'F', 'BIL': 'X', 'BIU': 'I', 'BJH': 'X', 'BLE': 'L', 'BLY': 'K', + 'BMP': 'N', 'BMT': 'T', 'BNN': 'F', 'BNO': 'X', 'BOE': 'T', 'BOR': 'R', + 'BPE': 'C', 'BRU': 'U', 'BSE': 'S', 'BT5': 'N', 'BTA': 'L', 'BTC': 'C', + 'BTR': 'W', 'BUC': 'C', 'BUG': 'V', 'BVP': 'U', 'BZG': 'N', 'C': 'C', + 'C1X': 'K', 'C25': 'C', 'C2L': 'C', 'C2S': 'C', 'C31': 'C', 'C32': 'C', + 'C34': 'C', 'C36': 'C', 'C37': 'C', 'C38': 'C', 'C3Y': 'C', 'C42': 'C', + 'C43': 'C', 'C45': 'C', 'C46': 'C', 'C49': 'C', 'C4R': 'C', 'C4S': 'C', + 'C5C': 'C', 'C66': 'X', 'C6C': 'C', 'CAF': 'C', 'CAL': 'X', 'CAR': 'C', + 'CAS': 'C', 'CAV': 'X', 'CAY': 'C', 'CB2': 'C', 'CBR': 'C', 'CBV': 'C', + 'CCC': 'C', 'CCL': 'K', 'CCS': 'C', 'CDE': 'X', 'CDV': 'X', 'CDW': 'C', + 'CEA': 'C', 'CFL': 'C', 'CG1': 'G', 'CGA': 'E', 'CGU': 'E', 'CH': 'C', + 'CHF': 'X', 'CHG': 'X', 'CHP': 'G', 'CHS': 'X', 'CIR': 'R', 'CLE': 'L', + 'CLG': 'K', 'CLH': 'K', 'CM0': 'N', 'CME': 'C', 'CMH': 'C', 'CML': 'C', + 'CMR': 'C', 'CMT': 'C', 'CNU': 'U', 'CP1': 'C', 'CPC': 'X', 'CPI': 'X', + 'CR5': 'G', 'CS0': 'C', 'CS1': 'C', 'CS3': 'C', 'CS4': 'C', 'CS8': 'N', + 'CSA': 'C', 'CSB': 'C', 'CSD': 'C', 'CSE': 'C', 'CSF': 'C', 'CSI': 'G', + 'CSJ': 'C', 'CSL': 'C', 'CSO': 'C', 'CSP': 'C', 'CSR': 'C', 'CSS': 'C', + 'CSU': 'C', 'CSW': 'C', 'CSX': 'C', 'CSZ': 'C', 'CTE': 'W', 'CTG': 'T', + 'CTH': 'T', 'CUC': 'X', 'CWR': 'S', 'CXM': 'M', 'CY0': 'C', 'CY1': 'C', + 'CY3': 'C', 'CY4': 'C', 'CYA': 'C', 'CYD': 'C', 'CYF': 'C', 'CYG': 'C', + 'CYJ': 'X', 'CYM': 'C', 'CYQ': 'C', 'CYR': 'C', 'CYS': 'C', 'CZ2': 'C', + 'CZZ': 'C', 'D11': 'T', 'D1P': 'N', 'D3': 'N', 'D33': 'N', 'D3P': 'G', + 'D3T': 'T', 'D4M': 'T', 'D4P': 'X', 'DA': 'A', 'DA2': 'X', 'DAB': 'A', + 'DAH': 'F', 'DAL': 'A', 'DAR': 'R', 'DAS': 'D', 'DBB': 'T', 'DBM': 'N', + 'DBS': 'S', 'DBU': 'T', 'DBY': 'Y', 'DBZ': 'A', 'DC': 'C', 'DC2': 'C', + 'DCG': 'G', 'DCI': 'X', 'DCL': 'X', 'DCT': 'C', 'DCY': 'C', 'DDE': 'H', + 'DDG': 'G', 'DDN': 'U', 'DDX': 'N', 'DFC': 'C', 'DFG': 'G', 'DFI': 'X', + 'DFO': 'X', 'DFT': 'N', 'DG': 'G', 'DGH': 'G', 'DGI': 'G', 'DGL': 'E', + 'DGN': 'Q', 'DHA': 'S', 'DHI': 'H', 'DHL': 'X', 'DHN': 'V', 'DHP': 'X', + 'DHU': 'U', 'DHV': 'V', 'DI': 'I', 'DIL': 'I', 'DIR': 'R', 'DIV': 'V', + 'DLE': 'L', 'DLS': 'K', 'DLY': 'K', 'DM0': 'K', 'DMH': 'N', 'DMK': 'D', + 'DMT': 'X', 'DN': 'N', 'DNE': 'L', 'DNG': 'L', 'DNL': 'K', 'DNM': 'L', + 'DNP': 'A', 'DNR': 'C', 'DNS': 'K', 'DOA': 'X', 'DOC': 'C', 'DOH': 'D', + 'DON': 'L', 'DPB': 'T', 'DPH': 'F', 'DPL': 'P', 'DPP': 'A', 'DPQ': 'Y', + 'DPR': 'P', 'DPY': 'N', 'DRM': 'U', 'DRP': 'N', 'DRT': 'T', 'DRZ': 'N', + 'DSE': 'S', 'DSG': 'N', 'DSN': 'S', 'DSP': 'D', 'DT': 'T', 'DTH': 'T', + 'DTR': 'W', 'DTY': 'Y', 'DU': 'U', 'DVA': 'V', 'DXD': 'N', 'DXN': 'N', + 'DYS': 'C', 'DZM': 'A', 'E': 'A', 'E1X': 'A', 'ECC': 'Q', 'EDA': 'A', + 'EFC': 'C', 'EHP': 'F', 'EIT': 'T', 'ENP': 'N', 'ESB': 'Y', 'ESC': 'M', + 'EXB': 'X', 'EXY': 'L', 'EY5': 'N', 'EYS': 'X', 'F2F': 'F', 'FA2': 'A', + 'FA5': 'N', 'FAG': 'N', 'FAI': 'N', 'FB5': 'A', 'FB6': 'A', 'FCL': 'F', + 'FFD': 'N', 'FGA': 'E', 'FGL': 'G', 'FGP': 'S', 'FHL': 'X', 'FHO': 'K', + 'FHU': 'U', 'FLA': 'A', 'FLE': 'L', 'FLT': 'Y', 'FME': 'M', 'FMG': 'G', + 'FMU': 'N', 'FOE': 'C', 'FOX': 'G', 'FP9': 'P', 'FPA': 'F', 'FRD': 'X', + 'FT6': 'W', 'FTR': 'W', 'FTY': 'Y', 'FVA': 'V', 'FZN': 'K', 'G': 'G', + 'G25': 'G', 'G2L': 'G', 'G2S': 'G', 'G31': 'G', 'G32': 'G', 'G33': 'G', + 'G36': 'G', 'G38': 'G', 'G42': 'G', 'G46': 'G', 'G47': 'G', 'G48': 'G', + 'G49': 'G', 'G4P': 'N', 'G7M': 'G', 'GAO': 'G', 'GAU': 'E', 'GCK': 'C', + 'GCM': 'X', 'GDP': 'G', 'GDR': 'G', 'GFL': 'G', 'GGL': 'E', 'GH3': 'G', + 'GHG': 'Q', 'GHP': 'G', 'GL3': 'G', 'GLH': 'Q', 'GLJ': 'E', 'GLK': 'E', + 'GLM': 'X', 'GLN': 'Q', 'GLQ': 'E', 'GLU': 'E', 'GLX': 'Z', 'GLY': 'G', + 'GLZ': 'G', 'GMA': 'E', 'GMS': 'G', 'GMU': 'U', 'GN7': 'G', 'GND': 'X', + 'GNE': 'N', 'GOM': 'G', 'GPL': 'K', 'GS': 'G', 'GSC': 'G', 'GSR': 'G', + 'GSS': 'G', 'GSU': 'E', 'GT9': 'C', 'GTP': 'G', 'GVL': 'X', 'H2U': 'U', + 'H5M': 'P', 'HAC': 'A', 'HAR': 'R', 'HBN': 'H', 'HCS': 'X', 'HDP': 'U', + 'HEU': 'U', 'HFA': 'X', 'HGL': 'X', 'HHI': 'H', 'HIA': 'H', 'HIC': 'H', + 'HIP': 'H', 'HIQ': 'H', 'HIS': 'H', 'HL2': 'L', 'HLU': 'L', 'HMR': 'R', + 'HOL': 'N', 'HPC': 'F', 'HPE': 'F', 'HPH': 'F', 'HPQ': 'F', 'HQA': 'A', + 'HRG': 'R', 'HRP': 'W', 'HS8': 'H', 'HS9': 'H', 'HSE': 'S', 'HSL': 'S', + 'HSO': 'H', 'HTI': 'C', 'HTN': 'N', 'HTR': 'W', 'HV5': 'A', 'HVA': 'V', + 'HY3': 'P', 'HYP': 'P', 'HZP': 'P', 'I': 'I', 'I2M': 'I', 'I58': 'K', + 'I5C': 'C', 'IAM': 'A', 'IAR': 'R', 'IAS': 'D', 'IC': 'C', 'IEL': 'K', + 'IG': 'G', 'IGL': 'G', 'IGU': 'G', 'IIL': 'I', 'ILE': 'I', 'ILG': 'E', + 'ILX': 'I', 'IMC': 'C', 'IML': 'I', 'IOY': 'F', 'IPG': 'G', 'IPN': 'N', + 'IRN': 'N', 'IT1': 'K', 'IU': 'U', 'IYR': 'Y', 'IYT': 'T', 'IZO': 'M', + 'JJJ': 'C', 'JJK': 'C', 'JJL': 'C', 'JW5': 'N', 'K1R': 'C', 'KAG': 'G', + 'KCX': 'K', 'KGC': 'K', 'KNB': 'A', 'KOR': 'M', 'KPI': 'K', 'KST': 'K', + 'KYQ': 'K', 'L2A': 'X', 'LA2': 'K', 'LAA': 'D', 'LAL': 'A', 'LBY': 'K', + 'LC': 'C', 'LCA': 'A', 'LCC': 'N', 'LCG': 'G', 'LCH': 'N', 'LCK': 'K', + 'LCX': 'K', 'LDH': 'K', 'LED': 'L', 'LEF': 'L', 'LEH': 'L', 'LEI': 'V', + 'LEM': 'L', 'LEN': 'L', 'LET': 'X', 'LEU': 'L', 'LEX': 'L', 'LG': 'G', + 'LGP': 'G', 'LHC': 'X', 'LHU': 'U', 'LKC': 'N', 'LLP': 'K', 'LLY': 'K', + 'LME': 'E', 'LMF': 'K', 'LMQ': 'Q', 'LMS': 'N', 'LP6': 'K', 'LPD': 'P', + 'LPG': 'G', 'LPL': 'X', 'LPS': 'S', 'LSO': 'X', 'LTA': 'X', 'LTR': 'W', + 'LVG': 'G', 'LVN': 'V', 'LYF': 'K', 'LYK': 'K', 'LYM': 'K', 'LYN': 'K', + 'LYR': 'K', 'LYS': 'K', 'LYX': 'K', 'LYZ': 'K', 'M0H': 'C', 'M1G': 'G', + 'M2G': 'G', 'M2L': 'K', 'M2S': 'M', 'M30': 'G', 'M3L': 'K', 'M5M': 'C', + 'MA': 'A', 'MA6': 'A', 'MA7': 'A', 'MAA': 'A', 'MAD': 'A', 'MAI': 'R', + 'MBQ': 'Y', 'MBZ': 'N', 'MC1': 'S', 'MCG': 'X', 'MCL': 'K', 'MCS': 'C', + 'MCY': 'C', 'MD3': 'C', 'MD6': 'G', 'MDH': 'X', 'MDR': 'N', 'MEA': 'F', + 'MED': 'M', 'MEG': 'E', 'MEN': 'N', 'MEP': 'U', 'MEQ': 'Q', 'MET': 'M', + 'MEU': 'G', 'MF3': 'X', 'MG1': 'G', 'MGG': 'R', 'MGN': 'Q', 'MGQ': 'A', + 'MGV': 'G', 'MGY': 'G', 'MHL': 'L', 'MHO': 'M', 'MHS': 'H', 'MIA': 'A', + 'MIS': 'S', 'MK8': 'L', 'ML3': 'K', 'MLE': 'L', 'MLL': 'L', 'MLY': 'K', + 'MLZ': 'K', 'MME': 'M', 'MMO': 'R', 'MMT': 'T', 'MND': 'N', 'MNL': 'L', + 'MNU': 'U', 'MNV': 'V', 'MOD': 'X', 'MP8': 'P', 'MPH': 'X', 'MPJ': 'X', + 'MPQ': 'G', 'MRG': 'G', 'MSA': 'G', 'MSE': 'M', 'MSL': 'M', 'MSO': 'M', + 'MSP': 'X', 'MT2': 'M', 'MTR': 'T', 'MTU': 'A', 'MTY': 'Y', 'MVA': 'V', + 'N': 'N', 'N10': 'S', 'N2C': 'X', 'N5I': 'N', 'N5M': 'C', 'N6G': 'G', + 'N7P': 'P', 'NA8': 'A', 'NAL': 'A', 'NAM': 'A', 'NB8': 'N', 'NBQ': 'Y', + 'NC1': 'S', 'NCB': 'A', 'NCX': 'N', 'NCY': 'X', 'NDF': 'F', 'NDN': 'U', + 'NEM': 'H', 'NEP': 'H', 'NF2': 'N', 'NFA': 'F', 'NHL': 'E', 'NIT': 'X', + 'NIY': 'Y', 'NLE': 'L', 'NLN': 'L', 'NLO': 'L', 'NLP': 'L', 'NLQ': 'Q', + 'NMC': 'G', 'NMM': 'R', 'NMS': 'T', 'NMT': 'T', 'NNH': 'R', 'NP3': 'N', + 'NPH': 'C', 'NPI': 'A', 'NSK': 'X', 'NTY': 'Y', 'NVA': 'V', 'NYM': 'N', + 'NYS': 'C', 'NZH': 'H', 'O12': 'X', 'O2C': 'N', 'O2G': 'G', 'OAD': 'N', + 'OAS': 'S', 'OBF': 'X', 'OBS': 'X', 'OCS': 'C', 'OCY': 'C', 'ODP': 'N', + 'OHI': 'H', 'OHS': 'D', 'OIC': 'X', 'OIP': 'I', 'OLE': 'X', 'OLT': 'T', + 'OLZ': 'S', 'OMC': 'C', 'OMG': 'G', 'OMT': 'M', 'OMU': 'U', 'ONE': 'U', + 'ONH': 'A', 'ONL': 'X', 'OPR': 'R', 'ORN': 'A', 'ORQ': 'R', 'OSE': 'S', + 'OTB': 'X', 'OTH': 'T', 'OTY': 'Y', 'OXX': 'D', 'P': 'G', 'P1L': 'C', + 'P1P': 'N', 'P2T': 'T', 'P2U': 'U', 'P2Y': 'P', 'P5P': 'A', 'PAQ': 'Y', + 'PAS': 'D', 'PAT': 'W', 'PAU': 'A', 'PBB': 'C', 'PBF': 'F', 'PBT': 'N', + 'PCA': 'E', 'PCC': 'P', 'PCE': 'X', 'PCS': 'F', 'PDL': 'X', 'PDU': 'U', + 'PEC': 'C', 'PF5': 'F', 'PFF': 'F', 'PFX': 'X', 'PG1': 'S', 'PG7': 'G', + 'PG9': 'G', 'PGL': 'X', 'PGN': 'G', 'PGP': 'G', 'PGY': 'G', 'PHA': 'F', + 'PHD': 'D', 'PHE': 'F', 'PHI': 'F', 'PHL': 'F', 'PHM': 'F', 'PIV': 'X', + 'PLE': 'L', 'PM3': 'F', 'PMT': 'C', 'POM': 'P', 'PPN': 'F', 'PPU': 'A', + 'PPW': 'G', 'PQ1': 'N', 'PR3': 'C', 'PR5': 'A', 'PR9': 'P', 'PRN': 'A', + 'PRO': 'P', 'PRS': 'P', 'PSA': 'F', 'PSH': 'H', 'PST': 'T', 'PSU': 'U', + 'PSW': 'C', 'PTA': 'X', 'PTH': 'Y', 'PTM': 'Y', 'PTR': 'Y', 'PU': 'A', + 'PUY': 'N', 'PVH': 'H', 'PVL': 'X', 'PYA': 'A', 'PYO': 'U', 'PYX': 'C', + 'PYY': 'N', 'QMM': 'Q', 'QPA': 'C', 'QPH': 'F', 'QUO': 'G', 'R': 'A', + 'R1A': 'C', 'R4K': 'W', 'RE0': 'W', 'RE3': 'W', 'RIA': 'A', 'RMP': 'A', + 'RON': 'X', 'RT': 'T', 'RTP': 'N', 'S1H': 'S', 'S2C': 'C', 'S2D': 'A', + 'S2M': 'T', 'S2P': 'A', 'S4A': 'A', 'S4C': 'C', 'S4G': 'G', 'S4U': 'U', + 'S6G': 'G', 'SAC': 'S', 'SAH': 'C', 'SAR': 'G', 'SBL': 'S', 'SC': 'C', + 'SCH': 'C', 'SCS': 'C', 'SCY': 'C', 'SD2': 'X', 'SDG': 'G', 'SDP': 'S', + 'SEB': 'S', 'SEC': 'A', 'SEG': 'A', 'SEL': 'S', 'SEM': 'S', 'SEN': 'S', + 'SEP': 'S', 'SER': 'S', 'SET': 'S', 'SGB': 'S', 'SHC': 'C', 'SHP': 'G', + 'SHR': 'K', 'SIB': 'C', 'SLA': 'P', 'SLR': 'P', 'SLZ': 'K', 'SMC': 'C', + 'SME': 'M', 'SMF': 'F', 'SMP': 'A', 'SMT': 'T', 'SNC': 'C', 'SNN': 'N', + 'SOC': 'C', 'SOS': 'N', 'SOY': 'S', 'SPT': 'T', 'SRA': 'A', 'SSU': 'U', + 'STY': 'Y', 'SUB': 'X', 'SUN': 'S', 'SUR': 'U', 'SVA': 'S', 'SVV': 'S', + 'SVW': 'S', 'SVX': 'S', 'SVY': 'S', 'SVZ': 'X', 'SYS': 'C', 'T': 'T', + 'T11': 'F', 'T23': 'T', 'T2S': 'T', 'T2T': 'N', 'T31': 'U', 'T32': 'T', + 'T36': 'T', 'T37': 'T', 'T38': 'T', 'T39': 'T', 'T3P': 'T', 'T41': 'T', + 'T48': 'T', 'T49': 'T', 'T4S': 'T', 'T5O': 'U', 'T5S': 'T', 'T66': 'X', + 'T6A': 'A', 'TA3': 'T', 'TA4': 'X', 'TAF': 'T', 'TAL': 'N', 'TAV': 'D', + 'TBG': 'V', 'TBM': 'T', 'TC1': 'C', 'TCP': 'T', 'TCQ': 'Y', 'TCR': 'W', + 'TCY': 'A', 'TDD': 'L', 'TDY': 'T', 'TFE': 'T', 'TFO': 'A', 'TFQ': 'F', + 'TFT': 'T', 'TGP': 'G', 'TH6': 'T', 'THC': 'T', 'THO': 'X', 'THR': 'T', + 'THX': 'N', 'THZ': 'R', 'TIH': 'A', 'TLB': 'N', 'TLC': 'T', 'TLN': 'U', + 'TMB': 'T', 'TMD': 'T', 'TNB': 'C', 'TNR': 'S', 'TOX': 'W', 'TP1': 'T', + 'TPC': 'C', 'TPG': 'G', 'TPH': 'X', 'TPL': 'W', 'TPO': 'T', 'TPQ': 'Y', + 'TQI': 'W', 'TQQ': 'W', 'TRF': 'W', 'TRG': 'K', 'TRN': 'W', 'TRO': 'W', + 'TRP': 'W', 'TRQ': 'W', 'TRW': 'W', 'TRX': 'W', 'TS': 'N', 'TST': 'X', + 'TT': 'N', 'TTD': 'T', 'TTI': 'U', 'TTM': 'T', 'TTQ': 'W', 'TTS': 'Y', + 'TY1': 'Y', 'TY2': 'Y', 'TY3': 'Y', 'TY5': 'Y', 'TYB': 'Y', 'TYI': 'Y', + 'TYJ': 'Y', 'TYN': 'Y', 'TYO': 'Y', 'TYQ': 'Y', 'TYR': 'Y', 'TYS': 'Y', + 'TYT': 'Y', 'TYU': 'N', 'TYW': 'Y', 'TYX': 'X', 'TYY': 'Y', 'TZB': 'X', + 'TZO': 'X', 'U': 'U', 'U25': 'U', 'U2L': 'U', 'U2N': 'U', 'U2P': 'U', + 'U31': 'U', 'U33': 'U', 'U34': 'U', 'U36': 'U', 'U37': 'U', 'U8U': 'U', + 'UAR': 'U', 'UCL': 'U', 'UD5': 'U', 'UDP': 'N', 'UFP': 'N', 'UFR': 'U', + 'UFT': 'U', 'UMA': 'A', 'UMP': 'U', 'UMS': 'U', 'UN1': 'X', 'UN2': 'X', + 'UNK': 'X', 'UR3': 'U', 'URD': 'U', 'US1': 'U', 'US2': 'U', 'US3': 'T', + 'US5': 'U', 'USM': 'U', 'VAD': 'V', 'VAF': 'V', 'VAL': 'V', 'VB1': 'K', + 'VDL': 'X', 'VLL': 'X', 'VLM': 'X', 'VMS': 'X', 'VOL': 'X', 'X': 'G', + 'X2W': 'E', 'X4A': 'N', 'XAD': 'A', 'XAE': 'N', 'XAL': 'A', 'XAR': 'N', + 'XCL': 'C', 'XCN': 'C', 'XCP': 'X', 'XCR': 'C', 'XCS': 'N', 'XCT': 'C', + 'XCY': 'C', 'XGA': 'N', 'XGL': 'G', 'XGR': 'G', 'XGU': 'G', 'XPR': 'P', + 'XSN': 'N', 'XTH': 'T', 'XTL': 'T', 'XTR': 'T', 'XTS': 'G', 'XTY': 'N', + 'XUA': 'A', 'XUG': 'G', 'XX1': 'K', 'Y': 'A', 'YCM': 'C', 'YG': 'G', + 'YOF': 'Y', 'YRR': 'N', 'YYG': 'G', 'Z': 'C', 'Z01': 'A', 'ZAD': 'A', + 'ZAL': 'A', 'ZBC': 'C', 'ZBU': 'U', 'ZCL': 'F', 'ZCY': 'C', 'ZDU': 'U', + 'ZFB': 'X', 'ZGU': 'G', 'ZHP': 'N', 'ZTH': 'T', 'ZU0': 'T', 'ZZJ': 'A', +} +# common_typos_enable +# pyformat: enable + + +@functools.lru_cache(maxsize=64) +def letters_three_to_one(restype: str, *, default: str) -> str: + """Returns single letter name if one exists otherwise returns default.""" + return CCD_NAME_TO_ONE_LETTER.get(restype, default) + + +ALA = sys.intern('ALA') +ARG = sys.intern('ARG') +ASN = sys.intern('ASN') +ASP = sys.intern('ASP') +CYS = sys.intern('CYS') +GLN = sys.intern('GLN') +GLU = sys.intern('GLU') +GLY = sys.intern('GLY') +HIS = sys.intern('HIS') +ILE = sys.intern('ILE') +LEU = sys.intern('LEU') +LYS = sys.intern('LYS') +MET = sys.intern('MET') +PHE = sys.intern('PHE') +PRO = sys.intern('PRO') +SER = sys.intern('SER') +THR = sys.intern('THR') +TRP = sys.intern('TRP') +TYR = sys.intern('TYR') +VAL = sys.intern('VAL') +UNK = sys.intern('UNK') +GAP = sys.intern('-') + +# Unknown ligand. +UNL = sys.intern('UNL') + +# Non-standard version of MET (with Se instead of S), but often appears in PDB. +MSE = sys.intern('MSE') + +# 20 standard protein amino acids (no unknown). +PROTEIN_TYPES: tuple[str, ...] = ( + ALA, ARG, ASN, ASP, CYS, GLN, GLU, GLY, HIS, ILE, LEU, LYS, MET, PHE, PRO, + SER, THR, TRP, TYR, VAL, +) # pyformat: disable + +# 20 standard protein amino acids plus the unknown (UNK) amino acid. +PROTEIN_TYPES_WITH_UNKNOWN: tuple[str, ...] = PROTEIN_TYPES + (UNK,) + +# This is the standard residue order when coding AA type as a number. +# Reproduce it by taking 3-letter AA codes and sorting them alphabetically. +# For legacy reasons this only refers to protein residues. + +PROTEIN_TYPES_ONE_LETTER: tuple[str, ...] = ( + 'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P', + 'S', 'T', 'W', 'Y', 'V', +) # pyformat: disable + +PROTEIN_TYPES_ONE_LETTER_WITH_UNKNOWN: tuple[str, ...] = ( + PROTEIN_TYPES_ONE_LETTER + ('X',) +) +PROTEIN_TYPES_ONE_LETTER_WITH_UNKNOWN_AND_GAP: tuple[str, ...] = ( + PROTEIN_TYPES_ONE_LETTER_WITH_UNKNOWN + (GAP,) +) + +PROTEIN_TYPES_ONE_LETTER_TO_INT: Mapping[str, int] = { + r: i for i, r in enumerate(PROTEIN_TYPES_ONE_LETTER) +} +PROTEIN_TYPES_ONE_LETTER_WITH_UNKNOWN_TO_INT: Mapping[str, int] = { + r: i for i, r in enumerate(PROTEIN_TYPES_ONE_LETTER_WITH_UNKNOWN) +} + +PROTEIN_TYPES_ONE_LETTER_WITH_UNKNOWN_AND_GAP_TO_INT: Mapping[str, int] = { + r: i for i, r in enumerate(PROTEIN_TYPES_ONE_LETTER_WITH_UNKNOWN_AND_GAP) +} + + +PROTEIN_COMMON_ONE_TO_THREE: Mapping[str, str] = { + 'A': ALA, + 'R': ARG, + 'N': ASN, + 'D': ASP, + 'C': CYS, + 'Q': GLN, + 'E': GLU, + 'G': GLY, + 'H': HIS, + 'I': ILE, + 'L': LEU, + 'K': LYS, + 'M': MET, + 'F': PHE, + 'P': PRO, + 'S': SER, + 'T': THR, + 'W': TRP, + 'Y': TYR, + 'V': VAL, +} + +PROTEIN_COMMON_THREE_TO_ONE: Mapping[str, str] = { + v: k for k, v in PROTEIN_COMMON_ONE_TO_THREE.items() +} + +A = sys.intern('A') +G = sys.intern('G') +C = sys.intern('C') +U = sys.intern('U') +T = sys.intern('T') + +DA = sys.intern('DA') +DG = sys.intern('DG') +DC = sys.intern('DC') +DT = sys.intern('DT') + +UNK_NUCLEIC_ONE_LETTER = sys.intern('N') # Unknown nucleic acid single letter. +UNK_RNA = sys.intern('N') # Unknown RNA. +UNK_DNA = sys.intern('DN') # Unknown DNA residue (differs from N). + +RNA_TYPES: tuple[str, ...] = (A, G, C, U) +DNA_TYPES: tuple[str, ...] = (DA, DG, DC, DT) + +NUCLEIC_TYPES: tuple[str, ...] = RNA_TYPES + DNA_TYPES +# Without UNK DNA. +NUCLEIC_TYPES_WITH_UNKNOWN: tuple[str, ...] = NUCLEIC_TYPES + ( + UNK_NUCLEIC_ONE_LETTER, +) +NUCLEIC_TYPES_WITH_2_UNKS: tuple[str, ...] = NUCLEIC_TYPES + ( + UNK_RNA, + UNK_DNA, +) + +RNA_TYPES_ONE_LETTER_WITH_UNKNOWN: tuple[str, ...] = RNA_TYPES + (UNK_RNA,) +RNA_TYPES_ONE_LETTER_WITH_UNKNOWN_TO_INT: Mapping[str, int] = { + r: i for i, r in enumerate(RNA_TYPES_ONE_LETTER_WITH_UNKNOWN) +} + +DNA_TYPES_WITH_UNKNOWN: tuple[str, ...] = DNA_TYPES + (UNK_DNA,) +DNA_TYPES_ONE_LETTER: tuple[str, ...] = (A, G, C, T) +DNA_TYPES_ONE_LETTER_WITH_UNKNOWN: tuple[str, ...] = DNA_TYPES_ONE_LETTER + ( + UNK_NUCLEIC_ONE_LETTER, +) +DNA_TYPES_ONE_LETTER_WITH_UNKNOWN_TO_INT: Mapping[str, int] = { + r: i for i, r in enumerate(DNA_TYPES_ONE_LETTER_WITH_UNKNOWN) +} +DNA_COMMON_ONE_TO_TWO: Mapping[str, str] = { + 'A': 'DA', + 'G': 'DG', + 'C': 'DC', + 'T': 'DT', +} + +STANDARD_POLYMER_TYPES: tuple[str, ...] = PROTEIN_TYPES + NUCLEIC_TYPES +POLYMER_TYPES: tuple[str, ...] = PROTEIN_TYPES_WITH_UNKNOWN + NUCLEIC_TYPES +POLYMER_TYPES_WITH_UNKNOWN: tuple[str, ...] = ( + PROTEIN_TYPES_WITH_UNKNOWN + NUCLEIC_TYPES_WITH_UNKNOWN +) +POLYMER_TYPES_WITH_GAP: tuple[str, ...] = PROTEIN_TYPES + \ + (GAP,) + NUCLEIC_TYPES +POLYMER_TYPES_WITH_UNKNOWN_AND_GAP: tuple[str, ...] = ( + PROTEIN_TYPES_WITH_UNKNOWN + (GAP,) + NUCLEIC_TYPES_WITH_UNKNOWN +) +POLYMER_TYPES_WITH_ALL_UNKS_AND_GAP: tuple[str, ...] = ( + PROTEIN_TYPES_WITH_UNKNOWN + (GAP,) + NUCLEIC_TYPES_WITH_2_UNKS +) + +POLYMER_TYPES_ORDER = {restype: i for i, restype in enumerate(POLYMER_TYPES)} + +POLYMER_TYPES_ORDER_WITH_UNKNOWN = { + restype: i for i, restype in enumerate(POLYMER_TYPES_WITH_UNKNOWN) +} + +POLYMER_TYPES_ORDER_WITH_UNKNOWN_AND_GAP = { + restype: i for i, restype in enumerate(POLYMER_TYPES_WITH_UNKNOWN_AND_GAP) +} + +POLYMER_TYPES_ORDER_WITH_ALL_UNKS_AND_GAP = { + restype: i for i, restype in enumerate(POLYMER_TYPES_WITH_ALL_UNKS_AND_GAP) +} + +POLYMER_TYPES_NUM = len(POLYMER_TYPES) # := 29. +POLYMER_TYPES_NUM_WITH_UNKNOWN = len(POLYMER_TYPES_WITH_UNKNOWN) # := 30. +POLYMER_TYPES_NUM_WITH_GAP = len(POLYMER_TYPES_WITH_GAP) # := 29. +POLYMER_TYPES_NUM_WITH_UNKNOWN_AND_GAP = len( + POLYMER_TYPES_WITH_UNKNOWN_AND_GAP +) # := 31. +POLYMER_TYPES_NUM_ORDER_WITH_ALL_UNKS_AND_GAP = len( + POLYMER_TYPES_WITH_ALL_UNKS_AND_GAP +) # := 32. + +WATER_TYPES: tuple[str, ...] = ('HOH', 'DOD') + +UNKNOWN_TYPES: tuple[str, ...] = (UNK, UNK_RNA, UNK_DNA, UNL) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/side_chains.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/side_chains.py new file mode 100644 index 0000000000000000000000000000000000000000..0e8cd1297970daba874f45919d4903d652c873b0 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/side_chains.py @@ -0,0 +1,112 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Constants associated with side chains.""" + +from collections.abc import Mapping, Sequence +import itertools + +# Format: The list for each AA type contains chi1, chi2, chi3, chi4 in +# this order (or a relevant subset from chi1 onwards). ALA and GLY don't have +# chi angles so their chi angle lists are empty. +CHI_ANGLES_ATOMS: Mapping[str, Sequence[tuple[str, ...]]] = { + 'ALA': [], + # Chi5 in arginine is always 0 +- 5 degrees, so ignore it. + 'ARG': [ + ('N', 'CA', 'CB', 'CG'), + ('CA', 'CB', 'CG', 'CD'), + ('CB', 'CG', 'CD', 'NE'), + ('CG', 'CD', 'NE', 'CZ'), + ], + 'ASN': [('N', 'CA', 'CB', 'CG'), ('CA', 'CB', 'CG', 'OD1')], + 'ASP': [('N', 'CA', 'CB', 'CG'), ('CA', 'CB', 'CG', 'OD1')], + 'CYS': [('N', 'CA', 'CB', 'SG')], + 'GLN': [ + ('N', 'CA', 'CB', 'CG'), + ('CA', 'CB', 'CG', 'CD'), + ('CB', 'CG', 'CD', 'OE1'), + ], + 'GLU': [ + ('N', 'CA', 'CB', 'CG'), + ('CA', 'CB', 'CG', 'CD'), + ('CB', 'CG', 'CD', 'OE1'), + ], + 'GLY': [], + 'HIS': [('N', 'CA', 'CB', 'CG'), ('CA', 'CB', 'CG', 'ND1')], + 'ILE': [('N', 'CA', 'CB', 'CG1'), ('CA', 'CB', 'CG1', 'CD1')], + 'LEU': [('N', 'CA', 'CB', 'CG'), ('CA', 'CB', 'CG', 'CD1')], + 'LYS': [ + ('N', 'CA', 'CB', 'CG'), + ('CA', 'CB', 'CG', 'CD'), + ('CB', 'CG', 'CD', 'CE'), + ('CG', 'CD', 'CE', 'NZ'), + ], + 'MET': [ + ('N', 'CA', 'CB', 'CG'), + ('CA', 'CB', 'CG', 'SD'), + ('CB', 'CG', 'SD', 'CE'), + ], + 'PHE': [('N', 'CA', 'CB', 'CG'), ('CA', 'CB', 'CG', 'CD1')], + 'PRO': [('N', 'CA', 'CB', 'CG'), ('CA', 'CB', 'CG', 'CD')], + 'SER': [('N', 'CA', 'CB', 'OG')], + 'THR': [('N', 'CA', 'CB', 'OG1')], + 'TRP': [('N', 'CA', 'CB', 'CG'), ('CA', 'CB', 'CG', 'CD1')], + 'TYR': [('N', 'CA', 'CB', 'CG'), ('CA', 'CB', 'CG', 'CD1')], + 'VAL': [('N', 'CA', 'CB', 'CG1')], +} + +CHI_GROUPS_FOR_ATOM = {} +for res_name, chi_angle_atoms_for_res in CHI_ANGLES_ATOMS.items(): + for chi_group_i, chi_group in enumerate(chi_angle_atoms_for_res): + for atom_i, atom in enumerate(chi_group): + CHI_GROUPS_FOR_ATOM.setdefault((res_name, atom), []).append( + (chi_group_i, atom_i) + ) + +# Mapping from (residue_name, atom_name) pairs to the atom's chi group index +# and atom index within that group. +CHI_GROUPS_FOR_ATOM: Mapping[tuple[str, str], Sequence[tuple[int, int]]] = ( + CHI_GROUPS_FOR_ATOM +) + +MAX_NUM_CHI_ANGLES: int = 4 +ATOMS_PER_CHI_ANGLE: int = 4 + +# A list of atoms for each AA type that are involved in chi angle calculations. +CHI_ATOM_SETS: Mapping[str, set[str]] = { + residue_name: set(itertools.chain(*atoms)) + for residue_name, atoms in CHI_ANGLES_ATOMS.items() +} + +# If chi angles given in fixed-length array, this matrix determines how to mask +# them for each AA type. The order is as per restype_order (see below). +CHI_ANGLES_MASK: Sequence[Sequence[float]] = ( + (0.0, 0.0, 0.0, 0.0), # ALA + (1.0, 1.0, 1.0, 1.0), # ARG + (1.0, 1.0, 0.0, 0.0), # ASN + (1.0, 1.0, 0.0, 0.0), # ASP + (1.0, 0.0, 0.0, 0.0), # CYS + (1.0, 1.0, 1.0, 0.0), # GLN + (1.0, 1.0, 1.0, 0.0), # GLU + (0.0, 0.0, 0.0, 0.0), # GLY + (1.0, 1.0, 0.0, 0.0), # HIS + (1.0, 1.0, 0.0, 0.0), # ILE + (1.0, 1.0, 0.0, 0.0), # LEU + (1.0, 1.0, 1.0, 1.0), # LYS + (1.0, 1.0, 1.0, 0.0), # MET + (1.0, 1.0, 0.0, 0.0), # PHE + (1.0, 1.0, 0.0, 0.0), # PRO + (1.0, 0.0, 0.0, 0.0), # SER + (1.0, 0.0, 0.0, 0.0), # THR + (1.0, 1.0, 0.0, 0.0), # TRP + (1.0, 1.0, 0.0, 0.0), # TYR + (1.0, 0.0, 0.0, 0.0), # VAL +) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/cpp.cc b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/cpp.cc new file mode 100644 index 0000000000000000000000000000000000000000..b2286b5c3681bfe6a3bb16aba8583760e4656c4d --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/cpp.cc @@ -0,0 +1,48 @@ +/* +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + */ + +#include "alphafold3/data/cpp/msa_profile_pybind.h" +#include "alphafold3/model/mkdssp_pybind.h" +#include "alphafold3/parsers/cpp/cif_dict_pybind.h" +#include "alphafold3/parsers/cpp/fasta_iterator_pybind.h" +#include "alphafold3/parsers/cpp/msa_conversion_pybind.h" +#include "alphafold3/structure/cpp/aggregation_pybind.h" +#include "alphafold3/structure/cpp/membership_pybind.h" +#include "alphafold3/structure/cpp/mmcif_atom_site_pybind.h" +#include "alphafold3/structure/cpp/mmcif_layout_pybind.h" +#include "alphafold3/structure/cpp/mmcif_struct_conn_pybind.h" +#include "alphafold3/structure/cpp/mmcif_utils_pybind.h" +#include "alphafold3/structure/cpp/string_array_pybind.h" +#include "pybind11/pybind11.h" + +namespace alphafold3 { +namespace { + +// Include all modules as submodules to simplify building. +PYBIND11_MODULE(cpp, m) { + RegisterModuleCifDict(m.def_submodule("cif_dict")); + RegisterModuleFastaIterator(m.def_submodule("fasta_iterator")); + RegisterModuleMsaConversion(m.def_submodule("msa_conversion")); + RegisterModuleMmcifLayout(m.def_submodule("mmcif_layout")); + RegisterModuleMmcifStructConn(m.def_submodule("mmcif_struct_conn")); + RegisterModuleMembership(m.def_submodule("membership")); + RegisterModuleMmcifUtils(m.def_submodule("mmcif_utils")); + RegisterModuleAggregation(m.def_submodule("aggregation")); + RegisterModuleStringArray(m.def_submodule("string_array")); + RegisterModuleMmcifAtomSite(m.def_submodule("mmcif_atom_site")); + RegisterModuleMkdssp(m.def_submodule("mkdssp")); + RegisterModuleMsaProfile(m.def_submodule("msa_profile")); +} + +} // namespace +} // namespace alphafold3 diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/cpp/msa_profile_pybind.cc b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/cpp/msa_profile_pybind.cc new file mode 100644 index 0000000000000000000000000000000000000000..83b86f4e2c3946323533c20dd266eda9e8b0ef57 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/cpp/msa_profile_pybind.cc @@ -0,0 +1,79 @@ +// Copyright 2024 DeepMind Technologies Limited +// +// AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +// this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +// +// To request access to the AlphaFold 3 model parameters, follow the process set +// out at https://github.com/google-deepmind/alphafold3. You may only use these +// if received directly from Google. Use is subject to terms of use available at +// https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +#include + +#include "absl/strings/str_cat.h" +#include "pybind11/cast.h" +#include "pybind11/numpy.h" +#include "pybind11/pybind11.h" + +namespace { + +namespace py = pybind11; + +py::array_t ComputeMsaProfile( + const py::array_t& msa, int num_residue_types) { + if (msa.size() == 0) { + throw py::value_error("The MSA must be non-empty."); + } + if (msa.ndim() != 2) { + throw py::value_error(absl::StrCat("The MSA must be rectangular, got ", + msa.ndim(), "-dimensional MSA array.")); + } + const int msa_depth = msa.shape()[0]; + const int sequence_length = msa.shape()[1]; + + py::array_t profile({sequence_length, num_residue_types}); + std::fill(profile.mutable_data(), profile.mutable_data() + profile.size(), + 0.0f); + auto profile_unchecked = profile.mutable_unchecked<2>(); + + const double normalized_count = 1.0 / msa_depth; + const int* msa_it = msa.data(); + for (int row_index = 0; row_index < msa_depth; ++row_index) { + for (int column_index = 0; column_index < sequence_length; ++column_index) { + const int residue_code = *(msa_it++); + if (residue_code < 0 || residue_code >= num_residue_types) { + throw py::value_error( + absl::StrCat("All residue codes must be positive and smaller than " + "num_residue_types ", + num_residue_types, ", got ", residue_code)); + } + profile_unchecked(column_index, residue_code) += normalized_count; + } + } + return profile; +} + +constexpr char kComputeMsaProfileDoc[] = R"( +Computes MSA profile for the given encoded MSA. + +Args: + msa: A Numpy array of shape (num_msa, num_res) with the integer coded MSA. + num_residue_types: Integer that determines the number of unique residue types. + This will determine the shape of the output profile. + +Returns: + A float Numpy array of shape (num_res, num_residue_types) with residue + frequency (residue type count normalized by MSA depth) for every column of the + MSA. +)"; + +} // namespace + +namespace alphafold3 { + +void RegisterModuleMsaProfile(pybind11::module m) { + m.def("compute_msa_profile", &ComputeMsaProfile, py::arg("msa"), + py::arg("num_residue_types"), py::doc(kComputeMsaProfileDoc + 1)); +} + +} // namespace alphafold3 \ No newline at end of file diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/cpp/msa_profile_pybind.h b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/cpp/msa_profile_pybind.h new file mode 100644 index 0000000000000000000000000000000000000000..1145d331bca619a80f3beef45ad65b7bc64e0bb5 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/cpp/msa_profile_pybind.h @@ -0,0 +1,25 @@ +/* +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + */ + +#ifndef ALPHAFOLD3_SRC_ALPHAFOLD3_DATA_PYTHON_MSA_PROFILE_PYBIND_H_ +#define ALPHAFOLD3_SRC_ALPHAFOLD3_DATA_PYTHON_MSA_PROFILE_PYBIND_H_ + +#include "pybind11/pybind11.h" + +namespace alphafold3 { + +void RegisterModuleMsaProfile(pybind11::module m); + +} + +#endif // ALPHAFOLD3_SRC_ALPHAFOLD3_DATA_PYTHON_MSA_PROFILE_PYBIND_H_ diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/featurisation.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/featurisation.py new file mode 100644 index 0000000000000000000000000000000000000000..ee351626ca91eda6688b729450a15d5c2b79e220 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/featurisation.py @@ -0,0 +1,90 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""AlphaFold 3 featurisation pipeline.""" + +from collections.abc import Sequence +import datetime +import time + +from alphafold3.common import folding_input +from alphafold3.constants import chemical_components +from alphafold3.model import features +from alphafold3.model.pipeline import pipeline +import numpy as np + + +def validate_fold_input(fold_input: folding_input.Input): + """Validates the fold input contains MSA and templates for featurisation.""" + for i, chain in enumerate(fold_input.protein_chains): + if chain.unpaired_msa is None: + raise ValueError(f'Protein chain {i + 1} is missing unpaired MSA.') + if chain.paired_msa is None: + raise ValueError(f'Protein chain {i + 1} is missing paired MSA.') + if chain.templates is None: + raise ValueError(f'Protein chain {i + 1} is missing Templates.') + for i, chain in enumerate(fold_input.rna_chains): + if chain.unpaired_msa is None: + raise ValueError(f'RNA chain {i + 1} is missing unpaired MSA.') + + +def featurise_input( + fold_input: folding_input.Input, + ccd: chemical_components.Ccd, + buckets: Sequence[int] | None, + max_template_date: datetime.date | None = None, + verbose: bool = False, +) -> Sequence[features.BatchDict]: + """Featurise the folding input. + + Args: + fold_input: The input to featurise. + ccd: The chemical components dictionary. + buckets: Bucket sizes to pad the data to, to avoid excessive re-compilation + of the model. If None, calculate the appropriate bucket size from the + number of tokens. If not None, must be a sequence of at least one integer, + in strictly increasing order. Will raise an error if the number of tokens + is more than the largest bucket size. + max_template_date: Optional max template date to prevent data leakage in + validation. + verbose: Whether to print progress messages. + + Returns: + A featurised batch for each rng_seed in the input. + """ + validate_fold_input(fold_input) + + # Set up data pipeline for single use. + data_pipeline = pipeline.WholePdbPipeline( + config=pipeline.WholePdbPipeline.Config( + buckets=buckets, max_template_date=max_template_date + ), + ) + + batches = [] + for rng_seed in fold_input.rng_seeds: + featurisation_start_time = time.time() + if verbose: + print(f'Featurising {fold_input.name} with rng_seed {rng_seed}.') + batch = data_pipeline.process_item( + fold_input=fold_input, + ccd=ccd, + random_state=np.random.RandomState(rng_seed), + random_seed=rng_seed, + ) + if verbose: + print( + f'Featurising {fold_input.name} with rng_seed {rng_seed} ' + f'took {time.time() - featurisation_start_time:.2f} seconds.' + ) + batches.append(batch) + + return batches diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/msa.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/msa.py new file mode 100644 index 0000000000000000000000000000000000000000..51fe211777a077cda24b12353747dba3ecae805b --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/msa.py @@ -0,0 +1,346 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Functions for getting MSA and calculating alignment features.""" + +from collections.abc import MutableMapping, Sequence +import string +from typing import Self + +from absl import logging +from alphafold3.constants import mmcif_names +from alphafold3.data import msa_config +from alphafold3.data import msa_features +from alphafold3.data import parsers +from alphafold3.data.tools import jackhmmer +from alphafold3.data.tools import msa_tool +from alphafold3.data.tools import nhmmer +import numpy as np + + +class Error(Exception): + """Error indicatating a problem with MSA Search.""" + + +def _featurize(seq: str, chain_poly_type: str) -> str | list[int]: + if mmcif_names.is_standard_polymer_type(chain_poly_type): + featurized_seqs, _ = msa_features.extract_msa_features( + msa_sequences=[seq], chain_poly_type=chain_poly_type + ) + return featurized_seqs[0].tolist() + # For anything else simply require an identical match. + return seq + + +def sequences_are_feature_equivalent( + sequence1: str, + sequence2: str, + chain_poly_type: str, +) -> bool: + feat1 = _featurize(sequence1, chain_poly_type) + feat2 = _featurize(sequence2, chain_poly_type) + return feat1 == feat2 + + +class Msa: + """Multiple Sequence Alignment container with methods for manipulating it.""" + + def __init__( + self, + query_sequence: str, + chain_poly_type: str, + sequences: Sequence[str], + descriptions: Sequence[str], + deduplicate: bool = True, + ): + """Raw constructor, prefer using the from_{a3m,multiple_msas} class methods. + + The first sequence must be equal (in featurised form) to the query sequence. + If sequences/descriptions are empty, they will be initialised to the query. + + Args: + query_sequence: The sequence that was used to search for MSA. + chain_poly_type: Polymer type of the query sequence, see mmcif_names. + sequences: The sequences returned by the MSA search tool. + descriptions: Metadata for the sequences returned by the MSA search tool. + deduplicate: If True, the MSA sequences will be deduplicated in the input + order. Lowercase letters (insertions) are ignored when deduplicating. + """ + if len(sequences) != len(descriptions): + raise ValueError( + 'The number of sequences and descriptions must match.') + + self.query_sequence = query_sequence + self.chain_poly_type = chain_poly_type + + if not deduplicate: + self.sequences = sequences + self.descriptions = descriptions + else: + self.sequences = [] + self.descriptions = [] + # A replacement table that removes all lowercase characters. + deletion_table = str.maketrans('', '', string.ascii_lowercase) + unique_sequences = set() + for seq, desc in zip(sequences, descriptions, strict=True): + # Using string.translate is faster than re.sub('[a-z]+', ''). + sequence_no_deletions = seq.translate(deletion_table) + if sequence_no_deletions not in unique_sequences: + unique_sequences.add(sequence_no_deletions) + self.sequences.append(seq) + self.descriptions.append(desc) + + # Make sure the MSA always has at least the query. + self.sequences = self.sequences or [query_sequence] + self.descriptions = self.descriptions or ['Original query'] + + # Check if the 1st MSA sequence matches the query sequence. Since it may be + # mutated by the search tool (jackhmmer) check using the featurized version. + if not sequences_are_feature_equivalent( + self.sequences[0], query_sequence, chain_poly_type + ): + raise ValueError( + f'First MSA sequence {self.sequences[0]} is not the {query_sequence=}' + ) + + @classmethod + def from_multiple_msas( + cls, msas: Sequence[Self], deduplicate: bool = True + ) -> Self: + """Initializes the MSA from multiple MSAs. + + Args: + msas: A sequence of Msa objects representing individual MSAs produced by + different tools/dbs. + deduplicate: If True, the MSA sequences will be deduplicated in the input + order. Lowercase letters (insertions) are ignored when deduplicating. + + Returns: + An Msa object created by merging multiple MSAs. + """ + if not msas: + raise ValueError('At least one MSA must be provided.') + + query_sequence = msas[0].query_sequence + chain_poly_type = msas[0].chain_poly_type + sequences = [] + descriptions = [] + + for msa in msas: + if msa.query_sequence != query_sequence: + raise ValueError( + f'Query sequences must match: {[m.query_sequence for m in msas]}' + ) + if msa.chain_poly_type != chain_poly_type: + raise ValueError( + f'Chain poly types must match: {[m.chain_poly_type for m in msas]}' + ) + sequences.extend(msa.sequences) + descriptions.extend(msa.descriptions) + + return cls( + query_sequence=query_sequence, + chain_poly_type=chain_poly_type, + sequences=sequences, + descriptions=descriptions, + deduplicate=deduplicate, + ) + + @classmethod + def from_multiple_a3ms( + cls, a3ms: Sequence[str], chain_poly_type: str, deduplicate: bool = True + ) -> Self: + """Initializes the MSA from multiple A3M strings. + + Args: + a3ms: A sequence of A3M strings representing individual MSAs produced by + different tools/dbs. + chain_poly_type: Polymer type of the query sequence, see mmcif_names. + deduplicate: If True, the MSA sequences will be deduplicated in the input + order. Lowercase letters (insertions) are ignored when deduplicating. + + Returns: + An Msa object created by merging multiple A3Ms. + """ + if not a3ms: + raise ValueError('At least one A3M must be provided.') + + query_sequence = None + all_sequences = [] + all_descriptions = [] + + for a3m in a3ms: + sequences, descriptions = parsers.parse_fasta(a3m) + if query_sequence is None: + query_sequence = sequences[0] + + if sequences[0] != query_sequence: + raise ValueError( + f'Query sequences must match: {sequences[0]=} != {query_sequence=}' + ) + all_sequences.extend(sequences) + all_descriptions.extend(descriptions) + + return cls( + query_sequence=query_sequence, + chain_poly_type=chain_poly_type, + sequences=all_sequences, + descriptions=all_descriptions, + deduplicate=deduplicate, + ) + + @classmethod + def from_a3m( + cls, + query_sequence: str, + chain_poly_type: str, + a3m: str, + max_depth: int | None = None, + deduplicate: bool = True, + ) -> Self: + """Parses the single A3M and builds the Msa object.""" + sequences, descriptions = parsers.parse_fasta(a3m) + + if max_depth is not None and 0 < max_depth < len(sequences): + logging.info( + 'MSA cropped from depth of %d to %d for %s.', + len(sequences), + max_depth, + query_sequence, + ) + sequences = sequences[:max_depth] + descriptions = descriptions[:max_depth] + + return cls( + query_sequence=query_sequence, + chain_poly_type=chain_poly_type, + sequences=sequences, + descriptions=descriptions, + deduplicate=deduplicate, + ) + + @classmethod + def from_empty(cls, query_sequence: str, chain_poly_type: str) -> Self: + """Creates an empty Msa containing just the query sequence.""" + return cls( + query_sequence=query_sequence, + chain_poly_type=chain_poly_type, + sequences=[], + descriptions=[], + deduplicate=False, + ) + + @property + def depth(self) -> int: + return len(self.sequences) + + def __repr__(self) -> str: + return f'Msa({self.depth} sequences, {self.chain_poly_type})' + + def to_a3m(self) -> str: + """Returns the MSA in the A3M format.""" + a3m_lines = [] + for desc, seq in zip(self.descriptions, self.sequences, strict=True): + a3m_lines.append(f'>{desc}') + a3m_lines.append(seq) + return '\n'.join(a3m_lines) + '\n' + + def featurize(self) -> MutableMapping[str, np.ndarray]: + """Featurises the MSA and returns a map of feature names to features. + + Returns: + A dictionary mapping feature names to values. + + Raises: + msa.Error: + * If the sequences in the MSA don't have the same length after deletions + (lower case letters) are removed. + * If the MSA contains an unknown amino acid code. + * If there are no sequences after aligning. + """ + try: + msa, deletion_matrix = msa_features.extract_msa_features( + msa_sequences=self.sequences, chain_poly_type=self.chain_poly_type + ) + except ValueError as e: + raise Error( + f'Error extracting MSA or deletion features: {e}') from e + + if msa.shape == (0, 0): + raise Error(f'Empty MSA feature for {self}') + + species_ids = msa_features.extract_species_ids(self.descriptions) + + return { + 'msa_species_identifiers': np.array(species_ids, dtype=object), + 'num_alignments': np.array(self.depth, dtype=np.int32), + 'msa': msa, + 'deletion_matrix_int': deletion_matrix, + } + + +def get_msa_tool( + msa_tool_config: msa_config.JackhmmerConfig | msa_config.NhmmerConfig, +) -> msa_tool.MsaTool: + """Returns the requested MSA tool.""" + + match msa_tool_config: + case msa_config.JackhmmerConfig(): + return jackhmmer.Jackhmmer( + binary_path=msa_tool_config.binary_path, + database_path=msa_tool_config.database_config.path, + n_cpu=msa_tool_config.n_cpu, + n_iter=msa_tool_config.n_iter, + e_value=msa_tool_config.e_value, + z_value=msa_tool_config.z_value, + max_sequences=msa_tool_config.max_sequences, + ) + case msa_config.NhmmerConfig(): + return nhmmer.Nhmmer( + binary_path=msa_tool_config.binary_path, + hmmalign_binary_path=msa_tool_config.hmmalign_binary_path, + hmmbuild_binary_path=msa_tool_config.hmmbuild_binary_path, + database_path=msa_tool_config.database_config.path, + n_cpu=msa_tool_config.n_cpu, + e_value=msa_tool_config.e_value, + max_sequences=msa_tool_config.max_sequences, + alphabet=msa_tool_config.alphabet, + ) + case _: + raise ValueError(f'Unknown MSA tool: {msa_tool_config}.') + + +def get_msa( + target_sequence: str, + run_config: msa_config.RunConfig, + chain_poly_type: str, + deduplicate: bool = False, +) -> Msa: + """Computes the MSA for a given query sequence. + + Args: + target_sequence: The target amino-acid sequence. + run_config: MSA run configuration. + chain_poly_type: The type of chain for which to get an MSA. + deduplicate: If True, the MSA sequences will be deduplicated in the input + order. Lowercase letters (insertions) are ignored when deduplicating. + + Returns: + Aligned MSA sequences. + """ + + return Msa.from_a3m( + query_sequence=target_sequence, + chain_poly_type=chain_poly_type, + a3m=get_msa_tool(run_config.config).query(target_sequence).a3m, + max_depth=run_config.crop_size, + deduplicate=deduplicate, + ) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/msa_config.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/msa_config.py new file mode 100644 index 0000000000000000000000000000000000000000..efa2d9b9e33686f93d7f2a32271121914438dbdb --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/msa_config.py @@ -0,0 +1,170 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Genetic search config settings for data pipelines.""" + +import dataclasses +import datetime +from typing import Self +from alphafold3.constants import mmcif_names + + +def _validate_chain_poly_type(chain_poly_type: str) -> None: + if chain_poly_type not in mmcif_names.STANDARD_POLYMER_CHAIN_TYPES: + raise ValueError( + 'chain_poly_type must be one of' + f' {mmcif_names.STANDARD_POLYMER_CHAIN_TYPES}: {chain_poly_type}' + ) + + +@dataclasses.dataclass(frozen=True, kw_only=True, slots=True) +class DatabaseConfig: + """Configuration for a database.""" + + name: str + path: str + + +@dataclasses.dataclass(frozen=True, kw_only=True, slots=True) +class JackhmmerConfig: + """Configuration for a jackhmmer run. + + Attributes: + binary_path: Path to the binary of the msa tool. + database_config: Database configuration. + n_cpu: An integer with the number of CPUs to use. + n_iter: An integer with the number of database search iterations. + e_value: e-value for the database lookup. + z_value: The Z-value representing the number of comparisons done (i.e + correct database size) for E-value calculation. + max_sequences: Max sequences to return in MSA. + """ + + binary_path: str + database_config: DatabaseConfig + n_cpu: int + n_iter: int + e_value: float + z_value: float | int | None + max_sequences: int + + +@dataclasses.dataclass(frozen=True, kw_only=True, slots=True) +class NhmmerConfig: + """Configuration for a nhmmer run. + + Attributes: + binary_path: Path to the binary of the msa tool. + hmmalign_binary_path: Path to the hmmalign binary. + hmmbuild_binary_path: Path to the hmmbuild binary. + database_config: Database configuration. + n_cpu: An integer with the number of CPUs to use. + e_value: e-value for the database lookup. + max_sequences: Max sequences to return in MSA. + alphabet: The alphabet when building a profile with hmmbuild. + """ + + binary_path: str + hmmalign_binary_path: str + hmmbuild_binary_path: str + database_config: DatabaseConfig + n_cpu: int + e_value: float + max_sequences: int + alphabet: str | None + + +@dataclasses.dataclass(frozen=True, kw_only=True, slots=True) +class RunConfig: + """Configuration for an MSA run. + + Attributes: + config: MSA tool config. + chain_poly_type: The chain type for which the tools will be run. + crop_size: The maximum number of sequences to keep in the MSA. If None, all + sequences are kept. Note that the query is included in the MSA, so it + doesn't make sense to set this to less than 2. + """ + + config: JackhmmerConfig | NhmmerConfig + chain_poly_type: str + crop_size: int | None + + def __post_init__(self): + if self.crop_size is not None and self.crop_size < 2: + raise ValueError( + f'crop_size must be None or >= 2: {self.crop_size}') + + _validate_chain_poly_type(self.chain_poly_type) + + +@dataclasses.dataclass(frozen=True, kw_only=True, slots=True) +class HmmsearchConfig: + """Configuration for a hmmsearch.""" + + hmmsearch_binary_path: str + hmmbuild_binary_path: str + + e_value: float + inc_e: float + dom_e: float + incdom_e: float + alphabet: str = 'amino' + filter_f1: float | None = None + filter_f2: float | None = None + filter_f3: float | None = None + filter_max: bool = False + + +@dataclasses.dataclass(frozen=True, kw_only=True, slots=True) +class TemplateToolConfig: + """Configuration for a template tool.""" + + database_path: str + chain_poly_type: str + hmmsearch_config: HmmsearchConfig + max_a3m_query_sequences: int | None = 300 + + def __post_init__(self): + _validate_chain_poly_type(self.chain_poly_type) + + +@dataclasses.dataclass(frozen=True, kw_only=True, slots=True) +class TemplateFilterConfig: + """Configuration for a template filter.""" + + max_subsequence_ratio: float | None + min_align_ratio: float | None + min_hit_length: int | None + deduplicate_sequences: bool + max_hits: int | None + max_template_date: datetime.date + + @classmethod + def no_op_filter(cls) -> Self: + """Returns a config for filter that keeps everything.""" + return cls( + max_subsequence_ratio=None, + min_align_ratio=None, + min_hit_length=None, + deduplicate_sequences=False, + max_hits=None, + # Very far in the future. + max_template_date=datetime.date(3000, 1, 1), + ) + + +@dataclasses.dataclass(frozen=True, kw_only=True, slots=True) +class TemplatesConfig: + """Configuration for the template search pipeline.""" + + template_tool_config: TemplateToolConfig + filter_config: TemplateFilterConfig diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/msa_features.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/msa_features.py new file mode 100644 index 0000000000000000000000000000000000000000..7c6fff3f532de6635d93e30eb421d8943f08e63c --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/msa_features.py @@ -0,0 +1,204 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Utilities for computing MSA features.""" + +from collections.abc import Sequence +import re +from alphafold3.constants import mmcif_names +import numpy as np + +_PROTEIN_TO_ID = { + 'A': 0, + 'B': 3, # Same as D. + 'C': 4, + 'D': 3, + 'E': 6, + 'F': 13, + 'G': 7, + 'H': 8, + 'I': 9, + 'J': 20, # Same as unknown (X). + 'K': 11, + 'L': 10, + 'M': 12, + 'N': 2, + 'O': 20, # Same as unknown (X). + 'P': 14, + 'Q': 5, + 'R': 1, + 'S': 15, + 'T': 16, + 'U': 4, # Same as C. + 'V': 19, + 'W': 17, + 'X': 20, + 'Y': 18, + 'Z': 6, # Same as E. + '-': 21, +} + +_RNA_TO_ID = { + # Map non-standard residues to UNK_NUCLEIC (N) -> 30 + **{chr(i): 30 for i in range(ord('A'), ord('Z') + 1)}, + # Continue the RNA indices from where Protein indices left off. + '-': 21, + 'A': 22, + 'G': 23, + 'C': 24, + 'U': 25, +} + +_DNA_TO_ID = { + # Map non-standard residues to UNK_NUCLEIC (N) -> 30 + **{chr(i): 30 for i in range(ord('A'), ord('Z') + 1)}, + # Continue the DNA indices from where DNA indices left off. + '-': 21, + 'A': 26, + 'G': 27, + 'C': 28, + 'T': 29, +} + + +def extract_msa_features( + msa_sequences: Sequence[str], chain_poly_type: str +) -> tuple[np.ndarray, np.ndarray]: + """Extracts MSA features. + + Example: + The input raw MSA is: `[["AAAAAA"], ["Ai-CiDiiiEFa"]]` + The output MSA will be: `[["AAAAAA"], ["A-CDEF"]]` + The deletions will be: `[[0, 0, 0, 0, 0, 0], [0, 1, 0, 1, 3, 0]]` + + Args: + msa_sequences: A list of strings, each string with one MSA sequence. Each + string must have the same, constant number of non-lowercase (matching) + residues. + chain_poly_type: Either 'polypeptide(L)' (protein), 'polyribonucleotide' + (RNA), or 'polydeoxyribonucleotide' (DNA). Use the appropriate string + constant from mmcif_names.py. + + Returns: + A tuple with: + * MSA array of shape (num_seq, num_res) that contains only the uppercase + characters or gaps (-) from the original MSA. + * Deletions array of shape (num_seq, num_res) that contains the number + of deletions (lowercase letters in the MSA) to the left from each + non-deleted residue (uppercase letters in the MSA). + + Raises: + ValueError if any of the preconditions are not met. + """ + + # Select the appropriate character map based on the chain type. + if chain_poly_type == mmcif_names.RNA_CHAIN: + char_map = _RNA_TO_ID + elif chain_poly_type == mmcif_names.DNA_CHAIN: + char_map = _DNA_TO_ID + elif chain_poly_type == mmcif_names.PROTEIN_CHAIN: + char_map = _PROTEIN_TO_ID + else: + raise ValueError(f'{chain_poly_type=} invalid.') + + # Handle empty MSA. + if not msa_sequences: + empty_msa = np.array([], dtype=np.int32).reshape((0, 0)) + empty_deletions = np.array([], dtype=np.int32).reshape((0, 0)) + return empty_msa, empty_deletions + + # Get the number of rows and columns in the MSA. + num_rows = len(msa_sequences) + num_cols = sum(1 for c in msa_sequences[0] if c in char_map) + + # Initialize the output arrays. + msa_arr = np.zeros((num_rows, num_cols), dtype=np.int32) + deletions_arr = np.zeros((num_rows, num_cols), dtype=np.int32) + + # Populate the output arrays. + for problem_row, msa_sequence in enumerate(msa_sequences): + deletion_count = 0 + upper_count = 0 + problem_col = 0 + problems = [] + for current in msa_sequence: + msa_id = char_map.get(current, -1) + if msa_id == -1: + if not current.islower(): + problems.append( + f'({problem_row}, {problem_col}):{current}') + deletion_count += 1 + else: + # Check the access is safe before writing to the array. + # We don't need to check problem_row since it's guaranteed to be within + # the array bounds, while upper_count is incremented in the loop. + if upper_count < deletions_arr.shape[1]: + deletions_arr[problem_row, upper_count] = deletion_count + msa_arr[problem_row, upper_count] = msa_id + deletion_count = 0 + upper_count += 1 + problem_col += 1 + if problems: + raise ValueError( + f"Unknown residues in MSA: {', '.join(problems)}. " + f'target_sequence: {msa_sequences[0]}' + ) + if upper_count != num_cols: + raise ValueError( + 'Invalid shape all strings must have the same number ' + 'of non-lowercase characters; First string has ' + f"{num_cols} non-lowercase characters but '{msa_sequence}' has " + f'{upper_count}. target_sequence: {msa_sequences[0]}' + ) + + return msa_arr, deletions_arr + + +# UniProtKB SwissProt/TrEMBL dbs have the following description format: +# `db|UniqueIdentifier|EntryName`, e.g. `sp|P0C2L1|A3X1_LOXLA` or +# `tr|A0A146SKV9|A0A146SKV9_FUNHE`. +_UNIPROT_ENTRY_NAME_REGEX = re.compile( + # UniProtKB TrEMBL or SwissProt database. + r'(?:tr|sp)\|' + # A primary accession number of the UniProtKB entry. + r'(?:[A-Z0-9]{6,10})' + # Occasionally there is an isoform suffix (e.g. _1 or _10) which we ignore. + r'(?:_\d+)?\|' + # TrEMBL: Same as AccessionId (6-10 characters). + # SwissProt: A mnemonic protein identification code (1-5 characters). + r'(?:[A-Z0-9]{1,10}_)' + # A mnemonic species identification code. + r'(?P[A-Z0-9]{1,5})' +) + + +def extract_species_ids(msa_descriptions: Sequence[str]) -> Sequence[str]: + """Extracts species ID from MSA UniProtKB sequence identifiers. + + Args: + msa_descriptions: The descriptions (the FASTA/A3M comment line) for each of + the sequences. + + Returns: + Extracted UniProtKB species IDs if there is a regex match for each + description line, blank if the regex doesn't match. + """ + species_ids = [] + for msa_description in msa_descriptions: + msa_description = msa_description.strip() + match = _UNIPROT_ENTRY_NAME_REGEX.match(msa_description) + if match: + species_ids.append(match.group('SpeciesId')) + else: + # Handle cases where the regex doesn't match + # (e.g., append None or raise an error depending on your needs) + species_ids.append('') + return species_ids diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/msa_identifiers.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/msa_identifiers.py new file mode 100644 index 0000000000000000000000000000000000000000..0296080bbc0150000b03718e081611b4f5c63f48 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/msa_identifiers.py @@ -0,0 +1,86 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Utilities for extracting identifiers from MSA sequence descriptions.""" + +import dataclasses +import re + + +# Sequences coming from UniProtKB database come in the +# `db|UniqueIdentifier|EntryName` format, e.g. `tr|A0A146SKV9|A0A146SKV9_FUNHE` +# or `sp|P0C2L1|A3X1_LOXLA` (for TREMBL/Swiss-Prot respectively). +_UNIPROT_PATTERN = re.compile( + r""" + ^ + # UniProtKB/TrEMBL or UniProtKB/Swiss-Prot + (?:tr|sp) + \| + # A primary accession number of the UniProtKB entry. + (?P[A-Za-z0-9]{6,10}) + # Occasionally there is a _0 or _1 isoform suffix, which we ignore. + (?:_\d)? + \| + # TREMBL repeats the accession ID here. Swiss-Prot has a mnemonic + # protein ID code. + (?:[A-Za-z0-9]+) + _ + # A mnemonic species identification code. + (?P([A-Za-z0-9]){1,5}) + # Small BFD uses a final value after an underscore, which we ignore. + (?:_\d+)? + $ + """, + re.VERBOSE, +) + + +@dataclasses.dataclass(frozen=True) +class Identifiers: + species_id: str = '' + + +def _parse_sequence_identifier(msa_sequence_identifier: str) -> Identifiers: + """Gets species from an msa sequence identifier. + + The sequence identifier has the format specified by + _UNIPROT_TREMBL_ENTRY_NAME_PATTERN or _UNIPROT_SWISSPROT_ENTRY_NAME_PATTERN. + An example of a sequence identifier: `tr|A0A146SKV9|A0A146SKV9_FUNHE` + + Args: + msa_sequence_identifier: a sequence identifier. + + Returns: + An `Identifiers` instance with species_id. These + can be empty in the case where no identifier was found. + """ + matches = re.search(_UNIPROT_PATTERN, msa_sequence_identifier.strip()) + if matches: + return Identifiers(species_id=matches.group('SpeciesIdentifier')) + return Identifiers() + + +def _extract_sequence_identifier(description: str) -> str | None: + """Extracts sequence identifier from description. Returns None if no match.""" + split_description = description.split() + if split_description: + return split_description[0].partition('/')[0] + else: + return None + + +def get_identifiers(description: str) -> Identifiers: + """Computes extra MSA features from the description.""" + sequence_identifier = _extract_sequence_identifier(description) + if sequence_identifier is None: + return Identifiers() + else: + return _parse_sequence_identifier(sequence_identifier) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/msa_store.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/msa_store.py new file mode 100644 index 0000000000000000000000000000000000000000..cda58e4c969d56dfef021a54c1e86cba86481201 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/msa_store.py @@ -0,0 +1,67 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Interface and implementations for fetching MSA data.""" + +from collections.abc import Sequence +from typing_extensions import Protocol, TypeAlias + +from alphafold3.data import msa +from alphafold3.data import msa_config + + +MsaErrors: TypeAlias = Sequence[tuple[msa_config.RunConfig, str]] + + +class MsaProvider(Protocol): + """Interface for providing Multiple Sequence Alignments.""" + + def __call__( + self, + query_sequence: str, + chain_polymer_type: str, + ) -> tuple[msa.Msa, MsaErrors]: + """Retrieve MSA for the given polymer query_sequence. + + Args: + query_sequence: The residue sequence of the polymer to search for. + chain_polymer_type: The polymer type of the query_sequence. This must + match the chain_polymer_type of the provider. + + Returns: + A tuple containing the MSA and MsaErrors. MsaErrors is a Sequence + containing a tuple for each msa_query that failed. Each tuple contains + the failing query and the associated error message. + """ + + +class EmptyMsaProvider: + """MSA provider that returns just the query sequence, useful for testing.""" + + def __init__(self, chain_polymer_type: str): + self._chain_polymer_type = chain_polymer_type + + def __call__( + self, query_sequence: str, chain_polymer_type: str + ) -> tuple[msa.Msa, MsaErrors]: + """Returns an MSA containing just the query sequence, never errors.""" + if chain_polymer_type != self._chain_polymer_type: + raise ValueError( + f'EmptyMsaProvider of type {self._chain_polymer_type} called with ' + f'sequence of {chain_polymer_type=}, {query_sequence=}.' + ) + return ( + msa.Msa.from_empty( + query_sequence=query_sequence, + chain_poly_type=self._chain_polymer_type, + ), + (), + ) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/parsers.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/parsers.py new file mode 100644 index 0000000000000000000000000000000000000000..608c9aaf014e2690669a98772cb9d77f7f4edf7d --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/parsers.py @@ -0,0 +1,181 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Functions for parsing various file formats.""" + +from collections.abc import Iterable, Sequence +from typing import IO, TypeAlias + +from alphafold3.cpp import fasta_iterator +from alphafold3.cpp import msa_conversion + + +DeletionMatrix: TypeAlias = Sequence[Sequence[int]] + + +def lazy_parse_fasta_string(fasta_string: str) -> Iterable[tuple[str, str]]: + """Lazily parses a FASTA/A3M string and yields (sequence, description) tuples. + + This implementation is more memory friendly than `fasta_sequence` while + offering comparable performance. The underlying implementation is in C++ and + is therefore faster than a pure Python implementation. + + Use this method when parsing FASTA files where you already have the FASTA + string, but need to control how far you iterate through its sequences. + + Arguments: + fasta_string: A string with the contents of FASTA/A3M file. + + Returns: + Iterator of (sequence, description). In the description, the leading ">" is + stripped. + + Raises: + ValueError if the FASTA/A3M file is invalid, e.g. empty. + """ + + # The lifetime of the FastaStringIterator is tied to the lifetime of + # fasta_string - fasta_string must be kept while the iterator is in use. + return fasta_iterator.FastaStringIterator(fasta_string) + + +def parse_fasta(fasta_string: str) -> tuple[Sequence[str], Sequence[str]]: + """Parses FASTA string and returns list of strings with amino-acid sequences. + + Arguments: + fasta_string: The string contents of a FASTA file. + + Returns: + A tuple of two lists: + * A list of sequences. + * A list of sequence descriptions taken from the comment lines. In the + same order as the sequences. + """ + return fasta_iterator.parse_fasta_include_descriptions(fasta_string) + + +def convert_a3m_to_stockholm(a3m: str, max_seqs: int | None = None) -> str: + """Converts MSA in the A3M format to the Stockholm format.""" + sequences, descriptions = parse_fasta(a3m) + if max_seqs is not None: + sequences = sequences[:max_seqs] + descriptions = descriptions[:max_seqs] + + stockholm = ['# STOCKHOLM 1.0', ''] + + # Add the Stockholm header with the sequence metadata. + names = [] + for i, description in enumerate(descriptions): + name, _, rest = description.partition(' ') + # Ensure that the names are unique - stockholm format requires that + # the sequence names are unique. + name = f'{name}_{i}' + names.append(name) + # Avoid zero-length description due to historic hmmbuild parsing bug. + desc = rest.strip() or '' + stockholm.append(f'#=GS {name.strip()} DE {desc}') + stockholm.append('') + + # Convert insertions in a sequence into gaps in all other sequences that don't + # have an insertion in that column as well. + sequences = msa_conversion.convert_a3m_to_stockholm(sequences) + + # Add the MSA data. + max_name_width = max(len(name) for name in names) + for name, sequence in zip(names, sequences, strict=True): + # Align the names to the left and pad with spaces to the maximum length. + stockholm.append(f'{name:<{max_name_width}s} {sequence}') + + # Add the reference annotation for the query (the first sequence). + ref_annotation = ''.join('.' if c == '-' else 'x' for c in sequences[0]) + stockholm.append(f'{"#=GC RF":<{max_name_width}s} {ref_annotation}') + stockholm.append('//') + + return '\n'.join(stockholm) + + +def convert_stockholm_to_a3m( + stockholm: IO[str], + max_sequences: int | None = None, + remove_first_row_gaps: bool = True, + linewidth: int | None = None, +) -> str: + """Converts MSA in Stockholm format to the A3M format.""" + descriptions = {} + sequences = {} + reached_max_sequences = False + + if linewidth is not None and linewidth <= 0: + raise ValueError('linewidth must be > 0 or None') + + for line in stockholm: + reached_max_sequences = max_sequences and len( + sequences) >= max_sequences + line = line.strip() + # Ignore blank lines, markup and end symbols - remainder are alignment + # sequence parts. + if not line or line.startswith(('#', '//')): + continue + seqname, aligned_seq = line.split(maxsplit=1) + if seqname not in sequences: + if reached_max_sequences: + continue + sequences[seqname] = '' + sequences[seqname] += aligned_seq + + stockholm.seek(0) + for line in stockholm: + line = line.strip() + if line[:4] == '#=GS': + # Description row - example format is: + # #=GS UniRef90_Q9H5Z4/4-78 DE [subseq from] cDNA: FLJ22755 ... + columns = line.split(maxsplit=3) + seqname, feature = columns[1:3] + value = columns[3] if len(columns) == 4 else '' + if feature != 'DE': + continue + if reached_max_sequences and seqname not in sequences: + continue + descriptions[seqname] = value + if len(descriptions) == len(sequences): + break + + assert len(descriptions) <= len(sequences) + + # Convert sto format to a3m line by line + a3m_sequences = {} + # query_sequence is assumed to be the first sequence + query_sequence = next(iter(sequences.values())) + for seqname, sto_sequence in sequences.items(): + if remove_first_row_gaps: + a3m_sequences[seqname] = msa_conversion.align_sequence_to_gapless_query( + sequence=sto_sequence, query_sequence=query_sequence + ).replace('.', '') + else: + a3m_sequences[seqname] = sto_sequence.replace('.', '') + + fasta_chunks = [] + + for seqname, seq in a3m_sequences.items(): + fasta_chunks.append(f'>{seqname} {descriptions.get(seqname, "")}') + + if linewidth: + fasta_chunks.extend( + seq[i: linewidth + i] for i in range(0, len(seq), linewidth) + ) + else: + fasta_chunks.append(seq) + + return '\n'.join(fasta_chunks) + '\n' # Include terminating newline. diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/pipeline.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..89ae3dff34473f191566753ff3928051d99cc84d --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/pipeline.py @@ -0,0 +1,543 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Functions for running the MSA and template tools for the AlphaFold model.""" + +from concurrent import futures +import dataclasses +import datetime +import functools +import logging +import time + +from alphafold3.common import folding_input +from alphafold3.constants import mmcif_names +from alphafold3.data import msa +from alphafold3.data import msa_config +from alphafold3.data import structure_stores +from alphafold3.data import templates as templates_lib + + +# Cache to avoid re-running template search for the same sequence in homomers. +@functools.cache +def _get_protein_templates( + sequence: str, + input_msa_a3m: str, + run_template_search: bool, + templates_config: msa_config.TemplatesConfig, + pdb_database_path: str, +) -> templates_lib.Templates: + """Searches for templates for a single protein chain.""" + if run_template_search: + templates_start_time = time.time() + logging.info('Getting protein templates for sequence %s', sequence) + protein_templates = templates_lib.Templates.from_seq_and_a3m( + query_sequence=sequence, + msa_a3m=input_msa_a3m, + max_template_date=templates_config.filter_config.max_template_date, + database_path=templates_config.template_tool_config.database_path, + hmmsearch_config=templates_config.template_tool_config.hmmsearch_config, + max_a3m_query_sequences=None, + chain_poly_type=mmcif_names.PROTEIN_CHAIN, + structure_store=structure_stores.StructureStore(pdb_database_path), + filter_config=templates_config.filter_config, + ) + logging.info( + 'Getting protein templates took %.2f seconds for sequence %s', + time.time() - templates_start_time, + sequence, + ) + else: + logging.info('Skipping template search for sequence %s', sequence) + protein_templates = templates_lib.Templates( + query_sequence=sequence, + hits=[], + max_template_date=templates_config.filter_config.max_template_date, + structure_store=structure_stores.StructureStore(pdb_database_path), + ) + return protein_templates + + +# Cache to avoid re-running the MSA tools for the same sequence in homomers. +@functools.cache +def _get_protein_msa_and_templates( + sequence: str, + run_template_search: bool, + uniref90_msa_config: msa_config.RunConfig, + mgnify_msa_config: msa_config.RunConfig, + small_bfd_msa_config: msa_config.RunConfig, + uniprot_msa_config: msa_config.RunConfig, + templates_config: msa_config.TemplatesConfig, + pdb_database_path: str, +) -> tuple[msa.Msa, msa.Msa, templates_lib.Templates]: + """Processes a single protein chain.""" + logging.info('Getting protein MSAs for sequence %s', sequence) + msa_start_time = time.time() + # Run various MSA tools in parallel. Use a ThreadPoolExecutor because + # they're not blocked by the GIL, as they're sub-shelled out. + with futures.ThreadPoolExecutor(max_workers=4) as executor: + uniref90_msa_future = executor.submit( + msa.get_msa, + target_sequence=sequence, + run_config=uniref90_msa_config, + chain_poly_type=mmcif_names.PROTEIN_CHAIN, + ) + mgnify_msa_future = executor.submit( + msa.get_msa, + target_sequence=sequence, + run_config=mgnify_msa_config, + chain_poly_type=mmcif_names.PROTEIN_CHAIN, + ) + small_bfd_msa_future = executor.submit( + msa.get_msa, + target_sequence=sequence, + run_config=small_bfd_msa_config, + chain_poly_type=mmcif_names.PROTEIN_CHAIN, + ) + uniprot_msa_future = executor.submit( + msa.get_msa, + target_sequence=sequence, + run_config=uniprot_msa_config, + chain_poly_type=mmcif_names.PROTEIN_CHAIN, + ) + uniref90_msa = uniref90_msa_future.result() + mgnify_msa = mgnify_msa_future.result() + small_bfd_msa = small_bfd_msa_future.result() + uniprot_msa = uniprot_msa_future.result() + logging.info( + 'Getting protein MSAs took %.2f seconds for sequence %s', + time.time() - msa_start_time, + sequence, + ) + + logging.info('Deduplicating MSAs for sequence %s', sequence) + msa_dedupe_start_time = time.time() + with futures.ThreadPoolExecutor() as executor: + unpaired_protein_msa_future = executor.submit( + msa.Msa.from_multiple_msas, + msas=[uniref90_msa, small_bfd_msa, mgnify_msa], + deduplicate=True, + ) + paired_protein_msa_future = executor.submit( + msa.Msa.from_multiple_msas, msas=[uniprot_msa], deduplicate=False + ) + unpaired_protein_msa = unpaired_protein_msa_future.result() + paired_protein_msa = paired_protein_msa_future.result() + logging.info( + 'Deduplicating MSAs took %.2f seconds for sequence %s', + time.time() - msa_dedupe_start_time, + sequence, + ) + + protein_templates = _get_protein_templates( + sequence=sequence, + input_msa_a3m=unpaired_protein_msa.to_a3m(), + run_template_search=run_template_search, + templates_config=templates_config, + pdb_database_path=pdb_database_path, + ) + + return unpaired_protein_msa, paired_protein_msa, protein_templates + + +# Cache to avoid re-running the Nhmmer for the same sequence in homomers. +@functools.cache +def _get_rna_msa( + sequence: str, + nt_rna_msa_config: msa_config.NhmmerConfig, + rfam_msa_config: msa_config.NhmmerConfig, + rnacentral_msa_config: msa_config.NhmmerConfig, +) -> msa.Msa: + """Processes a single RNA chain.""" + logging.info('Getting RNA MSAs for sequence %s', sequence) + rna_msa_start_time = time.time() + # Run various MSA tools in parallel. Use a ThreadPoolExecutor because + # they're not blocked by the GIL, as they're sub-shelled out. + with futures.ThreadPoolExecutor() as executor: + nt_rna_msa_future = executor.submit( + msa.get_msa, + target_sequence=sequence, + run_config=nt_rna_msa_config, + chain_poly_type=mmcif_names.RNA_CHAIN, + ) + rfam_msa_future = executor.submit( + msa.get_msa, + target_sequence=sequence, + run_config=rfam_msa_config, + chain_poly_type=mmcif_names.RNA_CHAIN, + ) + rnacentral_msa_future = executor.submit( + msa.get_msa, + target_sequence=sequence, + run_config=rnacentral_msa_config, + chain_poly_type=mmcif_names.RNA_CHAIN, + ) + nt_rna_msa = nt_rna_msa_future.result() + rfam_msa = rfam_msa_future.result() + rnacentral_msa = rnacentral_msa_future.result() + logging.info( + 'Getting RNA MSAs took %.2f seconds for sequence %s', + time.time() - rna_msa_start_time, + sequence, + ) + + return msa.Msa.from_multiple_msas( + msas=[rfam_msa, rnacentral_msa, nt_rna_msa], + deduplicate=True, + ) + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class DataPipelineConfig: + """The configuration for the data pipeline. + + Attributes: + jackhmmer_binary_path: Jackhmmer binary path, used for protein MSA search. + nhmmer_binary_path: Nhmmer binary path, used for RNA MSA search. + hmmalign_binary_path: Hmmalign binary path, used to align hits to the query + profile. + hmmsearch_binary_path: Hmmsearch binary path, used for template search. + hmmbuild_binary_path: Hmmbuild binary path, used to build HMM profile from + raw MSA in template search. + small_bfd_database_path: Small BFD database path, used for protein MSA + search. + mgnify_database_path: Mgnify database path, used for protein MSA search. + uniprot_cluster_annot_database_path: Uniprot database path, used for protein + paired MSA search. + uniref90_database_path: UniRef90 database path, used for MSA search, and the + MSA obtained by searching it is used to construct the profile for template + search. + ntrna_database_path: NT-RNA database path, used for RNA MSA search. + rfam_database_path: Rfam database path, used for RNA MSA search. + rna_central_database_path: RNAcentral database path, used for RNA MSA + search. + seqres_database_path: PDB sequence database path, used for template search. + pdb_database_path: PDB database directory with mmCIF files path, used for + template search. + jackhmmer_n_cpu: Number of CPUs to use for Jackhmmer. + nhmmer_n_cpu: Number of CPUs to use for Nhmmer. + max_template_date: The latest date of templates to use. + """ + + # Binary paths. + jackhmmer_binary_path: str + nhmmer_binary_path: str + hmmalign_binary_path: str + hmmsearch_binary_path: str + hmmbuild_binary_path: str + + # Jackhmmer databases. + small_bfd_database_path: str + mgnify_database_path: str + uniprot_cluster_annot_database_path: str + uniref90_database_path: str + # Nhmmer databases. + ntrna_database_path: str + rfam_database_path: str + rna_central_database_path: str + # Template search databases. + seqres_database_path: str + pdb_database_path: str + + # Optional configuration for MSA tools. + jackhmmer_n_cpu: int = 8 + nhmmer_n_cpu: int = 8 + + max_template_date: datetime.date + + +class DataPipeline: + """Runs the alignment tools and assembles the input features.""" + + def __init__(self, data_pipeline_config: DataPipelineConfig): + """Initializes the data pipeline with default configurations.""" + self._uniref90_msa_config = msa_config.RunConfig( + config=msa_config.JackhmmerConfig( + binary_path=data_pipeline_config.jackhmmer_binary_path, + database_config=msa_config.DatabaseConfig( + name='uniref90', + path=data_pipeline_config.uniref90_database_path, + ), + n_cpu=data_pipeline_config.jackhmmer_n_cpu, + n_iter=1, + e_value=1e-4, + z_value=None, + max_sequences=10_000, + ), + chain_poly_type=mmcif_names.PROTEIN_CHAIN, + crop_size=None, + ) + self._mgnify_msa_config = msa_config.RunConfig( + config=msa_config.JackhmmerConfig( + binary_path=data_pipeline_config.jackhmmer_binary_path, + database_config=msa_config.DatabaseConfig( + name='mgnify', + path=data_pipeline_config.mgnify_database_path, + ), + n_cpu=data_pipeline_config.jackhmmer_n_cpu, + n_iter=1, + e_value=1e-4, + z_value=None, + max_sequences=5_000, + ), + chain_poly_type=mmcif_names.PROTEIN_CHAIN, + crop_size=None, + ) + self._small_bfd_msa_config = msa_config.RunConfig( + config=msa_config.JackhmmerConfig( + binary_path=data_pipeline_config.jackhmmer_binary_path, + database_config=msa_config.DatabaseConfig( + name='small_bfd', + path=data_pipeline_config.small_bfd_database_path, + ), + n_cpu=data_pipeline_config.jackhmmer_n_cpu, + n_iter=1, + e_value=1e-4, + # Set z_value=138_515_945 to match the z_value used in the paper. + # In practice, this has minimal impact on predicted structures. + z_value=None, + max_sequences=5_000, + ), + chain_poly_type=mmcif_names.PROTEIN_CHAIN, + crop_size=None, + ) + self._uniprot_msa_config = msa_config.RunConfig( + config=msa_config.JackhmmerConfig( + binary_path=data_pipeline_config.jackhmmer_binary_path, + database_config=msa_config.DatabaseConfig( + name='uniprot_cluster_annot', + path=data_pipeline_config.uniprot_cluster_annot_database_path, + ), + n_cpu=data_pipeline_config.jackhmmer_n_cpu, + n_iter=1, + e_value=1e-4, + z_value=None, + max_sequences=50_000, + ), + chain_poly_type=mmcif_names.PROTEIN_CHAIN, + crop_size=None, + ) + self._nt_rna_msa_config = msa_config.RunConfig( + config=msa_config.NhmmerConfig( + binary_path=data_pipeline_config.nhmmer_binary_path, + hmmalign_binary_path=data_pipeline_config.hmmalign_binary_path, + hmmbuild_binary_path=data_pipeline_config.hmmbuild_binary_path, + database_config=msa_config.DatabaseConfig( + name='nt_rna', + path=data_pipeline_config.ntrna_database_path, + ), + n_cpu=data_pipeline_config.nhmmer_n_cpu, + e_value=1e-3, + alphabet='rna', + max_sequences=10_000, + ), + chain_poly_type=mmcif_names.RNA_CHAIN, + crop_size=None, + ) + self._rfam_msa_config = msa_config.RunConfig( + config=msa_config.NhmmerConfig( + binary_path=data_pipeline_config.nhmmer_binary_path, + hmmalign_binary_path=data_pipeline_config.hmmalign_binary_path, + hmmbuild_binary_path=data_pipeline_config.hmmbuild_binary_path, + database_config=msa_config.DatabaseConfig( + name='rfam_rna', + path=data_pipeline_config.rfam_database_path, + ), + n_cpu=data_pipeline_config.nhmmer_n_cpu, + e_value=1e-3, + alphabet='rna', + max_sequences=10_000, + ), + chain_poly_type=mmcif_names.RNA_CHAIN, + crop_size=None, + ) + self._rnacentral_msa_config = msa_config.RunConfig( + config=msa_config.NhmmerConfig( + binary_path=data_pipeline_config.nhmmer_binary_path, + hmmalign_binary_path=data_pipeline_config.hmmalign_binary_path, + hmmbuild_binary_path=data_pipeline_config.hmmbuild_binary_path, + database_config=msa_config.DatabaseConfig( + name='rna_central_rna', + path=data_pipeline_config.rna_central_database_path, + ), + n_cpu=data_pipeline_config.nhmmer_n_cpu, + e_value=1e-3, + alphabet='rna', + max_sequences=10_000, + ), + chain_poly_type=mmcif_names.RNA_CHAIN, + crop_size=None, + ) + + self._templates_config = msa_config.TemplatesConfig( + template_tool_config=msa_config.TemplateToolConfig( + database_path=data_pipeline_config.seqres_database_path, + chain_poly_type=mmcif_names.PROTEIN_CHAIN, + hmmsearch_config=msa_config.HmmsearchConfig( + hmmsearch_binary_path=data_pipeline_config.hmmsearch_binary_path, + hmmbuild_binary_path=data_pipeline_config.hmmbuild_binary_path, + filter_f1=0.1, + filter_f2=0.1, + filter_f3=0.1, + e_value=100, + inc_e=100, + dom_e=100, + incdom_e=100, + alphabet='amino', + ), + ), + filter_config=msa_config.TemplateFilterConfig( + max_subsequence_ratio=0.95, + min_align_ratio=0.1, + min_hit_length=10, + deduplicate_sequences=True, + max_hits=4, + max_template_date=data_pipeline_config.max_template_date, + ), + ) + self._pdb_database_path = data_pipeline_config.pdb_database_path + + def process_protein_chain( + self, chain: folding_input.ProteinChain + ) -> folding_input.ProteinChain: + """Processes a single protein chain.""" + has_unpaired_msa = chain.unpaired_msa is not None + has_paired_msa = chain.paired_msa is not None + has_templates = chain.templates is not None + + if not has_unpaired_msa and not has_paired_msa and not chain.templates: + # MSA None - search. Templates either [] - don't search, or None - search. + unpaired_msa, paired_msa, template_hits = _get_protein_msa_and_templates( + sequence=chain.sequence, + # Skip template search if []. + run_template_search=not has_templates, + uniref90_msa_config=self._uniref90_msa_config, + mgnify_msa_config=self._mgnify_msa_config, + small_bfd_msa_config=self._small_bfd_msa_config, + uniprot_msa_config=self._uniprot_msa_config, + templates_config=self._templates_config, + pdb_database_path=self._pdb_database_path, + ) + unpaired_msa = unpaired_msa.to_a3m() + paired_msa = paired_msa.to_a3m() + templates = [ + folding_input.Template( + mmcif=struct.to_mmcif(), + query_to_template_map=hit.query_to_hit_mapping, + ) + for hit, struct in template_hits.get_hits_with_structures() + ] + elif has_unpaired_msa and has_paired_msa and not has_templates: + # Has MSA, but doesn't have templates. Search for templates only. + empty_msa = msa.Msa.from_empty( + query_sequence=chain.sequence, + chain_poly_type=mmcif_names.PROTEIN_CHAIN, + ).to_a3m() + unpaired_msa = chain.unpaired_msa or empty_msa + paired_msa = chain.paired_msa or empty_msa + template_hits = _get_protein_templates( + sequence=chain.sequence, + input_msa_a3m=unpaired_msa, + run_template_search=True, + templates_config=self._templates_config, + pdb_database_path=self._pdb_database_path, + ) + templates = [ + folding_input.Template( + mmcif=struct.to_mmcif(), + query_to_template_map=hit.query_to_hit_mapping, + ) + for hit, struct in template_hits.get_hits_with_structures() + ] + else: + # Has MSA and templates, don't search for anything. + if not has_unpaired_msa or not has_paired_msa or not has_templates: + raise ValueError( + f'Protein chain {chain.id} has unpaired MSA, paired MSA, or' + ' templates set only partially. If you want to run the pipeline' + ' with custom MSA/templates, you need to set all of them. You can' + ' set MSA to empty string and templates to empty list to signify' + ' that they should not be used and searched for.' + ) + logging.info( + 'Skipping MSA and template search for protein chain %s because it ' + 'already has MSAs and templates.', + chain.id, + ) + if not chain.unpaired_msa: + logging.info( + 'Using empty unpaired MSA for protein chain %s', chain.id) + if not chain.paired_msa: + logging.info( + 'Using empty paired MSA for protein chain %s', chain.id) + if not chain.templates: + logging.info( + 'Using no templates for protein chain %s', chain.id) + empty_msa = msa.Msa.from_empty( + query_sequence=chain.sequence, + chain_poly_type=mmcif_names.PROTEIN_CHAIN, + ).to_a3m() + unpaired_msa = chain.unpaired_msa or empty_msa + paired_msa = chain.paired_msa or empty_msa + templates = chain.templates + + return dataclasses.replace( + chain, + unpaired_msa=unpaired_msa, + paired_msa=paired_msa, + templates=templates, + ) + + def process_rna_chain( + self, chain: folding_input.RnaChain + ) -> folding_input.RnaChain: + """Processes a single RNA chain.""" + if chain.unpaired_msa is not None: + # Don't run MSA tools if the chain already has an MSA. + logging.info( + 'Skipping MSA search for RNA chain %s because it already has MSA.', + chain.id, + ) + if not chain.unpaired_msa: + logging.info( + 'Using empty unpaired MSA for RNA chain %s', chain.id) + empty_msa = msa.Msa.from_empty( + query_sequence=chain.sequence, chain_poly_type=mmcif_names.RNA_CHAIN + ).to_a3m() + unpaired_msa = chain.unpaired_msa or empty_msa + else: + unpaired_msa = _get_rna_msa( + sequence=chain.sequence, + nt_rna_msa_config=self._nt_rna_msa_config, + rfam_msa_config=self._rfam_msa_config, + rnacentral_msa_config=self._rnacentral_msa_config, + ).to_a3m() + return dataclasses.replace(chain, unpaired_msa=unpaired_msa) + + def process(self, fold_input: folding_input.Input) -> folding_input.Input: + """Runs MSA and template tools and returns a new Input with the results.""" + processed_chains = [] + for chain in fold_input.chains: + print(f'Processing chain {chain.id}') + process_chain_start_time = time.time() + match chain: + case folding_input.ProteinChain(): + processed_chains.append(self.process_protein_chain(chain)) + case folding_input.RnaChain(): + processed_chains.append(self.process_rna_chain(chain)) + case _: + processed_chains.append(chain) + print( + f'Processing chain {chain.id} took' + f' {time.time() - process_chain_start_time:.2f} seconds', + ) + + return dataclasses.replace(fold_input, chains=processed_chains) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/structure_stores.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/structure_stores.py new file mode 100644 index 0000000000000000000000000000000000000000..afaa10d320be780f211bc9bfa922b47e87764343 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/structure_stores.py @@ -0,0 +1,102 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Library for loading structure data from various sources.""" + +from collections.abc import Mapping, Sequence +import functools +import os +import pathlib +import tarfile + + +class NotFoundError(KeyError): + """Raised when the structure store doesn't contain the requested target.""" + + +class StructureStore: + """Handles the retrieval of mmCIF files from a filesystem.""" + + def __init__( + self, + structures: str | os.PathLike[str] | Mapping[str, str], + ): + """Initialises the instance. + + Args: + structures: Path of the directory where the mmCIF files are or a Mapping + from target name to mmCIF string. + """ + if isinstance(structures, Mapping): + self._structure_mapping = structures + self._structure_path = None + self._structure_tar = None + else: + self._structure_mapping = None + path_str = os.fspath(structures) + if path_str.endswith('.tar'): + self._structure_tar = tarfile.open(path_str, 'r') + self._structure_path = None + else: + self._structure_path = pathlib.Path(structures) + self._structure_tar = None + + @functools.cached_property + def _tar_members(self) -> Mapping[str, tarfile.TarInfo]: + assert self._structure_tar is not None + return { + path.stem: tarinfo + for tarinfo in self._structure_tar.getmembers() + if tarinfo.isfile() + and (path := pathlib.Path(tarinfo.path.lower())).suffix == '.cif' + } + + def get_mmcif_str(self, target_name: str) -> str: + """Returns an mmCIF for a given `target_name`. + + Args: + target_name: Name specifying the target mmCIF. + + Raises: + NotFoundError: If the target is not found. + """ + if self._structure_mapping is not None: + try: + return self._structure_mapping[target_name] + except KeyError as e: + raise NotFoundError(f'{target_name=} not found') from e + + if self._structure_tar is not None: + try: + member = self._tar_members[target_name] + if struct_file := self._structure_tar.extractfile(member): + return struct_file.read().decode() + else: + raise NotFoundError(f'{target_name=} not found') + except KeyError: + raise NotFoundError(f'{target_name=} not found') from None + + filepath = self._structure_path / f'{target_name}.cif' + try: + return filepath.read_text() + except FileNotFoundError as e: + raise NotFoundError( + f'{target_name=} not found at {filepath=}') from e + + def target_names(self) -> Sequence[str]: + """Returns all targets in the store.""" + if self._structure_mapping is not None: + return [*self._structure_mapping.keys()] + elif self._structure_tar is not None: + return sorted(self._tar_members.keys()) + elif self._structure_path is not None: + return sorted([path.stem for path in self._structure_path.glob('*.cif')]) + return () diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/template_realign.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/template_realign.py new file mode 100644 index 0000000000000000000000000000000000000000..6b4d0215d889d710d15ca6515fb023661934544e --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/template_realign.py @@ -0,0 +1,170 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Realign sequences found in PDB seqres to the actual CIF sequences.""" + +from collections.abc import Mapping + + +class AlignmentError(Exception): + """Failed alignment between the hit sequence and the actual mmCIF sequence.""" + + +def realign_hit_to_structure( + *, + hit_sequence: str, + hit_start_index: int, + hit_end_index: int, + full_length: int, + structure_sequence: str, + query_to_hit_mapping: Mapping[int, int], +) -> Mapping[int, int]: + """Realigns the hit sequence to the Structure sequence. + + For example, for the given input: + query_sequence : ABCDEFGHIJKL + hit_sequence : ---DEFGHIJK- + struc_sequence : XDEFGHKL + the mapping is {3: 0, 4: 1, 5: 2, 6: 3, 7: 4, 8: 5, 9: 6, 10: 7}. However, the + actual Structure sequence has an extra X at the start as well as no IJ. So the + alignment from the query to the Structure sequence will be: + hit_sequence : ---DEFGHIJK- + struc_aligned : --XDEFGH--KL + and the new mapping will therefore be: {3: 1, 4: 2, 5: 3, 6: 4, 7: 5, 10: 6}. + + Args: + hit_sequence: The PDB seqres hit sequence obtained from Hmmsearch, but + without any gaps. This is not the full PDB seqres template sequence but + rather just its subsequence from hit_start_index to hit_end_index. + hit_start_index: The start index of the hit sequence in the full PDB seqres + template sequence (inclusive). + hit_end_index: The end index of the hit sequence in the full PDB seqres + template sequence (exclusive). + full_length: The length of the full PDB seqres template sequence. + structure_sequence: The actual sequence extracted from the Structure + corresponding to this template. In vast majority of cases this is the same + as the PDB seqres sequence, but this function handles the cases when not. + query_to_hit_mapping: The mapping from the query sequence to the + hit_sequence. + + Raises: + AlignmentError: if the alignment between the sequence returned by Hmmsearch + differs from the actual sequence found in the mmCIF and can't be aligned + using the simple alignment algorithm. + + Returns: + A mapping from the query sequence to the actual Structure sequence. + """ + max_num_gaps = full_length - len(structure_sequence) + if max_num_gaps < 0: + raise AlignmentError( + f'The Structure sequence ({len(structure_sequence)}) ' + f'must be shorter than the PDB seqres sequence ({full_length}):\n' + f'Structure sequence : {structure_sequence}\n' + f'PDB seqres sequence: {hit_sequence}' + ) + + if len(hit_sequence) != hit_end_index - hit_start_index: + raise AlignmentError( + f'The difference of {hit_end_index=} and {hit_start_index=} does not ' + f'equal to the length of the {hit_sequence}: {len(hit_sequence)}' + ) + + best_score = -1 + best_start = 0 + best_query_to_hit_mapping = query_to_hit_mapping + max_num_gaps_before_subseq = min(hit_start_index, max_num_gaps) + # It is possible the gaps needed to align the PDB seqres subsequence and + # the Structure subsequence need to be inserted before the match region. + # Try and pick the alignment with the best number of aligned residues. + for num_gaps_before_subseq in range(0, max_num_gaps_before_subseq + 1): + start = hit_start_index - num_gaps_before_subseq + end = hit_end_index - num_gaps_before_subseq + structure_subseq = structure_sequence[start:end] + + new_query_to_hit_mapping, score = _remap_to_struc_seq( + hit_seq=hit_sequence, + struc_seq=structure_subseq, + max_num_gaps=max_num_gaps - num_gaps_before_subseq, + mapping=query_to_hit_mapping, + ) + if score >= best_score: + # Use >= to prefer matches with larger number of gaps before. + best_score = score + best_start = start + best_query_to_hit_mapping = new_query_to_hit_mapping + + return {q: h + best_start for q, h in best_query_to_hit_mapping.items()} + + +def _remap_to_struc_seq( + *, + hit_seq: str, + struc_seq: str, + max_num_gaps: int, + mapping: Mapping[int, int], +) -> tuple[Mapping[int, int], int]: + """Remaps the query -> hit mapping to match the actual Structure sequence. + + Args: + hit_seq: The hit sequence - a subsequence of the PDB seqres sequence without + any Hmmsearch modifications like inserted gaps or lowercased residues. + struc_seq: The actual sequence obtained from the corresponding Structure. + max_num_gaps: The maximum number of gaps that can be inserted in the + Structure sequence. In practice, this is the length difference between the + PDB seqres sequence and the actual Structure sequence. + mapping: The mapping from the query residues to the hit residues. This will + be remapped to point to the actual Structure sequence using a simple + realignment algorithm. + + Returns: + A tuple of (mapping, score): + * Mapping from the query to the actual Structure sequence. + * Score which is the number of matching aligned residues. + + Raises: + ValueError if the structure sequence isn't shorter than the seqres sequence. + ValueError if the alignment fails. + """ + hit_seq_idx = 0 + struc_seq_idx = 0 + hit_to_struc_seq_mapping = {} + score = 0 + + # This while loop is guaranteed to terminate since we increase both + # struc_seq_idx and hit_seq_idx by at least 1 in each iteration. + remaining_num_gaps = max_num_gaps + while hit_seq_idx < len(hit_seq) and struc_seq_idx < len(struc_seq): + if hit_seq[hit_seq_idx] != struc_seq[struc_seq_idx]: + # Explore which alignment aligns the next residue (if present). + best_shift = 0 + for shift in range(0, remaining_num_gaps + 1): + next_hit_res = hit_seq[hit_seq_idx + + shift: hit_seq_idx + shift + 1] + next_struc_res = struc_seq[struc_seq_idx: struc_seq_idx + 1] + if next_hit_res == next_struc_res: + best_shift = shift + break + hit_seq_idx += best_shift + remaining_num_gaps -= best_shift + + hit_to_struc_seq_mapping[hit_seq_idx] = struc_seq_idx + score += hit_seq[hit_seq_idx] == struc_seq[struc_seq_idx] + hit_seq_idx += 1 + struc_seq_idx += 1 + + fixed_mapping = {} + for query_idx, original_hit_idx in mapping.items(): + fixed_hit_idx = hit_to_struc_seq_mapping.get(original_hit_idx) + if fixed_hit_idx is not None: + fixed_mapping[query_idx] = fixed_hit_idx + + return fixed_mapping, score diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/template_store.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/template_store.py new file mode 100644 index 0000000000000000000000000000000000000000..004443960e0ed6f1581665da219b6a07f2829b12 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/template_store.py @@ -0,0 +1,47 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Interface and implementations for fetching templates data.""" + +from collections.abc import Mapping +import datetime +from typing import Any, Protocol, TypeAlias + + +TemplateFeatures: TypeAlias = Mapping[str, Any] + + +class TemplateFeatureProvider(Protocol): + """Interface for providing Template Features.""" + + def __call__( + self, + sequence: str, + release_date: datetime.date | None, + include_ligand_features: bool = True, + ) -> TemplateFeatures: + """Retrieve template features for the given sequence and release_date. + + Args: + sequence: The residue sequence of the query. + release_date: The release_date of the template query, this is used to + filter templates for training, ensuring that they do not leak structure + information from the future. + include_ligand_features: Whether to include ligand features. + + Returns: + Template features: A mapping of template feature labels to features, which + may be numpy arrays, bytes objects, or for the special case of label + `ligand_features`, a nested feature map of labels to numpy arrays. + + Raises: + TemplateRetrievalError if the template features were not found. + """ diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/templates.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/templates.py new file mode 100644 index 0000000000000000000000000000000000000000..060dc7b83a39f6ec18ec5f142713ccb09265035b --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/templates.py @@ -0,0 +1,974 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""API for retrieving and manipulating template search results.""" + +from collections.abc import Iterable, Iterator, Mapping, Sequence +import dataclasses +import datetime +import functools +import os +import re +from typing import Any, Final, Self, TypeAlias +import numpy as np +from absl import logging +from alphafold3 import structure +from alphafold3.common import resources +from alphafold3.constants import atom_types +from alphafold3.constants import mmcif_names +from alphafold3.constants import residue_names +from alphafold3.data import msa_config +from alphafold3.data import parsers +from alphafold3.data import structure_stores +from alphafold3.data import template_realign +from alphafold3.data.tools import hmmsearch +from alphafold3.structure import mmcif + + +_POLYMER_FEATURES: Final[Mapping[str, np.float64 | np.int32 | object]] = { + 'template_aatype': np.int32, + 'template_all_atom_masks': np.float64, + 'template_all_atom_positions': np.float64, + 'template_domain_names': object, + 'template_release_date': object, + 'template_sequence': object, +} + +_LIGAND_FEATURES: Final[Mapping[str, Any]] = { + 'ligand_features': Mapping[str, Any] +} + + +TemplateFeatures: TypeAlias = Mapping[ + str, np.ndarray | bytes | Mapping[str, np.ndarray | bytes] +] +_REQUIRED_METADATA_COLUMNS: Final[Sequence[str]] = ( + 'seq_release_date', + 'seq_unresolved_res_num', + 'seq_author_chain_id', + 'seq_sequence', +) + + +@dataclasses.dataclass(frozen=True, kw_only=True, slots=True) +class _Polymer: + """Container for alphabet specific (dna, rna, protein) atom information.""" + + min_atoms: int + num_atom_types: int + atom_order: Mapping[str, int] + + +_POLYMERS = { + mmcif_names.PROTEIN_CHAIN: _Polymer( + min_atoms=5, + num_atom_types=atom_types.ATOM37_NUM, + atom_order=atom_types.ATOM37_ORDER, + ), + mmcif_names.DNA_CHAIN: _Polymer( + min_atoms=21, + num_atom_types=atom_types.ATOM29_NUM, + atom_order=atom_types.ATOM29_ORDER, + ), + mmcif_names.RNA_CHAIN: _Polymer( + min_atoms=20, + num_atom_types=atom_types.ATOM29_NUM, + atom_order=atom_types.ATOM29_ORDER, + ), +} + + +def _encode_restype( + chain_poly_type: str, + sequence: str, +) -> Sequence[int]: + """Encodes a sequence of residue names as a sequence of ints. + + Args: + chain_poly_type: Polymer chain type to determine sequence encoding. + sequence: Polymer residues. Protein encoded by single letters. RNA and DNA + encoded by multi-letter CCD codes. + + Returns: + A sequence of integers encoding amino acid types for the given chain type. + """ + if chain_poly_type == mmcif_names.PROTEIN_CHAIN: + return [ + residue_names.PROTEIN_TYPES_ONE_LETTER_WITH_UNKNOWN_AND_GAP_TO_INT[ + _STANDARDIZED_AA.get(res, res) + ] + for res in sequence + ] + + unk_nucleic = residue_names.UNK_NUCLEIC_ONE_LETTER + unk_nucleic_idx = residue_names.POLYMER_TYPES_ORDER_WITH_UNKNOWN_AND_GAP[ + unk_nucleic + ] + if chain_poly_type == mmcif_names.RNA_CHAIN: + return [ + residue_names.POLYMER_TYPES_ORDER_WITH_UNKNOWN_AND_GAP.get( + res, unk_nucleic_idx + ) + for res in sequence + ] + elif chain_poly_type == mmcif_names.DNA_CHAIN: + # Map UNK DNA to the generic nucleic UNK (N), which happens to also be the + # same as the RNA UNK. + return [ + residue_names.POLYMER_TYPES_ORDER_WITH_UNKNOWN_AND_GAP.get( + residue_names.DNA_COMMON_ONE_TO_TWO.get(res, unk_nucleic), + unk_nucleic_idx, + ) + for res in sequence + ] + + raise NotImplementedError(f'"{chain_poly_type}" unsupported.') + + +_DAYS_BEFORE_QUERY_DATE: Final[int] = 60 +_HIT_DESCRIPTION_REGEX = re.compile( + r'(?P[a-z0-9]{4,})_(?P\w+)/(?P\d+)-(?P\d+) ' + r'.* length:(?P\d+)\b.*' +) + +_STANDARDIZED_AA = {'B': 'D', 'J': 'X', 'O': 'X', 'U': 'C', 'Z': 'E'} + + +class Error(Exception): + """Base class for exceptions.""" + + +class HitDateError(Error): + """An error indicating that invalid release date was detected.""" + + +class InvalidTemplateError(Error): + """An error indicating that template is invalid.""" + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class Hit: + """Template hit metrics derived from the MSA for filtering and featurising. + + Attributes: + pdb_id: The PDB ID of the hit. + auth_chain_id: The author chain ID of the hit. + hmmsearch_sequence: Hit sequence as given in hmmsearch a3m output. + structure_sequence: Hit sequence as given in PDB structure. + unresolved_res_indices: Indices of unresolved residues in the structure + sequence. 0-based. + query_sequence: The query nucleotide/amino acid sequence. + start_index: The start index of the sequence relative to the full PDB seqres + sequence. Inclusive and uses 0-based indexing. + end_index: The end index of the sequence relative to the full PDB seqres + sequence. Exclusive and uses 0-based indexing. + full_length: Length of the full PDB seqres sequence. This can be different + from the length from the actual sequence we get from the mmCIF and we use + this to detect whether we need to realign or not. + release_date: The release date of the PDB corresponding to this hit. + chain_poly_type: The polymer type of the selected hit structure. + """ + + pdb_id: str + auth_chain_id: str + hmmsearch_sequence: str + structure_sequence: str + unresolved_res_indices: Sequence[int] | None + query_sequence: str + start_index: int + end_index: int + full_length: int + release_date: datetime.date + chain_poly_type: str + + @functools.cached_property + def query_to_hit_mapping(self) -> Mapping[int, int]: + """0-based query index to hit index mapping.""" + query_to_hit_mapping = {} + hit_index = 0 + query_index = 0 + for residue in self.hmmsearch_sequence: + # Gap inserted in the template + if residue == '-': + query_index += 1 + # Deleted residue in the template (would be a gap in the query). + elif residue.islower(): + hit_index += 1 + # Normal aligned residue, in both query and template. Add to mapping. + elif residue.isupper(): + query_to_hit_mapping[query_index] = hit_index + query_index += 1 + hit_index += 1 + + structure_subseq = self.structure_sequence[ + self.start_index: self.end_index + ] + if self.matching_sequence != structure_subseq: + # The seqres sequence doesn't match the structure sequence. Two cases: + # 1. The sequences have the same length. The sequences are different + # because our 3->1 residue code mapping is different from the one PDB + # uses. We don't do anything in this case as both sequences have the + # same length, so the original query to hit mapping stays valid. + # 2. The sequences don't have the same length, the one in structure is + # shorter. In this case we change the mapping to match the actual + # structure sequence using a simple realignment algorithm. + # This procedure was validated on all PDB seqres (2023_01_12) sequences + # and handles all cases that can happen. + if self.full_length != len(self.structure_sequence): + return template_realign.realign_hit_to_structure( + hit_sequence=self.matching_sequence, + hit_start_index=self.start_index, + hit_end_index=self.end_index, + full_length=self.full_length, + structure_sequence=self.structure_sequence, + query_to_hit_mapping=query_to_hit_mapping, + ) + + # Hmmsearch returns a subsequence and so far indices have been relative to + # the subsequence. Add an offset to index relative to the full structure + # sequence. + return {q: h + self.start_index for q, h in query_to_hit_mapping.items()} + + @property + def matching_sequence(self) -> str: + """Returns the matching hit sequence including insertions. + + Make deleted residues uppercase and remove gaps ("-"). + """ + return self.hmmsearch_sequence.upper().replace('-', '') + + @functools.cached_property + def output_templates_sequence(self) -> str: + """Returns the final template sequence.""" + result_seq = ['-'] * len(self.query_sequence) + for query_index, template_index in self.query_to_hit_mapping.items(): + result_seq[query_index] = self.structure_sequence[template_index] + return ''.join(result_seq) + + @property + def length_ratio(self) -> float: + """Ratio of the length of the hit sequence to the query.""" + return len(self.matching_sequence) / len(self.query_sequence) + + @property + def align_ratio(self) -> float: + """Ratio of the number of aligned residues to the query length.""" + return len(self.query_to_hit_mapping) / len(self.query_sequence) + + @functools.cached_property + def is_valid(self) -> bool: + """Whether hit can be used as a template.""" + if self.unresolved_res_indices is None: + return False + + return bool( + set(self.query_to_hit_mapping.values()) + - set(self.unresolved_res_indices) + ) + + @property + def full_name(self) -> str: + """A full name of the hit.""" + return f'{self.pdb_id}_{self.auth_chain_id}' + + def __post_init__(self): + if not self.pdb_id.islower() and not self.pdb_id.isdigit(): + raise ValueError(f'pdb_id must be lowercase {self.pdb_id}') + + if not (0 <= self.start_index <= self.end_index): + raise ValueError( + 'Start must be non-negative and less than or equal to end index. ' + f'Range: {self.start_index}-{self.end_index}' + ) + + if len(self.matching_sequence) != (self.end_index - self.start_index): + raise ValueError( + 'Sequence length must be equal to end_index - start_index. ' + f'{len(self.matching_sequence)} != {self.end_index} - ' + f'{self.start_index}' + ) + + if self.full_length < 0: + raise ValueError( + f'Full length must be non-negative: {self.full_length}') + + def keep( + self, + *, + release_date_cutoff: datetime.date | None, + max_subsequence_ratio: float | None, + min_hit_length: int | None, + min_align_ratio: float | None, + ) -> bool: + """Returns whether the hit should be kept. + + In addition to filtering on all of the provided parameters, this method also + excludes hits with unresolved residues. + + Args: + release_date_cutoff: Maximum release date of the template. + max_subsequence_ratio: If set, excludes hits which are an exact + subsequence of the query sequence, and longer than this ratio. Useful to + avoid ground truth leakage. + min_hit_length: If set, excludes hits which have fewer residues than this. + min_align_ratio: If set, excludes hits where the number of residues + aligned to the query is less than this proportion of the template + length. + """ + # Exclude hits which are too recent. + if ( + release_date_cutoff is not None + and self.release_date > release_date_cutoff + ): + return False + + # Exclude hits which are large duplicates of the query_sequence. + if ( + max_subsequence_ratio is not None + and self.length_ratio > max_subsequence_ratio + ): + if self.matching_sequence in self.query_sequence: + return False + + # Exclude hits which are too short. + if ( + min_hit_length is not None + and len(self.matching_sequence) < min_hit_length + ): + return False + + # Exclude hits with unresolved residues. + if not self.is_valid: + return False + + # Exclude hits with too few alignments. + try: + if min_align_ratio is not None and self.align_ratio <= min_align_ratio: + return False + except template_realign.AlignmentError as e: + logging.warning('Failed to align %s: %s', self, str(e)) + return False + + return True + + +def _filter_hits( + hits: Iterable[Hit], + release_date_cutoff: datetime.date, + max_subsequence_ratio: float | None, + min_align_ratio: float | None, + min_hit_length: int | None, + deduplicate_sequences: bool, + max_hits: int | None, +) -> Sequence[Hit]: + """Filters hits based on the filter config.""" + filtered_hits = [] + seen_before = set() + for hit in hits: + if not hit.keep( + max_subsequence_ratio=max_subsequence_ratio, + min_align_ratio=min_align_ratio, + min_hit_length=min_hit_length, + release_date_cutoff=release_date_cutoff, + ): + continue + + # Remove duplicate templates, keeping the first. + if deduplicate_sequences: + if hit.output_templates_sequence in seen_before: + continue + seen_before.add(hit.output_templates_sequence) + + filtered_hits.append(hit) + if max_hits and len(filtered_hits) == max_hits: + break + + return filtered_hits + + +@dataclasses.dataclass(init=False) +class Templates: + """A container for templates that were found for the given query sequence. + + The structure_store is constructed from the config by default. Callers can + optionally supply a structure_store to the constructor to avoid the cost of + construction and metadata loading. + """ + + def __init__( + self, + *, + query_sequence: str, + hits: Sequence[Hit], + max_template_date: datetime.date, + structure_store: structure_stores.StructureStore, + query_release_date: datetime.date | None = None, + ): + self._query_sequence = query_sequence + self._hits = tuple(hits) + self._max_template_date = max_template_date + self._query_release_date = query_release_date + self._hit_structures = {} + self._structure_store = structure_store + + if any(h.query_sequence != self._query_sequence for h in self.hits): + raise ValueError('All hits must match the query sequence.') + + if self._hits: + chain_poly_type = self._hits[0].chain_poly_type + if any(h.chain_poly_type != chain_poly_type for h in self.hits): + raise ValueError( + 'All hits must have the same chain_poly_type.') + + @classmethod + def from_seq_and_a3m( + cls, + *, + query_sequence: str, + msa_a3m: str, + max_template_date: datetime.date, + database_path: os.PathLike[str] | str, + hmmsearch_config: msa_config.HmmsearchConfig, + max_a3m_query_sequences: int | None, + structure_store: structure_stores.StructureStore, + filter_config: msa_config.TemplateFilterConfig | None = None, + query_release_date: datetime.date | None = None, + chain_poly_type: str = mmcif_names.PROTEIN_CHAIN, + ) -> Self: + """Creates templates from a run of hmmsearch tool against a custom a3m. + + Args: + query_sequence: The polymer sequence of the target query. + msa_a3m: An a3m of related polymers aligned to the query sequence, this is + used to create an HMM for the hmmsearch run. + max_template_date: This is used to filter templates for training, ensuring + that they do not leak ground truth information used in testing sets. + database_path: A path to the sequence database to search for templates. + hmmsearch_config: Config with Hmmsearch settings. + max_a3m_query_sequences: The maximum number of input MSA sequences to use + to construct the profile which is then used to search for templates. + structure_store: Structure store to fetch template structures from. + filter_config: Optional config that controls which and how many hits to + keep. More performant than constructing and then filtering. If not + provided, no filtering is done. + query_release_date: The release_date of the template query, this is used + to filter templates for training, ensuring that they do not leak + structure information from the future. + chain_poly_type: The polymer type of the templates. + + Returns: + Templates object containing a list of Hits initialised from the + structure_store metadata and a3m alignments. + """ + hmmsearch_a3m = run_hmmsearch_with_a3m( + database_path=database_path, + hmmsearch_config=hmmsearch_config, + max_a3m_query_sequences=max_a3m_query_sequences, + a3m=msa_a3m, + ) + return cls.from_hmmsearch_a3m( + query_sequence=query_sequence, + a3m=hmmsearch_a3m, + max_template_date=max_template_date, + query_release_date=query_release_date, + chain_poly_type=chain_poly_type, + structure_store=structure_store, + filter_config=filter_config, + ) + + @classmethod + def from_hmmsearch_a3m( + cls, + *, + query_sequence: str, + a3m: str, + max_template_date: datetime.date, + structure_store: structure_stores.StructureStore, + filter_config: msa_config.TemplateFilterConfig | None = None, + query_release_date: datetime.date | None = None, + chain_poly_type: str = mmcif_names.PROTEIN_CHAIN, + ) -> Self: + """Creates Templates from a Hmmsearch A3M. + + Args: + query_sequence: The polymer sequence of the target query. + a3m: Results of Hmmsearch in A3M format. This provides a list of potential + template alignments and pdb codes. + max_template_date: This is used to filter templates for training, ensuring + that they do not leak ground truth information used in testing sets. + structure_store: Structure store to fetch template structures from. + filter_config: Optional config that controls which and how many hits to + keep. More performant than constructing and then filtering. If not + provided, no filtering is done. + query_release_date: The release_date of the template query, this is used + to filter templates for training, ensuring that they do not leak + structure information from the future. + chain_poly_type: The polymer type of the templates. + + Returns: + Templates object containing a list of Hits initialised from the + structure_store metadata and a3m alignments. + """ + + def hit_generator(a3m: str): + for hit_seq, hit_desc in parsers.lazy_parse_fasta_string(a3m): + pdb_id, auth_chain_id, start, end, full_length = _parse_hit_description( + hit_desc + ) + + release_date, sequence, unresolved_res_ids = _parse_hit_metadata( + structure_store, pdb_id, auth_chain_id + ) + if unresolved_res_ids is None: + continue + + # seq_unresolved_res_num are 1-based, setting to 0-based indices. + unresolved_indices = [i - 1 for i in unresolved_res_ids] + + yield Hit( + pdb_id=pdb_id, + auth_chain_id=auth_chain_id, + hmmsearch_sequence=hit_seq, + structure_sequence=sequence, + query_sequence=query_sequence, + unresolved_res_indices=unresolved_indices, + # Raw value is residue number, not index. + start_index=start - 1, + end_index=end, + full_length=full_length, + release_date=datetime.date.fromisoformat(release_date), + chain_poly_type=chain_poly_type, + ) + + if filter_config is None: + hits = tuple(hit_generator(a3m)) + else: + hits = _filter_hits( + hit_generator(a3m), + release_date_cutoff=filter_config.max_template_date, + max_subsequence_ratio=filter_config.max_subsequence_ratio, + min_align_ratio=filter_config.min_align_ratio, + min_hit_length=filter_config.min_hit_length, + deduplicate_sequences=filter_config.deduplicate_sequences, + max_hits=filter_config.max_hits, + ) + + return Templates( + query_sequence=query_sequence, + query_release_date=query_release_date, + hits=hits, + max_template_date=max_template_date, + structure_store=structure_store, + ) + + @property + def query_sequence(self) -> str: + return self._query_sequence + + @property + def hits(self) -> tuple[Hit, ...]: + return self._hits + + @property + def query_release_date(self) -> datetime.date | None: + return self._query_release_date + + @property + def num_hits(self) -> int: + return len(self._hits) + + @functools.cached_property + def release_date_cutoff(self) -> datetime.date: + if self.query_release_date is None: + return self._max_template_date + return min( + self._max_template_date, + self.query_release_date + - datetime.timedelta(days=_DAYS_BEFORE_QUERY_DATE), + ) + + def __repr__(self) -> str: + return f'Templates({self.num_hits} hits)' + + def filter( + self, + *, + max_subsequence_ratio: float | None, + min_align_ratio: float | None, + min_hit_length: int | None, + deduplicate_sequences: bool, + max_hits: int | None, + ) -> Self: + """Returns a new Templates object with only the hits that pass all filters. + + This also filters on query_release_date and max_template_date. + + Args: + max_subsequence_ratio: If set, excludes hits which are an exact + subsequence of the query sequence, and longer than this ratio. Useful to + avoid ground truth leakage. + min_align_ratio: If set, excludes hits where the number of residues + aligned to the query is less than this proportion of the template + length. + min_hit_length: If set, excludes hits which have fewer residues than this. + deduplicate_sequences: Whether to exclude duplicate template sequences, + keeping only the first. This can be useful in increasing the diversity + of hits especially in the case of homomer hits. + max_hits: If set, excludes any hits which exceed this count. + """ + filtered_hits = _filter_hits( + hits=self._hits, + release_date_cutoff=self.release_date_cutoff, + max_subsequence_ratio=max_subsequence_ratio, + min_align_ratio=min_align_ratio, + min_hit_length=min_hit_length, + deduplicate_sequences=deduplicate_sequences, + max_hits=max_hits, + ) + return Templates( + query_sequence=self.query_sequence, + query_release_date=self.query_release_date, + hits=filtered_hits, + max_template_date=self._max_template_date, + structure_store=self._structure_store, + ) + + def get_hits_with_structures( + self, + ) -> Sequence[tuple[Hit, structure.Structure]]: + """Returns hits + Structures, Structures filtered to the hit's chain.""" + results = [] + structures = {struct.name.lower(): struct for struct in self.structures} + for hit in self.hits: + if not hit.is_valid: + raise InvalidTemplateError( + 'Hits must be filtered before calling get_hits_with_structures.' + ) + struct = structures[hit.pdb_id] + label_chain_id = struct.polymer_auth_asym_id_to_label_asym_id().get( + hit.auth_chain_id + ) + results.append((hit, struct.filter(chain_id=label_chain_id))) + return results + + def featurize( + self, + include_ligand_features: bool = True, + ) -> TemplateFeatures: + """Featurises the templates and returns a map of feature names to features. + + NB: If you don't do any prefiltering, this method might be slow to run + as it has to fetch many CIFs and featurize them all. + + Args: + include_ligand_features: Whether to compute ligand features. + + Returns: + Template features: A mapping of template feature labels to features, which + may be numpy arrays, bytes objects, or for the special case of label + `ligand_features` (if `include_ligand_features` is True), a nested + feature map of labels to numpy arrays. + + Raises: + InvalidTemplateError: If hits haven't been filtered before featurization. + """ + hits_by_pdb_id = {} + for idx, hit in enumerate(self.hits): + if not hit.is_valid: + raise InvalidTemplateError( + f'Hits must be filtered before featurizing, got unprocessed {hit=}' + ) + hits_by_pdb_id.setdefault(hit.pdb_id, []).append((idx, hit)) + + unsorted_features = [] + for struct in self.structures: + pdb_id = str(struct.name).lower() + for idx, hit in hits_by_pdb_id[pdb_id]: + try: + label_chain_id = struct.polymer_auth_asym_id_to_label_asym_id()[ + hit.auth_chain_id + ] + hit_features = { + **get_polymer_features( + chain=struct.filter(chain_id=label_chain_id), + chain_poly_type=hit.chain_poly_type, + query_sequence_length=len(hit.query_sequence), + query_to_hit_mapping=hit.query_to_hit_mapping, + ), + } + if include_ligand_features: + hit_features['ligand_features'] = _get_ligand_features( + struct) + unsorted_features.append((idx, hit_features)) + except Error as e: + raise type(e)(f'Failed to featurise {hit=}') from e + + sorted_features = sorted(unsorted_features, key=lambda x: x[0]) + sorted_features = [feat for _, feat in sorted_features] + return package_template_features( + hit_features=sorted_features, + include_ligand_features=include_ligand_features, + ) + + @property + def structures(self) -> Iterator[structure.Structure]: + """Yields template structures for each unique PDB ID among hits. + + If there are multiple hits in the same Structure, the Structure will be + included only once by this method. + + Yields: + A Structure object for each unique PDB ID among hits. + + Raises: + HitDateError: If template's release date exceeds max cutoff date. + """ + + for hit in self.hits: + if hit.release_date > self.release_date_cutoff: # pylint: disable=comparison-with-callable + raise HitDateError( + f'Invalid release date for hit {hit.pdb_id=}, when release date ' + f'cutoff is {self.release_date_cutoff}.' + ) + + # Get the set of pdbs to load. In particular, remove duplicate PDB IDs. + targets_to_load = tuple({hit.pdb_id for hit in self.hits}) + + for target_name in targets_to_load: + yield structure.from_mmcif( + mmcif_string=self._structure_store.get_mmcif_str(target_name), + fix_mse_residues=True, + fix_arginines=True, + include_water=False, + include_bonds=False, + include_other=True, # For non-standard polymer chains. + ) + + +def _parse_hit_description(description: str) -> tuple[str, str, int, int, int]: + """Parses the hmmsearch A3M sequence description line.""" + # Example lines (protein, nucleic, no description): + # >4pqx_A/2-217 [subseq from] mol:protein length:217 Free text + # >4pqx_A/2-217 [subseq from] mol:na length:217 Free text + # >5g3r_A/1-55 [subseq from] mol:protein length:352 + if match := re.fullmatch(_HIT_DESCRIPTION_REGEX, description): + return ( + match['pdb_id'], + match['chain_id'], + int(match['start']), + int(match['end']), + int(match['length']), + ) + else: + raise ValueError(f'Could not parse description "{description}"') + + +def _parse_hit_metadata( + structure_store: structure_stores.StructureStore, + pdb_id: str, + auth_chain_id: str, +) -> tuple[Any, str | None, Sequence[int] | None]: + """Parse hit metadata by parsing mmCIF from structure store.""" + try: + cif = mmcif.from_string(structure_store.get_mmcif_str(pdb_id)) + except structure_stores.NotFoundError: + logging.warning('Failed to get mmCIF for %s.', pdb_id) + return None, None, None + release_date = mmcif.get_release_date(cif) + + try: + struct = structure.from_parsed_mmcif( + cif, + model_id=structure.ModelID.ALL, + include_water=True, + include_other=True, + include_bonds=False, + ) + except ValueError: + struct = structure.from_parsed_mmcif( + cif, + model_id=structure.ModelID.FIRST, + include_water=True, + include_other=True, + include_bonds=False, + ) + + sequence = struct.polymer_author_chain_single_letter_sequence( + include_missing_residues=True, + protein=True, + dna=True, + rna=True, + other=True, + )[auth_chain_id] + + unresolved_res_ids = struct.filter( + chain_auth_asym_id=auth_chain_id + ).unresolved_residues.id + + return release_date, sequence, unresolved_res_ids + + +def get_polymer_features( + *, + chain: structure.Structure, + chain_poly_type: str, + query_sequence_length: int, + query_to_hit_mapping: Mapping[int, int], +) -> Mapping[str, Any]: + """Returns features for this polymer chain. + + Args: + chain: Structure object representing the template. Must be already filtered + to a single chain. + chain_poly_type: The chain polymer type (protein, DNA, RNA). + query_sequence_length: The length of the query sequence. + query_to_hit_mapping: 0-based query index to hit index mapping. + + Returns: + A dictionary with polymer features for template_chain_id in the struct. + + Raises: + ValueError: If the input structure contains more than just a single chain. + """ + if len(chain.polymer_auth_asym_id_to_label_asym_id()) != 1: + raise ValueError('The structure must be filtered to a single chain.') + + if chain.name is None: + raise ValueError('The structure must have a name.') + + if chain.release_date is None: + raise ValueError('The structure must have a release date.') + + auth_chain_id, label_chain_id = next( + iter(chain.polymer_auth_asym_id_to_label_asym_id().items()) + ) + chain_sequence = chain.chain_single_letter_sequence()[label_chain_id] + + polymer = _POLYMERS[chain_poly_type] + positions, positions_mask = chain.to_res_arrays( + include_missing_residues=True, atom_order=polymer.atom_order + ) + template_all_atom_positions = np.zeros( + (query_sequence_length, polymer.num_atom_types, 3), dtype=np.float64 + ) + template_all_atom_masks = np.zeros( + (query_sequence_length, polymer.num_atom_types), dtype=np.int64 + ) + + template_sequence = ['-'] * query_sequence_length + for query_index, template_index in query_to_hit_mapping.items(): + template_all_atom_positions[query_index] = positions[template_index] + template_all_atom_masks[query_index] = positions_mask[template_index] + template_sequence[query_index] = chain_sequence[template_index] + + template_sequence = ''.join(template_sequence) + template_aatype = _encode_restype(chain_poly_type, template_sequence) + template_name = f'{chain.name.lower()}_{auth_chain_id}' + release_date = chain.release_date.strftime('%Y-%m-%d') + return { + 'template_all_atom_positions': template_all_atom_positions, + 'template_all_atom_masks': template_all_atom_masks, + 'template_sequence': template_sequence.encode(), + 'template_aatype': np.array(template_aatype, dtype=np.int32), + 'template_domain_names': np.array(template_name.encode(), dtype=object), + 'template_release_date': np.array(release_date.encode(), dtype=object), + } + + +def _get_ligand_features( + struct: structure.Structure, +) -> Mapping[str, Mapping[str, np.ndarray | bytes]]: + """Returns features for the ligands in this structure.""" + ligand_struct = struct.filter_to_entity_type(ligand=True) + assert ligand_struct.coords is not None + assert ligand_struct.atom_name is not None + assert ligand_struct.atom_occupancy is not None + + ligand_features = {} + for ligand_chain_id in ligand_struct.chains: + idxs = np.where(ligand_struct.chain_id == ligand_chain_id)[0] + if idxs.shape[0]: + ligand_features[ligand_chain_id] = { + 'ligand_atom_positions': ligand_struct.coords[idxs, :].astype( + np.float32 + ), + 'ligand_atom_names': ligand_struct.atom_name[idxs].astype(object), + 'ligand_atom_occupancies': ligand_struct.atom_occupancy[idxs].astype( + np.float32 + ), + 'ccd_id': ligand_struct.res_name[idxs][0].encode(), + } + return ligand_features + + +def package_template_features( + *, + hit_features: Sequence[Mapping[str, Any]], + include_ligand_features: bool, +) -> Mapping[str, Any]: + """Stacks polymer features, adds empty and keeps ligand features unstacked.""" + + features_to_include = set(_POLYMER_FEATURES) + if include_ligand_features: + features_to_include.update(_LIGAND_FEATURES) + + features = { + feat: [single_hit_features[feat] + for single_hit_features in hit_features] + for feat in features_to_include + } + + stacked_features = {} + for k, v in features.items(): + if k in _POLYMER_FEATURES: + v = np.stack(v, axis=0) if v else np.array( + [], dtype=_POLYMER_FEATURES[k]) + stacked_features[k] = v + + return stacked_features + + +def _resolve_path(path: os.PathLike[str] | str) -> str: + """Resolves path for data dep paths, stringifies otherwise.""" + # Data dependency paths: db baked into the binary. + resolved_path = resources.filename(path) + if os.path.exists(resolved_path): + return resolved_path + else: + # Other paths, e.g. local. + return str(path) + + +def run_hmmsearch_with_a3m( + *, + database_path: os.PathLike[str] | str, + hmmsearch_config: msa_config.HmmsearchConfig, + max_a3m_query_sequences: int | None, + a3m: str | None, +) -> str: + """Runs Hmmsearch to get a3m string of hits.""" + searcher = hmmsearch.Hmmsearch( + binary_path=hmmsearch_config.hmmsearch_binary_path, + hmmbuild_binary_path=hmmsearch_config.hmmbuild_binary_path, + database_path=_resolve_path(database_path), + e_value=hmmsearch_config.e_value, + inc_e=hmmsearch_config.inc_e, + dom_e=hmmsearch_config.dom_e, + incdom_e=hmmsearch_config.incdom_e, + alphabet=hmmsearch_config.alphabet, + filter_f1=hmmsearch_config.filter_f1, + filter_f2=hmmsearch_config.filter_f2, + filter_f3=hmmsearch_config.filter_f3, + filter_max=hmmsearch_config.filter_max, + ) + # STO enables us to annotate query non-gap columns as reference columns. + sto = parsers.convert_a3m_to_stockholm(a3m, max_a3m_query_sequences) + return searcher.query_with_sto(sto, model_construction='hand') diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/hmmalign.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/hmmalign.py new file mode 100644 index 0000000000000000000000000000000000000000..f36967338daab9bcceca0776d75e1ddcf10d968b --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/hmmalign.py @@ -0,0 +1,144 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""A Python wrapper for hmmalign from the HMMER Suite.""" + +from collections.abc import Mapping, Sequence +import os +import tempfile + +from alphafold3.data import parsers +from alphafold3.data.tools import subprocess_utils + + +def _to_a3m(sequences: Sequence[str], name_prefix: str = 'sequence') -> str: + a3m = '' + for i, sequence in enumerate(sequences, 1): + a3m += f'> {name_prefix} {i}\n{sequence}\n' + return a3m + + +class Hmmalign: + """Python wrapper of the hmmalign binary.""" + + def __init__(self, binary_path: str): + """Initializes the Python hmmalign wrapper. + + Args: + binary_path: Path to the hmmalign binary. + + Raises: + RuntimeError: If hmmalign binary not found within the path. + """ + self.binary_path = binary_path + + subprocess_utils.check_binary_exists( + path=self.binary_path, name='hmmalign') + + def align_sequences( + self, + sequences: Sequence[str], + profile: str, + extra_flags: Mapping[str, str] | None = None, + ) -> str: + """Aligns sequence list to the profile and returns the alignment in A3M.""" + return self.align( + a3m_str=_to_a3m(sequences, name_prefix='query'), + profile=profile, + extra_flags=extra_flags, + ) + + def align( + self, + a3m_str: str, + profile: str, + extra_flags: Mapping[str, str] | None = None, + ) -> str: + """Aligns sequences in A3M to the profile and returns the alignment in A3M. + + Args: + a3m_str: A list of sequence strings. + profile: A hmm file with the hmm profile to align the sequences to. + extra_flags: Dictionary with extra flags, flag_name: flag_value, that are + added to hmmalign. + + Returns: + An A3M string with the aligned sequences. + + Raises: + RuntimeError: If hmmalign fails. + """ + with tempfile.TemporaryDirectory() as query_tmp_dir: + input_profile = os.path.join(query_tmp_dir, 'profile.hmm') + input_sequences = os.path.join(query_tmp_dir, 'sequences.a3m') + output_a3m_path = os.path.join(query_tmp_dir, 'output.a3m') + + with open(input_profile, 'w') as f: + f.write(profile) + + with open(input_sequences, 'w') as f: + f.write(a3m_str) + + cmd = [ + self.binary_path, + *('-o', output_a3m_path), + *('--outformat', 'A2M'), # A2M is A3M in the HMMER suite. + ] + if extra_flags: + for flag_name, flag_value in extra_flags.items(): + cmd.extend([flag_name, flag_value]) + cmd.extend([input_profile, input_sequences]) + + subprocess_utils.run( + cmd=cmd, + cmd_name='hmmalign', + log_stdout=False, + log_stderr=True, + log_on_process_error=True, + ) + + with open(output_a3m_path, encoding='utf-8') as f: + a3m = f.read() + + return a3m + + def align_sequences_to_profile(self, profile: str, sequences_a3m: str) -> str: + """Aligns the sequences to profile and returns the alignment in A3M string. + + Uses hmmalign to align the sequences to the profile, then outputs the + sequence concatenated at the beginning of the sequences in the A3M format. + As the sequences are represented by an alignment with possible gaps ('-') + and insertions (lowercase characters), the method first removes the gaps, + then uppercases the insertions to prepare the sequences for realignment. + Sequences with gaps cannot be aligned, as '-'s are not a valid symbol to + align; lowercase characters must be uppercased to preserve the original + sequences before realignment. + + Args: + profile: The Hmmbuild profile to align the sequences to. + sequences_a3m: Sequences in A3M format to align to the profile. + + Returns: + An A3M string with the aligned sequences. + + Raises: + RuntimeError: If hmmalign fails. + """ + deletion_table = str.maketrans('', '', '-') + sequences_no_gaps_a3m = [] + for seq, desc in parsers.lazy_parse_fasta_string(sequences_a3m): + sequences_no_gaps_a3m.append(f'>{desc}') + sequences_no_gaps_a3m.append(seq.translate(deletion_table)) + sequences_no_gaps_a3m = '\n'.join(sequences_no_gaps_a3m) + + aligned_sequences = self.align(sequences_no_gaps_a3m, profile) + + return aligned_sequences diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/hmmbuild.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/hmmbuild.py new file mode 100644 index 0000000000000000000000000000000000000000..8d1f798ba360562f90140f9c027f36c1b1c36ba1 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/hmmbuild.py @@ -0,0 +1,148 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""A Python wrapper for hmmbuild - construct HMM profiles from MSA.""" + +import os +import re +import tempfile +from typing import Literal + +from alphafold3.data import parsers +from alphafold3.data.tools import subprocess_utils + + +class Hmmbuild(object): + """Python wrapper of the hmmbuild binary.""" + + def __init__( + self, + *, + binary_path: str, + singlemx: bool = False, + alphabet: str | None = None, + ): + """Initializes the Python hmmbuild wrapper. + + Args: + binary_path: The path to the hmmbuild executable. + singlemx: Whether to use --singlemx flag. If True, it forces HMMBuild to + just use a common substitution score matrix. + alphabet: The alphabet to assert when building a profile. Useful when + hmmbuild cannot guess the alphabet. If None, no alphabet is asserted. + + Raises: + RuntimeError: If hmmbuild binary not found within the path. + """ + self.binary_path = binary_path + self.singlemx = singlemx + self.alphabet = alphabet + + subprocess_utils.check_binary_exists( + path=self.binary_path, name='hmmbuild') + + def build_profile_from_sto(self, sto: str, model_construction='fast') -> str: + """Builds a HHM for the aligned sequences given as an A3M string. + + Args: + sto: A string with the aligned sequences in the Stockholm format. + model_construction: Whether to use reference annotation in the msa to + determine consensus columns ('hand') or default ('fast'). + + Returns: + A string with the profile in the HMM format. + + Raises: + RuntimeError: If hmmbuild fails. + """ + return self._build_profile( + sto, informat='stockholm', model_construction=model_construction + ) + + def build_profile_from_a3m(self, a3m: str) -> str: + """Builds a HHM for the aligned sequences given as an A3M string. + + Args: + a3m: A string with the aligned sequences in the A3M format. + + Returns: + A string with the profile in the HMM format. + + Raises: + RuntimeError: If hmmbuild fails. + """ + lines = [] + for sequence, description in parsers.lazy_parse_fasta_string(a3m): + # Remove inserted residues. + sequence = re.sub('[a-z]+', '', sequence) + lines.append(f'>{description}\n{sequence}\n') + msa = ''.join(lines) + return self._build_profile(msa, informat='afa') + + def _build_profile( + self, + msa: str, + informat: Literal['afa', 'stockholm'], + model_construction: str = 'fast', + ) -> str: + """Builds a HMM for the aligned sequences given as an MSA string. + + Args: + msa: A string with the aligned sequences, in A3M or STO format. + informat: One of 'afa' (aligned FASTA) or 'sto' (Stockholm). + model_construction: Whether to use reference annotation in the msa to + determine consensus columns ('hand') or default ('fast'). + + Returns: + A string with the profile in the HMM format. + + Raises: + RuntimeError: If hmmbuild fails. + ValueError: If unspecified arguments are provided. + """ + if model_construction not in {'hand', 'fast'}: + raise ValueError( + f'Bad {model_construction=}. Only hand or fast allowed.') + + with tempfile.TemporaryDirectory() as query_tmp_dir: + input_msa_path = os.path.join(query_tmp_dir, 'query.msa') + output_hmm_path = os.path.join(query_tmp_dir, 'output.hmm') + + with open(input_msa_path, 'w') as f: + f.write(msa) + + # Specify the format as we don't specify the input file extension. See + # https://github.com/EddyRivasLab/hmmer/issues/321 for more details. + cmd_flags = ['--informat', informat] + # If adding flags, we have to do so before the output and input: + if model_construction == 'hand': + cmd_flags.append(f'--{model_construction}') + if self.singlemx: + cmd_flags.append('--singlemx') + if self.alphabet: + cmd_flags.append(f'--{self.alphabet}') + + cmd_flags.extend([output_hmm_path, input_msa_path]) + + cmd = [self.binary_path, *cmd_flags] + + subprocess_utils.run( + cmd=cmd, + cmd_name='Hmmbuild', + log_stdout=False, + log_stderr=True, + log_on_process_error=True, + ) + + with open(output_hmm_path) as f: + hmm = f.read() + + return hmm diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/hmmsearch.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/hmmsearch.py new file mode 100644 index 0000000000000000000000000000000000000000..1b4854e5d65094649ec95aa07b3b048e5444a627 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/hmmsearch.py @@ -0,0 +1,152 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""A Python wrapper for hmmsearch - search profile against a sequence db.""" + +import os +import tempfile + +from absl import logging +from alphafold3.data import parsers +from alphafold3.data.tools import hmmbuild +from alphafold3.data.tools import subprocess_utils + + +class Hmmsearch(object): + """Python wrapper of the hmmsearch binary.""" + + def __init__( + self, + *, + binary_path: str, + hmmbuild_binary_path: str, + database_path: str, + alphabet: str = 'amino', + filter_f1: float | None = None, + filter_f2: float | None = None, + filter_f3: float | None = None, + e_value: float | None = None, + inc_e: float | None = None, + dom_e: float | None = None, + incdom_e: float | None = None, + filter_max: bool = False, + ): + """Initializes the Python hmmsearch wrapper. + + Args: + binary_path: The path to the hmmsearch executable. + hmmbuild_binary_path: The path to the hmmbuild executable. Used to build + an hmm from an input a3m. + database_path: The path to the hmmsearch database (FASTA format). + alphabet: Chain type e.g. amino, rna, dna. + filter_f1: MSV and biased composition pre-filter, set to >1.0 to turn off. + filter_f2: Viterbi pre-filter, set to >1.0 to turn off. + filter_f3: Forward pre-filter, set to >1.0 to turn off. + e_value: E-value criteria for inclusion in tblout. + inc_e: E-value criteria for inclusion in MSA/next round. + dom_e: Domain e-value criteria for inclusion in tblout. + incdom_e: Domain e-value criteria for inclusion of domains in MSA/next + round. + filter_max: Remove all filters, will ignore all filter_f* settings. + + Raises: + RuntimeError: If hmmsearch binary not found within the path. + """ + self.binary_path = binary_path + self.hmmbuild_runner = hmmbuild.Hmmbuild( + alphabet=alphabet, binary_path=hmmbuild_binary_path + ) + self.database_path = database_path + flags = [] + if filter_max: + flags.append('--max') + else: + if filter_f1 is not None: + flags.extend(('--F1', filter_f1)) + if filter_f2 is not None: + flags.extend(('--F2', filter_f2)) + if filter_f3 is not None: + flags.extend(('--F3', filter_f3)) + + if e_value is not None: + flags.extend(('-E', e_value)) + if inc_e is not None: + flags.extend(('--incE', inc_e)) + if dom_e is not None: + flags.extend(('--domE', dom_e)) + if incdom_e is not None: + flags.extend(('--incdomE', incdom_e)) + + self.flags = tuple(map(str, flags)) + + subprocess_utils.check_binary_exists( + path=self.binary_path, name='hmmsearch' + ) + + if not os.path.exists(self.database_path): + logging.error( + 'Could not find hmmsearch database %s', database_path) + raise ValueError( + f'Could not find hmmsearch database {database_path}') + + def query_with_hmm(self, hmm: str) -> str: + """Queries the database using hmmsearch using a given hmm.""" + with tempfile.TemporaryDirectory() as query_tmp_dir: + hmm_input_path = os.path.join(query_tmp_dir, 'query.hmm') + sto_out_path = os.path.join(query_tmp_dir, 'output.sto') + with open(hmm_input_path, 'w') as f: + f.write(hmm) + + cmd = [ + self.binary_path, + '--noali', # Don't include the alignment in stdout. + *('--cpu', '8'), + ] + # If adding flags, we have to do so before the output and input: + if self.flags: + cmd.extend(self.flags) + cmd.extend([ + *('-A', sto_out_path), + hmm_input_path, + self.database_path, + ]) + + subprocess_utils.run( + cmd=cmd, + cmd_name='Hmmsearch', + log_stdout=False, + log_stderr=True, + log_on_process_error=True, + ) + + with open(sto_out_path) as f: + a3m_out = parsers.convert_stockholm_to_a3m( + f, remove_first_row_gaps=False, linewidth=60 + ) + + return a3m_out + + def query_with_a3m(self, a3m_in: str) -> str: + """Query the database using hmmsearch using a given a3m.""" + + # Only the "fast" model construction makes sense with A3M, as it doesn't + # have any way to annotate reference columns. + hmm = self.hmmbuild_runner.build_profile_from_a3m(a3m_in) + return self.query_with_hmm(hmm) + + def query_with_sto( + self, msa_sto: str, model_construction: str = 'fast' + ) -> str: + """Queries the database using hmmsearch using a given stockholm msa.""" + hmm = self.hmmbuild_runner.build_profile_from_sto( + msa_sto, model_construction=model_construction + ) + return self.query_with_hmm(hmm) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/jackhmmer.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/jackhmmer.py new file mode 100644 index 0000000000000000000000000000000000000000..4ac8c80105deb20039b9882bcd3aa0e47dfafae7 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/jackhmmer.py @@ -0,0 +1,137 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Library to run Jackhmmer from Python.""" + +import os +import tempfile + +from absl import logging +from alphafold3.data import parsers +from alphafold3.data.tools import msa_tool +from alphafold3.data.tools import subprocess_utils + + +class Jackhmmer(msa_tool.MsaTool): + """Python wrapper of the Jackhmmer binary.""" + + def __init__( + self, + *, + binary_path: str, + database_path: str, + n_cpu: int = 8, + n_iter: int = 3, + e_value: float | None = 1e-3, + z_value: float | int | None = None, + max_sequences: int = 5000, + filter_f1: float = 5e-4, + filter_f2: float = 5e-5, + filter_f3: float = 5e-7, + ): + """Initializes the Python Jackhmmer wrapper. + + Args: + binary_path: The path to the jackhmmer executable. + database_path: The path to the jackhmmer database (FASTA format). + n_cpu: The number of CPUs to give Jackhmmer. + n_iter: The number of Jackhmmer iterations. + e_value: The E-value, see Jackhmmer docs for more details. + z_value: The Z-value representing the number of comparisons done (i.e + correct database size) for E-value calculation. + max_sequences: Maximum number of sequences to return in the MSA. + filter_f1: MSV and biased composition pre-filter, set to >1.0 to turn off. + filter_f2: Viterbi pre-filter, set to >1.0 to turn off. + filter_f3: Forward pre-filter, set to >1.0 to turn off. + + Raises: + RuntimeError: If Jackhmmer binary not found within the path. + """ + self.binary_path = binary_path + self.database_path = database_path + + subprocess_utils.check_binary_exists( + path=self.binary_path, name='Jackhmmer' + ) + + if not os.path.exists(self.database_path): + raise ValueError( + f'Could not find Jackhmmer database {database_path}') + + self.n_cpu = n_cpu + self.n_iter = n_iter + self.e_value = e_value + self.z_value = z_value + self.max_sequences = max_sequences + self.filter_f1 = filter_f1 + self.filter_f2 = filter_f2 + self.filter_f3 = filter_f3 + + def query(self, target_sequence: str) -> msa_tool.MsaToolResult: + """Queries the database using Jackhmmer.""" + logging.info('Query sequence: %s', target_sequence) + with tempfile.TemporaryDirectory() as query_tmp_dir: + input_fasta_path = os.path.join(query_tmp_dir, 'query.fasta') + subprocess_utils.create_query_fasta_file( + sequence=target_sequence, path=input_fasta_path + ) + + output_sto_path = os.path.join(query_tmp_dir, 'output.sto') + + # The F1/F2/F3 are the expected proportion to pass each of the filtering + # stages (which get progressively more expensive), reducing these + # speeds up the pipeline at the expensive of sensitivity. They are + # currently set very low to make querying Mgnify run in a reasonable + # amount of time. + cmd_flags = [ + # Don't pollute stdout with Jackhmmer output. + *('-o', '/dev/null'), + *('-A', output_sto_path), + '--noali', + *('--F1', str(self.filter_f1)), + *('--F2', str(self.filter_f2)), + *('--F3', str(self.filter_f3)), + *('--cpu', str(self.n_cpu)), + *('-N', str(self.n_iter)), + ] + + # Report only sequences with E-values <= x in per-sequence output. + if self.e_value is not None: + cmd_flags.extend(['-E', str(self.e_value)]) + + # Use the same value as the reporting e-value (`-E` flag). + cmd_flags.extend(['--incE', str(self.e_value)]) + + if self.z_value is not None: + cmd_flags.extend(['-Z', str(self.z_value)]) + + cmd = ( + [self.binary_path] + + cmd_flags + + [input_fasta_path, self.database_path] + ) + + subprocess_utils.run( + cmd=cmd, + cmd_name='Jackhmmer', + log_stdout=False, + log_stderr=True, + log_on_process_error=True, + ) + + with open(output_sto_path) as f: + a3m = parsers.convert_stockholm_to_a3m( + f, max_sequences=self.max_sequences + ) + + return msa_tool.MsaToolResult( + target_sequence=target_sequence, a3m=a3m, e_value=self.e_value + ) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/msa_tool.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/msa_tool.py new file mode 100644 index 0000000000000000000000000000000000000000..0c8bd1894f343548566588bd69f90def425df455 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/msa_tool.py @@ -0,0 +1,31 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Defines protocol for MSA tools.""" + +import dataclasses +from typing import Protocol + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class MsaToolResult: + """The result of a MSA tool query.""" + + target_sequence: str + e_value: float + a3m: str + + +class MsaTool(Protocol): + """Interface for MSA tools.""" + + def query(self, target_sequence: str) -> MsaToolResult: + """Runs the MSA tool on the target sequence.""" diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/nhmmer.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/nhmmer.py new file mode 100644 index 0000000000000000000000000000000000000000..70bc06149babf09d00a4264913e63f4a9be48c13 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/nhmmer.py @@ -0,0 +1,175 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Library to run Nhmmer from Python.""" + +import os +import pathlib +import tempfile +from typing import Final + +from absl import logging +from alphafold3.data import parsers +from alphafold3.data.tools import hmmalign +from alphafold3.data.tools import hmmbuild +from alphafold3.data.tools import msa_tool +from alphafold3.data.tools import subprocess_utils + +_SHORT_SEQUENCE_CUTOFF: Final[int] = 50 + + +class Nhmmer(msa_tool.MsaTool): + """Python wrapper of the Nhmmer binary.""" + + def __init__( + self, + binary_path: str, + hmmalign_binary_path: str, + hmmbuild_binary_path: str, + database_path: str, + n_cpu: int = 8, + e_value: float = 1e-3, + max_sequences: int = 5000, + filter_f3: float = 1e-5, + alphabet: str | None = None, + strand: str | None = None, + ): + """Initializes the Python Nhmmer wrapper. + + Args: + binary_path: Path to the Nhmmer binary. + hmmalign_binary_path: Path to the Hmmalign binary. + hmmbuild_binary_path: Path to the Hmmbuild binary. + database_path: MSA database path to search against. This can be either a + FASTA (slow) or HMMERDB produced from the FASTA using the makehmmerdb + binary. The HMMERDB is ~10x faster but experimental. + n_cpu: The number of CPUs to give Nhmmer. + e_value: The E-value, see Nhmmer docs for more details. Will be + overwritten if bit_score is set. + max_sequences: Maximum number of sequences to return in the MSA. + filter_f3: Forward pre-filter, set to >1.0 to turn off. + alphabet: The alphabet to assert when building a profile with hmmbuild. + This must be 'rna', 'dna', or None. + strand: "watson" searches query sequence, "crick" searches + reverse-compliment and default is None which means searching for both. + + Raises: + RuntimeError: If Nhmmer binary not found within the path. + """ + self._binary_path = binary_path + self._hmmalign_binary_path = hmmalign_binary_path + self._hmmbuild_binary_path = hmmbuild_binary_path + self._db_path = database_path + + subprocess_utils.check_binary_exists( + path=self._binary_path, name='Nhmmer') + + if strand and strand not in {'watson', 'crick'}: + raise ValueError( + f'Invalid {strand=}. only "watson" or "crick" supported') + + if alphabet and alphabet not in {'rna', 'dna'}: + raise ValueError( + f'Invalid {alphabet=}, only "rna" or "dna" supported') + + self._e_value = e_value + self._n_cpu = n_cpu + self._max_sequences = max_sequences + self._filter_f3 = filter_f3 + self._alphabet = alphabet + self._strand = strand + + def query(self, target_sequence: str) -> msa_tool.MsaToolResult: + """Query the database using Nhmmer.""" + logging.info('Query sequence: %s', target_sequence) + + with tempfile.TemporaryDirectory() as query_tmp_dir: + input_a3m_path = os.path.join(query_tmp_dir, 'query.a3m') + output_sto_path = os.path.join(query_tmp_dir, 'output.sto') + pathlib.Path(output_sto_path).touch() + subprocess_utils.create_query_fasta_file( + sequence=target_sequence, path=input_a3m_path + ) + + cmd_flags = [ + # Don't pollute stdout with nhmmer output. + *('-o', '/dev/null'), + '--noali', # Don't include the alignment in stdout. + *('--cpu', str(self._n_cpu)), + ] + + cmd_flags.extend(['-E', str(self._e_value)]) + + if self._alphabet: + cmd_flags.extend([f'--{self._alphabet}']) + + if self._strand is not None: + cmd_flags.extend([f'--{self._strand}']) + + cmd_flags.extend(['-A', output_sto_path]) + # As recommend by RNAcentral for short sequences. + if ( + self._alphabet == 'rna' + and len(target_sequence) < _SHORT_SEQUENCE_CUTOFF + ): + cmd_flags.extend(['--F3', str(0.02)]) + else: + cmd_flags.extend(['--F3', str(self._filter_f3)]) + + # The input A3M and the db are the last two arguments. + cmd_flags.extend((input_a3m_path, self._db_path)) + + cmd = [self._binary_path, *cmd_flags] + + subprocess_utils.run( + cmd=cmd, + cmd_name='Nhmmer', + log_stdout=False, + log_stderr=True, + log_on_process_error=True, + ) + + if os.path.getsize(output_sto_path) > 0: + with open(output_sto_path) as f: + a3m_out = parsers.convert_stockholm_to_a3m( + # Query not included. + f, max_sequences=self._max_sequences - 1 + ) + # Nhmmer hits are generally shorter than the query sequence. To get MSA + # of width equal to the query sequence, align hits to the query profile. + logging.info( + 'Aligning output a3m of size %d bytes', len(a3m_out)) + + aligner = hmmalign.Hmmalign(self._hmmalign_binary_path) + target_sequence_fasta = f'>query\n{target_sequence}\n' + profile_builder = hmmbuild.Hmmbuild( + binary_path=self._hmmbuild_binary_path, alphabet=self._alphabet + ) + profile = profile_builder.build_profile_from_a3m( + target_sequence_fasta) + a3m_out = aligner.align_sequences_to_profile( + profile=profile, sequences_a3m=a3m_out + ) + a3m_out = ''.join([target_sequence_fasta, a3m_out]) + + # Parse the output a3m to remove line breaks. + a3m = '\n'.join( + [f'>{n}\n{s}' for s, + n in parsers.lazy_parse_fasta_string(a3m_out)] + ) + else: + # Nhmmer returns an empty file if there are no hits. + # In this case return only the query sequence. + a3m = f'>query\n{target_sequence}' + + return msa_tool.MsaToolResult( + target_sequence=target_sequence, e_value=self._e_value, a3m=a3m + ) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/rdkit_utils.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/rdkit_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..eed6688223de6e3ece0a177485396c312c1d26c9 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/rdkit_utils.py @@ -0,0 +1,526 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Tools for calculating features for ligands.""" + +import collections +from collections.abc import Mapping, Sequence + +from absl import logging +from alphafold3.cpp import cif_dict +import numpy as np +import rdkit.Chem as rd_chem + + +_RDKIT_MMCIF_TO_BOND_TYPE: Mapping[str, rd_chem.BondType] = { + 'SING': rd_chem.BondType.SINGLE, + 'DOUB': rd_chem.BondType.DOUBLE, + 'TRIP': rd_chem.BondType.TRIPLE, +} + +_RDKIT_BOND_TYPE_TO_MMCIF: Mapping[rd_chem.BondType, str] = { + v: k for k, v in _RDKIT_MMCIF_TO_BOND_TYPE.items() +} + +_RDKIT_BOND_STEREO_TO_MMCIF: Mapping[rd_chem.BondStereo, str] = { + rd_chem.BondStereo.STEREONONE: 'N', + rd_chem.BondStereo.STEREOE: 'E', + rd_chem.BondStereo.STEREOZ: 'Z', + rd_chem.BondStereo.STEREOCIS: 'Z', + rd_chem.BondStereo.STEREOTRANS: 'E', +} + + +class MolFromMmcifError(Exception): + """Raised when conversion from mmCIF to RDKit Mol fails.""" + + +class UnsupportedMolBondError(Exception): + """Raised when we try to handle unsupported RDKit bonds.""" + + +def _populate_atoms_in_mol( + mol: rd_chem.Mol, + atom_names: Sequence[str], + atom_types: Sequence[str], + atom_charges: Sequence[int], + implicit_hydrogens: bool, + ligand_name: str, + atom_leaving_flags: Sequence[str], +): + """Populate the atoms of a Mol given atom features. + + Args: + mol: Mol object. + atom_names: Names of the atoms. + atom_types: Types of the atoms. + atom_charges: Charges of the atoms. + implicit_hydrogens: Whether to mark the atoms to allow implicit Hs. + ligand_name: Name of the ligand which the atoms are in. + atom_leaving_flags: Whether the atom is possibly a leaving atom. Values from + the CCD column `_chem_comp_atom.pdbx_leaving_atom_flag`. The expected + values are 'Y' (yes), 'N' (no), '?' (unknown/unset, interpreted as no). + + Raises: + ValueError: If atom type is invalid. + """ + # Map atom names to the position they will take in the rdkit molecule. + atom_name_to_idx = {name: i for i, name in enumerate(atom_names)} + + for atom_name, atom_type, atom_charge, atom_leaving_flag in zip( + atom_names, atom_types, atom_charges, atom_leaving_flags, strict=True + ): + try: + if atom_type == 'X': + atom_type = '*' + atom = rd_chem.Atom(atom_type) + except RuntimeError as e: + raise ValueError(f'Failed to use atom type: {str(e)}') from e + + if not implicit_hydrogens: + atom.SetNoImplicit(True) + + atom.SetProp('atom_name', atom_name) + atom.SetProp('atom_leaving_flag', atom_leaving_flag) + atom.SetFormalCharge(atom_charge) + residue_info = rd_chem.AtomPDBResidueInfo() + residue_info.SetName(_format_atom_name(atom_name, atom_type)) + residue_info.SetIsHeteroAtom(True) + residue_info.SetResidueName(ligand_name) + residue_info.SetResidueNumber(1) + atom.SetPDBResidueInfo(residue_info) + atom_index = mol.AddAtom(atom) + assert atom_index == atom_name_to_idx[atom_name] + + +def _populate_bonds_in_mol( + mol: rd_chem.Mol, + atom_names: Sequence[str], + bond_begins: Sequence[str], + bond_ends: Sequence[str], + bond_orders: Sequence[str], + bond_is_aromatics: Sequence[bool], +): + """Populate the bonds of a Mol given bond features. + + Args: + mol: Mol object. + atom_names: Names of atoms in the molecule. + bond_begins: Names of atoms at the beginning of the bond. + bond_ends: Names of atoms at the end of the bond. + bond_orders: What order the bonds are. + bond_is_aromatics: Whether the bonds are aromatic. + """ + atom_name_to_idx = {name: i for i, name in enumerate(atom_names)} + for begin, end, bond_type, is_aromatic in zip( + bond_begins, bond_ends, bond_orders, bond_is_aromatics, strict=True + ): + begin_name, end_name = atom_name_to_idx[begin], atom_name_to_idx[end] + bond_idx = mol.AddBond(begin_name, end_name, bond_type) + mol.GetBondWithIdx(bond_idx - 1).SetIsAromatic(is_aromatic) + + +def _sanitize_mol(mol, sort_alphabetically, remove_hydrogens) -> rd_chem.Mol: + # https://www.rdkit.org/docs/source/rdkit.Chem.rdmolops.html#rdkit.Chem.rdmolops.SanitizeMol + # Kekulize, check valencies, set aromaticity, conjugation and hybridization. + # This can repair e.g. incorrect aromatic flags. + rd_chem.SanitizeMol(mol) + if sort_alphabetically: + mol = sort_atoms_by_name(mol) + if remove_hydrogens: + mol = rd_chem.RemoveHs(mol) + return mol + + +def _add_conformer_to_mol(mol, conformer, force_parse) -> rd_chem.Mol: + # Create conformer and use it to assign stereochemistry. + if conformer is not None: + try: + mol.AddConformer(conformer) + rd_chem.AssignStereochemistryFrom3D(mol) + except ValueError as e: + logging.warning('Failed to parse conformer: %s', e) + if not force_parse: + raise + + +def mol_from_ccd_cif( + mol_cif: cif_dict.CifDict, + *, + force_parse: bool = False, + sort_alphabetically: bool = True, + remove_hydrogens: bool = True, + implicit_hydrogens: bool = False, +) -> rd_chem.Mol: + """Creates an rdkit Mol object from a CCD mmcif data block. + + The atoms are renumbered so that their names are in alphabetical order and + these names are placed on the atoms under property 'atom_name'. + Only hydrogens which are not required to define the molecule are removed. + For example, hydrogens that define stereochemistry around a double bond are + retained. + See this link for more details. + https://www.rdkit.org/docs/source/rdkit.Chem.rdmolops.html#rdkit.Chem.rdmolops.RemoveHs + + Args: + mol_cif: An mmcif object representing a molecule. + force_parse: If True, assumes missing aromatic flags are false, substitutes + deuterium for hydrogen, assumes missing charges are 0 and ignores missing + conformer / stereochemistry information. + sort_alphabetically: True: sort atom alphabetically; False: keep CCD order + remove_hydrogens: if True, remove non-important hydrogens + implicit_hydrogens: Sets a marker on the atom that allows implicit Hs. + + Returns: + An rdkit molecule, with the atoms sorted by name. + + Raises: + MolToMmcifError: If conversion from mmcif to rdkit Mol fails. More detailed + error is available as this error's cause. + """ + # Read data fields. + try: + atom_names, atom_types, atom_charges, atom_leaving_flags = parse_atom_data( + mol_cif, force_parse + ) + bond_begins, bond_ends, bond_orders, bond_is_aromatics = parse_bond_data( + mol_cif, force_parse + ) + lig_name = mol_cif['_chem_comp.id'][0].rjust(3) + except (KeyError, ValueError) as e: + raise MolFromMmcifError from e + + # Build Rdkit molecule. + mol = rd_chem.RWMol() + + # Per atom features. + try: + _populate_atoms_in_mol( + mol=mol, + atom_names=atom_names, + atom_types=atom_types, + atom_charges=atom_charges, + implicit_hydrogens=implicit_hydrogens, + ligand_name=lig_name, + atom_leaving_flags=atom_leaving_flags, + ) + except (ValueError, RuntimeError) as e: + raise MolFromMmcifError from e + + _populate_bonds_in_mol( + mol, atom_names, bond_begins, bond_ends, bond_orders, bond_is_aromatics + ) + + try: + conformer = _parse_ideal_conformer(mol_cif) + except (KeyError, ValueError) as e: + logging.warning('Failed to parse ideal conformer: %s', e) + if not force_parse: + raise MolFromMmcifError from e + conformer = None + + mol.UpdatePropertyCache(strict=False) + + try: + _add_conformer_to_mol(mol, conformer, force_parse) + mol = _sanitize_mol(mol, sort_alphabetically, remove_hydrogens) + except ( + ValueError, + rd_chem.KekulizeException, + rd_chem.AtomValenceException, + ) as e: + raise MolFromMmcifError from e + + return mol + + +def mol_to_ccd_cif( + mol: rd_chem.Mol, + component_id: str, + pdbx_smiles: str | None = None, + include_hydrogens: bool = True, +) -> cif_dict.CifDict: + """Creates a CCD-like mmcif data block from an rdkit Mol object. + + Only a subset of associated mmcif fields is populated, but that is + sufficient for further usage, e.g. in featurization code. + + Atom names can be specified via `atom_name` property. For atoms with + unspecified value of that property, the name is assigned based on element type + and the order in the Mol object. + + If the Mol object has associated conformers, atom positions from the first of + them will be populated in the resulting mmcif file. + + Args: + mol: An rdkit molecule. + component_id: Name of the molecule to use in the resulting mmcif. That is + equivalent to CCD code. + pdbx_smiles: If specified, the value will be used to populate + `_chem_comp.pdbx_smiles`. + include_hydrogens: Whether to include atom and bond data involving + hydrogens. + + Returns: + An mmcif data block corresponding for the given rdkit molecule. + + Raises: + UnsupportedMolBond: When a molecule contains a bond that can't be + represented with mmcif. + """ + mol = rd_chem.Mol(mol) + if include_hydrogens: + mol = rd_chem.AddHs(mol) + rd_chem.Kekulize(mol) + + if mol.GetNumConformers() > 0: + ideal_conformer = mol.GetConformer(0).GetPositions() + ideal_conformer = np.vectorize(lambda x: f'{x:.3f}')(ideal_conformer) + else: + # No data will be populated in the resulting mmcif if the molecule doesn't + # have any conformers attached to it. + ideal_conformer = None + + mol_cif = collections.defaultdict(list) + mol_cif['data_'] = [component_id] + mol_cif['_chem_comp.id'] = [component_id] + if pdbx_smiles: + mol_cif['_chem_comp.pdbx_smiles'] = [pdbx_smiles] + + mol = assign_atom_names_from_graph(mol, keep_existing_names=True) + + for atom_idx, atom in enumerate(mol.GetAtoms()): + element = atom.GetSymbol() + if not include_hydrogens and element in ('H', 'D'): + continue + + mol_cif['_chem_comp_atom.comp_id'].append(component_id) + mol_cif['_chem_comp_atom.atom_id'].append(atom.GetProp('atom_name')) + mol_cif['_chem_comp_atom.type_symbol'].append(atom.GetSymbol().upper()) + mol_cif['_chem_comp_atom.charge'].append(str(atom.GetFormalCharge())) + if ideal_conformer is not None: + coords = ideal_conformer[atom_idx] + mol_cif['_chem_comp_atom.pdbx_model_Cartn_x_ideal'].append( + coords[0]) + mol_cif['_chem_comp_atom.pdbx_model_Cartn_y_ideal'].append( + coords[1]) + mol_cif['_chem_comp_atom.pdbx_model_Cartn_z_ideal'].append( + coords[2]) + + for bond in mol.GetBonds(): + atom1 = bond.GetBeginAtom() + atom2 = bond.GetEndAtom() + if not include_hydrogens and ( + atom1.GetSymbol() in ('H', 'D') or atom2.GetSymbol() in ('H', 'D') + ): + continue + mol_cif['_chem_comp_bond.comp_id'].append(component_id) + mol_cif['_chem_comp_bond.atom_id_1'].append( + bond.GetBeginAtom().GetProp('atom_name') + ) + mol_cif['_chem_comp_bond.atom_id_2'].append( + bond.GetEndAtom().GetProp('atom_name') + ) + try: + bond_type = bond.GetBondType() + # Older versions of RDKit did not have a DATIVE bond type. Convert it to + # SINGLE to match the AF3 training setup. + if bond_type == rd_chem.BondType.DATIVE: + bond_type = rd_chem.BondType.SINGLE + mol_cif['_chem_comp_bond.value_order'].append( + _RDKIT_BOND_TYPE_TO_MMCIF[bond_type] + ) + mol_cif['_chem_comp_bond.pdbx_stereo_config'].append( + _RDKIT_BOND_STEREO_TO_MMCIF[bond.GetStereo()] + ) + except KeyError as e: + raise UnsupportedMolBondError from e + mol_cif['_chem_comp_bond.pdbx_aromatic_flag'].append( + 'Y' if bond.GetIsAromatic() else 'N' + ) + + return cif_dict.CifDict(mol_cif) + + +def _format_atom_name(atom_name: str, atom_type: str) -> str: + """Formats an atom name to fit in the four characters specified in PDB. + + See for example the following note on atom name formatting in PDB files: + https://www.cgl.ucsf.edu/chimera/docs/UsersGuide/tutorials/pdbintro.html#note1 + + Args: + atom_name: The unformatted atom name. + atom_type: The atom element symbol. + + Returns: + formatted_atom_name: The formatted 4-character atom name. + """ + atom_name = atom_name.strip() + atom_type = atom_type.strip().upper() + if len(atom_name) == 1: + return atom_name.rjust(2).ljust(4) + elif len(atom_name) == 2: + if atom_name == atom_type: + return atom_name.ljust(4) + return atom_name.center(4) + elif len(atom_name) == 3: + if atom_name[:2] == atom_type: + return atom_name.ljust(4) + return atom_name.rjust(4) + elif len(atom_name) == 4: + return atom_name + else: + raise ValueError( + f'Atom name `{atom_name}` has more than four characters ' + 'or is an empty string.' + ) + + +def parse_atom_data( + mol_cif: cif_dict.CifDict | Mapping[str, Sequence[str]], force_parse: bool +) -> tuple[Sequence[str], Sequence[str], Sequence[int], Sequence[str]]: + """Parses atoms. If force_parse is True, fix deuterium and missing charge.""" + atom_types = [t.capitalize() + for t in mol_cif['_chem_comp_atom.type_symbol']] + atom_names = mol_cif['_chem_comp_atom.atom_id'] + atom_charges = mol_cif['_chem_comp_atom.charge'] + atom_leaving_flags = ['?'] * len(atom_names) + if '_chem_comp_atom.pdbx_leaving_atom_flag' in mol_cif: + atom_leaving_flags = mol_cif['_chem_comp_atom.pdbx_leaving_atom_flag'] + + if force_parse: + # Replace missing charges with 0. + atom_charges = [charge if charge != + '?' else '0' for charge in atom_charges] + # Deuterium for hydrogen. + atom_types = [type_ if type_ != 'D' else 'H' for type_ in atom_types] + + atom_charges = [int(atom_charge) for atom_charge in atom_charges] + return atom_names, atom_types, atom_charges, atom_leaving_flags + + +def parse_bond_data( + mol_cif: cif_dict.CifDict | Mapping[str, Sequence[str]], force_parse: bool +) -> tuple[ + Sequence[str], Sequence[str], Sequence[rd_chem.BondType], Sequence[bool] +]: + """Parses bond data. If force_parse is True, ignore missing aromatic flags.""" + # The bond table isn't present if there are no bonds. Use [] in that case. + begin_atoms = mol_cif.get('_chem_comp_bond.atom_id_1', []) + end_atoms = mol_cif.get('_chem_comp_bond.atom_id_2', []) + orders = mol_cif.get('_chem_comp_bond.value_order', []) + bond_types = [_RDKIT_MMCIF_TO_BOND_TYPE[order] for order in orders] + + try: + aromatic_flags = mol_cif.get('_chem_comp_bond.pdbx_aromatic_flag', []) + is_aromatic = [{'Y': True, 'N': False}[flag] + for flag in aromatic_flags] + except KeyError: + if force_parse: + # Set them all to not aromatic. + is_aromatic = [False for _ in begin_atoms] + else: + raise + + return begin_atoms, end_atoms, bond_types, is_aromatic + + +def _parse_ideal_conformer(mol_cif: cif_dict.CifDict) -> rd_chem.Conformer: + """Builds a conformer containing the ideal coordinates from the CCD. + + Args: + mol_cif: An mmcif object representing a molecule. + + Returns: + An rdkit conformer filled with the ideal positions from the mmcif. + + Raises: + ValueError: if the positions can't be interpreted. + """ + atom_x = [ + float(x) for x in mol_cif['_chem_comp_atom.pdbx_model_Cartn_x_ideal'] + ] + atom_y = [ + float(y) for y in mol_cif['_chem_comp_atom.pdbx_model_Cartn_y_ideal'] + ] + atom_z = [ + float(z) for z in mol_cif['_chem_comp_atom.pdbx_model_Cartn_z_ideal'] + ] + atom_positions = zip(atom_x, atom_y, atom_z, strict=True) + + conformer = rd_chem.Conformer(len(atom_x)) + for atom_index, atom_position in enumerate(atom_positions): + conformer.SetAtomPosition(atom_index, atom_position) + + return conformer + + +def sort_atoms_by_name(mol: rd_chem.Mol) -> rd_chem.Mol: + """Sorts the atoms in the molecule by their names.""" + atom_names = { + atom.GetProp('atom_name'): atom.GetIdx() for atom in mol.GetAtoms() + } + + # Sort the name, int tuples by the names. + sorted_atom_names = sorted(atom_names.items()) + + # Zip these tuples back together to the sorted indices. + _, new_order = zip(*sorted_atom_names, strict=True) + + # Reorder the molecule. + # new_order is effectively an argsort of the names. + return rd_chem.RenumberAtoms(mol, new_order) + + +def assign_atom_names_from_graph( + mol: rd_chem.Mol, + keep_existing_names: bool = False, +) -> rd_chem.Mol: + """Assigns atom names from the molecular graph. + + The atom name is stored as an atom property 'atom_name', accessible + with atom.GetProp('atom_name'). If the property is already specified, and + keep_existing_names is True we keep the original name. + + We traverse the graph in the order of the rdkit atom index and give each atom + a name equal to '{ELEMENT_TYPE}_{INDEX}'. E.g. C5 is the name for the fifth + unnamed carbon encountered. + + NOTE: A new mol is returned, the original is not changed in place. + + Args: + mol: + keep_existing_names: If True, atoms that already have the atom_name property + will keep their assigned names. + + Returns: + A new mol, with potentially new 'atom_name' properties. + """ + mol = rd_chem.Mol(mol) + + specified_atom_names = { + atom.GetProp('atom_name') + for atom in mol.GetAtoms() + if atom.HasProp('atom_name') and keep_existing_names + } + + element_counts = collections.Counter() + for atom in mol.GetAtoms(): + if not atom.HasProp('atom_name') or not keep_existing_names: + element = atom.GetSymbol() + while True: + element_counts[element] += 1 + new_name = f'{element}{element_counts[element]}' + if new_name not in specified_atom_names: + break + atom.SetProp('atom_name', new_name) + + return mol diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/subprocess_utils.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/subprocess_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..15aa61ba5f970d1728eebd7ff8efc8ab30f93c0e --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/subprocess_utils.py @@ -0,0 +1,108 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Helper functions for launching external tools.""" + +from collections.abc import Sequence +import os +import subprocess +import time +from typing import Any + +from absl import logging + + +def create_query_fasta_file(sequence: str, path: str, linewidth: int = 80): + """Creates a fasta file with the sequence with line width limit.""" + with open(path, 'w') as f: + f.write('>query\n') + + i = 0 + while i < len(sequence): + f.write(f'{sequence[i:(i + linewidth)]}\n') + i += linewidth + + +def check_binary_exists(path: str, name: str) -> None: + """Checks if a binary exists on the given path and raises otherwise.""" + if not os.path.exists(path): + raise RuntimeError(f'{name} binary not found at {path}') + + +def run( + cmd: Sequence[str], + cmd_name: str, + log_on_process_error: bool = False, + log_stderr: bool = False, + log_stdout: bool = False, + max_out_streams_len: int | None = 500_000, + **run_kwargs, +) -> subprocess.CompletedProcess[Any]: + """Launches a subprocess, times it, and checks for errors. + + Args: + cmd: Command to launch. + cmd_name: Human-readable command name to be used in logs. + log_on_process_error: Whether to use `logging.error` to log the process' + stderr on failure. + log_stderr: Whether to log the stderr of the command. + log_stdout: Whether to log the stdout of the command. + max_out_streams_len: Max length of prefix of stdout and stderr included in + the exception message. Set to `None` to disable truncation. + **run_kwargs: Any other kwargs for `subprocess.run`. + + Returns: + The completed process object. + + Raises: + RuntimeError: if the process completes with a non-zero return code. + """ + + logging.info('Launching subprocess "%s"', ' '.join(cmd)) + + start_time = time.time() + try: + completed_process = subprocess.run( + cmd, + check=True, + stderr=subprocess.PIPE, + stdout=subprocess.PIPE, + text=True, + **run_kwargs, + ) + except subprocess.CalledProcessError as e: + if log_on_process_error: + # Logs have a 15k character limit, so log the error line by line. + logging.error('%s failed. %s stderr begin:', cmd_name, cmd_name) + for error_line in e.stderr.splitlines(): + if stripped_error_line := error_line.strip(): + logging.error(stripped_error_line) + logging.error('%s stderr end.', cmd_name) + + error_msg = ( + f'{cmd_name} failed' + f'\nstdout:\n{e.stdout[:max_out_streams_len]}\n' + f'\nstderr:\n{e.stderr[:max_out_streams_len]}' + ) + raise RuntimeError(error_msg) from e + end_time = time.time() + + logging.info('Finished %s in %.3f seconds', + cmd_name, end_time - start_time) + stdout, stderr = completed_process.stdout, completed_process.stderr + + if log_stdout and stdout: + logging.info('%s stdout:\n%s', cmd_name, stdout) + + if log_stderr and stderr: + logging.info('%s stderr:\n%s', cmd_name, stderr) + + return completed_process diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/atom_layout/atom_layout.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/atom_layout/atom_layout.py new file mode 100644 index 0000000000000000000000000000000000000000..4bb172eec1d2e798be36372d3bcf9f3266ede570 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/atom_layout/atom_layout.py @@ -0,0 +1,1194 @@ +# Copyright 2024 DeepMind Technologies Limited +# Copyright (C) 2025 Huawei Technologies Co., Ltd +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# +# Modifications by Huawei Technologies Co., Ltd: Adapt to run by MindSpore on Ascend + +"""Helper functions for different atom layouts and conversion between them.""" + +import collections +from collections.abc import Mapping, Sequence +import math +import dataclasses +import types +from typing import Any, TypeAlias + +import numpy as np +import mindspore as ms +from mindspore import ops +from rdkit import Chem + +from alphafold3 import structure +from alphafold3.constants import atom_types +from alphafold3.constants import chemical_component_sets +from alphafold3.constants import chemical_components +from alphafold3.constants import mmcif_names +from alphafold3.constants import residue_names +from alphafold3.structure import chemical_components as struc_chem_comps + + +xnp_ndarray: TypeAlias = np.ndarray # pylint: disable=invalid-name +NumpyIndex: TypeAlias = Any + + +def _assign_atom_names_from_graph(mol: Chem.Mol) -> Chem.Mol: + """Assigns atom names from the molecular graph. + + The atom name is stored as an atom property 'atom_name', accessible with + atom.GetProp('atom_name'). If the property is already specified, we keep the + original name. + + We traverse the graph in the order of the rdkit atom index and give each atom + a name equal to '{ELEMENT_TYPE}_{INDEX}'. E.g. C5 is the name for the fifth + unnamed carbon encountered. + + NOTE: A new mol is returned, the original is not changed in place. + + Args: + mol: RDKit molecule. + + Returns: + A new mol, with potentially new 'atom_name' properties. + """ + mol = Chem.Mol(mol) + + specified_atom_names = { + a.GetProp('atom_name') for a in mol.GetAtoms() if a.HasProp('atom_name') + } + + element_counts = collections.Counter() + for atom in mol.GetAtoms(): + if not atom.HasProp('atom_name'): + element = atom.GetSymbol() + while True: + element_counts[element] += 1 + new_name = f'{element}{element_counts[element]}' + if new_name not in specified_atom_names: + break + atom.SetProp('atom_name', new_name) + + return mol + + +@dataclasses.dataclass(frozen=True) +class AtomLayout: + """Atom layout in a fixed shape (usually 1-dim or 2-dim). + + Examples for atom layouts are atom37, atom14, and similar. + All members are np.ndarrays with the same shape, e.g. + - [num_atoms] + - [num_residues, max_atoms_per_residue] + - [num_fragments, max_fragments_per_residue] + All string arrays should have dtype=object to avoid pitfalls with Numpy's + fixed-size strings + + Attributes: + atom_name: np.ndarray of str: atom names (e.g. 'CA', 'NE2'), padding + elements have an empty string (''), None or any other value, that maps to + False for .astype(bool). mmCIF field: _atom_site.label_atom_id. + res_id: np.ndarray of int: residue index (usually starting from 1) padding + elements can have an arbitrary value. mmCIF field: + _atom_site.label_seq_id. + chain_id: np.ndarray of str: chain names (e.g. 'A', 'B') padding elements + can have an arbitrary value. mmCIF field: _atom_site.label_seq_id. + atom_element: np.ndarray of str: atom elements (e.g. 'C', 'N', 'O'), padding + elements have an empty string (''), None or any other value, that maps to + False for .astype(bool). mmCIF field: _atom_site.type_symbol. + res_name: np.ndarray of str: residue names (e.g. 'ARG', 'TRP') padding + elements can have an arbitrary value. mmCIF field: + _atom_site.label_comp_id. + chain_type: np.ndarray of str: chain types (e.g. 'polypeptide(L)'). padding + elements can have an arbitrary value. mmCIF field: _entity_poly.type OR + _entity.type (for non-polymers). + shape: shape of the layout (just returns atom_name.shape) + """ + + atom_name: np.ndarray + res_id: np.ndarray + chain_id: np.ndarray + atom_element: np.ndarray | None = None + res_name: np.ndarray | None = None + chain_type: np.ndarray | None = None + + def __post_init__(self): + """Assert all arrays have the same shape.""" + attribute_names = ( + 'atom_name', + 'atom_element', + 'res_name', + 'res_id', + 'chain_id', + 'chain_type', + ) + _assert_all_arrays_have_same_shape( + obj=self, + expected_shape=self.atom_name.shape, + attribute_names=attribute_names, + ) + # atom_name must have dtype object, such that we can convert it to bool to + # obtain the mask + if self.atom_name.dtype != object: + raise ValueError( + 'atom_name must have dtype object, such that it can ' + 'be converted converted to bool to obtain the mask' + ) + + def __getitem__(self, key: NumpyIndex) -> 'AtomLayout': + return AtomLayout( + atom_name=self.atom_name[key], + res_id=self.res_id[key], + chain_id=self.chain_id[key], + atom_element=( + self.atom_element[key] if self.atom_element is not None else None + ), + res_name=(self.res_name[key] + if self.res_name is not None else None), + chain_type=( + self.chain_type[key] if self.chain_type is not None else None + ), + ) + + def __eq__(self, other: 'AtomLayout') -> bool: + if not np.array_equal(self.atom_name, other.atom_name): + return False + + mask = self.atom_name.astype(bool) + # Check essential fields. + for field in ('res_id', 'chain_id'): + my_arr = getattr(self, field) + other_arr = getattr(other, field) + if not np.array_equal(my_arr[mask], other_arr[mask]): + return False + + # Check optional fields. + for field in ('atom_element', 'res_name', 'chain_type'): + my_arr = getattr(self, field) + other_arr = getattr(other, field) + if ( + my_arr is not None + and other_arr is not None + and not np.array_equal(my_arr[mask], other_arr[mask]) + ): + return False + + return True + + def copy_and_pad_to(self, shape: tuple[int, ...]) -> 'AtomLayout': + """Copies and pads the layout to the requested shape. + + Args: + shape: new shape for the atom layout + + Returns: + a copy of the atom layout padded to the requested shape + + Raises: + ValueError: incompatible shapes. + """ + if len(shape) != len(self.atom_name.shape): + raise ValueError( + f'Incompatible shape {shape}. Current layout has shape {self.shape}.' + ) + if any(new < old for old, new in zip(self.atom_name.shape, shape)): + raise ValueError( + "Can't pad to a smaller shape. Current layout has shape " + f'{self.shape} and you requested shape {shape}.' + ) + pad_width = [ + (0, new - old) for old, new in zip(self.atom_name.shape, shape) + ] + pad_val = np.array('', dtype=object) + return AtomLayout( + atom_name=np.pad(self.atom_name, pad_width, + constant_values=pad_val), + res_id=np.pad(self.res_id, pad_width, constant_values=0), + chain_id=np.pad(self.chain_id, pad_width, constant_values=pad_val), + atom_element=( + np.pad(self.atom_element, pad_width, constant_values=pad_val) + if self.atom_element is not None + else None + ), + res_name=( + np.pad(self.res_name, pad_width, constant_values=pad_val) + if self.res_name is not None + else None + ), + chain_type=( + np.pad(self.chain_type, pad_width, constant_values=pad_val) + if self.chain_type is not None + else None + ), + ) + + def to_array(self) -> np.ndarray: + """Stacks the fields to a numpy array with shape (6, ). + + Creates a pure numpy array of type `object` by stacking the 6 fields of the + AtomLayout, i.e. (atom_name, atom_element, res_name, res_id, chain_id, + chain_type). This method together with from_array() provides an easy way to + apply pure numpy methods like np.concatenate() to `AtomLayout`s. + + Returns: + np.ndarray of object with shape (6, ), e.g. + array([['N', 'CA', 'C', ..., 'CB', 'CG', 'CD'], + ['N', 'C', 'C', ..., 'C', 'C', 'C'], + ['LEU', 'LEU', 'LEU', ..., 'PRO', 'PRO', 'PRO'], + [1, 1, 1, ..., 403, 403, 403], + ['A', 'A', 'A', ..., 'D', 'D', 'D'], + ['polypeptide(L)', 'polypeptide(L)', ..., 'polypeptide(L)']], + dtype=object) + """ + if ( + self.atom_element is None + or self.res_name is None + or self.chain_type is None + ): + raise ValueError('All optional fields need to be present.') + + return np.stack(dataclasses.astuple(self), axis=0) + + @classmethod + def from_array(cls, arr: np.ndarray) -> 'AtomLayout': + """Creates an AtomLayout object from a numpy array with shape (6, ...). + + see also to_array() + Args: + arr: np.ndarray of object with shape (6, ) + + Returns: + AtomLayout object with shape () + """ + if arr.shape[0] != 6: + raise ValueError( + 'Given array must have shape (6, ...) to match the 6 fields of ' + 'AtomLayout (atom_name, atom_element, res_name, res_id, chain_id, ' + f'chain_type). Your array has {arr.shape=}' + ) + return cls(*arr) + + @property + def shape(self) -> tuple[int, ...]: + return self.atom_name.shape + + +@dataclasses.dataclass(frozen=True) +class Residues: + """List of residues with meta data. + + Attributes: + res_name: np.ndarray of str [num_res], e.g. 'ARG', 'TRP' + res_id: np.ndarray of int [num_res] + chain_id: np.ndarray of str [num_res], e.g. 'A', 'B' + chain_type: np.ndarray of str [num_res], e.g. 'polypeptide(L)' + is_start_terminus: np.ndarray of bool [num_res] + is_end_terminus: np.ndarray of bool [num_res] + deprotonation: (optional) np.ndarray of set() [num_res], e.g. {'HD1', 'HE2'} + smiles_string: (optional) np.ndarray of str [num_res], e.g. 'Cc1ccccc1' + shape: shape of the layout (just returns res_name.shape) + """ + + res_name: np.ndarray + res_id: np.ndarray + chain_id: np.ndarray + chain_type: np.ndarray + is_start_terminus: np.ndarray + is_end_terminus: np.ndarray + deprotonation: np.ndarray | None = None + smiles_string: np.ndarray | None = None + + def __post_init__(self): + """Assert all arrays are 1D have the same shape.""" + attribute_names = ( + 'res_name', + 'res_id', + 'chain_id', + 'chain_type', + 'is_start_terminus', + 'is_end_terminus', + 'deprotonation', + 'smiles_string', + ) + _assert_all_arrays_have_same_shape( + obj=self, + expected_shape=(self.res_name.shape[0],), + attribute_names=attribute_names, + ) + + def __getitem__(self, key: NumpyIndex) -> 'Residues': + return Residues( + res_name=self.res_name[key], + res_id=self.res_id[key], + chain_id=self.chain_id[key], + chain_type=self.chain_type[key], + is_start_terminus=self.is_start_terminus[key], + is_end_terminus=self.is_end_terminus[key], + deprotonation=( + self.deprotonation[key] if self.deprotonation is not None else None + ), + smiles_string=( + self.smiles_string[key] if self.smiles_string is not None else None + ), + ) + + def __eq__(self, other: 'Residues') -> bool: + return all( + np.array_equal(getattr(self, field.name), + getattr(other, field.name)) + for field in dataclasses.fields(self) + ) + + @property + def shape(self) -> tuple[int, ...]: + return self.res_name.shape + + +@dataclasses.dataclass # (frozen=True) +class GatherInfo: + """Gather indices to translate from one atom layout to another. + + All members are np or jnp ndarray (usually 1-dim or 2-dim) with the same + shape, e.g. + - [num_atoms] + - [num_residues, max_atoms_per_residue] + - [num_fragments, max_fragments_per_residue] + + Attributes: + gather_idxs: np or jnp ndarray of int: gather indices into a flattened array + gather_mask: np or jnp ndarray of bool: mask for resulting array + input_shape: np or jnp ndarray of int: the shape of the unflattened input + array + shape: output shape. Just returns gather_idxs.shape + """ + + gather_idxs: ms.Tensor + gather_mask: ms.Tensor + input_shape: ms.Tensor + + def __post_init__(self): + if self.gather_mask.shape != self.gather_idxs.shape: + raise ValueError( + 'All arrays must have the same shape. Got\n' + f'gather_idxs.shape = {self.gather_idxs.shape}\n' + f'gather_mask.shape = {self.gather_mask.shape}\n' + ) + + def __getitem__(self, key: NumpyIndex) -> 'GatherInfo': + return GatherInfo( + gather_idxs=self.gather_idxs[key], + gather_mask=self.gather_mask[key], + input_shape=self.input_shape, + ) + + @property + def shape(self) -> tuple[int, ...]: + return self.gather_idxs.shape + + def as_np_or_jnp(self, xnp: types.ModuleType) -> 'GatherInfo': + return GatherInfo( + gather_idxs=xnp.array(self.gather_idxs), + gather_mask=xnp.array(self.gather_mask), + input_shape=xnp.array(self.input_shape), + ) + + def as_dict( + self, + key_prefix: str | None = None, + ) -> dict[str, xnp_ndarray]: + prefix = f'{key_prefix}:' if key_prefix else '' + return { + prefix + 'gather_idxs': self.gather_idxs, + prefix + 'gather_mask': self.gather_mask, + prefix + 'input_shape': self.input_shape, + } + + @classmethod + def from_dict( + cls, + d: Mapping[str, xnp_ndarray], + key_prefix: str | None = None, + ) -> 'GatherInfo': + """Creates GatherInfo from a given dictionary.""" + prefix = f'{key_prefix}:' if key_prefix else '' + return cls( + gather_idxs=d[prefix + 'gather_idxs'], + gather_mask=d[prefix + 'gather_mask'], + input_shape=d[prefix + 'input_shape'], + ) + + +def fill_in_optional_fields( + minimal_atom_layout: AtomLayout, + reference_atoms: AtomLayout, +) -> AtomLayout: + """Fill in the optional fields (atom_element, res_name, chain_type). + + Extracts the optional fields (atom_element, res_name, chain_type) from a + flat reference layout and fills them into the fields from this layout. + + Args: + minimal_atom_layout: An AtomLayout that only contains the essential fields + (atom_name, res_id, chain_id). + reference_atoms: A flat layout that contains all fields for all atoms. + + Returns: + An AtomLayout that contains all fields. + + Raises: + ValueError: Reference atoms layout is not flat. + ValueError: Missing atoms in reference. + """ + if len(reference_atoms.shape) > 1: + raise ValueError('Only flat layouts are supported as reference.') + ref_to_self = compute_gather_idxs( + source_layout=reference_atoms, target_layout=minimal_atom_layout + ) + atom_mask = minimal_atom_layout.atom_name.astype(bool) + missing_atoms_mask = atom_mask & ~ref_to_self.gather_mask + if np.any(missing_atoms_mask): + raise ValueError( + f'{np.sum(missing_atoms_mask)} missing atoms in reference: ' + f'{minimal_atom_layout[missing_atoms_mask]}' + ) + + def _convert_str_array(gather: GatherInfo, arr: np.ndarray): + output = arr[gather.gather_idxs] + output[~gather.gather_mask] = '' + return output + + return dataclasses.replace( + minimal_atom_layout, + atom_element=_convert_str_array( + ref_to_self, reference_atoms.atom_element + ), + res_name=_convert_str_array(ref_to_self, reference_atoms.res_name), + chain_type=_convert_str_array(ref_to_self, reference_atoms.chain_type), + ) + + +def guess_deprotonation(residues: Residues) -> Residues: + """Convenience function to create a plausible deprotonation field. + + Assumes a pH of 7 and always prefers HE2 over HD1 for HIS. + Args: + residues: a Residues object without a depronotation field + + Returns: + a Residues object with a depronotation field + """ + num_residues = residues.res_name.shape[0] + deprotonation = np.empty(num_residues, dtype=object) + deprotonation_at_ph7 = { + 'ASP': 'HD2', + 'GLU': 'HE2', + 'HIS': 'HD1', + } + for idx, res_name in enumerate(residues.res_name): + deprotonation[idx] = set() + if res_name in deprotonation_at_ph7: + deprotonation[idx].add(deprotonation_at_ph7[res_name]) + if residues.is_end_terminus[idx]: + deprotonation[idx].add('HXT') + + return dataclasses.replace(residues, deprotonation=deprotonation) + + +def atom_layout_from_structure( + struct: structure.Structure, + *, + fix_non_standard_polymer_res: bool = False, +) -> AtomLayout: + """Extract AtomLayout from a Structure.""" + + if not fix_non_standard_polymer_res: + return AtomLayout( + atom_name=np.array(struct.atom_name, dtype=object), + atom_element=np.array(struct.atom_element, dtype=object), + res_name=np.array(struct.res_name, dtype=object), + res_id=np.array(struct.res_id, dtype=int), + chain_id=np.array(struct.chain_id, dtype=object), + chain_type=np.array(struct.chain_type, dtype=object), + ) + + # Target lists. + target_atom_names = [] + target_atom_elements = [] + target_res_ids = [] + target_res_names = [] + target_chain_ids = [] + target_chain_types = [] + + for atom in struct.iter_atoms(): + target_atom_names.append(atom['atom_name']) + target_atom_elements.append(atom['atom_element']) + target_res_ids.append(atom['res_id']) + target_chain_ids.append(atom['chain_id']) + target_chain_types.append(atom['chain_type']) + if mmcif_names.is_standard_polymer_type(atom['chain_type']): + fixed_res_name = mmcif_names.fix_non_standard_polymer_res( + res_name=atom['res_name'], chain_type=atom['chain_type'] + ) + target_res_names.append(fixed_res_name) + else: + target_res_names.append(atom['res_name']) + + return AtomLayout( + atom_name=np.array(target_atom_names, dtype=object), + atom_element=np.array(target_atom_elements, dtype=object), + res_name=np.array(target_res_names, dtype=object), + res_id=np.array(target_res_ids, dtype=int), + chain_id=np.array(target_chain_ids, dtype=object), + chain_type=np.array(target_chain_types, dtype=object), + ) + + +def residues_from_structure( + struct: structure.Structure, + *, + include_missing_residues: bool = True, + fix_non_standard_polymer_res: bool = False, +) -> Residues: + """Create a Residues object from a Structure object.""" + + def _get_smiles(res_name): + """Get SMILES string from chemical components.""" + smiles = None + if ( + struct.chemical_components_data is not None + and struct.chemical_components_data.chem_comp is not None + and struct.chemical_components_data.chem_comp.get(res_name) + ): + smiles = struct.chemical_components_data.chem_comp[res_name].pdbx_smiles + return smiles + + res_names_per_chain = struct.chain_res_name_sequence( + include_missing_residues=include_missing_residues, + fix_non_standard_polymer_res=fix_non_standard_polymer_res, + ) + res_name = [] + res_id = [] + chain_id = [] + chain_type = [] + smiles = [] + is_start_terminus = [] + for c in struct.iter_chains(): + if include_missing_residues: + this_res_ids = [ + id for (_, id) in struct.all_residues[c['chain_id']]] + else: + this_res_ids = [ + r['res_id'] + for r in struct.iter_residues() + if r['chain_id'] == c['chain_id'] + ] + fixed_res_names = res_names_per_chain[c['chain_id']] + assert len(this_res_ids) == len( + fixed_res_names + ), f'{len(this_res_ids)} != {len(fixed_res_names)}' + this_start_res_id = min(min(this_res_ids), 1) + this_is_start_terminus = [r == this_start_res_id for r in this_res_ids] + smiles.extend([_get_smiles(res_name) for res_name in fixed_res_names]) + num_res = len(fixed_res_names) + res_name.extend(fixed_res_names) + res_id.extend(this_res_ids) + chain_id.extend([c['chain_id']] * num_res) + chain_type.extend([c['chain_type']] * num_res) + is_start_terminus.extend(this_is_start_terminus) + res_name = np.array(res_name, dtype=object) + res_id = np.array(res_id, dtype=int) + chain_id = np.array(chain_id, dtype=object) + chain_type = np.array(chain_type, dtype=object) + smiles = np.array(smiles, dtype=object) + is_start_terminus = np.array(is_start_terminus, dtype=bool) + + res_uid_to_idx = { + uid: idx for idx, uid in enumerate(zip(chain_id, res_id, strict=True)) + } + + # Start terminus indicates whether residue index is 1 and chain is polymer. + is_polymer = np.isin(chain_type, tuple(mmcif_names.POLYMER_CHAIN_TYPES)) + is_start_terminus = is_start_terminus & is_polymer + + # Start also indicates whether amino acid is attached to H2 or proline to H. + start_terminus_atom_index = np.nonzero( + (struct.chain_type == mmcif_names.PROTEIN_CHAIN) + & ( + (struct.atom_name == 'H2') + | ((struct.atom_name == 'H') & (struct.res_name == 'PRO')) + ) + )[0] + + # Translate atom idx to residue idx to assign start terminus. + for atom_idx in start_terminus_atom_index: + res_uid = (struct.chain_id[atom_idx], struct.res_id[atom_idx]) + res_idx = res_uid_to_idx[res_uid] + is_start_terminus[res_idx] = True + + # Infer end terminus: Check for OXT, or in case of + # include_missing_residues==True for the last residue of the chain. + num_all_residues = res_name.shape[0] + is_end_terminus = np.zeros(num_all_residues, dtype=bool) + end_term_atom_idxs = np.nonzero(struct.atom_name == 'OXT')[0] + for atom_idx in end_term_atom_idxs: + res_uid = (struct.chain_id[atom_idx], struct.res_id[atom_idx]) + res_idx = res_uid_to_idx[res_uid] + is_end_terminus[res_idx] = True + + if include_missing_residues: + for idx in range(num_all_residues - 1): + if is_polymer[idx] and chain_id[idx] != chain_id[idx + 1]: + is_end_terminus[idx] = True + if (num_all_residues > 0) and is_polymer[-1]: + is_end_terminus[-1] = True + + # Infer (de-)protonation: Only if hydrogens are given. + num_hydrogens = np.sum( + (struct.atom_element == 'H') & (struct.chain_type == 'polypeptide(L)') + ) + if num_hydrogens > 0: + deprotonation = np.empty(num_all_residues, dtype=object) + all_atom_uids = set( + zip(struct.chain_id, struct.res_id, struct.atom_name, strict=True) + ) + for idx in range(num_all_residues): + deprotonation[idx] = set() + check_hydrogens = set() + if is_end_terminus[idx]: + check_hydrogens.add('HXT') + if res_name[idx] in atom_types.PROTONATION_HYDROGENS: + check_hydrogens.update( + atom_types.PROTONATION_HYDROGENS[res_name[idx]]) + for hydrogen in check_hydrogens: + if (chain_id[idx], res_id[idx], hydrogen) not in all_atom_uids: + deprotonation[idx].add(hydrogen) + else: + deprotonation = None + + return Residues( + res_name=res_name, + res_id=res_id, + chain_id=chain_id, + chain_type=chain_type, + is_start_terminus=is_start_terminus.astype(bool), + is_end_terminus=is_end_terminus, + deprotonation=deprotonation, + smiles_string=smiles, + ) + + +def get_link_drop_atoms( + res_name: str, + chain_type: str, + *, + is_start_terminus: bool, + is_end_terminus: bool, + bonded_atoms: set[str], + drop_ligand_leaving_atoms: bool = False, +) -> set[str]: + """Returns set of atoms that are dropped when this res_name gets linked. + + Args: + res_name: residue name, e.g. 'ARG' + chain_type: chain_type, e.g. 'polypeptide(L)' + is_start_terminus: whether the residue is the n-terminus + is_end_terminus: whether the residue is the c-terminus + bonded_atoms: Names of atoms coming off this residue. + drop_ligand_leaving_atoms: Flag to switch on/off leaving atoms for ligands. + + Returns: + Set of atoms that are dropped when this amino acid gets linked. + """ + drop_atoms = set() + if chain_type == mmcif_names.PROTEIN_CHAIN: + if res_name == 'PRO': + if not is_start_terminus: + drop_atoms.update({'H', 'H2', 'H3'}) + if not is_end_terminus: + drop_atoms.update({'OXT', 'HXT'}) + else: + if not is_start_terminus: + drop_atoms.update({'H2', 'H3'}) + if not is_end_terminus: + drop_atoms.update({'OXT', 'HXT'}) + elif chain_type in mmcif_names.NUCLEIC_ACID_CHAIN_TYPES: + if not is_start_terminus: + drop_atoms.update({'OP3'}) + elif ( + drop_ligand_leaving_atoms and chain_type in mmcif_names.LIGAND_CHAIN_TYPES + ): + if res_name in { + *chemical_component_sets.GLYCAN_OTHER_LIGANDS, + *chemical_component_sets.GLYCAN_LINKING_LIGANDS, + }: + if 'O1' not in bonded_atoms: + drop_atoms.update({'O1'}) + return drop_atoms + + +def get_bonded_atoms( + polymer_ligand_bonds: AtomLayout, + ligand_ligand_bonds: AtomLayout, + res_id: int, + chain_id: str, +) -> set[str]: + """Finds the res_name on the opposite end of the bond, if a bond exists. + + Args: + polymer_ligand_bonds: Bond information for polymer-ligand pairs. + ligand_ligand_bonds: Bond information for ligand-ligand pairs. + res_id: residue id in question. + chain_id: chain id of residue in question. + + Returns: + res_name of bonded atom. + """ + bonded_atoms = set() + if polymer_ligand_bonds: + # Filter before searching to speed this up. + bond_idx = np.logical_and( + polymer_ligand_bonds.res_id == res_id, + polymer_ligand_bonds.chain_id == chain_id, + ).any(axis=1) + relevant_polymer_bonds = polymer_ligand_bonds[bond_idx] + for atom_names, res_ids, chain_ids in zip( + relevant_polymer_bonds.atom_name, + relevant_polymer_bonds.res_id, + relevant_polymer_bonds.chain_id, + ): + if (res_ids[0], chain_ids[0]) == (res_id, chain_id): + bonded_atoms.add(atom_names[0]) + elif (res_ids[1], chain_ids[1]) == (res_id, chain_id): + bonded_atoms.add(atom_names[1]) + if ligand_ligand_bonds: + bond_idx = np.logical_and( + ligand_ligand_bonds.res_id == res_id, + ligand_ligand_bonds.chain_id == chain_id, + ).any(axis=1) + relevant_ligand_bonds = ligand_ligand_bonds[bond_idx] + for atom_names, res_ids, chain_ids in zip( + relevant_ligand_bonds.atom_name, + relevant_ligand_bonds.res_id, + relevant_ligand_bonds.chain_id, + ): + if (res_ids[0], chain_ids[0]) == (res_id, chain_id): + bonded_atoms.add(atom_names[0]) + elif (res_ids[1], chain_ids[1]) == (res_id, chain_id): + bonded_atoms.add(atom_names[1]) + return bonded_atoms + + +def make_flat_atom_layout( + residues: Residues, + ccd: chemical_components.Ccd, + polymer_ligand_bonds: AtomLayout | None = None, + ligand_ligand_bonds: AtomLayout | None = None, + *, + with_hydrogens: bool = False, + skip_unk_residues: bool = True, + drop_ligand_leaving_atoms: bool = False, +) -> AtomLayout: + """Make a flat atom layout for given residues. + + Create a flat layout from a `Residues` object. The required atoms for each + amino acid type are taken from the CCD, hydrogens and oxygens are dropped to + make the linked residues. Terminal OXT's and protonation state for the + hydrogens come from the `Residues` object. + + Args: + residues: a `Residues` object. + ccd: The chemical components dictionary. + polymer_ligand_bonds: Bond information for polymer-ligand pairs. + ligand_ligand_bonds: Bond information for ligand-ligand pairs. + with_hydrogens: whether to create hydrogens + skip_unk_residues: whether to skip 'UNK' resides -- default is True to be + compatible with the rest of AlphaFold that does not predict atoms for + unknown residues + drop_ligand_leaving_atoms: Flag to switch on/ off leaving atoms for ligands. + + Returns: + an `AtomLayout` object + """ + num_res = residues.res_name.shape[0] + + # Target lists. + target_atom_names = [] + target_atom_elements = [] + target_res_ids = [] + target_res_names = [] + target_chain_ids = [] + target_chain_types = [] + + for idx in range(num_res): + # skip 'UNK' residues if requested + if ( + skip_unk_residues + and residues.res_name[idx] in residue_names.UNKNOWN_TYPES + ): + continue + + # Get the atoms for this residue type from CCD. + if ccd.get(residues.res_name[idx]): + res_atoms = struc_chem_comps.get_all_atoms_in_entry( + ccd=ccd, res_name=residues.res_name[idx] + ) + atom_names_elements = list( + zip( + res_atoms['_chem_comp_atom.atom_id'], + res_atoms['_chem_comp_atom.type_symbol'], + strict=True, + ) + ) + elif residues.smiles_string[idx]: + # Get atoms from RDKit via SMILES. + mol = Chem.MolFromSmiles(residues.smiles_string[idx]) + mol = _assign_atom_names_from_graph(mol) + atom_names_elements = [ + (a.GetProp('atom_name'), a.GetSymbol()) for a in mol.GetAtoms() + ] + else: + raise ValueError( + f'{residues.res_name[idx]} not found in CCD and no SMILES string' + ) + + # Remove hydrogens if requested. + if not with_hydrogens: + atom_names_elements = [ + (n, e) for n, e in atom_names_elements if (e != 'H' and e != 'D') + ] + bonded_atoms = get_bonded_atoms( + polymer_ligand_bonds, + ligand_ligand_bonds, + residues.res_id[idx], + residues.chain_id[idx], + ) + # Connect the amino-acids, i.e. remove OXT, HXT and H2. + drop_atoms = get_link_drop_atoms( + res_name=residues.res_name[idx], + chain_type=residues.chain_type[idx], + is_start_terminus=residues.is_start_terminus[idx], + is_end_terminus=residues.is_end_terminus[idx], + bonded_atoms=bonded_atoms, + drop_ligand_leaving_atoms=drop_ligand_leaving_atoms, + ) + + # If deprotonation info is available, remove the specific atoms. + if residues.deprotonation is not None: + drop_atoms.update(residues.deprotonation[idx]) + + atom_names_elements = [ + (n, e) for n, e in atom_names_elements if n not in drop_atoms + ] + + # Append the found atoms to the target lists. + target_atom_names.extend([n for n, _ in atom_names_elements]) + target_atom_elements.extend([e for _, e in atom_names_elements]) + num_atoms = len(atom_names_elements) + target_res_names.extend([residues.res_name[idx]] * num_atoms) + target_res_ids.extend([residues.res_id[idx]] * num_atoms) + target_chain_ids.extend([residues.chain_id[idx]] * num_atoms) + target_chain_types.extend([residues.chain_type[idx]] * num_atoms) + + return AtomLayout( + atom_name=np.array(target_atom_names, dtype=object), + atom_element=np.array(target_atom_elements, dtype=object), + res_name=np.array(target_res_names, dtype=object), + res_id=np.array(target_res_ids, dtype=int), + chain_id=np.array(target_chain_ids, dtype=object), + chain_type=np.array(target_chain_types, dtype=object), + ) + + +def compute_gather_idxs( + *, + source_layout: AtomLayout, + target_layout: AtomLayout, + fill_value: int = 0, +) -> GatherInfo: + """Produce gather indices and mask to convert from source layout to target.""" + source_uid_to_idx = { + uid: idx + for idx, uid in enumerate( + zip( + source_layout.chain_id.ravel(), + source_layout.res_id.ravel(), + source_layout.atom_name.ravel(), + strict=True, + ) + ) + } + gather_idxs = [] + gather_mask = [] + for uid in zip( + target_layout.chain_id.ravel(), + target_layout.res_id.ravel(), + target_layout.atom_name.ravel(), + strict=True, + ): + if uid in source_uid_to_idx: + gather_idxs.append(source_uid_to_idx[uid]) + gather_mask.append(True) + else: + gather_idxs.append(fill_value) + gather_mask.append(False) + target_shape = target_layout.atom_name.shape + return GatherInfo( + gather_idxs=np.array(gather_idxs, dtype=int).reshape(target_shape), + gather_mask=np.array(gather_mask, dtype=bool).reshape(target_shape), + input_shape=np.array(source_layout.atom_name.shape), + ) + + +def convert( + gather_info: GatherInfo, + arr: xnp_ndarray, + *, + layout_axes: tuple[int, ...] = (0,), +) -> xnp_ndarray: + """Convert an array from one atom layout to another.""" + # Translate negative indices to the corresponding positives. + layout_axes = tuple(i if i >= 0 else i + arr.ndim for i in layout_axes) + + # Ensure that layout_axes are continuous. + layout_axes_begin = layout_axes[0] + layout_axes_end = layout_axes[-1] + 1 + + if layout_axes != tuple(range(layout_axes_begin, layout_axes_end)): + raise ValueError(f'layout_axes must be continuous. Got {layout_axes}.') + layout_shape = arr.shape[layout_axes_begin:layout_axes_end] + + # Ensure that the layout shape is compatible + # with the gather_info. I.e. the first axis size must be equal or greater + # than the gather_info.input_shape, and all subsequent axes sizes must match. + if (len(layout_shape) != gather_info.input_shape.size) or ( + isinstance(gather_info.input_shape, np.ndarray) + and ( + (layout_shape[0] < gather_info.input_shape[0]) + or (np.any(layout_shape[1:] != gather_info.input_shape[1:])) + ) + ): + raise ValueError( + 'Input array layout axes are incompatible. You specified layout ' + f'axes {layout_axes} with an input array of shape {arr.shape}, but ' + f'the gather info expects shape {gather_info.input_shape}. ' + 'Your first axis size must be equal or greater than the ' + 'gather_info.input_shape, and all subsequent axes sizes must ' + 'match.' + ) + + # Compute the shape of the input array with flattened layout. + batch_shape = arr.shape[:layout_axes_begin] + features_shape = arr.shape[layout_axes_end:] + arr_flattened_shape = batch_shape + \ + (int(np.prod(layout_shape)),) + features_shape + + # Flatten input array and perform the gather. + arr_flattened = arr.reshape(arr_flattened_shape) + if layout_axes_begin == 0: + out_arr = arr_flattened[gather_info.gather_idxs, ...] + elif layout_axes_begin == 1: + out_arr = arr_flattened[:, gather_info.gather_idxs, ...] + elif layout_axes_begin == 2: + out_arr = arr_flattened[:, :, gather_info.gather_idxs, ...] + elif layout_axes_begin == 3: + out_arr = arr_flattened[:, :, :, gather_info.gather_idxs, ...] + elif layout_axes_begin == 4: + out_arr = arr_flattened[:, :, :, :, gather_info.gather_idxs, ...] + else: + raise ValueError( + 'Only 4 batch axes supported. If you need more, the code ' + 'is easy to extend.' + ) + + # Broadcast the mask and apply it. + broadcasted_mask_shape = ( + (1,) * len(batch_shape) + + gather_info.gather_mask.shape + + (1,) * len(features_shape) + ) + out_arr *= gather_info.gather_mask.reshape(broadcasted_mask_shape) + return out_arr + + +def convert_ms( + gather_info: GatherInfo, + arr: ms.Tensor, + *, + layout_axes: tuple[int, ...] = (0,), +) -> ms.Tensor: + """Convert an array from one atom layout to another.""" + # Translate negative indices to the corresponding positives. + layout_axes = tuple(i if i >= 0 else i + arr.ndim for i in layout_axes) + + # Ensure that layout_axes are continuous. + layout_axes_begin = layout_axes[0] + layout_axes_end = layout_axes[-1] + 1 + + if layout_axes != tuple(range(layout_axes_begin, layout_axes_end)): + raise ValueError(f'layout_axes must be continuous. Got {layout_axes}.') + layout_shape = arr.shape[layout_axes_begin:layout_axes_end] + + # Ensure that the layout shape is compatible + # with the gather_info. I.e. the first axis size must be equal or greater + # than the gather_info.input_shape, and all subsequent axes sizes must match. + # if (len(layout_shape) != gather_info.input_shape.size) or ( + # isinstance(gather_info.input_shape, np.ndarray) + # and ( + # (layout_shape[0] < gather_info.input_shape[0]) + # or (np.any(layout_shape[1:] != gather_info.input_shape[1:])) + # ) + # ): + # raise ValueError( + # 'Input array layout axes are incompatible. You specified layout ' + # f'axes {layout_axes} with an input array of shape {arr.shape}, but ' + # f'the gather info expects shape {gather_info.input_shape}. ' + # 'Your first axis size must be equal or greater than the ' + # 'gather_info.input_shape, and all subsequent axes sizes must ' + # 'match.' + # ) + + # Compute the shape of the input array with flattened layout. + batch_shape = arr.shape[:layout_axes_begin] + features_shape = arr.shape[layout_axes_end:] + arr_flattened_shape = batch_shape + \ + (int(math.prod(layout_shape)),) + features_shape + + # Flatten input array and perform the gather. + arr_flattened = arr.reshape(arr_flattened_shape) + out_arr = ops.gather(arr_flattened, gather_info.gather_idxs, axis=layout_axes_begin) + + # Broadcast the mask and apply it. + broadcasted_mask_shape = ( + (1,) * len(batch_shape) + + gather_info.gather_mask.shape + + (1,) * len(features_shape) + ) + out_arr *= ms.Tensor(gather_info.gather_mask.reshape(broadcasted_mask_shape)) + return out_arr.astype(ms.float32) + + +def make_structure( + flat_layout: AtomLayout, + atom_coords: np.ndarray, + name: str, + *, + atom_b_factors: np.ndarray | None = None, + all_physical_residues: Residues | None = None, +) -> structure.Structure: + """Returns a Structure from a flat layout and atom coordinates. + + The provided flat_layout must be 1-dim and must not contain any padding + elements. The flat_layout.atom_name must conform to the OpenMM/CCD standard + and must not contain deuterium. + + Args: + flat_layout: flat 1-dim AtomLayout without padding elements + atom_coords: np.ndarray of float, shape (num_atoms, 3) + name: str: the name (usually PDB id), e.g. '1uao' + atom_b_factors: np.ndarray of float, shape (num_atoms,) or None. If None, + they will be set to all zeros. + all_physical_residues: a Residues object that contains all physically + existing residues, i.e. also those residues that have no resolved atoms. + This is common in experimental structures, but also appears in predicted + structures for 'UNK' or other non-standard residue types, where the model + does not predict coordinates. This will be used to create the + `all_residues` field of the structure object. + """ + + if flat_layout.atom_name.ndim != 1 or not np.all( + flat_layout.atom_name.astype(bool) + ): + raise ValueError( + 'flat_layout must be 1-dim and must not contain anypadding element' + ) + if ( + flat_layout.atom_element is None + or flat_layout.res_name is None + or flat_layout.chain_type is None + ): + raise ValueError('All optional fields must be present.') + + if atom_b_factors is None: + atom_b_factors = np.zeros(atom_coords.shape[:-1]) + + if all_physical_residues is not None: + # Create the all_residues field from a Residues object + # (unfortunately there is no central place to keep the chain_types in + # the structure class, so we drop it here) + all_residues = collections.defaultdict(list) + for chain_id, res_id, res_name in zip( + all_physical_residues.chain_id, + all_physical_residues.res_id, + all_physical_residues.res_name, + strict=True, + ): + all_residues[chain_id].append((res_name, res_id)) + else: + # Create the all_residues field from the flat_layout + all_residues = collections.defaultdict(list) + if flat_layout.chain_id.shape[0] > 0: + all_residues[flat_layout.chain_id[0]].append( + (flat_layout.res_name[0], flat_layout.res_id[0]) + ) + for i in range(1, flat_layout.shape[0]): + if ( + flat_layout.chain_id[i] != flat_layout.chain_id[i - 1] + or flat_layout.res_name[i] != flat_layout.res_name[i - 1] + or flat_layout.res_id[i] != flat_layout.res_id[i - 1] + ): + all_residues[flat_layout.chain_id[i]].append( + (flat_layout.res_name[i], flat_layout.res_id[i]) + ) + + return structure.from_atom_arrays( + name=name, + all_residues=dict(all_residues), + chain_id=flat_layout.chain_id, + chain_type=flat_layout.chain_type, + res_id=flat_layout.res_id.astype(np.int32), + res_name=flat_layout.res_name, + atom_name=flat_layout.atom_name, + atom_element=flat_layout.atom_element, + atom_x=atom_coords[..., 0], + atom_y=atom_coords[..., 1], + atom_z=atom_coords[..., 2], + atom_b_factor=atom_b_factors, + ) + + +def _assert_all_arrays_have_same_shape( + *, + obj: AtomLayout | Residues | GatherInfo, + expected_shape: tuple[int, ...], + attribute_names: Sequence[str], +) -> None: + """Checks that given attributes of the object have the expected shape.""" + attribute_shapes_description = [] + all_shapes_are_valid = True + + for attribute_name in attribute_names: + attribute = getattr(obj, attribute_name) + + if attribute is None: + attribute_shape = None + else: + attribute_shape = attribute.shape + + if attribute_shape is not None and expected_shape != attribute_shape: + all_shapes_are_valid = False + + attribute_shape_name = attribute_name + '.shape' + attribute_shapes_description.append( + f'{attribute_shape_name:25} = {attribute_shape}' + ) + + if not all_shapes_are_valid: + raise ValueError( + f'All arrays must have the same shape ({expected_shape=}). Got\n' + + '\n'.join(attribute_shapes_description) + ) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/base_config.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/base_config.py new file mode 100644 index 0000000000000000000000000000000000000000..0d3a08b6206c02d88b74960e7c2f9a4b1c607027 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/base_config.py @@ -0,0 +1,153 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Config for the protein folding model and experiment.""" + +from collections.abc import Mapping +import copy +import dataclasses +import types +import typing +from typing import Any, ClassVar, TypeVar + + +_T = TypeVar('_T') +_ConfigT = TypeVar('_ConfigT', bound='BaseConfig') + + +def _strip_optional(t: type[Any]) -> type[Any]: + """Transforms type annotations of the form `T | None` to `T`.""" + if typing.get_origin(t) in (typing.Union, types.UnionType): + args = set(typing.get_args(t)) - {types.NoneType} + if len(args) == 1: + return args.pop() + return t + + +_NO_UPDATE = object() + + +class _Autocreate: + + def __init__(self, **defaults: Any): + self.defaults = defaults + + +def autocreate(**defaults: Any) -> Any: + """Marks a field as having a default factory derived from its type.""" + return _Autocreate(**defaults) + + +def _clone_field( + field: dataclasses.Field[_T], new_default: _T +) -> dataclasses.Field[_T]: + if new_default is _NO_UPDATE: + return copy.copy(field) + return dataclasses.field( + default=new_default, + init=True, + kw_only=True, + repr=field.repr, + hash=field.hash, + compare=field.compare, + metadata=field.metadata, + ) + + +@typing.dataclass_transform() +class ConfigMeta(type): + """Metaclass that synthesizes a __post_init__ that coerces dicts to Config subclass instances.""" + + def __new__(mcs, name, bases, classdict): + cls = super().__new__(mcs, name, bases, classdict) + + def _coercable_fields(self) -> Mapping[str, tuple[ConfigMeta, Any]]: + type_hints = typing.get_type_hints(self.__class__) + fields = dataclasses.fields(self.__class__) + field_to_type_and_default = { + field.name: (_strip_optional( + type_hints[field.name]), field.default) + for field in fields + } + coercable_fields = { + f: t + for f, t in field_to_type_and_default.items() + if issubclass(type(t[0]), ConfigMeta) + } + return coercable_fields + + cls._coercable_fields = property(_coercable_fields) + + old_post_init = getattr(cls, '__post_init__', None) + + def _post_init(self) -> None: + # Use get_type_hints instead of Field.type to ensure that forward + # references are resolved. + for field_name, ( + field_type, + field_default, + ) in self._coercable_fields.items(): # pylint: disable=protected-access + field_value = getattr(self, field_name) + if field_value is None: + continue + try: + match field_value: + case _Autocreate(): + # Construct from field defaults. + setattr(self, field_name, field_type( + **field_value.defaults)) + case Mapping(): + # Field value is not yet a `Config` instance; Assume we can create + # one by splatting keys and values. + args = {} + # Apply default args first, if present. + if isinstance(field_default, _Autocreate): + args.update(field_default.defaults) + args.update(field_value) + setattr(self, field_name, field_type(**args)) + case _: + pass + except TypeError as e: + raise TypeError( + f'Failure while coercing field {field_name!r} of' + f' {self.__class__.__qualname__}' + ) from e + if old_post_init: + old_post_init(self) + + cls.__post_init__ = _post_init + + return dataclasses.dataclass(kw_only=True)(cls) + + +class BaseConfig(metaclass=ConfigMeta): + """Config base class. + + Subclassing Config automatically makes the subclass a kw_only dataclass with + a `__post_init__` that coerces Config-subclass field values from mappings to + instances of the right type. + """ + # Provided by dataclasses.make_dataclass + __dataclass_fields__: ClassVar[dict[str, dataclasses.Field[Any]]] + + # Overridden by metaclass + @property + def _coercable_fields(self) -> Mapping[str, tuple[type['BaseConfig'], Any]]: + return {} + + def as_dict(self) -> Mapping[str, Any]: + result = dataclasses.asdict(self) + for field_name in self._coercable_fields: + field_value = getattr(self, field_name, None) + if isinstance(field_value, BaseConfig): + result[field_name] = field_value.as_dict() + return result diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/components/base_model.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/components/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5de365c85d21adecfe7c8ea53d46aef0b6aaff62 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/components/base_model.py @@ -0,0 +1,51 @@ +"""Defines interface of a BaseModel.""" + +from collections.abc import Mapping +import dataclasses +from typing import Any, TypeAlias +from alphafold3 import structure +import numpy as np +import mindspore as ms + +ModelResult: TypeAlias = Mapping[str, Any] +ScalarNumberOrArray: TypeAlias = Mapping[str, float | int | np.ndarray] + +# Eval result will contain scalars (e.g. metrics or losses), selected from the +# forward pass outputs or computed in the online evaluation; np.ndarrays or +# jax.Arrays generated from the forward pass outputs (e.g. distogram expected +# distances) or batch inputs; protein structures (predicted and ground-truth). +EvalResultValue: TypeAlias = ( + float | int | np.ndarray | ms.Tensor | structure.Structure +) +# Eval result may be None for some metrics if they are not computable. +EvalResults: TypeAlias = Mapping[str, EvalResultValue | None] +# Interface metrics are all floats or None. +InterfaceMetrics: TypeAlias = Mapping[str, float | None] +# Interface results are a mapping from interface name to mappings from score +# type to metric value. +InterfaceResults: TypeAlias = Mapping[str, Mapping[str, InterfaceMetrics]] +# Eval output consists of full eval results and a dict of interface metrics. +EvalOutput: TypeAlias = tuple[EvalResults, InterfaceResults] + +# Signature for `apply` method of hk.transform_with_state called on a BaseModel. +# ForwardFn: TypeAlias = Callable[ +# [hk.Params, hk.State, jax.Array, features.BatchDict], +# tuple[ModelResult, hk.State], +# ] + + +@dataclasses.dataclass(frozen=True) +class InferenceResult: + """Postprocessed model result.""" + + # Predicted protein structure. + predicted_structure: structure.Structure = dataclasses.field() + # Useful numerical data (scalars or arrays) to be saved at inference time. + numerical_data: ScalarNumberOrArray = dataclasses.field( + default_factory=dict) + # Smaller numerical data (usually scalar) to be saved as inference metadata. + metadata: ScalarNumberOrArray = dataclasses.field(default_factory=dict) + # Additional dict for debugging, e.g. raw outputs of a model forward pass. + debug_outputs: ModelResult | None = dataclasses.field(default_factory=dict) + # Model identifier. + model_id: bytes = b'' diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/components/base_modules.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/components/base_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..f5353738226f67697b8469307abef406a781d0b1 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/components/base_modules.py @@ -0,0 +1,148 @@ +# Copyright 2024 DeepMind Technologies Limited +# Copyright (C) 2025 Huawei Technologies Co., Ltd +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# +# Modifications by Huawei Technologies Co., Ltd: Adapt to run by MindSpore on Ascend + +"""Common modules.""" + +from collections.abc import Sequence +import contextlib +import numbers +from typing import TypeAlias + +import numpy as np +import mindspore as ms +from mindspore import nn, ops +from mindspore.common import initializer +from mindchemistry.e3.utils import Ncon + +# Useful for mocking in tests. +DEFAULT_PRECISION = None + +# Constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.) +TRUNCATED_NORMAL_STDDEV_FACTOR = np.asarray( + 0.87962566103423978, dtype=np.float32 +) + + +class LayerNorm(nn.Cell): + """LayerNorm module. + + Equivalent to ms.nn.LayerNorm. In most cases, it can be replaced by ms.nn.LayerNorm. + Here, gamma is scale, beta is shift or offset + Args: + normalized_shape (tuple | list): The shape of Tensor which need to LayerNorm. + name (str): Name of this layer. + begin_norm_axis(int): From which axis norm begin + begin_params_axis(int): From which axis params begin + gamma_init('str'): Initializer of gamma + beta_init('str'): Initializer of beta + epsilon(float): epsilon value + dtype(ms.type): Type of output + create_beta(bool): whether to create a trainable beta parameter + create_gamma(bool): whether to create a trainable gamma parameter + Inputs: + - **x** (Tensor) - Tensor of any shape + Outputs: + The shape of tensor is the same as x. + Supported Platforms: + ``Ascend`` + """ + + def __init__(self, normalized_shape, name=None, begin_norm_axis=-1, + begin_params_axis=-1, gamma_init='ones', + beta_init='zeros', epsilon=1e-5, dtype=ms.float32, + create_beta=True, create_gamma=True): + super().__init__() + if not create_beta: + beta_init = 'zeros' + if not create_gamma: + gamma_init = 'ones' + self.layernorm = nn.LayerNorm(normalized_shape[begin_norm_axis:], begin_norm_axis=begin_norm_axis, + begin_params_axis=begin_params_axis, gamma_init=gamma_init, + beta_init=beta_init, epsilon=epsilon, dtype=dtype) + if create_beta is False: + self.layernorm.beta.requires_grad = False + if create_gamma is False: + self.layernorm.gamma.requires_grad = False + self.dtype = dtype + + def construct(self, x): + out = self.layernorm(x.astype(ms.float32)).astype(x.dtype) + return out + + +class CustomDense(nn.Cell): + """ + Custom Linear Module. It can be apply to a high dimension Tensor, and can be used on more than 1D Matmul. + In Alphafold, they use Einsum to replace Matmul, here we use Ncon to replace Matmul. if in_shape and out_shape + are both int, this layer is equivalence to nn.Dense. + Args: + in_shape (Union(int, List, Tuple)): input shape, that need to be multiplied. + out_shape (Union(int, List, Tuple)): output shape, that need to be multiplied. + Inputs: + - **x** (Tensor) + Outputs: + + Supported Platforms: + ``Ascend`` + """ + + def __init__(self, in_shape, out_shape, weight_init="zeros", use_bias=False, \ + bias_init="zeros", ndim=None, dtype=ms.float32): + super().__init__() + if isinstance(in_shape, int): + in_shape = (in_shape,) + if isinstance(out_shape, int): + out_shape = (out_shape,) + self.num_output_dims = len(out_shape) + self.num_input_dims = len(in_shape) + if ndim is None: + ndim = len(in_shape) + 1 + if weight_init in ["relu", "linear"]: + self.weight = custom_initializer( + weight_init, in_shape + out_shape, dtype=dtype) + else: + self.weight = ms.Parameter(initializer.initializer( + weight_init, in_shape + out_shape, dtype=dtype)) + self.use_bias = use_bias + if self.use_bias: + self.bias = ms.Parameter( + initializer.initializer(bias_init, out_shape, dtype=dtype)) + ncon_list1 = [-i-1 for i in range(ndim - self.num_input_dims)] + [ + i+1 for i in range(len(in_shape))] + ncon_list2 = (ncon_list1[ndim - self.num_input_dims:]) + \ + [-i-ndim+self.num_input_dims-1 for i in range(len(out_shape))] + self.ncon = Ncon([ncon_list1, ncon_list2]) + + in_letters = 'abcde'[: self.num_input_dims] + out_letters = 'hijkl'[: self.num_output_dims] + self.equation = f'...{in_letters}, {in_letters}{out_letters}->...{out_letters}' + + def construct(self, x): + if self.use_bias: + output = self.ncon([x, self.weight]) + self.bias + else: + output = self.ncon([x, self.weight]) + return output + + +def custom_initializer(initializer_name, input_shape, dtype=ms.float32): + noise_scale = ms.Tensor(1.0) + for channel_dim in input_shape: + noise_scale /= channel_dim + if initializer_name == 'relu': + noise_scale *= 2 + stddev = ops.sqrt(noise_scale) + stddev = stddev / ms.Tensor(TRUNCATED_NORMAL_STDDEV_FACTOR) + param = ms.Parameter(initializer.initializer( + initializer.TruncatedNormal(stddev, 0), input_shape, dtype)) + return param diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/components/mapping.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/components/mapping.py new file mode 100644 index 0000000000000000000000000000000000000000..4d5f660bebb9a895914bf6b27b104759624ceb09 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/components/mapping.py @@ -0,0 +1,353 @@ +# Copyright 2024 DeepMind Technologies Limited +# Copyright (C) 2025 Huawei Technologies Co., Ltd +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# +# Modifications by Huawei Technologies Co., Ltd: Adapt to run by MindSpore on Ascend + +"""Specialized mapping functions.""" + +from collections.abc import Callable, Sequence +import functools +from typing import Any +import mindspore as ms + + +Pytree = Any +PytreeJaxArray = Any + +partial = functools.partial +PROXY = object() + + +def _maybe_slice(array, i, slice_size, axis): + "modified to mindspore" + if axis is PROXY: + return array + start = [0]*array.ndim + start[axis] = i + size = list(array.shape) + size[axis] = slice_size + return ms.ops.slice(array, start, size) + + +def _maybe_get_size(array, axis): + "modified to mindspore" + if axis == PROXY: + return -1 + return array.shape[axis] + + +def tree_flatten(tree): + if isinstance(tree, (list, tuple)): + flat, structure = [], [] + for item in tree: + sub_flat, sub_struct = tree_flatten(item) + flat.extend(sub_flat) + structure.append(sub_struct) + return flat, structure + elif isinstance(tree, dict): + flat, structure = [], {} + for key, value in tree.items(): + sub_flat, sub_struct = tree_flatten(value) + flat.extend(sub_flat) + structure[key] = sub_struct + return flat, structure + else: + return [tree], None + + +def tree_unflatten(flat, structure): + if isinstance(structure, list): + result, idx = [], 0 + for sub_struct in structure: + sub_tree, idx = tree_unflatten(flat[idx:], sub_struct) + result.append(sub_tree) + return result, idx + elif isinstance(structure, dict): + result, idx = {}, 0 + for key, sub_struct in structure.items(): + sub_tree, idx = tree_unflatten(flat[idx:], sub_struct) + result[key] = sub_tree + return result, idx + else: + return flat[0], 1 + + +def _expand_axes(axes, values, name="sharded_apply"): + values_tree_def = tree_flatten(values)[1] + # flat_axes = tree_flatten(axes)[0] + flat_axes = [PROXY if axes is None else axes for _ in values_tree_def] + expanded_axes, _ = tree_unflatten(flat_axes, values_tree_def) + return expanded_axes + + +def tree_map(fn, *trees): + "Mindspore do not have the same function like Jax.tree.map, so try to write a mindspore version." + tree_types = {type(tree) for tree in trees} + tree_type = tree_types.pop() + if tree_type in (list,): + return tree_type(tree_map(fn, *subtrees) for subtrees in zip(*trees)) + if tree_type is dict: + keys = trees[0].keys() + if not all(tree.keys() == keys for tree in trees): + raise ValueError("All input dictionaries must have the same keys") + return {key: tree_map(fn, *(tree[key] for tree in trees)) for key in keys} + return fn(*trees) + + +def tree_leaves(tree): + "same as tree_map" + if isinstance(tree, (list, tuple)): + leaves = [] + for item in tree: + leaves.extend(tree_leaves(item)) + return leaves + if isinstance(tree, dict): + leaves = [] + for key in tree: + leaves.extend(tree_leaves(tree[key])) + return leaves + return [tree] + + +def eval_shape(fun, *args, **kwargs): + fake_inputs = [ms.ops.zeros(arg.shape, dtype=arg.dtype) if isinstance( + arg, ms.Tensor) else arg for arg in args] + output = fun(*fake_inputs, **kwargs) + return output + + +def sharded_apply( + fun: Callable[..., PytreeJaxArray], + shard_size: int | None = 1, + in_axes: int | Pytree = 0, + out_axes: int | Pytree = 0, + new_out_axes: bool = False, +) -> Callable[..., PytreeJaxArray]: + """Sharded apply. + + Applies `fun` over shards to axes, in a way similar to vmap, + but does so in shards of `shard_size`. Shards are stacked after. + This allows a smooth trade-off between + memory usage (as in a plain map) vs higher throughput (as in a vmap). + + Args: + fun: Function to apply smap transform to. + shard_size: Integer denoting shard size. + in_axes: Either integer or pytree describing which axis to map over for each + input to `fun`, None denotes broadcasting. + out_axes: Integer or pytree denoting to what axis in the output the mapped + over axis maps. + new_out_axes: Whether to stack outputs on new axes. This assumes that the + output sizes for each shard (including the possible remainder shard) are + the same. + + Returns: + Function with smap applied. + """ + docstr = ( + "Mapped version of {fun}. Takes similar arguments to {fun} " + "but with additional array axes over which {fun} is mapped." + ) + if new_out_axes: + raise NotImplementedError("New output axes not yet implemented.") + + # shard size None denotes no sharding + if shard_size is None: + return fun + + def mapped_fn(*args, **kwargs): + # Expand in axes and determine loop range. + in_axes_ = _expand_axes(ms.Tensor(in_axes), args) + + in_sizes = tree_map(_maybe_get_size, list(args), in_axes_) + in_size = max(tree_leaves(in_sizes)) + + num_extra_shards = (in_size - 1) // shard_size + + # Fix if necessary. + last_shard_size = in_size % shard_size + last_shard_size = shard_size if last_shard_size == 0 else last_shard_size + + def apply_fun_to_slice(slice_start, slice_size, args, in_axes_): + input_slice = tree_map( + lambda array, axis: _maybe_slice( + array, slice_start, slice_size, axis + ), + args, + in_axes_, + ) + return fun(input_slice, **kwargs) + + remainder_shape_dtype = eval_shape( + lambda array, axis: apply_fun_to_slice( + 0, last_shard_size, array, axis), + args, in_axes_ + ) + + out_shapes = tree_map(lambda x: x.shape, remainder_shape_dtype) + out_dtypes = tree_map(lambda x: x.dtype, remainder_shape_dtype) + out_axes_ = _expand_axes(out_axes, out_shapes) + + if num_extra_shards > 0: + regular_shard_shape_dtype = eval_shape( + lambda array, axis: apply_fun_to_slice( + 0, shard_size, array, axis), + args, in_axes_ + ) + shard_shapes = tree_map( + lambda x: x.shape, regular_shard_shape_dtype) + + def make_output_shape(axis, shard_shape, remainder_shape): + axis = axis if isinstance(axis, int) else int(axis[0]) + shard_shape = tuple(shard_shape) + remainder_shape = tuple(remainder_shape) + return ms.ops.stack( + shard_shape[:axis] + + (shard_shape[axis] * num_extra_shards + + remainder_shape[axis],) + + shard_shape[axis + 1:] + ) + + out_shapes = tree_map( + make_output_shape, out_axes_[0], ms.Tensor( + shard_shapes), ms.Tensor(out_shapes) + ) + + # Calls dynamic Update slice with different argument order. + # This is here since tree_map only works with positional arguments. + def dynamic_update_slice_in_dim(array, slice_size, axis, i): + start = [0]*array.ndim + start[axis] = int(i) + size = list(array.shape) + size[axis] = slice_size.shape[axis] + # return ms.ops.slice(array, start, size) + end = [x + y for x, y in zip(start, size)] + array[start[0]: end[0]] = slice_size + return array + + def compute_shard(outputs, slice_start, slice_size): + slice_out = (lambda array, axis: apply_fun_to_slice( + int(slice_start), shard_size, array, axis))(args, in_axes_) + update_slice = partial(dynamic_update_slice_in_dim, i=slice_start) + # slice_out = (slice_out,) if not isinstance(slice, (int, float)) else [int(x) for x in slice_out] + return tree_map(update_slice, outputs, slice_out, out_axes_[0]) + + def scan_iteration(outputs, i): + new_outputs = compute_shard(outputs, i, shard_size) + return new_outputs + + slice_starts = ms.ops.arange(0, in_size - shard_size + 1, shard_size) + + def allocate_buffer(dtype, shape): + return ms.ops.zeros(shape, dtype=dtype) + + outputs = tree_map(allocate_buffer, out_dtypes, out_shapes) + + if slice_starts.shape[0] > 0: + for slice_start in slice_starts: + outputs = scan_iteration(outputs, slice_start) + # scan_op = ms.ops.Scan() + # outputs, _ = scan_op(scan_iteration, outputs, slice_starts) + + if last_shard_size != shard_size: + remainder_start = in_size - last_shard_size + outputs = compute_shard(outputs, remainder_start, last_shard_size) + + return outputs + + return mapped_fn + + +def sharded_map(fun, shard_size=1, in_axes=0, out_axes=0): + vmapped_fun = ms.vmap(fun, int(in_axes), int(out_axes)) + return sharded_apply(vmapped_fun, shard_size, in_axes, out_axes) + + +def reshape_partitioned_inputs(batched_args, partitioned_dim, subbatch_size): + """Reshapes so subbatching doesn't happen on the partitioned dim.""" + subbatched_args = [] + for arg in batched_args: + shape = arg.shape + new_shape = ( + shape[:partitioned_dim] + + (subbatch_size, shape[partitioned_dim] // subbatch_size) + + shape[partitioned_dim + 1:] + ) + subbatched_args.append(arg.reshape(new_shape)) + return subbatched_args + + +def reshape_partitioned_output(output, output_subbatch_dim): + """Reshapes outputs as if reshape_partitioned_inputs were never applied.""" + out_shape = ( + output.shape[: output_subbatch_dim - 1] + + (-1,) + + output.shape[output_subbatch_dim + 1:] + ) + return output.reshape(out_shape) + + +def inference_subbatch(module, subbatch_size, batched_args, + nonbatched_args, input_subbatch_dim=0, output_subbatch_dim=None, + input_subbatch_dim_is_partitioned=False): + """Run through subbatches (like batch apply but with split and concat).""" + assert len(batched_args) > 0 + if output_subbatch_dim is None: + output_subbatch_dim = input_subbatch_dim + if input_subbatch_dim_is_partitioned: + # Subbatching along the partitioned axis would induce an all-gather that + # undoes the partitioning. So instead we reshape such that + # [..., partitioned_input_size, ...] becomes [..., subbatch_size, + # partitioned_input_size // subbatch_size, ...] and then actually subbatch + # along the partitioned_input_size // subbatch_size axis in slices of + # size 1. Partitioning is then preserved on the partitioned axis, except + # that dimension is now of size subbatch_size instead of + # partitioned_input_size. Note that the module itself still sees inputs of + # size [..., subbatch_size, ...], just as it would if this reshaping were + # not applied. + batched_args = reshape_partitioned_inputs( + batched_args, input_subbatch_dim, subbatch_size + ) + input_subbatch_dim += 1 + output_subbatch_dim += 1 + subbatch_size = 1 + + def run_module(*batched_args): + if input_subbatch_dim_is_partitioned: + # Squeeze off the singleton dimension (otherwise the module would see + # [..., subbatch_size, 1, ...]). + batched_args = [b.squeeze(axis=input_subbatch_dim) + for b in batched_args] + args = list(batched_args)[0] + list(nonbatched_args) + res = module(*args) + if input_subbatch_dim_is_partitioned: + # Add back in the singleton dimension so the outputs are stacked on the + # axis we are actually subbatching over (i.e stacked back to + # [..., subbatch_size, partitioned_input_size // subbatch_size, ...]), + # rather than on the partitioned axis, which would again induce an + # all-gather that breaks partitioning. + res = ms.ops.expand_dims(res, axis=output_subbatch_dim) + return res + sharded_module = sharded_apply( + run_module, + shard_size=subbatch_size, + in_axes=input_subbatch_dim, + out_axes=output_subbatch_dim, + ) + output = sharded_module(*batched_args) + if input_subbatch_dim_is_partitioned: + # The is of the same shape as the inputs [..., subbatch_size, + # partitioned_input_size // subbatch_size, ...]. Reshape to + # [..., partitioned_input_size, ...] as if the reshaping due to partitioning + # had never been applied. + output = reshape_partitioned_output(output, output_subbatch_dim) + + return output diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/components/utils.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/components/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..537ee0648247d167a890cbf7af567a7a677ef11d --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/components/utils.py @@ -0,0 +1,65 @@ +# Copyright 2024 DeepMind Technologies Limited +# Copyright (C) 2025 Huawei Technologies Co., Ltd +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# +# Modifications by Huawei Technologies Co., Ltd: Adapt to run by MindSpore on Ascend + +from collections import abc +import numbers + +import numpy as np +import mindspore as ms + +VALID_DTYPES = [np.float32, np.float64, np.int8, np.int32, np.int32, bool] + + +def remove_invalidly_typed_feats(batch): + """Remove features of types we don't want to send to the TPU e.g. strings.""" + return { + k: v + for k, v in batch.items() + if hasattr(v, 'dtype') and v.dtype in VALID_DTYPES + } + + +def mask_mean(mask, value, axis=None, keepdims=False, eps=1e-10): + """Masked mean.""" + + mask_shape = mask.shape + value_shape = value.shape + + assert len(mask_shape) == len( + value_shape + ), 'Shapes are not compatible, shapes: {}, {}'.format(mask_shape, value_shape) + + if isinstance(axis, numbers.Integral): + axis = [axis] + elif axis is None: + axis = list(range(len(mask_shape))) + assert isinstance( + axis, abc.Iterable + ), 'axis needs to be either an iterable, integer or "None"' + + broadcast_factor = 1.0 + for axis_ in axis: + value_size = value_shape[axis_] + mask_size = mask_shape[axis_] + if mask_size == 1: + broadcast_factor *= value_size + else: + error = f'Shapes are not compatible, shapes: {mask_shape}, {value_shape}' + assert mask_size == value_size, error + + return ms.ops.sum(mask * value, keepdim=keepdims, dim=axis) / ( + ms.ops.maximum( + ms.ops.sum(mask, keepdim=keepdims, dim=axis) * + broadcast_factor, eps + ) + ) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/confidence_types.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/confidence_types.py new file mode 100644 index 0000000000000000000000000000000000000000..3ceb926d6b9409021e881c12a8c772d8cfc1d8b7 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/confidence_types.py @@ -0,0 +1,310 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Confidence categories for predictions.""" + +import dataclasses +import enum +import json +from typing import Any, Self + +from absl import logging +import numpy as np +import mindspore as ms +from alphafold3.model.components import base_model +from alphafold3.model.components.mapping import tree_map + + +class StructureConfidenceFullEncoder(json.JSONEncoder): + """JSON encoder for serializing confidence types.""" + + def __init__(self, **kwargs): + super().__init__(**(kwargs | dict(separators=(',', ':')))) + + def encode(self, o: 'StructureConfidenceFull'): + # Cast to np.float64 before rounding, since casting to Python float will + # cast to a 64 bit float, potentially undoing np.float32 rounding. + atom_plddts = np.round( + np.clip(np.asarray(o.atom_plddts, dtype=np.float64), 0.0, 99.99), 2 + ).astype(float) + contact_probs = np.round( + np.clip(np.asarray(o.contact_probs, dtype=np.float64), 0.0, 1.0), 2 + ).astype(float) + pae = np.round( + np.clip(np.asarray(o.pae, dtype=np.float64), 0.0, 99.9), 1 + ).astype(float) + return """\ +{ + "atom_chain_ids": %s, + "atom_plddts": %s, + "contact_probs": %s, + "pae": %s, + "token_chain_ids": %s, + "token_res_ids": %s +}""" % ( + super().encode(o.atom_chain_ids), + super().encode(list(atom_plddts)).replace('NaN', 'null'), + super().encode([list(x) + for x in contact_probs]).replace('NaN', 'null'), + super().encode([list(x) for x in pae]).replace('NaN', 'null'), + super().encode(o.token_chain_ids), + super().encode(o.token_res_ids), + ) + + +def _dump_json(data: Any, indent: int | None = None) -> str: + """Dumps a json string with JSON compatible NaN representation.""" + json_str = json.dumps( + data, + sort_keys=True, + indent=indent, + separators=(',', ': '), + ) + return json_str.replace('NaN', 'null') + + +@enum.unique +class ConfidenceCategory(enum.Enum): + """Confidence categories for AlphaFold predictions.""" + + HIGH = 0 + MEDIUM = 1 + LOW = 2 + DISORDERED = 3 + + @classmethod + def from_char(cls, char: str) -> Self: + match char: + case 'H': + return cls.HIGH + case 'M': + return cls.MEDIUM + case 'L': + return cls.LOW + case 'D': + return cls.DISORDERED + case _: + raise ValueError( + f'Unknown character. Expected one of H, M, L or D; got: {char}' + ) + + def to_char(self) -> str: + match self: + case self.HIGH: + return 'H' + case self.MEDIUM: + return 'M' + case self.LOW: + return 'L' + case self.DISORDERED: + return 'D' + + @classmethod + def from_confidence_score(cls, confidence: float) -> Self: + if 90 <= confidence <= 100: + return cls.HIGH + if 70 <= confidence < 90: + return cls.MEDIUM + if 50 <= confidence < 70: + return cls.LOW + if 0 <= confidence < 50: + return cls.DISORDERED + raise ValueError( + f'Confidence score out of range [0, 100]: {confidence}') + + +@dataclasses.dataclass() +class AtomConfidence: + """Dataclass for 1D per-atom confidences from AlphaFold.""" + + chain_id: list[str] + atom_number: list[int] + confidence: list[float] + confidence_category: list[ConfidenceCategory] + + def __post_init__(self): + num_res = len(self.atom_number) + if not all( + len(v) == num_res + for v in [self.chain_id, self.confidence, self.confidence_category] + ): + raise ValueError( + 'All confidence fields must have the same length.') + + @classmethod + def from_inference_result( + cls, inference_result: base_model.InferenceResult + ) -> Self: + """Instantiates an AtomConfidence from a structure. + + Args: + inference_result: Inference result from AlphaFold. + + Returns: + Scores in AtomConfidence dataclass. + """ + struct = inference_result.predicted_structure + as_dict = { + 'chain_id': [], + 'atom_number': [], + 'confidence': [], + 'confidence_category': [], + } + for atom_number, atom in enumerate(struct.iter_atoms()): + this_confidence = float(struct.atom_b_factor[atom_number]) + as_dict['chain_id'].append(atom['chain_id']) + as_dict['atom_number'].append(atom_number) + as_dict['confidence'].append(round(this_confidence, 2)) + as_dict['confidence_category'].append( + ConfidenceCategory.from_confidence_score(this_confidence) + ) + return cls(**as_dict) + + @classmethod + def from_json(cls, json_string: str) -> Self: + """Instantiates a AtomConfidence from a json string.""" + input_dict = json.loads(json_string) + input_dict['confidence_category'] = [ + ConfidenceCategory.from_char(k) + for k in input_dict['confidence_category'] + ] + return cls(**input_dict) + + def to_json(self) -> str: + output = dataclasses.asdict(self) + output['confidence_category'] = [ + k.to_char() for k in output['confidence_category'] + ] + output['atom_number'] = [int(k) for k in output['atom_number']] + return _dump_json(output) + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class StructureConfidenceSummary: + """Dataclass for the summary of structure scores from AlphaFold. + + Attributes: + ptm: Predicted TM global score. + iptm: Interface predicted TM global score. + ranking_score: Ranking score extracted from CIF metadata. + fraction_disordered: Fraction disordered, measured with RASA. + has_clash: Has significant clashing. + chain_pair_pae_min: [num_chains, num_chains] Minimum cross chain PAE. + chain_pair_iptm: [num_chains, num_chains] Chain pair ipTM. + chain_ptm: [num_chains] Chain pTM. + chain_iptm: [num_chains] Mean cross chain ipTM for a chain. + """ + + ptm: float + iptm: float + ranking_score: float + fraction_disordered: float + has_clash: float + chain_pair_pae_min: np.ndarray + chain_pair_iptm: np.ndarray + chain_ptm: np.ndarray + chain_iptm: np.ndarray + + @classmethod + def from_inference_result( + cls, inference_result: base_model.InferenceResult + ) -> Self: + """Returns a new instance based on a given inference result.""" + return cls( + ptm=float(inference_result.metadata['ptm']), + iptm=float(inference_result.metadata['iptm']), + ranking_score=float(inference_result.metadata['ranking_score']), + fraction_disordered=float( + inference_result.metadata['fraction_disordered'] + ), + has_clash=float(inference_result.metadata['has_clash']), + chain_pair_pae_min=inference_result.metadata['chain_pair_pae_min'], + chain_pair_iptm=inference_result.metadata['chain_pair_iptm'], + chain_ptm=inference_result.metadata['iptm_ichain'], + chain_iptm=inference_result.metadata['iptm_xchain'], + ) + + @classmethod + def from_json(cls, json_string: str) -> Self: + """Returns a new instance from a given json string.""" + return cls(**json.loads(json_string)) + + def to_json(self) -> str: + def convert(data): + if isinstance(data, np.ndarray): + # Cast to np.float64 before rounding, since casting to Python float will + # cast to a 64 bit float, potentially undoing np.float32 rounding. + rounded_data = np.round(data.astype( + np.float64), decimals=2).tolist() + else: + rounded_data = np.round(data, decimals=2) + return rounded_data + + return _dump_json(tree_map(convert, dataclasses.asdict(self)), indent=1) + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class StructureConfidenceFull: + """Dataclass for full structure data from AlphaFold.""" + + pae: np.ndarray + token_chain_ids: list[str] + token_res_ids: list[int] + atom_plddts: list[float] + atom_chain_ids: list[str] + contact_probs: np.ndarray # [num_tokens, num_tokens] + + @classmethod + def from_inference_result( + cls, inference_result: base_model.InferenceResult + ) -> Self: + """Returns a new instance based on a given inference result.""" + + pae = inference_result.numerical_data['full_pae'] + if isinstance(pae, ms.Tensor): + pae = pae.asnumpy() + if not isinstance(pae, np.ndarray): + logging.info('%s', type(pae)) + raise TypeError('pae should be a numpy array.') + + contact_probs = inference_result.numerical_data['contact_probs'] + if isinstance(contact_probs, ms.Tensor): + contact_probs = contact_probs.asnumpy() + if not isinstance(contact_probs, np.ndarray): + logging.info('%s', type(contact_probs)) + raise TypeError('contact_probs should be a numpy array.') + + struct = inference_result.predicted_structure + chain_ids = struct.chain_id.tolist() + atom_plddts = struct.atom_b_factor.tolist() + token_chain_ids = [ + str(token_id) + for token_id in inference_result.metadata['token_chain_ids'] + ] + token_res_ids = [ + int(token_id) for token_id in inference_result.metadata['token_res_ids'] + ] + return cls( + pae=pae, + token_chain_ids=token_chain_ids, + token_res_ids=token_res_ids, + atom_plddts=atom_plddts, + atom_chain_ids=chain_ids, + contact_probs=contact_probs, + ) + + @classmethod + def from_json(cls, json_string: str) -> Self: + """Returns a new instance from a given json string.""" + return cls(**json.loads(json_string)) + + def to_json(self) -> str: + """Converts StructureConfidenceFull to json string.""" + return json.dumps(self, cls=StructureConfidenceFullEncoder) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/confidences.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/confidences.py new file mode 100644 index 0000000000000000000000000000000000000000..9eeb22a25a06e21a2ebeacdca8bc9aa584cfd685 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/confidences.py @@ -0,0 +1,665 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Functions for extracting and processing confidences from model outputs.""" +import warnings +import numpy as np +from absl import logging +from alphafold3 import structure +from alphafold3.constants import residue_names +from alphafold3.cpp import mkdssp +from scipy import spatial + + +# From Sander & Rost 1994 https://doi.org/10.1002/prot.340200303 + +MAX_ACCESSIBLE_SURFACE_AREA = { + 'ALA': 106.0, + 'ARG': 248.0, + 'ASN': 157.0, + 'ASP': 163.0, + 'CYS': 135.0, + 'GLN': 198.0, + 'GLU': 194.0, + 'GLY': 84.0, + 'HIS': 184.0, + 'ILE': 169.0, + 'LEU': 164.0, + 'LYS': 205.0, + 'MET': 188.0, + 'PHE': 197.0, + 'PRO': 136.0, + 'SER': 130.0, + 'THR': 142.0, + 'TRP': 227.0, + 'TYR': 222.0, + 'VAL': 142.0, +} + +# Weights for ranking confidence. +_IPTM_WEIGHT = 0.8 +_FRACTION_DISORDERED_WEIGHT = 0.5 +_CLASH_PENALIZATION_WEIGHT = 100.0 + + +def windowed_solvent_accessible_area(cif: str, window: int = 25) -> np.ndarray: + """Implementation of AlphaFold_RSA. + + AlphaFold_RSA defined in + https://www.ncbi.nlm.nih.gov/pmc/articles/PMC9601767/. + + Args: + cif: raw cif string. + window: The window over which to average accessible surface area + + Returns: + An array of size num_res that predicts disorder by using windowed solvent + accessible surface area. + """ + result = mkdssp.get_dssp(cif, calculate_surface_accessibility=True) + parse_row = False + rasa = [] + for row in result.splitlines(): + if parse_row: + aa = row[13:14] + if aa == '!': + continue + aa3 = residue_names.PROTEIN_COMMON_ONE_TO_THREE.get(aa, 'ALA') + max_acc = MAX_ACCESSIBLE_SURFACE_AREA[aa3] + acc = int(row[34:38]) + norm_acc = acc / max_acc + if norm_acc > 1.0: + norm_acc = 1.0 + rasa.append(norm_acc) + if row.startswith(' # RESIDUE'): + parse_row = True + + half_w = (window - 1) // 2 + pad_rasa = np.pad(rasa, (half_w, half_w), 'reflect') + rasa = np.convolve(pad_rasa, np.ones(window), 'valid') / window + return rasa + + +def fraction_disordered( + struct: structure.Structure, rasa_disorder_cutoff: float = 0.581 +) -> float: + """Compute fraction of protein residues that are disordered. + + Args: + struct: A structure to compute rASA metrics on. + rasa_disorder_cutoff: The threshold at which residues are considered + disordered. Default value taken from + https://www.ncbi.nlm.nih.gov/pmc/articles/PMC9601767/. + + Returns: + The fraction of protein residues that are disordered + (rasa > rasa_disorder_cutoff). + """ + struct = struct.filter_to_entity_type(protein=True) + rasa = [] + seq_rasa = {} + for chain_id, chain_seq in struct.chain_single_letter_sequence().items(): + if chain_seq in seq_rasa: + # We assume that identical sequences have approximately similar rasa + # values to speed up the computation. + rasa.extend(seq_rasa[chain_seq]) + continue + chain_struct = struct.filter(chain_id=chain_id) + try: + rasa_per_residue = windowed_solvent_accessible_area( + chain_struct.to_mmcif() + ) + seq_rasa[chain_seq] = rasa_per_residue + rasa.extend(rasa_per_residue) + except (ValueError, RuntimeError): + logging.warning('%s: rasa calculation failed', struct.name) + + if not rasa: + return 0.0 + return np.mean(np.array(rasa) > rasa_disorder_cutoff) + + +def has_clash( + struct: structure.Structure, + cutoff_radius: float = 1.1, + min_clashes_for_overlap: int = 100, + min_fraction_for_overlap: float = 0.5, +) -> bool: + """Determine whether the structure has at least one clashing chain. + + A clashing chain is defined as having greater than 100 polymer atoms within + 1.1A of another polymer atom, or having more than 50% of the chain with + clashing atoms. + + Args: + struct: A structure to get clash metrics for. + cutoff_radius: atom distances under this threshold are considered a clash. + min_clashes_for_overlap: The minimum number of atom-atom clashes for a chain + to be considered overlapping. + min_fraction_for_overlap: The minimum fraction of atoms within a chain that + are clashing for the chain to be considered overlapping. + + Returns: + True if the structure has at least one clashing chain. + """ + struct = struct.filter_to_entity_type(protein=True, rna=True, dna=True) + if not struct.chains: + return False + coords = struct.coords + coord_kdtree = spatial.cKDTree(coords) + clashes_per_atom = coord_kdtree.query_ball_point( + coords, p=2.0, r=cutoff_radius + ) + per_atom_has_clash = np.zeros(len(coords), dtype=np.int32) + for atom_idx, clashing_indices in enumerate(clashes_per_atom): + for clashing_idx in clashing_indices: + if np.abs(struct.res_id[atom_idx] - struct.res_id[clashing_idx]) > 1 or ( + struct.chain_id[atom_idx] != struct.chain_id[clashing_idx] + ): + per_atom_has_clash[atom_idx] = True + break + for chain_id in struct.chains: + mask = struct.chain_id == chain_id + num_atoms = np.sum(mask) + if num_atoms == 0: + continue + num_clashes = np.sum(per_atom_has_clash * mask) + frac_clashes = num_clashes / num_atoms + if ( + num_clashes > min_clashes_for_overlap + or frac_clashes > min_fraction_for_overlap + ): + return True + return False + + +def get_ranking_score( + ptm: float, iptm: float, fraction_disordered_: float, has_clash_: bool +) -> float: + # ipTM is NaN for single chain structures. Use pTM for such cases. + if np.isnan(iptm): + ptm_iptm_average = ptm + else: + ptm_iptm_average = _IPTM_WEIGHT * iptm + (1.0 - _IPTM_WEIGHT) * ptm + return ( + ptm_iptm_average + + _FRACTION_DISORDERED_WEIGHT * fraction_disordered_ + - _CLASH_PENALIZATION_WEIGHT * has_clash_ + ) + + +def rank_metric( + full_pde: np.ndarray, contact_probs: np.ndarray +) -> np.ndarray: + """Compute the metric that will be used to rank predictions, higher is better. + + Args: + full_pde: A [num_samples, num_tokens,num_tokens] matrix of predicted + distance errors between pairs of tokens. + contact_probs: A [num_tokens, num_tokens] matrix consisting of the + probability of contact (<8A) that is returned from the distogram head. + + Returns: + A scalar that can be used to rank (higher is better). + """ + if not isinstance(full_pde, type(contact_probs)): + raise ValueError( + 'full_pde and contact_probs must be of the same type.') + + if isinstance(full_pde, np.ndarray): + sum_fn = np.sum + else: + raise ValueError('full_pde must be a numpy array or a jax array.') + # It was found that taking the contact_map weighted average was better than + # just the predicted distance error on its own. + return -sum_fn(full_pde * contact_probs[None, :, :], axis=(-2, -1)) / ( + sum_fn(contact_probs) + 1e-6 + ) + + +def weighted_mean(mask, value, axis): + return np.mean(mask * value, axis=axis) / (1e-8 + np.mean(mask, axis=axis)) + + +def pde_single( + num_tokens: int, + asym_ids: np.ndarray, + full_pde: np.ndarray, + contact_probs: np.ndarray, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Compute 1D PDE summaries. + + Args: + num_tokens: The number of tokens (not including padding). + asym_ids: The asym_ids (array of shape num_tokens). + full_pde: A [num_samples, num_tokens, num_tokens] matrix of predicted + distance errors. + contact_probs: A [num_tokens, num_tokens] matrix consisting of the + probability of contact (<8A) that is returned from the distogram head. + + Returns: + A tuple (ichain, xchain, full_chain) where: + `ichain` is a [num_samples, num_chains] matrix where the + value assigned to each chain is an average of the full PDE matrix over all + its within-chain interactions, weighted by `contact_probs`. + `xchain` is a [num_samples, num_chains] matrix where the + value assigned to each chain is an average of the full PDE matrix over all + its cross-chain interactions, weighted by `contact_probs`. + `full_chain` is a [num_samples, num_tokens] matrix where the + value assigned to each token is an average of it PDE against all tokens, + weighted by `contact_probs`. + """ + + full_pde = full_pde[:, :num_tokens, :num_tokens] + contact_probs = contact_probs[:num_tokens, :num_tokens] + asym_ids = asym_ids[:num_tokens] + unique_asym_ids = np.unique(asym_ids) + num_chains = len(unique_asym_ids) + num_samples = full_pde.shape[0] + + asym_ids = asym_ids[None] + contact_probs = contact_probs[None] + + ichain = np.zeros((num_samples, num_chains)) + xchain = np.zeros((num_samples, num_chains)) + + for idx, asym_id in enumerate(unique_asym_ids): + my_asym_id = asym_ids == asym_id + imask = my_asym_id[:, :, None] * my_asym_id[:, None, :] + xmask = my_asym_id[:, :, None] * ~my_asym_id[:, None, :] + imask = imask * contact_probs + xmask = xmask * contact_probs + ichain[:, idx] = weighted_mean( + mask=imask, value=full_pde, axis=(-2, -1)) + xchain[:, idx] = weighted_mean( + mask=xmask, value=full_pde, axis=(-2, -1)) + + full_chain = weighted_mean(mask=contact_probs, value=full_pde, axis=(-1,)) + + return ichain, xchain, full_chain + + +def chain_pair_pde( + num_tokens: int, asym_ids: np.ndarray, full_pde: np.ndarray +) -> tuple[np.ndarray, np.ndarray]: + """Compute predicted distance errors for all pairs of chains. + + Args: + num_tokens: The number of tokens (not including padding). + asym_ids: The asym_ids (array of shape num_tokens). + full_pde: A [num_samples, num_tokens, num_tokens] matrix of predicted + distance errors. + + Returns: + chain_pair_pred_err_mean - a [num_chains, num_chains] matrix with average + per chain-pair predicted distance error. + chain_pair_pred_err_min - a [num_chains, num_chains] matrix with min + per chain-pair predicted distance error. + """ + full_pde = full_pde[:, :num_tokens, :num_tokens] + asym_ids = asym_ids[:num_tokens] + unique_asym_ids = np.unique(asym_ids) + num_chains = len(unique_asym_ids) + num_samples = full_pde.shape[0] + chain_pair_pred_err_mean = np.zeros((num_samples, num_chains, num_chains)) + chain_pair_pred_err_min = np.zeros((num_samples, num_chains, num_chains)) + + for idx1, asym_id_1 in enumerate(unique_asym_ids): + subset = full_pde[:, asym_ids == asym_id_1, :] + for idx2, asym_id_2 in enumerate(unique_asym_ids): + subsubset = subset[:, :, asym_ids == asym_id_2] + chain_pair_pred_err_mean[:, idx1, idx2] = np.mean( + subsubset, axis=(1, 2)) + chain_pair_pred_err_min[:, idx1, idx2] = np.min( + subsubset, axis=(1, 2)) + return chain_pair_pred_err_mean, chain_pair_pred_err_min + + +def weighted_nanmean( + value: np.ndarray, mask: np.ndarray, axis: int +) -> np.ndarray: + """Nan-mean with weighting -- empty slices return NaN.""" + assert mask.shape == value.shape + assert not np.isnan(mask).all() + + nan_idxs = np.where(np.isnan(value)) + # Need to NaN the mask to get the correct denominator weighting. + mask_with_nan = mask.copy() + mask_with_nan[nan_idxs] = np.nan + with warnings.catch_warnings(): + # Mean of empty slice is ok and should return a NaN. + warnings.filterwarnings(action='ignore', message='Mean of empty slice') + return np.nanmean(value * mask_with_nan, axis=axis) / np.nanmean( + mask_with_nan, axis=axis + ) + + +def chain_pair_pae( + *, + num_tokens: int, + asym_ids: np.ndarray, + full_pae: np.ndarray, + mask: np.ndarray | None=None, + contact_probs: np.ndarray | None=None, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Compute predicted errors for all pairs of chains. + + Args: + num_tokens: The number of tokens (not including padding). + asym_ids: The asym_ids (array of shape num_tokens). + full_pae: A [num_samples, num_tokens, num_tokens] matrix of predicted + errors. + mask: A [num_tokens, num_tokens] mask matrix. + contact_probs: A [num_tokens, num_tokens] matrix consisting of the + probability of contact (<8A) that is returned from the distogram head. + + Returns: + chain_pair_pred_err_mean - a [num_chains, num_chains] matrix with average + per chain-pair predicted error. + """ + if mask is None: + mask = np.ones(shape=full_pae.shape[1:], dtype=bool) + if contact_probs is None: + contact_probs = np.ones(shape=full_pae.shape[1:], dtype=float) + assert mask.shape == full_pae.shape[1:] + + full_pae = full_pae[:, :num_tokens, :num_tokens] + mask = mask[:num_tokens, :num_tokens] + asym_ids = asym_ids[:num_tokens] + contact_probs = contact_probs[:num_tokens, :num_tokens] + unique_asym_ids = np.unique(asym_ids) + num_chains = len(unique_asym_ids) + num_samples = full_pae.shape[0] + chain_pair_pred_err_mean = np.zeros((num_samples, num_chains, num_chains)) + chain_pair_pred_err_min = np.zeros((num_samples, num_chains, num_chains)) + + for idx1, asym_id_1 in enumerate(unique_asym_ids): + subset = full_pae[:, asym_ids == asym_id_1, :] + subset_mask = mask[asym_ids == asym_id_1, :] + subset_contact_probs = contact_probs[asym_ids == asym_id_1, :] + for idx2, asym_id_2 in enumerate(unique_asym_ids): + subsubset = subset[:, :, asym_ids == asym_id_2] + subsubset_mask = subset_mask[:, asym_ids == asym_id_2] + subsubset_contact_probs = subset_contact_probs[:, + asym_ids == asym_id_2] + (flat_mask_idxs,) = np.where(subsubset_mask.flatten() > 0) + flat_subsubset = subsubset.reshape([num_samples, -1]) + flat_contact_probs = subsubset_contact_probs.flatten() + # A ligand chain will have no valid frames if it contains fewer than + # three non-colinear atoms (e.g. a sodium ion). + if not flat_mask_idxs.size: + chain_pair_pred_err_mean[:, idx1, idx2] = np.nan + chain_pair_pred_err_min[:, idx1, idx2] = np.nan + else: + chain_pair_pred_err_min[:, idx1, idx2] = np.min( + flat_subsubset[:, flat_mask_idxs], axis=1 + ) + chain_pair_pred_err_mean[:, idx1, idx2] = weighted_mean( + mask=flat_contact_probs[flat_mask_idxs], + value=flat_subsubset[:, flat_mask_idxs], + axis=-1, + ) + return chain_pair_pred_err_mean, chain_pair_pred_err_min, unique_asym_ids + + +def reduce_chain_pair( + *, + chain_pair_met: np.ndarray, + num_chain_tokens: np.ndarray, + agg_over_col: bool, + agg_type: str, + weight_method: str, +) -> tuple[np.ndarray, np.ndarray]: + """Compute 1D summaries from a chain-pair summary. + + Args: + chain_pair_met: A [num_samples, num_chains, num_chains] aggregate matrix. + num_chain_tokens: A [num_chains] array of number of tokens for each chain. + Used for 'per_token' weighting. + agg_over_col: Whether to aggregate the PAE over rows (i.e. average error + when aligned to me) or columns (i.e. my average error when aligned to all + others.) + agg_type: The type of aggregation to use, 'mean' or 'min'. + weight_method: The method to use for weighting the PAE, 'per_token' or + 'per_chain'. + + Returns: + A tuple (ichain, xchain) where: + `ichain` is a [num_samples, num_chains] matrix where the + value assigned to each chain is an average of the full PAE matrix over all + its within-chain interactions, weighted by `contact_probs`. + `xchain` is a [num_samples, num_chains] matrix where the + value assigned to each chain is an average of the full PAE matrix over all + its cross-chain interactions, weighted by `contact_probs`. + """ + num_samples, num_chains, _ = chain_pair_met.shape + + ichain = chain_pair_met.diagonal(axis1=-2, axis2=-1) + + if weight_method == 'per_chain': + chain_weight = np.ones((num_chains,), dtype=float) + elif weight_method == 'per_token': + chain_weight = num_chain_tokens + else: + raise ValueError(f'Unknown weight method: {weight_method}') + + if agg_over_col: + agg_axis = -1 + else: + agg_axis = -2 + + if agg_type == 'mean': + weight = np.ones((num_samples, num_chains, num_chains), dtype=float) + weight -= np.eye(num_chains, dtype=float) + weight *= chain_weight[None] * chain_weight[:, None] + xchain = weighted_nanmean(chain_pair_met, mask=weight, axis=agg_axis) + elif agg_type == 'min': + is_self = np.eye(num_chains) + with warnings.catch_warnings(): + # Min over empty slice is ok and should return a NaN. + warnings.filterwarnings( + 'ignore', message='All-NaN slice encountered') + xchain = np.nanmin(chain_pair_met + 1e8 * is_self, axis=agg_axis) + else: + raise ValueError(f'Unknown aggregation method: {agg_type}') + + return ichain, xchain + + +def pae_metrics( + num_tokens: int, + asym_ids: np.ndarray, + full_pae: np.ndarray, + mask: np.ndarray, + contact_probs: np.ndarray, + tm_adjusted_pae: np.ndarray, +): + """PAE aggregate metrics.""" + assert mask.shape == full_pae.shape[1:] + assert contact_probs.shape == full_pae.shape[1:] + + chain_pair_contact_weighted, _, unique_asym_ids = chain_pair_pae( + num_tokens=num_tokens, + asym_ids=asym_ids, + full_pae=full_pae, + mask=mask, + contact_probs=contact_probs, + ) + + ret = {} + ret['chain_pair_pae_mean'], ret['chain_pair_pae_min'], _ = chain_pair_pae( + num_tokens=num_tokens, + asym_ids=asym_ids, + full_pae=full_pae, + mask=mask, + ) + chain_pair_iptm = np.stack( + [ + chain_pairwise_predicted_tm_scores( + tm_adjusted_pae=sample_tm_adjusted_pae[:num_tokens], + asym_id=asym_ids[:num_tokens], + pair_mask=mask[:num_tokens, :num_tokens], + ) + for sample_tm_adjusted_pae in tm_adjusted_pae + ], + axis=0, + ) + + num_chain_tokens = np.array( + [sum(asym_ids == asym_id) for asym_id in unique_asym_ids] + ) + + def reduce_chain_pair_fn(chain_pair: np.ndarray): + def inner(agg_over_col): + ichain_pae, xchain_pae = reduce_chain_pair( + num_chain_tokens=num_chain_tokens, + chain_pair_met=chain_pair, + agg_over_col=agg_over_col, + agg_type='mean', + weight_method='per_chain', + ) + return ichain_pae, xchain_pae + + ichain, xchain_row_agg = inner(False) + _, xchain_col_agg = inner(True) + with warnings.catch_warnings(): + # Mean of empty slice is ok and should return a NaN. + warnings.filterwarnings( + action='ignore', message='Mean of empty slice') + xchain = np.nanmean( + np.stack([xchain_row_agg, xchain_col_agg], axis=0), axis=0 + ) + return ichain, xchain + + pae_ichain, pae_xchain = reduce_chain_pair_fn(chain_pair_contact_weighted) + iptm_ichain, iptm_xchain = reduce_chain_pair_fn(chain_pair_iptm) + + ret.update({ + 'chain_pair_iptm': chain_pair_iptm, + 'iptm_ichain': iptm_ichain, + 'iptm_xchain': iptm_xchain, + 'pae_ichain': pae_ichain, + 'pae_xchain': pae_xchain, + }) + + return ret + + +def get_iptm_xchain(chain_pair_iptm: np.ndarray) -> np.ndarray: + """Cross chain aggregate ipTM.""" + num_samples, num_chains, _ = chain_pair_iptm.shape + weight = np.ones((num_samples, num_chains, num_chains), dtype=float) + weight -= np.eye(num_chains, dtype=float) + xchain_row_agg = weighted_nanmean(chain_pair_iptm, mask=weight, axis=-2) + xchain_col_agg = weighted_nanmean(chain_pair_iptm, mask=weight, axis=-1) + with warnings.catch_warnings(): + # Mean of empty slice is ok and should return a NaN. + warnings.filterwarnings(action='ignore', message='Mean of empty slice') + iptm_xchain = np.nanmean( + np.stack([xchain_row_agg, xchain_col_agg], axis=0), axis=0 + ) + return iptm_xchain + + +def predicted_tm_score( + tm_adjusted_pae: np.ndarray, + pair_mask: np.ndarray, + asym_id: np.ndarray, + interface: bool = False, +) -> float: + """Computes predicted TM alignment or predicted interface TM alignment score. + + Args: + tm_adjusted_pae: [num_res, num_res] Relevant tensor for computing TMScore + values. + pair_mask: A [num_res, num_res] mask. The TM score will only aggregate over + masked-on entries. + asym_id: [num_res] asymmetric unit ID (the chain ID). Only needed for ipTM + calculation, i.e. when interface=True. + interface: If True, the interface predicted TM score is computed. If False, + the predicted TM score without any residue pair restrictions is computed. + + Returns: + score: pTM or ipTM score. + """ + num_tokens, _ = tm_adjusted_pae.shape + if tm_adjusted_pae.shape != (num_tokens, num_tokens): + raise ValueError( + f'Bad tm_adjusted_pae shape, expected ({num_tokens, num_tokens}), got ' + f'{tm_adjusted_pae.shape}.' + ) + + if pair_mask.shape != (num_tokens, num_tokens): + raise ValueError( + f'Bad pair_mask shape, expected ({num_tokens, num_tokens}), got ' + f'{pair_mask.shape}.' + ) + if pair_mask.dtype != bool: + raise TypeError( + f'Bad pair mask type, expected bool, got {pair_mask.dtype}') + if asym_id.shape[0] != num_tokens: + raise ValueError( + f'Bad asym_id shape, expected ({num_tokens},), got {asym_id.shape}.' + ) + + # Create pair mask. + if interface: + pair_mask = pair_mask * (asym_id[:, None] != asym_id[None, :]) + + # Ions and other ligands with colinear atoms have ill-defined frames. + if pair_mask.sum() == 0: + return np.nan + + normed_residue_mask = pair_mask / ( + 1e-8 + np.sum(pair_mask, axis=-1, keepdims=True) + ) + per_alignment = np.sum(tm_adjusted_pae * normed_residue_mask, axis=-1) + return per_alignment.max() + + +def chain_pairwise_predicted_tm_scores( + tm_adjusted_pae: np.ndarray, + pair_mask: np.ndarray, + asym_id: np.ndarray, +) -> np.ndarray: + """Compute predicted TM (pTM) between each pair of chains independently. + + Args: + tm_adjusted_pae: [num_res, num_res] Relevant tensor for computing TMScore + values. + pair_mask: A [num_res, num_res] mask specifying which frames are valid. + Invalid frames can be the result of chains with not enough atoms (e.g. + ions). + asym_id: [num_res] asymmetric unit ID (the chain ID). + + Returns: + A [num_chains, num_chains] matrix, where row i, column j indicates the + predicted TM-score for the interface between chain i and chain j. + """ + unique_chains = list(np.unique(asym_id)) + num_chains = len(unique_chains) + all_pairs_iptms = np.zeros((num_chains, num_chains)) + for i, chain_i in enumerate(unique_chains): + chain_i_mask = asym_id == chain_i + for j, chain_j in enumerate(unique_chains[i:]): + chain_j_mask = asym_id == chain_j + mask = chain_i_mask | chain_j_mask + (indices,) = np.where(mask) + is_interface = chain_i != chain_j + indices = np.ix_(indices, indices) + iptm = predicted_tm_score( + tm_adjusted_pae=tm_adjusted_pae[indices], + pair_mask=pair_mask[indices], + asym_id=asym_id[mask], + interface=is_interface, + ) + all_pairs_iptms[i, i + j] = iptm + all_pairs_iptms[i + j, i] = iptm + return all_pairs_iptms diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/data3.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/data3.py new file mode 100644 index 0000000000000000000000000000000000000000..26c7993873728911b040615b51d6c13702a4d4f0 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/data3.py @@ -0,0 +1,127 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Protein features that are computed from parsed mmCIF objects.""" + +from collections.abc import Mapping, MutableMapping +import datetime +from typing import TypeAlias + +from alphafold3.constants import residue_names +from alphafold3.cpp import msa_profile +from alphafold3.model import protein_data_processing +import numpy as np + + +FeatureDict: TypeAlias = Mapping[str, np.ndarray] +MutableFeatureDict: TypeAlias = MutableMapping[str, np.ndarray] + + +def fix_features(msa_features: MutableFeatureDict) -> MutableFeatureDict: + """Renames the deletion_matrix feature.""" + msa_features['deletion_matrix'] = msa_features.pop('deletion_matrix_int') + return msa_features + + +def get_profile_features( + msa: np.ndarray, deletion_matrix: np.ndarray +) -> FeatureDict: + """Returns the MSA profile and deletion_mean features.""" + num_restypes = residue_names.POLYMER_TYPES_NUM_WITH_UNKNOWN_AND_GAP + profile = msa_profile.compute_msa_profile( + msa=msa, num_residue_types=num_restypes + ) + + return { + 'profile': profile.astype(np.float32), + 'deletion_mean': np.mean(deletion_matrix, axis=0), + } + + +def fix_template_features( + sequence: str, + template_features: FeatureDict, +) -> FeatureDict: + """Convert template features to AlphaFold 3 format. + + Args: + sequence: amino acid sequence of the protein. + template_features: Template features for the protein. + + Returns: + Updated template_features for the chain. + """ + num_res = len(sequence) + if not template_features['template_aatype'].shape[0]: + template_features = empty_template_features(num_res) + else: + template_release_timestamp = [ + _get_timestamp(x.decode('utf-8')) + for x in template_features['template_release_date'] + ] + + # Convert from atom37 to dense atom + dense_atom_indices = np.take( + protein_data_processing.PROTEIN_AATYPE_DENSE_ATOM_TO_ATOM37, + template_features['template_aatype'], + axis=0, + ) + + atom_mask = np.take_along_axis( + template_features['template_all_atom_masks'], dense_atom_indices, axis=2 + ) + atom_positions = np.take_along_axis( + template_features['template_all_atom_positions'], + dense_atom_indices[..., None], + axis=2, + ) + atom_positions *= atom_mask[..., None] + + template_features = { + 'template_aatype': template_features['template_aatype'], + 'template_atom_mask': atom_mask.astype(np.int32), + 'template_atom_positions': atom_positions.astype(np.float32), + 'template_domain_names': np.array( + template_features['template_domain_names'], dtype=object + ), + 'template_release_timestamp': np.array( + template_release_timestamp, dtype=np.float32 + ), + } + return template_features + + +def empty_template_features(num_res: int) -> FeatureDict: + """Creates a fully masked out template features to allow padding to work. + + Args: + num_res: The length of the target chain. + + Returns: + Empty template features for the chain. + """ + template_features = { + 'template_aatype': np.zeros(num_res, dtype=np.int32)[None, ...], + 'template_atom_mask': np.zeros( + (num_res, protein_data_processing.NUM_DENSE), dtype=np.int32 + )[None, ...], + 'template_atom_positions': np.zeros( + (num_res, protein_data_processing.NUM_DENSE, 3), dtype=np.float32 + )[None, ...], + 'template_domain_names': np.array([b''], dtype=object), + 'template_release_timestamp': np.array([0.0], dtype=np.float32), + } + return template_features + + +def _get_timestamp(date_str: str): + dt = datetime.datetime.fromisoformat(date_str) + dt = dt.replace(tzinfo=datetime.timezone.utc) + return dt.timestamp() diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/data_constants.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/data_constants.py new file mode 100644 index 0000000000000000000000000000000000000000..eabdcfda922340cfbdd1f03f0cf11c6e4692bcb2 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/data_constants.py @@ -0,0 +1,27 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Constants shared across modules in the AlphaFold data pipeline.""" + +from alphafold3.constants import residue_names + +MSA_GAP_IDX = residue_names.PROTEIN_TYPES_ONE_LETTER_WITH_UNKNOWN_AND_GAP.index( + '-' +) + +# Feature groups. +NUM_SEQ_NUM_RES_MSA_FEATURES = ('msa', 'msa_mask', 'deletion_matrix') +NUM_SEQ_MSA_FEATURES = ('msa_species_identifiers',) +TEMPLATE_FEATURES = ( + 'template_aatype', + 'template_atom_positions', + 'template_atom_mask', +) +MSA_PAD_VALUES = {'msa': MSA_GAP_IDX, 'msa_mask': 1, 'deletion_matrix': 0} diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/atom_cross_attention.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/atom_cross_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..f10cbfd0bb43050e78f1fb361e60b984ad849b07 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/atom_cross_attention.py @@ -0,0 +1,466 @@ +# Copyright 2024 DeepMind Technologies Limited +# Copyright (C) 2025 Huawei Technologies Co., Ltd +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# +# Modifications by Huawei Technologies Co., Ltd: Adapt to run by MindSpore on Ascend + +from dataclasses import dataclass +import mindspore as ms +from mindspore import nn, ops, Tensor + +from alphafold3.model import base_config +from alphafold3.model.atom_layout import atom_layout +from alphafold3.model.components import base_modules as bm +from alphafold3.model.components import utils +from alphafold3.model.diffusion import diffusion_transformer + +@dataclass +class AtomCrossAttEncoderConfig(base_config.BaseConfig): + per_token_channels: int = 768 + per_atom_channels: int = 128 + atom_transformer: diffusion_transformer.CrossAttTransformer.Config = ( + base_config.autocreate(num_intermediate_factor=2, num_blocks=3) + ) + per_atom_pair_channels: int = 16 + + +class _PerAtomConditioning(nn.Cell): + """ + A class to compute per-atom and pairwise conditioning information for structural data. + + Args: + config: Configuration object containing model parameters. + + Inputs: + - **batch** (dict) - A dictionary containing structural information: + - **ref_structure.positions** (Tensor) - Tensor of atomic positions. + - **ref_structure.mask** (Tensor) - Tensor of masks indicating valid atoms. + - **ref_structure.element** (Tensor) - Tensor of atomic elements. + - **ref_structure.charge** (Tensor) - Tensor of atomic charges. + - **ref_structure.atom_name_chars** (Tensor) - Tensor of atomic name characters. + + Outputs: + - **act** (Tensor) - Per-atom conditioning information. + - **pair_act** (Tensor) - Pairwise conditioning information. + """ + + def __init__(self, config): + super().__init__() + self.c = config + self.linear1 = nn.Dense(3, self.c.per_atom_channels, has_bias=False) + self.linear2 = nn.Dense(1, self.c.per_atom_channels, has_bias=False) + self.linear3 = nn.Dense(128, self.c.per_atom_channels, has_bias=False) + self.linear4 = nn.Dense(1, self.c.per_atom_channels, has_bias=False) + self.linear5 = nn.Dense(256, self.c.per_atom_channels, has_bias=False) + self.linear_row_act = nn.Dense( + self.c.per_atom_channels, self.c.per_atom_pair_channels, has_bias=False) + self.linear_col_act = nn.Dense( + self.c.per_atom_channels, self.c.per_atom_pair_channels, has_bias=False) + self.linear_pair_act1 = nn.Dense( + 3, self.c.per_atom_pair_channels, has_bias=False) + self.linear_pair_act2 = nn.Dense( + 1, self.c.per_atom_pair_channels, has_bias=False) + + @ms.jit + def construct(self, batch): + # Compute per-atom single conditioning + # Shape (num_tokens, num_dense, channels) + act = self.linear1(batch.ref_structure.positions) + act += self.linear2(batch.ref_structure.mask[:, :, None]) + # Element is encoded as atomic number if the periodic table, so + # 128 should be fine. + act += self.linear3( + ops.one_hot(batch.ref_structure.element, 128, + Tensor(1.0, ms.float32), Tensor(0.0, ms.float32)) + .astype(act.dtype)) + act += self.linear4(ops.arcsinh(batch.ref_structure.charge) + [:, :, None]) + # Characters are encoded as ASCII code minus 32, so we need 64 classes, + # to encode all standard ASCII characters between 32 and 96. + atom_name_chars_1hot = ops.one_hot(batch.ref_structure.atom_name_chars, 64, + Tensor(1.0, ms.float32), Tensor(0.0, ms.float32)).astype(act.dtype) + num_token, num_dense, _ = act.shape + act += self.linear5(atom_name_chars_1hot.reshape(num_token, num_dense, -1)) + act *= batch.ref_structure.mask[:, :, None] + + # Compute pair conditioning + # shape (num_tokens, num_dense, num_dense, channels) + # Embed single features + row_act = self.linear_row_act(ops.relu(act)) + col_act = self.linear_col_act(ops.relu(act)) + pair_act = row_act[:, :, None, :] + col_act[:, None, :, :] + + # Embed pairwise offsets + pair_act += self.linear_pair_act1(batch.ref_structure.positions[:, :, None, :] + - batch.ref_structure.positions[:, None, :, :]) + # Embed pairwise inverse squared distances + sq_dists = ops.sum(ops.square(batch.ref_structure.positions[:, :, None, :] + - batch.ref_structure.positions[:, None, :, :]), dim=-1) + pair_act += self.linear_pair_act2(1.0 / (1 + sq_dists[:, :, :, None])) + return act, pair_act + +@dataclass +class AtomCrossAttEncoderOutput: + def __init__( + self, + token_act, + skip_connection, + queries_mask, + queries_single_cond, + keys_mask, + keys_single_cond, + pair_cond, + ): + self.token_act = token_act + self.skip_connection = skip_connection + self.queries_mask = queries_mask + self.queries_single_cond = queries_single_cond + self.keys_mask = keys_mask + self.keys_single_cond = keys_single_cond + self.pair_cond = pair_cond + + +class AtomCrossAttEncoder(nn.Cell): + """Cross-attention on flat atom subsets and mapping to per-token features. + + Args: + config: Configuration object containing model parameters. + global_config: Global configuration object with initialization settings. + name (str): Name of the module. + cond_channels (int): Number of conditioning channels. Default: ``384``. + with_cond (bool): Whether to include conditioning layers. Default: ``True``. + + Inputs: + - **token_atoms_act** (ms.Tensor): Tensor representing token atom activations. + - **trunk_single_cond** (ms.Tensor): Tensor representing single token conditioning. + - **trunk_pair_cond** (ms.Tensor): Tensor representing pair token conditioning. + - **batch** (feat_batch.Batch) : Batch of input data. + + Outputs: + - **token_act** (ms.Tensor): Activations for tokens after processing. + - **skip_connection** (ms.Tensor): Skip connection tensor for token queries. + - **queries_mask** (ms.Tensor): Mask for token queries. + - **queries_single_cond** (ms.Tensor): Single conditioning for token queries. + - **keys_mask** (ms.Tensor): Mask for token keys. + - **keys_single_cond** (ms.Tensor): Single conditioning for token keys. + - **pair_cond** (ms.Tensor): Pair conditioning tensor. + """ + + def __init__(self, config, global_config, name, cond_channels=384, with_cond=True, dtype=ms.float32): + super().__init__() + self.c = config + self.with_cond = with_cond + self.dtype = dtype + self._per_atom_conditioning = _PerAtomConditioning(config) + if self.with_cond: + self._embed_trunk_single_cond = nn.Dense(cond_channels, self.c.per_atom_channels, + weight_init=global_config.final_init, has_bias=False, dtype=dtype) + self._lnorm_trunk_single_cond = bm.LayerNorm((cond_channels,), + create_beta=False, gamma_init="ones", dtype=dtype) + self._atom_positions_to_features = nn.Dense(3, self.c.per_atom_channels, has_bias=False, dtype=dtype) + self._embed_trunk_pair_cond = nn.Dense(self.c.per_atom_channels, self.c.per_atom_pair_channels, + weight_init=global_config.final_init, has_bias=False, dtype=dtype) + self._lnorm_trunk_pair_cond = bm.LayerNorm((self.c.per_atom_channels,), create_beta=False, + gamma_init="ones", dtype=dtype) + + self._single_to_pair_cond_row = nn.Dense( + self.c.per_atom_channels, self.c.per_atom_pair_channels, has_bias=False, dtype=dtype) + self._single_to_pair_cond_col = nn.Dense( + self.c.per_atom_channels, self.c.per_atom_pair_channels, has_bias=False, dtype=dtype) + + self._embed_pair_offsets = nn.Dense( + 3, self.c.per_atom_pair_channels, has_bias=False, dtype=dtype) + self._embed_pair_distances = nn.Dense( + 1, self.c.per_atom_pair_channels, has_bias=False, dtype=dtype) + self._embed_pair_offsets_valid = nn.Dense( + 1, self.c.per_atom_pair_channels, has_bias=False, dtype=dtype) + + self._pair_mlp_1 = nn.Dense( + self.c.per_atom_pair_channels, self.c.per_atom_pair_channels, has_bias=False, dtype=dtype) + self._pair_mlp_2 = nn.Dense( + self.c.per_atom_pair_channels, self.c.per_atom_pair_channels, has_bias=False, dtype=dtype) + self._pair_mlp_3 = nn.Dense(self.c.per_atom_pair_channels, self.c.per_atom_pair_channels, + weight_init=global_config.final_init, has_bias=False, dtype=dtype) + self.relu = nn.ReLU() + self._project_atom_features_for_aggr = nn.Dense( + self.c.per_atom_channels, self.c.per_token_channels, has_bias=False, dtype=dtype) + + self._atom_transformer_encoder = diffusion_transformer.CrossAttTransformer( + self.c.atom_transformer, global_config, in_shape=[ + self.c.per_atom_channels, self.c.per_atom_pair_channels], dtype=dtype + ) + + def construct( + self, + token_atoms_act, + trunk_single_cond, + trunk_pair_cond, + batch, + ): + # Compute single conditioning from atom meta data and convert to queries + # layout. + token_atoms_single_cond, _ = self._per_atom_conditioning( + batch) + token_atoms_mask = batch.predicted_structure_info.atom_mask + queries_single_cond = atom_layout.convert_ms( + batch.atom_cross_att.token_atoms_to_queries, + token_atoms_single_cond, + layout_axes=(-3, -2), + ) + queries_mask = atom_layout.convert_ms( + batch.atom_cross_att.token_atoms_to_queries, + token_atoms_mask, + layout_axes=(-2, -1), + ) + + # If provided, broadcast single conditioning from trunk to all queries + if trunk_single_cond is not None: + trunk_single_cond = self._embed_trunk_single_cond( + self._lnorm_trunk_single_cond( + trunk_single_cond) + ) + queries_single_cond += atom_layout.convert_ms( + batch.atom_cross_att.tokens_to_queries, + trunk_single_cond, + layout_axes=(-2,), + ) + + if token_atoms_act is None: + # if no token_atoms_act is given (e.g. begin of evoformer), we use the + # static conditioning only + queries_act = queries_single_cond + else: + # Convert token_atoms_act to queries layout and map to per_atom_channels + queries_act = atom_layout.convert_ms( + batch.atom_cross_att.token_atoms_to_queries, + token_atoms_act, + layout_axes=(-3, -2), + ) + queries_act = self._atom_positions_to_features( + queries_act) + queries_act *= queries_mask[..., None] + queries_act += queries_single_cond + + # Gather the keys from the queries. + keys_single_cond = atom_layout.convert_ms( + batch.atom_cross_att.queries_to_keys, queries_single_cond, layout_axes=( + -3, -2), + ) + keys_mask = atom_layout.convert_ms( + batch.atom_cross_att.queries_to_keys, queries_mask, layout_axes=( + -2, -1) + ) + + # Embed single features into the pair conditioning. + row_act = self._single_to_pair_cond_row( + self.relu(queries_single_cond)) + pair_cond_keys_input = atom_layout.convert_ms( + batch.atom_cross_att.queries_to_keys, queries_single_cond, layout_axes=( + -3, -2), + ) + col_act = self._single_to_pair_cond_col( + self.relu(pair_cond_keys_input)) + pair_act = row_act[:, :, None, :] + col_act[:, None, :, :] + + if trunk_pair_cond is not None: + # If provided, broadcast the pair conditioning for the trunk (evoformer + # pairs) to the atom pair activations. This should boost ligands, but also + # help for cross attention within proteins, because we always have atoms + # from multiple residues in a subset. + # Map trunk pair conditioning to per_atom_pair_channels + trunk_pair_cond = self._embed_trunk_pair_cond( + self._lnorm_trunk_pair_cond( + trunk_pair_cond) + ) + + # Create the GatherInfo into a flattened trunk_pair_cond from the + # queries and keys gather infos. + num_tokens = trunk_pair_cond.shape[0] + tokens_to_queries = batch.atom_cross_att.tokens_to_queries + tokens_to_keys = batch.atom_cross_att.tokens_to_keys + + # Gather the conditioning and add it to the atom-pair activations. + gather_idxs = Tensor(num_tokens * tokens_to_queries.gather_idxs[:, :, None] + + tokens_to_keys.gather_idxs[:, None, :]) + gather_mask = ops.logical_and(tokens_to_queries.gather_mask[:, :, None], + tokens_to_keys.gather_mask[:, None, :]) + input_shape = Tensor((num_tokens, num_tokens)) + trunk_pair_to_atom_pair = atom_layout.GatherInfo(gather_idxs=gather_idxs, + gather_mask=gather_mask, + input_shape=input_shape) + pair_act += atom_layout.convert_ms( + trunk_pair_to_atom_pair, trunk_pair_cond, layout_axes=(-3, -2) + ) + + # Embed pairwise offsets + queries_ref_pos = atom_layout.convert_ms( + batch.atom_cross_att.token_atoms_to_queries, + batch.ref_structure.positions, + layout_axes=(-3, -2), + ) + queries_ref_space_uid = atom_layout.convert_ms( + batch.atom_cross_att.token_atoms_to_queries, + batch.ref_structure.ref_space_uid, + layout_axes=(-2, -1), + ) + keys_ref_pos = atom_layout.convert_ms( + batch.atom_cross_att.queries_to_keys, + queries_ref_pos, + layout_axes=(-3, -2), + ) + keys_ref_space_uid = atom_layout.convert_ms( + batch.atom_cross_att.queries_to_keys, + batch.ref_structure.ref_space_uid, + layout_axes=(-2, -1), + ) + + offsets_valid = ( + queries_ref_space_uid[:, :, None] == keys_ref_space_uid[:, None, :] + ) + offsets = queries_ref_pos[:, :, None, :] - keys_ref_pos[:, None, :, :] + pair_act += (self._embed_pair_offsets(offsets) + * offsets_valid[:, :, :, None]) + + # Embed pairwise inverse squared distances + sq_dists = ops.sum(ops.square(offsets), dim=-1) + pair_act += ( + self._embed_pair_distances(1.0 / (1 + sq_dists[:, :, :, None])) + * offsets_valid[:, :, :, None] + ) + + # Embed offsets valid mask + pair_act += self._embed_pair_offsets_valid( + offsets_valid[:, :, :, None].astype(ms.float32)) + + # Run a small MLP on the pair acitvations + pair_act2 = self._pair_mlp_1(self.relu(pair_act)) + pair_act2 = self._pair_mlp_2(self.relu(pair_act2)) + pair_act += self._pair_mlp_3(self.relu(pair_act2)) + + # Run the atom cross attention transformer. + queries_act = self._atom_transformer_encoder( + queries_act=queries_act, + queries_mask=queries_mask, + queries_to_keys=batch.atom_cross_att.queries_to_keys, + keys_mask=keys_mask, + queries_single_cond=queries_single_cond, + keys_single_cond=keys_single_cond, + pair_cond=pair_act, + ) + queries_act *= queries_mask[..., None] + skip_connection = queries_act + + # convert back to token-atom layout and aggregate to tokens + queries_act = self._project_atom_features_for_aggr(queries_act) + token_atoms_act = atom_layout.convert_ms( + batch.atom_cross_att.queries_to_token_atoms, + queries_act, + layout_axes=(-3, -2), + ) + token_act = utils.mask_mean( + token_atoms_mask[..., None], self.relu(token_atoms_act), axis=-2 + ) + + return AtomCrossAttEncoderOutput( + token_act=token_act, + skip_connection=skip_connection, + queries_mask=queries_mask, + queries_single_cond=queries_single_cond, + keys_mask=keys_mask, + keys_single_cond=keys_single_cond, + pair_cond=pair_act, + ) + +@dataclass +class AtomCrossAttDecoderConfig(base_config.BaseConfig): + per_token_channels: int = 768 + per_atom_channels: int = 128 + per_atom_pair_channels: int = 16 + atom_transformer: diffusion_transformer.CrossAttTransformer.Config = ( + base_config.autocreate(num_intermediate_factor=2, num_blocks=3) + ) + + +class AtomCrossAttDecoder(nn.Cell): + """Mapping to per-atom features and self-attention on subsets. + + Args: + config: Configuration object containing model parameters. + global_config: Global configuration object with additional parameters. + name (str): Name of the decoder. Default: ``None``. + + Inputs: + - **token_act** (Tensor) - Tensor representing token activations. + - **enc** (AtomCrossAttEncoderOutput) - Output from the encoder containing necessary features and masks. + - **batch** (feat_batch.Batch) - Batch containing atom cross attention features. + + Outputs: + - **position_update** (Tensor) - Tensor representing the updated positions after processing. + """ + + def __init__(self, config, global_config, name, dtype=ms.float32): + super().__init__() + self.c = config + self._project_token_features_for_broadcast = nn.Dense( + self.c.per_token_channels, self.c.per_atom_channels, has_bias=False, dtype=dtype) + self._atom_features_layer_norm = bm.LayerNorm( + (self.c.per_atom_channels,), create_beta=False, gamma_init="ones", dtype=dtype) + self._atom_features_to_position_update = nn.Dense( + self.c.per_atom_channels, 3, weight_init=global_config.final_init, has_bias=False, dtype=dtype) + self._atom_transformer_decoder = diffusion_transformer.CrossAttTransformer( + self.c.atom_transformer, global_config, in_shape=[ + self.c.per_atom_channels, self.c.per_atom_pair_channels], dtype=dtype + ) + + def construct( + self, + token_act, + enc, + batch, + ): + # map per-token act down to per_atom channels + token_act = self._project_token_features_for_broadcast(token_act) + # Broadcast to token-atoms layout and convert to queries layout. + num_token, max_atoms_per_token = ( + batch.atom_cross_att.queries_to_token_atoms.shape + ) + token_atom_act = ops.broadcast_to( + token_act[:, None, :], + (num_token, max_atoms_per_token, self.c.per_atom_channels), + ) + queries_act = atom_layout.convert_ms( + batch.atom_cross_att.token_atoms_to_queries, + token_atom_act, + layout_axes=(-3, -2), + ) + queries_act += enc.skip_connection + queries_act *= enc.queries_mask[..., None] + + # Run the atom cross attention transformer. + queries_act = self._atom_transformer_decoder( + queries_act=queries_act, + queries_mask=enc.queries_mask, + queries_to_keys=batch.atom_cross_att.queries_to_keys, + keys_mask=enc.keys_mask, + queries_single_cond=enc.queries_single_cond, + keys_single_cond=enc.keys_single_cond, + pair_cond=enc.pair_cond, + ) + + queries_act *= enc.queries_mask[..., None] + queries_position_update = self._atom_features_to_position_update( + self._atom_features_layer_norm(queries_act) + ) + position_update = atom_layout.convert_ms( + batch.atom_cross_att.queries_to_token_atoms, + queries_position_update, + layout_axes=(-3, -2), + ) + return position_update diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/confidence_head.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/confidence_head.py new file mode 100644 index 0000000000000000000000000000000000000000..3cd568129c85217bee1c49ab80a9961f70d7bd16 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/confidence_head.py @@ -0,0 +1,289 @@ +# Copyright 2024 DeepMind Technologies Limited +# Copyright (C) 2025 Huawei Technologies Co., Ltd +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# +# Modifications by Huawei Technologies Co., Ltd: Adapt to run by MindSpore on Ascend + +"""Confidence Head.""" +from dataclasses import dataclass +import mindspore as ms +from mindspore import nn, ops +from alphafold3.model import base_config +from alphafold3.model.atom_layout import atom_layout +from alphafold3.model.components import base_modules as bm +from alphafold3.model.diffusion import modules +from alphafold3.model.diffusion import template_modules + + +def _safe_norm(x, keepdims, axis, eps=1e-8): + return ops.sqrt(eps + ops.sum(ops.square(x), dim=axis, keepdims=keepdims)) + + +class ConfidenceHead(nn.Cell): + """Head to predict the distance errors in a prediction. + + Args: + config (ConfidenceHead.Config): Configuration for the ConfidenceHead module. + global_config (base_config.BaseConfig): Global configuration for the model. + pair_shape (tuple): Shape of the pair features. + single_shape (tuple): Shape of the single features. + atom_shape (tuple): Shape of the atom features. + feat_in_channel (int): Number of input channels for feature projections. + out_channel (int): Number of output channels for feature projections. + + Inputs: + - **dense_atom_positions** (Tensor): [N_res, N_atom, 3] array of atom positions. + - **embeddings** (dict): Dictionary containing pair, single, and target features. + - **seq_mask** (Tensor): Sequence mask indicating valid residues. + - **token_atoms_to_pseudo_beta** (Tensor): Pseudo beta information for atom tokens. + - **asym_id** (Tensor): Asym ID token features. + + Outputs: + - **predicted_lddt** (Tensor): Predicted LDDT scores for each residue. + - **predicted_experimentally_resolved** (Tensor): Predicted experimental resolution scores. + - **full_pde** (Tensor): Full predicted distance errors. + - **average_pde** (Tensor): Average predicted distance errors. + - **pae_outputs** (dict): Additional outputs from PAE (Predicted Alignment Error) calculations. + """ + @dataclass + class PAEConfig(base_config.BaseConfig): + max_error_bin: float = 31.0 + num_bins: int = 64 + + @dataclass + class Config(base_config.BaseConfig): + """Configuration for ConfidenceHead.""" + + pairformer: modules.PairFormerIteration.Config = base_config.autocreate( + single_attention=base_config.autocreate(), + single_transition=base_config.autocreate(), + num_layer=4, + ) + max_error_bin: float = 31.0 + num_plddt_bins: int = 50 + num_bins: int = 64 + no_embedding_prob: float = 0.2 + pae: 'ConfidenceHead.PAEConfig' = base_config.autocreate() + dgram_features: template_modules.DistogramFeaturesConfig = ( + base_config.autocreate() + ) + + def __init__(self, config, global_config, pair_shape, single_shape, atom_shape, + feat_in_channel, out_channel, dtype=ms.float32): + super().__init__() + self.dtype = dtype + self.config = config + self.global_config = global_config + self.left_target_feat_project = nn.Dense( + feat_in_channel, out_channel, has_bias=False, dtype=dtype) + self.right_target_feat_project = nn.Dense( + feat_in_channel, out_channel, has_bias=False, dtype=dtype) + self.distogram_feat_project = nn.Dense( + template_modules.DistogramFeaturesConfig.num_bins, out_channel, has_bias=False, dtype=dtype) + self.pairformer_block = ms.nn.CellList( + [ + modules.PairFormerIteration( + self.config.pairformer, global_config, pair_shape, single_shape, with_single=True, dtype=dtype + ) + for _ in range(self.config.pairformer.num_layer) + ] + ) + self.left_half_distance_logits = nn.Dense( + pair_shape[-1], self.config.num_bins, has_bias=False, dtype=ms.float32) + self.logits_ln = bm.LayerNorm(pair_shape, dtype=ms.float32) + self.pae_logits = nn.Dense( + pair_shape[-1], self.config.pae.num_bins, has_bias=False, dtype=ms.float32) + self.pae_logits_ln = bm.LayerNorm(pair_shape, dtype=ms.float32) + self.plddt_logits = bm.CustomDense( + single_shape[-1], (atom_shape[-2], self.config.num_plddt_bins), ndim=2, dtype=ms.float32) + self.plddt_logits_ln = bm.LayerNorm(single_shape, dtype=ms.float32) + self.experimentally_resolved_logits = bm.CustomDense( + single_shape[-1], (atom_shape[-2], 2), ndim=2, dtype=ms.float32) + self.experimentally_resolved_ln = bm.LayerNorm(single_shape, dtype=ms.float32) + + def _embed_features(self, dense_atom_positions, token_atoms_to_pseude_beta, + pair_mask, target_feat): + out = self.left_target_feat_project(target_feat) + out2 = self.right_target_feat_project(target_feat)[:, None] + out = out + out2 + positions = atom_layout.convert_ms( + token_atoms_to_pseude_beta, + dense_atom_positions, + layout_axes=(-3, -2), + ) + dgram = template_modules.dgram_from_positions( + positions, self.config.dgram_features, dtype=ms.float32 + ) + dgram *= pair_mask[..., None] + out += self.distogram_feat_project(dgram) + return out + + def construct(self, dense_atom_positions, embeddings, seq_mask, + token_atoms_to_pseudo_beta, asym_id): + seq_mask_cast = seq_mask + pair_mask = seq_mask_cast[:, None] * seq_mask_cast[None, :] + pair_act = embeddings['pair'] + single_act = embeddings['single'] + target_feat = embeddings['target_feat'] + pair_act += self._embed_features( + dense_atom_positions, + token_atoms_to_pseudo_beta, + pair_mask, + target_feat, + ) + + for i in range(self.config.pairformer.num_layer): + pair_act, single_act = self.pairformer_block[i]( + pair_act, pair_mask, single_act, seq_mask) + pair_act = pair_act.astype(ms.float32) + + # Produce logits to predict a distogram of pairwise distance errors + # between the input prediction and the ground truth. + # Shape (num_res, num_res, num_bins) + left_distance_logits = self.left_half_distance_logits( + self.logits_ln(pair_act)) + right_distance_logits = left_distance_logits + distance_logits = left_distance_logits + ops.swapaxes( # Symmetrize. + right_distance_logits, -2, -3 + ) + # Shape (num_bins,) + distance_breaks = ops.linspace( + 0.0, self.config.max_error_bin, self.config.num_bins - 1 + ) + + step = distance_breaks[1] - distance_breaks[0] + + # Add half-step to get the center + bin_centers = distance_breaks + step / 2 + # Add a catch-all bin at the end. + bin_centers = ops.concat( + [bin_centers, bin_centers[-1:] + step], axis=0 + ) + + distance_probs = ops.softmax(distance_logits, axis=-1) + + pred_distance_error = ( + ops.sum(distance_probs * bin_centers, dim=-1) * pair_mask + ) + average_pred_distance_error = ops.sum( + pred_distance_error, dim=[-2, -1] + ) / ops.sum(pair_mask, dim=[-2, -1]) + + # Predicted aligned error + pae_outputs = {} + # Shape (num_res, num_res, num_bins) + pae_logits = self.pae_logits(self.pae_logits_ln(pair_act)) + # Shape (num_bins,) + pae_breaks = ops.linspace( + 0.0, self.config.pae.max_error_bin, self.config.pae.num_bins - 1 + ) + step = pae_breaks[1] - pae_breaks[0] + # Add half-step to get the center + bin_centers = pae_breaks + step / 2 + # Add a catch-all bin at the end. + bin_centers = ops.concat( + [bin_centers, bin_centers[-1:] + step], axis=0 + ) + pae_probs = ops.softmax(pae_logits, axis=-1) + + seq_mask_bool = seq_mask.astype(bool) + pair_mask_bool = seq_mask_bool[:, None] * seq_mask_bool[None, :] + pae = ops.sum(pae_probs * bin_centers, dim=-1) * pair_mask_bool + pae_outputs.update({ + 'full_pae': pae, + }) + + # The pTM is computed outside of bfloat16 context. + tmscore_adjusted_pae_global, tmscore_adjusted_pae_interface = ( + self._get_tmscore_adjusted_pae( + asym_id=asym_id, + seq_mask=seq_mask, + pair_mask=pair_mask_bool, + bin_centers=bin_centers, + pae_probs=pae_probs, + ) + ) + pae_outputs.update({ + 'tmscore_adjusted_pae_global': tmscore_adjusted_pae_global, + 'tmscore_adjusted_pae_interface': tmscore_adjusted_pae_interface, + }) + + # pLDDT + # Shape (num_res, num_atom, num_bins) + plddt_logits = self.plddt_logits(self.plddt_logits_ln(single_act)) + + bin_width = 1.0 / self.config.num_plddt_bins + bin_centers = ops.arange(0.5 * bin_width, 1.0, bin_width) + predicted_lddt = ops.sum( + ops.softmax(plddt_logits, axis=-1) * bin_centers, dim=-1 + ) + predicted_lddt = predicted_lddt * 100.0 + + # Experimentally resolved + # Shape (num_res, num_atom, 2) + experimentally_resolved_logits = self.experimentally_resolved_logits( + self.experimentally_resolved_ln(single_act) + ) + + predicted_experimentally_resolved = ops.softmax( + experimentally_resolved_logits, axis=-1 + )[..., 1] + + return { + 'predicted_lddt': predicted_lddt, + 'predicted_experimentally_resolved': predicted_experimentally_resolved, + 'full_pde': pred_distance_error, + 'average_pde': average_pred_distance_error, + **pae_outputs, + } + + def _get_tmscore_adjusted_pae( + self, asym_id, seq_mask, pair_mask, bin_centers, pae_probs, + ): + def get_tmscore_adjusted_pae(num_interface_tokens, bin_centers, pae_probs): + # Clip to avoid negative/undefined d0. + clipped_num_res = ops.maximum(num_interface_tokens, 19) + + # Compute d_0(num_res) as defined by TM-score, eqn. (5) in + # http://zhanglab.ccmb.med.umich.edu/papers/2004_3.pdf + # Yang & Skolnick "Scoring function for automated + # assessment of protein structure template quality" 2004. + d0 = 1.24 * (clipped_num_res - 15) ** (1.0 / 3) - 1.8 + + # Make compatible with [num_tokens, num_tokens, num_bins] + d0 = d0[:, :, None] + bin_centers = bin_centers[None, None, :] + + # TM-Score term for every bin. + tm_per_bin = 1.0 / (1 + ops.square(bin_centers) / ops.square(d0)) + # E_distances tm(distance). + predicted_tm_term = ops.sum(pae_probs * tm_per_bin, dim=-1) + return predicted_tm_term + + # Interface version + x = asym_id[None, :] == asym_id[:, None] + num_chain_tokens = ops.sum(x * pair_mask, dim=-1) + num_interface_tokens = num_chain_tokens[None, + :] + num_chain_tokens[:, None] + # Don't double-count within a single chain + num_interface_tokens -= x * (num_interface_tokens // 2) + num_interface_tokens = num_interface_tokens * pair_mask + + num_global_tokens = ops.full( + size=pair_mask.shape, fill_value=seq_mask.sum() + ).astype(ms.int32) + + global_apae = get_tmscore_adjusted_pae( + num_global_tokens, bin_centers, pae_probs + ) + interface_apae = get_tmscore_adjusted_pae( + num_interface_tokens, bin_centers, pae_probs + ) + return global_apae, interface_apae diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/diffusion_head.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/diffusion_head.py new file mode 100644 index 0000000000000000000000000000000000000000..aeb9c5f0e239706abe402b3ef0d499165b308519 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/diffusion_head.py @@ -0,0 +1,331 @@ +# Copyright 2024 DeepMind Technologies Limited +# Copyright (C) 2025 Huawei Technologies Co., Ltd +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# +# Modifications by Huawei Technologies Co., Ltd: Adapt to run by MindSpore on Ascend + +"""Diffusion Head.""" + +from dataclasses import dataclass +from collections.abc import Callable +import math +import numpy as np +import mindspore as ms +from mindspore import mint, nn +from alphafold3.constants import residue_names +from alphafold3.model import base_config +from alphafold3.model.components import base_modules as bm +from alphafold3.model.components import utils +from alphafold3.model.diffusion import atom_cross_attention +from alphafold3.model.diffusion import diffusion_transformer +from alphafold3.model.diffusion import featurization + + +# Carefully measured by averaging multimer training set. +SIGMA_DATA = 16.0 +WEIGHT = ms.Tensor(np.load(f"./src/alphafold3/model/diffusion/random/weight.npy"), dtype=ms.float32) +BIAS = ms.Tensor(np.load(f"./src/alphafold3/model/diffusion/random/bias.npy"), dtype=ms.float32) + +def fourier_embeddings(x): + return mint.cos(2 * math.pi * (x[..., None] * WEIGHT + BIAS)) + +def random_rotation(key): + # Create a random rotation (Gram-Schmidt orthogonalization of two + # random normal vectors) + np.random.seed(key) + v0, v1 = ms.Tensor(np.random.normal(0, 1, (2, 3)), dtype=ms.float32) + e0 = v0 / mint.maximum(1e-10, mint.norm(v0)) + v1 = v1 - e0 * mint.matmul(v1, e0) + e1 = v1 / mint.maximum(1e-10, mint.norm(v1)) + e2 = mint.cross(e0, e1) + return mint.stack([e0, e1, e2]) + +def random_augmentation(rng_key, positions, mask): + """Apply random rigid augmentation. + Args: + rng_key: random key + positions: atom positions of shape (, 3) + mask: per-atom mask of shape (,) + Returns: + Transformed positions with the same shape as input positions. + """ + center = utils.mask_mean( + mask.unsqueeze(-1), positions, axis=(-2, -3), keepdims=True, eps=1e-6 + ).astype(ms.float32) + rot = random_rotation(rng_key) + np.random.seed(rng_key) + translation = ms.Tensor(np.random.normal(0, 1, (3,)), dtype=ms.float32) + + augmented_positions = ( + mint.einsum( + '...i,ij->...j', + (positions - center).astype(ms.float32), + rot, + ) + + translation + ) + return augmented_positions * mask[..., None] + +def noise_schedule(t, smin=0.0004, smax=160.0, p=7): + return ( + SIGMA_DATA + * (smax ** (1 / p) + t * (smin ** (1 / p) - smax ** (1 / p))) ** p + ) + +@dataclass +class ConditioningConfig(base_config.BaseConfig): + pair_channel: int + seq_channel: int + prob: float + +@dataclass +class SampleConfig(base_config.BaseConfig): + steps: int + gamma_0: float = 0.8 + gamma_min: float = 1.0 + noise_scale: float = 1.003 + step_scale: float = 1.5 + num_samples: int = 1 + +class DiffusionHead(nn.Cell): + """Denoising Diffusion Head. + + Args: + config (Config): Configuration object containing parameters for the diffusion head. + global_config (GlobalConfig): Global configuration object containing shared parameters. + in_shape (tuple): Input shape for the module. + max_relative_chain (int): Maximum number of relative chains for positional encoding. Default: ``2``. + max_relative_idx (int): Maximum relative index for positional encoding. Default: ``32``. + + Inputs: + - **positions_noisy** (Tensor) - Noisy atomic positions tensor. + - **noise_level** (Tensor) - Tensor representing the noise level. + - **batch** (Batch) - Batch of input data containing token features and structure information. + - **embeddings** (dict) - Dictionary of embeddings for single and pair features. + - **use_conditioning** (bool) - Flag to enable or disable conditioning. + + Outputs: + - **position_update** (Tensor) - Refined atomic positions tensor. + """ + + class Config( + atom_cross_attention.AtomCrossAttEncoderConfig, + atom_cross_attention.AtomCrossAttDecoderConfig, + ): + """Configuration for DiffusionHead.""" + eval_batch_size: int = 5 + eval_batch_dim_shard_size: int = 5 + conditioning: ConditioningConfig = base_config.autocreate( + prob=0.8, pair_channel=128, seq_channel=384 + ) + eval: SampleConfig = base_config.autocreate( + num_samples=5, + steps=200, + ) + transformer: diffusion_transformer.Transformer.Config = ( + base_config.autocreate() + ) + + def __init__(self, config, global_config, in_shape, max_relative_chain=2, max_relative_idx=32, dtype=ms.float32): + super().__init__() + self.config = config + self.global_config = global_config + self.dtype = dtype + in_channel = in_shape[-1] + self.max_relative_chain = max_relative_chain + self.max_relative_idx = max_relative_idx + + # _conditioning modules + in_channel_pair = in_channel + 4 * self.max_relative_idx + 4 + 2 * self.max_relative_chain + 2 + 1 + self.pair_cond_initial_norm = bm.LayerNorm( + in_shape[:-1] + (in_channel_pair,), + create_beta=False, gamma_init="ones", + name='pair_cond_initial_norm', dtype=dtype) + self.pair_cond_initial_projection = nn.Dense(in_channel_pair, self.config.conditioning.pair_channel, + has_bias=False, dtype=ms.float32) + self.transition_block1 = diffusion_transformer.TransitionBlock( + global_config, in_channel, 2, with_single_cond=False, + name=f'pair_transition_1', dtype=dtype + ) + self.transition_block2 = diffusion_transformer.TransitionBlock( + global_config, in_channel, 2, with_single_cond=False, + name=f'pair_transition_2', dtype=dtype + ) + in_channel_single = self.config.conditioning.seq_channel * 2 \ + + residue_names.POLYMER_TYPES_NUM_WITH_UNKNOWN_AND_GAP * 2 + 1 + self.single_cond_initial_norm = bm.LayerNorm( + in_shape[:-1] + (in_channel_single,), + create_beta=False, gamma_init="ones", + name='single_cond_initial_norm', dtype=dtype) + self.single_cond_initial_projection = nn.Dense(in_channel_single, self.config.conditioning.seq_channel, + has_bias=False, dtype=dtype) + self.num_noise_embedding = 256 + self.layer_norm_noise = bm.LayerNorm( + in_shape[:-1]+(self.num_noise_embedding,), + create_beta=False, gamma_init="ones", + name='noise_embedding_initial_norm', dtype=dtype) + self.linear_noise = nn.Dense(self.num_noise_embedding, self.config.conditioning.seq_channel, + has_bias=False, dtype=dtype) + self.single_transition1 = diffusion_transformer.TransitionBlock( + global_config, self.config.conditioning.seq_channel, 2, + ndim=2, with_single_cond=False, name=f'single_transition_1', + dtype=dtype + ) + self.single_transition2 = diffusion_transformer.TransitionBlock( + global_config, self.config.conditioning.seq_channel, 2, + ndim=2, with_single_cond=False, name=f'single_transition_2', + dtype=dtype + ) + + # modules + self.layer_norm_act = bm.LayerNorm( + (in_channel,)+(self.config.conditioning.seq_channel,), + create_beta=False, gamma_init="ones", + name='single_cond_embedding_norm', dtype=dtype) + self.linear_act = nn.Dense(self.config.conditioning.seq_channel, + self.config.per_token_channels, has_bias=False, dtype=dtype) + self.layer_norm_out = bm.LayerNorm( + in_shape[:-1]+(self.config.per_token_channels,), + create_beta=False, gamma_init="ones", + name='output_norm', dtype=dtype) + self.atom_cross_att_encoder = atom_cross_attention.AtomCrossAttEncoder( + self.config, self.global_config, "", dtype=dtype + ) + self.transformer = diffusion_transformer.Transformer( + self.config.transformer, self.global_config, in_shape[:-1] + (self.config.conditioning.seq_channel * 2,), + in_shape, using_pair_act=True, dtype=dtype + ) + self.atom_cross_att_decoder = atom_cross_attention.AtomCrossAttDecoder( + self.config, self.global_config, '', dtype=dtype + ) + + def _conditioning(self, batch, embeddings, noise_level, use_conditioning): + single_embedding = use_conditioning * embeddings['single'] + pair_embedding = use_conditioning * embeddings['pair'] + rel_features = featurization.create_relative_encoding( + batch.token_features, max_relative_idx=self.max_relative_idx, max_relative_chain=self.max_relative_chain + ).astype(pair_embedding.dtype) + features_2d = mint.concat([pair_embedding, rel_features], dim=-1) + pair_cond = self.pair_cond_initial_projection( + self.pair_cond_initial_norm(features_2d) + ) + pair_cond += self.transition_block1(pair_cond) + pair_cond += self.transition_block2(pair_cond) + + target_feat = embeddings['target_feat'] + features_1d = mint.concat([single_embedding, target_feat], dim=-1) + single_cond = self.single_cond_initial_norm(features_1d) + single_cond = self.single_cond_initial_projection(single_cond) + noise_embedding = fourier_embeddings( + (1 / 4) * mint.log(noise_level / SIGMA_DATA) + ) + single_cond += self.linear_noise(self.layer_norm_noise(noise_embedding)) + single_cond += self.single_transition1(single_cond) + single_cond += self.single_transition2(single_cond) + + return single_cond, pair_cond + + def construct(self, positions_noisy, noise_level, batch, embeddings, use_conditioning): + trunk_single_cond, trunk_pair_cond = self._conditioning( + batch=batch, + embeddings=embeddings, + noise_level=noise_level, + use_conditioning=use_conditioning, + ) + + # Extract features + sequence_mask = batch.token_features.mask + atom_mask = batch.predicted_structure_info.atom_mask + # Position features + act = positions_noisy * atom_mask[..., None] + act = act / mint.sqrt(noise_level**2 + SIGMA_DATA**2) + enc = self.atom_cross_att_encoder(act, embeddings["single"], trunk_pair_cond, batch) + + act = enc.token_act + act += self.linear_act(self.layer_norm_act(trunk_single_cond)) + act = self.transformer(act, trunk_single_cond, sequence_mask, trunk_pair_cond) + act = self.layer_norm_out(act) + position_update = self.atom_cross_att_decoder(act, enc, batch) + skip_scaling = SIGMA_DATA**2 / (noise_level**2 + SIGMA_DATA**2) + out_scaling = ( + noise_level * SIGMA_DATA / mint.sqrt(noise_level**2 + SIGMA_DATA**2) + ) + return ( + skip_scaling * positions_noisy + out_scaling * position_update + ) * atom_mask[..., None] + +def sample(denoising_step, batch, key, config, init_positions=None): + """Sample using denoiser on batch. + + Args: + denoising_step: the denoising function. + batch: the batch + key: random key + config: config for the sampling process (e.g. number of denoising steps, + etc.) + + Returns: + a dict + { + 'atom_positions': ms.Tensor # shape (, 3) + 'mask': ms.Tensor # shape (,) + } + where the are + (num_samples, num_tokens, max_atoms_per_token) + """ + + mask = batch.predicted_structure_info.atom_mask + # get weight and bias from Jax, this two values cannot be randomly generated + + def apply_denoising_step(carry, noise_level): + key, positions, noise_level_prev = carry + + positions = random_augmentation( + rng_key=key, positions=positions, mask=mask, + ) + gamma = config.gamma_0 * (noise_level > config.gamma_min) + t_hat = noise_level_prev * (1 + gamma) + + noise_scale = config.noise_scale * mint.sqrt(t_hat**2 - noise_level_prev**2) + np.random.seed(key) + noise = noise_scale * ms.Tensor(np.random.normal(0, 1, positions.shape), dtype=ms.float32) + positions_noisy = positions + noise + + positions_denoised = denoising_step(positions_noisy, t_hat) + grad = (positions_noisy - positions_denoised) / t_hat + + d_t = noise_level - t_hat + positions_out = positions_noisy + config.step_scale * d_t * grad + + return (key, positions_out, noise_level), positions_out + + num_samples = config.num_samples + + noise_levels = noise_schedule(mint.linspace(0, 1, config.steps + 1)) + + noise_key, key = key, key + 1 + np.random.seed(noise_key) + if init_positions is None: + init_positions = ms.Tensor(np.random.normal(0, 1, (num_samples,) + mask.shape + (3,)), dtype=ms.float32) + init_positions *= noise_levels[0] + init = (ms.Tensor([key + i for i in range(num_samples)]).reshape((-1, 1)), + init_positions, + mint.tile(noise_levels[None, 0], (num_samples,)).reshape((-1, 1))) + count = 0 + for noise_level in noise_levels[1:]: + for i in range(num_samples): + temp, _ = apply_denoising_step((count * 10 + i, init[1][i], init[2][i]), noise_level) + init[0][i], init[1][i], init[2][i] = temp + count += 1 + _, positions_out, _ = init + + final_dense_atom_mask = mint.tile(mask[None], (num_samples, 1, 1)) + + return {'atom_positions': positions_out, 'mask': final_dense_atom_mask} diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/diffusion_transformer.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/diffusion_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..df02b870cd62f35e9b83525cdc24d5753f962e76 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/diffusion_transformer.py @@ -0,0 +1,488 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Diffusion transformer model.""" + +from dataclasses import dataclass +from alphafold3.model import base_config +from alphafold3.utils.gated_linear_unit import gated_linear_unit +from alphafold3.model.atom_layout import atom_layout +from alphafold3.model.components import base_modules as bm + +from mindspore import mint +import mindspore as ms +from mindspore import nn, ops +from mindchemistry.e3.utils import Ncon + + +class AdaptiveLayernorm(nn.Cell): + """ + If single condition is None, this layer is the same as layernorm. + If single condition is given, the layer is modified from Scalable Diffusion Models with Transformers + https://arxiv.org/abs/2212.09748 + + Args: + num_channels (int): Number of channels in the input tensor. + single_channel (int, optional): Number of channels in the single condition tensor. Required if `with_single_cond` is True. Default: ``None``. + ndim (int, optional): Number of dimensions for the dense layers. Default: ``3``. + with_single_cond (bool, optional): Whether to include the single condition adaptation. Default: ``True``. + + Inputs: + - **x** (Tensor) - Input tensor to be normalized. + - **single_cond** (Tensor, optional) - Optional single condition tensor used to adapt the normalization parameters. Required if `with_single_cond` is True. + + Outputs: + - **output** (Tensor) - The normalized output tensor. + """ + + def __init__(self, num_channels, single_channel=None, ndim=3, with_single_cond=True, dtype=ms.float32): + super().__init__() + self.with_single_cond = with_single_cond + if self.with_single_cond: + self.layernorm = bm.LayerNorm([num_channels], name='layer_norm', + create_gamma=False, create_beta=False, + gamma_init='ones', beta_init='zeros', dtype=ms.float32) + self.single_cond_layer_norm = bm.LayerNorm([single_channel], name='single_cond_layer_norm', + create_beta=False, gamma_init='ones', beta_init='zeros', + dtype=ms.float32) + self.single_cond_scale = bm.CustomDense(single_channel, num_channels, weight_init='zeros', + use_bias=True, bias_init='ones', ndim=ndim, dtype=dtype) + self.single_cond_bias = bm.CustomDense( + single_channel, num_channels, weight_init='zeros', ndim=ndim, dtype=dtype) + else: + self.layernorm = bm.LayerNorm([num_channels], dtype=ms.float32) + + def construct(self, x, single_cond=None): + if not self.with_single_cond: + x = self.layernorm(x) + else: + x = self.layernorm(x) + single_cond = self.single_cond_layer_norm(single_cond) + single_scale = self.single_cond_scale(single_cond) + single_bias = self.single_cond_bias(single_cond) + x = mint.add(mint.mul(mint.sigmoid(single_scale), x), single_bias) + return x + + +class AdaptiveZeroInit(nn.Cell): + """ + An adaptive initialization layer that combines two conditional linear transformations. + + Args: + global_config: Configuration object containing initialization settings. + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + single_channels (int, optional): Number of single conditional channels. Default: ``None``. + ndim (int, optional): Number of dimensions for the dense layer input. Default: ``3``. + with_single_cond (bool, optional): Whether to use single conditional transformation. Default: ``True``. + + Inputs: + - **x** (Tensor) - Input tensor to the layer. + - **single_cond** (Tensor, optional) - Single conditional tensor. Required if `with_single_cond` is True. + + Outputs: + - **output** (Tensor) - Output tensor after applying the adaptive initialization. + """ + + def __init__(self, global_config, in_channels, out_channels, single_channels=None, ndim=3, with_single_cond=True, dtype=ms.float32): + super().__init__() + self.with_single_cond = with_single_cond + self.cond_linear1 = bm.CustomDense( + in_channels, out_channels, weight_init='zeros', ndim=ndim, dtype=dtype) + if self.with_single_cond: + if single_channels is None: + single_channels = in_channels + self.cond_linear2 = bm.CustomDense(single_channels, out_channels, weight_init='zeros', + use_bias=True, bias_init='zeros', ndim=ndim, dtype=dtype) + self.cond_linear2.bias = ms.Parameter(self.cond_linear2.bias * (-2)) + + def construct(self, x, single_cond=None): + if not self.with_single_cond: + output = self.cond_linear1(x) + else: + output = self.cond_linear1(x) + cond = self.cond_linear2(single_cond) + output = mint.mul(mint.sigmoid(cond), output) + return output + + +class TransitionBlock(nn.Cell): + """ + A neural network layer that combines adaptive layer normalization, a gated linear unit (GLU), and adaptive zero initialization to process input data with optional conditional inputs. + + Args: + global_config: Configuration object containing initialization settings. + in_channels (int): Number of input channels. + num_intermediate_factor (int): Factor to determine the number of intermediate channels. + single_channels (int, optional): Number of single conditional channels. Default: ``None``. + ndim (int, optional): Number of dimensions for input tensor. Default: ``3``. + with_single_cond (bool, optional): Whether to use single conditional processing. Default: ``True``. + use_glu_kernel (bool, optional): Whether to use GLU. Default: ``True``. + name (str, optional): Name of the layer. Default: ``''``. + + Inputs: + - **x** (Tensor) - Input tensor to the layer. + - **single_cond** (Tensor, optional) - Single conditional tensor. Required if `with_single_cond` is True. + + Outputs: + - **output** (Tensor) - Output tensor after processing through the TransitionBlock. + """ + + def __init__(self, global_config, in_channels, num_intermediate_factor, single_channels=None, ndim=3, with_single_cond=True, use_glu_kernel=True, name='', dtype=ms.float32): + super().__init__() + self.num_intermediate = num_intermediate_factor * in_channels + if single_channels is None: + single_channels = in_channels + self.adaptive_layernorm = AdaptiveLayernorm( + in_channels, single_channels, ndim=ndim, with_single_cond=with_single_cond, dtype=dtype) + self.use_glu_kernel = use_glu_kernel + if self.use_glu_kernel: + self.weights = bm.custom_initializer( + 'relu', [in_channels, self.num_intermediate * 2], dtype=dtype) + self.weights = ms.Parameter(ms.Tensor(self.weights).reshape( + in_channels, 2, self.num_intermediate)) + else: + self.linear = bm.CustomDense( + in_channels, self.num_intermediate * 2, weight_init='zeros', ndim=3, dtype=dtype) + self.adaptive_zero_init = AdaptiveZeroInit( + global_config, self.num_intermediate, in_channels, single_channels, ndim=ndim, with_single_cond=with_single_cond, dtype=dtype) + + def construct(self, x, single_cond=None): + x = self.adaptive_layernorm(x, single_cond) + if self.use_glu_kernel: + c = gated_linear_unit.gated_linear_unit( + x=x, weight=self.weights.astype(x.dtype), + implementation=None, activation=mint.nn.functional.silu, precision=None + ).astype(x.dtype) + else: + x = self.linear(x) + x0, x1 = ops.split(x, int(x.shape[-1]/2), axis=-1) + c = ops.silu(x0) * x1 + output = self.adaptive_zero_init(c, single_cond) + return output + +@dataclass +class SelfAttentionConfig(base_config.BaseConfig): + num_head: int = 16 + key_dim: int | None = None + value_dim: int | None = None + + +class SelfAttention(nn.Cell): + """ + A self-attention mechanism implementation with adaptive layer normalization and adaptive zero initialization. + + This class implements the self-attention mechanism commonly used in transformer models. It includes adaptive layer normalization for input processing and adaptive zero initialization for the final output. The mechanism computes attention scores using query, key, and value transformations, applies masking, and optionally incorporates pair-wise logits. + + Args: + config: Configuration object containing parameters such as key dimension, value dimension, and number of attention heads. + global_config: Global configuration object for additional settings. + num_channels (int): Number of channels in the input tensor. + in_shape (tuple): Shape of the input tensor. + ndim (int, optional): Number of dimensions for the dense layers. Default: ``3``. + with_single_cond (bool, optional): Whether to include single condition adaptation. Default: ``True``. + + Inputs: + - **x** (Tensor) - Input tensor to the self-attention layer. + - **mask** (Tensor) - Attention mask to apply. + - **single_cond** (Tensor, optional) - Single condition tensor for adaptation. + - **pair_logits** (Tensor, optional) - Additional logits to incorporate into attention scores. + + Outputs: + - **output** (Tensor) - The output tensor after self-attention and adaptive zero initialization. + + Notes: + - The class uses adaptive layer normalization and adaptive zero initialization for processing inputs and outputs. + - The attention mechanism supports optional single condition adaptation and pair-wise logits. + """ + + def __init__(self, config, global_config, num_channels, in_shape, ndim=3, with_single_cond=True, dtype=ms.float32): + super().__init__() + self.config = config + self.global_config = global_config + self.adaptive_layernorm = AdaptiveLayernorm(num_channels, int( + num_channels//2), ndim=ndim, with_single_cond=with_single_cond, dtype=dtype) + key_dim = self.config.key_dim if self.config.key_dim is not None else num_channels + value_dim = self.config.value_dim if self.config.value_dim is not None else num_channels + num_head = self.config.num_head + assert key_dim % num_head == 0, f'{key_dim=} % {num_head=} != 0' + assert value_dim % num_head == 0, f'{value_dim=} % {num_head=} != 0' + key_dim = key_dim // num_head + self.key_dim = key_dim + value_dim = value_dim // num_head + qk_shape = (num_head, key_dim) + v_shape = (num_head, value_dim) + self.q_linear = bm.CustomDense(num_channels, qk_shape, use_bias=True, dtype=dtype) + self.k_linear = bm.CustomDense(num_channels, qk_shape, use_bias=False, dtype=dtype) + self.v_linear = bm.CustomDense(num_channels, v_shape, use_bias=False, dtype=dtype) + self.linear = bm.CustomDense( + num_channels, num_head * value_dim, weight_init='zeros', dtype=dtype) + self.adaptive_zero_init = AdaptiveZeroInit(global_config, num_channels, num_channels, int( + num_channels//2), 2, with_single_cond=with_single_cond, dtype=dtype) + self.ncon1 = Ncon([[-2, -1, 1], [-3, -1, 1]]) + self.ncon2 = Ncon([[-2, -1, 2], [2, -2, -3]]) + + def construct(self, x, mask, single_cond, pair_logits): + bias = (1e9 * (mask - 1.0))[..., None, None, :].astype(x.dtype) + x = self.adaptive_layernorm(x, single_cond) + q = self.q_linear(x) + k = self.k_linear(x) + logits = mint.einsum('...qhc,...khc->...hqk', q * self.key_dim ** (-0.5), k) + bias + if pair_logits is not None: + logits += pair_logits + weights = mint.softmax(logits, dim=-1) + weights = weights.astype(q.dtype) + v = self.v_linear(x) + weighted_avg = mint.einsum('...hqk,...khc->...qhc', weights, v) + weighted_avg = weighted_avg.reshape(weighted_avg.shape[:-2] + (-1,)) + gate_logits = self.linear(x) + weighted_avg *= mint.sigmoid(gate_logits) + output = self.adaptive_zero_init(weighted_avg, single_cond) + return output + + +class Transformer(nn.Cell): + @dataclass + class Config(base_config.BaseConfig): + attention: SelfAttentionConfig = base_config.autocreate() + num_blocks: int = 24 + block_remat: bool = False + super_block_size: int = 4 + num_intermediate_factor: int = 2 + + def __init__(self, config, global_config, in_shape, pair_shape, using_pair_act=False, name="transformer", dtype=ms.float32): + super().__init__() + self.config = config + self.global_config = global_config + self.using_pair_act = using_pair_act + self.act = [] + if using_pair_act: + self.pair_layernorm = bm.LayerNorm(pair_shape, create_beta=False, dtype=ms.float32) + else: + self.pair_layernorm = None + assert self.config.num_blocks % self.config.super_block_size == 0 + self.num_super_blocks = self.config.num_blocks // self.config.super_block_size + self.super_blocks = ms.nn.CellList( + [ + SuperBlock( + config, global_config, self.config.num_blocks, + using_pair_act, in_shape, pair_shape, name, dtype=dtype + ) + for _ in range(self.num_super_blocks) + ] + ) + + @ms.jit + def construct(self, act, single_cond, mask, pair_cond=None): + if pair_cond is None: + pair_act = None + else: + pair_act = self.pair_layernorm(pair_cond) + for i in range(self.num_super_blocks): + act = self.super_blocks[i](act, mask, single_cond, pair_act) + return act + + +class Block(nn.Cell): + def __init__(self, config, global_config, in_shape, dtype=ms.float32): + super().__init__() + self.self_attention = SelfAttention( + config.attention, global_config, in_shape[-1], in_shape, ndim=2, dtype=dtype) + self.transition_block = TransitionBlock(global_config, in_shape[-1], + config.num_intermediate_factor, int(in_shape[-1]//2), ndim=2, dtype=dtype) + + def construct(self, act, mask, single_cond, pair_logits): + act += self.self_attention(act, mask, single_cond, pair_logits) + act += self.transition_block(act, single_cond) + return act + + +class SuperBlock(nn.Cell): + def __init__(self, config, global_config, num_blocks, using_pair_act, in_shape, pair_shape=None, name='', dtype=ms.float32): + super().__init__() + self.config = config + self.global_config = global_config + self.num_blocks = num_blocks + self.using_pair_act = using_pair_act + self.blocks = ms.nn.CellList( + [ + Block( + config, global_config, in_shape, dtype=dtype + ) + for _ in range(self.config.super_block_size) + ] + ) + if self.using_pair_act: + self.pair_linear = bm.CustomDense( + pair_shape[-1], (self.config.super_block_size, self.config.attention.num_head), ndim=3, dtype=dtype) + else: + self.pair_linear = None + + def construct(self, act, mask, single_cond, pair_act): + if pair_act is None: + pair_logits = None + else: + pair_logits = self.pair_linear(pair_act).transpose([2, 3, 0, 1]) + for j in range(self.config.super_block_size): + act = self.blocks[j](act, mask, single_cond, pair_logits[j]) + return act + +@dataclass +class CrossAttentionConfig(base_config.BaseConfig): + num_head: int = 4 + key_dim: int = 128 + value_dim: int = 128 + + +class CrossAttention(nn.Cell): + """ + A CrossAttention class implementing multi-head cross-attention mechanism for processing sequential data. + + Args: + config (Config): Configuration object containing attention settings. + global_config (GlobalConfig): Global configuration object. + in_channel (int): Input dimension for the attention mechanism. + + Inputs: + - **x_q** (Tensor) - Query tensor. + - **x_k** (Tensor) - Key tensor. + - **mask_q** (Tensor) - Query mask tensor. + - **mask_k** (Tensor) - Key mask tensor. + - **pair_logits** (Tensor, optional) - Optional pair logits tensor. Default: ``None``. + - **single_cond_q** (Tensor) - Single condition tensor for queries. + - **single_cond_k** (Tensor) - Single condition tensor for keys. + + Outputs: + - **output** (Tensor) - Output tensor after cross-attention processing. + """ + + def __init__(self, config, global_config, in_channel, dtype=ms.float32): + super().__init__() + self.config = config + self.global_config = global_config + self.adaptive_layernorm_q = AdaptiveLayernorm(in_channel, in_channel, dtype=dtype) + self.adaptive_layernorm_k = AdaptiveLayernorm(in_channel, in_channel, dtype=dtype) + assert config.key_dim % config.num_head == 0 + assert config.value_dim % config.num_head == 0 + self.key_dim = config.key_dim // config.num_head + self.value_dim = config.value_dim // config.num_head + self.linear_q = bm.CustomDense( + in_channel, (self.config.num_head, self.key_dim), use_bias=True, ndim=3, dtype=dtype) + self.linear_k = bm.CustomDense( + in_channel, (self.config.num_head, self.key_dim), use_bias=False, ndim=3, dtype=dtype) + self.linear_v = bm.CustomDense( + in_channel, (self.config.num_head, self.value_dim), use_bias=False, ndim=3, dtype=dtype) + self.ncon1 = Ncon([[-1, -3, -2, 1], [-1, -4, -2, 1]]) + self.ncon2 = Ncon([[-1, -3, -2, 1], [-1, 1, -3, -4]]) + self.gating_query = bm.CustomDense( + in_channel, self.config.num_head * self.value_dim, use_bias=False, + weight_init='zeros', bias_init='ones', ndim=3, dtype=dtype) + self.adaptive_zero_init = AdaptiveZeroInit( + global_config, in_channel, in_channel, in_channel, dtype=dtype) + + def construct(self, x_q, x_k, mask_q, mask_k, pair_logits, single_cond_q, single_cond_k): + """Multihead self-attention.""" + bias = ( + 1e9 + * (mask_q - 1.0)[..., None, :, None] + * (mask_k - 1.0)[..., None, None, :] + ) + x_q = self.adaptive_layernorm_q(x_q, single_cond_q) + x_k = self.adaptive_layernorm_k(x_k, single_cond_k) + q = self.linear_q(x_q) + k = self.linear_k(x_k) + logits = mint.einsum('...qhc,...khc->...hqk', q * self.key_dim ** (-0.5), k) + bias + if pair_logits is not None: + logits += pair_logits + weights = ops.softmax(logits, axis=-1) + v = self.linear_v(x_k) + weighted_avg = mint.einsum('...hqk,...khc->...qhc', weights, v) + weighted_avg = ops.reshape( + weighted_avg, weighted_avg.shape[:-2] + (-1,)) + + gate_logits = self.gating_query(x_q) + weighted_avg *= ops.sigmoid(gate_logits) + + output = self.adaptive_zero_init(weighted_avg, single_cond_q,) + return output + + +class CrossAttTransformer(nn.Cell): + """ + A CrossAttTransformer class implementing a transformer that applies cross attention between two sets of subsets. + + Args: + config (Config): Configuration object containing settings for the transformer. + global_config (GlobalConfig): Global configuration object. + in_shape (tuple): Input shape for the transformer. + + Inputs: + - **queries_act** (Tensor) - Query activations tensor. + - **queries_mask** (Tensor) - Mask tensor for queries. + - **queries_to_keys** (Tensor) - Tensor mapping queries to keys. + - **keys_mask** (Tensor) - Mask tensor for keys. + - **queries_single_cond** (Tensor) - Single condition tensor for queries. + - **keys_single_cond** (Tensor) - Single condition tensor for keys. + - **pair_cond** (Tensor) - Pair condition tensor. + + Outputs: + - **queries_act** (Tensor) - Processed query activations tensor after cross attention. + """ + @dataclass + class Config(base_config.BaseConfig): + num_intermediate_factor: int + num_blocks: int + attention: CrossAttentionConfig = base_config.autocreate() + + def __init__(self, config, global_config, in_shape, dtype=ms.float32): + super().__init__() + self.config = config + self.global_config = global_config + self.pair_input_layer_norm = bm.LayerNorm(in_shape, create_beta=False, dtype=ms.float32) + self.pair_logits_projection = bm.CustomDense( + in_shape[-1], (self.config.num_blocks, self.config.attention.num_head), ndim=4, dtype=dtype) + self.block = ms.nn.CellList( + [ + CrossAttTransformerBlock( + config, global_config, in_shape[-2], dtype=dtype + ) + for _ in range(self.config.num_blocks) + ] + ) + + def construct(self, queries_act, queries_mask, queries_to_keys, + keys_mask, queries_single_cond, keys_single_cond, + pair_cond): + pair_act = self.pair_input_layer_norm(pair_cond) + pair_logits = self.pair_logits_projection(pair_act) + pair_logits = ops.transpose(pair_logits, (3, 0, 4, 1, 2)) + for i in range(self.config.num_blocks): + queries_act = self.block[i](queries_act, queries_mask, queries_to_keys, keys_mask, pair_logits[i], + queries_single_cond, keys_single_cond) + return queries_act + + +class CrossAttTransformerBlock(nn.Cell): + def __init__(self, config, global_config, in_channel, dtype=ms.float32): + super().__init__() + self.cross_attention = CrossAttention( + config.attention, global_config, in_channel, dtype=dtype) + self.transition = TransitionBlock( + global_config, in_channel, config.num_intermediate_factor, dtype=dtype) + + def construct(self, queries_act, queries_mask, queries_to_keys, keys_mask, pair_logits, + queries_single_cond, keys_single_cond): + keys_act = atom_layout.convert_ms( + queries_to_keys, queries_act, layout_axes=(-3, -2) + ) + queries_act += self.cross_attention(queries_act, keys_act, queries_mask, keys_mask, + pair_logits, queries_single_cond, keys_single_cond) + queries_act += self.transition(queries_act, queries_single_cond) + return queries_act diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/distogram_head.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/distogram_head.py new file mode 100644 index 0000000000000000000000000000000000000000..0d7a665f6457dbce9efcb463cd58e8b11d3f3ef5 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/distogram_head.py @@ -0,0 +1,88 @@ +# Copyright 2024 DeepMind Technologies Limited +# Copyright (C) 2025 Huawei Technologies Co., Ltd +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# +# Modifications by Huawei Technologies Co., Ltd: Adapt to run by MindSpore on Ascend + +"""Distogram head.""" + +from typing import Final +from dataclasses import dataclass +import mindspore as ms +from mindspore import nn, ops +from alphafold3.model import base_config +from alphafold3.model.components import base_modules as bm +from mindchemistry.e3.utils import Ncon + + +_CONTACT_THRESHOLD: Final[float] = 8.0 +_CONTACT_EPSILON: Final[float] = 1e-3 + + +class DistogramHead(nn.Cell): + """ + A DistogramHead class that computes a distogram from pair embeddings, predicting distances between residues. + + Args: + config (Config): Configuration object containing parameters for the distogram head. + global_config (GlobalConfig): Global configuration object. + in_channel (int): Number of input channels for the linear layer. + + Inputs: + - **batch** (dict) - Dictionary containing batch features. + - **embeddings** (dict) - Dictionary containing pair embeddings. + + Outputs: + - **bin_edges** (Tensor) - Tensor of bin edges for distance predictions. + - **contact_probs** (Tensor) - Tensor of contact probabilities. + + Notes: + - The distogram head computes distance probabilities using a linear transformation and softmax. + - The Ncon class is used for tensor contraction operations. + """ + @dataclass + class Config(base_config.BaseConfig): + first_break: float = 2.3125 + last_break: float = 21.6875 + num_bins: int = 64 + + def __init__( + self, config, global_config, in_channel, dtype=ms.float32 + ): + super().__init__() + self.config = config + self.global_config = global_config + self.linear = bm.CustomDense( + in_channel, self.config.num_bins, weight_init=self.global_config.final_init, ndim=3, dtype=dtype) + self.ncon = Ncon([[-1, -2, 1], [1]]) + + def construct(self, batch, embeddings): + pair_act = embeddings["pair"] + seq_mask = batch.token_features.mask.astype(ms.bool_) + pair_mask = seq_mask[:, None] * seq_mask[None, :] + left_half_logits = self.linear(pair_act) + right_half_logits = left_half_logits + logits = left_half_logits + ops.swapaxes(right_half_logits, -2, -3) + probs = ops.softmax(logits, axis=-1) + breaks = ops.linspace( + self.config.first_break, + self.config.last_break, + self.config.num_bins - 1, + ) + bin_tops = ops.concat( + (breaks, (breaks[-1] + breaks[-1] - breaks[-2]).reshape(-1))) + threshold = _CONTACT_THRESHOLD + _CONTACT_EPSILON + is_contact_bin = 1.0 * (bin_tops <= threshold) + contact_probs = self.ncon([probs.astype(ms.float32), is_contact_bin.astype(ms.float32)]) + contact_probs = pair_mask * contact_probs + return { + 'bin_edges': breaks, + 'contact_probs': contact_probs, + } diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/featurization.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/featurization.py new file mode 100644 index 0000000000000000000000000000000000000000..439c6502b632ee49bcce6c041cd50ec02d2427a8 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/featurization.py @@ -0,0 +1,212 @@ +# Copyright 2024 DeepMind Technologies Limited +# Copyright (C) 2025 Huawei Technologies Co., Ltd +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# +# Modifications by Huawei Technologies Co., Ltd: Adapt to run by MindSpore on Ascend + +"""Model-side of the input features processing.""" +import math +import numpy as np +import mindspore as ms +from mindspore import ops +from alphafold3.constants import residue_names +from alphafold3.model.components import utils + + +def _grid_keys(key, shape): + """Generate a grid of rng keys that is consistent with different padding. + + Generate random keys such that the keys will be identical, regardless of + how much padding is added to any dimension. + + Args: + key: A PRNG key. + shape: The shape of the output array of keys that will be generated. + + Returns: + An array of shape `shape` consisting of random keys. + """ + if not shape: + return key + + def partial_bitwise_xor(other): + return ops.bitwise_xor(key, other) + + def _partial_grid_keys(key): + return _grid_keys(key, shape[1:]) + new_keys = ms.vmap(partial_bitwise_xor)( + ops.arange(shape[0]) + ) + return ms.vmap(_partial_grid_keys)(new_keys) + + +def _padding_consistent_rng(f): + def inner(key, shape, **kwargs): + keys = _grid_keys(key, shape) + out = keys.flatten() + count = 0 + for key in keys.flatten(): + out[count] = (f((), key)) + count += 1 + return out.reshape(keys.shape) + return inner + + +def gumbel_sample(shape): + uniform_samples = ms.Tensor(np.random.uniform(0.0, 1.0, shape)) + gumbel_samples = -ops.log(-ops.log(uniform_samples)) + return gumbel_samples + + +def gumbel_argsort_sample_idx(key, logits): + gumbel = _padding_consistent_rng(gumbel_sample) + z = gumbel(key, logits.shape) + perm = ops.argsort(logits + z, axis=-1, descending=False) + return perm[::-1] + + +def create_msa_feat(msa): + msa_1hot = ops.one_hot(msa.rows.astype( + ms.int64), residue_names.POLYMER_TYPES_NUM_WITH_UNKNOWN_AND_GAP + 1) + deletion_matrix = msa.deletion_matrix + has_deletion = ops.clip(deletion_matrix, 0.0, 1.0)[..., None] + deletion_value = (ops.arctan(deletion_matrix / 3.0) + * (2.0 / math.pi))[..., None] + msa_feat = [msa_1hot.astype(deletion_value.dtype), has_deletion, deletion_value] + return ops.concat(msa_feat, axis=-1) + + +def truncate_msa_batch(msa, num_msa): + indices = ops.arange(num_msa) + return msa.index_msa_rows(indices) + + +def create_target_feat(batch, append_per_atom_features, dtype=ms.float32): + token_features = batch.token_features + target_features = [] + target_features.append(ops.one_hot( + token_features.aatype.astype(ms.int64), + residue_names.POLYMER_TYPES_NUM_WITH_UNKNOWN_AND_GAP).astype(dtype)) + target_features.append(batch.msa.profile) + target_features.append(batch.msa.deletion_mean[..., None]) + + if append_per_atom_features: + ref_mask = batch.ref_structure.mask + element_feat = ops.one_hot(batch.ref_structure.element, 128) + element_feat = utils.mask_mean( + mask=ref_mask[..., None], value=element_feat, axis=-2, eps=1e-6) + target_features.append(element_feat) + pos_feat = batch.ref_structure.positions + pos_feat = pos_feat.reshape([pos_feat.shape[0], -1]) + target_features.append(pos_feat) + target_features.append(ref_mask) + return ops.concat(target_features, axis=-1) + + +def create_relative_encoding( + seq_features, + max_relative_idx, + max_relative_chain, +): + """Add relative position encodings.""" + rel_feats = [] + token_index = seq_features.token_index + residue_index = seq_features.residue_index + asym_id = seq_features.asym_id + entity_id = seq_features.entity_id + sym_id = seq_features.sym_id + + left_asym_id = asym_id[:, None] + right_asym_id = asym_id[None, :] + + left_residue_index = residue_index[:, None] + right_residue_index = residue_index[None, :] + + left_token_index = token_index[:, None] + right_token_index = token_index[None, :] + + left_entity_id = entity_id[:, None] + right_entity_id = entity_id[None, :] + left_sym_id = sym_id[:, None] + right_sym_id = sym_id[None, :] + + # Embed relative positions using a one-hot embedding of distance along chain + offset = left_residue_index - right_residue_index + clipped_offset = ops.clip( + offset + max_relative_idx, min=0, max=2 * max_relative_idx + ) + asym_id_same = left_asym_id == right_asym_id + final_offset = ops.where( + asym_id_same, + clipped_offset, + (2 * max_relative_idx + 1) * ops.ones_like(clipped_offset), + ) + rel_pos = ops.one_hot(final_offset.astype( + ms.int64), 2 * max_relative_idx + 2) + rel_feats.append(rel_pos) + + # Embed relative token index as a one-hot embedding of distance along residue + token_offset = left_token_index - right_token_index + clipped_token_offset = ops.clip( + token_offset + max_relative_idx, min=0, max=2 * max_relative_idx + ) + residue_same = ops.logical_and((left_asym_id == right_asym_id), ( + left_residue_index == right_residue_index + )) + final_token_offset = ops.where( + residue_same, + clipped_token_offset, + (2 * max_relative_idx + 1) * ops.ones_like(clipped_token_offset), + ) + rel_token = ops.one_hot(final_token_offset.astype( + ms.int64), 2 * max_relative_idx + 2) + rel_feats.append(rel_token) + + # Embed same entity ID + entity_id_same = left_entity_id == right_entity_id + rel_feats.append(entity_id_same.astype(rel_pos.dtype)[..., None]) + + # Embed relative chain ID inside each symmetry class + rel_sym_id = left_sym_id - right_sym_id + + max_rel_chain = max_relative_chain + + clipped_rel_chain = ops.clip( + rel_sym_id + max_rel_chain, min=0, max=2 * max_rel_chain + ) + + final_rel_chain = ops.where( + entity_id_same, + clipped_rel_chain, + (2 * max_rel_chain + 1) * ops.ones_like(clipped_rel_chain), + ) + rel_chain = ops.one_hot(final_rel_chain.astype( + ms.int64), 2 * max_relative_chain + 2) + + rel_feats.append(rel_chain) + + return ops.concat(rel_feats, axis=-1) + + +def shuffle_msa(key, msa): + """Shuffle MSA randomly, return batch with shuffled MSA. + + Args: + key: rng key for random number generation. + msa: MSA object to sample msa from. + + Returns: + Protein with sampled msa. + """ + key, sample_key = key, key + 1 + # Sample uniformly among sequences with at least one non-masked position. + logits = (ops.clip(ops.sum(msa.mask, dim=-1), 0.0, 1.0) - 1.0) * 1e6 + index_order = gumbel_argsort_sample_idx(sample_key, logits) + return msa.index_msa_rows(index_order), sample_key diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/load_ckpt.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/load_ckpt.py new file mode 100644 index 0000000000000000000000000000000000000000..2f1cdfedf8a604a9e6949bd5059bd79ab651edea --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/load_ckpt.py @@ -0,0 +1,579 @@ +# Copyright 2024 DeepMind Technologies Limited +# Copyright (C) 2025 Huawei Technologies Co., Ltd +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# +# Modifications by Huawei Technologies Co., Ltd: Adapt to run by MindSpore on Ascend + + +import pathlib +import mindspore as ms +from mindspore import ops +from alphafold3.model.params import get_model_af3_params + + +def np_slice(arr, i, j, dtype=ms.bfloat16): + if i is not None and j is not None: + return ms.Parameter(ms.Tensor(arr[i, j], dtype)) + if i is not None and j is None: + return ms.Parameter(ms.Tensor(arr[i], dtype)) + if i is None and j is not None: + return ms.Parameter(ms.Tensor(arr[j], dtype)) + return ms.Parameter(ms.Tensor(arr, dtype)) + + + +def load_adaptive_layernorm(adaptive_layernorm, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + if not ckpt.get(f'{path}single_cond_layer_norm'): + adaptive_layernorm.layernorm.layernorm.gamma.set_data( + np_slice(ckpt[f'{path}layer_norm']['scale'], i, j, dtype=ms.float32)) + adaptive_layernorm.layernorm.layernorm.beta.set_data( + np_slice(ckpt[f'{path}layer_norm']['offset'], i, j, dtype=ms.float32)) + else: + adaptive_layernorm.single_cond_layer_norm.layernorm.gamma.set_data( + np_slice(ckpt[f'{path}single_cond_layer_norm']['scale'], i, j, dtype=ms.float32)) + adaptive_layernorm.single_cond_scale.weight.set_data( + np_slice(ckpt[f'{path}single_cond_scale']['weights'], i, j, dtype=dtype)) + adaptive_layernorm.single_cond_scale.bias.set_data( + np_slice(ckpt[f'{path}single_cond_scale']['bias'], i, j, dtype=dtype)) + adaptive_layernorm.single_cond_bias.weight.set_data( + np_slice(ckpt[f'{path}single_cond_bias']['weights'], i, j, dtype=dtype)) + + +def load_adaptive_zero_init(adaptive_zero_init, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + adaptive_zero_init.cond_linear1.weight.set_data( + np_slice(ckpt[f'{path}transition2']['weights'], i, j, dtype=dtype)) + if ckpt.get(f'{path}adaptive_zero_cond'): + adaptive_zero_init.cond_linear2.weight.set_data( + np_slice(ckpt[f'{path}adaptive_zero_cond']['weights'], i, j, dtype=dtype)) + adaptive_zero_init.cond_linear2.bias.set_data( + np_slice(ckpt[f'{path}adaptive_zero_cond']['bias'], i, j, dtype=dtype)) + + +def load_transition(transition_block, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + load_adaptive_layernorm( + transition_block.adaptive_layernorm, f'{path}ffw_', ckpt, i, j, dtype=dtype) + transition_block.weights.set_data( + np_slice(ckpt[f'{path}ffw_transition1']['weights'], i, j, dtype=dtype).reshape( + (transition_block.weights.shape[0], 2, transition_block.num_intermediate))) + load_adaptive_zero_init( + transition_block.adaptive_zero_init, f'{path}ffw_', ckpt, i, j, dtype=dtype) + + +def load_self_attention(self_attention, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + load_adaptive_layernorm( + self_attention.adaptive_layernorm, path, ckpt, i, j) + self_attention.q_linear.weight.set_data( + np_slice(ckpt[f'{path}q_projection']['weights'], i, j, dtype=dtype)) + self_attention.q_linear.bias.set_data( + np_slice(ckpt[f'{path}q_projection']['bias'], i, j, dtype=dtype)) + self_attention.k_linear.weight.set_data( + np_slice(ckpt[f'{path}k_projection']['weights'], i, j, dtype=dtype)) + self_attention.v_linear.weight.set_data( + np_slice(ckpt[f'{path}v_projection']['weights'], i, j, dtype=dtype)) + self_attention.linear.weight.set_data( + np_slice(ckpt[f'{path}gating_query']['weights'], i, j, dtype=dtype)) + load_adaptive_zero_init( + self_attention.adaptive_zero_init, path, ckpt, i, j, dtype=dtype) + + +def load_transformer(transformer, path, ckpt, dtype=ms.bfloat16): + for i in range(6): + for j in range(4): + transformer_path = (path + + '/__layer_stack_with_per_layer/__layer_stack_with_per_layer/transformer') + load_self_attention(transformer.super_blocks[i].blocks[j].self_attention, + transformer_path, ckpt, i, j, dtype=dtype) + load_transition(transformer.super_blocks[i].blocks[j].transition_block, + transformer_path, ckpt, i, j, dtype=dtype) + if transformer.using_pair_act is True: + pair_projection_path = f'{path}/__layer_stack_with_per_layer/pair_logits_projection' + transformer.super_blocks[i].pair_linear.weight.set_data( + np_slice(ckpt[pair_projection_path]['weights'], i, None, dtype=dtype)) + if transformer.using_pair_act is True: + pair_norm_path = f'{path}/pair_input_layer_norm' + transformer.pair_layernorm.layernorm.gamma.set_data( + np_slice(ckpt[pair_norm_path]['scale'].T, dtype=ms.float32)) + + +def load_transition_block(transition_block, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + transition_block.glu_weight.set_data( + np_slice(ckpt[f'{path}/transition1']['weights'], i, j, dtype=dtype).reshape( + (-1, 2, transition_block.num_intermediate))) + transition_block.out_linear.weight.set_data( + np_slice(ckpt[f'{path}/transition2']['weights'], i, j, dtype=dtype)) + transition_block.layernorm.layernorm.gamma.set_data( + np_slice(ckpt[f'{path}/input_layer_norm']['scale'], i, j, dtype=ms.float32)) + transition_block.layernorm.layernorm.beta.set_data( + np_slice(ckpt[f'{path}/input_layer_norm']['offset'], i, j, dtype=ms.float32)) + + +def load_grid_self_attention(grid_self_attention, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + grid_self_attention.q_projection.weight.set_data( + np_slice(ckpt[f'{path}/q_projection']['weights'], i, j, dtype=dtype).transpose(2, 0, 1)) + grid_self_attention.k_projection.weight.set_data( + np_slice(ckpt[f'{path}/k_projection']['weights'], i, j, dtype=dtype).transpose(2, 0, 1)) + grid_self_attention.v_projection.weight.set_data( + np_slice(ckpt[f'{path}/v_projection']['weights'], i, j, dtype=dtype)) + grid_self_attention.gating_query.weight.set_data( + np_slice(ckpt[f'{path}/gating_query']['weights'], i, j, dtype=dtype).T) + grid_self_attention.output_projection.weight.set_data( + np_slice(ckpt[f'{path}/output_projection']['weights'], i, j, dtype=dtype)) + grid_self_attention.pair_bias_projection.weight.set_data( + np_slice(ckpt[f'{path}/pair_bias_projection']['weights'], i, j, dtype=dtype)) + grid_self_attention.act_norm.layernorm.gamma.set_data( + np_slice(ckpt[f'{path}/act_norm']['scale'], i, j, dtype=ms.float32)) + grid_self_attention.act_norm.layernorm.beta.set_data( + np_slice(ckpt[f'{path}/act_norm']['offset'], i, j, dtype=ms.float32)) + + +def load_outer_product_mean(outer_product_mean, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + outer_product_mean.outer_product_mean.o_biases.set_data( + np_slice(ckpt[path]['output_b'], i, j, dtype=dtype)) + outer_product_mean.outer_product_mean.linear_output_weight.set_data( + np_slice(ckpt[path]['output_w'], i, j, dtype=dtype)) + outer_product_mean.outer_product_mean.left_projection_weight.set_data( + np_slice(ckpt[f'{path}/left_projection']['weights'], i, j, dtype=dtype).T) + outer_product_mean.outer_product_mean.right_projection_weight.set_data( + np_slice(ckpt[f'{path}/right_projection']['weights'], i, j, dtype=dtype).T) + outer_product_mean.outer_product_mean.layer_norm_input_gamma.set_data( + np_slice(ckpt[f'{path}/layer_norm_input']['scale'], i, j, dtype=ms.float32)) + outer_product_mean.outer_product_mean.layer_norm_input_beta.set_data( + np_slice(ckpt[f'{path}/layer_norm_input']['offset'], i, j, dtype=ms.float32)) + + +def load_msa_attention(msa_attention, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + msa_attention.actnorm.layernorm.gamma.set_data( + np_slice(ckpt[f'{path}/act_norm']['scale'], i, j, dtype=ms.float32)) + msa_attention.actnorm.layernorm.beta.set_data( + np_slice(ckpt[f'{path}/act_norm']['offset'], i, j, dtype=ms.float32)) + msa_attention.pairnorm.layernorm.gamma.set_data( + np_slice(ckpt[f'{path}/pair_norm']['scale'], i, j, dtype=ms.float32)) + msa_attention.pairnorm.layernorm.beta.set_data( + np_slice(ckpt[f'{path}/pair_norm']['offset'], i, j, dtype=ms.float32)) + msa_attention.pair_logits.weight.set_data( + np_slice(ckpt[f'{path}/pair_logits']['weights'], i, j, dtype=dtype)) + msa_attention.v_projection.weight.set_data( + np_slice(ckpt[f'{path}/v_projection']['weights'], i, j, dtype=dtype)) + msa_attention.gating_query.weight.set_data( + np_slice(ckpt[f'{path}/gating_query']['weights'], i, j, dtype=dtype)) + msa_attention.output_projection.weight.set_data( + np_slice(ckpt[f'{path}/output_projection']['weights'], i, j, dtype=dtype)) + + +def load_triangle_multiplication(triangle_multiplication, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + triangle_multiplication.triangle_multi.gate.weight.set_data( + np_slice(ckpt[f'{path}/gate']['weights'], i, j, dtype=dtype).T) + triangle_multiplication.triangle_multi.projection.weight.set_data( + np_slice(ckpt[f'{path}/projection']['weights'], i, j, dtype=dtype).T) + triangle_multiplication.triangle_multi.weight_glu = ops.stack( + [triangle_multiplication.triangle_multi.gate.weight, + triangle_multiplication.triangle_multi.projection.weight], axis=1) + triangle_multiplication.triangle_multi.output_projection.weight.set_data( + np_slice(ckpt[f'{path}/output_projection']['weights'], i, j, dtype=dtype)) + triangle_multiplication.triangle_multi.gating_linear.weight.set_data( + np_slice(ckpt[f'{path}/gating_linear']['weights'], i, j, dtype=dtype)) + triangle_multiplication.triangle_multi.left_norm_input.layernorm.gamma.set_data( + np_slice(ckpt[f'{path}/left_norm_input']['scale'], i, j, dtype=ms.float32)) + triangle_multiplication.triangle_multi.left_norm_input.layernorm.beta.set_data( + np_slice(ckpt[f'{path}/left_norm_input']['offset'], i, j, dtype=ms.float32)) + triangle_multiplication.triangle_multi.center_norm.layernorm.gamma.set_data( + np_slice(ckpt[f'{path}/center_norm']['scale'], i, j, dtype=ms.float32)) + triangle_multiplication.triangle_multi.center_norm.layernorm.beta.set_data( + np_slice(ckpt[f'{path}/center_norm']['offset'], i, j, dtype=ms.float32)) + + +def load_pair_former(pair_former, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + load_grid_self_attention(pair_former.grid_self_attention1, f'{path}/pair_attention1', + ckpt, i, j, dtype=dtype) + load_grid_self_attention(pair_former.grid_self_attention2, f'{path}/pair_attention2', + ckpt, i, j, dtype=dtype) + load_triangle_multiplication(pair_former.triangle_multiplication1, + f'{path}/triangle_multiplication_outgoing', ckpt, i, j, dtype=dtype) + load_triangle_multiplication(pair_former.triangle_multiplication2, + f'{path}/triangle_multiplication_incoming', ckpt, i, j, dtype=dtype) + load_transition_block(pair_former.transition_block, f'{path}/pair_transition', + ckpt, i, j, dtype=dtype) + if pair_former.with_single: + pair_former.single_pair_logits_norm.layernorm.gamma.set_data( + np_slice(ckpt[f'{path}/single_pair_logits_norm']['scale'], i, j, dtype=ms.float32)) + pair_former.single_pair_logits_norm.layernorm.beta.set_data( + np_slice(ckpt[f'{path}/single_pair_logits_norm']['offset'], i, j, dtype=ms.float32)) + pair_former.single_pair_logits_projection.weight.set_data( + np_slice(ckpt[f'{path}/single_pair_logits_projection']['weights'], i, j, dtype=dtype)) + load_self_attention(pair_former.single_attention, f'{path}/single_attention_', + ckpt, i, j, dtype=dtype) + load_transition_block(pair_former.single_transition, f'{path}/single_transition', + ckpt, i, j, dtype=dtype) + + +def load_evo_former(evo_former, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + load_outer_product_mean(evo_former.outer_product_mean, f'{path}/outer_product_mean', + ckpt, i, j, dtype=dtype) + load_msa_attention(evo_former.msa_attention, f'{path}/msa_attention1', + ckpt, i, j, dtype=dtype) + load_transition_block(evo_former.msa_transition, f'{path}/msa_transition', + ckpt, i, j, dtype=dtype) + load_triangle_multiplication(evo_former.triangle_multiplication1, + f'{path}/triangle_multiplication_outgoing', ckpt, i, j, dtype=dtype) + load_triangle_multiplication(evo_former.triangle_multiplication2, + f'{path}/triangle_multiplication_incoming', ckpt, i, j, dtype=dtype) + load_grid_self_attention(evo_former.pair_attention1, f'{path}/pair_attention1', + ckpt, i, j, dtype=dtype) + load_grid_self_attention(evo_former.pair_attention2, f'{path}/pair_attention2', + ckpt, i, j, dtype=dtype) + load_transition_block(evo_former.transition_block, f'{path}/pair_transition', + ckpt, i, j, dtype=dtype) + + +def load_single_template_embedding(single_template_embedding, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + num_layer = single_template_embedding.config.template_stack.num_layer + for ii in range(num_layer): + template_path = f'{path}/__layer_stack_no_per_layer/template_embedding_iteration' + load_pair_former(single_template_embedding.template_stack[ii], template_path, + ckpt, ii, dtype=dtype) + for jj in range(9): + template_pair_path = f'{path}/template_pair_embedding_{jj}' + single_template_embedding.template_pair_embedding[jj].weight.set_data( + np_slice(ckpt[template_pair_path]['weights'], None, None, dtype=dtype)) + single_template_embedding.output_layer_norm.layernorm.gamma.set_data( + np_slice(ckpt[f'{path}/output_layer_norm']['scale'], i, j, dtype=ms.float32)) + single_template_embedding.output_layer_norm.layernorm.beta.set_data( + np_slice(ckpt[f'{path}/output_layer_norm']['offset'], i, j, dtype=ms.float32)) + single_template_embedding.query_embedding_norm.layernorm.gamma.set_data( + np_slice(ckpt[f'{path}/query_embedding_norm']['scale'], i, j, dtype=ms.float32)) + single_template_embedding.query_embedding_norm.layernorm.beta.set_data( + np_slice(ckpt[f'{path}/query_embedding_norm']['offset'], i, j, dtype=ms.float32)) + + +def load_template_embedding(template_embedding, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + template_embedding.output_linear.weight.set_data( + np_slice(ckpt[f'{path}/output_linear']['weights'], i, j, dtype=dtype)) + load_single_template_embedding(template_embedding.template_embedder, + f'{path}/single_template_embedding', ckpt, i, j, dtype=dtype) + + +def load_distogram_head(distogram_head, path, ckpt, i=None, j=None, dtype=ms.float32): + distogram_head.linear.weight.set_data( + np_slice(ckpt[f'{path}/half_logits']['weights'], i, j, dtype=dtype)) + + +def load_evoformer(evoformer, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + relative_encoding_path = f'{path}/~_relative_encoding/position_activations' + evoformer.position_activations.weight.set_data( + np_slice(ckpt[relative_encoding_path]['weights'], i, j, dtype=dtype)) + evoformer.left_single.weight.set_data( + np_slice(ckpt[f'{path}/left_single']['weights'], i, j, dtype=dtype)) + evoformer.right_single.weight.set_data( + np_slice(ckpt[f'{path}/right_single']['weights'], i, j, dtype=dtype)) + evoformer.bond_embedding.weight.set_data( + np_slice(ckpt[f'{path}/bond_embedding']['weights'], i, j, dtype=dtype)) + evoformer.msa_activations.weight.set_data( + np_slice(ckpt[f'{path}/msa_activations']['weights'], i, j, dtype=dtype)) + evoformer.extra_msa_target_feat.weight.set_data( + np_slice(ckpt[f'{path}/extra_msa_target_feat']['weights'], i, j, dtype=dtype)) + evoformer.prev_embedding.weight.set_data( + np_slice(ckpt[f'{path}/prev_embedding']['weights'], i, j, dtype=dtype)) + evoformer.prev_embedding_layer_norm.layernorm.gamma.set_data( + np_slice(ckpt[f'{path}/prev_embedding_layer_norm']['scale'], i, j, dtype=ms.float32)) + evoformer.prev_embedding_layer_norm.layernorm.beta.set_data( + np_slice(ckpt[f'{path}/prev_embedding_layer_norm']['offset'], i, j, dtype=ms.float32)) + evoformer.single_activations.weight.set_data( + np_slice(ckpt[f'{path}/single_activations']['weights'], i, j, dtype=dtype)) + evoformer.prev_single_embedding.weight.set_data( + np_slice(ckpt[f'{path}/prev_single_embedding']['weights'], i, j, dtype=dtype)) + evoformer.prev_single_embedding_layer_norm.layernorm.gamma.set_data( + np_slice(ckpt[f'{path}/prev_single_embedding_layer_norm']['scale'], i, j, dtype=ms.float32)) + evoformer.prev_single_embedding_layer_norm.layernorm.beta.set_data( + np_slice(ckpt[f'{path}/prev_single_embedding_layer_norm']['offset'], i, j, dtype=ms.float32)) + load_template_embedding(evoformer.template_module, f'{path}/template_embedding', + ckpt, i, j, dtype=dtype) + for ii in range(evoformer.config.pairformer.num_layer): + pairformer_path = path+'/__layer_stack_no_per_layer_1/trunk_pairformer' + load_pair_former( + evoformer.pairformer_stack[ii], pairformer_path, ckpt, ii, dtype=dtype) + for jj in range(evoformer.config.msa_stack.num_layer): + msa_stack_path = path+'/__layer_stack_no_per_layer/msa_stack' + load_evo_former( + evoformer.evoformer_stack[jj], msa_stack_path, ckpt, jj, dtype=dtype) + + +def load_adaptive_layernorm_ms(adaptive_layernorm, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + if not ckpt.get(f'{path}single_cond_layer_norm'): + adaptive_layernorm.layernorm.layernorm.gamma.set_data( + np_slice(ckpt[f'{path}layer_norm']['scale'], i, j, dtype=dtype)) + adaptive_layernorm.layernorm.layernorm.beta.set_data( + np_slice(ckpt[f'{path}layer_norm']['offset'], i, j, dtype=dtype)) + else: + adaptive_layernorm.single_cond_layer_norm.layernorm.gamma.set_data( + np_slice(ckpt[f'{path}single_cond_layer_norm']['scale'], i, j, dtype=dtype)) + adaptive_layernorm.single_cond_scale.weight.set_data( + np_slice(ckpt[f'{path}single_cond_scale']['weights'], i, j, dtype=dtype)) + adaptive_layernorm.single_cond_scale.bias.set_data( + np_slice(ckpt[f'{path}single_cond_scale']['bias'], i, j, dtype=dtype)) + adaptive_layernorm.single_cond_bias.weight.set_data( + np_slice(ckpt[f'{path}single_cond_bias']['weights'], i, j, dtype=dtype)) + + +def load_adaptive_zero_init_ms(adaptive_zero_init, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + adaptive_zero_init.cond_linear1.weight.set_data( + np_slice(ckpt[f'{path}transition2']['weights'], i, j, dtype=dtype)) + if ckpt.get(f'{path}adaptive_zero_cond'): + adaptive_zero_init.cond_linear2.weight.set_data( + np_slice(ckpt[f'{path}adaptive_zero_cond']['weights'], i, j, dtype=dtype)) + adaptive_zero_init.cond_linear2.bias.set_data( + np_slice(ckpt[f'{path}adaptive_zero_cond']['bias'], i, j, dtype=dtype)) + + +def load_transition_ms(transition_block, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + load_adaptive_layernorm_ms( + transition_block.adaptive_layernorm, f'{path}ffw_', ckpt, i, j, dtype=dtype) + transition_block.weights.set_data( + np_slice(ckpt[f'{path}ffw_transition1']['weights'], i, j, dtype=dtype).reshape( + (transition_block.weights.shape[0], 2, transition_block.num_intermediate))) + load_adaptive_zero_init_ms( + transition_block.adaptive_zero_init, f'{path}ffw_', ckpt, i, j, dtype=dtype) + + +def load_self_attention_ms(self_attention, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + load_adaptive_layernorm_ms( + self_attention.adaptive_layernorm, path, ckpt, i, j, dtype=dtype) + self_attention.q_linear.weight.set_data( + np_slice(ckpt[f'{path}q_projection']['weights'], i, j, dtype=dtype)) + self_attention.q_linear.bias.set_data( + np_slice(ckpt[f'{path}q_projection']['bias'], i, j, dtype=dtype)) + self_attention.k_linear.weight.set_data( + np_slice(ckpt[f'{path}k_projection']['weights'], i, j, dtype=dtype)) + self_attention.v_linear.weight.set_data( + np_slice(ckpt[f'{path}v_projection']['weights'], i, j, dtype=dtype)) + self_attention.linear.weight.set_data( + np_slice(ckpt[f'{path}gating_query']['weights'], i, j, dtype=dtype)) + load_adaptive_zero_init_ms( + self_attention.adaptive_zero_init, path, ckpt, i, j, dtype=dtype) + + +def load_transformer_ms(transformer, path, ckpt, dtype=ms.float16): + for i in range(6): + for j in range(4): + transformer_path = (path + + f'/__layer_stack_with_per_layer/__layer_stack_with_per_layer/transformer') + load_self_attention_ms(transformer.super_blocks[i].blocks[j].self_attention, + transformer_path, ckpt, i, j, dtype=dtype) + load_transition_ms(transformer.super_blocks[i].blocks[j].transition_block, + transformer_path, ckpt, i, j, dtype=dtype) + if transformer.using_pair_act: + pair_projection_path = path + f'/__layer_stack_with_per_layer/pair_logits_projection' + transformer.super_blocks[i].pair_linear.weight.set_data( + np_slice(ckpt[pair_projection_path]['weights'], i, None, dtype=dtype)) + if transformer.using_pair_act: + pair_norm_path = f'{path}/pair_input_layer_norm' + transformer.pair_layernorm.layernorm.gamma.set_data( + np_slice(ckpt[pair_norm_path]['scale'].T, None, None, dtype=ms.float32)) + + +def load_cross_attention(cross_attention, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + load_adaptive_layernorm_ms( + cross_attention.adaptive_layernorm_q, f'{path}q', ckpt, i, j, dtype=dtype) + load_adaptive_layernorm_ms( + cross_attention.adaptive_layernorm_k, f'{path}k', ckpt, i, j, dtype=dtype) + cross_attention.linear_q.weight.set_data( + np_slice(ckpt[f'{path}q_projection']['weights'], i, j, dtype=dtype)) + cross_attention.linear_q.bias.set_data( + np_slice(ckpt[f'{path}q_projection']['bias'], i, j, dtype=dtype)) + cross_attention.linear_k.weight.set_data( + np_slice(ckpt[f'{path}k_projection']['weights'], i, j, dtype=dtype)) + cross_attention.linear_v.weight.set_data( + np_slice(ckpt[f'{path}v_projection']['weights'], i, j, dtype=dtype)) + cross_attention.gating_query.weight.set_data( + np_slice(ckpt[f'{path}gating_query']['weights'], i, j, dtype=dtype)) + load_adaptive_zero_init_ms( + cross_attention.adaptive_zero_init, path, ckpt, i, j, dtype=dtype) + + +def load_cross_att_transformer_block(cross_att_transformer_block, path, ckpt, i=None, dtype=ms.bfloat16): + load_cross_attention( + cross_att_transformer_block.cross_attention, path, ckpt, i, dtype=dtype) + load_transition_ms(cross_att_transformer_block.transition, + path, ckpt, i, dtype=dtype) + + +def load_cross_attention_transformer(cross_attention_transformer, path, ckpt, last_name, i, j, dtype=ms.bfloat16): + cross_attention_transformer.pair_input_layer_norm.layernorm.gamma.set_data( + np_slice(ckpt[f'{path}/pair_input_layer_norm']['scale'], i, j, dtype=dtype)) + cross_attention_transformer.pair_logits_projection.weight.set_data( + np_slice(ckpt[f'{path}/pair_logits_projection']['weights'], i, j, dtype=dtype)) + for ii in range(cross_attention_transformer.config.num_blocks): + block_path = path + f'/__layer_stack_with_per_layer/{last_name}' + load_cross_att_transformer_block(cross_attention_transformer.block[ii], block_path, + ckpt, ii, dtype=dtype) + + +def load_per_atom_conditioning(per_atom_conditioning, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + per_atom_conditioning.linear1.weight.set_data( + np_slice(ckpt[f'{path}_embed_ref_pos']['weights'].T, i, j, dtype=dtype)) + per_atom_conditioning.linear2.weight.set_data( + np_slice(ckpt[f'{path}_embed_ref_mask']['weights'].T, i, j, dtype=dtype)) + per_atom_conditioning.linear3.weight.set_data( + np_slice(ckpt[f'{path}_embed_ref_element']['weights'].T, i, j, dtype=dtype)) + per_atom_conditioning.linear4.weight.set_data( + np_slice(ckpt[f'{path}_embed_ref_charge']['weights'].T, i, j, dtype=dtype)) + per_atom_conditioning.linear5.weight.set_data( + np_slice(ckpt[f'{path}_embed_ref_atom_name']['weights'].T, i, j, dtype=dtype)) + per_atom_conditioning.linear_row_act.weight.set_data( + np_slice(ckpt[f'{path}_single_to_pair_cond_row']['weights'].T, i, j, dtype=dtype)) + per_atom_conditioning.linear_col_act.weight.set_data( + np_slice(ckpt[f'{path}_single_to_pair_cond_col']['weights'].T, i, j, dtype=dtype)) + per_atom_conditioning.linear_pair_act1.weight.set_data( + np_slice(ckpt[f'{path}_embed_pair_offsets']['weights'].T, i, j, dtype=dtype)) + per_atom_conditioning.linear_pair_act2.weight.set_data( + np_slice(ckpt[f'{path}_embed_pair_distances']['weights'].T, i, j, dtype=dtype)) + + +def load_atom_cross_encoder(atom_cross_att_encoder, path, ckpt, last_name, i=None, j=None, dtype=ms.bfloat16): + load_per_atom_conditioning( + atom_cross_att_encoder._per_atom_conditioning, path, ckpt, dtype=dtype) + if atom_cross_att_encoder.with_cond: + atom_cross_att_encoder._embed_trunk_single_cond.weight.set_data( + np_slice(ckpt[f'{path}_embed_trunk_single_cond']['weights'].T, i, j, dtype=dtype)) + atom_cross_att_encoder._lnorm_trunk_single_cond.layernorm.gamma.set_data( + np_slice(ckpt[f'{path}_lnorm_trunk_single_cond']['scale'], i, j, dtype=ms.float32)) + atom_cross_att_encoder._atom_positions_to_features.weight.set_data( + np_slice(ckpt[f'{path}_atom_positions_to_features']['weights'].T, i, j, dtype=dtype)) + atom_cross_att_encoder._embed_trunk_pair_cond.weight.set_data( + np_slice(ckpt[f'{path}_embed_trunk_pair_cond']['weights'].T, i, j, dtype=dtype)) + atom_cross_att_encoder._lnorm_trunk_pair_cond.layernorm.gamma.set_data( + np_slice(ckpt[f'{path}_lnorm_trunk_pair_cond']['scale'], i, j, dtype=ms.float32)) + atom_cross_att_encoder._single_to_pair_cond_row.weight.set_data( + np_slice(ckpt[f'{path}_single_to_pair_cond_row_1']['weights'].T, i, j, dtype=dtype)) + atom_cross_att_encoder._single_to_pair_cond_col.weight.set_data( + np_slice(ckpt[f'{path}_single_to_pair_cond_col_1']['weights'].T, i, j, dtype=dtype)) + if atom_cross_att_encoder.with_cond: + atom_cross_att_encoder._embed_pair_offsets.weight.set_data( + np_slice(ckpt[f'{path}_embed_pair_offsets_1']['weights'].T, i, j, dtype=dtype)) + atom_cross_att_encoder._embed_pair_distances.weight.set_data( + np_slice(ckpt[f'{path}_embed_pair_distances_1']['weights'].T, i, j, dtype=dtype)) + else: + atom_cross_att_encoder._embed_pair_offsets.weight.set_data( + np_slice(ckpt[f'{path}_embed_pair_offsets']['weights'].T, i, j, dtype=dtype)) + atom_cross_att_encoder._embed_pair_distances.weight.set_data( + np_slice(ckpt[f'{path}_embed_pair_distances']['weights'].T, i, j, dtype=dtype)) + atom_cross_att_encoder._embed_pair_offsets_valid.weight.set_data( + np_slice(ckpt[f'{path}_embed_pair_offsets_valid']['weights'].T, i, j, dtype=dtype)) + atom_cross_att_encoder._pair_mlp_1.weight.set_data( + np_slice(ckpt[f'{path}_pair_mlp_1']['weights'].T, i, j, dtype=dtype)) + atom_cross_att_encoder._pair_mlp_2.weight.set_data( + np_slice(ckpt[f'{path}_pair_mlp_2']['weights'].T, i, j, dtype=dtype)) + atom_cross_att_encoder._pair_mlp_3.weight.set_data( + np_slice(ckpt[f'{path}_pair_mlp_3']['weights'].T, i, j, dtype=dtype)) + atom_cross_att_encoder._project_atom_features_for_aggr.weight.set_data( + np_slice(ckpt[f'{path}_project_atom_features_for_aggr']['weights'].T, i, j, dtype=dtype)) + load_cross_attention_transformer(atom_cross_att_encoder._atom_transformer_encoder, + f'{path}_atom_transformer_encoder', ckpt, + f"{last_name}_atom_transformer_encoder", i, j, dtype=dtype) + + +def load_atom_cross_decoder(atom_cross_att_decoder, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + atom_cross_att_decoder._project_token_features_for_broadcast.weight.set_data( + np_slice(ckpt[f'{path}_project_token_features_for_broadcast']['weights'].T, i, j, dtype=dtype)) + atom_cross_att_decoder._atom_features_layer_norm.layernorm.gamma.set_data( + np_slice(ckpt[f'{path}_atom_features_layer_norm']['scale'], i, j, dtype=ms.float32)) + atom_cross_att_decoder._atom_features_to_position_update.weight.set_data( + np_slice(ckpt[f'{path}_atom_features_to_position_update']['weights'].T, i, j, dtype=dtype)) + load_cross_attention_transformer(atom_cross_att_decoder._atom_transformer_decoder, + f'{path}_atom_transformer_decoder', ckpt, + last_name='diffusion_atom_transformer_decoder', i=i, j=j, dtype=dtype) + + +def load_diffusion_head(diffusion_head, path, ckpt, i=None, j=None, dtype=ms.float32): + diffusion_head.pair_cond_initial_norm.layernorm.gamma.set_data( + np_slice(ckpt[f'{path}/pair_cond_initial_norm']['scale'], i, j, dtype=ms.float32)) + diffusion_head.pair_cond_initial_projection.weight.set_data( + np_slice(ckpt[f'{path}/pair_cond_initial_projection']['weights'].T, i, j, dtype=ms.float32)) + load_transition_ms(diffusion_head.transition_block1, + f'{path}/pair_transition_0', ckpt, dtype=dtype) + load_transition_ms(diffusion_head.transition_block2, + f'{path}/pair_transition_1', ckpt, dtype=dtype) + diffusion_head.single_cond_initial_norm.layernorm.gamma.set_data( + np_slice(ckpt[f'{path}/single_cond_initial_norm']['scale'], i, j, dtype=ms.float32)) + diffusion_head.single_cond_initial_projection.weight.set_data( + np_slice(ckpt[f'{path}/single_cond_initial_projection']['weights'].T, i, j, dtype=dtype)) + diffusion_head.layer_norm_noise.layernorm.gamma.set_data( + np_slice(ckpt[f'{path}/noise_embedding_initial_norm']['scale'], i, j, dtype=ms.float32)) + diffusion_head.linear_noise.weight.set_data( + np_slice(ckpt[f'{path}/noise_embedding_initial_projection']['weights'].T, i, j, dtype=dtype)) + load_transition_ms(diffusion_head.single_transition1, + f'{path}/single_transition_0', ckpt, dtype=dtype) + load_transition_ms(diffusion_head.single_transition2, + f'{path}/single_transition_1', ckpt, dtype=dtype) + diffusion_head.layer_norm_act.layernorm.gamma.set_data( + np_slice(ckpt[f'{path}/single_cond_embedding_norm']['scale'], i, j, dtype=ms.float32)) + diffusion_head.linear_act.weight.set_data( + np_slice(ckpt[f'{path}/single_cond_embedding_projection']['weights'].T, i, j, dtype=dtype)) + diffusion_head.layer_norm_out.layernorm.gamma.set_data( + np_slice(ckpt[f'{path}/output_norm']['scale'], i, j, dtype=ms.float32)) + load_atom_cross_encoder(diffusion_head.atom_cross_att_encoder, f'{path}/diffusion', ckpt, + last_name="diffusion", dtype=dtype) + load_transformer_ms(diffusion_head.transformer, path + + '/transformer', ckpt, dtype=dtype) + load_atom_cross_decoder( + diffusion_head.atom_cross_att_decoder, f'{path}/diffusion', ckpt, dtype=dtype) + + +def load_confidence_head(confidence_head, path, ckpt, i=None, j=None, dtype=ms.float32): + confidence_head.left_target_feat_project.weight.set_data( + np_slice(ckpt[f'{path}/~_embed_features/left_target_feat_project']['weights'].T, i, j, dtype=dtype)) + confidence_head.right_target_feat_project.weight.set_data( + np_slice(ckpt[f'{path}/~_embed_features/right_target_feat_project']['weights'].T, i, j, dtype=dtype)) + confidence_head.distogram_feat_project.weight.set_data( + np_slice(ckpt[f'{path}/~_embed_features/distogram_feat_project']['weights'].T, i, j, dtype=dtype)) + for ii in range(confidence_head.config.pairformer.num_layer): + confidence_pairformer_path = path + \ + f'/__layer_stack_no_per_layer/confidence_pairformer' + load_pair_former(confidence_head.pairformer_block[ii], confidence_pairformer_path, + ckpt, ii, dtype=dtype) + confidence_head.left_half_distance_logits.weight.set_data( + np_slice(ckpt[f'{path}/left_half_distance_logits']['weights'].T, i, j, dtype=ms.float32)) + confidence_head.logits_ln.layernorm.gamma.set_data( + np_slice(ckpt[f'{path}/logits_ln']['scale'], i, j, dtype=ms.float32)) + confidence_head.logits_ln.layernorm.beta.set_data( + np_slice(ckpt[f'{path}/logits_ln']['offset'], i, j, dtype=ms.float32)) + confidence_head.pae_logits.weight.set_data( + np_slice(ckpt[f'{path}/pae_logits']['weights'].T, i, j, dtype=ms.float32)) + confidence_head.pae_logits_ln.layernorm.gamma.set_data( + np_slice(ckpt[f'{path}/pae_logits_ln']['scale'], i, j, dtype=ms.float32)) + confidence_head.pae_logits_ln.layernorm.beta.set_data( + np_slice(ckpt[f'{path}/pae_logits_ln']['offset'], i, j, dtype=ms.float32)) + confidence_head.plddt_logits.weight.set_data( + np_slice(ckpt[f'{path}/plddt_logits']['weights'], i, j, dtype=ms.float32)) + confidence_head.plddt_logits_ln.layernorm.gamma.set_data( + np_slice(ckpt[f'{path}/plddt_logits_ln']['scale'], i, j, dtype=ms.float32)) + confidence_head.plddt_logits_ln.layernorm.beta.set_data( + np_slice(ckpt[f'{path}/plddt_logits_ln']['offset'], i, j, dtype=ms.float32)) + confidence_head.experimentally_resolved_logits.weight.set_data( + np_slice(ckpt[f'{path}/experimentally_resolved_logits']['weights'], i, j, dtype=ms.float32)) + confidence_head.experimentally_resolved_ln.layernorm.gamma.set_data( + np_slice(ckpt[f'{path}/experimentally_resolved_ln']['scale'], i, j, dtype=ms.float32)) + confidence_head.experimentally_resolved_ln.layernorm.beta.set_data( + np_slice(ckpt[f'{path}/experimentally_resolved_ln']['offset'], i, j, dtype=ms.float32)) + + +def load_diffuser(diffuser, ckpt_dir, dtype=ms.bfloat16): + path = 'diffuser' + ckpt = get_model_af3_params(pathlib.Path(ckpt_dir)) + load_evoformer(diffuser.embedding_module, path + + '/evoformer', ckpt, dtype=dtype) + load_distogram_head(diffuser.distogram_head, path + + '/distogram_head', ckpt, dtype=ms.float32) + load_atom_cross_encoder(diffuser.create_target_feat_embedding.atom_cross_att_encoder, + f'{path}/evoformer_conditioning', ckpt, + last_name='evoformer_conditioning', dtype=ms.float32) + load_diffusion_head(diffuser.diffusion_module, path + + '/~/diffusion_head', ckpt, dtype=ms.float32) + load_confidence_head(diffuser.confidence_head, path + + '/confidence_head', ckpt, dtype=ms.float32) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/model.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/model.py new file mode 100644 index 0000000000000000000000000000000000000000..312c9613ab69e2fa282c3c62818a4f58bbc4a25f --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/model.py @@ -0,0 +1,758 @@ +# Copyright 2024 DeepMind Technologies Limited +# Copyright (C) 2025 Huawei Technologies Co., Ltd +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# +# Modifications by Huawei Technologies Co., Ltd: Adapt to run by MindSpore on Ascend + +from dataclasses import dataclass +import random +import concurrent +import functools +from absl import logging +import numpy as np +import mindspore as ms +from mindspore import ops, nn +from alphafold3.constants import residue_names +from alphafold3.model import base_config +from alphafold3.model import confidences +from alphafold3.model import model_config +from alphafold3.model.atom_layout import atom_layout +from alphafold3.model.components import base_model +from alphafold3.model.components import base_modules as bm +from alphafold3.model.diffusion import atom_cross_attention +from alphafold3.model.diffusion import confidence_head +from alphafold3.model.diffusion import diffusion_head +from alphafold3.model.diffusion import distogram_head +from alphafold3.model.diffusion import featurization +from alphafold3.model.diffusion import modules +from alphafold3.model.diffusion import template_modules +from alphafold3.structure import mmcif + + +def get_predicted_structure(result, batch): + """Creates the predicted structure and ion preditions. + + Args: + result: model output in a model specific layout + batch: model input batch + + Returns: + Predicted structure. + """ + model_output_coords = result['diffusion_samples']['atom_positions'] + + # Rearrange model output coordinates to the flat output layout. + model_output_to_flat = atom_layout.compute_gather_idxs( + source_layout=batch.convert_model_output.token_atoms_layout, + target_layout=batch.convert_model_output.flat_output_layout, + ) + pred_flat_atom_coords = atom_layout.convert( + gather_info=model_output_to_flat, + arr=model_output_coords.asnumpy(), + layout_axes=(-3, -2), + ) + + predicted_lddt = result.get('predicted_lddt') + + if predicted_lddt is not None: + pred_flat_b_factors = atom_layout.convert( + gather_info=model_output_to_flat, + arr=predicted_lddt.asnumpy(), + layout_axes=(-2, -1), + ) + else: + # Handle models which don't have predicted_lddt outputs. + pred_flat_b_factors = np.zeros(pred_flat_atom_coords.shape[:-1]) + + (missing_atoms_indices,) = np.nonzero( + model_output_to_flat.gather_mask == 0) + if missing_atoms_indices.shape[0] > 0: + missing_atoms_flat_layout = batch.convert_model_output.flat_output_layout[ + missing_atoms_indices + ] + missing_atoms_uids = list( + zip( + missing_atoms_flat_layout.chain_id, + missing_atoms_flat_layout.res_id, + missing_atoms_flat_layout.res_name, + missing_atoms_flat_layout.atom_name, + ) + ) + logging.warning( + 'Target %s: warning: %s atoms were not predicted by the ' + 'model, setting their coordinates to (0, 0, 0). ' + 'Missing atoms: %s', + batch.convert_model_output.empty_output_struc.name, + missing_atoms_indices.shape[0], + missing_atoms_uids, + ) + + # Put them into a structure + pred_struc = batch.convert_model_output.empty_output_struc + pred_struc = pred_struc.copy_and_update_atoms( + atom_x=pred_flat_atom_coords[..., 0], + atom_y=pred_flat_atom_coords[..., 1], + atom_z=pred_flat_atom_coords[..., 2], + atom_b_factor=pred_flat_b_factors, + # Always 1.0. + atom_occupancy=np.ones(pred_flat_atom_coords.shape[:-1]), + ) + # Set manually/differently when adding metadata. + pred_struc = pred_struc.copy_and_update_globals(release_date=None) + return pred_struc + + +class CreateTargetFeatEmbedding(nn.Cell): + """ + A class that creates target feature embeddings by combining raw features with cross-attention encoded features. + + Args: + config (Config): Configuration object containing parameters for the target feature embedding. + global_config (GlobalConfig): Global configuration object. + + Inputs: + - **batch** (dict) - Dictionary containing batch features. + + Outputs: + - **target_feat** (Tensor) - Tensor of target feature embeddings. + """ + + def __init__(self, config, global_config, dtype=ms.float32): + super().__init__() + self.config = config + self.global_config = global_config + self.dtype = dtype + self.atom_cross_att_encoder = atom_cross_attention.AtomCrossAttEncoder( + self.config.per_atom_conditioning, self.global_config, '', with_cond=False, dtype=dtype + ) + + def construct(self, batch): + target_feat = featurization.create_target_feat( + batch, + append_per_atom_features=False, + dtype=ms.float32 + ).astype(self.dtype) + enc = self.atom_cross_att_encoder( + token_atoms_act=None, + trunk_single_cond=None, + trunk_pair_cond=None, + batch=batch, + ) + target_feat = ops.concat( + [target_feat, enc.token_act.astype(self.dtype)], axis=-1) + return target_feat + + +def _compute_ptm(result, num_tokens, asym_id, pae_single_mask, interface): + """Computes the pTM metrics from PAE.""" + return np.stack( + [ + confidences.predicted_tm_score( + tm_adjusted_pae=tm_adjusted_pae[:num_tokens, :num_tokens].asnumpy( + ), + asym_id=asym_id.asnumpy(), + pair_mask=pae_single_mask[:num_tokens, :num_tokens], + interface=interface, + ) + for tm_adjusted_pae in result['tmscore_adjusted_pae_global'] + ], + axis=0, + ) + + +def _compute_chain_pair_iptm( + num_tokens, + asym_ids, + mask, + tm_adjusted_pae): + """Computes the chain pair ipTM metrics from PAE.""" + return np.stack( + [ + confidences.chain_pairwise_predicted_tm_scores( + tm_adjusted_pae=sample_tm_adjusted_pae[:num_tokens], + asym_id=asym_ids[:num_tokens], + pair_mask=mask[:num_tokens, :num_tokens], + ) + for sample_tm_adjusted_pae in tm_adjusted_pae + ], + axis=0, + ) + + +class Diffuser(nn.Cell): + """ + Diffuser class for processing and generating diffusion samples, confidence scores, and distanceograms. + + Args: + config (Diffuser.Config): Configuration object containing parameters for the diffuser. + in_channel (int): Number of input channels. + feat_shape (tuple): Shape of the feature tensor. + act_shape (tuple): Shape of the activation tensor. + pair_shape (tuple): Shape of the pair tensor. + single_shape (tuple): Shape of the single tensor. + atom_shape (tuple): Shape of the atom tensor. + out_channel (int): Number of output channels. + num_templates (int): Number of templates. + + Inputs: + - **batch** (dict): Dictionary containing batch data. + - **key** (int): Random key generator. + + Outputs: + - **result** (dict): Dictionary containing diffusion samples, distanceogram, and confidence outputs. + """ + @dataclass + class HeadsConfig(base_config.BaseConfig): + diffusion: diffusion_head.DiffusionHead.Config = base_config.autocreate() + confidence: confidence_head.ConfidenceHead.Config = base_config.autocreate() + distogram: distogram_head.DistogramHead.Config = base_config.autocreate() + + @dataclass + class Config(base_config.BaseConfig): + evoformer: 'Evoformer.Config' = base_config.autocreate() + global_config: model_config.GlobalConfig = base_config.autocreate() + heads: 'Diffuser.HeadsConfig' = base_config.autocreate() + num_recycles: int = 10 + return_embeddings: bool = False + + def __init__(self, config, in_channel, feat_shape, act_shape, pair_shape, single_shape, atom_shape, + out_channel, num_templates, dtype=ms.float32, name="model"): + super().__init__(auto_prefix=True) + self.config = config + self.global_config = config.global_config + self.dtype = dtype + self.diffusion_module = diffusion_head.DiffusionHead( + self.config.heads.diffusion, self.global_config, pair_shape, dtype=ms.float32 + ) + self.embedding_module = Evoformer(self.config.evoformer, self.global_config, + feat_shape, act_shape, pair_shape, single_shape, num_templates, dtype=dtype) + self.create_target_feat_embedding = CreateTargetFeatEmbedding( + self.embedding_module.config, self.global_config, dtype=ms.float32) + self.confidence_head = confidence_head.ConfidenceHead( + self.config.heads.confidence, self.global_config, + pair_shape, single_shape, atom_shape, feat_shape[-1], out_channel, dtype=dtype + ) + self.distogram_head = distogram_head.DistogramHead( + self.config.heads.distogram, self.global_config, pair_shape[-1], dtype=ms.float32 + ) + + def _sample_diffusion(self, batch, embeddings, sample_config, key, init_positions=None): + denoising_step = functools.partial( + self.diffusion_module, + batch=batch, + embeddings=embeddings, + use_conditioning=True, + ) + sample = diffusion_head.sample( + denoising_step=denoising_step, + batch=batch, + key=key+1, + config=sample_config, + init_positions=init_positions, + ) + return sample + + def construct(self, batch, key): + if key is None: + # generate a random number + key = int(np.random.randint(100)) + # batch = feat_batch.Batch.from_data_dict(batch) + target_feat = self.create_target_feat_embedding( + batch) + + def recycle_body(prev, key): + key, subkey = random.randint(0, 1e6), key + embeddings = self.embedding_module( + batch=batch, + prev=prev, + target_feat=target_feat, + key=subkey, + ) + embeddings['pair'] = embeddings['pair'] + embeddings['single'] = embeddings['single'] + return embeddings, key + + num_res = batch.num_res + embeddings = { + 'pair': ops.zeros( + [num_res, num_res, self.config.evoformer.pair_channel], + dtype=ms.float32, + ), + 'single': ops.zeros( + [num_res, self.config.evoformer.seq_channel], dtype=ms.float32 + ), + 'target_feat': target_feat, + } + num_iter = self.config.num_recycles + 1 + for _ in range(num_iter): + embeddings, _ = recycle_body(embeddings, key) + + samples = self._sample_diffusion( + batch, + embeddings, + sample_config=self.config.heads.diffusion.eval, + key=key + ) + confidence_output = [] + for i in range(samples['atom_positions'].shape[0]): + confidence_output.append(self.confidence_head( + dense_atom_positions=samples['atom_positions'][i], + embeddings=embeddings, + seq_mask=batch.token_features.mask, + token_atoms_to_pseudo_beta=batch.pseudo_beta_info.token_atoms_to_pseudo_beta, + asym_id=batch.token_features.asym_id, + )) + for key in confidence_output[0].keys(): + confidence_output[0][key] = ops.stack( + [value[key] for value in confidence_output]) + confidence_output = confidence_output[0] + distogram = self.distogram_head(batch, embeddings) + output = { + 'diffusion_samples': samples, + 'distogram': distogram, + **confidence_output, + } + if self.config.return_embeddings: + output['single_embeddings'] = embeddings['single'] + output['pair_embeddings'] = embeddings['pair'] + return output + + @classmethod + def get_inference_result(cls, batch, result, target_name,): + """Get the predicted structure, scalars, and arrays for inference. + + This function also computes any inference-time quantities, which are not a + part of the forward-pass, e.g. additional confidence scores. Note that this + function is not serialized, so it should be slim if possible. + + Args: + batch: data batch used for model inference, incl. TPU invalid types. + result: output dict from the model's forward pass. + target_name: target name to be saved within structure. + + Yields: + inference_result: dataclass object that contains a predicted structure, + important inference-time scalars and arrays, as well as a slightly trimmed + dictionary of raw model result from the forward pass (for debugging). + """ + del target_name + # Retrieve structure and construct a predicted structure. + pred_structure = get_predicted_structure(result=result, batch=batch) + num_tokens = batch.token_features.seq_length.item() + pae_single_mask = np.tile( + batch.frames.mask[:, None], + [1, batch.frames.mask.shape[0]], + ) + ptm = _compute_ptm( + result=result, + num_tokens=num_tokens, + asym_id=batch.token_features.asym_id[:num_tokens], + pae_single_mask=pae_single_mask, + interface=False, + ) + iptm = _compute_ptm( + result=result, + num_tokens=num_tokens, + asym_id=batch.token_features.asym_id[:num_tokens], + pae_single_mask=pae_single_mask, + interface=True, + ) + ptm_iptm_average = 0.8 * iptm + 0.2 * ptm + + asym_ids = batch.token_features.asym_id[:num_tokens].asnumpy() + chain_ids = [mmcif.int_id_to_str_id(asym_id) for asym_id in asym_ids] + res_ids = batch.token_features.residue_index[:num_tokens] + + if len(np.unique(asym_ids)) > 1: + # There is more than one chain, hence interface pTM (i.e. ipTM) defined, + # so use it. + ranking_confidence = ptm_iptm_average + else: + # There is only one chain, hence ipTM=NaN, so use just pTM. + ranking_confidence = ptm + + contact_probs = result['distogram']['contact_probs'].astype(ms.float32) + # Compute PAE related summaries. + _, chain_pair_pae_min, _ = confidences.chain_pair_pae( + num_tokens=num_tokens, + asym_ids=batch.token_features.asym_id.asnumpy(), + full_pae=result['full_pae'].asnumpy(), + mask=pae_single_mask, + ) + chain_pair_pde_mean, chain_pair_pde_min = confidences.chain_pair_pde( + num_tokens=num_tokens, + asym_ids=batch.token_features.asym_id.asnumpy(), + full_pde=result['full_pde'].asnumpy(), + ) + intra_chain_single_pde, cross_chain_single_pde, _ = confidences.pde_single( + num_tokens, + batch.token_features.asym_id.asnumpy(), + result['full_pde'].asnumpy(), + contact_probs.asnumpy(), + ) + pae_metrics = confidences.pae_metrics( + num_tokens=num_tokens, + asym_ids=batch.token_features.asym_id.asnumpy(), + full_pae=result['full_pae'].asnumpy(), + mask=pae_single_mask, + contact_probs=contact_probs.asnumpy(), + tm_adjusted_pae=result['tmscore_adjusted_pae_interface'].asnumpy(), + ) + ranking_confidence_pae = confidences.rank_metric( + result['full_pae'].asnumpy(), + contact_probs.asnumpy() * batch.frames.mask[:, None].astype(float), + ) + chain_pair_iptm = _compute_chain_pair_iptm( + num_tokens=num_tokens, + asym_ids=batch.token_features.asym_id.asnumpy(), + mask=pae_single_mask, + tm_adjusted_pae=result['tmscore_adjusted_pae_interface'].asnumpy(), + ) + # iptm_ichain is a vector of per-chain ptm values. iptm_ichain[0], + # for example, is just the zeroth diagonal entry of the chain pair iptm + # matrix: + # [[x, , ], + # [ , , ], + # [ , , ]]] + iptm_ichain = chain_pair_iptm.diagonal(axis1=-2, axis2=-1) + # iptm_xchain is a vector of cross-chain interactions for each chain. + # iptm_xchain[0], for example, is an average of chain 0's interactions with + # other chains: + # [[ ,x,x], + # [x, , ], + # [x, , ]]] + iptm_xchain = confidences.get_iptm_xchain(chain_pair_iptm) + + predicted_distance_errors = result['average_pde'] + + # Computing solvent accessible area with dssp can be slow for large + # structures with lots of chains, so we parallelize the call. + pred_structures = pred_structure.unstack() + num_workers = len(pred_structures) + with concurrent.futures.ThreadPoolExecutor( + max_workers=num_workers + ) as executor: + has_clash = list(executor.map( + confidences.has_clash, pred_structures)) + fraction_disordered = list( + executor.map(confidences.fraction_disordered, pred_structures) + ) + for idx, pred_structure in enumerate(pred_structures): + ranking_score = confidences.get_ranking_score( + ptm=ptm[idx], + iptm=iptm[idx], + fraction_disordered_=fraction_disordered[idx], + has_clash_=has_clash[idx], + ) + print(f"####### result {idx} ######") + print(f"####### ranking_score {ranking_score} ######") + print(f"####### predicted_tm_score {ptm[idx]} ######") + print(f"####### interface_predicted_tm_score {iptm[idx]} ######") + yield base_model.InferenceResult( + predicted_structure=pred_structure, + numerical_data={ + 'full_pde': result['full_pde'][idx, :num_tokens, :num_tokens], + 'full_pae': result['full_pae'][idx, :num_tokens, :num_tokens], + 'contact_probs': contact_probs[:num_tokens, :num_tokens], + }, + metadata={ + 'predicted_distance_error': predicted_distance_errors[idx], + 'ranking_score': ranking_score, + 'fraction_disordered': fraction_disordered[idx], + 'has_clash': has_clash[idx], + 'predicted_tm_score': ptm[idx], + 'interface_predicted_tm_score': iptm[idx], + 'chain_pair_pde_mean': chain_pair_pde_mean[idx], + 'chain_pair_pde_min': chain_pair_pde_min[idx], + 'chain_pair_pae_min': chain_pair_pae_min[idx], + 'ptm': ptm[idx], + 'iptm': iptm[idx], + 'ptm_iptm_average': ptm_iptm_average[idx], + 'intra_chain_single_pde': intra_chain_single_pde[idx], + 'cross_chain_single_pde': cross_chain_single_pde[idx], + 'pae_ichain': pae_metrics['pae_ichain'][idx], + 'pae_xchain': pae_metrics['pae_xchain'][idx], + 'ranking_confidence': ranking_confidence[idx], + 'ranking_confidence_pae': ranking_confidence_pae[idx], + 'chain_pair_iptm': chain_pair_iptm[idx], + 'iptm_ichain': iptm_ichain[idx], + 'iptm_xchain': iptm_xchain[idx], + 'token_chain_ids': chain_ids, + 'token_res_ids': res_ids, + }, + ) + + +class Evoformer(nn.Cell): + """ + Evoformer class for generating 'single' and 'pair' embeddings in protein structure prediction. + + Args: + config (Evoformer.Config): Configuration object defining the parameters for the Evoformer module. + global_config (base_config.BaseConfig): Global configuration object containing general settings. + feat_shape (tuple): Shape of the feature tensor. + act_shape (tuple): Shape of the activation tensor. + pair_shape (tuple): Shape of the pair tensor. + single_shape (tuple): Shape of the single tensor. + num_templates (int): Number of templates used in the model. + + Inputs: + - **batch** (dict): Dictionary containing batch data including token features, MSA, and other + relevant information. + - **prev** (dict): Dictionary containing previous embeddings for 'single' and 'pair' activations. + - **target_feat** (Tensor): Target feature tensor used for generating embeddings. + - **key** (int): Random key for reproducibility. + + Outputs: + - **output** (dict): Dictionary containing the generated embeddings: + - **single** (Tensor): Single residue embeddings. + - **pair** (Tensor): Pairwise residue embeddings. + - **target_feat** (Tensor): Target feature tensor. + + Notes: + - The class processes input data through multiple modules including position encoding, bond embedding, + template embedding, MSA processing, and Pairformer iterations. + - The `construct` method iteratively processes the input data to generate rich embeddings for + downstream tasks in protein structure prediction. + """ + @dataclass + # pytype: disable=invalid-function-definition + class PairformerConfig(modules.PairFormerIteration.Config): + block_remat: bool = False + remat_block_size: int = 8 + + @dataclass + class Config(base_config.BaseConfig): + """Configuration for Evoformer.""" + + max_relative_chain: int = 2 + msa_channel: int = 64 + seq_channel: int = 384 + max_relative_idx: int = 32 + num_msa: int = 1024 + pair_channel: int = 128 + pairformer: 'Evoformer.PairformerConfig' = base_config.autocreate( + single_transition=base_config.autocreate(), + single_attention=base_config.autocreate(), + num_layer=48, + ) + per_atom_conditioning: atom_cross_attention.AtomCrossAttEncoderConfig = ( + base_config.autocreate( + per_token_channels=384, + per_atom_channels=128, + atom_transformer=base_config.autocreate( + num_intermediate_factor=2, + num_blocks=3, + ), + per_atom_pair_channels=16, + ) + ) + template: template_modules.TemplateEmbedding.Config = ( + base_config.autocreate() + ) + msa_stack: modules.EvoformerIteration.Config = base_config.autocreate() + + def __init__(self, config, global_config, feat_shape, act_shape, pair_shape, single_shape, + num_templates, dtype=ms.float32): + super().__init__() + self.config = config + self.global_config = global_config + in_channel = feat_shape[-1] + position_activations_in = 4 * self.config.max_relative_idx + \ + 4 + 2 * self.config.max_relative_chain + 2 + 1 + self.position_activations = bm.CustomDense( + position_activations_in, self.config.pair_channel, ndim=3, dtype=dtype) + self.left_single = bm.CustomDense( + in_channel, self.config.pair_channel, ndim=2, dtype=dtype) + self.right_single = bm.CustomDense( + in_channel, self.config.pair_channel, ndim=2, dtype=dtype) + self.bond_embedding = bm.CustomDense( + 1, self.config.pair_channel, ndim=3, dtype=dtype) + self.template_module = template_modules.TemplateEmbedding( + self.config.template, self.global_config, num_templates, act_shape, dtype=dtype + ) + self.msa_activations = bm.CustomDense( + residue_names.POLYMER_TYPES_NUM_WITH_UNKNOWN_AND_GAP + 3, self.config.msa_channel, ndim=3, dtype=dtype) + self.extra_msa_target_feat = bm.CustomDense( + in_channel, self.config.msa_channel, ndim=2, dtype=dtype) + evofromer_act_shape = (self.config.num_msa, + act_shape[1], self.config.msa_channel) + self.evoformer_stack = nn.CellList( + [ + modules.EvoformerIteration( + self.config.msa_stack, self.global_config, evofromer_act_shape, pair_shape, dtype=dtype + ) for _ in range(self.config.msa_stack.num_layer) + ] + ) + self.prev_embedding = bm.CustomDense( + pair_shape[-1], pair_shape[-1], ndim=3, dtype=dtype) + self.prev_embedding_layer_norm = bm.LayerNorm( + pair_shape, dtype=ms.float32) + self.single_activations = bm.CustomDense( + in_channel, self.config.seq_channel, ndim=2, dtype=dtype) + self.prev_single_embedding = bm.CustomDense( + self.config.seq_channel, self.config.seq_channel, ndim=2, dtype=dtype) + self.prev_single_embedding_layer_norm = bm.LayerNorm(act_shape[:-1] + + (self.config.seq_channel,), dtype=ms.float32) + self.pairformer_stack = nn.CellList( + [ + modules.PairFormerIteration( + self.config.pairformer, self.global_config, pair_shape, single_shape, with_single=True, dtype=dtype + ) for _ in range(self.config.pairformer.num_layer) + ] + ) + + def _relative_encoding(self, batch, pair_activations): + rel_feat = featurization.create_relative_encoding( + batch.token_features, + self.config.max_relative_idx, + self.config.max_relative_chain, + ) + rel_feat = rel_feat.astype(pair_activations.dtype) + pair_activations += self.position_activations(rel_feat) + return pair_activations + + def _seq_pair_embedding(self, token_features, target_feat): + left_single = self.left_single(target_feat)[:, None] + right_single = self.right_single(target_feat)[None] + dtype = left_single.dtype + pair_activations = left_single + right_single + num_residues = pair_activations.shape[0] + mask = token_features.mask + pair_mask = (mask[:, None] * mask[None, :]).astype(dtype) + assert pair_mask.shape == (num_residues, num_residues) + return pair_activations, pair_mask + + def _embed_bonds(self, batch, pair_activations): + """Embeds bond features and merges into pair activations.""" + # Construct contact matrix. + num_tokens = batch.token_features.token_index.shape[0] + contact_matrix = ops.zeros((num_tokens, num_tokens)) + + tokens_to_polymer_ligand_bonds = ( + batch.polymer_ligand_bond_info.tokens_to_polymer_ligand_bonds + ) + gather_idxs_polymer_ligand = tokens_to_polymer_ligand_bonds.gather_idxs + gather_mask_polymer_ligand = ( + tokens_to_polymer_ligand_bonds.gather_mask.prod(dim=1).astype( + gather_idxs_polymer_ligand.dtype + )[:, None] + ) + # If valid mask then it will be all 1's, so idxs should be unchanged. + gather_idxs_polymer_ligand = ( + gather_idxs_polymer_ligand * gather_mask_polymer_ligand + ) + tokens_to_ligand_ligand_bonds = ( + batch.ligand_ligand_bond_info.tokens_to_ligand_ligand_bonds + ) + gather_idxs_ligand_ligand = tokens_to_ligand_ligand_bonds.gather_idxs + gather_mask_ligand_ligand = tokens_to_ligand_ligand_bonds.gather_mask.prod( + dim=1 + ).astype(gather_idxs_ligand_ligand.dtype)[:, None] + gather_idxs_ligand_ligand = ( + gather_idxs_ligand_ligand * gather_mask_ligand_ligand + ) + gather_idxs = ops.concat( + [gather_idxs_polymer_ligand, gather_idxs_ligand_ligand] + ) + contact_matrix[gather_idxs[:, 0], gather_idxs[:, 1]] = 1.0 + contact_matrix[0, 0] = 0.0 + + bonds_act = self.bond_embedding( + contact_matrix[:, :, None].astype(pair_activations.dtype) + ) + return pair_activations + bonds_act + + def _embed_template_pair(self, batch, pair_activations, pair_mask, key): + """Embeds Templates and merges into pair activations.""" + dtype = pair_activations.dtype + key, subkey = key, key + 1 + + templates = batch.templates + asym_id = batch.token_features.asym_id + # Construct a mask such that only intra-chain template features are + # computed, since all templates are for each chain individually. + multichain_mask = (asym_id[:, None] == asym_id[None, :]).astype(dtype) + template_fn = functools.partial(self.template_module, key=subkey) + template_act = template_fn( + query_embedding=pair_activations, + templates=templates, + multichain_mask_2d=multichain_mask, + padding_mask_2d=pair_mask, + ) + return pair_activations + template_act, key + + def _embed_process_msa(self, msa_batch, pair_activations, pair_mask, key, target_feat): + """Processes MSA and returns updated pair activations.""" + dtype = pair_activations.dtype + msa_batch = featurization.truncate_msa_batch( + msa_batch, self.config.num_msa) + msa_feat = featurization.create_msa_feat(msa_batch).astype(dtype) + + msa_activations = self.msa_activations(msa_feat) + msa_activations += self.extra_msa_target_feat(target_feat)[None] + msa_mask = msa_batch.mask.astype(dtype) + # Evoformer MSA stack. + evoformer_input = {'msa': msa_activations, 'pair': pair_activations} + mask = {'msa': msa_mask, 'pair': pair_mask} + for i in range(self.config.msa_stack.num_layer): + evoformer_input = self.evoformer_stack[i](evoformer_input, mask) + + return evoformer_input['pair'], key + + def construct(self, batch, prev, target_feat, key): + + dtype = (ms.bfloat16 if self.global_config.bfloat16 == + 'all' else ms.float32) + pair_activations, pair_mask = self._seq_pair_embedding( + batch.token_features, target_feat + ) + pair_activations += self.prev_embedding( + self.prev_embedding_layer_norm( + prev['pair'] + ).astype(pair_activations.dtype) + ) + pair_activations = self._relative_encoding(batch, pair_activations) + pair_activations = self._embed_bonds( + batch=batch, pair_activations=pair_activations + ) + pair_activations, key = self._embed_template_pair( + batch=batch, + pair_activations=pair_activations, + pair_mask=pair_mask, + key=key, + ) + pair_activations, key = self._embed_process_msa( + msa_batch=batch.msa, + pair_activations=pair_activations, + pair_mask=pair_mask, + key=key, + target_feat=target_feat, + ) + del key # Unused after this point. + single_activations = self.single_activations(target_feat) + + single_activations += self.prev_single_embedding( + self.prev_single_embedding_layer_norm( + prev['single'].astype(single_activations.dtype) + ) + ) + for i in range(self.config.pairformer.num_layer): + pair_activations, single_activations = self.pairformer_stack[i]( + pair_activations, pair_mask, single_act=single_activations, + seq_mask=batch.token_features.mask.astype(dtype) + ) + output = { + 'single': single_activations, + 'pair': pair_activations, + 'target_feat': target_feat, + } + + return output diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/modules.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..9ba957636ecd456f42ec20ceefbe84b829986b9d --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/modules.py @@ -0,0 +1,562 @@ +# Copyright 2024 DeepMind Technologies Limited +# Copyright (C) 2025 Huawei Technologies Co., Ltd +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# +# Modifications by Huawei Technologies Co., Ltd: Adapt to run by MindSpore on Ascend + +"""modules for the Diffuser model.""" + +from dataclasses import dataclass +from typing import Literal + +import mindspore as ms +from mindspore import nn, ops, Tensor, mint +from mindchemistry.e3.utils import Ncon +from alphafold3.model import base_config +from alphafold3.utils.attention import attention +from alphafold3.utils.gated_linear_unit.gated_linear_unit import gated_linear_unit +from alphafold3.model.components import base_modules as bm +from alphafold3.model.components import mapping +from alphafold3.model.diffusion import diffusion_transformer +from alphafold3.model.diffusion.triangle import TriangleMultiplication as Triangle +from alphafold3.model.diffusion.triangle import OuterProductMean as ProductMean + + +def get_shard_size(num_residues, shard_spec): + shard_size = shard_spec[0][-1] + for num_residues_upper_bound, num_residues_shard_size in shard_spec: + shard_size = num_residues_shard_size + if ( + num_residues_upper_bound is None + or num_residues <= num_residues_upper_bound + ): + break + return shard_size + + +class TransitionBlock(nn.Cell): + """ + A transition block for transformer networks, implementing either a GLU-based or linear-based transformation. + + Args: + config (Config): Configuration object containing parameters for the transition block. + global_config (GlobalConfig): Global configuration object. + normalized_shape (tuple): Shape of the input tensor for normalization. + ndim (int): Number of dimensions of the input tensor. Default: ``3``. + + Inputs: + - **act** (Tensor) - Input activation tensor to be processed. + + Outputs: + - **output** (Tensor) - Output tensor after processing through the transition block. + """ + @dataclass + class Config(base_config.BaseConfig): + num_intermediate_factor: int = 4 + use_glu_kernel: bool = True + + def __init__( + self, config, global_config, normalized_shape, ndim=3, dtype=ms.float32 + ): + super().__init__() + self.config = config + self.global_config = global_config + num_channels = normalized_shape[-1] + self.num_intermediate = int( + num_channels * self.config.num_intermediate_factor) + self.layernorm = bm.LayerNorm( + normalized_shape, name='input_layer_norm', dtype=ms.float32) + if self.config.use_glu_kernel: + self.glu_weight = bm.custom_initializer( + 'relu', (num_channels, 2 * self.num_intermediate), dtype=dtype) + self.glu_weight = ms.Parameter(Tensor(self.glu_weight).reshape( + num_channels, 2, self.num_intermediate)) + else: + self.linear = bm.CustomDense(num_channels, self.num_intermediate * 2, + weight_init='zeros', ndim=ndim, dtype=dtype) + self.linear.weight = bm.custom_initializer( + 'zeros', self.linear.weight.shape, dtype=dtype) + self.out_linear = bm.CustomDense(self.num_intermediate, num_channels, + weight_init=self.global_config.final_init, ndim=ndim, dtype=dtype) + + def construct(self, act, broadcast_dim=0): + act = self.layernorm(act) + if self.config.use_glu_kernel: + c = gated_linear_unit( + x=act, + weight=self.glu_weight, + implementation=None, + activation=mint.nn.functional.silu, + precision=None + ) + else: + act = self.linear(act) + a, b = mint.split(act, act.shape[-1]//2, axis=-1) + c = mint.nn.functional.silu(a) * b + return self.out_linear(c) + + +class MSAAttention(nn.Cell): + """ + Multi-Head Self-Attention (MSA) attention mechanism for processing sequence and pair data. + + Args: + config (Config): Configuration object containing parameters for the attention mechanism. + global_config (GlobalConfig): Global configuration object. + act_shape (tuple): Shape of the activation tensor. + pair_shape (tuple): Shape of the pair tensor. + + Inputs: + - **act** (Tensor) - Input activation tensor. + - **mask** (Tensor) - Mask tensor to prevent attention weights from focusing on invalid positions. + - **pair_act** (Tensor) - Pair activation tensor. + + Outputs: + - **output** (Tensor) - Output tensor after processing through the attention mechanism. + """ + @dataclass + class Config(base_config.BaseConfig): + num_head: int = 8 + + def __init__(self, config, global_config, act_shape, pair_shape, dtype=ms.float32): + super().__init__() + self.config = config + self.global_config = global_config + self.actnorm = bm.LayerNorm(act_shape, dtype=ms.float32) + self.pairnorm = bm.LayerNorm(pair_shape, dtype=ms.float32) + num_channel = act_shape[-1] + value_dim = num_channel // self.config.num_head + self.pair_logits = bm.CustomDense(pair_shape[-1], self.config.num_head, use_bias=False, + weight_init='zeros', ndim=3, dtype=dtype) + self.v_projection = bm.CustomDense(num_channel, (self.config.num_head, value_dim), + use_bias=False, ndim=len(act_shape), dtype=dtype) + ncon_list1 = [-3, -2, 1] + ncon_list2 = [-1, 1, -3, -4] + self.ncon = Ncon([ncon_list1, ncon_list2]) + self.gating_query = bm.CustomDense( + num_channel, self.config.num_head * value_dim, weight_init='zeros', use_bias=False, ndim=3, dtype=dtype) + self.output_projection = bm.CustomDense(self.config.num_head * value_dim, num_channel, + weight_init=self.global_config.final_init, + use_bias=False, ndim=3, dtype=dtype) + + def construct(self, act, mask, pair_act): + act = self.actnorm(act) + pair_act = self.pairnorm(pair_act) + logits = self.pair_logits(pair_act).transpose([2, 0, 1]) + logits += 1e9 * (mint.max(mask, dim=0)[0] - 1.0) + weights = mint.softmax(logits, dim=-1) + v = self.v_projection(act) + v_avg = self.ncon([weights, v]) + v_avg = v_avg.reshape(v_avg.shape[:-2]+(-1,)) + gate_value = self.gating_query(act) + v_avg *= mint.sigmoid(gate_value) + out = self.output_projection(v_avg) + return out + + +class GridSelfAttention(nn.Cell): + """ + Self-attention mechanism that operates either per-sequence or per-residue. + + Args: + config (Config): Configuration object containing parameters for the attention mechanism. + global_config (GlobalConfig): Global configuration object. + transpose (bool): Whether to transpose the activation tensor during processing. + normalized_shape (tuple): Shape of the input tensor for normalization. + + Inputs: + - **act** (Tensor) - Input activation tensor. + - **pair_mask** (Tensor) - Mask tensor indicating valid regions in the input. + + Outputs: + - **output** (Tensor) - Output tensor after processing through the self-attention mechanism. + """ + @dataclass + class Config(base_config.BaseConfig): + num_head: int = 4 + + def __init__( + self, config, global_config, transpose, normalized_shape, dtype=ms.float32 + ): + super().__init__() + self.config = config + self.global_config = global_config + self.transpose = transpose + num_channels = normalized_shape[-1] + in_shape = normalized_shape[-1] + assert num_channels % self.config.num_head == 0 + qkv_dim = max(num_channels // self.config.num_head, 16) + qkv_shape = (self.config.num_head, qkv_dim) + self.q_projection = bm.CustomDense( + in_shape, qkv_shape, use_bias=False, ndim=3, dtype=dtype) + self.k_projection = bm.CustomDense( + in_shape, qkv_shape, use_bias=False, ndim=3, dtype=dtype) + self.v_projection = bm.CustomDense( + in_shape, qkv_shape, use_bias=False, ndim=3, dtype=dtype) + self.gating_query = bm.CustomDense( + num_channels, self.config.num_head * qkv_dim, weight_init='zeros', use_bias=False, ndim=3, dtype=dtype) + self.output_projection = bm.CustomDense(self.config.num_head * qkv_dim, num_channels, + weight_init=self.global_config.final_init, ndim=3, dtype=dtype) + self.act_norm = bm.LayerNorm(normalized_shape, dtype=ms.float32) + self.pair_bias_projection = bm.CustomDense( + num_channels, self.config.num_head, use_bias=False, weight_init='linear', ndim=3, dtype=dtype) + num_residues = normalized_shape[0] + self.chunk_size = get_shard_size( + num_residues, self.global_config.pair_attention_chunk_size + ) + + def _attention(self, act, mask, bias): + q = self.q_projection(act) + k = self.k_projection(act) + v = self.v_projection(act) + bias = ops.expand_dims(bias, 0) + weighted_avg = attention.dot_product_attention( + q, + k, + v, + mask=mask, + bias=bias, + logits_dtype=ms.float32, + precision=None, + implementation=self.global_config.flash_attention_implementation, + ) + weighted_avg = weighted_avg.reshape(weighted_avg.shape[:-2] + (-1,)) + gate_value = self.gating_query(act) + weighted_avg *= mint.sigmoid(gate_value) + return self.output_projection(weighted_avg) + + def construct(self, act, pair_mask): + """Builds a module. + + Arguments: + act: [num_seq, num_res, channels] activations tensor + pair_mask: [num_seq, num_res] mask of non-padded regions in the tensor. + Only used in inducing points attention currently. + + Returns: + Result of the self-attention operation. + """ + pair_mask = mint.swapaxes(pair_mask, -1, -2) + act = self.act_norm(act) + + non_batched_bias = self.pair_bias_projection(act) + non_batched_bias = non_batched_bias.transpose(2, 0, 1) + if self.transpose: + act = mint.swapaxes(act, -2, -3) + pair_mask = pair_mask[:, None, None, :].astype(ms.bool_) + act = self._attention(act, pair_mask, non_batched_bias) + if self.transpose: + act = mint.swapaxes(act, -2, -3) + return act + + +class TriangleMultiplication(nn.Cell): + """ + Implements triangle multiplication for tensor operations. + + Args: + config (Config): Configuration object specifying the equation and whether to use a GLU kernel. + global_config (GlobalConfig): Global configuration object. + in_channel (int): Number of input channels. + normalized_shape (tuple): Shape of the input tensor for normalization. + batch_size (int, optional): Batch size for processing. Default: ``None``. + + Inputs: + - **act** (Tensor) - Input activation tensor. + - **mask** (Tensor) - Mask tensor indicating valid regions in the input. + + Outputs: + - **out** (Tensor) - Output tensor after triangle multiplication. + """ + @dataclass + class Config(base_config.BaseConfig): + equation: Literal['ikc,jkc->ijc', 'kjc,kic->ijc'] + use_glu_kernel: bool = True + + def __init__(self, config, global_config, in_channel, normalized_shape, batch_size=None, dtype=ms.float32): + super().__init__() + self.config = config + self.global_config = global_config + self.triangle_multi = Triangle( + self.config, + self.global_config, + num_intermediate_channel=in_channel, + equation=self.config.equation, + normalized_shape=normalized_shape, + batch_size=batch_size, + dtype=dtype) + + def construct(self, act, mask): + out = self.triangle_multi(act, mask) + return out + + +class OuterProductMean(nn.Cell): + """ + Implements the OuterProductMean operation for tensor computations. + + Args: + config (Config): Configuration object containing parameters for the operation. + global_config (GlobalConfig): Global configuration object. + num_output_channel (int): Number of output channels. + in_channel (int): Number of input channels. + + Inputs: + - **act** (Tensor) - Input activation tensor. + - **mask** (Tensor) - Mask tensor indicating valid regions in the input. + + Outputs: + - **out** (Tensor) - Output tensor after applying the outer product mean operation. + """ + @dataclass + class Config(base_config.BaseConfig): + chunk_size: int = 128 + num_outer_channel: int = 32 + + def __init__(self, config, global_config, num_output_channel, in_channel, dtype=ms.float32): + super().__init__() + self.config = config + self.global_config = global_config + self.num_output_channel = num_output_channel + self.outer_product_mean = ProductMean(self.config.num_outer_channel, + in_channel, + self.num_output_channel, + dtype=dtype) + + def construct(self, act, mask): + mask_norm = ops.expand_dims(mint.matmul(mask.T, mask), -1) + out = self.outer_product_mean(act, mask, mask_norm) + return out + + +class PairFormerIteration(nn.Cell): + """ + Single Iteration of PairFormer, which processes pairwise and single activations in a single iteration. + + Args: + config (PairFormerIteration.Config): Configuration for the PairFormerIteration module. + global_config: Global configuration for the model. + normalized_shape (tuple): Shape of the input tensor for normalization. + single_shape (tuple | None): Shape of the single activation tensor. Default: ``None``. + with_single (bool): Whether to include single activation processing. Default: ``False``. + + Inputs: + - **act** (Tensor) - Pairwise activations tensor. + - **pair_mask** (Tensor) - Padding mask for pairwise activations. + - **single_act** (Tensor | None) - Single activations tensor, optional. + - **seq_mask** (Tensor | None) - Sequence mask, optional. + + Outputs: + - **act** (Tensor) - Processed pairwise activations tensor. + - **single_act** (Tensor) - Processed single activations tensor (if `with_single` is True). + """ + @dataclass + class Config(base_config.BaseConfig): + """Config for PairFormerIteration.""" + num_layer: int = 1 + pair_attention: GridSelfAttention.Config = base_config.autocreate() + pair_transition: TransitionBlock.Config = base_config.autocreate() + single_attention: diffusion_transformer.SelfAttentionConfig | None = base_config.autocreate() + single_transition: TransitionBlock.Config | None = base_config.autocreate() + triangle_multiplication_incoming: TriangleMultiplication.Config = ( + base_config.autocreate(equation='kjc,kic->ijc') + ) + triangle_multiplication_outgoing: TriangleMultiplication.Config = ( + base_config.autocreate(equation='ikc,jkc->ijc') + ) + shard_transition_blocks: bool = True + + def __init__(self, config, global_config, normalized_shape, single_shape=None, with_single=False, dtype=ms.float32): + super().__init__() + self.config = config + self.global_config = global_config + self.with_single = with_single + num_channel = normalized_shape[-1] + self.triangle_multiplication1 = TriangleMultiplication( + self.config.triangle_multiplication_outgoing, + self.global_config, + num_channel, + normalized_shape, + dtype=dtype + ) + self.triangle_multiplication2 = TriangleMultiplication( + self.config.triangle_multiplication_incoming, + self.global_config, + num_channel, + normalized_shape, + dtype=dtype + ) + self.grid_self_attention1 = GridSelfAttention( + self.config.pair_attention, + self.global_config, + False, + normalized_shape, + dtype=dtype + ) + self.grid_self_attention2 = GridSelfAttention( + self.config.pair_attention, + self.global_config, + True, + normalized_shape, + dtype=dtype + ) + self.transition_block = TransitionBlock( + self.config.pair_transition, self.global_config, normalized_shape, dtype=dtype + ) + num_residues = normalized_shape[0] + if self.config.shard_transition_blocks: + self.transition_block = mapping.sharded_apply( + self.transition_block, + get_shard_size( + num_residues, self.global_config.pair_transition_shard_spec + ) + ) + if self.with_single: + assert self.config.single_attention is not None + self.single_pair_logits_projection = bm.CustomDense( + num_channel, self.config.single_attention.num_head, ndim=3, dtype=dtype + ) + self.single_pair_logits_norm = bm.LayerNorm(normalized_shape, dtype=ms.float32) + self.single_attention = diffusion_transformer.SelfAttention( + self.config.single_attention, self.global_config, + single_shape[-1], normalized_shape, with_single_cond=False, dtype=dtype) + self.single_transition = TransitionBlock( + self.config.single_transition, + self.global_config, + single_shape, + 2, + dtype=dtype + ) + + def construct(self, act, pair_mask, single_act=None, seq_mask=None): + act += self.triangle_multiplication1(act, pair_mask) + act += self.triangle_multiplication2(act, pair_mask) + act += self.grid_self_attention1(act, pair_mask) + act += self.grid_self_attention2(act, pair_mask) + act += self.transition_block(act) + if self.with_single: + norm_act = self.single_pair_logits_norm(act) + pair_logits = self.single_pair_logits_projection(norm_act) + pair_logits = pair_logits.transpose((2, 0, 1)) + single_act += self.single_attention( + single_act, seq_mask, None, pair_logits + ) + single_act += self.single_transition(single_act, + broadcast_dim=None) + return act, single_act + return act + + +class EvoformerIteration(nn.Cell): + """ + EvoformerIteration is a single iteration of the Evoformer main stack, which processes + activations and masks through a series of attention and transformation layers to + update the MSA (Multiple Sequence Alignment) and pair representations. + + Args: + config (EvoformerIteration.Config): Configuration for the EvoformerIteration. + global_config (base_config.BaseConfig): Global configuration for the model. + act_shape (tuple): Shape of the activation tensor. + pair_shape (tuple): Shape of the pair tensor. + + Inputs: + - **activations** (dict): A dictionary containing the MSA and pair activations. + - **masks** (dict): A dictionary containing the MSA and pair masks. + + Outputs: + - **activations** (dict): A dictionary containing the updated MSA and pair activations. + """ + @dataclass + class Config(base_config.BaseConfig): + """Configuration for EvoformerIteration.""" + + num_layer: int = 4 + msa_attention: MSAAttention.Config = base_config.autocreate() + outer_product_mean: OuterProductMean.Config = base_config.autocreate() + msa_transition: TransitionBlock.Config = base_config.autocreate() + pair_attention: GridSelfAttention.Config = base_config.autocreate() + pair_transition: TransitionBlock.Config = base_config.autocreate() + triangle_multiplication_incoming: TriangleMultiplication.Config = ( + base_config.autocreate(equation='kjc,kic->ijc') + ) + triangle_multiplication_outgoing: TriangleMultiplication.Config = ( + base_config.autocreate(equation='ikc,jkc->ijc') + ) + shard_transition_blocks: bool = False + + def __init__(self, config, global_config, act_shape, pair_shape, dtype=ms.float32): + super().__init__() + self.config = config + self.global_config = global_config + num_channel = pair_shape[-1] + self.outer_product_mean = OuterProductMean( + config=self.config.outer_product_mean, + global_config=self.global_config, + num_output_channel=num_channel, + in_channel=act_shape[-1], + dtype=dtype + ) + self.msa_attention = MSAAttention(self.config.msa_attention, + self.global_config, act_shape, pair_shape, dtype=dtype) + self.msa_transition = TransitionBlock( + self.config.msa_transition, self.global_config, act_shape, dtype=dtype + ) + self.triangle_multiplication1 = TriangleMultiplication( + self.config.triangle_multiplication_outgoing, + self.global_config, + num_channel, + pair_shape, + dtype=dtype + ) + self.triangle_multiplication2 = TriangleMultiplication( + self.config.triangle_multiplication_incoming, + self.global_config, + num_channel, + pair_shape, + dtype=dtype + ) + self.pair_attention1 = GridSelfAttention( + self.config.pair_attention, + self.global_config, + False, + pair_shape, + dtype=dtype + ) + self.pair_attention2 = GridSelfAttention( + self.config.pair_attention, + self.global_config, + True, + pair_shape, + dtype=dtype + ) + self.transition_block = TransitionBlock( + self.config.msa_transition, self.global_config, pair_shape, dtype=dtype + ) + num_residues = act_shape[0] + if self.config.shard_transition_blocks: + self.transition_block = mapping.sharded_apply( + self.transition_block, + get_shard_size( + num_residues, self.global_config.pair_transition_shard_spec + ) + ) + + def construct(self, activations, masks): + msa_act, pair_act = activations["msa"], activations["pair"] + msa_mask, pair_mask = masks['msa'], masks['pair'] + pair_act += self.outer_product_mean(msa_act, msa_mask) + msa_act += self.msa_attention(msa_act, msa_mask, pair_act) + msa_act += self.msa_transition(msa_act) + pair_act += self.triangle_multiplication1(pair_act, pair_mask) + pair_act += self.triangle_multiplication2(pair_act, pair_mask) + pair_act += self.pair_attention1(pair_act, pair_mask) + pair_act += self.pair_attention2(pair_act, pair_mask) + pair_act += self.transition_block(pair_act) + return {"msa": msa_act, "pair": pair_act} diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/random/bias.npy b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/random/bias.npy new file mode 100644 index 0000000000000000000000000000000000000000..c7cd7468f857d0a2849491f4a3de779471901d38 Binary files /dev/null and b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/random/bias.npy differ diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/random/weight.npy b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/random/weight.npy new file mode 100644 index 0000000000000000000000000000000000000000..c595d2a5f23945d87a8589c43cb6879ee3a2de48 Binary files /dev/null and b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/random/weight.npy differ diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/template_modules.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/template_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..1749b19b84d0eae0850dd91d60734ef6733b1071 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/template_modules.py @@ -0,0 +1,326 @@ +# Copyright 2024 DeepMind Technologies Limited +# Copyright (C) 2025 Huawei Technologies Co., Ltd +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# +# Modifications by Huawei Technologies Co., Ltd: Adapt to run by MindSpore on Ascend + +"""template modules""" +from dataclasses import dataclass +import mindspore as ms +from mindspore import nn, ops, Tensor, mint + +from alphafold3.model import base_config +from alphafold3.constants import residue_names +from alphafold3.utils import geometry +from alphafold3.model import protein_data_processing +from alphafold3.model.components import base_modules as bm +from alphafold3.model.diffusion import modules +from alphafold3.model.scoring import scoring + + +@dataclass +class DistogramFeaturesConfig(base_config.BaseConfig): + # The left edge of the first bin. + min_bin: float = 3.25 + # The left edge of the final bin. The final bin catches everything larger than + # `max_bin`. + max_bin: float = 50.75 + # The number of bins in the distogram. + num_bins: int = 39 + + +def dgram_from_positions(positions, config, dtype=ms.float32): + """Compute distogram from amino acid positions. + + Args: + positions: (num_res, 3) Position coordinates. + config: Distogram bin configuration. + + Returns: + Distogram with the specified number of bins. + """ + lower_breaks = mint.linspace( + config.min_bin, config.max_bin, config.num_bins) + lower_breaks = mint.square(lower_breaks) + upper_breaks = mint.concat( + [lower_breaks[1:], Tensor([1e8], dtype=ms.float32)], dim=-1) + dist2 = mint.sum(mint.square(ops.expand_dims(positions, axis=-2) + - ops.expand_dims(positions, axis=-3)), dim=-1, keepdim=True) + dgram = (dist2 > lower_breaks).astype(ms.float32) * \ + (dist2 < upper_breaks).astype(ms.float32) + return dgram + + +def slice_index(x, idx): + return ops.gather_d(x, 1, idx.reshape(-1, 1)).squeeze() + + +def make_backbone_rigid(positions, mask, group_indices,): + """Make backbone Rigid3Array and mask. + + Args: + positions: (num_res, num_atoms) of atom positions as Vec3Array. + mask: (num_res, num_atoms) for atom mask. + group_indices: (num_res, num_group, 3) for atom indices forming groups. + + Returns: + tuple of backbone Rigid3Array and mask (num_res,). + """ + backbone_indices = group_indices[:, 0] + + # main backbone frames differ in sidechain frame convention. + # for sidechain it's (C, CA, N), for backbone it's (N, CA, C) + # Hence using c, b, a, each of shape (num_res,). + c, b, a = [backbone_indices[..., i] for i in range(3)] + + rigid_mask = slice_index(mask, a) * \ + slice_index(mask, b) * slice_index(mask, c) + frame_positions = [] + for indices in [a, b, c]: + frame_positions.append(geometry.vector.tree_map( + lambda x, idx=indices: slice_index(x, idx), positions + )) + rotation = geometry.Rot3Array.from_two_vectors( + frame_positions[2] - frame_positions[1], + frame_positions[0] - frame_positions[1], + ) + rigid = geometry.Rigid3Array(rotation, frame_positions[1]) + return rigid, rigid_mask + + +class TemplateEmbedding(nn.Cell): + """ + Embed a set of templates. + + Args: + config (TemplateEmbedding.Config): Configuration for the template embedding. + global_config (base_config.BaseConfig): Global configuration for the model. + num_templates (int): Number of templates to process. + normalized_shape (tuple): Shape of the normalized input tensor. + num_atoms (int): Number of atoms per residue. Default: ``24``. + + Inputs: + - **query_embedding** (Tensor) - Query tensor of shape [num_res, num_res, num_channel]. + - **templates** (Templates) - Object containing template data. + - **padding_mask_2d** (Tensor) - Pair mask for attention operations of shape [num_res, num_res]. + - **multichain_mask_2d** (Tensor) - Pair mask for multichain operations of shape [num_res, num_res]. + - **key** (int) - Random key generator. + + Outputs: + - **embedding** (Tensor) - Output embedding tensor of shape [num_res, num_res, num_channels]. + """ + @dataclass + class Config(base_config.BaseConfig): + num_channels: int = 64 + template_stack: modules.PairFormerIteration.Config = base_config.autocreate( + num_layer=2, + pair_transition=base_config.autocreate(num_intermediate_factor=2), + ) + dgram_features: DistogramFeaturesConfig = base_config.autocreate() + + def __init__(self, config, global_config, num_templates, normalized_shape, num_atoms=24, dtype=ms.float32): + super().__init__() + self.config = config + self.global_config = global_config + self.num_residues = normalized_shape[0] + self.num_templates = num_templates + self.query_num_channels = normalized_shape[2] + self.num_atoms = num_atoms + self.template_embedder = SingleTemplateEmbedding( + self.config, self.global_config, normalized_shape, dtype=dtype) + self.output_linear = bm.CustomDense( + self.config.num_channels, self.query_num_channels, ndim=3, dtype=dtype) + self.output_linear.weight = bm.custom_initializer( + 'relu', (self.config.num_channels, self.query_num_channels), dtype=dtype) + + def construct(self, query_embedding, templates, padding_mask_2d, + multichain_mask_2d, key): + """Generate an embedding for a set of templates. + + Args: + query_embedding: [num_res, num_res, num_channel] a query tensor that will + be used to attend over the templates to remove the num_templates + dimension. + templates: A 'Templates' object. + padding_mask_2d: [num_res, num_res] Pair mask for attention operations. + multichain_mask_2d: [num_res, num_res] Pair mask for multichain. + key: random key generator. + + Returns: + An embedding of size [num_res, num_res, num_channels] + """ + subkeys = mint.arange(key, key + self.num_templates, 1) + summed_template_embeddings = mint.zeros( + (self.num_residues, self.num_residues, + self.config.num_channels), dtype=query_embedding.dtype + ) + + def scan_fn(carry, x): + templates, key = x + embedding = self.template_embedder( + query_embedding, + templates, + padding_mask_2d, + multichain_mask_2d, + key, + ) + return carry + embedding + for i in range(len(subkeys)): + summed_template_embeddings = scan_fn( + summed_template_embeddings, (templates[i], subkeys[i])) + embedding = summed_template_embeddings / (1e-7 + self.num_templates) + embedding = mint.nn.functional.relu(embedding) + embedding = self.output_linear(embedding) + return embedding + + +class SingleTemplateEmbedding(nn.Cell): + """ + Embed a single template. + + Args: + config: Configuration object containing model parameters. + global_config: Global configuration object. + normalized_shape (tuple): Shape for normalization layers. + + Inputs: + - **query_embedding** (Tensor) - Query embedding tensor of shape (num_res, num_res, num_channels). + - **templates** (Templates object) - Object containing single template data. + - **padding_mask_2d** (Tensor) - Padding mask tensor. + - **multichain_mask_2d** (Tensor) - Mask indicating intra-chain residue pairs. + - **key** (random.KeyArray) - Random key generator. + + Outputs: + - **output** (Tensor) - Template embedding tensor of shape (num_res, num_res, num_channels). + """ + + def __init__( + self, + config, + global_config, + normalized_shape, + dtype=ms.float32 + ): + super().__init__() + self.config = config + self.global_config = global_config + num_channels = self.config.num_channels + self.query_embedding_norm = bm.LayerNorm( + normalized_shape, dtype=ms.float32) + + # to be determined the shape of input, output and number of layers + num_layers = 9 + in_shape_list = [39, (), 31, 31, (), (), (), (), 128] + ndim_list = [3, 2, 3, 3, 2, 2, 2, 2, 3] + self.template_pair_embedding = ms.nn.CellList( + [ + bm.CustomDense( + in_shape_list[i], num_channels, weight_init="relu", ndim=ndim_list[i], dtype=dtype + ) + for i in range(num_layers) + ] + ) + self.template_stack = ms.nn.CellList( + [ + modules.PairFormerIteration( + self.config.template_stack, self.global_config, normalized_shape[:-1] + ( + num_channels,), dtype=dtype + ) + for _ in range(self.config.template_stack.num_layer) + ] + ) + self.output_layer_norm = bm.LayerNorm( + normalized_shape[:-1] + (num_channels,), dtype=ms.float32) + + def construct(self, query_embedding, templates, padding_mask_2d, multichain_mask_2d, key): + act = self.construct_input( + query_embedding, templates, multichain_mask_2d) + if self.config.template_stack.num_layer: + for i in range(self.config.template_stack.num_layer): + act = self.template_stack[i](act, padding_mask_2d) + act = self.output_layer_norm(act) + return act + + def construct_input(self, query_embedding, templates, multichain_mask_2d): + # Compute distogram feature for the template. + dtype = multichain_mask_2d.dtype + aatype = templates.aatype + dense_atom_mask = templates.atom_mask + dense_atom_positions = templates.atom_positions + dense_atom_positions *= dense_atom_mask[..., None] + pseudo_beta_positions, pseudo_beta_mask = [ms.Tensor(x) for x in scoring.pseudo_beta_fn( + templates.aatype, dense_atom_positions, dense_atom_mask + )] + pseudo_beta_mask_2d = ( + pseudo_beta_mask[:, None] * pseudo_beta_mask[None, :] + ) + pseudo_beta_mask_2d *= multichain_mask_2d + dgram = dgram_from_positions( + pseudo_beta_positions, self.config.dgram_features + ) + dgram *= pseudo_beta_mask_2d[..., None] + pseudo_beta_mask_2d = pseudo_beta_mask_2d.astype(dtype) + to_concat = [(dgram, 1), (pseudo_beta_mask_2d, 0)] + aatype = mint.nn.functional.one_hot( + aatype.astype(ms.int64), + residue_names.POLYMER_TYPES_NUM_WITH_UNKNOWN_AND_GAP, + ).astype(dtype) + to_concat.append((aatype[None, :, :], 1)) + to_concat.append((aatype[:, None, :], 1)) + template_group_indices = mint.index_select( + ms.Tensor(protein_data_processing.RESTYPE_RIGIDGROUP_DENSE_ATOM_IDX), + 0, + templates.aatype, + ) + rigid, backbone_mask = make_backbone_rigid( + geometry.Vec3Array.from_array(dense_atom_positions), + dense_atom_mask, + template_group_indices, + ) + points = rigid.translation + x = rigid.translation.x.unsqueeze(-1) + y = rigid.translation.y.unsqueeze(-1) + z = rigid.translation.z.unsqueeze(-1) + xx = rigid.rotation.xx.unsqueeze(-1) + xy = rigid.rotation.xy.unsqueeze(-1) + xz = rigid.rotation.xz.unsqueeze(-1) + yx = rigid.rotation.yx.unsqueeze(-1) + yy = rigid.rotation.yy.unsqueeze(-1) + yz = rigid.rotation.yz.unsqueeze(-1) + zx = rigid.rotation.zx.unsqueeze(-1) + zy = rigid.rotation.zy.unsqueeze(-1) + zz = rigid.rotation.zz.unsqueeze(-1) + rigid = geometry.Rigid3Array(geometry.Rot3Array( + xx, xy, xz, yx, yy, yz, zx, zy, zz), geometry.Vec3Array(x, y, z)) + rigid_vec = rigid.inverse().apply_to_point(points) + + unit_vector = rigid_vec.normalized() + unit_vector = [unit_vector.x, unit_vector.y, unit_vector.z] + unit_vector = [x for x in unit_vector] + backbone_mask = backbone_mask + + backbone_mask_2d = backbone_mask[:, None] * backbone_mask[None, :] + backbone_mask_2d *= multichain_mask_2d + unit_vector = [x * backbone_mask_2d for x in unit_vector] + + # Note that the backbone_mask takes into account C, CA and N (unlike + # pseudo beta mask which just needs CB) so we add both masks as features. + to_concat.extend([(x, 0) for x in unit_vector]) + to_concat.append((backbone_mask_2d, 0)) + query_embedding = self.query_embedding_norm(query_embedding) + # Allow the template embedder to see the query embedding. Note this + # contains the position relative feature, so this is how the network knows + # which residues are next to each other. + to_concat.append((query_embedding, 1)) + + act = 0 + for i, (x, _) in enumerate(to_concat): + act += self.template_pair_embedding[i](x) + return act diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/triangle.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/triangle.py new file mode 100644 index 0000000000000000000000000000000000000000..af006b7eac983a7ad4d8678568637caf0eb0de56 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/triangle.py @@ -0,0 +1,262 @@ +# Copyright 2024 DeepMind Technologies Limited +# Copyright (C) 2025 Huawei Technologies Co., Ltd +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# +# Modifications by Huawei Technologies Co., Ltd: Adapt to run by MindSpore on Ascend + +"""Triangle""" +import numpy as np +import mindspore as ms +import mindspore.nn as nn +import mindspore.common.dtype as mstype +from mindspore import Parameter, mint +from mindspore.common.tensor import Tensor +import mindspore.ops as ops +from mindspore.ops import operations as P +from mindspore.common.initializer import initializer +from mindsponge.common.utils import _memory_reduce +from mindsponge.cell.initializer import lecun_init +from mindsponge.cell.mask import MaskedLayerNorm +from mindchemistry.e3.utils import Ncon + +from alphafold3.utils.gated_linear_unit import gated_linear_unit +from alphafold3.model.components.base_modules import LayerNorm, CustomDense + + +class TriangleMultiplication(nn.Cell): + r""" + Triangle multiplication layer. for the detailed implementation process, refer to + `TriangleMultiplication `_. + + The information between the amino acid pair is integrated through the information of three edges ij, ik, jk, and + the result of the dot product between ik and jk is added to the edge of ij. + + Args: + num_intermediate_channel (float): The number of intermediate channel. + equation (str): The equation used in triangle multiplication layer. edge update forms + corresponding to 'incoming' and 'outgoing', + :math:`(ikc,jkc->ijc, kjc,kic->ijc)`. + layer_norm_dim (int): The last dimension length of the layer norm. + batch_size (int): The batch size of parameters in triangle multiplication. Default: ``None``. + + Inputs: + - **pair_act** (Tensor) - Tensor of pair_act. shape :math:`(N{res}, N{res}, layer\_norm\_dim)`. + - **pair_mask** (Tensor) - The mask for TriangleAttention matrix with shape. shape :math:`(N{res}, N{res})`. + - **index** (Tensor) - The index of while loop, only used in case of while control + flow. + + Outputs: + Tensor, the float tensor of the pair_act of the layer with shape :math:`(N{res}, N{res}, layer\_norm\_dim)`. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> import numpy as np + >>> from mindsponge.cell import TriangleMultiplication + >>> from mindspore import dtype as mstype + >>> from mindspore import Tensor + >>> model = TriangleMultiplication(num_intermediate_channel=64, + ... equation="ikc,jkc->ijc", layer_norm_dim=64, batch_size=0) + >>> input_0 = Tensor(np.ones((256, 256, 64)), mstype.float32) + >>> input_1 = Tensor(np.ones((256, 256)), mstype.float32) + >>> out = model(input_0, input_1, index=0) + >>> print(out.shape) + (256, 256, 64) + """ + + def __init__(self, config, global_config, num_intermediate_channel, equation, normalized_shape, + batch_size=None, dtype=ms.float32): + super().__init__() + self.config = config + self.global_config = global_config + self.num_intermediate_channel = num_intermediate_channel + self.left_norm_input = LayerNorm(normalized_shape, dtype=ms.float32) + self.center_norm = LayerNorm(normalized_shape, dtype=ms.float32) + self.projection = nn.Dense( + normalized_shape[-1], num_intermediate_channel * 2, has_bias=False, dtype=dtype) + self.gate = nn.Dense(normalized_shape[-1], num_intermediate_channel * 2, + weight_init=self.global_config.final_init, has_bias=False, dtype=dtype) + self.output_projection = CustomDense( + normalized_shape[-1], num_intermediate_channel, weight_init=self.global_config.final_init, + ndim=3, dtype=dtype) + self.gating_linear = CustomDense( + num_intermediate_channel, num_intermediate_channel, weight_init=self.global_config.final_init, + ndim=3, dtype=dtype) + self.weight_glu = mint.stack( + [self.gate.weight.T, self.projection.weight.T], dim=1) + if self.config.equation == "ikc,jkc->ijc": + ncon_list = [[-1, -2, 1], [-1, -3, 1]] + elif self.config.equation == "kjc,kic->ijc": + ncon_list = [[-1, 1, -3], [-1, 1, -2]] + else: + raise ValueError("Not support this equation.") + self.ncon = Ncon(ncon_list) + + def construct(self, act, mask, use_glu=True): + r""" + Builds triangle multiplication module. + + Args: + act(Tensor): Pair activations. Data type is float. + mask(Tensor): Pair mask. Data type is float. + + Returns: + act(Tensor), the shape is same as act_shape[:-1]. + """ + self.weight_glu = mint.stack( + [self.gate.weight.T, self.projection.weight.T], dim=1) + + mask = mask[None, ...] + act = self.left_norm_input(act) + input_act = act + + if use_glu is True: + projection = gated_linear_unit.gated_linear_unit( + x=act, + weight=self.weight_glu, + activation=ms.mint.sigmoid, + implementation=None, + precision=None, + ) + projection = ops.transpose(projection, (2, 0, 1)) + projection *= mask + else: + projection = self.projection(act) + projection = ops.transpose(projection, (2, 0, 1)) + projection *= mask + gate = self.gate(act) + gate = ops.transpose(gate, (2, 0, 1)) + projection *= ms.mint.sigmoid(gate) + projection = projection.reshape( + self.num_intermediate_channel, 2, *projection.shape[1:]) + a, b = projection[:, 0], projection[:, 1] + act = self.ncon([a, b]) + act = self.center_norm(act.transpose((1, 2, 0))) + act = self.output_projection(act) + gate_out = self.gating_linear(input_act) + act *= mint.sigmoid(gate_out) + return act + + +class OuterProductMean(nn.Cell): + r""" + Computing the correlation of the input tensor along its second dimension, the computed correlation + could be used to update the correlation features(e.g. the Pair representation). + + .. math:: + OuterProductMean(\mathbf{act}) = Linear(flatten(mean(\mathbf{act}\otimes\mathbf{act}))) + + Args: + num_outer_channel (float): The last dimension size of intermediate layer in OuterProductMean. + act_dim (int): The last dimension size of the input act. + num_output_channel (int): The last dimension size of output. + batch_size(int): The batch size of parameters in OuterProductMean, + used in while control flow. Default: "None". + slice_num (int): The slice num used in OuterProductMean layer + when the memory is overflow. Default: 0. + + Inputs: + - **act** (Tensor) - The input tensor with shape :math:`(dim_1, dim_2, act\_dim)`. + - **mask** (Tensor) - The mask for OuterProductMean with shape :math:`(dim_1, dim_2)`. + - **mask_norm** (Tensor) - Squared L2-norm along the first dimension of **mask**, + pre-computed to avoid re-computing, its shape is :math:`(dim_2, dim_2, 1)`. + - **index** (Tensor) - The index of while loop, only used in case of while control + flow. Default: "None". + + Outputs: + Tensor, the float tensor of the output of OuterProductMean layer with + shape :math:`(dim_2, dim_2, num\_output\_channel)`. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> from mindsponge.cell import OuterProductMean + >>> from mindspore import dtype as mstype + >>> from mindspore import Tensor + >>> from mindspore.ops import operations as P + >>> model = OuterProductMean(num_outer_channel=32, act_dim=128, num_output_channel=256) + >>> act = Tensor(np.ones((32, 64, 128)), mstype.float32) + >>> mask = Tensor(np.ones((32, 64)), mstype.float32) + >>> mask_norm = P.ExpandDims()(P.MatMul(transpose_a=True)(mask, mask), -1) + >>> output= model(act, mask, mask_norm) + >>> print(output.shape) + (64, 64, 256) + """ + + def __init__(self, num_outer_channel, act_dim, num_output_channel, batch_size=None, slice_num=0, dtype=ms.float32): + super(OuterProductMean, self).__init__() + self.dtype = dtype + self.num_output_channel = num_output_channel + self.num_outer_channel = num_outer_channel + self.layer_norm_input = MaskedLayerNorm() + self.matmul_trans_b = P.MatMul(transpose_b=True) + self.matmul = P.MatMul() + self.batch_matmul_trans_b = P.BatchMatMul(transpose_b=True) + self.act_dim = act_dim + self.batch_size = batch_size + self.slice_num = slice_num + self.idx = Tensor(0, mstype.int32) + self._init_parameter() + + def construct(self, act, mask, mask_norm, index=None): + """Compute outer product mean.""" + mask = P.ExpandDims()(mask, -1) + act = self.layer_norm_input( + act, self.layer_norm_input_gamma, self.layer_norm_input_beta) + act_shape = P.Shape()(act) + if len(act_shape) != 2: + act = P.Reshape()(act, (-1, act_shape[-1])) + out_shape = act_shape[:-1] + (-1,) + left_act = mask * P.Reshape()( + P.BiasAdd()(self.matmul_trans_b(act, self.left_projection_weight), self.left_projection_bias), out_shape) + right_act = mask * P.Reshape()( + P.BiasAdd()(self.matmul_trans_b(act, self.right_projection_weight), self.right_projection_bias), out_shape) + _, d, e = right_act.shape + batched_inputs = (left_act,) + nonbatched_inputs = (right_act, self.linear_output_weight, + self.o_biases, d, e) + act = _memory_reduce(self._compute, batched_inputs, + nonbatched_inputs, self.slice_num, 1) + epsilon = 1e-3 + act = P.RealDiv()(act, epsilon + mask_norm) + return act + + def _init_parameter(self): + '''init parameter''' + self.layer_norm_input_gamma = Parameter( + Tensor(np.ones((self.act_dim)), self.dtype)) + self.layer_norm_input_beta = Parameter( + Tensor(np.zeros((self.act_dim)), self.dtype)) + self.left_projection_weight = Parameter( + initializer(lecun_init(self.act_dim), [self.num_outer_channel, self.act_dim], self.dtype)) + self.left_projection_bias = Tensor( + np.zeros((self.num_outer_channel)), self.dtype) + self.right_projection_weight = Parameter( + initializer(lecun_init(self.act_dim), [self.num_outer_channel, self.act_dim], self.dtype)) + self.right_projection_bias = Tensor( + np.zeros((self.num_outer_channel)), self.dtype) + self.linear_output_weight = Parameter( + Tensor(np.zeros((self.num_outer_channel, self.num_outer_channel, self.num_output_channel)), + self.dtype)) + self.o_biases = Parameter( + Tensor(np.zeros((self.num_output_channel)), self.dtype)) + + def _compute(self, left_act, right_act, linear_output_weight, linear_output_bias, d, e): + '''compute outer product mean''' + + left_act = left_act.transpose((0, 2, 1)) + act = Ncon([[1, -2, -4], [1, -1, -3]])([left_act, right_act]) + act = Ncon([[-1, 1, 2, -2], [1, 2, -3]] + )([act, linear_output_weight]) + linear_output_bias + act = P.Transpose()(act, (1, 0, 2)) + return act diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/feat_batch.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/feat_batch.py new file mode 100644 index 0000000000000000000000000000000000000000..bc69b9d3e33bba22a6cbfcb4e815d6f338bb37fe --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/feat_batch.py @@ -0,0 +1,180 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Batch dataclass.""" + +import dataclasses +from typing import Self +import mindspore as ms +from mindspore import Tensor +from alphafold3.model import features + + +@dataclasses.dataclass +class Batch: + """Dataclass containing batch.""" + + msa: features.MSA + templates: features.Templates + token_features: features.TokenFeatures + ref_structure: features.RefStructure + predicted_structure_info: features.PredictedStructureInfo + polymer_ligand_bond_info: features.PolymerLigandBondInfo + ligand_ligand_bond_info: features.LigandLigandBondInfo + pseudo_beta_info: features.PseudoBetaInfo + atom_cross_att: features.AtomCrossAtt + convert_model_output: features.ConvertModelOutput + frames: features.Frames + + @property + def num_res(self) -> int: + return self.token_features.aatype.shape[-1] + + @staticmethod + def gather_to_tensor(input_feat): + input_feat.gather_idxs = Tensor(input_feat.gather_idxs) + input_feat.gather_mask = Tensor(input_feat.gather_mask) + input_feat.input_shape = Tensor(input_feat.input_shape) + + @classmethod + def from_data_dict(cls, batch: features.BatchDict) -> Self: + """Construct batch object from dictionary.""" + return cls( + msa=features.MSA.from_data_dict(batch), + templates=features.Templates.from_data_dict(batch), + token_features=features.TokenFeatures.from_data_dict(batch), + ref_structure=features.RefStructure.from_data_dict(batch), + predicted_structure_info=features.PredictedStructureInfo.from_data_dict( + batch + ), + polymer_ligand_bond_info=features.PolymerLigandBondInfo.from_data_dict( + batch + ), + ligand_ligand_bond_info=features.LigandLigandBondInfo.from_data_dict( + batch + ), + pseudo_beta_info=features.PseudoBetaInfo.from_data_dict(batch), + atom_cross_att=features.AtomCrossAtt.from_data_dict(batch), + convert_model_output=features.ConvertModelOutput.from_data_dict( + batch), + frames=features.Frames.from_data_dict(batch), + ) + + def as_data_dict(self) -> features.BatchDict: + """Converts batch object to dictionary.""" + output = { + **self.msa.as_data_dict(), + **self.templates.as_data_dict(), + **self.token_features.as_data_dict(), + **self.ref_structure.as_data_dict(), + **self.predicted_structure_info.as_data_dict(), + **self.polymer_ligand_bond_info.as_data_dict(), + **self.ligand_ligand_bond_info.as_data_dict(), + **self.pseudo_beta_info.as_data_dict(), + **self.atom_cross_att.as_data_dict(), + **self.convert_model_output.as_data_dict(), + **self.frames.as_data_dict(), + } + return output + + def convert_to_tensor(self, dtype=ms.float32): + # msa: features.MSA + self.msa.rows = Tensor(self.msa.rows, dtype=ms.int32) + self.msa.mask = Tensor(self.msa.mask, dtype=ms.int32) + self.msa.deletion_matrix = Tensor( + self.msa.deletion_matrix, dtype=dtype) + self.msa.deletion_mean = Tensor(self.msa.deletion_mean, dtype=dtype) + self.msa.profile = Tensor(self.msa.profile, dtype=dtype) + self.msa.num_alignments = Tensor( + self.msa.num_alignments, dtype=ms.int32) + # templates: features.Templates + self.templates.aatype = Tensor(self.templates.aatype, dtype=ms.int32) + self.templates.atom_mask = Tensor( + self.templates.atom_mask, dtype=ms.int32) + self.templates.atom_positions = Tensor( + self.templates.atom_positions, dtype=dtype) + # token_features: features.TokenFeatures + self.token_features.mask = Tensor( + self.token_features.mask, dtype=ms.int32) + self.token_features.token_index = Tensor( + self.token_features.mask, dtype=ms.int32) + self.token_features.asym_id = Tensor( + self.token_features.asym_id, dtype=ms.int32) + self.token_features.aatype = Tensor( + self.token_features.aatype, dtype=ms.int32) + self.token_features.residue_index = Tensor( + self.token_features.residue_index, dtype=ms.int32) + self.token_features.entity_id = Tensor( + self.token_features.entity_id, dtype=ms.int32) + self.token_features.sym_id = Tensor( + self.token_features.sym_id, dtype=ms.int32) + # ref_structure: features.RefStructure + self.ref_structure.positions = Tensor( + self.ref_structure.positions, dtype=dtype) + self.ref_structure.mask = Tensor(self.ref_structure.mask, dtype=dtype) + self.ref_structure.element = Tensor( + self.ref_structure.element, dtype=ms.int32) + self.ref_structure.charge = Tensor( + self.ref_structure.charge, dtype=dtype) + self.ref_structure.atom_name_chars = Tensor( + self.ref_structure.atom_name_chars, dtype=ms.int32) + self.ref_structure.ref_space_uid = Tensor( + self.ref_structure.ref_space_uid, dtype=dtype) + + # predicted_structure_info: features.PredictedStructureInfo + self.predicted_structure_info.atom_mask = Tensor( + self.predicted_structure_info.atom_mask, dtype=dtype) + + # polymer_ligand_bond_info: features.PolymerLigandBondInfo + self.polymer_ligand_bond_info.tokens_to_polymer_ligand_bonds.gather_idxs = Tensor( + self.polymer_ligand_bond_info.tokens_to_polymer_ligand_bonds.gather_idxs, dtype=ms.int32 + ) + self.polymer_ligand_bond_info.tokens_to_polymer_ligand_bonds.gather_mask = Tensor( + self.polymer_ligand_bond_info.tokens_to_polymer_ligand_bonds.gather_mask, dtype=ms.int32 + ) + # ligand_ligand_bond_info: features.LigandLigandBondInfo + self.ligand_ligand_bond_info.tokens_to_ligand_ligand_bonds.gather_idxs = Tensor( + self.ligand_ligand_bond_info.tokens_to_ligand_ligand_bonds.gather_idxs, dtype=ms.int32 + ) + self.ligand_ligand_bond_info.tokens_to_ligand_ligand_bonds.gather_mask = Tensor( + self.ligand_ligand_bond_info.tokens_to_ligand_ligand_bonds.gather_mask, dtype=ms.int32 + ) + + self.gather_to_tensor(self.pseudo_beta_info.token_atoms_to_pseudo_beta) + self.gather_to_tensor(self.atom_cross_att.queries_to_keys) + self.gather_to_tensor(self.atom_cross_att.tokens_to_queries) + self.gather_to_tensor(self.atom_cross_att.tokens_to_keys) + self.gather_to_tensor(self.atom_cross_att.token_atoms_to_queries) + self.gather_to_tensor(self.atom_cross_att.queries_to_token_atoms) + + # frames: features.Frames + + def astype(self, dtype=ms.float32): + # change dtype of float + # msa: features.MSA + self.msa.deletion_matrix = self.msa.deletion_matrix.astype(dtype) + self.msa.deletion_mean = self.msa.deletion_mean.astype(dtype) + self.msa.profile = self.msa.profile.astype(dtype) + # templates: features.Templates + self.templates.atom_positions = self.templates.atom_positions.astype( + dtype) + # ref_structure: features.RefStructure + self.ref_structure.positions = self.ref_structure.positions.astype( + dtype) + self.ref_structure.mask = self.ref_structure.mask.astype(dtype) + self.ref_structure.charge = self.ref_structure.charge.astype(dtype) + self.ref_structure.ref_space_uid = self.ref_structure.ref_space_uid.astype( + dtype) + + # predicted_structure_info: features.PredictedStructureInfo + self.predicted_structure_info.atom_mask = self.predicted_structure_info.atom_mask.astype( + dtype) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/features.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/features.py new file mode 100644 index 0000000000000000000000000000000000000000..da9f13069594ea4f958757e76ea062fbc403ab99 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/features.py @@ -0,0 +1,2103 @@ +# Copyright 2024 DeepMind Technologies Limited +# Copyright (C) 2025 Huawei Technologies Co., Ltd +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# +# Modifications by Huawei Technologies Co., Ltd: Adapt to run by MindSpore on Ascend + +"""Data-side of the input features processing.""" + +import dataclasses +import datetime +import itertools +import numpy as np +from typing_extensions import Any, Self, TypeAlias +from rdkit import Chem +from rdkit.Chem import AllChem +from absl import logging +from alphafold3 import structure +from alphafold3.common import folding_input +from alphafold3.constants import chemical_components +from alphafold3.constants import mmcif_names +from alphafold3.constants import periodic_table +from alphafold3.constants import residue_names +from alphafold3.data import msa as msa_module +from alphafold3.data import templates +from alphafold3.data.tools import rdkit_utils +from alphafold3.model import data3 +from alphafold3.model import data_constants +from alphafold3.model import merging_features +from alphafold3.model import msa_pairing +from alphafold3.model.atom_layout import atom_layout +from alphafold3.structure import chemical_components as struc_chem_comps + + +xnp_ndarray: TypeAlias = np.ndarray # pylint: disable=invalid-name +BatchDict: TypeAlias = dict[str, xnp_ndarray] + +_STANDARD_RESIDUES = frozenset({ + *residue_names.PROTEIN_TYPES_WITH_UNKNOWN, + *residue_names.NUCLEIC_TYPES_WITH_2_UNKS, +}) + + +@dataclasses.dataclass +class PaddingShapes: + num_tokens: int + msa_size: int + num_chains: int + num_templates: int + num_atoms: int + + +def _pad_to( + arr: np.ndarray, shape: tuple[int | None, ...], **kwargs +) -> np.ndarray: + """Pads an array to a given shape. Wrapper around np.pad(). + + Args: + arr: numpy array to pad + shape: target shape, use None for axes that should stay the same + **kwargs: additional args for np.pad, e.g. constant_values=-1 + + Returns: + the padded array + + Raises: + ValueError if arr and shape have a different number of axes. + """ + if arr.ndim != len(shape): + raise ValueError( + f'arr and shape have different number of axes. {arr.shape=}, {shape=}' + ) + + num_pad = [] + for axis, width in enumerate(shape): + if width is None: + num_pad.append((0, 0)) + else: + if width >= arr.shape[axis]: + num_pad.append((0, width - arr.shape[axis])) + else: + raise ValueError( + f'Can not pad to a smaller shape. {arr.shape=}, {shape=}' + ) + padded_arr = np.pad(arr, pad_width=num_pad, **kwargs) + return padded_arr + + +def _unwrap(obj): + """Unwrap an object from a zero-dim np.ndarray.""" + if isinstance(obj, np.ndarray) and obj.ndim == 0: + return obj.item() + else: + return obj + + +@dataclasses.dataclass +class Chains: + chain_id: np.ndarray + asym_id: np.ndarray + entity_id: np.ndarray + sym_id: np.ndarray + + +def _compute_asym_entity_and_sym_id( + all_tokens: atom_layout.AtomLayout, +) -> Chains: + """Compute asym_id, entity_id and sym_id. + + Args: + all_tokens: atom layout containing a representative atom for each token. + + Returns: + A Chains object + """ + + # Find identical sequences and assign entity_id and sym_id to every chain. + seq_to_entity_id_sym_id = {} + seen_chain_ids = set() + chain_ids = [] + asym_ids = [] + entity_ids = [] + sym_ids = [] + for chain_id in all_tokens.chain_id: + if chain_id not in seen_chain_ids: + asym_id = len(seen_chain_ids) + 1 + seen_chain_ids.add(chain_id) + seq = ','.join( + all_tokens.res_name[all_tokens.chain_id == chain_id]) + if seq not in seq_to_entity_id_sym_id: + entity_id = len(seq_to_entity_id_sym_id) + 1 + sym_id = 1 + else: + entity_id, sym_id = seq_to_entity_id_sym_id[seq] + sym_id += 1 + seq_to_entity_id_sym_id[seq] = (entity_id, sym_id) + + chain_ids.append(chain_id) + asym_ids.append(asym_id) + entity_ids.append(entity_id) + sym_ids.append(sym_id) + + return Chains( + chain_id=np.array(chain_ids), + asym_id=np.array(asym_ids), + entity_id=np.array(entity_ids), + sym_id=np.array(sym_ids), + ) + + +def tokenizer( + flat_output_layout: atom_layout.AtomLayout, + ccd: chemical_components.Ccd, + max_atoms_per_token: int, + flatten_non_standard_residues: bool, + logging_name: str, +) -> tuple[atom_layout.AtomLayout, atom_layout.AtomLayout, np.ndarray]: + """Maps a flat atom layout to tokens for evoformer. + + Creates the evoformer tokens as one token per polymer residue and one token + per ligand atom. The tokens are represented as AtomLayouts all_tokens + (1 representative atom per token) atoms per residue, and + all_token_atoms_layout (num_tokens, max_atoms_per_token). The atoms in a + residue token use the layout of the corresponding CCD entry + + Args: + flat_output_layout: flat AtomLayout containing all atoms that the model + wants to predict. + ccd: The chemical components dictionary. + max_atoms_per_token: number of slots per token. + flatten_non_standard_residues: whether to flatten non-standard residues, + i.e. whether to use one token per atom for non-standard residues. + logging_name: logging name for debugging (usually the mmcif_id). + + Returns: + A tuple (all_tokens, all_tokens_atoms_layout) with + all_tokens: AtomLayout shape (num_tokens,) containing one representative + atom per token. + all_token_atoms_layout: AtomLayout with shape + (num_tokens, max_atoms_per_token) containing all atoms per token. + standard_token_idxs: The token index that each token would have if not + flattening non standard resiudes. + """ + # Select the representative atom for each token. + token_idxs = [] + single_atom_token = [] + standard_token_idxs = [] + current_standard_token_id = 0 + # Iterate over residues, and provide a group_iter over the atoms of each + # residue. + for key, group_iter in itertools.groupby( + zip( + flat_output_layout.chain_type, + flat_output_layout.chain_id, + flat_output_layout.res_id, + flat_output_layout.res_name, + flat_output_layout.atom_name, + np.arange(flat_output_layout.shape[0]), + ), + key=lambda x: x[:3], + ): + + # Get chain type and chain id of this residue + chain_type, chain_id, _ = key + + # Get names and global idxs for all atoms of this residue + _, _, _, res_names, atom_names, idxs = zip(*group_iter) + + # As of March 2023, all OTHER CHAINs in pdb are artificial nucleics. + is_nucleic_backbone = ( + chain_type in mmcif_names.NUCLEIC_ACID_CHAIN_TYPES + or chain_type == mmcif_names.OTHER_CHAIN + ) + if chain_type in mmcif_names.PEPTIDE_CHAIN_TYPES: + res_name = res_names[0] + if ( + flatten_non_standard_residues + and res_name not in residue_names.PROTEIN_TYPES_WITH_UNKNOWN + and res_name != residue_names.MSE + ): + # For non-standard protein residues take all atoms. + # NOTE: This may get very large if we include hydrogens. + token_idxs.extend(idxs) + single_atom_token += [True] * len(idxs) + standard_token_idxs.extend( + [current_standard_token_id] * len(idxs)) + else: + # For standard protein residues take 'CA' if it exists, else first atom. + if 'CA' in atom_names: + token_idxs.append(idxs[atom_names.index('CA')]) + else: + token_idxs.append(idxs[0]) + single_atom_token += [False] + standard_token_idxs.append(current_standard_token_id) + current_standard_token_id += 1 + elif is_nucleic_backbone: + res_name = res_names[0] + if ( + flatten_non_standard_residues + and res_name not in residue_names.NUCLEIC_TYPES_WITH_2_UNKS + ): + # For non-standard nucleic residues take all atoms. + token_idxs.extend(idxs) + single_atom_token += [True] * len(idxs) + standard_token_idxs.extend( + [current_standard_token_id] * len(idxs)) + else: + # For standard nucleic residues take C1' if it exists, else first atom. + if "C1'" in atom_names: + token_idxs.append(idxs[atom_names.index("C1'")]) + else: + token_idxs.append(idxs[0]) + single_atom_token += [False] + standard_token_idxs.append(current_standard_token_id) + current_standard_token_id += 1 + elif chain_type in mmcif_names.NON_POLYMER_CHAIN_TYPES: + # For non-polymers take all atoms + token_idxs.extend(idxs) + single_atom_token += [True] * len(idxs) + standard_token_idxs.extend([current_standard_token_id] * len(idxs)) + current_standard_token_id += len(idxs) + else: + # Chain type that we don't handle yet. + logging.warning( + '%s: ignoring chain %s with chain type %s.', + logging_name, + chain_id, + chain_type, + ) + + assert len(token_idxs) == len(single_atom_token) + assert len(token_idxs) == len(standard_token_idxs) + standard_token_idxs = np.array(standard_token_idxs, dtype=np.int32) + + # Create the list of all tokens, represented as a flat AtomLayout with 1 + # representative atom per token. + all_tokens = flat_output_layout[token_idxs] + + # Create the 2D atoms_per_token layout + num_tokens = all_tokens.shape[0] + + # Target lists. + target_atom_names = [] + target_atom_elements = [] + target_res_ids = [] + target_res_names = [] + target_chain_ids = [] + target_chain_types = [] + + # uids of all atoms in the flat layout, to check whether the dense atoms + # exist -- This is necessary for terminal atoms (e.g. 'OP3' or 'OXT') + all_atoms_uids = set( + zip( + flat_output_layout.chain_id, + flat_output_layout.res_id, + flat_output_layout.atom_name, + ) + ) + + for idx, single_atom in enumerate(single_atom_token): + if not single_atom: + # Standard protein and nucleic residues have many atoms per token + chain_id = all_tokens.chain_id[idx] + res_id = all_tokens.res_id[idx] + res_name = all_tokens.res_name[idx] + atom_names = [] + atom_elements = [] + + res_atoms = struc_chem_comps.get_all_atoms_in_entry( + ccd=ccd, res_name=res_name + ) + atom_names_elements = list( + zip( + res_atoms['_chem_comp_atom.atom_id'], + res_atoms['_chem_comp_atom.type_symbol'], + strict=True, + ) + ) + + for atom_name, atom_element in atom_names_elements: + # Remove hydrogens if they are not in flat layout. + if atom_element in ['H', 'D'] and ( + (chain_id, res_id, atom_name) not in all_atoms_uids + ): + continue + elif (chain_id, res_id, atom_name) in all_atoms_uids: + atom_names.append(atom_name) + atom_elements.append(atom_element) + # Leave spaces for OXT etc. + else: + atom_names.append('') + atom_elements.append('') + + if len(atom_names) > max_atoms_per_token: + logging.warning( + 'Atom list for chain %s ' + 'residue %s %s is too long and will be truncated: ' + '%s to the max atoms limit %s. Dropped atoms: %s', + chain_id, + res_id, + res_name, + len(atom_names), + max_atoms_per_token, + list( + zip( + atom_names[max_atoms_per_token:], + atom_elements[max_atoms_per_token:], + strict=True, + ) + ), + ) + atom_names = atom_names[:max_atoms_per_token] + atom_elements = atom_elements[:max_atoms_per_token] + + num_pad = max_atoms_per_token - len(atom_names) + atom_names.extend([''] * num_pad) + atom_elements.extend([''] * num_pad) + + else: + # ligands have only 1 atom per token + padding = [''] * (max_atoms_per_token - 1) + atom_names = [all_tokens.atom_name[idx]] + padding + atom_elements = [all_tokens.atom_element[idx]] + padding + + # Append the atoms to the target lists. + target_atom_names.append(atom_names) + target_atom_elements.append(atom_elements) + target_res_names.append( + [all_tokens.res_name[idx]] * max_atoms_per_token) + target_res_ids.append([all_tokens.res_id[idx]] * max_atoms_per_token) + target_chain_ids.append( + [all_tokens.chain_id[idx]] * max_atoms_per_token) + target_chain_types.append( + [all_tokens.chain_type[idx]] * max_atoms_per_token + ) + + # Make sure to get the right shape also for 0 tokens + trg_shape = (num_tokens, max_atoms_per_token) + all_token_atoms_layout = atom_layout.AtomLayout( + atom_name=np.array(target_atom_names, dtype=object).reshape(trg_shape), + atom_element=np.array(target_atom_elements, dtype=object).reshape( + trg_shape + ), + res_name=np.array(target_res_names, dtype=object).reshape(trg_shape), + res_id=np.array(target_res_ids, dtype=int).reshape(trg_shape), + chain_id=np.array(target_chain_ids, dtype=object).reshape(trg_shape), + chain_type=np.array(target_chain_types, + dtype=object).reshape(trg_shape), + ) + + return all_tokens, all_token_atoms_layout, standard_token_idxs + + +@dataclasses.dataclass +class MSA: + """Dataclass containing MSA.""" + + rows: xnp_ndarray + mask: xnp_ndarray + deletion_matrix: xnp_ndarray + # Occurrence of each residue type along the sequence, averaged over MSA rows. + profile: xnp_ndarray + # Occurrence of deletions along the sequence, averaged over MSA rows. + deletion_mean: xnp_ndarray + # Number of MSA alignments. + num_alignments: xnp_ndarray + + @classmethod + def compute_features( + cls, + *, + all_tokens: atom_layout.AtomLayout, + standard_token_idxs: np.ndarray, + padding_shapes: PaddingShapes, + fold_input: folding_input.Input, + logging_name: str, + max_paired_sequence_per_species: int, + ) -> Self: + """Compute the msa features.""" + seen_entities = {} + + substruct = atom_layout.make_structure( + flat_layout=all_tokens, + atom_coords=np.zeros(all_tokens.shape + (3,)), + name=logging_name, + ) + prot = substruct.filter_to_entity_type(protein=True) + num_unique_chains = len( + set(prot.chain_single_letter_sequence().values())) + need_msa_pairing = num_unique_chains > 1 + + np_chains_list = [] + input_chains_by_id = {chain.id: chain for chain in fold_input.chains} + nonempty_chain_ids = set(all_tokens.chain_id) + for asym_id, chain_info in enumerate(substruct.iter_chains(), start=1): + b_chain_id = chain_info['chain_id'] + chain_type = chain_info['chain_type'] + chain = input_chains_by_id[b_chain_id] + + # Generalised "sequence" for ligands (can't trust residue name) + chain_tokens = all_tokens[all_tokens.chain_id == b_chain_id] + assert chain_tokens.res_name is not None + three_letter_sequence = ','.join(chain_tokens.res_name.tolist()) + chain_num_tokens = len(chain_tokens.atom_name) + if chain_type in mmcif_names.POLYMER_CHAIN_TYPES: + sequence = substruct.chain_single_letter_sequence()[b_chain_id] + if chain_type in mmcif_names.NUCLEIC_ACID_CHAIN_TYPES: + # Only allow nucleic residue types for nucleic chains (can have some + # protein residues in e.g. tRNA, but that causes MSA search failures). + # Replace non nucleic residue types by UNK_NUCLEIC. + nucleic_types_one_letter = ( + residue_names.DNA_TYPES_ONE_LETTER + + residue_names.RNA_TYPES_ONE_LETTER_WITH_UNKNOWN + ) + sequence = ''.join([ + base + if base in nucleic_types_one_letter + else residue_names.UNK_NUCLEIC_ONE_LETTER + for base in sequence + ]) + else: + sequence = 'X' * chain_num_tokens + + skip_chain = ( + chain_type not in mmcif_names.STANDARD_POLYMER_CHAIN_TYPES + or len(sequence) <= 4 + or b_chain_id not in nonempty_chain_ids + ) + if three_letter_sequence in seen_entities: + entity_id = seen_entities[three_letter_sequence] + else: + entity_id = len(seen_entities) + 1 + + if chain_type in mmcif_names.STANDARD_POLYMER_CHAIN_TYPES: + unpaired_a3m = '' + paired_a3m = '' + if not skip_chain: + if need_msa_pairing and isinstance(chain, folding_input.ProteinChain): + paired_a3m = chain.paired_msa + if isinstance( + chain, folding_input.RnaChain | folding_input.ProteinChain + ): + unpaired_a3m = chain.unpaired_msa + unpaired_msa = msa_module.Msa.from_a3m( + query_sequence=sequence, + chain_poly_type=chain_type, + a3m=unpaired_a3m, + deduplicate=True, + ) + + paired_msa = msa_module.Msa.from_a3m( + query_sequence=sequence, + chain_poly_type=mmcif_names.PROTEIN_CHAIN, + a3m=paired_a3m, + deduplicate=False, + ) + else: + unpaired_msa = msa_module.Msa.from_empty( + query_sequence='-' * len(sequence), + chain_poly_type=mmcif_names.PROTEIN_CHAIN, + ) + paired_msa = msa_module.Msa.from_empty( + query_sequence='-' * len(sequence), + chain_poly_type=mmcif_names.PROTEIN_CHAIN, + ) + + msa_features = unpaired_msa.featurize() + all_seqs_msa_features = paired_msa.featurize() + + msa_features = data3.fix_features(msa_features) + all_seqs_msa_features = data3.fix_features(all_seqs_msa_features) + + msa_features = msa_features | { + f'{k}_all_seq': v for k, v in all_seqs_msa_features.items() + } + feats = msa_features + feats['chain_id'] = b_chain_id + feats['asym_id'] = np.full(chain_num_tokens, asym_id) + feats['entity_id'] = entity_id + np_chains_list.append(feats) + + # Add profile features to each chain. + for chain in np_chains_list: + chain.update( + data3.get_profile_features( + chain['msa'], chain['deletion_matrix']) + ) + + # Allow 50% of the MSA to come from MSA pairing. + max_paired_sequences = padding_shapes.msa_size // 2 + if need_msa_pairing: + np_chains_list = list(map(dict, np_chains_list)) + np_chains_list = msa_pairing.create_paired_features( + np_chains_list, + max_paired_sequences=max_paired_sequences, + nonempty_chain_ids=nonempty_chain_ids, + max_hits_per_species=max_paired_sequence_per_species, + ) + np_chains_list = msa_pairing.deduplicate_unpaired_sequences( + np_chains_list + ) + + # Remove all gapped rows from all seqs. + nonempty_asym_ids = [] + for chain in np_chains_list: + if chain['chain_id'] in nonempty_chain_ids: + nonempty_asym_ids.append(chain['asym_id'][0]) + if 'msa_all_seq' in np_chains_list[0]: + np_chains_list = msa_pairing.remove_all_gapped_rows_from_all_seqs( + np_chains_list, asym_ids=nonempty_asym_ids + ) + + # Crop MSA rows. + cropped_chains_list = [] + for chain in np_chains_list: + unpaired_msa_size, paired_msa_size = ( + msa_pairing.choose_paired_unpaired_msa_crop_sizes( + unpaired_msa=chain['msa'], + paired_msa=chain.get('msa_all_seq'), + total_msa_crop_size=padding_shapes.msa_size, + max_paired_sequences=max_paired_sequences, + ) + ) + cropped_chain = { + 'asym_id': chain['asym_id'], + 'chain_id': chain['chain_id'], + 'profile': chain['profile'], + 'deletion_mean': chain['deletion_mean'], + } + for feat in data_constants.NUM_SEQ_NUM_RES_MSA_FEATURES: + if feat in chain: + cropped_chain[feat] = chain[feat][:unpaired_msa_size] + if feat + '_all_seq' in chain: + cropped_chain[feat + '_all_seq'] = chain[feat + '_all_seq'][ + :paired_msa_size + ] + cropped_chains_list.append(cropped_chain) + + # Merge Chains. + # Make sure the chain order is unaltered before slicing with tokens. + curr_chain_order = [chain['chain_id'] for chain in cropped_chains_list] + orig_chain_order = [chain['chain_id'] + for chain in substruct.iter_chains()] + assert curr_chain_order == orig_chain_order + np_example = { + 'asym_id': np.concatenate( + [c['asym_id'] for c in cropped_chains_list], axis=0 + ), + } + for feature in data_constants.NUM_SEQ_NUM_RES_MSA_FEATURES: + for feat in [feature, feature + '_all_seq']: + if feat in cropped_chains_list[0]: + np_example[feat] = merging_features.merge_msa_features( + feat, cropped_chains_list + ) + for feature in ['profile', 'deletion_mean']: + feature_list = [c[feature] for c in cropped_chains_list] + np_example[feature] = np.concatenate(feature_list, axis=0) + + # Crop MSA rows to maximum size given by chains participating in the crop. + max_allowed_unpaired = max([ + len(chain['msa']) + for chain in cropped_chains_list + if chain['asym_id'][0] in nonempty_asym_ids + ]) + np_example['msa'] = np_example['msa'][:max_allowed_unpaired] + if 'msa_all_seq' in np_example: + max_allowed_paired = max([ + len(chain['msa_all_seq']) + for chain in cropped_chains_list + if chain['asym_id'][0] in nonempty_asym_ids + ]) + np_example['msa_all_seq'] = np_example['msa_all_seq'][:max_allowed_paired] + + np_example = merging_features.merge_paired_and_unpaired_msa(np_example) + + # Crop MSA residues. Need to use the standard token indices, since msa does + # not expand non-standard residues. This means that for expanded residues, + # we get repeated msa columns. + new_cropping_idxs = standard_token_idxs + for feature in data_constants.NUM_SEQ_NUM_RES_MSA_FEATURES: + if feature in np_example: + np_example[feature] = np_example[feature][:, + new_cropping_idxs].copy() + for feature in ['profile', 'deletion_mean']: + np_example[feature] = np_example[feature][new_cropping_idxs] + + # Make MSA mask. + np_example['msa_mask'] = np.ones_like( + np_example['msa'], dtype=np.float32) + + # Count MSA size before padding. + num_alignments = np_example['msa'].shape[0] + + # Pad: + msa_size, num_tokens = padding_shapes.msa_size, padding_shapes.num_tokens + + def safe_cast_int8(x): + return np.clip(x, np.iinfo(np.int8).min, np.iinfo(np.int8).max).astype( + np.int8 + ) + + return MSA( + rows=_pad_to(safe_cast_int8( + np_example['msa']), (msa_size, num_tokens)), + mask=_pad_to( + np_example['msa_mask'].astype(bool), (msa_size, num_tokens) + ), + # deletion_matrix may be out of int8 range, but we mostly care about + # small values since we arctan it in the model. + deletion_matrix=_pad_to( + safe_cast_int8(np_example['deletion_matrix']), + (msa_size, num_tokens), + ), + profile=_pad_to(np_example['profile'], (num_tokens, None)), + deletion_mean=_pad_to(np_example['deletion_mean'], (num_tokens,)), + num_alignments=np.array(num_alignments, dtype=np.int32), + ) + + def index_msa_rows(self, indices: xnp_ndarray) -> Self: + assert indices.ndim == 1 + + return MSA( + rows=self.rows[indices, :], + mask=self.mask[indices, :], + deletion_matrix=self.deletion_matrix[indices, :], + profile=self.profile, + deletion_mean=self.deletion_mean, + num_alignments=self.num_alignments, + ) + + @classmethod + def from_data_dict(cls, batch: BatchDict) -> Self: + output = cls( + rows=batch['msa'], + mask=batch['msa_mask'], + deletion_matrix=batch['deletion_matrix'], + profile=batch['profile'], + deletion_mean=batch['deletion_mean'], + num_alignments=batch['num_alignments'], + ) + return output + + def as_data_dict(self) -> BatchDict: + return { + 'msa': self.rows, + 'msa_mask': self.mask, + 'deletion_matrix': self.deletion_matrix, + 'profile': self.profile, + 'deletion_mean': self.deletion_mean, + 'num_alignments': self.num_alignments, + } + + +@dataclasses.dataclass +class Templates: + """Dataclass containing templates.""" + + # aatype of templates, int32 w shape [num_templates, num_res] + aatype: xnp_ndarray + # atom positions of templates, float32 w shape [num_templates, num_res, 24, 3] + atom_positions: xnp_ndarray + # atom mask of templates, bool w shape [num_templates, num_res, 24] + atom_mask: xnp_ndarray + def __getitem__(self, idx): + return Templates(self.aatype[idx], self.atom_positions[idx], self.atom_mask[idx]) + @classmethod + def compute_features( + cls, + all_tokens: atom_layout.AtomLayout, + standard_token_idxs: np.ndarray, + padding_shapes: PaddingShapes, + fold_input: folding_input.Input, + max_templates: int, + logging_name: str, + ) -> Self: + """Compute the template features.""" + + seen_entities = {} + polymer_entity_features = {True: {}, False: {}} + + substruct = atom_layout.make_structure( + flat_layout=all_tokens, + atom_coords=np.zeros(all_tokens.shape + (3,)), + name=logging_name, + ) + np_chains_list = [] + + input_chains_by_id = {chain.id: chain for chain in fold_input.chains} + + nonempty_chain_ids = set(all_tokens.chain_id) + for chain_info in substruct.iter_chains(): + chain_id = chain_info['chain_id'] + chain_type = chain_info['chain_type'] + chain = input_chains_by_id[chain_id] + + # Generalised "sequence" for ligands (can't trust residue name) + chain_tokens = all_tokens[all_tokens.chain_id == chain_id] + assert chain_tokens.res_name is not None + three_letter_sequence = ','.join(chain_tokens.res_name.tolist()) + chain_num_tokens = len(chain_tokens.atom_name) + + # Don't compute features for chains not included in the crop, or ligands. + skip_chain = ( + chain_type != mmcif_names.PROTEIN_CHAIN + or chain_num_tokens <= 4 # not cache filled + or chain_id not in nonempty_chain_ids + ) + + if three_letter_sequence in seen_entities: + entity_id = seen_entities[three_letter_sequence] + else: + entity_id = len(seen_entities) + 1 + + if entity_id not in polymer_entity_features[skip_chain]: + if skip_chain: + template_features = data3.empty_template_features( + chain_num_tokens) + else: + assert isinstance(chain, folding_input.ProteinChain) + + sorted_features = [] + for template in chain.templates: + struct = structure.from_mmcif( + template.mmcif, + fix_mse_residues=True, + fix_arginines=True, + include_bonds=False, + include_water=False, + # For non-standard polymer chains. + include_other=True, + ) + hit_features = templates.get_polymer_features( + chain=struct, + chain_poly_type=mmcif_names.PROTEIN_CHAIN, + query_sequence_length=len(chain.sequence), + query_to_hit_mapping=dict( + template.query_to_template_map), + ) + sorted_features.append(hit_features) + + template_features = templates.package_template_features( + hit_features=sorted_features, + include_ligand_features=False, + ) + + template_features = data3.fix_template_features( + sequence=chain.sequence, + template_features=template_features, + ) + + template_features = _reduce_template_features( + template_features, max_templates + ) + polymer_entity_features[skip_chain][entity_id] = template_features + + seen_entities[three_letter_sequence] = entity_id + feats = polymer_entity_features[skip_chain][entity_id].copy() + feats['chain_id'] = chain_id + np_chains_list.append(feats) + + # We pad the num_templates dimension before merging, so that different + # chains can be concatenated on the num_res dimension. Masking will be + # applied so that each chains templates can't see each other. + for chain in np_chains_list: + chain['template_aatype'] = _pad_to( + chain['template_aatype'], (max_templates, None) + ) + chain['template_atom_positions'] = _pad_to( + chain['template_atom_positions'], ( + max_templates, None, None, None) + ) + chain['template_atom_mask'] = _pad_to( + chain['template_atom_mask'], (max_templates, None, None) + ) + + # Merge on token dimension. + np_example = { + ft: np.concatenate([c[ft] for c in np_chains_list], axis=1) + for ft in np_chains_list[0] + if ft in data_constants.TEMPLATE_FEATURES + } + + # Crop template data. Need to use the standard token indices, since msa does + # not expand non-standard residues. This means that for expanded residues, + # we get repeated template information. + for feature_name, v in np_example.items(): + np_example[feature_name] = v[:max_templates, + standard_token_idxs, ...] + + # Pad along the token dimension. + templates_features = Templates( + aatype=_pad_to( + np_example['template_aatype'], (None, + padding_shapes.num_tokens) + ), + atom_positions=_pad_to( + np_example['template_atom_positions'], + (None, padding_shapes.num_tokens, None, None), + ), + atom_mask=_pad_to( + np_example['template_atom_mask'].astype(bool), + (None, padding_shapes.num_tokens, None), + ), + ) + return templates_features + + @classmethod + def from_data_dict(cls, batch: BatchDict) -> Self: + """Make Template from batch dictionary.""" + return cls( + aatype=batch['template_aatype'], + atom_positions=batch['template_atom_positions'], + atom_mask=batch['template_atom_mask'], + ) + + def as_data_dict(self) -> BatchDict: + return { + 'template_aatype': self.aatype, + 'template_atom_positions': self.atom_positions, + 'template_atom_mask': self.atom_mask, + } + + +def _reduce_template_features( + template_features: data3.FeatureDict, + max_templates: int, +) -> data3.FeatureDict: + """Reduces template features to max num templates and defined feature set.""" + num_templates = template_features['template_aatype'].shape[0] + template_keep_mask = np.arange(num_templates) < max_templates + template_fields = data_constants.TEMPLATE_FEATURES + ( + 'template_release_timestamp', + ) + template_features = { + k: v[template_keep_mask] + for k, v in template_features.items() + if k in template_fields + } + return template_features + + +@dataclasses.dataclass +class TokenFeatures: + """Dataclass containing features for tokens.""" + + residue_index: xnp_ndarray + token_index: xnp_ndarray + aatype: xnp_ndarray + mask: xnp_ndarray + seq_length: xnp_ndarray + + # Chain symmetry identifiers + # for an A3B2 stoichiometry the meaning of these features is as follows: + # asym_id: 1 2 3 4 5 + # entity_id: 1 1 1 2 2 + # sym_id: 1 2 3 1 2 + asym_id: xnp_ndarray + entity_id: xnp_ndarray + sym_id: xnp_ndarray + + # token type features + is_protein: xnp_ndarray + is_rna: xnp_ndarray + is_dna: xnp_ndarray + is_ligand: xnp_ndarray + is_nonstandard_polymer_chain: xnp_ndarray + is_water: xnp_ndarray + + @classmethod + def compute_features( + cls, + all_tokens: atom_layout.AtomLayout, + padding_shapes: PaddingShapes, + ) -> Self: + """Compute the per-token features.""" + + residue_index = all_tokens.res_id.astype(np.int32) + + token_index = np.arange( + 1, len(all_tokens.atom_name) + 1).astype(np.int32) + + aatype = [] + for res_name, chain_type in zip(all_tokens.res_name, all_tokens.chain_type): + if chain_type in mmcif_names.POLYMER_CHAIN_TYPES: + res_name = mmcif_names.fix_non_standard_polymer_res( + res_name=res_name, chain_type=chain_type + ) + if ( + chain_type == mmcif_names.DNA_CHAIN + and res_name == residue_names.UNK_DNA + ): + res_name = residue_names.UNK_NUCLEIC_ONE_LETTER + elif chain_type in mmcif_names.NON_POLYMER_CHAIN_TYPES: + res_name = residue_names.UNK + else: + raise ValueError( + f'Chain type {chain_type} not polymer or ligand.') + aa = residue_names.POLYMER_TYPES_ORDER_WITH_UNKNOWN_AND_GAP[res_name] + aatype.append(aa) + aatype = np.array(aatype, dtype=np.int32) + + mask = np.ones(all_tokens.shape[0], dtype=bool) + chains = _compute_asym_entity_and_sym_id(all_tokens) + m = dict(zip(chains.chain_id, chains.asym_id)) + asym_id = np.array([m[c] for c in all_tokens.chain_id], dtype=np.int32) + + m = dict(zip(chains.chain_id, chains.entity_id)) + entity_id = np.array([m[c] + for c in all_tokens.chain_id], dtype=np.int32) + + m = dict(zip(chains.chain_id, chains.sym_id)) + sym_id = np.array([m[c] for c in all_tokens.chain_id], dtype=np.int32) + + seq_length = np.array(all_tokens.shape[0], dtype=np.int32) + + is_protein = all_tokens.chain_type == mmcif_names.PROTEIN_CHAIN + is_rna = all_tokens.chain_type == mmcif_names.RNA_CHAIN + is_dna = all_tokens.chain_type == mmcif_names.DNA_CHAIN + is_ligand = np.isin( + all_tokens.chain_type, list(mmcif_names.LIGAND_CHAIN_TYPES) + ) + standard_polymer_chain = list(mmcif_names.NON_POLYMER_CHAIN_TYPES) + list( + mmcif_names.STANDARD_POLYMER_CHAIN_TYPES + ) + is_nonstandard_polymer_chain = np.isin( + all_tokens.chain_type, standard_polymer_chain, invert=True + ) + is_water = all_tokens.chain_type == mmcif_names.WATER + + return TokenFeatures( + residue_index=_pad_to(residue_index, (padding_shapes.num_tokens,)), + token_index=_pad_to(token_index, (padding_shapes.num_tokens,)), + aatype=_pad_to(aatype, (padding_shapes.num_tokens,)), + mask=_pad_to(mask, (padding_shapes.num_tokens,)), + asym_id=_pad_to(asym_id, (padding_shapes.num_tokens,)), + entity_id=_pad_to(entity_id, (padding_shapes.num_tokens,)), + sym_id=_pad_to(sym_id, (padding_shapes.num_tokens,)), + seq_length=seq_length, + is_protein=_pad_to(is_protein, (padding_shapes.num_tokens,)), + is_rna=_pad_to(is_rna, (padding_shapes.num_tokens,)), + is_dna=_pad_to(is_dna, (padding_shapes.num_tokens,)), + is_ligand=_pad_to(is_ligand, (padding_shapes.num_tokens,)), + is_nonstandard_polymer_chain=_pad_to( + is_nonstandard_polymer_chain, (padding_shapes.num_tokens,) + ), + is_water=_pad_to(is_water, (padding_shapes.num_tokens,)), + ) + + @classmethod + def from_data_dict(cls, batch: BatchDict) -> Self: + return cls( + residue_index=batch['residue_index'], + token_index=batch['token_index'], + aatype=batch['aatype'], + mask=batch['seq_mask'], + entity_id=batch['entity_id'], + asym_id=batch['asym_id'], + sym_id=batch['sym_id'], + seq_length=batch['seq_length'], + is_protein=batch['is_protein'], + is_rna=batch['is_rna'], + is_dna=batch['is_dna'], + is_ligand=batch['is_ligand'], + is_nonstandard_polymer_chain=batch['is_nonstandard_polymer_chain'], + is_water=batch['is_water'], + ) + + def as_data_dict(self) -> BatchDict: + return { + 'residue_index': self.residue_index, + 'token_index': self.token_index, + 'aatype': self.aatype, + 'seq_mask': self.mask, + 'entity_id': self.entity_id, + 'asym_id': self.asym_id, + 'sym_id': self.sym_id, + 'seq_length': self.seq_length, + 'is_protein': self.is_protein, + 'is_rna': self.is_rna, + 'is_dna': self.is_dna, + 'is_ligand': self.is_ligand, + 'is_nonstandard_polymer_chain': self.is_nonstandard_polymer_chain, + 'is_water': self.is_water, + } + + +@dataclasses.dataclass +class PredictedStructureInfo: + """Contains information necessary to work with predicted structure.""" + + atom_mask: xnp_ndarray + residue_center_index: xnp_ndarray + + @classmethod + def compute_features( + cls, + all_tokens: atom_layout.AtomLayout, + all_token_atoms_layout: atom_layout.AtomLayout, + padding_shapes: PaddingShapes, + ) -> Self: + """Compute the PredictedStructureInfo features. + + Args: + all_tokens: flat AtomLayout with 1 representative atom per token, shape + (num_tokens,) + all_token_atoms_layout: AtomLayout for all atoms per token, shape + (num_tokens, max_atoms_per_token) + padding_shapes: padding shapes. + + Returns: + A PredictedStructureInfo object. + """ + atom_mask = _pad_to( + all_token_atoms_layout.atom_name.astype(bool), + (padding_shapes.num_tokens, None), + ) + residue_center_index = np.zeros( + padding_shapes.num_tokens, dtype=np.int32) + for idx in range(all_tokens.shape[0]): + repr_atom = all_tokens.atom_name[idx] + atoms = list(all_token_atoms_layout.atom_name[idx, :]) + if repr_atom in atoms: + residue_center_index[idx] = atoms.index(repr_atom) + else: + # Representative atoms can be missing if cropping the number of atoms + # per residue. + logging.warning( + 'The representative atom in all_tokens (%s) is not in ' + 'all_token_atoms_layout (%s)', + all_tokens[idx: idx + 1], + all_token_atoms_layout[idx, :], + ) + residue_center_index[idx] = 0 + return cls(atom_mask=atom_mask, residue_center_index=residue_center_index) + + @classmethod + def from_data_dict(cls, batch: BatchDict) -> Self: + return cls( + atom_mask=batch['pred_dense_atom_mask'], + residue_center_index=batch['residue_center_index'], + ) + + def as_data_dict(self) -> BatchDict: + return { + 'pred_dense_atom_mask': self.atom_mask, + 'residue_center_index': self.residue_center_index, + } + + +@dataclasses.dataclass +class PolymerLigandBondInfo: + """Contains information about polymer-ligand bonds.""" + + tokens_to_polymer_ligand_bonds: atom_layout.GatherInfo + # Gather indices to convert from cropped dense atom layout to bonds layout + # (num_tokens, 2) + token_atoms_to_bonds: atom_layout.GatherInfo + + @classmethod + def compute_features( + cls, + all_tokens: atom_layout.AtomLayout, + all_token_atoms_layout: atom_layout.AtomLayout, + bond_layout: atom_layout.AtomLayout | None, + padding_shapes: PaddingShapes, + ) -> Self: + """Computes the InterChainBondInfo features. + + Args: + all_tokens: AtomLayout for tokens; shape (num_tokens,). + all_token_atoms_layout: Atom Layout for all atoms (num_tokens, + max_atoms_per_token) + bond_layout: Bond layout for polymer-ligand bonds. + padding_shapes: Padding shapes. + + Returns: + A PolymerLigandBondInfo object. + """ + + if bond_layout is not None: + # Must convert to list before calling np.isin, will not work raw. + peptide_types = list(mmcif_names.PEPTIDE_CHAIN_TYPES) + nucleic_types = list(mmcif_names.NUCLEIC_ACID_CHAIN_TYPES) + [ + mmcif_names.OTHER_CHAIN + ] + # These atom renames are so that we can use the atom layout code with + # all_tokens, which only has a single atom per token. + atom_names = bond_layout.atom_name.copy() + atom_names[np.isin(bond_layout.chain_type, peptide_types)] = 'CA' + atom_names[np.isin(bond_layout.chain_type, nucleic_types)] = "C1'" + adjusted_bond_layout = atom_layout.AtomLayout( + atom_name=atom_names, + res_id=bond_layout.res_id, + chain_id=bond_layout.chain_id, + chain_type=bond_layout.chain_type, + ) + # Remove bonds that are not in the crop. + cropped_tokens_to_bonds = atom_layout.compute_gather_idxs( + source_layout=all_tokens, target_layout=adjusted_bond_layout + ) + bond_is_in_crop = np.all( + cropped_tokens_to_bonds.gather_mask, axis=1 + ).astype(bool) + adjusted_bond_layout = adjusted_bond_layout[bond_is_in_crop, :] + else: + # Create layout with correct shape when bond_layout is None. + s = (0, 2) + adjusted_bond_layout = atom_layout.AtomLayout( + atom_name=np.array([], dtype=object).reshape(s), + res_id=np.array([], dtype=int).reshape(s), + chain_id=np.array([], dtype=object).reshape(s), + ) + adjusted_bond_layout = adjusted_bond_layout.copy_and_pad_to( + (padding_shapes.num_tokens, 2) + ) + tokens_to_polymer_ligand_bonds = atom_layout.compute_gather_idxs( + source_layout=all_tokens, target_layout=adjusted_bond_layout + ) + + # Stuff for computing the bond loss. + if bond_layout is not None: + # Pad to num_tokens (hoping that there are never more bonds than tokens). + padded_bond_layout = bond_layout.copy_and_pad_to( + (padding_shapes.num_tokens, 2) + ) + token_atoms_to_bonds = atom_layout.compute_gather_idxs( + source_layout=all_token_atoms_layout, target_layout=padded_bond_layout + ) + else: + token_atoms_to_bonds = atom_layout.GatherInfo( + gather_idxs=np.zeros( + (padding_shapes.num_tokens, 2), dtype=int), + gather_mask=np.zeros( + (padding_shapes.num_tokens, 2), dtype=bool), + input_shape=np.array(( + padding_shapes.num_tokens, + all_token_atoms_layout.shape[1], + )), + ) + + return cls( + tokens_to_polymer_ligand_bonds=tokens_to_polymer_ligand_bonds, + token_atoms_to_bonds=token_atoms_to_bonds, + ) + + @classmethod + def from_data_dict(cls, batch: BatchDict) -> Self: + return cls( + tokens_to_polymer_ligand_bonds=atom_layout.GatherInfo.from_dict( + batch, key_prefix='tokens_to_polymer_ligand_bonds' + ), + token_atoms_to_bonds=atom_layout.GatherInfo.from_dict( + batch, key_prefix='token_atoms_to_polymer_ligand_bonds' + ), + ) + + def as_data_dict(self) -> BatchDict: + return { + **self.tokens_to_polymer_ligand_bonds.as_dict( + key_prefix='tokens_to_polymer_ligand_bonds' + ), + **self.token_atoms_to_bonds.as_dict( + key_prefix='token_atoms_to_polymer_ligand_bonds' + ), + } + + +@dataclasses.dataclass +class LigandLigandBondInfo: + """Contains information about the location of ligand-ligand bonds.""" + + tokens_to_ligand_ligand_bonds: atom_layout.GatherInfo + + @classmethod + def compute_features( + cls, + all_tokens: atom_layout.AtomLayout, + bond_layout: atom_layout.AtomLayout | None, + padding_shapes: PaddingShapes, + ) -> Self: + """Computes the InterChainBondInfo features. + + Args: + all_tokens: AtomLayout for tokens; shape (num_tokens,). + bond_layout: Bond layout for ligand-ligand bonds. + padding_shapes: Padding shapes. + + Returns: + A LigandLigandBondInfo object. + """ + + if bond_layout is not None: + # Discard any bonds that do not join to an existing atom. + keep_mask = [] + all_atom_ids = { + uid + for uid in zip( + all_tokens.chain_id, + all_tokens.res_id, + all_tokens.atom_name, + strict=True, + ) + } + for chain_id, res_id, atom_name in zip( + bond_layout.chain_id, + bond_layout.res_id, + bond_layout.atom_name, + strict=True, + ): + atom_a = (chain_id[0], res_id[0], atom_name[0]) + atom_b = (chain_id[1], res_id[1], atom_name[1]) + if atom_a in all_atom_ids and atom_b in all_atom_ids: + keep_mask.append(True) + else: + keep_mask.append(False) + keep_mask = np.array(keep_mask).astype(bool) + bond_layout = bond_layout[keep_mask] + # Remove any bonds to Hydrogen atoms. + bond_layout = bond_layout[ + ~np.char.startswith(bond_layout.atom_name.astype(str), 'H').any( + axis=1 + ) + ] + atom_names = bond_layout.atom_name + adjusted_bond_layout = atom_layout.AtomLayout( + atom_name=atom_names, + res_id=bond_layout.res_id, + chain_id=bond_layout.chain_id, + chain_type=bond_layout.chain_type, + ) + else: + # Create layout with correct shape when bond_layout is None. + s = (0, 2) + adjusted_bond_layout = atom_layout.AtomLayout( + atom_name=np.array([], dtype=object).reshape(s), + res_id=np.array([], dtype=int).reshape(s), + chain_id=np.array([], dtype=object).reshape(s), + ) + # 10 x num_tokens as max_inter_bonds_ratio + max_intra_bonds_ration = 2.061. + adjusted_bond_layout = adjusted_bond_layout.copy_and_pad_to( + (padding_shapes.num_tokens * 10, 2) + ) + gather_idx = atom_layout.compute_gather_idxs( + source_layout=all_tokens, target_layout=adjusted_bond_layout + ) + return cls(tokens_to_ligand_ligand_bonds=gather_idx) + + @classmethod + def from_data_dict(cls, batch: BatchDict) -> Self: + return cls( + tokens_to_ligand_ligand_bonds=atom_layout.GatherInfo.from_dict( + batch, key_prefix='tokens_to_ligand_ligand_bonds' + ) + ) + + def as_data_dict(self) -> BatchDict: + return { + **self.tokens_to_ligand_ligand_bonds.as_dict( + key_prefix='tokens_to_ligand_ligand_bonds' + ) + } + + +@dataclasses.dataclass +class PseudoBetaInfo: + """Contains information for extracting pseudo-beta and equivalent atoms.""" + + token_atoms_to_pseudo_beta: atom_layout.GatherInfo + + @classmethod + def compute_features( + cls, + all_token_atoms_layout: atom_layout.AtomLayout, + ccd: chemical_components.Ccd, + padding_shapes: PaddingShapes, + logging_name: str, + ) -> Self: + """Compute the PseudoBetaInfo features. + + Args: + all_token_atoms_layout: AtomLayout for all atoms per token, shape + (num_tokens, max_atoms_per_token) + ccd: The chemical components dictionary. + padding_shapes: padding shapes. + logging_name: logging name for debugging (usually the mmcif_id) + + Returns: + A PseudoBetaInfo object. + """ + token_idxs = [] + atom_idxs = [] + for token_idx in range(all_token_atoms_layout.shape[0]): + chain_type = all_token_atoms_layout.chain_type[token_idx, 0] + atom_names = list(all_token_atoms_layout.atom_name[token_idx, :]) + atom_idx = None + is_nucleic_backbone = ( + chain_type in mmcif_names.NUCLEIC_ACID_CHAIN_TYPES + or chain_type == mmcif_names.OTHER_CHAIN + ) + if chain_type == mmcif_names.PROTEIN_CHAIN: + # Protein chains + if 'CB' in atom_names: + atom_idx = atom_names.index('CB') + elif 'CA' in atom_names: + atom_idx = atom_names.index('CA') + elif is_nucleic_backbone: + # RNA / DNA chains + res_name = all_token_atoms_layout.res_name[token_idx, 0] + cifdict = ccd.get(res_name) + if cifdict: + parent = cifdict['_chem_comp.mon_nstd_parent_comp_id'][0] + if parent != '?': + res_name = parent + if res_name in {'A', 'G', 'DA', 'DG'}: + if 'C4' in atom_names: + atom_idx = atom_names.index('C4') + else: + if 'C2' in atom_names: + atom_idx = atom_names.index('C2') + elif chain_type in mmcif_names.NON_POLYMER_CHAIN_TYPES: + # Ligands: there is only one atom per token + atom_idx = 0 + else: + logging.warning( + '%s: Unknown chain type for token %i. (%s)', + logging_name, + token_idx, + all_token_atoms_layout[token_idx: token_idx + 1], + ) + atom_idx = 0 + if atom_idx is None: + (valid_atom_idxs,) = np.nonzero( + all_token_atoms_layout.atom_name[token_idx, :] + ) + if valid_atom_idxs.shape[0] > 0: + atom_idx = valid_atom_idxs[0] + else: + atom_idx = 0 + logging.warning( + '%s token %i (%s), does not contain a pseudo-beta atom.' + 'Using first valid atom (%s) instead.', + logging_name, + token_idx, + all_token_atoms_layout[token_idx: token_idx + 1], + all_token_atoms_layout.atom_name[token_idx, atom_idx], + ) + + token_idxs.append(token_idx) + atom_idxs.append(atom_idx) + + pseudo_beta_layout = all_token_atoms_layout[token_idxs, atom_idxs] + pseudo_beta_layout = pseudo_beta_layout.copy_and_pad_to(( + padding_shapes.num_tokens, + )) + token_atoms_to_pseudo_beta = atom_layout.compute_gather_idxs( + source_layout=all_token_atoms_layout, target_layout=pseudo_beta_layout + ) + + return cls( + token_atoms_to_pseudo_beta=token_atoms_to_pseudo_beta, + ) + + @classmethod + def from_data_dict(cls, batch: BatchDict) -> Self: + return cls( + token_atoms_to_pseudo_beta=atom_layout.GatherInfo.from_dict( + batch, key_prefix='token_atoms_to_pseudo_beta' + ), + ) + + def as_data_dict(self) -> BatchDict: + return { + **self.token_atoms_to_pseudo_beta.as_dict( + key_prefix='token_atoms_to_pseudo_beta' + ), + } + + +_DEFAULT_BLANK_REF = { + 'positions': np.zeros(3), + 'mask': 0, + 'element': 0, + 'charge': 0, + 'atom_name_chars': np.zeros(4), +} + + +def random_rotation(random_state: np.random.RandomState) -> np.ndarray: + # Create a random rotation (Gram-Schmidt orthogonalization of two + # random normal vectors) + v0, v1 = random_state.normal(size=(2, 3)) + e0 = v0 / np.maximum(1e-10, np.linalg.norm(v0)) + v1 = v1 - e0 * np.dot(v1, e0) + e1 = v1 / np.maximum(1e-10, np.linalg.norm(v1)) + e2 = np.cross(e0, e1) + return np.stack([e0, e1, e2]) + + +def random_augmentation( + positions: np.ndarray, + random_state: np.random.RandomState, +) -> np.ndarray: + """Center then apply random translation and rotation.""" + + center = np.mean(positions, axis=0) + rot = random_rotation(random_state) + positions_target = np.einsum('ij,kj->ki', rot, positions - center) + + translation = random_state.normal(size=(3,)) + positions_target = positions_target + translation + return positions_target + + +def get_reference( + res_name: str, + chemical_components_data: struc_chem_comps.ChemicalComponentsData, + ccd: chemical_components.Ccd, + random_state: np.random.RandomState, + ref_max_modified_date: datetime.date, + intra_ligand_ptm_bonds: bool, +) -> tuple[dict[str, Any], Any, Any]: + """Reference structure for residue from CCD or SMILES. + + Args: + res_name: ccd code of the residue. + chemical_components_data: ChemicalComponentsData for making ref structure. + ccd: The chemical components dictionary. + random_state: Numpy RandomState + ref_max_modified_date: date beyond which reference structures must not be + modefied. + intra_ligand_ptm_bonds: Whether to return intra ligand/ ptm bonds. + + Returns: + Mapping from atom names to features, from_atoms, dest_atoms. + """ + ccd_cif = ccd.get(res_name) + non_ccd_with_smiles = False + if not ccd_cif: + # If res name is non-CCD try to get SMILES from chem comp dict. + has_smiles = ( + chemical_components_data.chem_comp + and res_name in chemical_components_data.chem_comp + and chemical_components_data.chem_comp[res_name].pdbx_smiles + ) + if has_smiles: + non_ccd_with_smiles = True + else: + # If no SMILES or CCD, return empty dictionary. + return dict(), None, None + + pos = [] + elements = [] + charges = [] + atom_names = [] + + mol_from_smiles = None # useless init to make pylint happy + if non_ccd_with_smiles: + smiles_string = chemical_components_data.chem_comp[res_name].pdbx_smiles + mol_from_smiles = Chem.MolFromSmiles(smiles_string) + if mol_from_smiles is None: + logging.warning( + 'Fail to construct RDKit Mol from the SMILES string: %s', + smiles_string, + ) + return dict(), None, None + # Note this does not contain ideal coordinates, just bonds. + ccd_cif = rdkit_utils.mol_to_ccd_cif( + mol_from_smiles, component_id='fake_cif' + ) + + # RDKit for non-CCD structure and if ref should be a random RDKit conformer. + try: + if non_ccd_with_smiles: + m = mol_from_smiles + m = Chem.AddHs(m) + m = rdkit_utils.assign_atom_names_from_graph( + m, keep_existing_names=True) + logging.info( + 'Success constructing SMILES reference structure for: %s', res_name + ) + else: + m = rdkit_utils.mol_from_ccd_cif(ccd_cif, remove_hydrogens=False) + # Stochastic conformer search method. + # V3 is the latest and supports macrocycles . + params = AllChem.ETKDGv3() + params.randomSeed = int(random_state.randint(1, 1 << 31)) + AllChem.EmbedMolecule(m, params) + conformer = m.GetConformer() + for i, atom in enumerate(m.GetAtoms()): + elements.append(atom.GetAtomicNum()) + charges.append(atom.GetFormalCharge()) + name = atom.GetProp('atom_name') + atom_names.append(name) + coords = conformer.GetAtomPosition(i) + pos.append([coords.x, coords.y, coords.z]) + pos = np.array(pos, dtype=np.float32) + except (rdkit_utils.MolFromMmcifError, ValueError): + logging.warning( + 'Failed to construct RDKit reference structure for: %s', res_name + ) + + if not atom_names: + # Get CCD ideal coordinates if RDKit fails. + atom_names = ccd_cif['_chem_comp_atom.atom_id'] + # If mol_from_smiles then it won't have ideal coordinates by default. + if '_chem_comp_atom.pdbx_model_Cartn_x_ideal' in ccd_cif: + atom_x = ccd_cif['_chem_comp_atom.pdbx_model_Cartn_x_ideal'] + atom_y = ccd_cif['_chem_comp_atom.pdbx_model_Cartn_y_ideal'] + atom_z = ccd_cif['_chem_comp_atom.pdbx_model_Cartn_z_ideal'] + else: + atom_x = np.array(['?'] * len(atom_names)) + atom_y = np.array(['?'] * len(atom_names)) + atom_z = np.array(['?'] * len(atom_names)) + type_symbols = ccd_cif['_chem_comp_atom.type_symbol'] + charges = ccd_cif['_chem_comp_atom.charge'] + elements = [ + periodic_table.ATOMIC_NUMBER.get(elem_type.capitalize(), 0) + for elem_type in type_symbols + ] + pos = np.array([[x, y, z] for x, y, z in zip(atom_x, atom_y, atom_z)]) + # Unknown reference coordinates are specified by '?' in chem comp dict. + # Replace unknown reference coords with 0. + if '?' in pos and '_chem_comp.pdbx_modified_date' in ccd_cif: + # Use reference coordinates if modified date is before cutoff. + modified_dates = [ + datetime.date.fromisoformat(date) + for date in ccd_cif['_chem_comp.pdbx_modified_date'] + ] + max_modified_date = max(modified_dates) + if max_modified_date < ref_max_modified_date: + atom_x = ccd_cif['_chem_comp_atom.model_Cartn_x'] + atom_y = ccd_cif['_chem_comp_atom.model_Cartn_y'] + atom_z = ccd_cif['_chem_comp_atom.model_Cartn_z'] + pos = np.array([[x, y, z] + for x, y, z in zip(atom_x, atom_y, atom_z)]) + if '?' in pos: + if np.all(pos == '?'): + logging.warning('All ref positions unknown for: %s', res_name) + else: + logging.warning('Some ref positions unknown for: %s', res_name) + pos[pos == '?'] = 0 + pos = np.array(pos, dtype=np.float32) + + pos = random_augmentation(pos, random_state) + + if intra_ligand_ptm_bonds: + assert ccd_cif is not None, 'CCD CIF is None' + from_atom = ccd_cif.get('_chem_comp_bond.atom_id_1', None) + dest_atom = ccd_cif.get('_chem_comp_bond.atom_id_2', None) + else: + from_atom = None + dest_atom = None + + features = {} + for atom_name in atom_names: + features[atom_name] = {} + idx = atom_names.index(atom_name) + charge = 0 if charges[idx] == '?' else int(charges[idx]) + atom_name_chars = np.array([ord(c) - 32 for c in atom_name], dtype=int) + atom_name_chars = _pad_to(atom_name_chars, (4,)) + features[atom_name]['positions'] = pos[idx] + features[atom_name]['mask'] = 1 + features[atom_name]['element'] = elements[idx] + features[atom_name]['charge'] = charge + features[atom_name]['atom_name_chars'] = atom_name_chars + return features, from_atom, dest_atom + + +@dataclasses.dataclass +class RefStructure: + """Contains ref structure information.""" + + # Array with positions, float32, shape [num_res, max_atoms_per_token, 3] + positions: xnp_ndarray + # Array with masks, bool, shape [num_res, max_atoms_per_token] + mask: xnp_ndarray + # Array with elements, int32, shape [num_res, max_atoms_per_token] + element: xnp_ndarray + # Array with charges, float32, shape [num_res, max_atoms_per_token] + charge: xnp_ndarray + # Array with atom name characters, int32, [num_res, max_atoms_per_token, 4] + atom_name_chars: xnp_ndarray + # Array with reference space uids, int32, [num_res, max_atoms_per_token] + ref_space_uid: xnp_ndarray + + @classmethod + def compute_features( + cls, + all_token_atoms_layout: atom_layout.AtomLayout, + ccd: chemical_components.Ccd, + padding_shapes: PaddingShapes, + chemical_components_data: struc_chem_comps.ChemicalComponentsData, + random_state: np.random.RandomState, + ref_max_modified_date: datetime.date, + intra_ligand_ptm_bonds: bool, + ligand_ligand_bonds: atom_layout.AtomLayout | None = None, + ) -> tuple[Self, Any]: + """Reference structure information for each residue.""" + + # Get features per atom + padded_shape = (padding_shapes.num_tokens, + all_token_atoms_layout.shape[1]) + result = { + 'positions': np.zeros((*padded_shape, 3), 'float32'), + 'mask': np.zeros(padded_shape, 'bool'), + 'element': np.zeros(padded_shape, 'int32'), + 'charge': np.zeros(padded_shape, 'float32'), + 'atom_name_chars': np.zeros((*padded_shape, 4), 'int32'), + 'ref_space_uid': np.zeros((*padded_shape,), 'int32'), + } + + atom_names_all = [] + chain_ids_all = [] + res_ids_all = [] + + # Cache reference conformations for each residue. + conformations = {} + ref_space_uids = {} + for idx in np.ndindex(all_token_atoms_layout.shape): + chain_id = all_token_atoms_layout.chain_id[idx] + res_id = all_token_atoms_layout.res_id[idx] + res_name = all_token_atoms_layout.res_name[idx] + is_non_standard = res_name not in _STANDARD_RESIDUES + atom_name = all_token_atoms_layout.atom_name[idx] + if not atom_name: + ref = _DEFAULT_BLANK_REF + else: + if (chain_id, res_id) not in conformations: + conf, from_atom, dest_atom = get_reference( + res_name=res_name, + chemical_components_data=chemical_components_data, + ccd=ccd, + random_state=random_state, + ref_max_modified_date=ref_max_modified_date, + intra_ligand_ptm_bonds=intra_ligand_ptm_bonds, + ) + conformations[(chain_id, res_id)] = conf + + if ( + is_non_standard + and (from_atom is not None) + and (dest_atom is not None) + ): + # Add intra-ligand bond graph + atom_names_ligand = np.stack( + [from_atom, dest_atom], axis=1, dtype=object + ) + atom_names_all.append(atom_names_ligand) + res_ids_all.append( + np.full_like(atom_names_ligand, res_id, dtype=int) + ) + chain_ids_all.append( + np.full_like(atom_names_ligand, + chain_id, dtype=object) + ) + + conformation = conformations.get( + (chain_id, res_id), {atom_name: _DEFAULT_BLANK_REF} + ) + if atom_name not in conformation: + logging.warning( + 'Missing atom "%s" for CCD "%s"', + atom_name, + all_token_atoms_layout.res_name[idx], + ) + ref = conformation.get(atom_name, _DEFAULT_BLANK_REF) + for k in ref: + result[k][idx] = ref[k] + + # Assign a unique reference space id to each component, to determine which + # reference positions live in the same reference space. + space_str_id = ( + all_token_atoms_layout.chain_id[idx], + all_token_atoms_layout.res_id[idx], + ) + if space_str_id not in ref_space_uids: + ref_space_uids[space_str_id] = len(ref_space_uids) + result['ref_space_uid'][idx] = ref_space_uids[space_str_id] + + if atom_names_all: + atom_names_all = np.concatenate(atom_names_all, axis=0) + res_ids_all = np.concatenate(res_ids_all, axis=0) + chain_ids_all = np.concatenate(chain_ids_all, axis=0) + if ligand_ligand_bonds is not None: + adjusted_ligand_ligand_bonds = atom_layout.AtomLayout( + atom_name=np.concatenate( + [ligand_ligand_bonds.atom_name, atom_names_all], axis=0 + ), + chain_id=np.concatenate( + [ligand_ligand_bonds.chain_id, chain_ids_all], axis=0 + ), + res_id=np.concatenate( + [ligand_ligand_bonds.res_id, res_ids_all], axis=0 + ), + ) + else: + adjusted_ligand_ligand_bonds = atom_layout.AtomLayout( + atom_name=atom_names_all, + chain_id=chain_ids_all, + res_id=res_ids_all, + ) + else: + adjusted_ligand_ligand_bonds = ligand_ligand_bonds + + return cls(**result), adjusted_ligand_ligand_bonds + + @classmethod + def from_data_dict(cls, batch: BatchDict) -> Self: + return cls( + positions=batch['ref_pos'], + mask=batch['ref_mask'], + element=batch['ref_element'], + charge=batch['ref_charge'], + atom_name_chars=batch['ref_atom_name_chars'], + ref_space_uid=batch['ref_space_uid'], + ) + + def as_data_dict(self) -> BatchDict: + return { + 'ref_pos': self.positions, + 'ref_mask': self.mask, + 'ref_element': self.element, + 'ref_charge': self.charge, + 'ref_atom_name_chars': self.atom_name_chars, + 'ref_space_uid': self.ref_space_uid, + } + + +@dataclasses.dataclass +class ConvertModelOutput: + """Contains atom layout info.""" + + cleaned_struc: structure.Structure + token_atoms_layout: atom_layout.AtomLayout + flat_output_layout: atom_layout.AtomLayout + empty_output_struc: structure.Structure + polymer_ligand_bonds: atom_layout.AtomLayout + ligand_ligand_bonds: atom_layout.AtomLayout + + @classmethod + def compute_features( + cls, + all_token_atoms_layout: atom_layout.AtomLayout, + padding_shapes: PaddingShapes, + cleaned_struc: structure.Structure, + flat_output_layout: atom_layout.AtomLayout, + empty_output_struc: structure.Structure, + polymer_ligand_bonds: atom_layout.AtomLayout, + ligand_ligand_bonds: atom_layout.AtomLayout, + ) -> Self: + """Pads the all_token_atoms_layout and stores other data.""" + # Crop and pad the all_token_atoms_layout. + token_atoms_layout = all_token_atoms_layout.copy_and_pad_to( + (padding_shapes.num_tokens, all_token_atoms_layout.shape[1]) + ) + + return cls( + cleaned_struc=cleaned_struc, + token_atoms_layout=token_atoms_layout, + flat_output_layout=flat_output_layout, + empty_output_struc=empty_output_struc, + polymer_ligand_bonds=polymer_ligand_bonds, + ligand_ligand_bonds=ligand_ligand_bonds, + ) + + @classmethod + def from_data_dict(cls, batch: BatchDict) -> Self: + """Construct atom layout object from dictionary.""" + + return cls( + cleaned_struc=_unwrap(batch.get('cleaned_struc', None)), + token_atoms_layout=_unwrap(batch.get('token_atoms_layout', None)), + flat_output_layout=_unwrap(batch.get('flat_output_layout', None)), + empty_output_struc=_unwrap(batch.get('empty_output_struc', None)), + polymer_ligand_bonds=_unwrap( + batch.get('polymer_ligand_bonds', None)), + ligand_ligand_bonds=_unwrap( + batch.get('ligand_ligand_bonds', None)), + ) + + def as_data_dict(self) -> BatchDict: + return { + 'cleaned_struc': np.array(self.cleaned_struc, object), + 'token_atoms_layout': np.array(self.token_atoms_layout, object), + 'flat_output_layout': np.array(self.flat_output_layout, object), + 'empty_output_struc': np.array(self.empty_output_struc, object), + 'polymer_ligand_bonds': np.array(self.polymer_ligand_bonds, object), + 'ligand_ligand_bonds': np.array(self.ligand_ligand_bonds, object), + } + + +@dataclasses.dataclass +class AtomCrossAtt: + """Operate on flat atoms.""" + + token_atoms_to_queries: atom_layout.GatherInfo + tokens_to_queries: atom_layout.GatherInfo + tokens_to_keys: atom_layout.GatherInfo + queries_to_keys: atom_layout.GatherInfo + queries_to_token_atoms: atom_layout.GatherInfo + + @classmethod + def compute_features( + cls, + # (num_tokens, num_dense) + all_token_atoms_layout: atom_layout.AtomLayout, + queries_subset_size: int, + keys_subset_size: int, + padding_shapes: PaddingShapes, + ) -> Self: + """Computes gather indices and meta data to work with a flat atom list.""" + + token_atoms_layout = all_token_atoms_layout.copy_and_pad_to( + (padding_shapes.num_tokens, all_token_atoms_layout.shape[1]) + ) + token_atoms_mask = token_atoms_layout.atom_name.astype(bool) + flat_layout = token_atoms_layout[token_atoms_mask] + num_atoms = flat_layout.shape[0] + + padded_flat_layout = flat_layout.copy_and_pad_to(( + padding_shapes.num_atoms, + )) + + # Create the layout for queries + num_subsets = padding_shapes.num_atoms // queries_subset_size + lay_arr = padded_flat_layout.to_array() + queries_layout = atom_layout.AtomLayout.from_array( + lay_arr.reshape((6, num_subsets, queries_subset_size)) + ) + + # Create the layout for the keys (the key subsets are centered around the + # query subsets) + # Create initial gather indices (contain out-of-bound indices) + subset_centers = np.arange( + queries_subset_size / 2, padding_shapes.num_atoms, queries_subset_size + ) + flat_to_key_gathers = ( + subset_centers[:, None] + + np.arange(-keys_subset_size / 2, keys_subset_size / 2)[None, :] + ) + flat_to_key_gathers = flat_to_key_gathers.astype(int) + # Shift subsets with out-of-bound indices, such that they are fully within + # the bounds. + for row in range(flat_to_key_gathers.shape[0]): + if flat_to_key_gathers[row, 0] < 0: + flat_to_key_gathers[row, :] -= flat_to_key_gathers[row, 0] + elif flat_to_key_gathers[row, -1] > num_atoms - 1: + overflow = flat_to_key_gathers[row, -1] - (num_atoms - 1) + flat_to_key_gathers[row, :] -= overflow + # Create the keys layout. + keys_layout = padded_flat_layout[flat_to_key_gathers] + + # Create gather indices for conversion between token atoms layout, + # queries layout and keys layout. + token_atoms_to_queries = atom_layout.compute_gather_idxs( + source_layout=token_atoms_layout, target_layout=queries_layout + ) + + token_atoms_to_keys = atom_layout.compute_gather_idxs( + source_layout=token_atoms_layout, target_layout=keys_layout + ) + + queries_to_keys = atom_layout.compute_gather_idxs( + source_layout=queries_layout, target_layout=keys_layout + ) + + queries_to_token_atoms = atom_layout.compute_gather_idxs( + source_layout=queries_layout, target_layout=token_atoms_layout + ) + + # Create gather indices for conversion of tokens layout to + # queries and keys layout + token_idxs = np.arange(padding_shapes.num_tokens).astype(np.int64) + token_idxs = np.broadcast_to( + token_idxs[:, None], token_atoms_layout.shape) + tokens_to_queries = atom_layout.GatherInfo( + gather_idxs=atom_layout.convert( + token_atoms_to_queries, token_idxs, layout_axes=(0, 1) + ), + gather_mask=atom_layout.convert( + token_atoms_to_queries, token_atoms_mask, layout_axes=(0, 1) + ), + input_shape=np.array((padding_shapes.num_tokens,)), + ) + + tokens_to_keys = atom_layout.GatherInfo( + gather_idxs=atom_layout.convert( + token_atoms_to_keys, token_idxs, layout_axes=(0, 1) + ), + gather_mask=atom_layout.convert( + token_atoms_to_keys, token_atoms_mask, layout_axes=(0, 1) + ), + input_shape=np.array((padding_shapes.num_tokens,)), + ) + + return cls( + token_atoms_to_queries=token_atoms_to_queries, + tokens_to_queries=tokens_to_queries, + tokens_to_keys=tokens_to_keys, + queries_to_keys=queries_to_keys, + queries_to_token_atoms=queries_to_token_atoms, + ) + + @classmethod + def from_data_dict(cls, batch: BatchDict) -> Self: + return cls( + token_atoms_to_queries=atom_layout.GatherInfo.from_dict( + batch, key_prefix='token_atoms_to_queries' + ), + tokens_to_queries=atom_layout.GatherInfo.from_dict( + batch, key_prefix='tokens_to_queries' + ), + tokens_to_keys=atom_layout.GatherInfo.from_dict( + batch, key_prefix='tokens_to_keys' + ), + queries_to_keys=atom_layout.GatherInfo.from_dict( + batch, key_prefix='queries_to_keys' + ), + queries_to_token_atoms=atom_layout.GatherInfo.from_dict( + batch, key_prefix='queries_to_token_atoms' + ), + ) + + def as_data_dict(self) -> BatchDict: + return { + **self.token_atoms_to_queries.as_dict( + key_prefix='token_atoms_to_queries' + ), + **self.tokens_to_queries.as_dict(key_prefix='tokens_to_queries'), + **self.tokens_to_keys.as_dict(key_prefix='tokens_to_keys'), + **self.queries_to_keys.as_dict(key_prefix='queries_to_keys'), + **self.queries_to_token_atoms.as_dict( + key_prefix='queries_to_token_atoms' + ), + } + + +@dataclasses.dataclass +class Frames: + """Features for backbone frames.""" + + mask: xnp_ndarray + + @classmethod + def compute_features( + cls, + all_tokens: atom_layout.AtomLayout, + all_token_atoms_layout: atom_layout.AtomLayout, + ref_structure: RefStructure, + padding_shapes: PaddingShapes, + ) -> Self: + """Computes features for backbone frames.""" + num_tokens = padding_shapes.num_tokens + all_token_atoms_layout = all_token_atoms_layout.copy_and_pad_to( + (num_tokens, all_token_atoms_layout.shape[1]) + ) + + all_token_atoms_to_all_tokens = atom_layout.compute_gather_idxs( + source_layout=all_token_atoms_layout, target_layout=all_tokens + ) + ref_coordinates = atom_layout.convert( + all_token_atoms_to_all_tokens, + ref_structure.positions.astype(np.float32), + layout_axes=(0, 1), + ) + ref_mask = atom_layout.convert( + all_token_atoms_to_all_tokens, + ref_structure.mask.astype(bool), + layout_axes=(0, 1), + ) + ref_mask = ref_mask & all_token_atoms_to_all_tokens.gather_mask.astype( + bool) + + all_frame_mask = [] + + # Iterate over tokens + for idx, args in enumerate( + zip(all_tokens.chain_type, all_tokens.chain_id, all_tokens.res_id) + ): + + chain_type, chain_id, res_id = args + + if chain_type in list(mmcif_names.PEPTIDE_CHAIN_TYPES): + frame_mask = True + elif chain_type in list(mmcif_names.NUCLEIC_ACID_CHAIN_TYPES): + frame_mask = True + elif chain_type in list(mmcif_names.NON_POLYMER_CHAIN_TYPES): + # For ligands, build frames from closest atoms from the same molecule. + (local_token_idxs,) = np.where( + (all_tokens.chain_type == chain_type) + & (all_tokens.chain_id == chain_id) + & (all_tokens.res_id == res_id) + ) + + if len(local_token_idxs) < 3: + frame_mask = False + + else: + # [local_tokens] + local_dist = np.linalg.norm( + ref_coordinates[idx] - ref_coordinates[local_token_idxs], axis=-1 + ) + local_mask = ref_mask[local_token_idxs] + cost = local_dist + 1e8 * ~local_mask + cost = cost + 1e8 * (idx == local_token_idxs) + # [local_tokens] + closest_idxs = np.argsort(cost, axis=0) + + # The closest indices index an array of local tokens. Convert this + # to indices of the full (num_tokens,) array. + global_closest_idxs = local_token_idxs[closest_idxs] + + # Construct frame by placing the current token at the origin and two + # nearest atoms on either side. + global_frame_idxs = np.array( + (global_closest_idxs[0], idx, global_closest_idxs[1]) + ) + + # Check that the frame atoms are not colinear. + a, b, c = ref_coordinates[global_frame_idxs] + vec1 = a - b + vec2 = c - b + # Reference coordinates can be all zeros, in which case we have + # to explicitly set colinearity. + if np.isclose(np.linalg.norm(vec1, axis=-1), 0) or np.isclose( + np.linalg.norm(vec2, axis=-1), 0 + ): + is_colinear = True + logging.info( + 'Found identical coordinates: Assigning as colinear.') + else: + vec1 = vec1 / np.linalg.norm(vec1, axis=-1) + vec2 = vec2 / np.linalg.norm(vec2, axis=-1) + cos_angle = np.einsum('...k,...k->...', vec1, vec2) + # <25 degree deviation is considered colinear. + is_colinear = 1 - np.abs(cos_angle) < 0.0937 + + frame_mask = not is_colinear + else: + # No frame for other chain types. + frame_mask = False + + all_frame_mask.append(frame_mask) + + all_frame_mask = np.array(all_frame_mask, dtype=bool) + + mask = _pad_to(all_frame_mask, (padding_shapes.num_tokens,)) + + return cls(mask=mask) + + @classmethod + def from_data_dict(cls, batch: BatchDict) -> Self: + return cls(mask=batch['frames_mask']) + + def as_data_dict(self) -> BatchDict: + return {'frames_mask': self.mask} diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/load_batch.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/load_batch.py new file mode 100644 index 0000000000000000000000000000000000000000..41cf2bf4827f319c80b57f684d2627b97e8c1796 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/load_batch.py @@ -0,0 +1,22 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +"""load data 'batch' used in test""" +import pickle +import mindspore as ms +from alphafold3.model.feat_batch import Batch + + +def load_batch(dtype=ms.float32): + """Load batch data for test""" + with open('/data/zmmVol2/AF3/test/unit_tests/model/diffusion/example_np.pkl', 'rb') as f: + data = pickle.load(f) + batch = Batch.from_data_dict(data) + batch.convert_to_tensor(dtype=dtype) + return batch diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/merging_features.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/merging_features.py new file mode 100644 index 0000000000000000000000000000000000000000..3c1fab8996d7484cbaf1a4414a60f891a54b0fd4 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/merging_features.py @@ -0,0 +1,92 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Methods for merging existing features to create a new example. + +Covers: +- Merging features across chains. +- Merging the paired and unpaired parts of the MSA. +""" + +from typing import TypeAlias + +from alphafold3.model import data_constants +import numpy as np + +NUM_SEQ_NUM_RES_MSA_FEATURES = data_constants.NUM_SEQ_NUM_RES_MSA_FEATURES +NUM_SEQ_MSA_FEATURES = data_constants.NUM_SEQ_MSA_FEATURES +MSA_PAD_VALUES = data_constants.MSA_PAD_VALUES + + +xnp_ndarray: TypeAlias = np.ndarray # pylint: disable=invalid-name +BatchDict: TypeAlias = dict[str, xnp_ndarray] + + +def _pad_features_to_max(feat_name: str, chains: list[BatchDict], axis: int): + """Pad a set of features to the maximum size amongst all chains. + + Args: + feat_name: The feature name to pad. + chains: A list of chains with associated features. + axis: Which axis to pad to the max. + + Returns: + A list of features, all with the same size on the given axis. + """ + max_num_seq = np.max([chain[feat_name].shape[axis] for chain in chains]) + + padded_feats = [] + for chain in chains: + feat = chain[feat_name] + + padding = np.zeros_like(feat.shape) # pytype: disable=attribute-error + # pytype: disable=attribute-error + padding[axis] = max_num_seq - feat.shape[axis] + padding = [(0, p) for p in padding] + padded_feats.append( + np.pad( + feat, + padding, + mode='constant', + constant_values=MSA_PAD_VALUES[feat_name], + ) + ) + return padded_feats + + +def merge_msa_features(feat_name: str, chains: list[BatchDict]) -> np.ndarray: + """Merges MSA features with shape (NUM_SEQ, NUM_RES) across chains.""" + expected_dtype = chains[0][feat_name].dtype + if '_all_seq' in feat_name: + return np.concatenate( + [c.get(feat_name, np.array([], expected_dtype)) for c in chains], axis=1 + ) + else: + # Since each MSA can be of different lengths, we first need to pad them + # all to the size of the largest MSA before concatenating. + padded_feats = _pad_features_to_max(feat_name, chains, axis=0) + return np.concatenate(padded_feats, axis=1) + + +def merge_paired_and_unpaired_msa(example: BatchDict) -> BatchDict: + """Concatenates the paired (all_seq) MSA features with the unpaired ones.""" + new_example = dict(example) + + for feature_name in NUM_SEQ_NUM_RES_MSA_FEATURES + NUM_SEQ_MSA_FEATURES: + if feature_name in example and feature_name + '_all_seq' in example: + feat = example[feature_name] + feat_all_seq = example[feature_name + '_all_seq'] + merged_feat = np.concatenate([feat_all_seq, feat], axis=0) + new_example[feature_name] = merged_feat + + new_example['num_alignments'] = np.array( + new_example['msa'].shape[0], dtype=np.int32 + ) + return new_example diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/mkdssp_pybind.cc b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/mkdssp_pybind.cc new file mode 100644 index 0000000000000000000000000000000000000000..663e7f3036e3817b4e339662184e276fb3755f59 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/mkdssp_pybind.cc @@ -0,0 +1,63 @@ +// Copyright 2024 DeepMind Technologies Limited +// +// AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +// this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +// +// To request access to the AlphaFold 3 model parameters, follow the process set +// out at https://github.com/google-deepmind/alphafold3. You may only use these +// if received directly from Google. Use is subject to terms of use available at +// https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +#include "alphafold3/model/mkdssp_pybind.h" + +#include + +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" + +namespace alphafold3 { +namespace py = pybind11; + +void RegisterModuleMkdssp(pybind11::module m) { + py::module site = py::module::import("site"); + py::list paths = py::cast(site.attr("getsitepackages")()); + // Find the first path that contains the libcifpp components.cif file. + bool found = false; + for (const auto& py_path : paths) { + auto path_str = + std::filesystem::path(py::cast(py_path)) / + "share/libcifpp/components.cif"; + if (std::filesystem::exists(path_str)) { + setenv("LIBCIFPP_DATA_DIR", path_str.parent_path().c_str(), 0); + found = true; + break; + } + } + if (!found) { + throw py::type_error("Could not find the libcifpp components.cif file."); + } + m.def( + "get_dssp", + [](absl::string_view mmcif, int model_no, + int min_poly_proline_stretch_length, + bool calculate_surface_accessibility) { + cif::file cif_file(mmcif.data(), mmcif.size()); + dssp result(cif_file.front(), model_no, min_poly_proline_stretch_length, + calculate_surface_accessibility); + std::stringstream sstream; + result.write_legacy_output(sstream); + return sstream.str(); + }, + py::arg("mmcif"), py::arg("model_no") = 1, + py::arg("min_poly_proline_stretch_length") = 3, + py::arg("calculate_surface_accessibility") = false, + py::doc("Gets secondary structure from an mmCIF file.")); +} + +} // namespace alphafold3 \ No newline at end of file diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/mkdssp_pybind.h b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/mkdssp_pybind.h new file mode 100644 index 0000000000000000000000000000000000000000..a1e4832b8d65ac28568424dbd2bf5001b896f646 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/mkdssp_pybind.h @@ -0,0 +1,26 @@ +/* + * Copyright 2024 DeepMind Technologies Limited + * + * AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of + * this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ + * + * To request access to the AlphaFold 3 model parameters, follow the process set + * out at https://github.com/google-deepmind/alphafold3. You may only use these + * if received directly from Google. Use is subject to terms of use available at + * https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + */ + +#ifndef ALPHAFOLD3_SRC_ALPHAFOLD3_MODEL_MKDSSP_PYBIND_H_ +#define ALPHAFOLD3_SRC_ALPHAFOLD3_MODEL_MKDSSP_PYBIND_H_ + + +#include "pybind11/pybind11.h" + +namespace alphafold3 { + +void RegisterModuleMkdssp(pybind11::module m); + +} + + +#endif // ALPHAFOLD3_SRC_ALPHAFOLD3_MODEL_MKDSSP_PYBIND_H_ diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/mmcif_metadata.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/mmcif_metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..31784589ac1a04de76848c03ad65e068b03734b8 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/mmcif_metadata.py @@ -0,0 +1,202 @@ +# Copyright 2024 DeepMind Technologies Limited +# Copyright (C) 2025 Huawei Technologies Co., Ltd +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# +# Modifications by Huawei Technologies Co., Ltd: Adapt to run by MindSpore on Ascend + +"""Adds mmCIF metadata (to be ModelCIF-conformant) and author and legal info.""" + +from typing import Final + +from alphafold3.structure import mmcif +import numpy as np + +_LICENSE_URL: Final[str] = ( + 'https://github.com/google-deepmind/alphafold3/blob/main/OUTPUT_TERMS_OF_USE.md' +) + +_LICENSE: Final[str] = f"""\ +Non-commercial use only, by using this file you agree to the terms of use found +at {_LICENSE_URL}. +To request access to the AlphaFold 3 model parameters, follow the process set +out at https://github.com/google-deepmind/alphafold3. You may only use these if +received directly from Google. Use is subject to terms of use available at +https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md. +""" + +_DISCLAIMER: Final[str] = """\ +AlphaFold 3 and its output are not intended for, have not been validated for, +and are not approved for clinical use. They are provided "as-is" without any +warranty of any kind, whether expressed or implied. No warranty is given that +use shall not infringe the rights of any third party. +""" + +_MMCIF_PAPER_AUTHORS: Final[tuple[str, ...]] = ( + 'Google DeepMind', + 'Isomorphic Labs', +) + +# Authors of the mmCIF - we set them to be equal to the authors of the paper. +_MMCIF_AUTHORS: Final[tuple[str, ...]] = _MMCIF_PAPER_AUTHORS + + +def add_metadata_to_mmcif( + old_cif: mmcif.Mmcif, model_id: bytes +) -> mmcif.Mmcif: + """Adds metadata to a mmCIF to make it ModelCIF-conformant.""" + cif = {} + + # ModelCIF conformation dictionary. + cif['_audit_conform.dict_name'] = ['mmcif_ma.dic'] +# cif['_audit_conform.dict_version'] = ['1.4.5'] + cif['_audit_conform.dict_location'] = [ + 'https://raw.githubusercontent.com/ihmwg/ModelCIF/master/dist/mmcif_ma.dic' + ] + + cif['_pdbx_data_usage.id'] = ['1', '2'] + cif['_pdbx_data_usage.type'] = ['license', 'disclaimer'] + cif['_pdbx_data_usage.details'] = [_LICENSE, _DISCLAIMER] + cif['_pdbx_data_usage.url'] = [_LICENSE_URL, '?'] + + # Structure author details. + cif['_audit_author.name'] = [] + cif['_audit_author.pdbx_ordinal'] = [] + for author_index, author_name in enumerate(_MMCIF_AUTHORS, start=1): + cif['_audit_author.name'].append(author_name) + cif['_audit_author.pdbx_ordinal'].append(str(author_index)) + + # Paper author details. + cif['_citation_author.citation_id'] = [] + cif['_citation_author.name'] = [] + cif['_citation_author.ordinal'] = [] + for author_index, author_name in enumerate(_MMCIF_PAPER_AUTHORS, start=1): + cif['_citation_author.citation_id'].append('primary') + cif['_citation_author.name'].append(author_name) + cif['_citation_author.ordinal'].append(str(author_index)) + + # Paper citation details. + cif['_citation.id'] = ['primary'] + cif['_citation.title'] = [ + 'Accurate structure prediction of biomolecular interactions with' + ' AlphaFold 3' + ] + cif['_citation.journal_full'] = ['Nature'] + cif['_citation.journal_volume'] = ['630'] + cif['_citation.page_first'] = ['493'] + cif['_citation.page_last'] = ['500'] + cif['_citation.year'] = ['2024'] + cif['_citation.journal_id_ASTM'] = ['NATUAS'] + cif['_citation.country'] = ['UK'] + cif['_citation.journal_id_ISSN'] = ['0028-0836'] + cif['_citation.journal_id_CSD'] = ['0006'] + cif['_citation.book_publisher'] = ['?'] + cif['_citation.pdbx_database_id_PubMed'] = ['38718835'] + cif['_citation.pdbx_database_id_DOI'] = ['10.1038/s41586-024-07487-w'] + + # Type of data in the dataset including data used in the model generation. + cif['_ma_data.id'] = ['1'] + cif['_ma_data.name'] = ['Model'] + cif['_ma_data.content_type'] = ['model coordinates'] + + # Description of number of instances for each entity. + cif['_ma_target_entity_instance.asym_id'] = old_cif['_struct_asym.id'] + cif['_ma_target_entity_instance.entity_id'] = old_cif[ + '_struct_asym.entity_id' + ] + cif['_ma_target_entity_instance.details'] = ['.'] * len( + cif['_ma_target_entity_instance.entity_id'] + ) + + # Details about the target entities. + cif['_ma_target_entity.entity_id'] = cif[ + '_ma_target_entity_instance.entity_id' + ] + cif['_ma_target_entity.data_id'] = ['1'] * len( + cif['_ma_target_entity.entity_id'] + ) + cif['_ma_target_entity.origin'] = ['.'] * len( + cif['_ma_target_entity.entity_id'] + ) + + # Details of the models being deposited. + cif['_ma_model_list.ordinal_id'] = ['1'] + cif['_ma_model_list.model_id'] = ['1'] + cif['_ma_model_list.model_group_id'] = ['1'] + cif['_ma_model_list.model_name'] = ['Top ranked model'] + + cif['_ma_model_list.model_group_name'] = [ + f'AlphaFold-beta-20231127' + ] + cif['_ma_model_list.data_id'] = ['1'] + cif['_ma_model_list.model_type'] = ['Ab initio model'] + + # Software used. + cif['_software.pdbx_ordinal'] = ['1'] + cif['_software.name'] = ['AlphaFold'] +# cif['_software.version'] = [ +# f'AlphaFold-beta-20231127 ({model_id.decode("ascii")})' +# ] + cif['_software.type'] = ['package'] + cif['_software.description'] = ['Structure prediction'] + cif['_software.classification'] = ['other'] + cif['_software.date'] = ['?'] + + # Collection of software into groups. + cif['_ma_software_group.ordinal_id'] = ['1'] + cif['_ma_software_group.group_id'] = ['1'] + cif['_ma_software_group.software_id'] = ['1'] + + # Method description to conform with ModelCIF. + cif['_ma_protocol_step.ordinal_id'] = ['1', '2', '3'] + cif['_ma_protocol_step.protocol_id'] = ['1', '1', '1'] + cif['_ma_protocol_step.step_id'] = ['1', '2', '3'] + cif['_ma_protocol_step.method_type'] = [ + 'coevolution MSA', + 'template search', + 'modeling', + ] + + # Details of the metrics use to assess model confidence. + cif['_ma_qa_metric.id'] = ['1', '2'] + cif['_ma_qa_metric.name'] = ['pLDDT', 'pLDDT'] + # Accepted values are distance, energy, normalised score, other, zscore. + cif['_ma_qa_metric.type'] = ['pLDDT', 'pLDDT'] + cif['_ma_qa_metric.mode'] = ['global', 'local'] + cif['_ma_qa_metric.software_group_id'] = ['1', '1'] + + # Global model confidence metric value. + cif['_ma_qa_metric_global.ordinal_id'] = ['1'] + cif['_ma_qa_metric_global.model_id'] = ['1'] + cif['_ma_qa_metric_global.metric_id'] = ['1'] + global_plddt = np.mean( + [float(v) for v in old_cif['_atom_site.B_iso_or_equiv']] + ) + cif['_ma_qa_metric_global.metric_value'] = [f'{global_plddt:.2f}'] + + cif['_atom_type.symbol'] = sorted(set(old_cif['_atom_site.type_symbol'])) + + return old_cif.copy_and_update(cif) + + +def add_legal_comment(cif: str) -> str: + """Adds legal comment at the top of the mmCIF.""" + # fmt: off + # pylint: disable=line-too-long + comment = ( + '# By using this file you agree to the legally binding terms of use found at\n' + f'# {_LICENSE_URL}.\n' + '# To request access to the AlphaFold 3 model parameters, follow the process set\n' + '# out at https://github.com/google-deepmind/alphafold3. You may only use these if\n' + '# received directly from Google. Use is subject to terms of use available at\n' + '# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md.' + ) + # pylint: enable=line-too-long + # fmt: on + return f'{comment}\n{cif}' diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/model_config.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/model_config.py new file mode 100644 index 0000000000000000000000000000000000000000..83cf9ce756a815a4c5683e51feaf15be1fd2b4e9 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/model_config.py @@ -0,0 +1,32 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Config for the protein folding model and experiment.""" + +from collections.abc import Sequence +from typing import Literal, TypeAlias + +from alphafold3.model import base_config +from alphafold3.utils.attention import attention + + +_Shape2DType: TypeAlias = tuple[int | None, int | None] + + +class GlobalConfig(base_config.BaseConfig): + bfloat16: Literal['all', 'none', 'intermediate'] = 'none' + final_init: Literal['zeros', 'linear'] = 'zeros' + pair_attention_chunk_size: Sequence[_Shape2DType] = ( + (1536, 128), (None, 32)) + pair_transition_shard_spec: Sequence[_Shape2DType] = ( + (2048, None), + (None, 1024), + ) + flash_attention_implementation: attention.Implementation = 'ms' diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/msa_pairing.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/msa_pairing.py new file mode 100644 index 0000000000000000000000000000000000000000..6d563eabd274b6936bb4501b4d455291b18fc13e --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/msa_pairing.py @@ -0,0 +1,316 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Functions for producing "paired" and "unpaired" MSA features for each chain. + +The paired MSA: +- Is made from the result of the all_seqs MSA query. +- Is ordered such that you can concatenate features across chains and related + sequences will end up on the same row. Related here means "from the same + species". Gaps are added to facilitate this whenever a sequence has no + suitable pair. + +The unpaired MSA: +- Is made from the results of the remaining MSA queries. +- Has no special ordering properties. +- Is deduplicated such that it doesn't contain any sequences in the paired MSA. +""" + +from typing import Mapping, MutableMapping, Sequence +from alphafold3.model import data_constants +import numpy as np + + +def _align_species( + all_species: Sequence[bytes], + chains_species_to_rows: Sequence[Mapping[bytes, np.ndarray]], + min_hits_per_species: Mapping[bytes, int], +) -> np.ndarray: + """Aligns MSA row indices based on species. + + Within a species, MSAs are aligned based on their original order (the first + sequence for a species in the first chain's MSA is aligned to the first + sequence for the same species in the second chain's MSA). + + Args: + all_species: A list of all unique species identifiers. + chains_species_to_rows: A dictionary for each chain, that maps species to + the set of MSA row indices from that species in that chain. + min_hits_per_species: A mapping from species id, to the minimum MSA size + across chains for that species (ignoring chains with zero hits). + + Returns: + A matrix of size [num_msa_rows, num_chains], where the i,j element is an + index into the jth chains MSA. Each row consists of sequences from each + chain for the same species (or -1 if that chain has no sequences for that + species). + """ + # Each species block is of size [num_seqs x num_chains] and consists of + # indices into the respective MSAs that have been aligned and are all for the + # same species. + species_blocks = [] + for species in all_species: + chain_row_indices = [] + for species_to_rows in chains_species_to_rows: + min_msa_size = min_hits_per_species[species] + if species not in species_to_rows: + # If a given chain has no hits for a species then we pad it with -1's, + # later on these values are used to make sure each feature is padded + # with its appropriate pad value. + row_indices = np.full( + min_msa_size, fill_value=-1, dtype=np.int32) + else: + # We crop down to the smallest MSA for a given species across chains. + row_indices = species_to_rows[species][:min_msa_size] + chain_row_indices.append(row_indices) + species_block = np.stack(chain_row_indices, axis=1) + species_blocks.append(species_block) + aligned_matrix = np.concatenate(species_blocks, axis=0) + return aligned_matrix + + +def create_paired_features( + chains: Sequence[MutableMapping[str, np.ndarray]], + max_paired_sequences: int, + nonempty_chain_ids: set[str], + max_hits_per_species: int, +) -> Sequence[MutableMapping[str, np.ndarray]]: + """Creates per-chain MSA features where the MSAs have been aligned. + + Args: + chains: A list of feature dicts, one for each chain. + max_paired_sequences: No more than this many paired sequences will be + returned from this function. + nonempty_chain_ids: A set of chain ids (str) that are included in the crop + there is no reason to process chains not in this list. + max_hits_per_species: No more than this number of sequences will be returned + for a given species. + + Returns: + An updated feature dictionary for each chain, where the {}_all_seq features + have been aligned so that the nth row in chain 1 is aligned to the nth row + in chain 2's features. + """ + # The number of chains that the given species appears in - we rank hits + # across more chains higher. + species_num_chains = {} + + # For each chain we keep a mapping from species to the row indices in the + # original MSA for that chain. + chains_species_to_rows = [] + + # Keep track of the minimum number of hits across chains for a given species. + min_hits_per_species = {} + + for chain in chains: + species_ids = chain['msa_species_identifiers_all_seq'] + + # The query gets an empty species_id, so no pairing happens for this row. + if ( + species_ids.size == 0 + or (species_ids.size == 1 and not species_ids[0]) + or chain['chain_id'] not in nonempty_chain_ids + ): + chains_species_to_rows.append({}) + continue + + # For each species keep track of which row indices in the original MSA are + # from this species. + row_indices = np.arange(len(species_ids)) + # The grouping np.split code requires that the input is already clustered + # by species id. + sort_idxs = species_ids.argsort() + species_ids = species_ids[sort_idxs] + row_indices = row_indices[sort_idxs] + + species, unique_row_indices = np.unique(species_ids, return_index=True) + grouped_row_indices = np.split(row_indices, unique_row_indices[1:]) + species_to_rows = dict(zip(species, grouped_row_indices, strict=True)) + chains_species_to_rows.append(species_to_rows) + + for s in species: + species_num_chains[s] = species_num_chains.get(s, 0) + 1 + + for species, row_indices in species_to_rows.items(): + min_hits_per_species[species] = min( + min_hits_per_species.get(species, max_hits_per_species), + len(row_indices), + ) + + # Construct a mapping from the number of chains a species appears in to + # the list of species with that count. + num_chains_to_species = {} + for species, num_chains in species_num_chains.items(): + if not species or num_chains <= 1: + continue + if num_chains not in num_chains_to_species: + num_chains_to_species[num_chains] = [] + num_chains_to_species[num_chains].append(species) + + num_rows_seen = 0 + # We always keep the first row as it is the query sequence. + all_rows = [np.array([[0] * len(chains)], dtype=np.int32)] + + # We prioritize species that have hits across more chains. + for num_chains in sorted(num_chains_to_species, reverse=True): + all_species = num_chains_to_species[num_chains] + + # Align all the per-chain row indices by species, so every paired row is + # for a single species. + rows = _align_species( + all_species, chains_species_to_rows, min_hits_per_species + ) + # Sort rows by the product of the original indices in the respective chain + # MSAS, so as to rank hits that appear higher in the original MSAs higher. + rank_metric = np.abs(np.prod(rows.astype(np.float32), axis=1)) + sorted_rows = rows[np.argsort(rank_metric), :] + all_rows.append(sorted_rows) + num_rows_seen += rows.shape[0] + if num_rows_seen >= max_paired_sequences: + break + + all_rows = np.concatenate(all_rows, axis=0) + all_rows = all_rows[:max_paired_sequences, :] + + # Now we just have to select the relevant rows from the original msa and + # deletion matrix features + paired_chains = [] + for chain_idx, chain in enumerate(chains): + out_chain = {k: v for k, v in chain.items() if 'all_seq' not in k} + selected_row_indices = all_rows[:, chain_idx] + for feat_name in {'msa', 'deletion_matrix'}: + all_seq_name = f'{feat_name}_all_seq' + feat_value = chain[all_seq_name] + + # The selected row indices are padded to be the same shape for each chain, + # they are padded with -1's, so we add a single row onto the feature with + # the appropriate pad value. This has the effect that we correctly pad + # each feature since all padded indices will select this padding row. + pad_value = data_constants.MSA_PAD_VALUES[feat_name] + feat_value = np.concatenate([ + feat_value, + np.full((1, feat_value.shape[1]), pad_value, feat_value.dtype), + ]) + + feat_value = feat_value[selected_row_indices, :] + out_chain[all_seq_name] = feat_value + out_chain['num_alignments_all_seq'] = np.array( + out_chain['msa_all_seq'].shape[0] + ) + paired_chains.append(out_chain) + return paired_chains + + +def deduplicate_unpaired_sequences( + np_chains: Sequence[MutableMapping[str, np.ndarray]], +) -> Sequence[MutableMapping[str, np.ndarray]]: + """Deduplicates unpaired sequences based on paired sequences.""" + + feature_names = np_chains[0].keys() + msa_features = ( + data_constants.NUM_SEQ_MSA_FEATURES + + data_constants.NUM_SEQ_NUM_RES_MSA_FEATURES + ) + + for chain in np_chains: + sequence_set = set( + hash(s.data.tobytes()) for s in chain['msa_all_seq'].astype(np.int8) + ) + keep_rows = [] + # Go through unpaired MSA seqs and remove any rows that correspond to the + # sequences that are already present in the paired MSA. + for row_num, seq in enumerate(chain['msa'].astype(np.int8)): + if hash(seq.data.tobytes()) not in sequence_set: + keep_rows.append(row_num) + for feature_name in feature_names: + if feature_name in msa_features: + chain[feature_name] = chain[feature_name][keep_rows] + chain['num_alignments'] = np.array( + chain['msa'].shape[0], dtype=np.int32) + return np_chains + + +def choose_paired_unpaired_msa_crop_sizes( + unpaired_msa: np.ndarray, + paired_msa: np.ndarray | None, + total_msa_crop_size: int, + max_paired_sequences: int, +) -> tuple[int, int | None]: + """Returns the sizes of the MSA crop and MSA_all_seq crop. + + NOTE: Unpaired + paired MSA sizes can exceed total_msa_size when + there are lots of gapped rows. Through the pairing logic another chain(s) + will have fewer than total_msa_size. + + Args: + unpaired_msa: The unpaired MSA array (not all_seq). + paired_msa: The paired MSA array (all_seq). + total_msa_crop_size: The maximum total number of sequences to crop to. + max_paired_sequences: The maximum number of sequences that can come from + MSA pairing. + + Returns: + A tuple of: + The size of the reduced MSA crop (not all_seq features). + The size of the unreduced MSA crop (for all_seq features) or None, if + paired_msa is None. + """ + if paired_msa is not None: + paired_crop_size = np.minimum( + paired_msa.shape[0], max_paired_sequences) + + # We reduce the number of un-paired sequences, by the number of times a + # sequence from this chains MSA is included in the paired MSA. This keeps + # the MSA size for each chain roughly constant. + cropped_all_seq_msa = paired_msa[:max_paired_sequences] + num_non_gapped_pairs = cropped_all_seq_msa.shape[0] + + assert num_non_gapped_pairs <= max_paired_sequences + unpaired_crop_size = np.minimum( + unpaired_msa.shape[0], total_msa_crop_size - num_non_gapped_pairs + ) + assert unpaired_crop_size >= 0 + else: + unpaired_crop_size = np.minimum( + unpaired_msa.shape[0], total_msa_crop_size) + paired_crop_size = None + return unpaired_crop_size, paired_crop_size + + +def remove_all_gapped_rows_from_all_seqs( + chains_list: Sequence[dict[str, np.ndarray]], asym_ids: Sequence[float] +) -> Sequence[dict[str, np.ndarray]]: + """Removes all gapped rows from all_seq feat based on selected asym_ids.""" + + merged_msa_all_seq = np.concatenate( + [ + chain['msa_all_seq'] + for chain in chains_list + if chain['asym_id'][0] in asym_ids + ], + axis=1, + ) + + non_gapped_keep_rows = np.any( + merged_msa_all_seq != data_constants.MSA_GAP_IDX, axis=1 + ) + for chain in chains_list: + for feat_name in list(chains_list)[0]: + if '_all_seq' in feat_name: + feat_name_split = feat_name.split('_all_seq')[0] + if feat_name_split in ( + data_constants.NUM_SEQ_NUM_RES_MSA_FEATURES + + data_constants.NUM_SEQ_MSA_FEATURES + ): + # For consistency we do this for all chains even though the + # gapped rows are based on a selected set asym_ids. + chain[feat_name] = chain[feat_name][non_gapped_keep_rows] + chain['num_alignments_all_seq'] = np.sum(non_gapped_keep_rows) + return chains_list diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/params.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/params.py new file mode 100644 index 0000000000000000000000000000000000000000..3c1d22df67ba6762caf2d0d9ab6adf52c0981972 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/params.py @@ -0,0 +1,218 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Model param loading.""" + +import bisect +import collections +from collections.abc import Iterator +import contextlib +import io +import os +import pathlib +import re +import struct +import sys +from typing import IO +import numpy as np + + +class RecordError(Exception): + """Error reading a record.""" + + +def encode_record(scope: str, name: str, arr: np.ndarray) -> bytes: + """Encodes a single haiku param as bytes, preserving non-numpy dtypes.""" + scope = scope.encode('utf-8') + name = name.encode('utf-8') + shape = arr.shape + dtype = str(arr.dtype).encode('utf-8') + arr = np.ascontiguousarray(arr) + if sys.byteorder == 'big': + arr = arr.byteswap() + arr_buffer = arr.tobytes('C') + header = struct.pack( + '<5i', len(scope), len(name), len(dtype), len(shape), len(arr_buffer) + ) + return header + b''.join( + (scope, name, dtype, struct.pack(f'{len(shape)}i', *shape), arr_buffer) + ) + + +def _read_record(stream: IO[bytes]) -> tuple[str, str, np.ndarray] | None: + """Reads a record encoded by `_encode_record` from a byte stream.""" + header_size = struct.calcsize('<5i') + header = stream.read(header_size) + if not header: + return None + if len(header) < header_size: + raise RecordError( + f'Incomplete header: {len(header)=} < {header_size=}') + (scope_len, name_len, dtype_len, shape_len, arr_buffer_len) = struct.unpack( + '<5i', header + ) + fmt = f'<{scope_len}s{name_len}s{dtype_len}s{shape_len}i' + payload_size = struct.calcsize(fmt) + arr_buffer_len + payload = stream.read(payload_size) + if len(payload) < payload_size: + raise RecordError( + f'Incomplete payload: {len(payload)=} < {payload_size=}') + scope, name, dtype, *shape = struct.unpack_from(fmt, payload) + scope = scope.decode('utf-8') + name = name.decode('utf-8') + dtype = dtype.decode('utf-8') + if dtype == 'bfloat16': + buffer = payload[-arr_buffer_len:] + if sys.byteorder == 'big': + buffer = buffer[::-1] + arr_uint16 = np.frombuffer(buffer, dtype=np.uint16) + arr_bf16 = arr_uint16.view('bfloat16') + arr = arr_bf16.astype(np.float32) + else: + arr = np.frombuffer(payload[-arr_buffer_len:], dtype=dtype) + if sys.byteorder == 'big': + arr = arr.byteswap() + arr = np.reshape(arr, shape) + if sys.byteorder == 'big': + arr = arr.byteswap() + return scope, name, arr + + +def read_records(stream: IO[bytes]) -> Iterator[tuple[str, str, np.ndarray]]: + """Fully reads the contents of a byte stream.""" + while record := _read_record(stream): + yield record + + +class _MultiFileIO(io.RawIOBase): + """A file-like object that presents a concatenated view of multiple files.""" + + def __init__(self, files: list[pathlib.Path]): + self._files = files + self._stack = contextlib.ExitStack() + self._handles = [ + self._stack.enter_context(file.open('rb')) for file in files + ] + self._sizes = [] + for handle in self._handles: + handle.seek(0, os.SEEK_END) + self._sizes.append(handle.tell()) + self._length = sum(self._sizes) + self._offsets = [0] + for s in self._sizes[:-1]: + self._offsets.append(self._offsets[-1] + s) + self._abspos = 0 + self._relpos = (0, 0) + + def _abs_to_rel(self, pos: int) -> tuple[int, int]: + idx = bisect.bisect_right(self._offsets, pos) - 1 + return idx, pos - self._offsets[idx] + + def close(self): + self._stack.close() + + def closed(self) -> bool: + return all(handle.closed for handle in self._handles) + + def fileno(self) -> int: + return -1 + + def readable(self) -> bool: + return True + + def tell(self) -> int: + return self._abspos + + def seek(self, pos: int, whence: int = os.SEEK_SET, /): + match whence: + case os.SEEK_SET: + pass + case os.SEEK_CUR: + pos += self._abspos + case os.SEEK_END: + pos = self._length - pos + case _: + raise ValueError(f'Invalid whence: {whence}') + self._abspos = pos + self._relpos = self._abs_to_rel(pos) + + def readinto(self, b: bytearray | memoryview) -> int: + result = 0 + mem = memoryview(b) + while mem: + self._handles[self._relpos[0]].seek(self._relpos[1]) + count = self._handles[self._relpos[0]].readinto(mem) + result += count + self._abspos += count + self._relpos = self._abs_to_rel(self._abspos) + mem = mem[count:] + if self._abspos == self._length: + break + return result + + +@contextlib.contextmanager +def open_for_reading(model_files: list[pathlib.Path], is_compressed: bool): + with contextlib.closing(_MultiFileIO(model_files)) as f: + yield f + + +def _match_model( + paths: list[pathlib.Path], pattern: re.Pattern[str] +) -> dict[str, list[pathlib.Path]]: + """Match files in a directory with a pattern, and group by model name.""" + models = collections.defaultdict(list) + for path in paths: + match = pattern.fullmatch(path.name) + if match: + models[match.group('model_name')].append(path) + return {k: sorted(v) for k, v in models.items()} + + +def select_model_files( + model_dir: pathlib.Path, model_name: str | None = None +) -> tuple[list[pathlib.Path], bool]: + """Select the model files from a model directory.""" + files = [file for file in model_dir.iterdir() if file.is_file()] + + for pattern, is_compressed in ( + (r'(?P.*)\.[0-9]+\.bin\.zst$', True), + (r'(?P.*)\.bin\.zst\.[0-9]+$', True), + (r'(?P.*)\.[0-9]+\.bin$', False), + (r'(?P.*)\.bin]\.[0-9]+$', False), + (r'(?P.*)\.bin\.zst$', True), + (r'(?P.*)\.bin$', False), + ): + models = _match_model(files, re.compile(pattern)) + if model_name is not None: + if model_name in models: + return models[model_name], is_compressed + else: + if models: + if len(models) > 1: + raise RuntimeError( + f'Multiple models matched in {model_dir}') + _, model_files = models.popitem() + return model_files, is_compressed + raise FileNotFoundError(f'No models matched in {model_dir}') + + +def get_model_af3_params(model_dir: pathlib.Path): + """Get the Haiku parameters from a model name.""" + params: dict[str, dict[str, np.array]] = {} + model_files, is_compressed = select_model_files(model_dir) + with open_for_reading(model_files, is_compressed) as stream: + for scope, name, arr in read_records(stream): + params.setdefault(scope, {})[name] = np.array(arr) + if not params: + raise FileNotFoundError(f'Model missing from "{model_dir}"') + return params diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/pipeline/inter_chain_bonds.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/pipeline/inter_chain_bonds.py new file mode 100644 index 0000000000000000000000000000000000000000..536c8b0caface73737af9f9dc83c236061612745 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/pipeline/inter_chain_bonds.py @@ -0,0 +1,348 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Functions for handling inter-chain bonds.""" + +from collections.abc import Collection +import functools +from typing import Final, NamedTuple +import numpy as np +from alphafold3 import structure +from alphafold3.constants import chemical_component_sets +from alphafold3.constants import mmcif_names +from alphafold3.model.atom_layout import atom_layout + + + +BOND_THRESHOLD_GLYCANS_ANGSTROM: Final[float] = 1.7 +# See https://pubs.acs.org/doi/10.1021/ja010331r for P-P atom bond distances. +BOND_THRESHOLD_ALL_ANGSTROM: Final[float] = 2.4 + + +class BondAtomArrays(NamedTuple): + chain_id: np.ndarray + chain_type: np.ndarray + res_id: np.ndarray + res_name: np.ndarray + atom_name: np.ndarray + coords: np.ndarray + + +def _get_bond_atom_arrays( + struct: structure.Structure, bond_atom_indices: np.ndarray +) -> BondAtomArrays: + return BondAtomArrays( + chain_id=struct.chain_id[bond_atom_indices], + chain_type=struct.chain_type[bond_atom_indices], + res_id=struct.res_id[bond_atom_indices], + res_name=struct.res_name[bond_atom_indices], + atom_name=struct.atom_name[bond_atom_indices], + coords=struct.coords[..., bond_atom_indices, :], + ) + + +@functools.lru_cache(maxsize=1) +def get_polymer_ligand_and_ligand_ligand_bonds( + struct: structure.Structure, + only_glycan_ligands: bool, + allow_multiple_bonds_per_atom: bool, +) -> tuple[atom_layout.AtomLayout, atom_layout.AtomLayout]: + """Return polymer-ligand & ligand-ligand inter-residue bonds. + + Args: + struct: Structure object to extract bonds from. + only_glycan_ligands: Whether to only include glycans in ligand category. + allow_multiple_bonds_per_atom: If not allowed, we greedily choose the first + bond seen per atom and discard the remaining on each atom.. + + Returns: + polymer_ligand, ligand_ligand_bonds: Each object is an AtomLayout object + [num_bonds, 2] for the bond-defining atoms. + """ + if only_glycan_ligands: + allowed_res_names = list({ + *chemical_component_sets.GLYCAN_OTHER_LIGANDS, + *chemical_component_sets.GLYCAN_LINKING_LIGANDS, + }) + else: + allowed_res_names = None + all_bonds = get_bond_layout( + bond_threshold=BOND_THRESHOLD_GLYCANS_ANGSTROM + if only_glycan_ligands + else BOND_THRESHOLD_ALL_ANGSTROM, + struct=struct, + allowed_chain_types1=list({ + *mmcif_names.LIGAND_CHAIN_TYPES, + *mmcif_names.POLYMER_CHAIN_TYPES, + }), + allowed_chain_types2=list(mmcif_names.LIGAND_CHAIN_TYPES), + allowed_res_names=allowed_res_names, + allow_multiple_bonds_per_atom=allow_multiple_bonds_per_atom, + ) + ligand_ligand_bonds_mask = np.isin( + all_bonds.chain_type, list(mmcif_names.LIGAND_CHAIN_TYPES) + ) + polymer_ligand_bonds_mask = np.isin( + all_bonds.chain_type, list(mmcif_names.POLYMER_CHAIN_TYPES) + ) + polymer_ligand_bonds_mask = np.logical_and( + ligand_ligand_bonds_mask.any(axis=1), + polymer_ligand_bonds_mask.any(axis=1), + ) + ligand_ligand_bonds = all_bonds[ligand_ligand_bonds_mask.all(axis=1)] + polymer_ligand_bonds = all_bonds[polymer_ligand_bonds_mask] + return polymer_ligand_bonds, ligand_ligand_bonds + + +def _remove_multi_bonds( + bond_layout: atom_layout.AtomLayout, +) -> atom_layout.AtomLayout: + """Remove instances greedily.""" + uids = {} + keep_indx = [] + for chain_id, res_id, atom_name in zip( + bond_layout.chain_id, + bond_layout.res_id, + bond_layout.atom_name, + strict=True, + ): + key1 = (chain_id[0], res_id[0], atom_name[0]) + key2 = (chain_id[1], res_id[1], atom_name[1]) + keep_indx.append(bool(key1 not in uids) and bool(key2 not in uids)) + if key1 not in uids: + uids[key1] = None + if key2 not in uids: + uids[key2] = None + return bond_layout[np.array(keep_indx, dtype=bool)] + + +@functools.lru_cache(maxsize=1) +def get_ligand_ligand_bonds( + struct: structure.Structure, + only_glycan_ligands: bool, + allow_multiple_bonds_per_atom: bool = False, +) -> atom_layout.AtomLayout: + """Return ligand-ligand inter-residue bonds. + + Args: + struct: Structure object to extract bonds from. + only_glycan_ligands: Whether to only include glycans in ligand category. + allow_multiple_bonds_per_atom: If not allowed, we greedily choose the first + bond seen per atom and discard the remaining on each atom. + + Returns: + bond_layout: AtomLayout object [num_bonds, 2] for the bond-defining atoms. + """ + if only_glycan_ligands: + allowed_res_names = list({ + *chemical_component_sets.GLYCAN_OTHER_LIGANDS, + *chemical_component_sets.GLYCAN_LINKING_LIGANDS, + }) + else: + allowed_res_names = None + return get_bond_layout( + bond_threshold=BOND_THRESHOLD_GLYCANS_ANGSTROM + if only_glycan_ligands + else BOND_THRESHOLD_ALL_ANGSTROM, + struct=struct, + allowed_chain_types1=list(mmcif_names.LIGAND_CHAIN_TYPES), + allowed_chain_types2=list(mmcif_names.LIGAND_CHAIN_TYPES), + allowed_res_names=allowed_res_names, + allow_multiple_bonds_per_atom=allow_multiple_bonds_per_atom, + ) + + +@functools.lru_cache(maxsize=1) +def get_polymer_ligand_bonds( + struct: structure.Structure, + only_glycan_ligands: bool, + allow_multiple_bonds_per_atom: bool = False, + bond_threshold: float | None = None, +) -> atom_layout.AtomLayout: + """Return polymer-ligand interchain bonds. + + Args: + struct: Structure object to extract bonds from. + only_glycan_ligands: Whether to only include glycans in ligand category. + allow_multiple_bonds_per_atom: If not allowed, we greedily choose the first + bond seen per atom and discard the remaining on each atom. + bond_threshold: Euclidean distance of max allowed bond. + + Returns: + bond_layout: AtomLayout object [num_bonds, 2] for the bond-defining atoms. + """ + if only_glycan_ligands: + allowed_res_names = list({ + *chemical_component_sets.GLYCAN_OTHER_LIGANDS, + *chemical_component_sets.GLYCAN_LINKING_LIGANDS, + }) + else: + allowed_res_names = None + if bond_threshold is None: + if only_glycan_ligands: + bond_threshold = BOND_THRESHOLD_GLYCANS_ANGSTROM + else: + bond_threshold = BOND_THRESHOLD_ALL_ANGSTROM + return get_bond_layout( + bond_threshold=bond_threshold, + struct=struct, + allowed_chain_types1=list(mmcif_names.POLYMER_CHAIN_TYPES), + allowed_chain_types2=list(mmcif_names.LIGAND_CHAIN_TYPES), + allowed_res_names=allowed_res_names, + allow_multiple_bonds_per_atom=allow_multiple_bonds_per_atom, + ) + + +def get_bond_layout( + bond_threshold: float = BOND_THRESHOLD_ALL_ANGSTROM, + *, + struct: structure.Structure, + allowed_chain_types1: Collection[str], + allowed_chain_types2: Collection[str], + include_bond_types: Collection[str] = ('covale',), + allowed_res_names: Collection[str] | None = None, + allow_multiple_bonds_per_atom: bool, +) -> atom_layout.AtomLayout: + """Get bond_layout for all bonds between two sets of chain types. + + There is a mask (all_mask) that runs through this script, and each bond pair + needs to maintain a True across all conditions in order to be preserved at the + end, otherwise the bond pair has invalidated a condition with a False and is + removed entirely. Note, we remove oxygen atom bonds as they are an edge case + that causes issues with scoring, due to multiple waters bonding with single + residues. + + Args: + bond_threshold: Maximum bond distance in Angstrom. + struct: Structure object to extract bonds from. + allowed_chain_types1: One end of the bonds must be an atom with one of these + chain types. + allowed_chain_types2: The other end of the bond must be an atom with one of + these chain types. + include_bond_types: Only include bonds with specified type e.g. hydrog, + metalc, covale, disulf. + allowed_res_names: Further restricts from chain_types. Either end of the + bonds must be an atom part of these res_names. If none all will be + accepted after chain and bond type filtering. + allow_multiple_bonds_per_atom: If not allowed, we greedily choose the first + bond seen per atom and discard the remaining on each atom. + + Returns: + bond_layout: AtomLayout object [num_bonds, 2] for the bond-defining atoms. + """ + if not struct.bonds: + return atom_layout.AtomLayout( + atom_name=np.empty((0, 2), dtype=object), + res_id=np.empty((0, 2), dtype=int), + res_name=np.empty((0, 2), dtype=object), + chain_id=np.empty((0, 2), dtype=object), + chain_type=np.empty((0, 2), dtype=object), + atom_element=np.empty((0, 2), dtype=object), + ) + from_atom_idxs, dest_atom_idxs = struct.bonds.get_atom_indices( + struct.atom_key + ) + from_atoms = _get_bond_atom_arrays(struct, from_atom_idxs) + dest_atoms = _get_bond_atom_arrays(struct, dest_atom_idxs) + # Chain type + chain_mask = np.logical_or( + np.logical_and( + np.isin( + from_atoms.chain_type, + allowed_chain_types1, + ), + np.isin( + dest_atoms.chain_type, + allowed_chain_types2, + ), + ), + np.logical_and( + np.isin( + from_atoms.chain_type, + allowed_chain_types2, + ), + np.isin( + dest_atoms.chain_type, + allowed_chain_types1, + ), + ), + ) + if allowed_res_names: + # Res type + res_mask = np.logical_or( + np.isin(from_atoms.res_name, allowed_res_names), + np.isin(dest_atoms.res_name, allowed_res_names), + ) + # All mask + all_mask = np.logical_and(chain_mask, res_mask) + else: + all_mask = chain_mask + # Bond type mask + type_mask = np.isin(struct.bonds.type, list(include_bond_types)) + np.logical_and(all_mask, type_mask, out=all_mask) + # Bond length check. Work in square length to avoid taking many square roots. + bond_length_squared = np.square(from_atoms.coords - dest_atoms.coords).sum( + axis=1 + ) + bond_threshold_squared = bond_threshold * bond_threshold + np.logical_and( + all_mask, bond_length_squared < bond_threshold_squared, out=all_mask + ) + # Inter-chain and inter-residue bonds for ligands + ligand_types = list(mmcif_names.LIGAND_CHAIN_TYPES) + is_ligand = np.logical_or( + np.isin( + from_atoms.chain_type, + ligand_types, + ), + np.isin( + dest_atoms.chain_type, + ligand_types, + ), + ) + res_id_differs = from_atoms.res_id != dest_atoms.res_id + chain_id_differs = from_atoms.chain_id != dest_atoms.chain_id + is_inter_res = np.logical_or(res_id_differs, chain_id_differs) + is_inter_ligand_res = np.logical_and(is_inter_res, is_ligand) + is_inter_chain_not_ligand = np.logical_and(chain_id_differs, ~is_ligand) + # If ligand then inter-res & inter-chain bonds, otherwise inter-chain only. + combined_allowed_bonds = np.logical_or( + is_inter_chain_not_ligand, is_inter_ligand_res + ) + np.logical_and(all_mask, combined_allowed_bonds, out=all_mask) + bond_layout = atom_layout.AtomLayout( + atom_name=np.stack( + [ + from_atoms.atom_name[all_mask], + dest_atoms.atom_name[all_mask], + ], + axis=1, + dtype=object, + ), + res_id=np.stack( + [from_atoms.res_id[all_mask], dest_atoms.res_id[all_mask]], + axis=1, + dtype=int, + ), + chain_id=np.stack( + [ + from_atoms.chain_id[all_mask], + dest_atoms.chain_id[all_mask], + ], + axis=1, + dtype=object, + ), + ) + if not allow_multiple_bonds_per_atom: + bond_layout = _remove_multi_bonds(bond_layout) + return atom_layout.fill_in_optional_fields( + bond_layout, + reference_atoms=atom_layout.atom_layout_from_structure(struct), + ) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/pipeline/pipeline.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/pipeline/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..1c08dccabe73c3cfc6706f3d4375c3fce875018b --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/pipeline/pipeline.py @@ -0,0 +1,446 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""The main featurizer.""" + +import bisect +from collections.abc import Sequence +import datetime +import itertools + +from absl import logging +from alphafold3.common import base_config +from alphafold3.common import folding_input +from alphafold3.constants import chemical_components +from alphafold3.model import feat_batch +from alphafold3.model import features +from alphafold3.model.pipeline import inter_chain_bonds +from alphafold3.model.pipeline import structure_cleaning +from alphafold3.structure import chemical_components as struc_chem_comps +import numpy as np +from alphafold3.common.folding_input import Template + + +_DETERMINISTIC_FRAMES_RANDOM_SEED = 12312837 + + +def calculate_bucket_size( + num_tokens: int, buckets: Sequence[int] | None +) -> int: + """Calculates the bucket size to pad the data to.""" + if buckets is None: + return num_tokens + + if not buckets: + raise ValueError('Buckets must be non-empty.') + + if not all(prev < curr for prev, curr in itertools.pairwise(buckets)): + raise ValueError( + f'Buckets must be in strictly increasing order. Got {buckets=}.' + ) + + bucket_idx = bisect.bisect_left(buckets, num_tokens) + + if bucket_idx == len(buckets): + logging.warning( + 'Creating a new bucket of size %d since the input has more tokens than' + ' the largest bucket size %d. This may trigger a re-compilation of the' + ' model. Consider additional large bucket sizes to avoid excessive' + ' re-compilation.', + num_tokens, + buckets[-1], + ) + return num_tokens + + return buckets[bucket_idx] + + +class NanDataError(Exception): + """Raised if the data pipeline produces data containing nans.""" + + +class TotalNumResOutOfRangeError(Exception): + """Raised if total number of residues for all chains outside allowed range.""" + + +class MmcifNumChainsError(Exception): + """Raised if the mmcif file contains too many / too few chains.""" + + +class WholePdbPipeline: + """Processes an entire mmcif entity and merges the content.""" + + class Config(base_config.BaseConfig): + """Configuration object for `WholePdbPipeline`. + + Properties: + max_atoms_per_token: number of atom slots in one token (was called + num_dense, and semi-hardcoded to 24 before) + pad_num_chains: Size to pad NUM_CHAINS feature dimensions to, only for + protein chains. + buckets: Bucket sizes to pad the data to, to avoid excessive + re-compilation of the model. If None, calculate the appropriate bucket + size from the number of tokens. If not None, must be a sequence of at + least one integer, in strictly increasing order. Will raise an error if + the number of tokens is more than the largest bucket size. + max_total_residues: Any mmCIF with more total residues will be rejected. + If none, then no limit is applied. + min_total_residues: Any mmCIF with less total residues will be rejected. + msa_crop_size: Maximum size of MSA to take across all chains. + max_template_date: Optional max template date to prevent data leakage in + validation. + max_templates: The maximum number of templates to send through the network + set to 0 to switch off templates. + filter_clashes: If true then will remove clashing chains. + filter_crystal_aids: If true ligands in the cryal aid list are removed. + max_paired_sequence_per_species: The maximum number of sequences per + species that will be used for MSA pairing. + drop_ligand_leaving_atoms: Flag for handling leaving atoms for ligands. + intra_ligand_ptm_bonds: Whether to embed intra ligand covalent bond graph. + average_num_atoms_per_token: Target average number of atoms per token to + compute the padding size for flat atoms. + atom_cross_att_queries_subset_size: queries subset size in atom cross + attention + atom_cross_att_keys_subset_size: keys subset size in atom cross attention + flatten_non_standard_residues: Whether to expand non-standard polymer + residues into flat-atom format. + remove_nonsymmetric_bonds: Whether to remove nonsymmetric bonds from + symmetric polymer chains. + deterministic_frames: Whether to use fixed-seed reference positions to + construct deterministic frames. + """ + + max_atoms_per_token: int = 24 + pad_num_chains: int = 1000 + buckets: list[int] | None = None + max_total_residues: int | None = None + min_total_residues: int | None = None + msa_crop_size: int = 16384 + max_template_date: datetime.date | None = None + max_templates: int = 4 + filter_clashes: bool = False + filter_crystal_aids: bool = False + max_paired_sequence_per_species: int = 600 + drop_ligand_leaving_atoms: bool = True + intra_ligand_ptm_bonds: bool = True + average_num_atoms_per_token: int = 24 + atom_cross_att_queries_subset_size: int = 32 + atom_cross_att_keys_subset_size: int = 128 + flatten_non_standard_residues: bool = True + remove_nonsymmetric_bonds: bool = False + deterministic_frames: bool = True + + def __init__( + self, + *, + config: Config, + ): + """Init WholePdb. + + Args: + config: Pipeline configuration. + """ + self._config = config + + def process_item( + self, + fold_input: folding_input.Input, + random_state: np.random.RandomState, + ccd: chemical_components.Ccd, + random_seed: int | None = None, + ) -> features.BatchDict: + """Takes requests from in_queue, adds (key, serialized ex) to out_queue.""" + if random_seed is None: + random_seed = random_state.randint(2**31) + + random_state = np.random.RandomState(seed=random_seed) + + logging_name = f'{fold_input.name}, random_seed={random_seed}' + logging.info('processing %s', logging_name) + struct = fold_input.to_structure(ccd=ccd) + + # Clean structure. + cleaned_struc, cleaning_metadata = structure_cleaning.clean_structure( + struct, + ccd=ccd, + drop_non_standard_atoms=True, + drop_missing_sequence=True, + filter_clashes=self._config.filter_clashes, + filter_crystal_aids=self._config.filter_crystal_aids, + filter_waters=True, + filter_hydrogens=True, + filter_leaving_atoms=self._config.drop_ligand_leaving_atoms, + only_glycan_ligands_for_leaving_atoms=True, + covalent_bonds_only=True, + remove_polymer_polymer_bonds=True, + remove_bad_bonds=True, + remove_nonsymmetric_bonds=self._config.remove_nonsymmetric_bonds, + ) + + num_clashing_chains_removed = cleaning_metadata[ + 'num_clashing_chains_removed' + ] + + if num_clashing_chains_removed: + logging.info( + 'Removed %d clashing chains from %s', + num_clashing_chains_removed, + logging_name, + ) + + # No chains after fixes + # if cleaned_struc.num_chains == 0: + # raise MmcifNumChainsError(f'{logging_name}: No chains in structure!') + + polymer_ligand_bonds, ligand_ligand_bonds = ( + inter_chain_bonds.get_polymer_ligand_and_ligand_ligand_bonds( + cleaned_struc, + only_glycan_ligands=False, + allow_multiple_bonds_per_atom=True, + ) + ) + + # If empty replace with None as this causes errors downstream. + if ligand_ligand_bonds and not ligand_ligand_bonds.atom_name.size: + ligand_ligand_bonds = None + if polymer_ligand_bonds and not polymer_ligand_bonds.atom_name.size: + polymer_ligand_bonds = None + + # Create the flat output AtomLayout + empty_output_struc, flat_output_layout = ( + structure_cleaning.create_empty_output_struct_and_layout( + struct=cleaned_struc, + ccd=ccd, + polymer_ligand_bonds=polymer_ligand_bonds, + ligand_ligand_bonds=ligand_ligand_bonds, + drop_ligand_leaving_atoms=self._config.drop_ligand_leaving_atoms, + ) + ) + + # Select the tokens for Evoformer. + # Each token (e.g. a residue) is encoded as one representative atom. This + # is flexible enough to allow the 1-token-per-atom ligand representation + # in the future. + all_tokens, all_token_atoms_layout, standard_token_idxs = ( + features.tokenizer( + flat_output_layout, + ccd=ccd, + max_atoms_per_token=self._config.max_atoms_per_token, + flatten_non_standard_residues=self._config.flatten_non_standard_residues, + logging_name=logging_name, + ) + ) + total_tokens = len(all_tokens.atom_name) + if ( + self._config.max_total_residues + and total_tokens > self._config.max_total_residues + ): + raise TotalNumResOutOfRangeError( + 'Total Number of Residues > max_total_residues: ' + f'({total_tokens} > {self._config.max_total_residues})' + ) + + if ( + self._config.min_total_residues + and total_tokens < self._config.min_total_residues + ): + raise TotalNumResOutOfRangeError( + 'Total Number of Residues < min_total_residues: ' + f'({total_tokens} < {self._config.min_total_residues})' + ) + + logging.info( + 'Calculating bucket size for input with %d tokens.', total_tokens + ) + padded_token_length = calculate_bucket_size( + total_tokens, self._config.buckets + ) + logging.info( + 'Got bucket size %d for input with %d tokens, resulting in %d padded' + ' tokens.', + padded_token_length, + total_tokens, + padded_token_length - total_tokens, + ) + + # Padding shapes for all features. + num_atoms = padded_token_length * self._config.average_num_atoms_per_token + # Round up to next multiple of subset size. + num_atoms = int( + np.ceil(num_atoms / self._config.atom_cross_att_queries_subset_size) + * self._config.atom_cross_att_queries_subset_size + ) + padding_shapes = features.PaddingShapes( + num_tokens=padded_token_length, + msa_size=self._config.msa_crop_size, + num_chains=self._config.pad_num_chains, + num_templates=self._config.max_templates, + num_atoms=num_atoms, + ) + + # Create the atom layouts for flat atom cross attention + batch_atom_cross_att = features.AtomCrossAtt.compute_features( + all_token_atoms_layout=all_token_atoms_layout, + queries_subset_size=self._config.atom_cross_att_queries_subset_size, + keys_subset_size=self._config.atom_cross_att_keys_subset_size, + padding_shapes=padding_shapes, + ) + + # Extract per-token features + batch_token_features = features.TokenFeatures.compute_features( + all_tokens=all_tokens, + padding_shapes=padding_shapes, + ) + + # Create reference structure features + chemical_components_data = struc_chem_comps.populate_missing_ccd_data( + ccd=ccd, + chemical_components_data=cleaned_struc.chemical_components_data, + populate_pdbx_smiles=True, + ) + + # Add smiles info to empty_output_struc. + empty_output_struc = empty_output_struc.copy_and_update_globals( + chemical_components_data=chemical_components_data + ) + # Create layouts and store structures for model output conversion. + batch_convert_model_output = features.ConvertModelOutput.compute_features( + all_token_atoms_layout=all_token_atoms_layout, + padding_shapes=padding_shapes, + cleaned_struc=cleaned_struc, + flat_output_layout=flat_output_layout, + empty_output_struc=empty_output_struc, + polymer_ligand_bonds=polymer_ligand_bonds, + ligand_ligand_bonds=ligand_ligand_bonds, + ) + + # Create the PredictedStructureInfo + batch_predicted_structure_info = ( + features.PredictedStructureInfo.compute_features( + all_tokens=all_tokens, + all_token_atoms_layout=all_token_atoms_layout, + padding_shapes=padding_shapes, + ) + ) + + # Create MSA features + batch_msa = features.MSA.compute_features( + all_tokens=all_tokens, + standard_token_idxs=standard_token_idxs, + padding_shapes=padding_shapes, + fold_input=fold_input, + logging_name=logging_name, + max_paired_sequence_per_species=self._config.max_paired_sequence_per_species, + ) + + # Create template features + batch_templates = features.Templates.compute_features( + all_tokens=all_tokens, + standard_token_idxs=standard_token_idxs, + padding_shapes=padding_shapes, + fold_input=fold_input, + max_templates=self._config.max_templates, + logging_name=logging_name, + ) + + ref_max_modified_date = self._config.max_template_date + batch_ref_structure, ligand_ligand_bonds = ( + features.RefStructure.compute_features( + all_token_atoms_layout=all_token_atoms_layout, + ccd=ccd, + padding_shapes=padding_shapes, + chemical_components_data=chemical_components_data, + random_state=random_state, + ref_max_modified_date=ref_max_modified_date, + intra_ligand_ptm_bonds=self._config.intra_ligand_ptm_bonds, + ligand_ligand_bonds=ligand_ligand_bonds, + ) + ) + deterministic_ref_structure = None + if self._config.deterministic_frames: + deterministic_ref_structure, _ = features.RefStructure.compute_features( + all_token_atoms_layout=all_token_atoms_layout, + ccd=ccd, + padding_shapes=padding_shapes, + chemical_components_data=chemical_components_data, + random_state=( + np.random.RandomState(_DETERMINISTIC_FRAMES_RANDOM_SEED) + ), + ref_max_modified_date=ref_max_modified_date, + intra_ligand_ptm_bonds=self._config.intra_ligand_ptm_bonds, + ligand_ligand_bonds=ligand_ligand_bonds, + ) + + # Create ligand-polymer bond features. + polymer_ligand_bond_info = features.PolymerLigandBondInfo.compute_features( + all_tokens=all_tokens, + all_token_atoms_layout=all_token_atoms_layout, + bond_layout=polymer_ligand_bonds, + padding_shapes=padding_shapes, + ) + # Create ligand-ligand bond features. + ligand_ligand_bond_info = features.LigandLigandBondInfo.compute_features( + all_tokens, + ligand_ligand_bonds, + padding_shapes, + ) + + # Create the Pseudo-beta layout for distogram head and distance error head. + batch_pseudo_beta_info = features.PseudoBetaInfo.compute_features( + all_token_atoms_layout=all_token_atoms_layout, + ccd=ccd, + padding_shapes=padding_shapes, + logging_name=logging_name, + ) + + # Frame construction. + batch_frames = features.Frames.compute_features( + all_tokens=all_tokens, + all_token_atoms_layout=all_token_atoms_layout, + ref_structure=( + deterministic_ref_structure + if self._config.deterministic_frames + else batch_ref_structure + ), + padding_shapes=padding_shapes, + ) + + # Assemble the Batch object. + batch = feat_batch.Batch( + msa=batch_msa, + templates=batch_templates, + token_features=batch_token_features, + ref_structure=batch_ref_structure, + predicted_structure_info=batch_predicted_structure_info, + polymer_ligand_bond_info=polymer_ligand_bond_info, + ligand_ligand_bond_info=ligand_ligand_bond_info, + pseudo_beta_info=batch_pseudo_beta_info, + atom_cross_att=batch_atom_cross_att, + convert_model_output=batch_convert_model_output, + frames=batch_frames, + ) + + np_example = batch.as_data_dict() + if 'num_iter_recycling' in np_example: + del np_example['num_iter_recycling'] # that does not belong here + + for name, value in np_example.items(): + if ( + value.dtype.kind not in {'U', 'S'} + and value.dtype.name != 'object' + and np.isnan(np.sum(value)) + ): + raise NanDataError( + 'The output of the data pipeline contained nans. ' + f'nan feature: {name}, fold input name: {fold_input.name}, ' + f'random_seed {random_seed}' + ) + + return np_example diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/pipeline/structure_cleaning.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/pipeline/structure_cleaning.py new file mode 100644 index 0000000000000000000000000000000000000000..0ca2505a3204bc614965c2455952f4305ec7c06c --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/pipeline/structure_cleaning.py @@ -0,0 +1,370 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Prepare PDB structure for training or inference.""" + +from typing import Any +import numpy as np +from absl import logging +from alphafold3 import structure +from alphafold3.constants import chemical_component_sets +from alphafold3.constants import chemical_components +from alphafold3.constants import mmcif_names +from alphafold3.model.atom_layout import atom_layout +from alphafold3.model.pipeline import inter_chain_bonds +from alphafold3.model.scoring import covalent_bond_cleaning +from alphafold3.structure import sterics + + +def _get_leaving_atom_mask( + struct: structure.Structure, + polymer_ligand_bonds: atom_layout.AtomLayout | None, + ligand_ligand_bonds: atom_layout.AtomLayout | None, + chain_id: str, + chain_type: str, + res_id: int, + res_name: str, +) -> np.ndarray: + """Updates a drop_leaving_atoms mask with new leaving atom locations.""" + bonded_atoms = atom_layout.get_bonded_atoms( + polymer_ligand_bonds, + ligand_ligand_bonds, + res_id, + chain_id, + ) + # Connect the amino-acids, i.e. remove OXT, HXT and H2. + drop_atoms = atom_layout.get_link_drop_atoms( + res_name=res_name, + chain_type=chain_type, + is_start_terminus=False, + is_end_terminus=False, + bonded_atoms=bonded_atoms, + drop_ligand_leaving_atoms=True, + ) + # Default mask where everything is false, which equates to being kept. + drop_atom_filter_atoms = struct.chain_id != struct.chain_id + for drop_atom in drop_atoms: + drop_atom_filter_atom = np.logical_and( + np.logical_and( + struct.atom_name == drop_atom, + struct.chain_id == chain_id, + ), + struct.res_id == res_id, + ) + drop_atom_filter_atoms = np.logical_or( + drop_atom_filter_atoms, drop_atom_filter_atom + ) + return drop_atom_filter_atoms + + +def clean_structure( + struct: structure.Structure, + ccd: chemical_components.Ccd, + *, + drop_missing_sequence: bool, + filter_clashes: bool, + drop_non_standard_atoms: bool, + filter_crystal_aids: bool, + filter_waters: bool, + filter_hydrogens: bool, + filter_leaving_atoms: bool, + only_glycan_ligands_for_leaving_atoms: bool, + covalent_bonds_only: bool, + remove_polymer_polymer_bonds: bool, + remove_bad_bonds: bool, + remove_nonsymmetric_bonds: bool, +) -> tuple[structure.Structure, dict[str, Any]]: + """Cleans structure. + + Args: + struct: Structure to clean. + ccd: The chemical components dictionary. + drop_missing_sequence: Whether to drop chains without specified sequences. + filter_clashes: Whether to drop clashing chains. + drop_non_standard_atoms: Whether to drop non CCD standard atoms. + filter_crystal_aids: Whether to drop ligands in the crystal aid set. + filter_waters: Whether to drop water chains. + filter_hydrogens: Whether to drop hyrdogen atoms. + filter_leaving_atoms: Whether to drop leaving atoms based on heuristics. + only_glycan_ligands_for_leaving_atoms: Whether to only include glycan + ligands when filtering leaving atoms. + covalent_bonds_only: Only include covalent bonds. + remove_polymer_polymer_bonds: Remove polymer-polymer bonds. + remove_bad_bonds: Whether to remove badly bonded ligands. + remove_nonsymmetric_bonds: Whether to remove nonsymmetric polymer-ligand + bonds from symmetric polymer chains. + + Returns: + Tuple of structure and metadata dict. The metadata dict has + information about what was cleaned from the original. + """ + + metadata = {} + # Crop crystallization aids. + if ( + filter_crystal_aids + and struct.structure_method in mmcif_names.CRYSTALLIZATION_METHODS + ): + struct = struct.filter_out( + res_name=chemical_component_sets.COMMON_CRYSTALLIZATION_AIDS + ) + + # Drop chains without specified sequences. + if drop_missing_sequence: + chains_with_unk_sequence = struct.find_chains_with_unknown_sequence() + num_with_unk_sequence = len(chains_with_unk_sequence) + if chains_with_unk_sequence: + struct = struct.filter_out(chain_id=chains_with_unk_sequence) + else: + num_with_unk_sequence = 0 + metadata['num_with_unk_sequence'] = num_with_unk_sequence + + # Remove intersecting chains. + if filter_clashes and struct.num_chains > 1: + clashing_chains = sterics.find_clashing_chains(struct) + if clashing_chains: + struct = struct.filter_out(chain_id=clashing_chains) + else: + clashing_chains = [] + metadata['num_clashing_chains_removed'] = len(clashing_chains) + metadata['chains_removed'] = clashing_chains + + # Drop non-standard atoms + if drop_non_standard_atoms: + struct = struct.drop_non_standard_atoms( + ccd=ccd, drop_unk=False, drop_non_ccd=False + ) + + # Sort chains in "reverse-spreadsheet" order. + struct = struct.with_sorted_chains + + if filter_hydrogens: + struct = struct.without_hydrogen() + + if filter_waters: + struct = struct.filter_out(chain_type=mmcif_names.WATER) + + if filter_leaving_atoms: + drop_leaving_atoms_all = struct.chain_id != struct.chain_id + polymer_ligand_bonds = inter_chain_bonds.get_polymer_ligand_bonds( + struct, + only_glycan_ligands=only_glycan_ligands_for_leaving_atoms, + ) + ligand_ligand_bonds = inter_chain_bonds.get_ligand_ligand_bonds( + struct, + only_glycan_ligands=only_glycan_ligands_for_leaving_atoms, + ) + all_glycans = { + *chemical_component_sets.GLYCAN_OTHER_LIGANDS, + *chemical_component_sets.GLYCAN_LINKING_LIGANDS, + } + # If only glycan ligands and no O1 atoms, we can do parallel drop. + if ( + only_glycan_ligands_for_leaving_atoms + and (not (ligand_ligand_bonds.atom_name == 'O1').any()) + and (not (polymer_ligand_bonds.atom_name == 'O1').any()) + ): + drop_leaving_atoms_all = np.logical_and( + np.isin(struct.atom_name, 'O1'), + np.isin(struct.res_name, list(all_glycans)), + ) + else: + substruct = struct.group_by_residue + glycan_mask = np.isin(substruct.res_name, list(all_glycans)) + substruct = substruct.filter(glycan_mask) + # We need to iterate over all glycan residues for this. + for res in substruct.iter_residues(): + # Only need to do drop leaving atoms for glycans depending on bonds. + if (res_name := res['res_name']) in all_glycans: + drop_atom_filter = _get_leaving_atom_mask( + struct=struct, + polymer_ligand_bonds=polymer_ligand_bonds, + ligand_ligand_bonds=ligand_ligand_bonds, + chain_id=res['chain_id'], + chain_type=res['chain_type'], + res_id=res['res_id'], + res_name=res_name, + ) + drop_leaving_atoms_all = np.logical_or( + drop_leaving_atoms_all, drop_atom_filter + ) + + num_atoms_before = struct.num_atoms + struct = struct.filter_out(drop_leaving_atoms_all) + num_atoms_after = struct.num_atoms + + if num_atoms_before > num_atoms_after: + logging.error( + 'Dropped %s atoms from GT struct: chain_id %s res_id %s res_name %s', + num_atoms_before - num_atoms_after, + struct.chain_id, + struct.res_id, + struct.res_name, + ) + + # Can filter by bond type without having to iterate over bonds. + if struct.bonds and covalent_bonds_only: + is_covalent = np.isin(struct.bonds.type, ['covale']) + if sum(is_covalent) > 0: + new_bonds = struct.bonds[is_covalent] + else: + new_bonds = structture.Bonds.make_empty() + struct = struct.copy_and_update(bonds=new_bonds) + + # Other bond filters require iterating over individual bonds. + if struct.bonds and (remove_bad_bonds or remove_polymer_polymer_bonds): + include_bond = [] + num_pp_bonds = 0 + num_bad_bonds = 0 + for bond in struct.iter_bonds(): + dest_atom = bond.dest_atom + from_atom = bond.from_atom + if remove_polymer_polymer_bonds: + if ( + from_atom['chain_type'] in mmcif_names.POLYMER_CHAIN_TYPES + and dest_atom['chain_type'] in mmcif_names.POLYMER_CHAIN_TYPES + ): + num_pp_bonds += 1 + include_bond.append(False) + continue + if remove_bad_bonds: + dest_coords = np.array( + [dest_atom['atom_x'], dest_atom['atom_y'], dest_atom['atom_z']] + ) + from_coords = np.array( + [from_atom['atom_x'], from_atom['atom_y'], from_atom['atom_z']] + ) + squared_dist = np.sum(np.square(dest_coords - from_coords)) + squared_threshold = 2.4 * 2.4 + if squared_dist > squared_threshold: + num_bad_bonds += 1 + include_bond.append(False) + continue + include_bond.append(True) + if sum(include_bond) < len(struct.bonds): + logging.info( + 'Reducing number of bonds for %s from %s to %s, of which %s are' + ' polymer-polymer bonds and %s are bad bonds.', + struct.name, + len(struct.bonds), + sum(include_bond), + num_pp_bonds, + num_bad_bonds, + ) + if sum(include_bond) > 0: + # Need to index bonds with bond keys or arrays of bools with same length + # as num bonds. In this case, we use array of bools (as elsewhere in the + # cleaning code). + new_bonds = struct.bonds[np.array(include_bond, dtype=bool)] + else: + new_bonds = structure.Bonds.make_empty() + struct = struct.copy_and_update(bonds=new_bonds) + + if struct.bonds and remove_nonsymmetric_bonds: + # Check for asymmetric polymer-ligand bonds and remove if these exist. + polymer_ligand_bonds = inter_chain_bonds.get_polymer_ligand_bonds( + struct, + only_glycan_ligands=False, + ) + if polymer_ligand_bonds: + if covalent_bond_cleaning.has_nonsymmetric_bonds_on_symmetric_polymer_chains( + struct, polymer_ligand_bonds + ): + from_atom_idxs, dest_atom_idxs = struct.bonds.get_atom_indices( + struct.atom_key + ) + poly_chain_types = list(mmcif_names.POLYMER_CHAIN_TYPES) + is_polymer_bond = np.logical_or( + np.isin( + struct.chain_type[from_atom_idxs], poly_chain_types), + np.isin( + struct.chain_type[dest_atom_idxs], poly_chain_types), + ) + struct = struct.copy_and_update( + bonds=struct.bonds[~is_polymer_bond]) + + return struct, metadata + + +def create_empty_output_struct_and_layout( + struct: structure.Structure, + ccd: chemical_components.Ccd, + *, + with_hydrogens: bool = False, + skip_unk: bool = False, + polymer_ligand_bonds: atom_layout.AtomLayout | None = None, + ligand_ligand_bonds: atom_layout.AtomLayout | None = None, + drop_ligand_leaving_atoms: bool = False, +) -> tuple[structure.Structure, atom_layout.AtomLayout]: + """Make zero-coordinate structure from all physical residues. + + Args: + struct: Structure object. + ccd: The chemical components dictionary. + with_hydrogens: Whether to keep hydrogen atoms in structure. + skip_unk: Whether to remove unknown residues from structure. + polymer_ligand_bonds: Bond information for polymer-ligand pairs. + ligand_ligand_bonds: Bond information for ligand-ligand pairs. + drop_ligand_leaving_atoms: Flag for handling leaving atoms for ligands. + + Returns: + Tuple of structure with all bonds, physical residues and coordinates set to + 0 and a flat atom layout of empty structure. + """ + bonded_atom_pairs = [] + if polymer_ligand_bonds: + for chain_ids, res_ids, atom_names in zip( + polymer_ligand_bonds.chain_id, + polymer_ligand_bonds.res_id, + polymer_ligand_bonds.atom_name, + strict=True, + ): + bonded_atom_pairs.append(( + (chain_ids[0], res_ids[0], atom_names[0]), + (chain_ids[1], res_ids[1], atom_names[1]), + )) + if ligand_ligand_bonds: + for chain_ids, res_ids, atom_names in zip( + ligand_ligand_bonds.chain_id, + ligand_ligand_bonds.res_id, + ligand_ligand_bonds.atom_name, + strict=True, + ): + bonded_atom_pairs.append(( + (chain_ids[0], res_ids[0], atom_names[0]), + (chain_ids[1], res_ids[1], atom_names[1]), + )) + residues = atom_layout.residues_from_structure( + struct, include_missing_residues=True + ) + + flat_output_layout = atom_layout.make_flat_atom_layout( + residues, + ccd=ccd, + with_hydrogens=with_hydrogens, + skip_unk_residues=skip_unk, + polymer_ligand_bonds=polymer_ligand_bonds, + ligand_ligand_bonds=ligand_ligand_bonds, + drop_ligand_leaving_atoms=drop_ligand_leaving_atoms, + ) + + empty_output_struct = atom_layout.make_structure( + flat_layout=flat_output_layout, + atom_coords=np.zeros((flat_output_layout.shape[0], 3)), + name=struct.name, + atom_b_factors=None, + all_physical_residues=residues, + ) + if bonded_atom_pairs: + empty_output_struct = empty_output_struct.add_bonds( + bonded_atom_pairs, bond_type=mmcif_names.COVALENT_BOND + ) + + return empty_output_struct, flat_output_layout diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/post_processing.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/post_processing.py new file mode 100644 index 0000000000000000000000000000000000000000..c56ed20fc6fcf18eda4841d4315b3a26acfe51fc --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/post_processing.py @@ -0,0 +1,114 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Post-processing utilities for AlphaFold inference results.""" + +import dataclasses +import datetime +import os + +# from alphafold3 import version +from alphafold3.model import confidence_types +from alphafold3.model import mmcif_metadata +from alphafold3.model.components import base_model +import numpy as np + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class ProcessedInferenceResult: + """Stores attributes of a processed inference result. + + Attributes: + cif: CIF file containing an inference result. + mean_confidence_1d: Mean 1D confidence calculated from confidence_1d. + ranking_score: Ranking score extracted from CIF metadata. + structure_confidence_summary_json: Content of JSON file with structure + confidences summary calculated from CIF file. + structure_full_data_json: Content of JSON file with structure full + confidences calculated from CIF file. + model_id: Identifier of the model that produced the inference result. + """ + + cif: bytes + mean_confidence_1d: float + ranking_score: float + structure_confidence_summary_json: bytes + structure_full_data_json: bytes + model_id: bytes + + +def post_process_inference_result( + inference_result: base_model.InferenceResult, +) -> ProcessedInferenceResult: + """Returns cif, confidence_1d_json, confidence_2d_json, mean_confidence_1d, and ranking confidence.""" + + # Add mmCIF metadata fields. + timestamp = datetime.datetime.now().isoformat(sep=' ', timespec='seconds') + cif_with_metadata = mmcif_metadata.add_metadata_to_mmcif( + old_cif=inference_result.predicted_structure.to_mmcif_dict(), + # version=f'{version.__version__} @ {timestamp}', + # version=None, + model_id=inference_result.model_id, + ) + cif = mmcif_metadata.add_legal_comment(cif_with_metadata.to_string()) + cif = cif.encode('utf-8') + confidence_1d = confidence_types.AtomConfidence.from_inference_result( + inference_result + ) + mean_confidence_1d = np.mean(confidence_1d.confidence) + structure_confidence_summary_json = ( + confidence_types.StructureConfidenceSummary.from_inference_result( + inference_result + ) + .to_json() + .encode('utf-8') + ) + structure_full_data_json = ( + confidence_types.StructureConfidenceFull.from_inference_result( + inference_result + ) + .to_json() + .encode('utf-8') + ) + return ProcessedInferenceResult( + cif=cif, + mean_confidence_1d=mean_confidence_1d, + ranking_score=float(inference_result.metadata['ranking_score']), + structure_confidence_summary_json=structure_confidence_summary_json, + structure_full_data_json=structure_full_data_json, + model_id=inference_result.model_id, + ) + + +def write_output( + inference_result: base_model.InferenceResult, + output_dir: os.PathLike[str] | str, + terms_of_use: str | None = None, + name: str | None = None, +) -> None: + """Writes processed inference result to a directory.""" + processed_result = post_process_inference_result(inference_result) + + prefix = f'{name}_' if name is not None else '' + + with open(os.path.join(output_dir, f'{prefix}model.cif'), 'wb') as f: + f.write(processed_result.cif) + + with open( + os.path.join(output_dir, f'{prefix}summary_confidences.json'), 'wb' + ) as f: + f.write(processed_result.structure_confidence_summary_json) + + with open(os.path.join(output_dir, f'{prefix}confidences.json'), 'wb') as f: + f.write(processed_result.structure_full_data_json) + + if terms_of_use is not None: + with open(os.path.join(output_dir, 'TERMS_OF_USE.md'), 'wt') as f: + f.write(terms_of_use) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/protein_data_processing.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/protein_data_processing.py new file mode 100644 index 0000000000000000000000000000000000000000..195db4c2775aedc37e820c25864c95d3ef0707ac --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/protein_data_processing.py @@ -0,0 +1,128 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Process Structure Data.""" + +from alphafold3.constants import atom_types +from alphafold3.constants import residue_names +from alphafold3.constants import side_chains +import numpy as np + + +NUM_DENSE = atom_types.DENSE_ATOM_NUM +NUM_AA = len(residue_names.PROTEIN_TYPES) +NUM_AA_WITH_UNK_AND_GAP = len( + residue_names.PROTEIN_TYPES_ONE_LETTER_WITH_UNKNOWN_AND_GAP +) +NUM_RESTYPES_WITH_UNK_AND_GAP = ( + residue_names.POLYMER_TYPES_NUM_WITH_UNKNOWN_AND_GAP +) + + +def _make_restype_rigidgroup_dense_atom_idx(): + """Create Mapping from rigid_groups to dense_atom indices.""" + # Create an array with the atom names. + # shape (num_restypes, num_rigidgroups, 3_atoms): + # (31, 8, 3) + base_atom_indices = np.zeros( + (NUM_RESTYPES_WITH_UNK_AND_GAP, 8, 3), dtype=np.int32 + ) + + # 4,5,6,7: 'chi1,2,3,4-group' + for restype, restype_letter in enumerate( + residue_names.PROTEIN_TYPES_ONE_LETTER + ): + resname = residue_names.PROTEIN_COMMON_ONE_TO_THREE[restype_letter] + + dense_atom_names = atom_types.ATOM14[resname] + # 0: backbone frame + base_atom_indices[restype, 0, :] = [ + dense_atom_names.index(atom) for atom in ['C', 'CA', 'N'] + ] + + # 3: 'psi-group' + base_atom_indices[restype, 3, :] = [ + dense_atom_names.index(atom) for atom in ['CA', 'C', 'O'] + ] + for chi_idx in range(4): + if side_chains.CHI_ANGLES_MASK[restype][chi_idx]: + atom_names = side_chains.CHI_ANGLES_ATOMS[resname][chi_idx] + base_atom_indices[restype, chi_idx + 4, :] = [ + dense_atom_names.index(atom) for atom in atom_names[1:] + ] + dense_atom_names = atom_types.DENSE_ATOM['A'] + nucleic_rigid_atoms = [ + dense_atom_names.index(atom) for atom in ["C1'", "C3'", "C4'"] + ] + for nanum, _ in enumerate(residue_names.NUCLEIC_TYPES): + # 0: backbone frame only. + # we have aa + unk + gap, so we want to start after those + resnum = nanum + NUM_AA_WITH_UNK_AND_GAP + base_atom_indices[resnum, 0, :] = nucleic_rigid_atoms + + return base_atom_indices + + +RESTYPE_RIGIDGROUP_DENSE_ATOM_IDX = _make_restype_rigidgroup_dense_atom_idx() + + +def _make_restype_pseudobeta_idx(): + """Returns indices of residue's pseudo-beta.""" + restype_pseudobeta_index = np.zeros( + (NUM_RESTYPES_WITH_UNK_AND_GAP,), dtype=np.int32 + ) + for restype, restype_letter in enumerate( + residue_names.PROTEIN_TYPES_ONE_LETTER + ): + restype_name = residue_names.PROTEIN_COMMON_ONE_TO_THREE[restype_letter] + atom_names = list(atom_types.ATOM14[restype_name]) + if restype_name in {'GLY'}: + restype_pseudobeta_index[restype] = atom_names.index('CA') + else: + restype_pseudobeta_index[restype] = atom_names.index('CB') + for nanum, resname in enumerate(residue_names.NUCLEIC_TYPES): + atom_names = list(atom_types.DENSE_ATOM[resname]) + # 0: backbone frame only. + # we have aa + unk , so we want to start after those + restype = nanum + NUM_AA_WITH_UNK_AND_GAP + if resname in {'A', 'G', 'DA', 'DG'}: + restype_pseudobeta_index[restype] = atom_names.index('C4') + else: + restype_pseudobeta_index[restype] = atom_names.index('C2') + return restype_pseudobeta_index + + +RESTYPE_PSEUDOBETA_INDEX = _make_restype_pseudobeta_idx() + + +def _make_aatype_dense_atom_to_atom37(): + """Map from dense_atom to atom37 per residue type.""" + restype_dense_atom_to_atom37 = [ + ] # mapping (restype, dense_atom) --> atom37 + for rt in residue_names.PROTEIN_TYPES_ONE_LETTER: + atom_names = list( + atom_types.ATOM14_PADDED[residue_names.PROTEIN_COMMON_ONE_TO_THREE[rt]] + ) + atom_names.extend([''] * (NUM_DENSE - len(atom_names))) + restype_dense_atom_to_atom37.append( + [(atom_types.ATOM37_ORDER[name] if name else 0) + for name in atom_names] + ) + # Add dummy mapping for restype 'UNK', '-' (gap), and nucleics [but not DN]. + for _ in range(2 + len(residue_names.NUCLEIC_TYPES_WITH_UNKNOWN)): + restype_dense_atom_to_atom37.append([0] * NUM_DENSE) + + restype_dense_atom_to_atom37 = np.array( + restype_dense_atom_to_atom37, dtype=np.int32 + ) + return restype_dense_atom_to_atom37 + + +PROTEIN_AATYPE_DENSE_ATOM_TO_ATOM37 = _make_aatype_dense_atom_to_atom37() diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/scoring/alignment.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/scoring/alignment.py new file mode 100644 index 0000000000000000000000000000000000000000..a4a8d225efc5599840de656ff711774959c809a7 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/scoring/alignment.py @@ -0,0 +1,146 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Alignment based metrics.""" + +import numpy as np + + +def transform_ls( + x: np.ndarray, + b: np.ndarray, + *, + allow_reflection: bool = False, +) -> np.ndarray: + """Find the least squares best fit rotation between two sets of N points. + + Solve Ax = b for A. Where A is the transform rotating x^T into b^T. + + Args: + x: NxD numpy array of coordinates. Usually dimension D is 3. + b: NxD numpy array of coordinates. Usually dimension D is 3. + allow_reflection: Whether the returned transformation can reflect as well as + rotate. + + Returns: + Matrix A transforming x into b, i.e. s.t. Ax^T = b^T. + """ + assert x.shape[1] >= b.shape[1] + assert b.shape[0] == x.shape[0], '%d, %d' % (b.shape[0], x.shape[0]) + # First postmultiply by x.; + # Axx^t = b x^t + bxt = np.dot(b.transpose(), x) / b.shape[0] + + u, _, v = np.linalg.svd(bxt) + + r = np.dot(u, v) + if not allow_reflection: + flip = np.ones((v.shape[1], 1)) + flip[v.shape[1] - 1, 0] = np.sign(np.linalg.det(r)) + r = np.dot(u, v * flip) + + return r + + +def align( + *, + x: np.ndarray, + y: np.ndarray, + x_indices: np.ndarray, + y_indices: np.ndarray, +) -> np.ndarray: + """Align x to y considering only included_idxs. + + Args: + x: NxD np array of coordinates. + y: NxD np array of coordinates. + x_indices: An np array of indices for `x` that will be used in the + alignment. Must be of the same length as `y_included_idxs`. + y_indices: An np array of indices for `y` that will be used in the + alignment. Must be of the same length as `x_included_idxs`. + + Returns: + NxD np array of points obtained by applying a rigid transformation to x. + These points are aligned to y and the alignment is the optimal alignment + over the points in included_idxs. + + Raises: + ValueError: If the number of included indices is not the same for both + input arrays. + """ + if len(x_indices) != len(y_indices): + raise ValueError( + 'Number of included indices must be the same for both input arrays,' + f' but got for x: {len(x_indices)}, and for y: {len(y_indices)}.' + ) + + x_mean = np.mean(x[x_indices, :], axis=0) + y_mean = np.mean(y[y_indices, :], axis=0) + + centered_x = x - x_mean + centered_y = y - y_mean + t = transform_ls(centered_x[x_indices, :], centered_y[y_indices, :]) + transformed_x = np.dot(centered_x, t.transpose()) + y_mean + + return transformed_x + + +def deviations_from_coords( + decoy_coords: np.ndarray, + gt_coords: np.ndarray, + align_idxs: np.ndarray | None = None, + include_idxs: np.ndarray | None = None, +) -> np.ndarray: + """Returns the raw per-atom deviations used in RMSD computation.""" + if decoy_coords.shape != gt_coords.shape: + raise ValueError( + 'decoy_coords.shape and gt_coords.shape must match.Found: %s and %s.' + % (decoy_coords.shape, gt_coords.shape) + ) + # Include and align all residues unless specified otherwise. + if include_idxs is None: + include_idxs = np.arange(decoy_coords.shape[0]) + if align_idxs is None: + align_idxs = include_idxs + aligned_decoy_coords = align( + x=decoy_coords, + y=gt_coords, + x_indices=align_idxs, + y_indices=align_idxs, + ) + deviations = np.linalg.norm( + aligned_decoy_coords[include_idxs] - gt_coords[include_idxs], axis=1 + ) + return deviations + + +def rmsd_from_coords( + decoy_coords: np.ndarray | str, + gt_coords: np.ndarray | str, + align_idxs: np.ndarray | None = None, + include_idxs: np.ndarray | None = None, +) -> float: + """Computes the *aligned* RMSD of two Mx3 np arrays of coordinates. + + Args: + decoy_coords: [M, 3] np array of decoy atom coordinates. + gt_coords: [M, 3] np array of gt atom coordinates. + align_idxs: [M] np array of indices specifying coordinates to align on. + Defaults to None, in which case all the include_idx (see after) are used. + include_idxs: [M] np array of indices specifying coordinates to score. + Defaults to None, in which case all indices are used for scoring. + + Returns: + rmsd value of the aligned decoy and gt coordinates. + """ + deviations = deviations_from_coords( + decoy_coords, gt_coords, align_idxs, include_idxs + ) + return np.sqrt(np.mean(np.square(deviations))) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/scoring/covalent_bond_cleaning.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/scoring/covalent_bond_cleaning.py new file mode 100644 index 0000000000000000000000000000000000000000..abc38ce6e70b751dd4dd68b7cdbba2c356088824 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/scoring/covalent_bond_cleaning.py @@ -0,0 +1,265 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Some methods to compute metrics for PTMs.""" + +import collections +from collections.abc import Mapping +import dataclasses +import numpy as np +from alphafold3 import structure +from alphafold3.constants import mmcif_names +from alphafold3.model.atom_layout import atom_layout + + +@dataclasses.dataclass(frozen=True) +class ResIdMapping: + old_res_ids: np.ndarray + new_res_ids: np.ndarray + + +def _count_symmetric_chains(struct: structure.Structure) -> Mapping[str, int]: + """Returns a dict with each chain ID and count.""" + chain_res_name_sequence_from_chain_id = struct.chain_res_name_sequence( + include_missing_residues=True, fix_non_standard_polymer_res=False + ) + counts_for_chain_res_name_sequence = collections.Counter( + chain_res_name_sequence_from_chain_id.values() + ) + chain_symmetric_count = {} + for chain_id, chain_res_name in chain_res_name_sequence_from_chain_id.items(): + chain_symmetric_count[chain_id] = counts_for_chain_res_name_sequence[ + chain_res_name + ] + return chain_symmetric_count + + +def has_nonsymmetric_bonds_on_symmetric_polymer_chains( + struct: structure.Structure, polymer_ligand_bonds: atom_layout.AtomLayout +) -> bool: + """Returns true if nonsymmetric bonds found on polymer chains.""" + try: + _get_polymer_dim(polymer_ligand_bonds) + except ValueError: + return True + if _has_non_polymer_ligand_ptm_bonds(polymer_ligand_bonds): + return True + if _has_multiple_polymers_bonded_to_one_ligand(polymer_ligand_bonds): + return True + combined_struct, _ = _combine_polymer_ligand_ptm_chains( + struct, polymer_ligand_bonds + ) + struct = struct.filter(chain_type=mmcif_names.POLYMER_CHAIN_TYPES) + combined_struct = combined_struct.filter( + chain_type=mmcif_names.POLYMER_CHAIN_TYPES + ) + return _count_symmetric_chains(struct) != _count_symmetric_chains( + combined_struc + ) + + +def _has_non_polymer_ligand_ptm_bonds( + polymer_ligand_bonds: atom_layout.AtomLayout, +): + """Checks if all bonds are between a polymer chain and a ligand chain type.""" + for start_chain_type, end_chain_type in polymer_ligand_bonds.chain_type: + if ( + start_chain_type in mmcif_names.POLYMER_CHAIN_TYPES + and end_chain_type in mmcif_names.LIGAND_CHAIN_TYPES + ): + continue + elif ( + start_chain_type in mmcif_names.LIGAND_CHAIN_TYPES + and end_chain_type in mmcif_names.POLYMER_CHAIN_TYPES + ): + continue + else: + return True + return False + + +def _combine_polymer_ligand_ptm_chains( + struct: structure.Structure, + polymer_ligand_bonds: atom_layout.AtomLayout, +) -> tuple[structure.Structure, dict[tuple[str, str], ResIdMapping]]: + """Combines the ptm polymer-ligand chains together. + + This will prevent them from being permuted away from each other when chains + are matched to the ground truth. This function also returns the res_id mapping + from the separate ligand res_ids to their res_ids in the combined + polymer-ligand chain; this information is needed to later separate the + combined polymer-ligand chain. + + Args: + struct: Structure to be modified. + polymer_ligand_bonds: AtomLayout with polymer-ligand bond info. + + Returns: + A tuple of a Structure with each ptm polymer-ligand chain relabelled as one + chain and a dict from bond chain pair to the res_id mapping. + """ + if not _has_only_single_bond_from_each_chain(polymer_ligand_bonds): + if _has_multiple_ligands_bonded_to_one_polymer(polymer_ligand_bonds): + # For structures where a polymer chain is connected to multiple ligands, + # we need to sort the multiple bonds from the same chain by res_id to + # ensure that the combined polymer-ligand chain will always be the same + # when you have repeated symmetric polymer-ligand chains. + polymer_ligand_bonds = ( + _sort_polymer_ligand_bonds_by_polymer_chain_and_res_id( + polymer_ligand_bonds + ) + ) + else: + raise ValueError( + 'Code cannot handle multiple bonds from one chain unless' + ' its several ligands bonded to a polymer.' + ) + res_id_mappings_for_bond_chain_pair = dict() + for (start_chain_id, end_chain_id), (start_chain_type, end_chain_type) in zip( + polymer_ligand_bonds.chain_id, polymer_ligand_bonds.chain_type + ): + poly_info, ligand_info = _get_polymer_and_ligand_chain_ids_and_types( + start_chain_id, end_chain_id, start_chain_type, end_chain_type + ) + polymer_chain_id, polymer_chain_type = poly_info + ligand_chain_id, _ = ligand_info + + # Join the ligand chain to the polymer chain. + ligand_res_ids = struct.filter(chain_id=ligand_chain_id).res_id + new_res_ids = ligand_res_ids + \ + len(struct.all_residues[polymer_chain_id]) + res_id_mappings_for_bond_chain_pair[(polymer_chain_id, ligand_chain_id)] = ( + ResIdMapping(old_res_ids=ligand_res_ids, new_res_ids=new_res_ids) + ) + chain_groups = [] + chain_group_ids = [] + chain_group_types = [] + for chain_id, chain_type in zip( + struct.chains_table.id, struct.chains_table.type + ): + if chain_id == ligand_chain_id: + continue + elif chain_id == polymer_chain_id: + chain_groups.append([polymer_chain_id, ligand_chain_id]) + chain_group_ids.append(polymer_chain_id) + chain_group_types.append(polymer_chain_type) + else: + chain_groups.append([chain_id]) + chain_group_ids.append(chain_id) + chain_group_types.append(chain_type) + + struct = struct.merge_chains( + chain_groups=chain_groups, + chain_group_ids=chain_group_ids, + chain_group_types=chain_group_types, + ) + + return struct, res_id_mappings_for_bond_chain_pair + + +def _has_only_single_bond_from_each_chain( + polymer_ligand_bonds: atom_layout.AtomLayout, +) -> bool: + """Checks that there is at most one bond from each chain.""" + chain_ids = [] + for chains in polymer_ligand_bonds.chain_id: + chain_ids.extend(chains) + if len(chain_ids) != len(set(chain_ids)): + return False + return True + + +def _get_polymer_and_ligand_chain_ids_and_types( + start_chain_id: str, + end_chain_id: str, + start_chain_type: str, + end_chain_type: str, +) -> tuple[tuple[str, str], tuple[str, str]]: + """Finds polymer and ligand chain ids from chain types.""" + if ( + start_chain_type in mmcif_names.POLYMER_CHAIN_TYPES + and end_chain_type in mmcif_names.LIGAND_CHAIN_TYPES + ): + return (start_chain_id, start_chain_type), (end_chain_id, end_chain_type) + elif ( + start_chain_type in mmcif_names.LIGAND_CHAIN_TYPES + and end_chain_type in mmcif_names.POLYMER_CHAIN_TYPES + ): + return (end_chain_id, end_chain_type), (start_chain_id, start_chain_type) + else: + raise ValueError( + 'This code only handles PTM-bonds from polymer chain to ligands.' + ) + + +def _get_polymer_dim(polymer_ligand_bonds: atom_layout.AtomLayout) -> int: + """Gets polymer dimension from the polymer-ligand bond layout.""" + start_chain_types = [] + end_chain_types = [] + for start_chain_type, end_chain_type in polymer_ligand_bonds.chain_type: + start_chain_types.append(start_chain_type) + end_chain_types.append(end_chain_type) + if set(start_chain_types).issubset( + set(mmcif_names.POLYMER_CHAIN_TYPES) + ) and set(end_chain_types).issubset(set(mmcif_names.LIGAND_CHAIN_TYPES)): + return 0 + elif set(start_chain_types).issubset(mmcif_names.LIGAND_CHAIN_TYPES) and set( + end_chain_types + ).issubset(set(mmcif_names.POLYMER_CHAIN_TYPES)): + return 1 + else: + raise ValueError( + 'Polymer and ligand dimensions are not consistent within the structure.' + ) + + +def _has_multiple_ligands_bonded_to_one_polymer(polymer_ligand_bonds): + """Checks if there are multiple ligands bonded to one polymer.""" + polymer_dim = _get_polymer_dim(polymer_ligand_bonds) + polymer_chain_ids = [ + chains[polymer_dim] for chains in polymer_ligand_bonds.chain_id + ] + if len(polymer_chain_ids) != len(set(polymer_chain_ids)): + return True + return False + + +def _has_multiple_polymers_bonded_to_one_ligand(polymer_ligand_bonds): + """Checks if there are multiple polymer chains bonded to one ligand.""" + polymer_dim = _get_polymer_dim(polymer_ligand_bonds) + ligand_dim = 1 - polymer_dim + ligand_chain_ids = [ + chains[ligand_dim] for chains in polymer_ligand_bonds.chain_id + ] + if len(ligand_chain_ids) != len(set(ligand_chain_ids)): + return True + return False + + +def _sort_polymer_ligand_bonds_by_polymer_chain_and_res_id( + polymer_ligand_bonds, +): + """Sorts bonds by res_id (for when a polymer chain has multiple bonded ligands).""" + + polymer_dim = _get_polymer_dim(polymer_ligand_bonds) + + polymer_chain_ids = [ + chains[polymer_dim] for chains in polymer_ligand_bonds.chain_id + ] + polymer_res_ids = [res[polymer_dim] for res in polymer_ligand_bonds.res_id] + + polymer_chain_and_res_id = zip(polymer_chain_ids, polymer_res_ids) + sorted_indices = [ + idx + for idx, _ in sorted( + enumerate(polymer_chain_and_res_id), key=lambda x: x[1] + ) + ] + return polymer_ligand_bonds[sorted_indices] diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/scoring/scoring.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/scoring/scoring.py new file mode 100644 index 0000000000000000000000000000000000000000..017210f92454a6e8ed1e8c265d9ffda476bde85e --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/scoring/scoring.py @@ -0,0 +1,67 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Library of scoring methods of the model outputs.""" + +from alphafold3.model import protein_data_processing +import numpy as np + + +Array = np.ndarray + + +def pseudo_beta_fn( + aatype: Array, + dense_atom_positions: Array, + dense_atom_masks: Array, + is_ligand: Array | None = None, + use_jax: bool | None = True, + ) -> tuple[Array, Array] | Array: + """Create pseudo beta atom positions and optionally mask. + + Args: + aatype: [num_res] amino acid types. + dense_atom_positions: [num_res, NUM_DENSE, 3] vector of all atom positions. + dense_atom_masks: [num_res, NUM_DENSE] mask. + is_ligand: [num_res] flag if something is a ligand. + use_jax: whether to use jax for the computations. + + Returns: + Pseudo beta dense atom positions and the corresponding mask. + """ + + if is_ligand is None: + is_ligand = np.zeros_like(aatype) + + pseudobeta_index_polymer = np.take( + protein_data_processing.RESTYPE_PSEUDOBETA_INDEX, aatype, axis=0 + ).astype(np.int32) + + pseudobeta_index = np.where( + is_ligand, + np.zeros_like(pseudobeta_index_polymer), + pseudobeta_index_polymer, + ) + + if not isinstance(dense_atom_positions, Array): + dense_atom_positions = dense_atom_positions.asnumpy() + if not isinstance(dense_atom_masks, Array): + dense_atom_masks = dense_atom_masks.asnumpy() + pseudo_beta = np.take_along_axis( + dense_atom_positions, pseudobeta_index[..., None, None], axis=-2 + ) + pseudo_beta = np.squeeze(pseudo_beta, axis=-2) + + pseudo_beta_mask = np.take_along_axis( + dense_atom_masks, pseudobeta_index[..., None], axis=-1 + ).astype(np.float32) + pseudo_beta_mask = np.squeeze(pseudo_beta_mask, axis=-1) + + return pseudo_beta, pseudo_beta_mask diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/cif_dict.pyi b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/cif_dict.pyi new file mode 100644 index 0000000000000000000000000000000000000000..09d915c845fec7fe0235c7780092f90f1a262fd1 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/cif_dict.pyi @@ -0,0 +1,125 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +from typing import Any, ClassVar, Iterable, Iterator, TypeVar, overload + +import numpy as np + +_T = TypeVar('_T') + +class CifDict: + class ItemView: + def __iter__(self) -> Iterator[tuple[str, list[str]]]: ... + def __len__(self) -> int: ... + + class KeyView: + @overload + def __contains__(self, key: str) -> bool: ... + @overload + def __contains__(self, key: object) -> bool: ... + def __iter__(self) -> Iterator[str]: ... + def __len__(self) -> int: ... + + class ValueView: + def __iter__(self) -> Iterator[list[str]]: ... + def __len__(self) -> int: ... + + def __init__(self, d: dict[str, Iterable[str]]) -> None: ... + def copy_and_update(self, d: dict[str, Iterable[str]]) -> CifDict: ... + def extract_loop_as_dict(self, prefix: str, index: str) -> dict: + """Extracts loop associated with a prefix from mmCIF data as a dict. + + For instance for an mmCIF with these fields: + '_a.ix': ['1', '2', '3'] + '_a.1': ['a.1.1', 'a.1.2', 'a.1.3'] + '_a.2': ['a.2.1', 'a.2.2', 'a.2.3'] + + this function called with prefix='_a.', index='_a.ix' extracts: + {'1': {'a.ix': '1', 'a.1': 'a.1.1', 'a.2': 'a.2.1'} + '2': {'a.ix': '2', 'a.1': 'a.1.2', 'a.2': 'a.2.2'} + '3': {'a.ix': '3', 'a.1': 'a.1.3', 'a.2': 'a.2.3'}} + + Args: + prefix: Prefix shared by each of the data items in the loop. The prefix + should include the trailing period. + index: Which item of loop data should serve as the key. + + Returns: + Dict of dicts; each dict represents 1 entry from an mmCIF loop, + indexed by the index column. + """ + + def extract_loop_as_list(self, prefix: str) -> list: + """Extracts loop associated with a prefix from mmCIF data as a list. + + Reference for loop_ in mmCIF: + http://mmcif.wwpdb.org/docs/tutorials/mechanics/pdbx-mmcif-syntax.html + + For instance for an mmCIF with these fields: + '_a.1': ['a.1.1', 'a.1.2', 'a.1.3'] + '_a.2': ['a.2.1', 'a.2.2', 'a.2.3'] + + this function called with prefix='_a.' extracts: + [{'_a.1': 'a.1.1', '_a.2': 'a.2.1'} + {'_a.1': 'a.1.2', '_a.2': 'a.2.2'} + {'_a.1': 'a.1.3', '_a.2': 'a.2.3'}] + + Args: + prefix: Prefix shared by each of the data items in the loop. The prefix + should include the trailing period. + + Returns: + A list of dicts; each dict represents 1 entry from an mmCIF loop. + """ + + def get(self, key: str, default_value: _T = ...) -> list[str] | _T: ... + def get_array( + self, key: str, dtype: object = ..., gather: object = ... + ) -> np.ndarray: + """Returns values looked up in dict converted to a NumPy array. + + Args: + key: Key in dictionary. + dtype: Optional (default `object`) Specifies output dtype of array. One of + [object, np.{int,uint}{8,16,32,64} np.float{32,64}]. As with NumPy use + `object` to return a NumPy array of strings. + gather: Optional one of [slice, np.{int,uint}{32,64}] non-intermediate + version of get_array(key, dtype)[gather]. + + Returns: + A NumPy array of given dtype. An optimised equivalent to + np.array(cif[key]).astype(dtype). With support of '.' being treated + as np.nan if dtype is one of np.float{32,64}. + Identical strings will all reference the same object to save space. + + Raises: + KeyError - if key is not found. + TypeError - if dtype is not valid or supported. + ValueError - if string cannot convert to dtype. + """ + + def get_data_name(self) -> str: ... + def items(self) -> CifDict.ItemView: ... + def keys(self) -> CifDict.KeyView: ... + def to_string(self) -> str: ... + def value_length(self, key: str) -> int: ... + def values(self) -> CifDict.ValueView: ... + def __bool__(self) -> bool: ... + def __contains__(self, key: str) -> bool: ... + def __getitem__(self, key: str) -> list[str]: ... + def __getstate__(self) -> tuple: ... + def __iter__(self) -> Iterator[str]: ... + def __len__(self) -> int: ... + def __setstate__(self, state: tuple) -> None: ... + +def tokenize(cif_string: str) -> list[str]: ... +def split_line(line: str) -> list[str]: ... +def from_string(mmcif_string: str | bytes) -> CifDict: ... +def parse_multi_data_cif(cif_string: str | bytes) -> dict[str, CifDict]: ... diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/cif_dict_lib.cc b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/cif_dict_lib.cc new file mode 100644 index 0000000000000000000000000000000000000000..2d2675c757a9636578154d7869084873887ba043 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/cif_dict_lib.cc @@ -0,0 +1,648 @@ +// Copyright 2024 DeepMind Technologies Limited +// +// AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +// this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +// +// To request access to the AlphaFold 3 model parameters, follow the process set +// out at https://github.com/google-deepmind/alphafold3. You may only use these +// if received directly from Google. Use is subject to terms of use available at +// https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +#include "alphafold3/parsers/cpp/cif_dict_lib.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/btree_map.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/node_hash_map.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/ascii.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/strings/strip.h" + +namespace alphafold3 { +namespace { + +bool IsQuote(const char symbol) { return symbol == '\'' || symbol == '"'; } +bool IsWhitespace(const char symbol) { return symbol == ' ' || symbol == '\t'; } + +// Splits line into tokens, returns whether successful. +bool SplitLineInline(absl::string_view line, + std::vector* tokens) { + // See https://www.iucr.org/resources/cif/spec/version1.1/cifsyntax + for (int i = 0, line_length = line.length(); i < line_length;) { + // Skip whitespace (spaces or tabs). + while (IsWhitespace(line[i])) { + if (++i == line_length) { + break; + } + } + if (i == line_length) { + break; + } + + // Skip comments (from # until the end of the line). If # is a non-comment + // character, it must be inside a quoted token. + if (line[i] == '#') { + break; + } + + int start_index; + int end_index; + if (IsQuote(line[i])) { + // Token in single or double quotes. CIF v1.1 specification considers a + // quote to be an opening quote only if it is at the beginning of a token. + // So e.g. A' B has tokens A' and B. Also, ""A" is a token "A. + const char quote_char = line[i++]; + start_index = i; + + // Find matching quote. The double loop is not strictly necessary, but + // optimises a bit better. + while (true) { + while (i < line_length && line[i] != quote_char) { + ++i; + } + if (i == line_length) { + // Reached the end of the line while still being inside a token. + return false; + } + if (i + 1 == line_length || IsWhitespace(line[i + 1])) { + break; + } + ++i; + } + end_index = i++; + } else { + // Non-quoted token. Read until reaching whitespace. + start_index = i++; + while (i < line_length && !IsWhitespace(line[i])) { + ++i; + } + end_index = i; + } + + tokens->push_back(line.substr(start_index, end_index - start_index)); + } + + return true; +} + +using HeapStrings = std::vector>; + +// The majority of strings can be viewed on original cif_string. +// heap_strings store multi-line tokens that have internal white-space stripped. +absl::StatusOr> TokenizeInternal( + absl::string_view cif_string, HeapStrings* heap_strings) { + const std::vector lines = absl::StrSplit(cif_string, '\n'); + std::vector tokens; + // Heuristic: Most lines in an mmCIF are _atom_site lines with 21 tokens. + tokens.reserve(lines.size() * 21); + int line_num = 0; + while (line_num < lines.size()) { + auto line = lines[line_num]; + line_num++; + + if (line.empty() || line[0] == '#') { + // Skip empty lines or lines that contain only comments. + continue; + } else if (line[0] == ';') { + // Leading whitespace on each line must be preserved while trailing + // whitespace may be stripped. + std::vector multiline_tokens; + // Strip the leading ";". + multiline_tokens.push_back( + absl::StripTrailingAsciiWhitespace(line.substr(1))); + while (line_num < lines.size()) { + auto multiline = absl::StripTrailingAsciiWhitespace(lines[line_num]); + line_num++; + if (!multiline.empty() && multiline[0] == ';') { + break; + } + multiline_tokens.push_back(multiline); + } + heap_strings->push_back( + std::make_unique(absl::StrJoin(multiline_tokens, "\n"))); + tokens.emplace_back(*heap_strings->back()); + } else { + if (!SplitLineInline(line, &tokens)) { + return absl::InvalidArgumentError( + absl::StrCat("Line ended with quote open: ", line)); + } + } + } + return tokens; +} + +absl::string_view GetEscapeQuote(const absl::string_view value) { + // Empty values should not happen, but if so, they should be quoted. + if (value.empty()) { + return "\""; + } + + // Shortcut for the most common cases where no quoting needed. + if (std::all_of(value.begin(), value.end(), [](char c) { + return absl::ascii_isalnum(c) || c == '.' || c == '?' || c == '-'; + })) { + return ""; + } + + // The value must not start with one of these CIF keywords. + if (absl::StartsWithIgnoreCase(value, "data_") || + absl::StartsWithIgnoreCase(value, "loop_") || + absl::StartsWithIgnoreCase(value, "save_") || + absl::StartsWithIgnoreCase(value, "stop_") || + absl::StartsWithIgnoreCase(value, "global_")) { + return "\""; + } + + // The first character must not be a special character. + const char first = value.front(); + if (first == '_' || first == '#' || first == '$' || first == '[' || + first == ']' || first == ';') { + return "\""; + } + + // No quotes or whitespace allowed inside. + for (const char c : value) { + if (c == '"') { + return "'"; + } else if (c == '\'' || c == ' ' || c == '\t') { + return "\""; + } + } + return ""; +} + +int RecordIndex(absl::string_view record) { + if (record == "_entry") { + return 0; // _entry is always first. + } + if (record == "_atom_site") { + return 2; // _atom_site is always last. + } + return 1; // other records are between _entry and _atom_site. +} + +struct RecordOrder { + using is_transparent = void; // Enable heterogeneous lookup. + bool operator()(absl::string_view lhs, absl::string_view rhs) const { + std::size_t lhs_index = RecordIndex(lhs); + std::size_t rhs_index = RecordIndex(rhs); + return std::tie(lhs_index, lhs) < std::tie(rhs_index, rhs); + } +}; + +// Make sure the _atom_site loop columns are sorted in the PDB-standard way. +constexpr absl::string_view kAtomSiteSortOrder[] = { + "_atom_site.group_PDB", + "_atom_site.id", + "_atom_site.type_symbol", + "_atom_site.label_atom_id", + "_atom_site.label_alt_id", + "_atom_site.label_comp_id", + "_atom_site.label_asym_id", + "_atom_site.label_entity_id", + "_atom_site.label_seq_id", + "_atom_site.pdbx_PDB_ins_code", + "_atom_site.Cartn_x", + "_atom_site.Cartn_y", + "_atom_site.Cartn_z", + "_atom_site.occupancy", + "_atom_site.B_iso_or_equiv", + "_atom_site.pdbx_formal_charge", + "_atom_site.auth_seq_id", + "_atom_site.auth_comp_id", + "_atom_site.auth_asym_id", + "_atom_site.auth_atom_id", + "_atom_site.pdbx_PDB_model_num", +}; + +size_t AtomSiteIndex(absl::string_view atom_site) { + return std::distance(std::begin(kAtomSiteSortOrder), + absl::c_find(kAtomSiteSortOrder, atom_site)); +} + +struct AtomSiteOrder { + bool operator()(absl::string_view lhs, absl::string_view rhs) const { + auto lhs_index = AtomSiteIndex(lhs); + auto rhs_index = AtomSiteIndex(rhs); + return std::tie(lhs_index, lhs) < std::tie(rhs_index, rhs); + } +}; + +class Column { + public: + Column(absl::string_view key, const std::vector* values) + : key_(key), values_(values) { + int max_value_length = 0; + for (size_t i = 0; i < values->size(); ++i) { + absl::string_view value = (*values)[i]; + if (absl::StrContains(value, '\n')) { + values_with_newlines_.insert(i); + } else { + absl::string_view quote = GetEscapeQuote(value); + if (!quote.empty()) { + values_with_quotes_[i] = quote; + } + max_value_length = + std::max(max_value_length, value.size() + quote.size() * 2); + } + } + max_value_length_ = max_value_length; + } + + absl::string_view key() const { return key_; } + + const std::vector* values() const { return values_; } + + int max_value_length() const { return max_value_length_; } + + bool has_newlines(size_t index) const { + return values_with_newlines_.contains(index); + } + + absl::string_view quote(size_t index) const { + if (auto it = values_with_quotes_.find(index); + it != values_with_quotes_.end()) { + return it->second; + } + return ""; + } + + private: + absl::string_view key_; + const std::vector* values_; + int max_value_length_; + // Values with newlines or quotes are very rare in a typical CIF file. + absl::flat_hash_set values_with_newlines_; + absl::flat_hash_map values_with_quotes_; +}; + +struct GroupedKeys { + std::vector grouped_columns; + int max_key_length; + int value_size; +}; + +} // namespace + +absl::StatusOr CifDict::FromString(absl::string_view cif_string) { + CifDict::Dict cif; + + bool loop_flag = false; + absl::string_view key; + + HeapStrings heap_strings; + auto tokens = TokenizeInternal(cif_string, &heap_strings); + if (!tokens.ok()) { + return tokens.status(); + } + + if (tokens->empty()) { + return absl::InvalidArgumentError("The CIF file must not be empty."); + } + + // The first token should be data_XXX. Split into key = data, value = XXX. + absl::string_view first_token = tokens->front(); + if (!absl::ConsumePrefix(&first_token, "data_")) { + return absl::InvalidArgumentError( + "The CIF file does not start with the data_ field."); + } + cif["data_"].emplace_back(first_token); + + // Counters for CIF loop_ regions. + int loop_token_index = 0; + int num_loop_keys = 0; + // Loops have usually O(10) columns but could have up to O(10^6) rows. It is + // therefore wasteful to look up the cif vector where to add a loop value + // since that means doing `columns * rows` map lookups. If we save pointers to + // these loop column fields instead, we need only 1 cif lookup per column. + std::vector*> loop_column_values; + + // Skip the first element since we already processed it above. + for (auto token_itr = tokens->begin() + 1; token_itr != tokens->end(); + ++token_itr) { + auto token = *token_itr; + if (absl::EqualsIgnoreCase(token, "loop_")) { + // A new loop started, get rid of old loop's data. + loop_flag = true; + loop_column_values.clear(); + loop_token_index = 0; + num_loop_keys = 0; + continue; + } else if (loop_flag) { + // The second condition checks we are in the first column. Some mmCIF + // files (e.g. 4q9r) have values in later columns starting with an + // underscore and we don't want to read these as keys. + int token_column_index = + num_loop_keys == 0 ? 0 : loop_token_index % num_loop_keys; + if (token_column_index == 0 && !token.empty() && token[0] == '_') { + if (loop_token_index > 0) { + // We are out of the loop. + loop_flag = false; + } else { + // We are in the keys (column names) section of the loop. + auto& columns = cif[token]; + columns.clear(); + + // Heuristic: _atom_site is typically the largest table in an mmCIF + // with ~16 columns. Make sure we reserve enough space for its values. + if (absl::StartsWith(token, "_atom_site.")) { + columns.reserve(tokens->size() / 20); + } + + // Save the pointer to the loop column values. + loop_column_values.push_back(&columns); + num_loop_keys += 1; + continue; + } + } else { + // We are in the values section of the loop. We have a pointer to the + // loops' values, add the new token in there. + if (token_column_index >= loop_column_values.size()) { + return absl::InvalidArgumentError( + absl::StrCat("Too many columns at: '", token, + "' at column index: ", token_column_index, + " expected at most: ", loop_column_values.size())); + } + loop_column_values[token_column_index]->emplace_back(token); + loop_token_index++; + continue; + } + } + if (key.empty()) { + key = token; + } else { + cif[key].emplace_back(token); + key = ""; + } + } + return CifDict(std::move(cif)); +} + +absl::StatusOr CifDict::ToString() const { + std::string output; + + absl::string_view data_name; + // Check that the data_ field exists. + if (auto name_it = (*dict_).find("data_"); + name_it == (*dict_).end() || name_it->second.empty()) { + return absl::InvalidArgumentError( + "The CIF must contain a valid name for this data block in the special " + "data_ field."); + } else { + data_name = name_it->second.front(); + } + + if (absl::c_any_of(data_name, + [](char i) { return absl::ascii_isspace(i); })) { + return absl::InvalidArgumentError(absl::StrFormat( + "The CIF data block name must not contain any whitespace characters, " + "got '%s'.", + data_name)); + } + absl::StrAppend(&output, "data_", data_name, "\n#\n"); + + // Group keys by their prefix. Use btree_map to iterate in alphabetical order, + // but with some keys being placed at the end (e.g. _atom_site). + absl::btree_map grouped_keys; + for (const auto& [key, values] : *dict_) { + if (key == "data_") { + continue; // Skip the special data_ key, we are already done with it. + } + const std::pair key_parts = + absl::StrSplit(key, absl::MaxSplits('.', 1)); + const absl::string_view key_prefix = key_parts.first; + auto [it, inserted] = grouped_keys.emplace(key_prefix, GroupedKeys{}); + GroupedKeys& grouped_key = it->second; + grouped_key.grouped_columns.push_back(Column(key, &values)); + if (inserted) { + grouped_key.max_key_length = key.length(); + grouped_key.value_size = values.size(); + } else { + grouped_key.max_key_length = + std::max(key.length(), grouped_key.max_key_length); + if (grouped_key.value_size != values.size()) { + return absl::InvalidArgumentError( + absl::StrFormat("Values for key %s have different length (%d) than " + "the other values with the same key prefix (%d).", + key, values.size(), grouped_key.value_size)); + } + } + } + + for (auto& [key_prefix, group_info] : grouped_keys) { + if (key_prefix == "_atom_site") { + // Make sure we sort the _atom_site loop in the standard way. + absl::c_sort(group_info.grouped_columns, + [](const Column& lhs, const Column& rhs) { + return AtomSiteOrder{}(lhs.key(), rhs.key()); + }); + } else { + // Make the key ordering within a key group deterministic. + absl::c_sort(group_info.grouped_columns, + [](const Column& lhs, const Column& rhs) { + return lhs.key() < rhs.key(); + }); + } + + // Force `_atom_site` field to always be a loop. This resolves issues with + // third party mmCIF parsers such as OpenBabel which always expect a loop + // even when there is only a single atom present. + if (group_info.value_size == 1 && key_prefix != "_atom_site") { + // Plain key-value pairs, output them as they are. + for (const Column& grouped_column : group_info.grouped_columns) { + int width = group_info.max_key_length + 1; + size_t start_pos = output.size(); + output.append(width, ' '); + auto out_it = output.begin() + start_pos; + absl::c_copy(grouped_column.key(), out_it); + // Append the value, handle multi-line/quoting. + absl::string_view value = grouped_column.values()->front(); + if (grouped_column.has_newlines(0)) { + absl::StrAppend(&output, "\n;", value, "\n;\n"); // Multi-line value. + } else { + const absl::string_view quote_char = grouped_column.quote(0); + absl::StrAppend(&output, quote_char, value, quote_char, "\n"); + } + } + } else { + // CIF loop. Output the column names, then the rows with data. + absl::StrAppend(&output, "loop_\n"); + for (Column& grouped_column : group_info.grouped_columns) { + absl::StrAppend(&output, grouped_column.key(), "\n"); + } + // Write the loop values, line by line. This is the most expensive part + // since this path is taken to write the entire atom site table which has + // about 20 columns, but thousands of rows. + for (int i = 0; i < group_info.value_size; i++) { + for (int column_index = 0; + column_index < group_info.grouped_columns.size(); ++column_index) { + const Column& grouped_column = + group_info.grouped_columns[column_index]; + const absl::string_view value = (*grouped_column.values())[i]; + if (grouped_column.has_newlines(i)) { + // Multi-line. This is very rarely taken path. + if (column_index == 0) { + // No extra newline before leading ;, already inserted. + absl::StrAppend(&output, ";", value, "\n;\n"); + } else if (column_index == group_info.grouped_columns.size() - 1) { + // No extra newline after trailing ;, will be inserted. + absl::StrAppend(&output, "\n;", value, "\n;"); + } else { + absl::StrAppend(&output, "\n;", value, "\n;\n"); + } + } else { + size_t start_pos = output.size(); + output.append(grouped_column.max_value_length() + 1, ' '); + auto out_it = output.begin() + start_pos; + absl::string_view quote = grouped_column.quote(i); + if (!quote.empty()) { + out_it = absl::c_copy(quote, out_it); + out_it = absl::c_copy(value, out_it); + absl::c_copy(quote, out_it); + } else { + absl::c_copy(value, out_it); + } + } + } + absl::StrAppend(&output, "\n"); + } + } + absl::StrAppend(&output, "#\n"); // Comment token after every key group. + } + return output; +} + +absl::StatusOr< + std::vector>> +CifDict::ExtractLoopAsList(absl::string_view prefix) const { + std::vector column_names; + std::vector> column_data; + + for (const auto& element : *dict_) { + if (absl::StartsWith(element.first, prefix)) { + column_names.emplace_back(element.first); + auto& cells = column_data.emplace_back(); + cells.insert(cells.begin(), element.second.begin(), element.second.end()); + } + } + // Make sure all columns have the same number of rows. + const std::size_t num_rows = column_data.empty() ? 0 : column_data[0].size(); + for (const auto& column : column_data) { + if (column.size() != num_rows) { + return absl::InvalidArgumentError(absl::StrCat( + GetDataName(), + ": Columns do not have the same number of rows for prefix: '", prefix, + "'. One possible reason could be not including the trailing dot, " + "e.g. '_atom_site.'.")); + } + } + + std::vector> result; + result.reserve(num_rows); + CHECK_EQ(column_names.size(), column_data.size()); + for (std::size_t row_index = 0; row_index < num_rows; ++row_index) { + auto& row_dict = result.emplace_back(); + row_dict.reserve(column_names.size()); + for (int col_index = 0; col_index < column_names.size(); ++col_index) { + row_dict[column_names[col_index]] = column_data[col_index][row_index]; + } + } + return result; +} + +absl::StatusOr>> +CifDict::ExtractLoopAsDict(absl::string_view prefix, + absl::string_view index) const { + if (!absl::StartsWith(index, prefix)) { + return absl::InvalidArgumentError( + absl::StrCat(GetDataName(), ": The loop index '", index, + "' must start with the loop prefix '", prefix, "'.")); + } + absl::flat_hash_map> + result; + auto loop_as_list = ExtractLoopAsList(prefix); + if (!loop_as_list.ok()) { + return loop_as_list.status(); + } + result.reserve(loop_as_list->size()); + for (auto& entry : *loop_as_list) { + if (const auto it = entry.find(index); it != entry.end()) { + result[it->second] = entry; + } else { + return absl::InvalidArgumentError(absl::StrCat( + GetDataName(), ": The index column '", index, + "' could not be found in the loop with prefix '", prefix, "'.")); + } + } + return result; +} + +absl::StatusOr> Tokenize( + absl::string_view cif_string) { + HeapStrings heap_strings; + auto tokens = TokenizeInternal(cif_string, &heap_strings); + if (!tokens.ok()) { + return tokens.status(); + } + return std::vector(tokens->begin(), tokens->end()); +} + +absl::StatusOr> SplitLine( + absl::string_view line) { + std::vector tokens; + if (!SplitLineInline(line, &tokens)) { + return absl::InvalidArgumentError( + absl::StrCat("Line ended with quote open: ", line)); + } + return tokens; +} + +absl::StatusOr> ParseMultiDataCifDict( + absl::string_view cif_string) { + absl::flat_hash_map mapping; + constexpr absl::string_view delimiter = "data_"; + // Check cif_string starts with correct offset. + if (!cif_string.empty() && !absl::StartsWith(cif_string, delimiter)) { + return absl::InvalidArgumentError( + "Invalid format. MultiDataCifDict must start with 'data_'"); + } + for (absl::string_view data_block : + absl::StrSplit(cif_string, delimiter, absl::SkipEmpty())) { + absl::string_view block_with_delimitor( + data_block.data() - delimiter.size(), + data_block.size() + delimiter.size()); + absl::StatusOr parsed_block = + CifDict::FromString(block_with_delimitor); + if (!parsed_block.ok()) { + return parsed_block.status(); + } + absl::string_view data_name = parsed_block->GetDataName(); + mapping[data_name] = *std::move(parsed_block); + } + + return mapping; +} + +} // namespace alphafold3 diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/cif_dict_lib.h b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/cif_dict_lib.h new file mode 100644 index 0000000000000000000000000000000000000000..5c16eaa87c2443061109673d06b6db24a1b998f9 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/cif_dict_lib.h @@ -0,0 +1,149 @@ +/* + * Copyright 2024 DeepMind Technologies Limited + * + * AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of + * this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ + * + * To request access to the AlphaFold 3 model parameters, follow the process set + * out at https://github.com/google-deepmind/alphafold3. You may only use these + * if received directly from Google. Use is subject to terms of use available at + * https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + */ + +// A C++ implementation of a CIF parser. For the format specification see +// https://www.iucr.org/resources/cif/spec/version1.1/cifsyntax +#ifndef ALPHAFOLD3_SRC_ALPHAFOLD3_PARSERS_PYTHON_CIF_DICT_LIB_H_ +#define ALPHAFOLD3_SRC_ALPHAFOLD3_PARSERS_PYTHON_CIF_DICT_LIB_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/node_hash_map.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" + +namespace alphafold3 { + +class CifDict { + public: + // Use absl::node_hash_map since it guarantees pointer stability. + using Dict = absl::node_hash_map>; + + CifDict() = default; + + explicit CifDict(Dict dict) + : dict_(std::make_shared(std::move(dict))) {} + + // Converts a CIF string into a dictionary mapping each CIF field to a list of + // values that field contains. + static absl::StatusOr FromString(absl::string_view cif_string); + + // Converts the CIF into into a string that is a valid CIF file. + absl::StatusOr ToString() const; + + // Extracts loop associated with a prefix from mmCIF data as a list. + // Reference for loop_ in mmCIF: + // http://mmcif.wwpdb.org/docs/tutorials/mechanics/pdbx-mmcif-syntax.html + // Args: + // prefix: Prefix shared by each of the data items in the loop. + // e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num, + // _entity_poly_seq.mon_id. Should include the trailing period. + // + // Returns a list of dicts; each dict represents 1 entry from an mmCIF loop. + // Lifetime of string_views tied to this. + absl::StatusOr< + std::vector>> + ExtractLoopAsList(absl::string_view prefix) const; + + // Extracts loop associated with a prefix from mmCIF data as a dictionary. + // Args: + // prefix: Prefix shared by each of the data items in the loop. + // e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num, + // _entity_poly_seq.mon_id. Should include the trailing period. + // index: Which item of loop data should serve as the key. + // + // Returns a dict of dicts; each dict represents 1 entry from an mmCIF loop, + // indexed by the index column. + // Lifetime of string_views tied to this. + absl::StatusOr>> + ExtractLoopAsDict(absl::string_view prefix, absl::string_view index) const; + + // Returns value at key if present or an empty list. + absl::Span operator[](absl::string_view key) const { + auto it = dict_->find(key); + if (it != dict_->end()) { + return it->second; + } + return {}; + } + + // Returns boolean of whether dict contains key. + bool Contains(absl::string_view key) const { return dict_->contains(key); } + + // Returns number of values for the given key if present, 0 otherwise. + size_t ValueLength(absl::string_view key) const { + return (*this)[key].size(); + } + + // Returns the size of the underlying dictionary. + std::size_t Length() { return dict_->size(); } + + // Creates a copy of this CifDict object that will contain the original values + // but only if not updated by the given dictionary. + // E.g. if the CifDict = {a: [a1, a2], b: [b1]} and other = {a: [x], c: [z]}, + // you will get {a: [x], b: [b1], c: [z]}. + CifDict CopyAndUpdate(Dict other) const { + other.insert(dict_->begin(), dict_->end()); + return CifDict(std::move(other)); + } + + // Returns the value of the special CIF data_ field. + absl::string_view GetDataName() const { + // The data_ element has to be present by construction. + if (auto it = dict_->find("data_"); + it != dict_->end() && !it->second.empty()) { + return it->second.front(); + } else { + return ""; + } + } + + const std::shared_ptr& dict() const { return dict_; } + + private: + std::shared_ptr dict_; +}; + +// Tokenizes a CIF string into a list of string tokens. This is more involved +// than just a simple split on whitespace as CIF allows comments and quoting. +absl::StatusOr> Tokenize(absl::string_view cif_string); + +// Tokenizes a single line of a CIF string. +absl::StatusOr> SplitLine( + absl::string_view line); + +// Parses a CIF string with multiple data records and returns a mapping from +// record names to CifDict objects. For instance, the following CIF string: +// +// data_001 +// _foo bar +// +// data_002 +// _foo baz +// +// will be parsed as: +// {'001': CifDict({'_foo': ['bar']}), +// '002': CifDict({'_foo': ['baz']})} +absl::StatusOr> ParseMultiDataCifDict( + absl::string_view cif_string); + +} // namespace alphafold3 + +#endif // ALPHAFOLD3_SRC_ALPHAFOLD3_PARSERS_PYTHON_CIF_DICT_LIB_H_ diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/cif_dict_pybind.cc b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/cif_dict_pybind.cc new file mode 100644 index 0000000000000000000000000000000000000000..130a8215abf89b54be76c9b5973fec8bf5b11222 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/cif_dict_pybind.cc @@ -0,0 +1,652 @@ +// Copyright 2024 DeepMind Technologies Limited +// +// AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +// this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +// +// To request access to the AlphaFold 3 model parameters, follow the process set +// out at https://github.com/google-deepmind/alphafold3. You may only use these +// if received directly from Google. Use is subject to terms of use available at +// https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "numpy/ndarrayobject.h" +#include "numpy/ndarraytypes.h" +#include "numpy/npy_common.h" +#include "absl/base/no_destructor.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "alphafold3/parsers/cpp/cif_dict_lib.h" +#include "pybind11/attr.h" +#include "pybind11/cast.h" +#include "pybind11/gil.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" +#include "pybind11/stl.h" + +namespace alphafold3 { +namespace { +namespace py = pybind11; + +template +bool GatherArray(size_t num_dims, npy_intp* shape_array, npy_intp* stride_array, + const char* data, absl::Span values, + ForEach&& for_each_cb) { + if (num_dims == 1) { + const npy_intp shape = shape_array[0]; + const npy_intp stride = stride_array[0]; + for (size_t i = 0; i < shape; ++i) { + Item index; + std::memcpy(&index, data + stride * i, sizeof(Item)); + if (index < 0 || index >= values.size()) { + PyErr_SetString(PyExc_IndexError, + absl::StrCat("index ", index, + " is out of bounds for column with size ", + values.size()) + .c_str()); + return false; + } + if (!for_each_cb(values[index])) { + return false; + } + } + } else if (num_dims == 0) { + Item index; + std::memcpy(&index, data, sizeof(Item)); + if (index < 0 || index >= values.size()) { + PyErr_SetString( + PyExc_IndexError, + absl::StrCat("index ", index, + " is out of bounds for column with size ", values.size()) + .c_str()); + return false; + } + if (!for_each_cb(values[index])) { + return false; + } + } else { + const npy_intp shape = shape_array[0]; + const npy_intp stride = stride_array[0]; + for (size_t i = 0; i < shape; ++i) { + if (!GatherArray(num_dims - 1, shape_array + 1, stride_array + 1, + data + stride * i, values, for_each_cb)) { + return false; + } + } + } + return true; +} + +template +bool Gather(PyObject* gather, absl::Span values, + Size&& size_cb, ForEach&& for_each_cb) { + if (gather == Py_None) { + npy_intp dim = static_cast(values.size()); + if (!size_cb(absl::MakeSpan(&dim, 1))) { + return false; + } + for (const std::string& v : values) { + if (!for_each_cb(v)) { + return false; + } + } + return true; + } + if (PySlice_Check(gather)) { + Py_ssize_t start, stop, step, slice_length; + if (PySlice_GetIndicesEx(gather, values.size(), &start, &stop, &step, + &slice_length) != 0) { + return false; + } + npy_intp dim = static_cast(slice_length); + if (!size_cb(absl::MakeSpan(&dim, 1))) { + return false; + } + for (size_t i = 0; i < slice_length; ++i) { + if (!for_each_cb(values[start + i * step])) { + return false; + } + } + return true; + } + if (PyArray_Check(gather)) { + PyArrayObject* gather_array = reinterpret_cast(gather); + auto shape = + absl::MakeSpan(PyArray_DIMS(gather_array), PyArray_NDIM(gather_array)); + switch (PyArray_TYPE(gather_array)) { + case NPY_INT16: + if (!size_cb(shape)) { + return false; + } + return GatherArray(shape.size(), shape.data(), + PyArray_STRIDES(gather_array), + PyArray_BYTES(gather_array), values, + std::forward(for_each_cb)); + case NPY_UINT16: + if (!size_cb(shape)) { + return false; + } + return GatherArray(shape.size(), shape.data(), + PyArray_STRIDES(gather_array), + PyArray_BYTES(gather_array), values, + std::forward(for_each_cb)); + case NPY_INT32: + if (!size_cb(shape)) { + return false; + } + return GatherArray(shape.size(), shape.data(), + PyArray_STRIDES(gather_array), + PyArray_BYTES(gather_array), values, + std::forward(for_each_cb)); + case NPY_UINT32: + if (!size_cb(shape)) { + return false; + } + return GatherArray(shape.size(), shape.data(), + PyArray_STRIDES(gather_array), + PyArray_BYTES(gather_array), values, + std::forward(for_each_cb)); + case NPY_INT64: + if (!size_cb(shape)) { + return false; + } + return GatherArray(shape.size(), shape.data(), + PyArray_STRIDES(gather_array), + PyArray_BYTES(gather_array), values, + std::forward(for_each_cb)); + case NPY_UINT64: + if (!size_cb(shape)) { + return false; + } + return GatherArray(shape.size(), shape.data(), + PyArray_STRIDES(gather_array), + PyArray_BYTES(gather_array), values, + std::forward(for_each_cb)); + default: + PyErr_SetString(PyExc_TypeError, "Unsupported NumPy array type."); + return false; + } + } + + PyErr_Format(PyExc_TypeError, "Invalid gather %R", gather); + return false; +} + +// Creates a NumPy array of objects of given strings. Reusing duplicates where +// possible. +PyObject* ConvertStrings(PyObject* gather, PyArray_Descr* type, + absl::Span values) { + absl::flat_hash_map existing; + + PyObject* ret = nullptr; + PyObject** dst; + if (Gather( + gather, values, + [&dst, &ret, type](absl::Span size) { + ret = PyArray_NewFromDescr( + /*subtype=*/&PyArray_Type, + /*type=*/type, + /*nd=*/size.size(), + /*dims=*/size.data(), + /*strides=*/nullptr, + /*data=*/nullptr, + /*flags=*/0, + /*obj=*/nullptr); + dst = static_cast( + PyArray_DATA(reinterpret_cast(ret))); + return true; + }, + [&dst, &existing](absl::string_view value) { + auto [it, inserted] = existing.emplace(value, nullptr); + if (inserted) { + it->second = + PyUnicode_FromStringAndSize(value.data(), value.size()); + PyUnicode_InternInPlace(&it->second); + } else { + Py_INCREF(it->second); + } + *dst++ = it->second; + return true; + })) { + return ret; + } else { + Py_XDECREF(ret); + return nullptr; + } +} + +// Creates NumPy array with given dtype given specified converter. +// `converter` shall have the following signature: +// bool converter(const std::string& value, T* result); +// It must return whether conversion is successful and store conversion in +// result. +template +inline PyObject* Convert(PyObject* gather, PyArray_Descr* type, + absl::Span values, C&& converter) { + py::object ret; + T* dst; + if (Gather( + gather, values, + [&dst, &ret, type](absl::Span size) { + // Construct uninitialised NumPy array of type T. + ret = py::reinterpret_steal(PyArray_NewFromDescr( + /*subtype=*/&PyArray_Type, + /*type=*/type, + /*nd=*/size.size(), + /*dims=*/size.data(), + /*strides=*/nullptr, + /*data=*/nullptr, + /*flags=*/0, + /*obj=*/nullptr)); + + dst = static_cast( + PyArray_DATA(reinterpret_cast(ret.ptr()))); + return true; + }, + [&dst, &converter](const std::string& value) { + if (!converter(value, dst++)) { + PyErr_SetString(PyExc_ValueError, value.c_str()); + return false; + } + return true; + })) { + return ret.release().ptr(); + } + return nullptr; +} + +PyObject* CifDictGetArray(const CifDict& self, absl::string_view key, + PyObject* dtype, PyObject* gather) { + import_array(); + PyArray_Descr* type = nullptr; + if (dtype == Py_None) { + type = PyArray_DescrFromType(NPY_OBJECT); + } else if (PyArray_DescrConverter(dtype, &type) == NPY_FAIL || !type) { + PyErr_Format(PyExc_TypeError, "Invalid dtype %R", dtype); + Py_XDECREF(type); + return nullptr; + } + auto entry = self.dict()->find(key); + if (entry == self.dict()->end()) { + Py_DECREF(type); + PyErr_SetObject(PyExc_KeyError, + PyUnicode_FromStringAndSize(key.data(), key.size())); + return nullptr; + } + + auto int_convert = [](absl::string_view str, auto* value) { + return absl::SimpleAtoi(str, value); + }; + + auto int_convert_bounded = [](absl::string_view str, auto* value) { + int64_t v; + if (absl::SimpleAtoi(str, &v)) { + using limits = + std::numeric_limits>; + if (limits::min() <= v && v <= limits::max()) { + *value = v; + return true; + } + } + return false; + }; + + absl::Span values = entry->second; + + switch (type->type_num) { + case NPY_DOUBLE: + return Convert( + gather, type, values, [](absl::string_view str, double* value) { + if (str == ".") { + *value = std::numeric_limits::quiet_NaN(); + return true; + } + return absl::SimpleAtod(str, value); + }); + case NPY_FLOAT: + return Convert( + gather, type, values, [](absl::string_view str, float* value) { + if (str == ".") { + *value = std::numeric_limits::quiet_NaN(); + return true; + } + return absl::SimpleAtof(str, value); + }); + case NPY_INT8: + return Convert(gather, type, values, int_convert_bounded); + case NPY_INT16: + return Convert(gather, type, values, int_convert_bounded); + case NPY_INT32: + return Convert(gather, type, values, int_convert); + case NPY_INT64: + return Convert(gather, type, values, int_convert); + case NPY_UINT8: + return Convert(gather, type, values, int_convert_bounded); + case NPY_UINT16: + return Convert(gather, type, values, int_convert_bounded); + case NPY_UINT32: + return Convert(gather, type, values, int_convert); + case NPY_UINT64: + return Convert(gather, type, values, int_convert); + case NPY_BOOL: + return Convert(gather, type, values, + [](absl::string_view str, bool* value) { + if (str == "n" || str == "no") { + *value = false; + return true; + } + if (str == "y" || str == "yes") { + *value = true; + return true; + } + return false; + }); + case NPY_OBJECT: + return ConvertStrings(gather, type, values); + default: { + PyErr_Format(PyExc_TypeError, "Unsupported dtype %R", dtype); + Py_XDECREF(type); + return nullptr; + } + } +} + +} // namespace + +void RegisterModuleCifDict(pybind11::module m) { + using Value = std::vector; + static absl::NoDestructor> empty_values; + + m.def( + "from_string", + [](absl::string_view s) { + absl::StatusOr dict = CifDict::FromString(s); + if (!dict.ok()) { + throw py::value_error(dict.status().ToString()); + } + return *dict; + }, + py::call_guard()); + + m.def( + "tokenize", + [](absl::string_view cif_string) { + absl::StatusOr> tokens = Tokenize(cif_string); + if (!tokens.ok()) { + throw py::value_error(tokens.status().ToString()); + } + return *std::move(tokens); + }, + py::arg("cif_string")); + + m.def("split_line", [](absl::string_view line) { + absl::StatusOr> tokens = SplitLine(line); + if (!tokens.ok()) { + throw py::value_error(tokens.status().ToString()); + } + return *std::move(tokens); + }); + + m.def( + "parse_multi_data_cif", + [](absl::string_view cif_string) { + auto result = ParseMultiDataCifDict(cif_string); + if (!result.ok()) { + throw py::value_error(result.status().ToString()); + } + py::dict dict; + for (auto& [key, value] : *result) { + dict[py::cast(key)] = py::cast(value); + } + return dict; + }, + py::arg("cif_string")); + + auto cif_dict = + py::class_(m, "CifDict") + .def(py::init<>([](py::dict dict) { + CifDict::Dict result; + for (const auto& [key, value] : dict) { + result.emplace(py::cast(key), + py::cast>(value)); + } + return CifDict(std::move(result)); + }), + "Initialise with a map") + .def("copy_and_update", + [](const CifDict& self, py::dict dict) { + CifDict::Dict result; + for (const auto& [key, value] : dict) { + result.emplace(py::cast(key), + py::cast>(value)); + } + { + py::gil_scoped_release gil_release; + return self.CopyAndUpdate(std::move(result)); + } + }) + .def( + "__str__", + [](const CifDict& self) { + absl::StatusOr result = self.ToString(); + if (!result.ok()) { + throw py::value_error(result.status().ToString()); + } + return *result; + }, + "Serialize to a string", py::call_guard()) + .def( + "to_string", + [](const CifDict& self) { + absl::StatusOr result = self.ToString(); + if (!result.ok()) { + throw py::value_error(result.status().ToString()); + } + return *result; + }, + "Serialize to a string", py::call_guard()) + .def("value_length", &CifDict::ValueLength, py::arg("key"), + "Num elements in value") + .def("__len__", + [](const CifDict& self) { return self.dict()->size(); }) + .def( + "__bool__", + [](const CifDict& self) { return !self.dict()->empty(); }, + "Check whether the map is nonempty") + .def( + "__contains__", + [](const CifDict& self, absl::string_view k) { + return self.dict()->find(k) != self.dict()->end(); + }, + py::arg("key"), py::call_guard()) + .def("get_data_name", &CifDict::GetDataName) + .def( + "get", + [](const CifDict& self, absl::string_view k, + py::object default_value) -> py::object { + auto it = self.dict()->find(k); + if (it == self.dict()->end()) return default_value; + py::list result(it->second.size()); + size_t index = 0; + for (const std::string& v : it->second) { + result[index++] = py::cast(v); + } + return result; + }, + py::arg("key"), py::arg("default_value") = py::none()) + .def( + "get_array", + [](const CifDict& self, absl::string_view key, py::handle dtype, + py::handle gather) -> py::object { + PyObject* obj = + CifDictGetArray(self, key, dtype.ptr(), gather.ptr()); + if (obj == nullptr) { + throw py::error_already_set(); + } + return py::reinterpret_steal(obj); + }, + py::arg("key"), py::arg("dtype") = py::none(), + py::arg("gather") = py::none()) + .def( + "__getitem__", + [](const CifDict& self, absl::string_view k) -> const Value& { + auto it = self.dict()->find(k); + if (it == self.dict()->end()) { + throw py::key_error(std::string(k).c_str()); + } + return it->second; + }, + py::arg("key"), py::call_guard()) + .def( + "extract_loop_as_dict", + [](const CifDict& self, absl::string_view prefix, + absl::string_view index) { + absl::StatusOr>> + dict; + { + py::gil_scoped_release gil_release; + dict = self.ExtractLoopAsDict(prefix, index); + if (!dict.ok()) { + throw py::value_error(dict.status().ToString()); + } + } + py::dict key_value_dict; + for (const auto& [key, value] : *dict) { + py::dict value_dict; + for (const auto& [key2, value2] : value) { + value_dict[py::cast(key2)] = py::cast(value2); + } + key_value_dict[py::cast(key)] = std::move(value_dict); + } + return key_value_dict; + }, + py::arg("prefix"), py::arg("index")) + .def( + "extract_loop_as_list", + [](const CifDict& self, absl::string_view prefix) { + absl::StatusOr>> + list_dict; + { + py::gil_scoped_release gil_release; + list_dict = self.ExtractLoopAsList(prefix); + if (!list_dict.ok()) { + throw py::value_error(list_dict.status().ToString()); + } + } + py::list list_obj(list_dict->size()); + size_t index = 0; + for (const auto& value : *list_dict) { + py::dict value_dict; + for (const auto& [key, value] : value) { + value_dict[py::cast(key)] = py::cast(value); + } + list_obj[index++] = std::move(value_dict); + } + return list_obj; + }, + py::arg("prefix")) + .def(py::pickle( + [](const CifDict& self) { // __getstate__. + py::tuple result_tuple(1); + py::dict result; + for (const auto& [key, value] : *self.dict()) { + result[py::cast(key)] = py::cast(value); + } + result_tuple[0] = std::move(result); + return result_tuple; + }, + [](py::tuple t) { // __setstate__. + py::dict dict = t[0].cast(); + CifDict::Dict result; + for (const auto& [key, value] : dict) { + result.emplace(py::cast(key), + py::cast>(value)); + } + return CifDict(std::move(result)); + })); + + // Item, value, and key views + struct KeyView { + CifDict map; + }; + + struct ValueView { + CifDict map; + }; + struct ItemView { + CifDict map; + }; + + py::class_(cif_dict, "ItemView") + .def("__len__", [](const ItemView& v) { return v.map.dict()->size(); }) + .def( + "__iter__", + [](const ItemView& v) { + return py::make_iterator(v.map.dict()->begin(), + v.map.dict()->end()); + }, + py::keep_alive<0, 1>()); + + py::class_(cif_dict, "KeyView") + .def("__contains__", + [](const KeyView& v, absl::string_view k) { + return v.map.dict()->find(k) != v.map.dict()->end(); + }) + .def("__contains__", [](const KeyView&, py::handle) { return false; }) + .def("__len__", [](const KeyView& v) { return v.map.dict()->size(); }) + .def( + "__iter__", + [](const KeyView& v) { + return py::make_key_iterator(v.map.dict()->begin(), + v.map.dict()->end()); + }, + py::keep_alive<0, 1>()); + + py::class_(cif_dict, "ValueView") + .def("__len__", [](const ValueView& v) { return v.map.dict()->size(); }) + .def( + "__iter__", + [](const ValueView& v) { + return py::make_value_iterator(v.map.dict()->begin(), + v.map.dict()->end()); + }, + py::keep_alive<0, 1>()); + + cif_dict + .def( + "__iter__", + [](CifDict& self) { + return py::make_key_iterator(self.dict()->begin(), + self.dict()->end()); + }, + py::keep_alive<0, 1>()) + .def( + "keys", [](CifDict& self) { return KeyView{self}; }, + "Returns an iterable view of the map's keys.") + .def( + "values", [](CifDict& self) { return ValueView{self}; }, + "Returns an iterable view of the map's values.") + .def( + "items", [](CifDict& self) { return ItemView{self}; }, + "Returns an iterable view of the map's items."); +} + +} // namespace alphafold3 diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/cif_dict_pybind.h b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/cif_dict_pybind.h new file mode 100644 index 0000000000000000000000000000000000000000..ca4f94702bc5b961160be78af151aaf756619f7e --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/cif_dict_pybind.h @@ -0,0 +1,24 @@ +/* + * Copyright 2024 DeepMind Technologies Limited + * + * AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of + * this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ + * + * To request access to the AlphaFold 3 model parameters, follow the process set + * out at https://github.com/google-deepmind/alphafold3. You may only use these + * if received directly from Google. Use is subject to terms of use available at + * https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + */ + +#ifndef ALPHAFOLD3_SRC_ALPHAFOLD3_PARSERS_PYTHON_CIF_DICT_PYBIND_H_ +#define ALPHAFOLD3_SRC_ALPHAFOLD3_PARSERS_PYTHON_CIF_DICT_PYBIND_H_ + +#include "pybind11/pybind11.h" + +namespace alphafold3 { + +void RegisterModuleCifDict(pybind11::module m); + +} + +#endif // ALPHAFOLD3_SRC_ALPHAFOLD3_PARSERS_PYTHON_CIF_DICT_PYBIND_H_ diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/fasta_iterator.pyi b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/fasta_iterator.pyi new file mode 100644 index 0000000000000000000000000000000000000000..d5da60ec8ade9fa4658ad71e99465beb271a48a5 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/fasta_iterator.pyi @@ -0,0 +1,22 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +class FastaFileIterator: + def __init__(self, fasta_path: str) -> None: ... + def __iter__(self) -> FastaFileIterator: ... + def __next__(self) -> tuple[str,str]: ... + +class FastaStringIterator: + def __init__(self, fasta_string: str | bytes) -> None: ... + def __iter__(self) -> FastaStringIterator: ... + def __next__(self) -> tuple[str,str]: ... + +def parse_fasta(fasta_string: str | bytes) -> list[str]: ... +def parse_fasta_include_descriptions(fasta_string: str | bytes) -> tuple[list[str],list[str]]: ... diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/fasta_iterator_lib.cc b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/fasta_iterator_lib.cc new file mode 100644 index 0000000000000000000000000000000000000000..82cac934313f2b9654f9aee54913ed5eb8f64dad --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/fasta_iterator_lib.cc @@ -0,0 +1,121 @@ +// Copyright 2024 DeepMind Technologies Limited +// +// AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +// this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +// +// To request access to the AlphaFold 3 model parameters, follow the process set +// out at https://github.com/google-deepmind/alphafold3. You may only use these +// if received directly from Google. Use is subject to terms of use available at +// https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +#include "alphafold3/parsers/cpp/fasta_iterator_lib.h" + +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/strings/strip.h" + +namespace alphafold3 { + +// Parse FASTA string and return list of strings with amino acid sequences. +// Returns a list of amino acid sequences only. +std::vector ParseFasta(absl::string_view fasta_string) { + std::vector sequences; + std::string* sequence = nullptr; + for (absl::string_view line_raw : absl::StrSplit(fasta_string, '\n')) { + absl::string_view line = absl::StripAsciiWhitespace(line_raw); + if (absl::ConsumePrefix(&line, ">")) { + sequence = &sequences.emplace_back(); + } else if (!line.empty() && sequence != nullptr) { + absl::StrAppend(sequence, line); + } + } + return sequences; +} + +// Parse FASTA string and return list of strings with amino acid sequences. +// Returns two lists: The first one with amino acid sequences, the second with +// the descriptions associated with each sequence. +std::pair, std::vector> +ParseFastaIncludeDescriptions(absl::string_view fasta_string) { + std::pair, std::vector> result; + auto& [sequences, descriptions] = result; + std::string* sequence = nullptr; + for (absl::string_view line_raw : absl::StrSplit(fasta_string, '\n')) { + absl::string_view line = absl::StripAsciiWhitespace(line_raw); + if (absl::ConsumePrefix(&line, ">")) { + descriptions.emplace_back(line); + sequence = &sequences.emplace_back(); + } else if (!line.empty() && sequence != nullptr) { + absl::StrAppend(sequence, line); + } + } + return result; +} + +absl::StatusOr> FastaFileIterator::Next() { + std::string line_str; + while (std::getline(reader_, line_str)) { + absl::string_view line = line_str; + line = absl::StripAsciiWhitespace(line); + if (absl::ConsumePrefix(&line, ">")) { + if (!description_.has_value()) { + description_ = line; + } else { + std::pair output(sequence_, *description_); + description_ = line; + sequence_ = ""; + return output; + } + } else if (description_.has_value()) { + absl::StrAppend(&sequence_, line); + } + } + has_next_ = false; + reader_.close(); + if (description_.has_value()) { + return std::pair(sequence_, *description_); + } else { + return absl::InvalidArgumentError( + absl::StrCat("Invalid FASTA file: ", filename_)); + } +} + +absl::StatusOr> +FastaStringIterator::Next() { + size_t consumed = 0; + for (absl::string_view line_raw : absl::StrSplit(fasta_string_, '\n')) { + consumed += line_raw.size() + 1; // +1 for the newline character. + absl::string_view line = absl::StripAsciiWhitespace(line_raw); + if (absl::ConsumePrefix(&line, ">")) { + if (!description_.has_value()) { + description_ = line; + } else { + std::pair output(sequence_, *description_); + description_ = line; + sequence_ = ""; + fasta_string_.remove_prefix(consumed); + return output; + } + } else if (description_.has_value()) { + absl::StrAppend(&sequence_, line); + } + } + has_next_ = false; + if (description_.has_value()) { + return std::pair(sequence_, *description_); + } else { + return absl::InvalidArgumentError("Invalid FASTA string"); + } +} + +} // namespace alphafold3 diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/fasta_iterator_lib.h b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/fasta_iterator_lib.h new file mode 100644 index 0000000000000000000000000000000000000000..486d05f20808a4a5566714b45b41e7f4c27e4f51 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/fasta_iterator_lib.h @@ -0,0 +1,94 @@ +/* + * Copyright 2024 DeepMind Technologies Limited + * + * AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of + * this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ + * + * To request access to the AlphaFold 3 model parameters, follow the process set + * out at https://github.com/google-deepmind/alphafold3. You may only use these + * if received directly from Google. Use is subject to terms of use available at + * https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + */ + +// A C++ implementation of a FASTA parser. +#ifndef ALPHAFOLD3_SRC_ALPHAFOLD3_PARSERS_PYTHON_FASTA_ITERATOR_LIB_H_ +#define ALPHAFOLD3_SRC_ALPHAFOLD3_PARSERS_PYTHON_FASTA_ITERATOR_LIB_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" + +namespace alphafold3 { + +// Parse FASTA string and return list of strings with amino acid sequences. +// Returns a list of amino acid sequences only. +std::vector ParseFasta(absl::string_view fasta_string); + +// Parse FASTA string and return list of strings with amino acid sequences. +// Returns two lists: The first one with amino acid sequences, the second with +// the descriptions associated with each sequence. +std::pair, std::vector> +ParseFastaIncludeDescriptions(absl::string_view fasta_string); + +// Lazy FASTA parser for memory efficient FASTA parsing from a path. +class FastaFileIterator { + public: + // Initialise FastaFileIterator with filename of fasta. If you initialize + // reader_ with an invalid path or empty file, it won't fail, only + // riegeli::ReadLine within the Next method will then return false. That will + // then trigger the "Invalid FASTA file" error. + explicit FastaFileIterator(absl::string_view fasta_path) + : filename_(fasta_path), + reader_(filename_, std::ios::in), + has_next_(true) {} + + // Returns whether there are more sequences. Returns true before first call to + // next even if the file is empty. + bool HasNext() const { return has_next_; } + + // Fetches the next (sequence, description) from the file. + absl::StatusOr> Next(); + + private: + // Use riegeli::FileReader instead of FileLineIterator for about 2x speedup. + std::string filename_; + std::fstream reader_; + std::optional description_; + std::string sequence_; + bool has_next_; +}; + +// Lazy FASTA parser for memory efficient FASTA parsing from a string. +class FastaStringIterator { + public: + // Initialise FastaStringIterator with a string_view of a FASTA. If you + // initialize it with an invalid FASTA string, it won't fail, the Next method + // will then return false. That will then trigger the "Invalid FASTA" error. + // WARNING: The object backing the fasta_string string_view must not be + // deleted while this Iterator is alive. + explicit FastaStringIterator(absl::string_view fasta_string) + : fasta_string_(fasta_string), has_next_(true) {} + + // Returns whether there are more sequences. Returns true before first call to + // next even if the string is empty. + bool HasNext() const { return has_next_; } + + // Fetches the next (sequence, description) from the string. + absl::StatusOr> Next(); + + private: + absl::string_view fasta_string_; + bool has_next_; + std::optional description_; + std::string sequence_; +}; + +} // namespace alphafold3 + +#endif // ALPHAFOLD3_SRC_ALPHAFOLD3_PARSERS_PYTHON_FASTA_ITERATOR_LIB_H_ diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/fasta_iterator_pybind.cc b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/fasta_iterator_pybind.cc new file mode 100644 index 0000000000000000000000000000000000000000..0b47933d42e3d5ea162ec405e8aec94ebca05320 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/fasta_iterator_pybind.cc @@ -0,0 +1,127 @@ +// Copyright 2024 DeepMind Technologies Limited +// +// AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +// this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +// +// To request access to the AlphaFold 3 model parameters, follow the process set +// out at https://github.com/google-deepmind/alphafold3. You may only use these +// if received directly from Google. Use is subject to terms of use available at +// https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "alphafold3/parsers/cpp/fasta_iterator_lib.h" +#include "pybind11/attr.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" +#include "pybind11/stl.h" + +namespace alphafold3 { +namespace { + +namespace py = pybind11; + +template +T ValueOrThrowValueError(absl::StatusOr value) { + if (!value.ok()) throw py::value_error(value.status().ToString()); + return *std::move(value); +} + +constexpr char kFastaFileIteratorDoc[] = R"( +Lazy FASTA parser for memory efficient FASTA parsing from a path.)"; + +constexpr char kFastaStringIteratorDoc[] = R"( +Lazy FASTA parser for memory efficient FASTA parsing from a string. + +WARNING: The object backing the fasta_string string_view must not be +deleted while the FastaStringIterator is alive. E.g. this will break: + +``` +# Make sure the fasta_string is not interned. +fasta_string = '\n'.join(['>d\nS' for _ in range(10)]) +iterator = fasta_iterator.FastaStringIterator(fasta_string) +del fasta_string +iterator.next() # Heap use-after-free. +``` +)"; + +constexpr char kParseFastaDoc[] = R"( +Parses a FASTA string and returns a list of amino-acid sequences. + +Args: + fasta_string: The contents of a FASTA file. + +Returns: + List of sequences in the FASTA file. Descriptions are ignored. +)"; + +constexpr char kParseFastaIncludeDescriptionsDoc[] = R"( +Parses a FASTA string, returns amino-acid sequences with descriptions. + +Args: + fasta_string: The contents of a FASTA file. + +Returns: + A tuple with two lists (sequences, descriptions): + * A list of sequences. + * A list of sequence descriptions taken from the comment lines. In the + same order as the sequences. +)"; + +class PythonFastaStringIterator : public FastaStringIterator { + public: + explicit PythonFastaStringIterator(py::object fasta_string) + : FastaStringIterator(py::cast(fasta_string)), + fasta_string_(std::move(fasta_string)) {} + + private: + py::object fasta_string_; +}; + +} // namespace + +void RegisterModuleFastaIterator(pybind11::module m) { + py::class_(m, "FastaFileIterator", kFastaFileIteratorDoc) + .def(py::init(), py::arg("fasta_path")) + .def("__iter__", + [](FastaFileIterator& iterator) -> FastaFileIterator& { + return iterator; + }) + .def( + "__next__", + [](FastaFileIterator& iterator) { + if (iterator.HasNext()) { + return ValueOrThrowValueError(iterator.Next()); + } else { + throw py::stop_iteration(); + } + }, + py::call_guard()); + + py::class_(m, "FastaStringIterator", + kFastaStringIteratorDoc) + .def(py::init(), py::arg("fasta_string")) + .def("__iter__", + [](PythonFastaStringIterator& iterator) + -> PythonFastaStringIterator& { return iterator; }) + .def( + "__next__", + [](PythonFastaStringIterator& iterator) { + if (iterator.HasNext()) { + return ValueOrThrowValueError(iterator.Next()); + } else { + throw py::stop_iteration(); + } + }, + py::call_guard()); + + m.def("parse_fasta", &ParseFasta, py::arg("fasta_string"), + py::call_guard(), py::doc(kParseFastaDoc + 1)); + m.def("parse_fasta_include_descriptions", &ParseFastaIncludeDescriptions, + py::arg("fasta_string"), py::call_guard(), + py::doc(kParseFastaIncludeDescriptionsDoc + 1)); +} + +} // namespace alphafold3 diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/fasta_iterator_pybind.h b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/fasta_iterator_pybind.h new file mode 100644 index 0000000000000000000000000000000000000000..091ea3fa21538a63aec87390583853a6964c6494 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/fasta_iterator_pybind.h @@ -0,0 +1,24 @@ +/* + * Copyright 2024 DeepMind Technologies Limited + * + * AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of + * this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ + * + * To request access to the AlphaFold 3 model parameters, follow the process set + * out at https://github.com/google-deepmind/alphafold3. You may only use these + * if received directly from Google. Use is subject to terms of use available at + * https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + */ + +#ifndef ALPHAFOLD3_SRC_ALPHAFOLD3_PARSERS_PYTHON_FASTA_ITERATOR_PYBIND_H_ +#define ALPHAFOLD3_SRC_ALPHAFOLD3_PARSERS_PYTHON_FASTA_ITERATOR_PYBIND_H_ + +#include "pybind11/pybind11.h" + +namespace alphafold3 { + +void RegisterModuleFastaIterator(pybind11::module m); + +} + +#endif // ALPHAFOLD3_SRC_ALPHAFOLD3_PARSERS_PYTHON_FASTA_ITERATOR_PYBIND_H_ diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/msa_conversion.pyi b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/msa_conversion.pyi new file mode 100644 index 0000000000000000000000000000000000000000..3602032b91866c252af4f613904ed094a262a972 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/msa_conversion.pyi @@ -0,0 +1,26 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Type annotations for Python bindings for `msa_conversion`. + +The type annotations in this file were modified from the automatically generated +stubgen output. +""" + +from collections.abc import Iterable + + +def align_sequence_to_gapless_query( + sequence: str | bytes, + query_sequence: str | bytes, +) -> str: ... + + +def convert_a3m_to_stockholm(a3m_sequences: Iterable[str]) -> list[str]: ... diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/msa_conversion_pybind.cc b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/msa_conversion_pybind.cc new file mode 100644 index 0000000000000000000000000000000000000000..c192052f02bda88b00ceafb2cdbe222b141fcef3 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/msa_conversion_pybind.cc @@ -0,0 +1,162 @@ +// Copyright 2024 DeepMind Technologies Limited +// +// AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +// this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +// +// To request access to the AlphaFold 3 model parameters, follow the process set +// out at https://github.com/google-deepmind/alphafold3. You may only use these +// if received directly from Google. Use is subject to terms of use available at +// https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +#include +#include +#include +#include +#include + +#include "absl/strings/ascii.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" + +namespace { + +namespace py = pybind11; + +std::vector ConvertA3MToStockholm( + std::vector a3m_sequences) { + std::vector stockholm_sequences(a3m_sequences.size()); + auto max_length_element = + std::max_element(a3m_sequences.begin(), a3m_sequences.end(), + [](absl::string_view lhs, absl::string_view rhs) { + return lhs.size() < rhs.size(); + }); + + for (auto& out : stockholm_sequences) { + out.reserve(max_length_element->size()); + } + + // While any sequence has remaining columns. + while (std::any_of(a3m_sequences.begin(), a3m_sequences.end(), + [](absl::string_view in) { return !in.empty(); })) { + if (std::any_of(a3m_sequences.begin(), a3m_sequences.end(), + [](absl::string_view in) { + return !in.empty() && absl::ascii_islower(in.front()); + })) { + // Insertion(s) found at column. + for (std::size_t i = 0; i < a3m_sequences.size(); ++i) { + absl::string_view& in = a3m_sequences[i]; + std::string& out = stockholm_sequences[i]; + if (!in.empty() && absl::ascii_islower(in.front())) { + // Consume insertion. + out.push_back(absl::ascii_toupper(in.front())); + in.remove_prefix(1); + } else { + // Row requires padding. + out.push_back('-'); + } + } + } else { + // No insertions found. + for (std::size_t i = 0; i < a3m_sequences.size(); ++i) { + absl::string_view& in = a3m_sequences[i]; + std::string& out = stockholm_sequences[i]; + if (!in.empty()) { + // Consume entire column. + out.push_back(in.front()); + in.remove_prefix(1); + } else { + // One alignment is shorter than the others. Should not happen with + // valid A3M input. + throw std::invalid_argument(absl::StrFormat( + "a3m rows have inconsistent lengths; row %d has no columns left " + "but not all rows are exhausted", + i)); + } + } + } + } + return stockholm_sequences; +} + +std::string AlignSequenceToGaplessQuery(absl::string_view sequence, + absl::string_view query_sequence) { + if (sequence.size() != query_sequence.size()) { + throw py::value_error( + absl::StrFormat("The sequence (%d) and the query sequence (%d) don't " + "have the same length.", + sequence.size(), query_sequence.size())); + } + std::string output; + for (std::size_t residue_index = 0, sequence_length = sequence.size(); + residue_index < sequence_length; ++residue_index) { + const char query_residue = query_sequence[residue_index]; + const char residue = sequence[residue_index]; + if (query_residue != '-') { + // No gap in the query, so the residue is aligned. + output += residue; + } else if (residue == '-') { + // Gap in both sequence and query, simply skip. + continue; + } else { + // Gap only in the query, so this must be an inserted residue. + output += absl::ascii_tolower(residue); + } + } + return output; +} + +constexpr char kConvertA3mToStockholm[] = R"( +Converts a list of sequences in a3m format to stockholm format sequences. + +As an example if the input is: +abCD +CgD +fCDa + +Then the output will be: +ABC-D- +--CGD- +F-C-DA + +Args: + a3m_sequences: A list of strings in a3m format. + +Returns + A list of strings converted to stockholm format. +)"; + +constexpr char kAlignSequenceToGaplessQuery[] = R"( +Aligns a sequence to a gapless query sequence. + +This is useful when converting Stockholm MSA to A3M MSA. Example: +Seq : AB--E +Query: A--DE +Output: Ab-E. + +Args: + sequence: A string containing to be aligned. + query_sequence: A string containing the reference sequence to align to. + +Returns + The input sequence with gaps dropped where both the `sequence` and + `query_sequence` have gaps, and sequence elements non-capitalized where the + `query_sequence` has a gap, but the `sequence` does not. +)"; + +} // namespace + +namespace alphafold3 { + +void RegisterModuleMsaConversion(pybind11::module m) { + m.def("convert_a3m_to_stockholm", &ConvertA3MToStockholm, + py::arg("a3m_sequences"), py::call_guard(), + py::doc(kConvertA3mToStockholm + 1)); + m.def("align_sequence_to_gapless_query", &AlignSequenceToGaplessQuery, + py::arg("sequence"), py::arg("query_sequence"), + py::call_guard(), + py::doc(kAlignSequenceToGaplessQuery + 1)); +} + +} // namespace alphafold3 diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/msa_conversion_pybind.h b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/msa_conversion_pybind.h new file mode 100644 index 0000000000000000000000000000000000000000..65f5fe99ec45f0199a32dd90d544e82fa7f21ea2 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/msa_conversion_pybind.h @@ -0,0 +1,24 @@ +/* + * Copyright 2024 DeepMind Technologies Limited + * + * AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of + * this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ + * + * To request access to the AlphaFold 3 model parameters, follow the process set + * out at https://github.com/google-deepmind/alphafold3. You may only use these + * if received directly from Google. Use is subject to terms of use available at + * https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + */ + +#ifndef ALPHAFOLD3_SRC_ALPHAFOLD3_PARSERS_PYTHON_MSA_CONVERSION_PYBIND_H_ +#define ALPHAFOLD3_SRC_ALPHAFOLD3_PARSERS_PYTHON_MSA_CONVERSION_PYBIND_H_ + +#include "pybind11/pybind11.h" + +namespace alphafold3 { + +void RegisterModuleMsaConversion(pybind11::module m); + +} + +#endif // ALPHAFOLD3_SRC_ALPHAFOLD3_PARSERS_PYTHON_MSA_CONVERSION_PYBIND_H_ diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/__init__.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..17f44cd06771a0f1d84b925ed1861bdb79a9af71 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/__init__.py @@ -0,0 +1,46 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Structure module initialization.""" + +# pylint: disable=g-importing-member +from alphafold3.structure.bioassemblies import BioassemblyData +from alphafold3.structure.bonds import Bonds +from alphafold3.structure.chemical_components import ChemCompEntry +from alphafold3.structure.chemical_components import ChemicalComponentsData +from alphafold3.structure.chemical_components import get_data_for_ccd_components +from alphafold3.structure.chemical_components import populate_missing_ccd_data +from alphafold3.structure.mmcif import BondParsingError +from alphafold3.structure.parsing import BondAtomId +from alphafold3.structure.parsing import from_atom_arrays +from alphafold3.structure.parsing import from_mmcif +from alphafold3.structure.parsing import from_parsed_mmcif +from alphafold3.structure.parsing import from_res_arrays +from alphafold3.structure.parsing import from_sequences_and_bonds +from alphafold3.structure.parsing import ModelID +from alphafold3.structure.parsing import SequenceFormat +from alphafold3.structure.structure import ARRAY_FIELDS +from alphafold3.structure.structure import AuthorNamingScheme +from alphafold3.structure.structure import Bond +from alphafold3.structure.structure import CascadeDelete +from alphafold3.structure.structure import concat +from alphafold3.structure.structure import enumerate_residues +from alphafold3.structure.structure import fix_non_standard_polymer_residues +from alphafold3.structure.structure import GLOBAL_FIELDS +from alphafold3.structure.structure import make_empty_structure +from alphafold3.structure.structure import MissingAtomError +from alphafold3.structure.structure import MissingAuthorResidueIdError +from alphafold3.structure.structure import multichain_residue_index +from alphafold3.structure.structure import stack +from alphafold3.structure.structure import Structure +from alphafold3.structure.structure_tables import Atoms +from alphafold3.structure.structure_tables import Chains +from alphafold3.structure.structure_tables import Residues diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/bioassemblies.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/bioassemblies.py new file mode 100644 index 0000000000000000000000000000000000000000..6c1d8e3ccf3dbdbfe983cb88efc17fe3162b5f01 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/bioassemblies.py @@ -0,0 +1,333 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Utilities for parsing and manipulating bioassembly data.""" + +from collections.abc import Mapping, Sequence +import copy +import dataclasses +from typing_extensions import Self + +from alphafold3.structure import mmcif +import numpy as np + + +@dataclasses.dataclass(frozen=True) +class Operation: + """A rigid transformation operation.""" + + trans: np.ndarray # shape: (3,) + rot: np.ndarray # shape: (3, 3) + + def apply_to_coords(self, coords: np.ndarray) -> np.ndarray: + """Applies the rotation followed by the translation to `coords`.""" + return np.dot(coords, self.rot.T) + self.trans[np.newaxis, :] + + +@dataclasses.dataclass(frozen=True) +class Transform: + """A rigid transformation composed of a sequence of `Operation`s.""" + + # The sequence of operations that form the transform. These will be applied + # right-to-left (last-to-first). + operations: Sequence[Operation] + + # The chain IDs that this transform should be applied to. These are + # label_asym_ids in the mmCIF spec. + chain_ids: Sequence[str] + + # A mapping from chain IDs (of chains that participate in this transform) + # to their new values in the bioassembly. + chain_id_rename_map: Mapping[str, str] + + def apply_to_coords(self, coords: np.ndarray) -> np.ndarray: + """Applies the `operations` in right-to-left order.""" + for operation in reversed(self.operations): + coords = operation.apply_to_coords(coords) + return coords + + +def _get_operation(oper_data: Mapping[str, str]) -> Operation: + """Parses an `Operation` from a mmCIF _pdbx_struct_oper_list row.""" + trans = np.zeros((3,), dtype=np.float32) + rot = np.zeros((3, 3), dtype=np.float32) + for i in range(3): + trans[i] = float(oper_data[f'_pdbx_struct_oper_list.vector[{i + 1}]']) + for i in range(3): + for j in range(3): + rot[i][j] = float( + oper_data[f'_pdbx_struct_oper_list.matrix[{i + 1}][{j + 1}]'] + ) + return Operation(trans=trans, rot=rot) + + +class MissingBioassemblyDataError(Exception): + """Raised when bioassembly data is missing from an mmCIF.""" + + +class BioassemblyData: + """Stores and processes bioassembly data from mmCIF tables.""" + + # Not all of these columns are required for internal operations, but all + # should be present whenever bioassemblies are defined in an mmCIF to stay + # consistent with external mmCIFs. + _REQUIRED_COLUMNS = ( + '_pdbx_struct_assembly.id', + '_pdbx_struct_assembly.details', + '_pdbx_struct_assembly.method_details', + '_pdbx_struct_assembly.oligomeric_details', + '_pdbx_struct_assembly.oligomeric_count', + '_pdbx_struct_assembly_gen.assembly_id', + '_pdbx_struct_assembly_gen.oper_expression', + '_pdbx_struct_assembly_gen.asym_id_list', + '_pdbx_struct_oper_list.id', + '_pdbx_struct_oper_list.type', + '_pdbx_struct_oper_list.name', + '_pdbx_struct_oper_list.symmetry_operation', + '_pdbx_struct_oper_list.matrix[1][1]', + '_pdbx_struct_oper_list.matrix[1][2]', + '_pdbx_struct_oper_list.matrix[1][3]', + '_pdbx_struct_oper_list.vector[1]', + '_pdbx_struct_oper_list.matrix[2][1]', + '_pdbx_struct_oper_list.matrix[2][2]', + '_pdbx_struct_oper_list.matrix[2][3]', + '_pdbx_struct_oper_list.vector[2]', + '_pdbx_struct_oper_list.matrix[3][1]', + '_pdbx_struct_oper_list.matrix[3][2]', + '_pdbx_struct_oper_list.matrix[3][3]', + '_pdbx_struct_oper_list.vector[3]', + ) + + def __init__( + self, + *, + pdbx_struct_assembly: Mapping[str, Mapping[str, str]], + pdbx_struct_assembly_gen: Mapping[str, Sequence[Mapping[str, str]]], + pdbx_struct_oper_list: Mapping[str, Mapping[str, str]], + assembly_ids: Sequence[str], + oper_ids: Sequence[str], + ): + for assembly_id in assembly_ids: + for table, table_name in ( + (pdbx_struct_assembly, '_pdbx_struct_assembly'), + (pdbx_struct_assembly_gen, '_pdbx_struct_assembly_gen'), + ): + if assembly_id not in table: + raise ValueError( + f'Assembly ID "{assembly_id}" missing from {table_name} ' + f'with keys: {table.keys()}' + ) + for oper_id in oper_ids: + if oper_id not in pdbx_struct_oper_list: + raise ValueError( + f'Oper ID "{oper_id}" missing from _pdbx_struct_oper_list ' + f'with keys: {pdbx_struct_oper_list.keys()}' + ) + + self._pdbx_struct_assembly = pdbx_struct_assembly + self._pdbx_struct_assembly_gen = pdbx_struct_assembly_gen + self._pdbx_struct_oper_list = pdbx_struct_oper_list + self._operations = { + oper_id: _get_operation(oper_data) + for oper_id, oper_data in self._pdbx_struct_oper_list.items() + } + self._assembly_ids = assembly_ids + self._oper_ids = oper_ids + + @classmethod + def from_mmcif(cls, cif: mmcif.Mmcif) -> Self: + """Constructs an instance of `BioassemblyData` from an `Mmcif` object.""" + for col in cls._REQUIRED_COLUMNS: + if col not in cif: + raise MissingBioassemblyDataError(col) + + pdbx_struct_assembly = cif.extract_loop_as_dict( + prefix='_pdbx_struct_assembly.', index='_pdbx_struct_assembly.id' + ) + pdbx_struct_oper_list = cif.extract_loop_as_dict( + prefix='_pdbx_struct_oper_list.', index='_pdbx_struct_oper_list.id' + ) + + # _pdbx_struct_assembly_gen is unlike the other two tables because it can + # have multiple rows share the same assembly ID. This can happen when an + # assembly is constructed by applying different sets of transforms to + # different sets of chain IDs. Each of these would have its own row. + # Here we group rows by their assembly_id. + pdbx_struct_assembly_gen = {} + for assembly_id, oper_expression, asym_id_list in zip( + cif['_pdbx_struct_assembly_gen.assembly_id'], + cif['_pdbx_struct_assembly_gen.oper_expression'], + cif['_pdbx_struct_assembly_gen.asym_id_list'], + ): + pdbx_struct_assembly_gen.setdefault(assembly_id, []).append({ + '_pdbx_struct_assembly_gen.assembly_id': assembly_id, + '_pdbx_struct_assembly_gen.oper_expression': oper_expression, + '_pdbx_struct_assembly_gen.asym_id_list': asym_id_list, + }) + + # We provide these separately to keep track of the original order that they + # appear in the mmCIF. + assembly_ids = cif['_pdbx_struct_assembly.id'] + oper_ids = cif['_pdbx_struct_oper_list.id'] + return cls( + pdbx_struct_assembly=pdbx_struct_assembly, + pdbx_struct_assembly_gen=pdbx_struct_assembly_gen, + pdbx_struct_oper_list=pdbx_struct_oper_list, + assembly_ids=assembly_ids, + oper_ids=oper_ids, + ) + + @property + def assembly_ids(self) -> Sequence[str]: + return self._assembly_ids + + def asym_id_by_assembly_chain_id(self, assembly_id: str) -> Mapping[str, str]: + asym_id_by_assembly_chain_id = {} + for transform in self.get_transforms(assembly_id): + for asym_id, assembly_chain_id in transform.chain_id_rename_map.items(): + asym_id_by_assembly_chain_id[assembly_chain_id] = asym_id + return asym_id_by_assembly_chain_id + + def assembly_chain_ids_by_asym_id( + self, assembly_id: str + ) -> Mapping[str, set[str]]: + assembly_chain_ids_by_asym_id = {} + for transform in self.get_transforms(assembly_id): + for asym_id, assembly_chain_id in transform.chain_id_rename_map.items(): + assembly_chain_ids_by_asym_id.setdefault(asym_id, set()).add( + assembly_chain_id + ) + return assembly_chain_ids_by_asym_id + + def get_default_assembly_id(self) -> str: + """Gets a default assembly ID.""" + # The first assembly is usually (though not always) the best choice. + # If we find a better heuristic for picking bioassemblies then this + # method should be updated. + return min(self._assembly_ids) + + def get_assembly_info(self, assembly_id: str) -> Mapping[str, str]: + return { + k.replace('_pdbx_struct_assembly.', ''): v + for k, v in self._pdbx_struct_assembly[assembly_id].items() + } + + def get_transforms(self, assembly_id: str) -> Sequence[Transform]: + """Returns the transforms required to generate the given assembly.""" + partial_transforms = [] + all_chain_ids = set() + for row in self._pdbx_struct_assembly_gen[assembly_id]: + oper_expression = row['_pdbx_struct_assembly_gen.oper_expression'] + parsed_oper_id_seqs = mmcif.parse_oper_expr(oper_expression) + label_asym_ids = row['_pdbx_struct_assembly_gen.asym_id_list'].split( + ',') + all_chain_ids |= set(label_asym_ids) + for parsed_oper_id_seq in parsed_oper_id_seqs: + partial_transforms.append((parsed_oper_id_seq, label_asym_ids)) + + # We start assigning new chain IDs by finding the largest chain ID in + # the original structure that is involved in this bioassembly, and then + # starting from the next one. + max_int_chain_id = max(mmcif.str_id_to_int_id(c) + for c in all_chain_ids) + next_int_chain_id = max_int_chain_id + 1 + + transforms = [] + has_been_renamed = set() + for parsed_oper_id_seq, label_asym_ids in partial_transforms: + chain_id_rename_map = {} + for label_asym_id in label_asym_ids: + if label_asym_id not in has_been_renamed: + # The first time we see a label_asym_id we don't need to rename it. + # This isn't strictly necessary since we don't provide any + # guarantees about chain naming after bioassembly extraction but + # can make it a bit easier to inspect and compare structures + # pre and post bioassembly extraction. + chain_id_rename_map[label_asym_id] = label_asym_id + has_been_renamed.add(label_asym_id) + else: + chain_id_rename_map[label_asym_id] = mmcif.int_id_to_str_id( + next_int_chain_id + ) + next_int_chain_id += 1 + transforms.append( + Transform( + operations=[ + self._operations[oper_id] for oper_id in parsed_oper_id_seq + ], + chain_ids=label_asym_ids, + chain_id_rename_map=chain_id_rename_map, + ) + ) + return transforms + + def to_mmcif_dict(self) -> Mapping[str, Sequence[str]]: + """Returns the bioassembly data as a dict suitable for `mmcif.Mmcif`.""" + mmcif_dict = {} + for assembly_id in self._assembly_ids: + for column, val in self._pdbx_struct_assembly[assembly_id].items(): + mmcif_dict.setdefault(column, []).append(val) + for row in self._pdbx_struct_assembly_gen[assembly_id]: + for column, val in row.items(): + mmcif_dict.setdefault(column, []).append(val) + for oper_id in self._oper_ids: + for column, val in self._pdbx_struct_oper_list[oper_id].items(): + mmcif_dict.setdefault(column, []).append(val) + return mmcif_dict + + def rename_label_asym_ids( + self, + mapping: Mapping[str, str], + present_chains: set[str], + ) -> Self: + """Returns a new BioassemblyData with renamed label_asym_ids. + + Args: + mapping: A mapping from original label_asym_ids to their new values. Any + label_asym_ids in this BioassemblyData that are not in this mapping will + remain unchanged. + present_chains: A set of label_asym_ids that are actually present in the + atom site list. All label_asym_ids that are in the BioassemblyData but + not in present_chains won't be included in the output BioassemblyData. + + Returns: + A new BioassemblyData with renamed label_asym_ids. + + Raises: + ValueError: If any two previously distinct chains do not have unique names + anymore after the rename. + """ + new_pdbx_struct_assembly_gen = copy.deepcopy( + self._pdbx_struct_assembly_gen) + for rows in new_pdbx_struct_assembly_gen.values(): + for row in rows: + old_asym_ids = row['_pdbx_struct_assembly_gen.asym_id_list'].split( + ',') + new_asym_ids = [ + mapping.get(label_asym_id, label_asym_id) + for label_asym_id in old_asym_ids + if label_asym_id in present_chains + ] + if len(set(old_asym_ids) & present_chains) != len(set(new_asym_ids)): + raise ValueError( + 'Can not rename chains, the new names are not unique: ' + f'{sorted(new_asym_ids)}.' + ) + row['_pdbx_struct_assembly_gen.asym_id_list'] = ','.join( + new_asym_ids) # pytype: disable=unsupported-operands + + return BioassemblyData( + pdbx_struct_assembly=copy.deepcopy(self._pdbx_struct_assembly), + pdbx_struct_assembly_gen=new_pdbx_struct_assembly_gen, + pdbx_struct_oper_list=copy.deepcopy(self._pdbx_struct_oper_list), + assembly_ids=copy.deepcopy(self._assembly_ids), + oper_ids=copy.deepcopy(self._oper_ids), + ) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/bonds.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/bonds.py new file mode 100644 index 0000000000000000000000000000000000000000..94689d797b8650b8810f53c161a9715e2efa18eb --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/bonds.py @@ -0,0 +1,237 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Bond representation for structure module.""" + +import collections +from collections.abc import Mapping, Sequence +import dataclasses +import typing +from typing_extensions import Self + +from alphafold3.structure import table +import numpy as np + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class Bonds(table.Table): + """Table of atomic bonds.""" + + # mmCIF column: _struct_conn.conn_type_id + # mmCIF desc: This data item is a pointer to _struct_conn_type.id in the + # STRUCT_CONN_TYPE category. + # E.g.: "covale", "disulf", "hydrog", "metalc". + type: np.ndarray + + # mmCIF column: _struct_conn.pdbx_role + # mmCIF desc: The chemical or structural role of the interaction. + # E.g.: "N-Glycosylation", "O-Glycosylation". + role: np.ndarray + + # mmCIF columns: _struct_conn.ptnr1_* + from_atom_key: np.ndarray + + # mmCIF columns: _struct_conn.ptnr2_* + dest_atom_key: np.ndarray + + @classmethod + def make_empty(cls) -> Self: + return cls( + key=np.empty((0,), dtype=np.int64), + from_atom_key=np.empty((0,), dtype=np.int64), + dest_atom_key=np.empty((0,), dtype=np.int64), + type=np.empty((0,), dtype=object), + role=np.empty((0,), dtype=object), + ) + + def get_atom_indices( + self, + atom_key: np.ndarray, + ) -> tuple[np.ndarray, np.ndarray]: + """Returns the indices of the from/dest atoms in the atom_key array.""" + from_atom_missing = ~np.isin(self.from_atom_key, atom_key) + dest_atom_missing = ~np.isin(self.dest_atom_key, atom_key) + if np.any(from_atom_missing): + raise ValueError( + f'No atoms for from_atom_key {self.from_atom_key[from_atom_missing]}' + ) + if np.any(dest_atom_missing): + raise ValueError( + f'No atoms for dest_atom_key {self.dest_atom_key[dest_atom_missing]}' + ) + sort_indices = np.argsort(atom_key) + from_indices_sorted = np.searchsorted( + atom_key, self.from_atom_key, sorter=sort_indices + ) + dest_indices_sorted = np.searchsorted( + atom_key, self.dest_atom_key, sorter=sort_indices + ) + from_indices = sort_indices[from_indices_sorted] + dest_indices = sort_indices[dest_indices_sorted] + return from_indices, dest_indices + + def restrict_to_atoms(self, atom_key: np.ndarray) -> Self: + if not self.size: # Early-out for empty table. + return self + from_atom_mask = np.isin(self.from_atom_key, atom_key) + dest_atom_mask = np.isin(self.dest_atom_key, atom_key) + mask = np.logical_and(from_atom_mask, dest_atom_mask) + return typing.cast(Bonds, self.filter(mask=mask)) + + def to_mmcif_dict_from_atom_arrays( + self, + atom_key: np.ndarray, + chain_id: np.ndarray, + res_id: np.ndarray, + res_name: np.ndarray, + atom_name: np.ndarray, + auth_asym_id: np.ndarray, + auth_seq_id: np.ndarray, + insertion_code: np.ndarray, + ) -> Mapping[str, Sequence[str] | np.ndarray]: + """Returns a dict suitable for building a CifDict, representing bonds. + + Args: + atom_key: A (num_atom,) integer array of atom_keys. + chain_id: A (num_atom,) array of label_asym_id strings. + res_id: A (num_atom,) array of label_seq_id strings. + res_name: A (num_atom,) array of label_comp_id strings. + atom_name: A (num_atom,) array of label_atom_id strings. + auth_asym_id: A (num_atom,) array of auth_asym_id strings. + auth_seq_id: A (num_atom,) array of auth_seq_id strings. + insertion_code: A (num_atom,) array of insertion code strings. + """ + mmcif_dict = collections.defaultdict(list) + ptnr1_indices, ptnr2_indices = self.get_atom_indices(atom_key) + + mmcif_dict['_struct_conn.ptnr1_label_asym_id'] = chain_id[ptnr1_indices] + mmcif_dict['_struct_conn.ptnr2_label_asym_id'] = chain_id[ptnr2_indices] + mmcif_dict['_struct_conn.ptnr1_label_comp_id'] = res_name[ptnr1_indices] + mmcif_dict['_struct_conn.ptnr2_label_comp_id'] = res_name[ptnr2_indices] + mmcif_dict['_struct_conn.ptnr1_label_seq_id'] = res_id[ptnr1_indices] + mmcif_dict['_struct_conn.ptnr2_label_seq_id'] = res_id[ptnr2_indices] + mmcif_dict['_struct_conn.ptnr1_label_atom_id'] = atom_name[ptnr1_indices] + mmcif_dict['_struct_conn.ptnr2_label_atom_id'] = atom_name[ptnr2_indices] + + mmcif_dict['_struct_conn.ptnr1_auth_asym_id'] = auth_asym_id[ptnr1_indices] + mmcif_dict['_struct_conn.ptnr2_auth_asym_id'] = auth_asym_id[ptnr2_indices] + mmcif_dict['_struct_conn.ptnr1_auth_seq_id'] = auth_seq_id[ptnr1_indices] + mmcif_dict['_struct_conn.ptnr2_auth_seq_id'] = auth_seq_id[ptnr2_indices] + mmcif_dict['_struct_conn.pdbx_ptnr1_PDB_ins_code'] = insertion_code[ + ptnr1_indices + ] + mmcif_dict['_struct_conn.pdbx_ptnr2_PDB_ins_code'] = insertion_code[ + ptnr2_indices + ] + + label_alt_id = ['?'] * self.size + mmcif_dict['_struct_conn.pdbx_ptnr1_label_alt_id'] = label_alt_id + mmcif_dict['_struct_conn.pdbx_ptnr2_label_alt_id'] = label_alt_id + + # We need to set this to make visualisation work in NGL/PyMOL. + mmcif_dict['_struct_conn.pdbx_value_order'] = ['?'] * self.size + + # We use a symmetry of 1_555 which is the no-op transformation. Other + # values are used when bonds involve atoms that only exist after expanding + # the bioassembly, but we don't support this kind of bond at the moment. + symmetry = ['1_555'] * self.size + mmcif_dict['_struct_conn.ptnr1_symmetry'] = symmetry + mmcif_dict['_struct_conn.ptnr2_symmetry'] = symmetry + bond_type_counter = collections.Counter() + for bond_row in self.iterrows(): + bond_type = bond_row['type'] + bond_type_counter[bond_type] += 1 + mmcif_dict['_struct_conn.id'].append( + f'{bond_type}{bond_type_counter[bond_type]}' + ) + mmcif_dict['_struct_conn.pdbx_role'].append(bond_row['role']) + mmcif_dict['_struct_conn.conn_type_id'].append(bond_type) + + bond_types = np.unique(self.type) + mmcif_dict['_struct_conn_type.id'] = bond_types + unknown = ['?'] * len(bond_types) + mmcif_dict['_struct_conn_type.criteria'] = unknown + mmcif_dict['_struct_conn_type.reference'] = unknown + + return dict(mmcif_dict) + + +def concat_with_atom_keys( + bonds_tables: Sequence[Bonds | None], + atom_key_arrays: Sequence[np.ndarray], +) -> tuple[Bonds | None, np.ndarray]: + """Concatenates bonds tables and atom keys simultaneously. + + Args: + bonds_tables: A sequence of `Bonds` instances to concatenate. If any are + None then these are skipped. + atom_key_arrays: A sequence of integer `atom_key` arrays, where the n-th + bonds_table refers to the atoms in the n-th atom_key array. These must + all be non-None. + + Returns: + A pair of (bonds, atom_key) where atom_key is a unique atom_key array with + length equal to the sum of the input atom array sizes, and the bonds table + contains all the bonds from the individual bonds table inputs. + """ + if not bonds_tables or not atom_key_arrays: + if bonds_tables or atom_key_arrays: + raise ValueError( + 'bonds_tables and atom_keys must have same length but got' + f' {len(bonds_tables)=} and {len(atom_key_arrays)=}' + ) + return None, np.array([], dtype=np.int64) + max_key = -1 + atom_keys_to_concat = [] + types_to_concat = [] + roles_to_concat = [] + from_atom_keys_to_concat = [] + dest_atom_keys_to_concat = [] + for bonds, atom_key in zip(bonds_tables, atom_key_arrays, strict=True): + if not atom_key.size: + assert bonds is None or bonds.size == 0 + continue + # Should always be non-negative! + assert np.min(atom_key, initial=0) >= 0 + offset = max_key + 1 + offset_atom_key = atom_key + offset + atom_keys_to_concat.append(offset_atom_key) + max_key = np.max(offset_atom_key) + if bonds is not None: + types_to_concat.append(bonds.type) + roles_to_concat.append(bonds.role) + from_atom_keys_to_concat.append(bonds.from_atom_key + offset) + dest_atom_keys_to_concat.append(bonds.dest_atom_key + offset) + + if atom_keys_to_concat: + concatted_atom_keys = np.concatenate(atom_keys_to_concat, axis=0) + else: + concatted_atom_keys = np.array([], dtype=np.int64) + + if types_to_concat: + assert ( + len(types_to_concat) + == len(roles_to_concat) + == len(from_atom_keys_to_concat) + == len(dest_atom_keys_to_concat) + ) + num_bonds = sum(b.size for b in bonds_tables if b is not None) + concatted_bonds = Bonds( + key=np.arange(num_bonds, dtype=np.int64), + type=np.concatenate(types_to_concat, axis=0), + role=np.concatenate(roles_to_concat, axis=0), + from_atom_key=np.concatenate(from_atom_keys_to_concat, axis=0), + dest_atom_key=np.concatenate(dest_atom_keys_to_concat, axis=0), + ) + else: + concatted_bonds = None + + return concatted_bonds, concatted_atom_keys diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/chemical_components.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/chemical_components.py new file mode 100644 index 0000000000000000000000000000000000000000..a25e91017c10212a76c818b8ec960211c3609340 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/chemical_components.py @@ -0,0 +1,286 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Utilities for manipulating chemical components data.""" + +from collections.abc import Iterable, Mapping, Sequence +import dataclasses +import functools +from typing_extensions import Self + +from alphafold3.constants import chemical_components +from alphafold3.constants import residue_names +from alphafold3.structure import mmcif +import rdkit.Chem as rd_chem + + +@dataclasses.dataclass(frozen=True) +class ChemCompEntry: + """Items of _chem_comp category. + + For the full list of items and their semantics see + http://mmcif.rcsb.org/dictionaries/mmcif_pdbx_v50.dic/Categories/chem_comp.html + """ + + type: str + name: str = '?' + pdbx_synonyms: str = '?' + formula: str = '?' + formula_weight: str = '?' + mon_nstd_flag: str = '?' + pdbx_smiles: str | None = None + + def __post_init__(self): + for field, value in vars(self).items(): + if not value and value is not None: + raise ValueError(f"{field} value can't be an empty string.") + + def extends(self, other: Self) -> bool: + """Checks whether this ChemCompEntry extends another one.""" + for field, value in vars(self).items(): + other_value = getattr(other, field) + if _value_is_missing(other_value): + continue + if value != other_value: + return False + return True + + @property + def rdkit_mol(self) -> rd_chem.Mol: + """Returns an RDKit Mol, created via RDKit from entry SMILES string.""" + if not self.pdbx_smiles: + raise ValueError( + 'Cannot construct RDKit Mol with empty pdbx_smiles') + return rd_chem.MolFromSmiles(self.pdbx_smiles) + + +_REQUIRED_MMCIF_COLUMNS = ('_chem_comp.id', '_chem_comp.type') + + +class MissingChemicalComponentsDataError(Exception): + """Raised when chemical components data is missing from an mmCIF.""" + + +@dataclasses.dataclass(frozen=True) +class ChemicalComponentsData: + """Extra information for chemical components occurring in mmCIF. + + Fields: + chem_comp: A mapping from _chem_comp.id to associated items in the + chem_comp category. + """ + + chem_comp: Mapping[str, ChemCompEntry] + + @classmethod + def from_mmcif( + cls, cif: mmcif.Mmcif, fix_mse: bool, fix_unknown_dna: bool + ) -> Self: + """Constructs an instance of ChemicalComponentsData from an Mmcif object.""" + for col in _REQUIRED_MMCIF_COLUMNS: + if col not in cif: + raise MissingChemicalComponentsDataError(col) + + id_ = cif['_chem_comp.id'] # Guaranteed to be present. + type_ = cif['_chem_comp.type'] # Guaranteed to be present. + name = cif.get('_chem_comp.name', ['?'] * len(id_)) + synonyms = cif.get('_chem_comp.pdbx_synonyms', ['?'] * len(id_)) + formula = cif.get('_chem_comp.formula', ['?'] * len(id_)) + weight = cif.get('_chem_comp.formula_weight', ['?'] * len(id_)) + mon_nstd_flag = cif.get('_chem_comp.mon_nstd_flag', ['?'] * len(id_)) + smiles = cif.get('_chem_comp.pdbx_smiles', ['?'] * len(id_)) + smiles = [None if s == '?' else s for s in smiles] + + chem_comp = { + component_name: ChemCompEntry(*entry) + for component_name, *entry in zip( + id_, type_, name, synonyms, formula, weight, mon_nstd_flag, smiles + ) + } + + if fix_mse and 'MSE' in chem_comp: + if 'MET' not in chem_comp: + chem_comp['MET'] = ChemCompEntry( + type='L-PEPTIDE LINKING', + name='METHIONINE', + pdbx_synonyms='?', + formula='C5 H11 N O2 S', + formula_weight='149.211', + mon_nstd_flag='y', + pdbx_smiles=None, + ) + + if fix_unknown_dna and 'N' in chem_comp: + # Do not delete 'N' as it may be needed for RNA in the system. + if 'DN' not in chem_comp: + chem_comp['DN'] = ChemCompEntry( + type='DNA LINKING', + name="UNKNOWN 2'-DEOXYNUCLEOTIDE", + pdbx_synonyms='?', + formula='C5 H11 O6 P', + formula_weight='198.111', + mon_nstd_flag='y', + pdbx_smiles=None, + ) + + return ChemicalComponentsData(chem_comp) + + def to_mmcif_dict(self) -> Mapping[str, Sequence[str]]: + """Returns chemical components data as a dict suitable for `mmcif.Mmcif`.""" + mmcif_dict = {} + + mmcif_fields = set() + for entry in self.chem_comp.values(): + for field, value in vars(entry).items(): + if value: + mmcif_fields.add(field) + chem_comp_ids = [] + for component_id in sorted(self.chem_comp): + entry = self.chem_comp[component_id] + chem_comp_ids.append(component_id) + for field in mmcif_fields: + mmcif_dict.setdefault(f'_chem_comp.{field}', []).append( + getattr(entry, field) or '?' + ) + if chem_comp_ids: + mmcif_dict['_chem_comp.id'] = chem_comp_ids + return mmcif_dict + + +def _value_is_missing(value: str) -> bool: + return not value or value in ('.', '?') + + +def get_data_for_ccd_components( + ccd: chemical_components.Ccd, + chemical_component_ids: Iterable[str], + populate_pdbx_smiles: bool = False, +) -> ChemicalComponentsData: + """Returns `ChemicalComponentsData` for chemical components known by PDB.""" + chem_comp = {} + for chemical_component_id in chemical_component_ids: + chem_data = chemical_components.component_name_to_info( + ccd=ccd, res_name=chemical_component_id + ) + if not chem_data: + continue + chem_comp[chemical_component_id] = ChemCompEntry( + type=chem_data.type, + name=chem_data.name, + pdbx_synonyms=chem_data.pdbx_synonyms, + formula=chem_data.formula, + formula_weight=chem_data.formula_weight, + mon_nstd_flag=chem_data.mon_nstd_flag, + pdbx_smiles=( + chem_data.pdbx_smiles or None if populate_pdbx_smiles else None + ), + ) + return ChemicalComponentsData(chem_comp=chem_comp) + + +def populate_missing_ccd_data( + ccd: chemical_components.Ccd, + chemical_components_data: ChemicalComponentsData, + chemical_component_ids: Iterable[str] | None = None, + populate_pdbx_smiles: bool = False, +) -> ChemicalComponentsData: + """Populates missing data for the chemical components from CCD. + + Args: + ccd: The chemical components database. + chemical_components_data: ChemicalComponentsData to populate missing values + for. This function doesn't modify the object, extended version is provided + as a return value. + chemical_component_ids: chemical components to populate missing values for. + If not specified, the function will consider all chemical components which + are already present in `chemical_components_data`. + populate_pdbx_smiles: whether to populate `pdbx_smiles` field using SMILES + descriptors from _pdbx_chem_comp_descriptor CCD table. If CCD provides + multiple SMILES strings, any of them could be used. + + Returns: + New instance of ChemicalComponentsData without missing values for CCD + entries. + """ + if chemical_component_ids is None: + chemical_component_ids = chemical_components_data.chem_comp.keys() + + ccd_data = get_data_for_ccd_components( + ccd, chemical_component_ids, populate_pdbx_smiles + ) + chem_comp = dict(chemical_components_data.chem_comp) + for component_id, ccd_entry in ccd_data.chem_comp.items(): + if component_id not in chem_comp: + chem_comp[component_id] = ccd_entry + else: + already_specified_fields = { + field: value + for field, value in vars(chem_comp[component_id]).items() + if not _value_is_missing(value) + } + chem_comp[component_id] = ChemCompEntry( + **{**vars(ccd_entry), **already_specified_fields} + ) + return ChemicalComponentsData(chem_comp=chem_comp) + + +def get_all_atoms_in_entry( + ccd: chemical_components.Ccd, res_name: str +) -> Mapping[str, Sequence[str]]: + """Get all possible atoms and bonds for this residue in a standard order. + + Args: + ccd: The chemical components dictionary. + res_name: Full CCD name. + + Returns: + A dictionary table of the atoms and bonds for this residue in this residue + type. + """ + # The CCD version of 'UNK' is weird. It has a CB and a CG atom. We just want + # the minimal amino-acid here which is GLY. + if res_name == 'UNK': + res_name = 'GLY' + ccd_data = ccd.get(res_name) + if not ccd_data: + raise ValueError(f'Unknown residue type {res_name}') + + keys = ( + '_chem_comp_atom.atom_id', + '_chem_comp_atom.type_symbol', + '_chem_comp_bond.atom_id_1', + '_chem_comp_bond.atom_id_2', + ) + + # Add terminal hydrogens for protonation of the N-terminal + if res_name == 'PRO': + res_atoms = {key: [*ccd_data.get(key, [])] for key in keys} + res_atoms['_chem_comp_atom.atom_id'].extend(['H2', 'H3']) + res_atoms['_chem_comp_atom.type_symbol'].extend(['H', 'H']) + res_atoms['_chem_comp_bond.atom_id_1'].extend(['N', 'N']) + res_atoms['_chem_comp_bond.atom_id_2'].extend(['H2', 'H3']) + elif res_name in residue_names.PROTEIN_TYPES_WITH_UNKNOWN: + res_atoms = {key: [*ccd_data.get(key, [])] for key in keys} + res_atoms['_chem_comp_atom.atom_id'].append('H3') + res_atoms['_chem_comp_atom.type_symbol'].append('H') + res_atoms['_chem_comp_bond.atom_id_1'].append('N') + res_atoms['_chem_comp_bond.atom_id_2'].append('H3') + else: + res_atoms = {key: ccd_data.get(key, []) for key in keys} + + return res_atoms + + +@functools.lru_cache(maxsize=128) +def get_res_atom_names(ccd: chemical_components.Ccd, res_name: str) -> set[str]: + """Gets the names of the atoms in a given CCD residue.""" + atoms = get_all_atoms_in_entry(ccd, res_name)['_chem_comp_atom.atom_id'] + return set(atoms) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/aggregation.pyi b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/aggregation.pyi new file mode 100644 index 0000000000000000000000000000000000000000..8f4a8b37539b6b1996be654409fd1f1605cb53c2 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/aggregation.pyi @@ -0,0 +1,13 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +from collections.abc import Sequence + +def indices_grouped_by_value(values: Sequence[int]) -> dict[int, list[int]]: ... diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/aggregation_pybind.cc b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/aggregation_pybind.cc new file mode 100644 index 0000000000000000000000000000000000000000..5ac46d62cd5872658595ed5ab5bd73793853ad7a --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/aggregation_pybind.cc @@ -0,0 +1,54 @@ +// Copyright 2024 DeepMind Technologies Limited +// +// AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +// this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +// +// To request access to the AlphaFold 3 model parameters, follow the process set +// out at https://github.com/google-deepmind/alphafold3. You may only use these +// if received directly from Google. Use is subject to terms of use available at +// https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/types/span.h" +#include "pybind11/cast.h" +#include "pybind11/numpy.h" +#include "pybind11/pybind11.h" +#include "pybind11_abseil/absl_casters.h" + +namespace { + +namespace py = pybind11; + +absl::flat_hash_map> IndicesGroupedByValue( + absl::Span values) { + absl::flat_hash_map> group_indices; + for (int64_t i = 0, e = values.size(); i < e; ++i) { + group_indices[values[i]].push_back(i); + } + return group_indices; +} + +constexpr char kIndicesGroupedByValue[] = R"( +Returns a map from value to a list of indices this value occupies. + +E.g. indices_grouped_by_value([1, 1, 2, 3, 3, 1, 1]) returns: +{1: [0, 1, 5, 6], 2: [2], 3: [3, 4]} + +Args: + values: a list of values to group. +)"; + +} // namespace + +namespace alphafold3 { + +void RegisterModuleAggregation(py::module m) { + m.def("indices_grouped_by_value", &IndicesGroupedByValue, py::arg("values"), + py::doc(kIndicesGroupedByValue + 1), + py::call_guard()); +} + +} // namespace alphafold3 diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/aggregation_pybind.h b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/aggregation_pybind.h new file mode 100644 index 0000000000000000000000000000000000000000..9547b9448d4b929699dcb88e2178d39ff292e5b9 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/aggregation_pybind.h @@ -0,0 +1,24 @@ +/* + * Copyright 2024 DeepMind Technologies Limited + * + * AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of + * this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ + * + * To request access to the AlphaFold 3 model parameters, follow the process set + * out at https://github.com/google-deepmind/alphafold3. You may only use these + * if received directly from Google. Use is subject to terms of use available at + * https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + */ + +#ifndef ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_AGGREGATION_PYBIND_H_ +#define ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_AGGREGATION_PYBIND_H_ + +#include "pybind11/pybind11.h" + +namespace alphafold3 { + +void RegisterModuleAggregation(pybind11::module m); + +} + +#endif // ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_AGGREGATION_PYBIND_H_ diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/membership.pyi b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/membership.pyi new file mode 100644 index 0000000000000000000000000000000000000000..305f36600f4dbee357b29ae35442c3382bedea4e --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/membership.pyi @@ -0,0 +1,18 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +import numpy + + +def isin( + array: numpy.ndarray[numpy.int64], + test_elements: set[int], + invert: bool = ..., +) -> numpy.ndarray[bool]: ... diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/membership_pybind.cc b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/membership_pybind.cc new file mode 100644 index 0000000000000000000000000000000000000000..2b3faf8a2afc990c42b812d35aa346009229b638 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/membership_pybind.cc @@ -0,0 +1,82 @@ +// Copyright 2024 DeepMind Technologies Limited +// +// AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +// this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +// +// To request access to the AlphaFold 3 model parameters, follow the process set +// out at https://github.com/google-deepmind/alphafold3. You may only use these +// if received directly from Google. Use is subject to terms of use available at +// https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "pybind11/cast.h" +#include "pybind11/numpy.h" +#include "pybind11/pybind11.h" +#include "pybind11_abseil/absl_casters.h" + +namespace { + +namespace py = pybind11; + +py::array_t IsIn(const py::array_t& array, + const absl::flat_hash_set& test_elements, + bool invert) { + const size_t num_elements = array.size(); + + py::array_t output(num_elements); + std::fill(output.mutable_data(), output.mutable_data() + output.size(), + invert); + + // Shortcut: The output will be trivially always false if test_elements empty. + if (test_elements.empty()) { + return output; + } + + for (size_t i = 0; i < num_elements; ++i) { + if (test_elements.contains(array.data()[i])) { + output.mutable_data()[i] = !invert; + } + } + if (array.ndim() > 1) { + auto shape = + std::vector(array.shape(), array.shape() + array.ndim()); + return output.reshape(shape); + } + return output; +} + +constexpr char kIsInDoc[] = R"( +Computes whether each element is in test_elements. + +Same use as np.isin, but much faster. If len(array) = n, len(test_elements) = m: +* This function has complexity O(n). +* np.isin with kind='sort' has complexity O(m*log(m) + n * log(m)). + +Args: + array: Input NumPy array with dtype=np.int64. + test_elements: The values against which to test each value of array. + invert: If True, the values in the returned array are inverted, as if + calculating `element not in test_elements`. Default is False. + `isin(a, b, invert=True)` is equivalent to but faster than `~isin(a, b)`. + +Returns + A boolean array of the same shape as the input array. Each value `val` is: + * `val in test_elements` if `invert=False`, + * `val not in test_elements` if `invert=True`. +)"; + +} // namespace + +namespace alphafold3 { + +void RegisterModuleMembership(pybind11::module m) { + m.def("isin", &IsIn, py::arg("array"), py::arg("test_elements"), + py::kw_only(), py::arg("invert") = false, py::doc(kIsInDoc + 1)); +} + +} // namespace alphafold3 diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/membership_pybind.h b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/membership_pybind.h new file mode 100644 index 0000000000000000000000000000000000000000..d224fb1f64c92d6f3753da7b2e231077edde436b --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/membership_pybind.h @@ -0,0 +1,24 @@ +/* + * Copyright 2024 DeepMind Technologies Limited + * + * AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of + * this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ + * + * To request access to the AlphaFold 3 model parameters, follow the process set + * out at https://github.com/google-deepmind/alphafold3. You may only use these + * if received directly from Google. Use is subject to terms of use available at + * https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + */ + +#ifndef ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_MEMBERSHIP_PYBIND_H_ +#define ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_MEMBERSHIP_PYBIND_H_ + +#include "pybind11/pybind11.h" + +namespace alphafold3 { + +void RegisterModuleMembership(pybind11::module m); + +} + +#endif // ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_MEMBERSHIP_PYBIND_H_ diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_altlocs.cc b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_altlocs.cc new file mode 100644 index 0000000000000000000000000000000000000000..cea9a1b1c9df6f27c707b1507439f7fba40b770f --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_altlocs.cc @@ -0,0 +1,249 @@ +// Copyright 2024 DeepMind Technologies Limited +// +// AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +// this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +// +// To request access to the AlphaFold 3 model parameters, follow the process set +// out at https://github.com/google-deepmind/alphafold3. You may only use these +// if received directly from Google. Use is subject to terms of use available at +// https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +#include "alphafold3/structure/cpp/mmcif_altlocs.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/log/log.h" +#include "absl/strings/numbers.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "alphafold3/structure/cpp/mmcif_layout.h" + +namespace alphafold3 { +namespace { + +float OccupancyToFloat(absl::string_view occupancy) { + float result = 0.0f; + LOG_IF(ERROR, !absl::SimpleAtof(occupancy, &result)) + << "Invalid Occupancy: " << occupancy; + return result; +} + +// Deuterium is the same atom as Hydrogen so keep equivalent for grouping. +bool AtomEquiv(absl::string_view lhs, absl::string_view rhs) { + if (lhs == rhs) return true; + if (lhs.empty() != rhs.empty()) return false; + // Both lhs and rhs are guaranteed to be non-empty after this. + char first_lhs = lhs.front(); + char second_rhs = rhs.front(); + if ((first_lhs == 'H' && second_rhs == 'D') || + (first_lhs == 'D' && second_rhs == 'H')) { + lhs.remove_prefix(1); + rhs.remove_prefix(1); + return lhs == rhs; + } + return false; +} + +// Calls group_callback with that start index and count for each group of +// equivalent values in `values`, starting at `start` and ending at `count`. +// Example: +// GroupBy({"B", "B", "B", "C", "C"}, 0, 5, [](size_t start, size_t count) { +// absl::Printf("start=%d, count=%d\n", start, count); +// }); +// Would print: +// start=0, count=3 +// start=3, count=2 +template > +void GroupBy(absl::Span values, std::size_t start, + std::size_t count, GroupCallback&& group_callback, + IsEqual&& is_equal = std::equal_to{}) { + std::size_t span_start = start; + if (count > 0) { + for (std::size_t i = start + 1; i < start + count; ++i) { + if (!is_equal(values[i], values[span_start])) { + group_callback(span_start, i - span_start); + span_start = i; + } + } + group_callback(span_start, start + count - span_start); + } +} + +void ProcessAltLocGroupsWhole(std::size_t alt_loc_start, + std::size_t alt_loc_count, + absl::Span comp_ids, + absl::Span atom_ids, + absl::Span alt_ids, + absl::Span occupancies, + std::vector& in_out_keep_indices) { + std::pair best_split = {alt_loc_start, + alt_loc_count}; + std::vector alt_loc_groups; + float best_occupancy = -std::numeric_limits::infinity(); + char best_group = alt_ids[alt_loc_start].front(); + std::vector> occupancy_stats; + + // Group by residue type. + GroupBy(comp_ids, alt_loc_start, alt_loc_count, + [&](std::size_t start, std::size_t count) { + // This callback selects the best residue group and the best + // Alt-loc char within that group. + alt_loc_groups.clear(); + occupancy_stats.clear(); + // Calculate total occupancy for residue type. + for (std::size_t i = 0; i < count; ++i) { + char alt_loc_id = alt_ids[start + i].front(); + float occupancy = OccupancyToFloat(occupancies[start + i]); + if (auto loc = absl::c_find(alt_loc_groups, alt_loc_id); + loc == alt_loc_groups.end()) { + occupancy_stats.emplace_back(1, occupancy); + alt_loc_groups.push_back(alt_loc_id); + } else { + auto& stat = + occupancy_stats[std::distance(alt_loc_groups.begin(), loc)]; + ++stat.first; + stat.second += occupancy; + } + } + float total_occupancy = 0.0; + for (auto& stat : occupancy_stats) { + total_occupancy += stat.second / stat.first; + } + char group = *absl::c_min_element(alt_loc_groups); + // Compares occupancy of residue to best seen so far. + // Tie breaks alphabetic. + if (total_occupancy > best_occupancy || + (total_occupancy == best_occupancy && group < best_group)) { + // Selects the best sub group. + best_group = alt_loc_groups.front(); + float best_amount = occupancy_stats.front().second / + occupancy_stats.front().first; + for (std::size_t i = 1; i < occupancy_stats.size(); ++i) { + float amount = + occupancy_stats[i].second / occupancy_stats[i].first; + char group = alt_loc_groups[i]; + if (amount > best_amount || + (amount == best_amount && group < best_group)) { + best_amount = amount; + best_group = group; + } + } + best_occupancy = total_occupancy; + best_split = {start, count}; + } + }); + + // Now that the best residue type has been selected and the best alt-loc + // within that has been selected add indices of indices to keep to the keep + // list. + auto [split_start, split_count] = best_split; + GroupBy( + atom_ids, split_start, split_count, + [&in_out_keep_indices, &alt_ids, best_group](std::size_t start, + std::size_t count) { + // This makes sure we select an atom for each atom id even if it does + // not have our selected alt-loc char. + std::size_t best_index = start; + for (std::size_t i = 1; i < count; ++i) { + if (alt_ids[start + i].front() == best_group) { + best_index = start + i; + break; + } + } + in_out_keep_indices.push_back(best_index); + }, + AtomEquiv); +} + +// Finds the alt-loc group with the highest score and pushes the indices on to +// the back of in_out_keep_indices. +void ProcessAltLocGroupPartial( + std::size_t alt_loc_start, std::size_t alt_loc_count, + absl::Span atom_ids, + absl::Span alt_ids, + absl::Span occupancies, + std::vector& in_out_keep_indices) { + GroupBy( + atom_ids, alt_loc_start, alt_loc_count, + [&](std::size_t start, std::size_t count) { + if (count == 1) { + in_out_keep_indices.push_back(start); + } else { + float best_occ = OccupancyToFloat(occupancies[start]); + std::size_t best_index = start; + char best_group = alt_ids[start].front(); + for (std::size_t i = 0; i < count; ++i) { + float occ = OccupancyToFloat(occupancies[start + i]); + char group = alt_ids[start + i].front(); + if (occ > best_occ || (occ == best_occ && group < best_group)) { + best_group = group; + best_index = start + i; + best_occ = occ; + } + } + in_out_keep_indices.push_back(best_index); + } + }, + AtomEquiv); +} + +} // namespace + +// Resolves alt-locs returning the atom indices that will be left. +std::vector ResolveMmcifAltLocs( + const MmcifLayout& layout, absl::Span comp_ids, + absl::Span atom_ids, + absl::Span alt_ids, + absl::Span occupancies, + absl::Span chain_indices) { + std::vector keep_indices; + keep_indices.reserve(layout.num_atoms()); + std::size_t alt_loc_start = 0; + for (std::size_t chain_index : chain_indices) { + auto [residues_start, residues_end] = layout.residue_range(chain_index); + for (std::size_t residue = residues_start; residue < residues_end; + ++residue) { + std::size_t alt_loc_count = 0; + auto [atom_start, atom_end] = layout.atom_range(residue); + for (std::size_t i = atom_start; i < atom_end; ++i) { + char alt_loc_id = alt_ids[i].front(); + if (alt_loc_id == '.' || alt_loc_id == '?') { + if (alt_loc_count > 0) { + ProcessAltLocGroupPartial(alt_loc_start, alt_loc_count, atom_ids, + alt_ids, occupancies, keep_indices); + alt_loc_count = 0; + } + keep_indices.push_back(i); + } else { + if (alt_loc_count == 0) { + alt_loc_start = i; + } + ++alt_loc_count; + } + } + if (alt_loc_count > 0) { + if (atom_end - atom_start == alt_loc_count) { + ProcessAltLocGroupsWhole(alt_loc_start, alt_loc_count, comp_ids, + atom_ids, alt_ids, occupancies, + keep_indices); + } else { + ProcessAltLocGroupPartial(alt_loc_start, alt_loc_count, atom_ids, + alt_ids, occupancies, keep_indices); + } + } + } + } + + return keep_indices; +} + +} // namespace alphafold3 diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_altlocs.h b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_altlocs.h new file mode 100644 index 0000000000000000000000000000000000000000..fab57817c38b62d96b3370d4b67cc4358656cc19 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_altlocs.h @@ -0,0 +1,51 @@ +/* + * Copyright 2024 DeepMind Technologies Limited + * + * AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of + * this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ + * + * To request access to the AlphaFold 3 model parameters, follow the process set + * out at https://github.com/google-deepmind/alphafold3. You may only use these + * if received directly from Google. Use is subject to terms of use available at + * https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + */ + +#ifndef ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_MMCIF_ALTLOCS_H_ +#define ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_MMCIF_ALTLOCS_H_ + +#include +#include +#include +#include + +#include "absl/types/span.h" +#include "alphafold3/structure/cpp/mmcif_layout.h" + +namespace alphafold3 { + +// Returns the list of indices that should be kept after resolving alt-locs. +// 1) Partial Residue. Each cycle of alt-locs are resolved separately with the +// highest occupancy alt-loc. Tie-breaks are resolved alphabetically. See +// tests for examples. +// 2) Whole Residue. These are resolved in two passes. +// a) The residue with the highest occupancy is chosen. +// b) The locations for a given residue are resolved. +// All tie-breaks are resolved alphabetically. See tests for examples. +// +// Preconditions: layout and comp_ids, alt_ids, occupancies are all from same +// mmCIF file and chain_indices are monotonically increasing and less than +// layout.num_chains(). +// +// comp_ids from '_atom_site.label_comp_id'. +// alt_ids from '_atom_site.label_alt_id'. +// occupancies from '_atom_site.occupancy'. +std::vector ResolveMmcifAltLocs( + const MmcifLayout& layout, absl::Span comp_ids, + absl::Span atom_ids, + absl::Span alt_ids, + absl::Span occupancies, + absl::Span chain_indices); + +} // namespace alphafold3 + +#endif // ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_MMCIF_ALTLOCS_H_ diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_atom_site.pyi b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_atom_site.pyi new file mode 100644 index 0000000000000000000000000000000000000000..5f0ba34b062eeabcdcfff9b52c97c54d693b06fa --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_atom_site.pyi @@ -0,0 +1,23 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +from collections.abc import Callable +from alphafold3.cpp import cif_dict + + +def get_internal_to_author_chain_id_map( + mmcif: cif_dict.CifDict +) -> dict[str,str]: ... + + +def get_or_infer_type_symbol( + mmcif: cif_dict.CifDict, + atom_id_to_type_symbol: Callable[[str, str], str], +) -> list[str]: ... diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_atom_site_pybind.cc b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_atom_site_pybind.cc new file mode 100644 index 0000000000000000000000000000000000000000..6037fe08ba30d169047a7f6644f0a33cebfdc2e7 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_atom_site_pybind.cc @@ -0,0 +1,83 @@ +// Copyright 2024 DeepMind Technologies Limited +// +// AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +// this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +// +// To request access to the AlphaFold 3 model parameters, follow the process set +// out at https://github.com/google-deepmind/alphafold3. You may only use these +// if received directly from Google. Use is subject to terms of use available at +// https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "alphafold3/parsers/cpp/cif_dict_lib.h" +#include "pybind11/gil.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" +#include "pybind11/stl.h" +#include "pybind11_abseil/absl_casters.h" + +namespace alphafold3 { +namespace { +namespace py = pybind11; + +// If present, returns the _atom_site.type_symbol. If not, infers it using +// _atom_site.label_comp_id (residue name), _atom_site.label_atom_id (atom name) +// and the CCD. +py::list GetOrInferTypeSymbol(const CifDict& mmcif, + const py::object& atom_id_to_type_symbol) { + const auto& type_symbol = mmcif["_atom_site.type_symbol"]; + const int num_atom = mmcif["_atom_site.id"].size(); + py::list patched_type_symbol(num_atom); + if (type_symbol.empty()) { + const auto& label_comp_id = mmcif["_atom_site.label_comp_id"]; + const auto& label_atom_id = mmcif["_atom_site.label_atom_id"]; + CHECK_EQ(label_comp_id.size(), num_atom); + CHECK_EQ(label_atom_id.size(), num_atom); + for (int i = 0; i < num_atom; i++) { + patched_type_symbol[i] = + atom_id_to_type_symbol(label_comp_id[i], label_atom_id[i]); + } + } else { + for (int i = 0; i < num_atom; i++) { + patched_type_symbol[i] = type_symbol[i]; + } + } + return patched_type_symbol; +} + +absl::flat_hash_map +GetInternalToAuthorChainIdMap(const CifDict& mmcif) { + const auto& label_asym_ids = mmcif["_atom_site.label_asym_id"]; + const auto& auth_asym_ids = mmcif["_atom_site.auth_asym_id"]; + CHECK_EQ(label_asym_ids.size(), auth_asym_ids.size()); + + absl::flat_hash_map mapping; + for (size_t i = 0, num_rows = label_asym_ids.size(); i < num_rows; ++i) { + // Use only the first internal_chain_id occurrence to generate the mapping. + // It should not matter as there should not be a case where a single + // internal chain ID would map to more than one author chain IDs (i.e. the + // mapping should be injective). Since we need this method to be fast, we + // choose not to check it. + mapping.emplace(label_asym_ids[i], auth_asym_ids[i]); + } + return mapping; +} + +} // namespace + +namespace py = pybind11; + +void RegisterModuleMmcifAtomSite(pybind11::module m) { + m.def("get_or_infer_type_symbol", &GetOrInferTypeSymbol, py::arg("mmcif"), + py::arg("atom_id_to_type_symbol")); + + m.def("get_internal_to_author_chain_id_map", &GetInternalToAuthorChainIdMap, + py::arg("mmcif"), py::call_guard()); +} + +} // namespace alphafold3 diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_atom_site_pybind.h b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_atom_site_pybind.h new file mode 100644 index 0000000000000000000000000000000000000000..1f2104ecf0de171483b67370a91ff1acef0e0e28 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_atom_site_pybind.h @@ -0,0 +1,24 @@ +/* + * Copyright 2024 DeepMind Technologies Limited + * + * AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of + * this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ + * + * To request access to the AlphaFold 3 model parameters, follow the process set + * out at https://github.com/google-deepmind/alphafold3. You may only use these + * if received directly from Google. Use is subject to terms of use available at + * https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + */ + +#ifndef ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_MMCIF_ATOM_SITE_PYBIND_H_ +#define ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_MMCIF_ATOM_SITE_PYBIND_H_ + +#include "pybind11/pybind11.h" + +namespace alphafold3 { + +void RegisterModuleMmcifAtomSite(pybind11::module m); + +} + +#endif // ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_MMCIF_ATOM_SITE_PYBIND_H_ diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_layout.h b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_layout.h new file mode 100644 index 0000000000000000000000000000000000000000..51c67c528f60b71d009fd6ed9da199443bb5c5d4 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_layout.h @@ -0,0 +1,146 @@ +/* + * Copyright 2024 DeepMind Technologies Limited + * + * AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of + * this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ + * + * To request access to the AlphaFold 3 model parameters, follow the process set + * out at https://github.com/google-deepmind/alphafold3. You may only use these + * if received directly from Google. Use is subject to terms of use available at + * https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + */ + +#ifndef ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_MMCIF_LAYOUT_H_ +#define ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_MMCIF_LAYOUT_H_ + +#include +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "alphafold3/parsers/cpp/cif_dict_lib.h" + +namespace alphafold3 { + +// Holds the layout of a parsed mmCIF file. +class MmcifLayout { + public: + MmcifLayout(std::vector chain_ends, + std::vector residues, std::size_t model_offset, + std::size_t num_models) + : chain_ends_(std::move(chain_ends)), + residue_ends_(std::move(residues)), + model_offset_(model_offset), + num_models_(num_models) {} + + // Reads a layout from a valid parsed mmCIF. If a valid model_id is provided + // the offsets will select that model from the mmCIF. + // If no model_id is specified, we calculate the layout of the first model + // only. Therefore it is a requirement that each model has identical atom + // layouts. An error is returned if the atom counts do not between models. + static absl::StatusOr Create(const CifDict& mmcif, + absl::string_view model_id = ""); + + std::string ToDebugString() const; + + // Returns the start index and one past the last residue index of a given + // chain. A chain_index of n refers to the n-th chain in the mmCIF. The + // returned residue indices are 0-based enumerations of residues in the + // _atom_site records, and therefore do not include missing residues. + std::pair residue_range( + std::size_t chain_index) const { + if (chain_index > 0) { + return {chain_ends_[chain_index - 1], chain_ends_[chain_index]}; + } else { + return {0, chain_ends_[0]}; + } + } + + // Returns the start index and one past the last index of a given residue. + // A residue_index of n refers to the n-th residue in the mmCIF, not + // including residues that are unresolved (i.e. only using _atom_site). + std::pair atom_range( + std::size_t residue_index) const { + if (residue_index > 0) { + return {residue_ends_[residue_index - 1], residue_ends_[residue_index]}; + } else { + return {model_offset_, residue_ends_[residue_index]}; + } + } + + // If model_id was provided during construction then this is 1, otherwise + // it is the number of models present in the mmCIF. + std::size_t num_models() const { return num_models_; } + // The number of atoms in the chosen model. + std::size_t num_atoms() const { + return residue_ends_.empty() ? 0 : residue_ends_.back() - model_offset_; + } + // The number of chains in the chosen model. + std::size_t num_chains() const { return chain_ends_.size(); } + // The number of residues in the chosen model, not counting unresolved + // residues. + std::size_t num_residues() const { return residue_ends_.size(); } + + // Returns the first atom index that is part of the specified chain. + // The chain is specified using chain_index, which is a 0-based + // enumeration of the chains in the _atom_site table. + std::size_t atom_site_from_chain_index(std::size_t chain_index) const { + if (chain_index == 0) { + return model_offset_; + } + return atom_site_from_residue_index(chain_ends_[chain_index - 1]); + } + + // Returns the first atom index that is part of the specified residue. + // The residue is specified using residue_index, which is a 0-based + // enumeration of the residues in the _atom_site table. + std::size_t atom_site_from_residue_index(std::size_t residues_index) const { + if (residues_index == 0) { + return model_offset_; + } + return residue_ends_[residues_index - 1]; + } + + // One past last residue index of each chain. The residue index does not + // include unresolved residues and is a simple 0-based enumeration of the + // residues in _atom_site table. + const std::vector& chains() const { return chain_ends_; } + + // Indices of the first atom of each chain. Note that this returns atom + // indices (like residue_starts()), not residue indices (like chains()). + std::vector chain_starts() const; + + // One past last atom index of each residue. + const std::vector& residues() const { return residue_ends_; } + + // Indices of the first atom of each residue. + std::vector residue_starts() const { + std::vector residue_starts; + if (!residue_ends_.empty()) { + residue_starts.reserve(residue_ends_.size()); + residue_starts.push_back(model_offset_); + residue_starts.insert(residue_starts.end(), residue_ends_.begin(), + residue_ends_.end() - 1); + } + return residue_starts; + } + + // The first atom index that is part of the specified model. + std::size_t model_offset() const { return model_offset_; } + + void Filter(absl::Span keep_indices); + + private: + std::vector chain_ends_; + std::vector residue_ends_; + std::size_t model_offset_; + std::size_t num_models_; +}; + +} // namespace alphafold3 + +#endif // ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_MMCIF_LAYOUT_H_ diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_layout.pyi b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_layout.pyi new file mode 100644 index 0000000000000000000000000000000000000000..add1b05ea89b2ed34b645daad9c709fec551cb43 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_layout.pyi @@ -0,0 +1,26 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +from alphafold3.cpp import cif_dict + +class MmcifLayout: + def atom_range(self, residue_index: int) -> tuple[int, int]: ... + def chain_starts(self) -> list[int]: ... + def chains(self) -> list[int]: ... + def model_offset(self) -> int: ... + def num_atoms(self) -> int: ... + def num_chains(self) -> int: ... + def num_models(self) -> int: ... + def num_residues(self) -> int: ... + def residue_range(self, chain_index: int) -> tuple[int, int]: ... + def residue_starts(self) -> list[int]: ... + def residues(self) -> list[int]: ... + +def from_mmcif(mmcif: cif_dict.CifDict, model_id: str = ...) -> MmcifLayout: ... diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_layout_lib.cc b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_layout_lib.cc new file mode 100644 index 0000000000000000000000000000000000000000..91ad70c0b7b4f5bc38d3bcb994593e11bb2b7616 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_layout_lib.cc @@ -0,0 +1,213 @@ +// Copyright 2024 DeepMind Technologies Limited +// +// AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +// this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +// +// To request access to the AlphaFold 3 model parameters, follow the process set +// out at https://github.com/google-deepmind/alphafold3. You may only use these +// if received directly from Google. Use is subject to terms of use available at +// https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "alphafold3/parsers/cpp/cif_dict_lib.h" +#include "alphafold3/structure/cpp/mmcif_layout.h" + +namespace alphafold3 { + +std::string MmcifLayout::ToDebugString() const { + return absl::StrFormat( + "MmcifLayout(models=%d, chains=%d, num_residues=%d, atoms=%d)", + num_models(), num_chains(), num_residues(), num_atoms()); +} + +// Changes layout to match keep_indices removing empty chains/residues. +void MmcifLayout::Filter(absl::Span keep_indices) { + if (num_chains() == 0) { + return; + } + // Update residue indices. + auto keep_it = absl::c_lower_bound(keep_indices, residue_ends_.front()); + for (auto& residue : residue_ends_) { + while (keep_it != keep_indices.end() && *keep_it < residue) { + ++keep_it; + } + residue = std::distance(keep_indices.begin(), keep_it); + } + // Unique residue_ends_ with updating chains. + auto first = residue_ends_.begin(); + auto tail = first; + std::size_t num_skipped = 0; + std::size_t current = 0; + for (std::size_t& chain_end : chain_ends_) { + for (auto e = residue_ends_.begin() + chain_end; first != e; ++first) { + std::size_t next = *first; + *tail = next; + if (current != next) { + current = next; + ++tail; + } else { + ++num_skipped; + } + } + chain_end -= num_skipped; + } + residue_ends_.erase(tail, residue_ends_.end()); + + current = 0; + chain_ends_.erase(std::remove_if(chain_ends_.begin(), chain_ends_.end(), + [¤t](std::size_t next) { + bool result = current == next; + current = next; + return result; + }), + chain_ends_.end()); + model_offset_ = 0; +} + +absl::StatusOr MmcifLayout::Create(const CifDict& mmcif, + absl::string_view model_id) { + auto model_ids = mmcif["_atom_site.pdbx_PDB_model_num"]; + auto chain_ids = mmcif["_atom_site.label_asym_id"]; // chain ID. + auto label_seq_ids = mmcif["_atom_site.label_seq_id"]; // residue ID. + auto auth_seq_ids = mmcif["_atom_site.auth_seq_id"]; // author residue ID. + auto insertion_codes = mmcif["_atom_site.pdbx_PDB_ins_code"]; + + if (model_ids.size() != chain_ids.size() || + model_ids.size() != label_seq_ids.size() || + (model_ids.size() != auth_seq_ids.size() && !auth_seq_ids.empty()) || + (model_ids.size() != insertion_codes.size() && + !insertion_codes.empty())) { + return absl::InvalidArgumentError(absl::StrCat( + "Invalid _atom_site table.", // + " len(_atom_site.pdbx_PDB_model_num): ", model_ids.size(), + " len(_atom_site.label_asym_id): ", chain_ids.size(), + " len(_atom_site.label_seq_id): ", label_seq_ids.size(), + " len(_atom_site.auth_seq_id): ", auth_seq_ids.size(), + " len(_atom_site.pdbx_PDB_ins_code): ", insertion_codes.size())); + } + std::size_t num_atoms = model_ids.size(); + if (num_atoms == 0) { + return MmcifLayout({}, {}, 0, 0); + } + std::size_t model_offset = 0; + std::size_t num_models; + std::size_t num_atoms_per_model; + if (model_id.empty()) { + absl::string_view first_model_id = model_ids.front(); + + // Binary search for where the first model ends. + num_atoms_per_model = std::distance( + model_ids.begin(), + absl::c_upper_bound(model_ids, first_model_id, std::not_equal_to<>{})); + if (num_atoms % num_atoms_per_model != 0) { + return absl::InvalidArgumentError(absl::StrCat( + "Each model must have the same number of atoms: (", num_atoms, " % ", + num_atoms_per_model, " == ", num_atoms % num_atoms_per_model, ").")); + } + num_models = num_atoms / num_atoms_per_model; + // Test boundary conditions for each model hold. + for (std::size_t i = 1; i < num_models; ++i) { + if ((model_ids[i * num_atoms_per_model] != + model_ids[(i + 1) * num_atoms_per_model - 1]) || + (model_ids[i * num_atoms_per_model - 1] == + model_ids[i * num_atoms_per_model])) { + return absl::InvalidArgumentError( + absl::StrCat("Each model must have the same number of atoms: (", + num_atoms, " % ", num_atoms_per_model, + " == ", num_atoms % num_atoms_per_model, ").")); + } + } + } else { + num_models = 1; + model_offset = + std::distance(model_ids.begin(), absl::c_find(model_ids, model_id)); + if (model_offset == model_ids.size()) { + return absl::InvalidArgumentError( + absl::StrCat("Unknown model_id: ", model_id)); + } + model_ids.remove_prefix(model_offset); + chain_ids.remove_prefix(model_offset); + label_seq_ids.remove_prefix(model_offset); + if (!auth_seq_ids.empty()) auth_seq_ids.remove_prefix(model_offset); + if (!insertion_codes.empty()) insertion_codes.remove_prefix(model_offset); + + num_atoms_per_model = std::distance( + model_ids.begin(), std::upper_bound(model_ids.begin(), model_ids.end(), + model_id, std::not_equal_to<>{})); + num_atoms = num_atoms_per_model; + } + std::vector residues; + std::vector chains; + absl::string_view chain_id = chain_ids.front(); + if (!auth_seq_ids.empty() && !insertion_codes.empty()) { + // If author residue IDs are present then these are preferred to + // label residue IDs because they work for multi-residue ligands (which + // are given constant "." label residue IDs). + // NB: Author residue IDs require both the auth_seq_id and the insertion + // code to be unique. + absl::string_view auth_seq_id = auth_seq_ids.front(); + absl::string_view insertion_code = insertion_codes.front(); + for (std::size_t i = 1; i < num_atoms_per_model; ++i) { + if (absl::string_view current_chain_id = chain_ids[i]; + current_chain_id != chain_id) { + residues.push_back(i + model_offset); + chains.push_back(residues.size()); + chain_id = current_chain_id; + auth_seq_id = auth_seq_ids[i]; + insertion_code = insertion_codes[i]; + } else if (absl::string_view current_seq_id = auth_seq_ids[i], + current_insertion_code = insertion_codes[i]; + insertion_code != current_insertion_code || + auth_seq_id != current_seq_id) { + residues.push_back(i + model_offset); + auth_seq_id = current_seq_id; + insertion_code = current_insertion_code; + } + } + } else { + absl::string_view label_seq_id = label_seq_ids.front(); + for (std::size_t i = 1; i < num_atoms_per_model; ++i) { + if (absl::string_view current_chain_id = chain_ids[i]; + current_chain_id != chain_id) { + residues.push_back(i + model_offset); + chains.push_back(residues.size()); + chain_id = current_chain_id; + label_seq_id = label_seq_ids[i]; + } else if (absl::string_view current_seq_id = label_seq_ids[i]; + label_seq_id != current_seq_id) { + residues.push_back(i + model_offset); + label_seq_id = current_seq_id; + } + } + } + residues.push_back(num_atoms_per_model + model_offset); + chains.push_back(residues.size()); + return MmcifLayout(std::move(chains), std::move(residues), model_offset, + num_models); +} + +std::vector MmcifLayout::chain_starts() const { + std::vector chain_starts; + chain_starts.reserve(chain_ends_.size()); + for (std::size_t index = 0; index < chain_ends_.size(); ++index) { + chain_starts.push_back(atom_site_from_chain_index(index)); + } + return chain_starts; +} + +} // namespace alphafold3 diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_layout_pybind.cc b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_layout_pybind.cc new file mode 100644 index 0000000000000000000000000000000000000000..8eb69befc0e93baf084333b055e1466e5902f35e --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_layout_pybind.cc @@ -0,0 +1,49 @@ +// Copyright 2024 DeepMind Technologies Limited +// +// AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +// this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +// +// To request access to the AlphaFold 3 model parameters, follow the process set +// out at https://github.com/google-deepmind/alphafold3. You may only use these +// if received directly from Google. Use is subject to terms of use available at +// https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +#include "alphafold3/structure/cpp/mmcif_layout.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" +#include "pybind11/stl.h" + +namespace alphafold3 { + +namespace py = pybind11; + +void RegisterModuleMmcifLayout(pybind11::module m) { + py::class_(m, "MmcifLayout") + .def("__str__", &MmcifLayout::ToDebugString) + .def("num_models", &MmcifLayout::num_models) + .def("num_chains", &MmcifLayout::num_chains) + .def("num_residues", &MmcifLayout::num_residues) + .def("num_atoms", &MmcifLayout::num_atoms) + .def("residue_range", &MmcifLayout::residue_range, py::arg("chain_index")) + .def("atom_range", &MmcifLayout::atom_range, py::arg("residue_index")) + .def("chains", &MmcifLayout::chains, + py::doc("Returns a list of indices one past the last residue of " + "each chain.")) + .def( + "chain_starts", &MmcifLayout::chain_starts, + py::doc("Returns a list of indices of the first atom of each chain.")) + .def("residues", &MmcifLayout::residues, + py::doc("Returns a list of indices one past the last atom of each " + "residue.")) + .def("residue_starts", &MmcifLayout::residue_starts, + py::doc( + "Returns a list of indices of the first atom of each residue.")) + .def("model_offset", &MmcifLayout::model_offset, + py::doc("Returns the first atom index that is part of the specified " + "model.")); + + m.def("from_mmcif", &MmcifLayout::Create, py::arg("mmcif"), + py::arg("model_id") = ""); +} + +} // namespace alphafold3 diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_layout_pybind.h b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_layout_pybind.h new file mode 100644 index 0000000000000000000000000000000000000000..c79b2dd50e0e48bc6065ae3e5468b1142126e659 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_layout_pybind.h @@ -0,0 +1,24 @@ +/* + * Copyright 2024 DeepMind Technologies Limited + * + * AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of + * this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ + * + * To request access to the AlphaFold 3 model parameters, follow the process set + * out at https://github.com/google-deepmind/alphafold3. You may only use these + * if received directly from Google. Use is subject to terms of use available at + * https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + */ + +#ifndef ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_MMCIF_LAYOUT_PYBIND_H_ +#define ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_MMCIF_LAYOUT_PYBIND_H_ + +#include "pybind11/pybind11.h" + +namespace alphafold3 { + +void RegisterModuleMmcifLayout(pybind11::module m); + +} + +#endif // ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_MMCIF_LAYOUT_PYBIND_H_ diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_struct_conn.h b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_struct_conn.h new file mode 100644 index 0000000000000000000000000000000000000000..821be658da838e72bd85bb884669414224a7a7f0 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_struct_conn.h @@ -0,0 +1,34 @@ +/* + * Copyright 2024 DeepMind Technologies Limited + * + * AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of + * this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ + * + * To request access to the AlphaFold 3 model parameters, follow the process set + * out at https://github.com/google-deepmind/alphafold3. You may only use these + * if received directly from Google. Use is subject to terms of use available at + * https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + */ + +#ifndef ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_MMCIF_STRUCT_CONN_H_ +#define ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_MMCIF_STRUCT_CONN_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "alphafold3/parsers/cpp/cif_dict_lib.h" + +namespace alphafold3 { + +// Returns a pair of atom indices for each row in the bonds table (aka +// _struct_conn). The indices are simple 0-based indexes into the columns of +// the _atom_site table in the input mmCIF, and do not necessarily correspond +// to the values in _atom_site.id, or any other column. +absl::StatusOr, std::vector>> +GetBondAtomIndices(const CifDict& mmcif, absl::string_view model_id); + +} // namespace alphafold3 + +#endif // ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_MMCIF_STRUCT_CONN_H_ diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_struct_conn.pyi b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_struct_conn.pyi new file mode 100644 index 0000000000000000000000000000000000000000..d293e666a3aab8dba274d4380c2ab24018a66f35 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_struct_conn.pyi @@ -0,0 +1,13 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +from alphafold3.cpp import cif_dict + +def get_bond_atom_indices(mmcif_dict: cif_dict.CifDict, model_id: str) -> tuple[list[int],list[int]]: ... diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_struct_conn_lib.cc b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_struct_conn_lib.cc new file mode 100644 index 0000000000000000000000000000000000000000..afb930fab350d1c39098c1f73f4deeb968b22a02 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_struct_conn_lib.cc @@ -0,0 +1,380 @@ +// Copyright 2024 DeepMind Technologies Limited +// +// AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +// this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +// +// To request access to the AlphaFold 3 model parameters, follow the process set +// out at https://github.com/google-deepmind/alphafold3. You may only use these +// if received directly from Google. Use is subject to terms of use available at +// https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "alphafold3/parsers/cpp/cif_dict_lib.h" +#include "alphafold3/structure/cpp/mmcif_struct_conn.h" + +namespace alphafold3 { + +namespace { + +struct AtomId { + absl::string_view chain_id; + absl::string_view res_id_1; + absl::string_view res_id_2; + absl::string_view atom_name; + absl::string_view alt_id; + + friend bool operator==(const AtomId&, const AtomId&) = default; + template + friend H AbslHashValue(H h, const AtomId& m) { + return H::combine(std::move(h), m.chain_id, m.res_id_1, m.res_id_2, + m.atom_name, m.alt_id); + } +}; + +using StringArrayRef = absl::Span; +using BondIndexByAtom = absl::flat_hash_map>; +using BondAtomIndices = std::vector; + +// Returns whether each container is the same size. +template +bool AreSameSize(const C& c, const Cs&... cs) { + return ((c.size() == cs.size()) && ...); +} + +struct ColumnSpec { + absl::string_view chain_id_col; + absl::string_view res_id_1_col; + absl::string_view res_id_2_col; + absl::string_view atom_name_col; + std::optional alt_id_col; // Not used by OpenMM. +}; + +class AtomColumns { + public: + static absl::StatusOr Create(const CifDict& mmcif, + const ColumnSpec& column_spec) { + StringArrayRef chain_id = mmcif[column_spec.chain_id_col]; + StringArrayRef res_id_1 = mmcif[column_spec.res_id_1_col]; + StringArrayRef res_id_2 = mmcif[column_spec.res_id_2_col]; + StringArrayRef atom_name = mmcif[column_spec.atom_name_col]; + if (!AreSameSize(chain_id, res_id_1, res_id_2, atom_name)) { + return absl::InvalidArgumentError(absl::StrCat( + "Atom columns are not the same size. ", // + "len(", column_spec.chain_id_col, ")=", chain_id.size(), // + ", len(", column_spec.res_id_1_col, ")=", res_id_1.size(), // + ", len(", column_spec.res_id_2_col, ")=", res_id_2.size(), // + ", len(", column_spec.atom_name_col, ")=", atom_name.size(), // + ".")); + } + if (column_spec.alt_id_col.has_value()) { + StringArrayRef alt_id = mmcif[*column_spec.alt_id_col]; + if (!AreSameSize(alt_id, chain_id)) { + return absl::InvalidArgumentError(absl::StrCat( + "Atom columns are not the same size. ", // + "len(", column_spec.chain_id_col, ")=", chain_id.size(), // + ", len(", *column_spec.alt_id_col, ")=", alt_id.size(), // + ".")); + } + return AtomColumns(chain_id, res_id_1, res_id_2, atom_name, alt_id, + column_spec); + } else { + return AtomColumns(chain_id, res_id_1, res_id_2, atom_name, std::nullopt, + column_spec); + } + } + + inline std::size_t size() const { return size_; } + + absl::string_view GetNormalizedAltId(const std::size_t index) const { + constexpr absl::string_view kFullStop = "."; + if (alt_id_.has_value()) { + absl::string_view alt_id = (*alt_id_)[index]; + return alt_id == "?" ? kFullStop : alt_id; + } else { + return kFullStop; + } + } + + AtomId GetAtom(const std::size_t index) const { + return {.chain_id = chain_id_[index], + .res_id_1 = res_id_1_[index], + .res_id_2 = res_id_2_[index], + .atom_name = atom_name_[index], + .alt_id = GetNormalizedAltId(index)}; + } + + std::string GetAtomString(const std::size_t index) const { + std::string alt_id_col; + if (column_spec_.alt_id_col.has_value()) { + alt_id_col = *column_spec_.alt_id_col; + } else { + alt_id_col = "default label_alt_id"; + } + return absl::StrCat( + column_spec_.chain_id_col, "=", chain_id_[index], ", ", // + column_spec_.res_id_1_col, "=", res_id_1_[index], ", ", // + column_spec_.res_id_2_col, "=", res_id_2_[index], ", ", // + column_spec_.atom_name_col, "=", atom_name_[index], ", ", // + alt_id_col, "=", GetNormalizedAltId(index)); // + } + + private: + AtomColumns(StringArrayRef chain_id, StringArrayRef res_id_1, + StringArrayRef res_id_2, StringArrayRef atom_name, + std::optional alt_id, + const ColumnSpec& column_spec) + : chain_id_(chain_id), + res_id_1_(res_id_1), + res_id_2_(res_id_2), + atom_name_(atom_name), + alt_id_(alt_id), + column_spec_(column_spec), + size_(chain_id.size()) {} + StringArrayRef chain_id_; + StringArrayRef res_id_1_; + StringArrayRef res_id_2_; + StringArrayRef atom_name_; + std::optional alt_id_; + ColumnSpec column_spec_; + std::size_t size_; +}; + +// Adds the atom index to any rows in the bond table involving that atom. +absl::Status FillInBondsForAtom(const BondIndexByAtom& bond_index_by_atom, + const AtomId& atom, + const std::size_t atom_index, + BondAtomIndices& bond_atom_indices) { + if (auto bond_index_it = bond_index_by_atom.find(atom); + bond_index_it != bond_index_by_atom.end()) { + for (std::size_t bond_index : bond_index_it->second) { + if (bond_index < 0 || bond_index >= bond_atom_indices.size()) { + return absl::OutOfRangeError( + absl::StrCat("Bond index out of range: ", bond_index)); + } + bond_atom_indices[bond_index] = atom_index; + } + } + return absl::OkStatus(); +} + +// Checks that the CifDict has all of the columns in the column spec. +bool HasAllColumns(const CifDict& mmcif, const ColumnSpec& columns) { + return mmcif.Contains(columns.chain_id_col) && + mmcif.Contains(columns.res_id_1_col) && + mmcif.Contains(columns.res_id_2_col) && + mmcif.Contains(columns.atom_name_col) && + (!columns.alt_id_col.has_value() || + mmcif.Contains(*columns.alt_id_col)); +} + +// Fully specified ptnr1 atom. +constexpr ColumnSpec kStructConnPtnr1ColumnsFull{ + .chain_id_col = "_struct_conn.ptnr1_label_asym_id", + .res_id_1_col = "_struct_conn.ptnr1_auth_seq_id", + .res_id_2_col = "_struct_conn.pdbx_ptnr1_PDB_ins_code", + .atom_name_col = "_struct_conn.ptnr1_label_atom_id", + .alt_id_col = "_struct_conn.pdbx_ptnr1_label_alt_id", +}; + +// Fully specified ptnr2 atom. +constexpr ColumnSpec kStructConnPtnr2ColumnsFull{ + .chain_id_col = "_struct_conn.ptnr2_label_asym_id", + .res_id_1_col = "_struct_conn.ptnr2_auth_seq_id", + .res_id_2_col = "_struct_conn.pdbx_ptnr2_PDB_ins_code", + .atom_name_col = "_struct_conn.ptnr2_label_atom_id", + .alt_id_col = "_struct_conn.pdbx_ptnr2_label_alt_id", +}; + +// Columns used by OpenMM for ptnr1 atoms. +constexpr ColumnSpec kStructConnPtnr1OpenMM{ + .chain_id_col = "_struct_conn.ptnr1_label_asym_id", + .res_id_1_col = "_struct_conn.ptnr1_label_seq_id", + .res_id_2_col = "_struct_conn.ptnr1_label_comp_id", + .atom_name_col = "_struct_conn.ptnr1_label_atom_id", + .alt_id_col = std::nullopt, +}; + +// Columns used by OpenMM for ptnr2 atoms. +constexpr ColumnSpec kStructConnPtnr2OpenMM{ + .chain_id_col = "_struct_conn.ptnr2_label_asym_id", + .res_id_1_col = "_struct_conn.ptnr2_label_seq_id", + .res_id_2_col = "_struct_conn.ptnr2_label_comp_id", + .atom_name_col = "_struct_conn.ptnr2_label_atom_id", + .alt_id_col = std::nullopt, +}; + +// Fully specified atom sites. +constexpr ColumnSpec kAtomSiteColumnsFull{ + .chain_id_col = "_atom_site.label_asym_id", + .res_id_1_col = "_atom_site.auth_seq_id", + .res_id_2_col = "_atom_site.pdbx_PDB_ins_code", + .atom_name_col = "_atom_site.label_atom_id", + .alt_id_col = "_atom_site.label_alt_id", +}; + +// Atom site columns used to match OpenMM _struct_conn tables. +constexpr ColumnSpec kAtomSiteColumnsOpenMM{ + .chain_id_col = "_atom_site.label_asym_id", + .res_id_1_col = "_atom_site.label_seq_id", + .res_id_2_col = "_atom_site.label_comp_id", + .atom_name_col = "_atom_site.label_atom_id", + .alt_id_col = "_atom_site.label_alt_id", +}; + +} // namespace + +absl::StatusOr> GetBondAtomIndices( + const CifDict& mmcif, absl::string_view model_id) { + ColumnSpec ptnr1_columns, ptnr2_columns, atom_site_columns; + + if (HasAllColumns(mmcif, kStructConnPtnr1ColumnsFull) && + HasAllColumns(mmcif, kStructConnPtnr2ColumnsFull)) { + ptnr1_columns = kStructConnPtnr1ColumnsFull; + ptnr2_columns = kStructConnPtnr2ColumnsFull; + atom_site_columns = kAtomSiteColumnsFull; + } else { + ptnr1_columns = kStructConnPtnr1OpenMM; + ptnr2_columns = kStructConnPtnr2OpenMM; + atom_site_columns = kAtomSiteColumnsOpenMM; + } + + absl::StatusOr ptnr1_atoms = + AtomColumns::Create(mmcif, ptnr1_columns); + if (!ptnr1_atoms.ok()) { + return ptnr1_atoms.status(); + } + absl::StatusOr ptnr2_atoms = + AtomColumns::Create(mmcif, ptnr2_columns); + if (!ptnr2_atoms.ok()) { + return ptnr2_atoms.status(); + } + StringArrayRef struct_conn_id = mmcif["_struct_conn.id"]; + if (!AreSameSize(struct_conn_id, *ptnr1_atoms, *ptnr2_atoms)) { + return absl::InvalidArgumentError(absl::StrCat( + "Invalid '_struct_conn.' loop. ", // + "len(id) = ", struct_conn_id.size(), ", ", // + "len(ptnr1_atoms) = ", ptnr1_atoms->size(), ", ", // + "len(ptnr2_atoms) = ", ptnr2_atoms->size(), "." // + )); + } + + absl::StatusOr atoms = + AtomColumns::Create(mmcif, atom_site_columns); + if (!atoms.ok()) { + return atoms.status(); + } + StringArrayRef atom_site_id = mmcif["_atom_site.id"]; + StringArrayRef atom_site_model_id = mmcif["_atom_site.pdbx_PDB_model_num"]; + if (!AreSameSize(atom_site_id, atom_site_model_id, *atoms)) { + return absl::InvalidArgumentError(absl::StrCat( + "Invalid '_atom_site.' loop. ", // + "len(id)= ", atom_site_id.size(), ", ", // + "len(pdbx_PDB_model_num)= ", atom_site_model_id.size(), ", ", // + "len(atoms)= ", atoms->size(), ".")); // + } + + // Build maps from atom ID tuples to the rows in _struct_conn where that + // atom appears (NB could be multiple). + const std::size_t struct_conn_size = struct_conn_id.size(); + BondIndexByAtom ptnr1_rows_by_atom(struct_conn_size); + BondIndexByAtom ptnr2_rows_by_atom(struct_conn_size); + for (std::size_t i = 0; i < struct_conn_size; ++i) { + ptnr1_rows_by_atom[ptnr1_atoms->GetAtom(i)].push_back(i); + ptnr2_rows_by_atom[ptnr2_atoms->GetAtom(i)].push_back(i); + } + + // Allocate two output arrays with one element per row in struct_conn, where + // each element will be the index of that atom in the atom_site table. + // Fill the arrays with atom_site_size, which is an invalid value, so that + // we can check at the end that each atom has been found. + const std::size_t atom_site_size = atom_site_id.size(); + BondAtomIndices ptnr1_atom_indices(struct_conn_size, atom_site_size); + BondAtomIndices ptnr2_atom_indices(struct_conn_size, atom_site_size); + + bool model_id_ecountered = false; + absl::flat_hash_set seen_alt_ids; + for (std::size_t atom_i = 0; atom_i < atom_site_size; ++atom_i) { + if (atom_site_model_id[atom_i] != model_id) { + if (!model_id_ecountered) { + continue; + } else { + // Models are contiguous so once we see a different model ID after + // encountering our model ID then we can exit early. + break; + } + } else { + model_id_ecountered = true; + } + AtomId atom = atoms->GetAtom(atom_i); + seen_alt_ids.insert(atom.alt_id); + + if (auto fill_in_bonds_status1 = FillInBondsForAtom( + ptnr1_rows_by_atom, atom, atom_i, ptnr1_atom_indices); + !fill_in_bonds_status1.ok()) { + return fill_in_bonds_status1; + } + if (auto fill_in_bonds_status2 = FillInBondsForAtom( + ptnr2_rows_by_atom, atom, atom_i, ptnr2_atom_indices); + !fill_in_bonds_status2.ok()) { + return fill_in_bonds_status2; + } + } + // The seen_alt_ids check is a workaround for a known PDB issue: some mmCIFs + // (2evw, 2g0v, 2g0x, 2g0z, 2g10, 2g11, 2g12, 2g14, 2grz, 2ntw as of 2024) + // have multiple models and they set different whole-chain altloc in each + // model. The bond table however doesn't distinguish between models, so there + // are bonds that are valid only for some models. E.g. 2grz has model 1 with + // chain A with altloc A, and model 2 with chain A with altloc B. The bonds + // table lists a bond for each of these. + + // Check that a ptnr1 atom was found for every bond. + if (auto row_it = absl::c_find(ptnr1_atom_indices, atom_site_size); + row_it != ptnr1_atom_indices.end()) { + if (seen_alt_ids.size() > 1 || seen_alt_ids.contains(".") || + seen_alt_ids.contains("?")) { + std::size_t i = std::distance(ptnr1_atom_indices.begin(), row_it); + return absl::InvalidArgumentError( + absl::StrCat("Error parsing \"", mmcif.GetDataName(), "\". ", + "Cannot find atom for bond ID ", struct_conn_id[i], ": ", + ptnr1_atoms->GetAtomString(i))); + } + } + + // Check that a ptnr2 atom was found for every bond. + if (auto row_it = absl::c_find(ptnr2_atom_indices, atom_site_size); + row_it != ptnr2_atom_indices.end()) { + if (seen_alt_ids.size() > 1 || seen_alt_ids.contains(".") || + seen_alt_ids.contains("?")) { + std::size_t i = std::distance(ptnr2_atom_indices.begin(), row_it); + return absl::InvalidArgumentError( + absl::StrCat("Error parsing \"", mmcif.GetDataName(), "\". ", + "Cannot find atom for bond ID ", struct_conn_id[i], ": ", + ptnr2_atoms->GetAtomString(i))); + } + } + + if (!model_id_ecountered) { + return absl::InvalidArgumentError(absl::StrCat( + "Error parsing \"", mmcif.GetDataName(), "\". model_id \"", model_id, + "\" not found in _atom_site.pdbx_PDB_model_num.")); + } + + return std::make_pair(std::move(ptnr1_atom_indices), + std::move(ptnr2_atom_indices)); +} + +} // namespace alphafold3 diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_struct_conn_pybind.cc b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_struct_conn_pybind.cc new file mode 100644 index 0000000000000000000000000000000000000000..111715ab5b5d0bce3ea735b85302be3e5e852beb --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_struct_conn_pybind.cc @@ -0,0 +1,68 @@ +// Copyright 2024 DeepMind Technologies Limited +// +// AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +// this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +// +// To request access to the AlphaFold 3 model parameters, follow the process set +// out at https://github.com/google-deepmind/alphafold3. You may only use these +// if received directly from Google. Use is subject to terms of use available at +// https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +#include + +#include "absl/strings/string_view.h" +#include "alphafold3/parsers/cpp/cif_dict_lib.h" +#include "alphafold3/structure/cpp/mmcif_struct_conn.h" +#include "pybind11/gil.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" +#include "pybind11/stl.h" + +namespace alphafold3 { + +namespace py = pybind11; + +constexpr char kGetBondAtomIndices[] = R"( +Extracts the indices of the atoms that participate in bonds. + +This function has a workaround for a known PDB issue: some mmCIFs have +(2evw, 2g0v, 2g0x, 2g0z, 2g10, 2g11, 2g12, 2g14, 2grz, 2ntw as of 2024) +multiple models and they set different whole-chain altloc in each model. +The bond table however doesn't distinguish between models, so there are +bonds that are valid only for some models. E.g. 2grz has model 1 with +chain A with altloc A, and model 2 with chain A with altloc B. The bonds +table lists a bond for each of these. This case is rather rare (10 cases +in PDB as of 2024). For the offending bonds, the returned atom index is +set to the size of the atom_site table, i.e. it is an invalid index. + +Args: + mmcif: The mmCIF object to process. + model_id: The ID of the model that the returned atoms will belong to. This + should be a value in the mmCIF's _atom_site.pdbx_PDB_model_num column. + +Returns: + Two lists of atom indices, `from_atoms` and `to_atoms`, each one having + length num_bonds (as defined by _struct_conn, the bonds table). The bond + i, defined by the i'th row in _struct_conn, is a bond from atom at index + from_atoms[i], to the atom at index to_atoms[i]. The indices are simple + 0-based indexes into the columns of the _atom_site table in the input + mmCIF, and do not necessarily correspond to the values in _atom_site.id, + or any other column. +)"; + +void RegisterModuleMmcifStructConn(pybind11::module m) { + m.def( + "get_bond_atom_indices", + [](const CifDict& mmcif, absl::string_view model_id) { + auto result = GetBondAtomIndices(mmcif, model_id); + if (result.ok()) { + return *result; + } + throw py::value_error(std::string(result.status().message())); + }, + py::arg("mmcif_dict"), py::arg("model_id"), + py::doc(kGetBondAtomIndices + 1), + py::call_guard()); +} + +} // namespace alphafold3 diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_struct_conn_pybind.h b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_struct_conn_pybind.h new file mode 100644 index 0000000000000000000000000000000000000000..acdbf7b773ba65c779bb88dffb7dc8b69ca8ee60 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_struct_conn_pybind.h @@ -0,0 +1,24 @@ +/* + * Copyright 2024 DeepMind Technologies Limited + * + * AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of + * this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ + * + * To request access to the AlphaFold 3 model parameters, follow the process set + * out at https://github.com/google-deepmind/alphafold3. You may only use these + * if received directly from Google. Use is subject to terms of use available at + * https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + */ + +#ifndef ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_MMCIF_STRUCT_CONN_PYBIND_H_ +#define ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_MMCIF_STRUCT_CONN_PYBIND_H_ + +#include "pybind11/pybind11.h" + +namespace alphafold3 { + +void RegisterModuleMmcifStructConn(pybind11::module m); + +} + +#endif // ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_MMCIF_STRUCT_CONN_PYBIND_H_ diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_utils.pyi b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_utils.pyi new file mode 100644 index 0000000000000000000000000000000000000000..aa2dc23e90af680004747226cf88578693920177 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_utils.pyi @@ -0,0 +1,71 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +from collections.abc import Sequence + +import numpy as np + +from alphafold3.cpp import cif_dict +from alphafold3.structure.python import mmcif_layout + + +def filter( + mmcif: cif_dict.CifDict, + include_nucleotides: bool, + include_ligands: bool = ..., + include_water: bool = ..., + include_other: bool = ..., + model_id: str = ..., +) -> tuple[np.ndarray[int], mmcif_layout.MmcifLayout]: ... + + +def fix_residues( + layout: mmcif_layout.MmcifLayout, + comp_id: Sequence[str], + atom_id: Sequence[str], + atom_x: Sequence[float], + atom_y: Sequence[float], + atom_z: Sequence[float], + fix_arg: bool = ..., +) -> None: ... + + +def read_layout( + mmcif: cif_dict.CifDict, model_id: str = ... +) -> mmcif_layout.MmcifLayout: ... + + +def selected_ligand_residue_mask( + layout: mmcif_layout.MmcifLayout, + atom_site_label_asym_ids: list[str], + atom_site_label_seq_ids: list[str], + atom_site_auth_seq_ids: list[str], + atom_site_label_comp_ids: list[str], + atom_site_pdbx_pdb_ins_codes: list[str], + nonpoly_asym_ids: list[str], + nonpoly_auth_seq_ids: list[str], + nonpoly_pdb_ins_codes: list[str], + nonpoly_mon_ids: list[str], + branch_asym_ids: list[str], + branch_auth_seq_ids: list[str], + branch_pdb_ins_codes: list[str], + branch_mon_ids: list[str], +) -> tuple[list[bool], list[bool]]: ... + + +def selected_polymer_residue_mask( + layout: mmcif_layout.MmcifLayout, + atom_site_label_asym_ids: list[str], + atom_site_label_seq_ids: list[str], + atom_site_label_comp_ids: list[str], + poly_seq_asym_ids: list[str], + poly_seq_seq_ids: list[str], + poly_seq_mon_ids: list[str], +) -> list[bool]: ... diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_utils_pybind.cc b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_utils_pybind.cc new file mode 100644 index 0000000000000000000000000000000000000000..52bd039b2984e6d8c599124bfc0c4b201c0a7041 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_utils_pybind.cc @@ -0,0 +1,787 @@ +// Copyright 2024 DeepMind Technologies Limited +// +// AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +// this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +// +// To request access to the AlphaFold 3 model parameters, follow the process set +// out at https://github.com/google-deepmind/alphafold3. You may only use these +// if received directly from Google. Use is subject to terms of use available at +// https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "numpy/ndarrayobject.h" +#include "numpy/ndarraytypes.h" +#include "numpy/npy_common.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "alphafold3/parsers/cpp/cif_dict_lib.h" +#include "alphafold3/structure/cpp/mmcif_altlocs.h" +#include "alphafold3/structure/cpp/mmcif_layout.h" +#include "pybind11/cast.h" +#include "pybind11/gil.h" +#include "pybind11/numpy.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" +#include "pybind11/stl.h" +#include "pybind11_abseil/absl_casters.h" + +namespace alphafold3 { +namespace { +namespace py = pybind11; + +struct PyObjectDeleter { + inline void operator()(PyObject* obj) const { Py_CLEAR(obj); } +}; + +using ScopedPyObject = std::unique_ptr; + +using StringArrayRef = absl::Span; +using Indexer = absl::flat_hash_map; + +// Returns the reverse look-up map of name to index. +Indexer MakeIndex(StringArrayRef col) { + Indexer index; + index.reserve(col.size()); + for (std::size_t i = 0; i < col.size(); ++i) { + index[col[i]] = i; + } + return index; +} + +// Returns whether each container is the same size. +template +bool AreSameSize(C c, const Cs&... cs) { + return ((c.size() == cs.size()) && ...); +} + +// Stores references to columns in `_atom_site` ensuring they all exist and +// are the same size. +struct AtomSiteLoop { + explicit AtomSiteLoop(const CifDict& cif_dict) + : id(cif_dict["_atom_site.id"]), + model_id(cif_dict["_atom_site.pdbx_PDB_model_num"]), + chain_id(cif_dict["_atom_site.label_asym_id"]), + seq_id(cif_dict["_atom_site.label_seq_id"]), + + comp_id(cif_dict["_atom_site.label_comp_id"]), + atom_id(cif_dict["_atom_site.label_atom_id"]), + + alt_id(cif_dict["_atom_site.label_alt_id"]), + occupancy(cif_dict["_atom_site.occupancy"]) + + { + if (!AreSameSize(id, model_id, chain_id, seq_id, comp_id, atom_id, alt_id, + occupancy)) { + throw py::value_error( + absl::StrCat("Invalid '_atom_site.' loop. ", // + "len(id)=", id.size(), ", ", // + "len(pdbx_PDB_model_num)=", model_id.size(), ", ", // + "len(label_asym_id)=", chain_id.size(), ", ", // + "len(label_seq_id)=", seq_id.size(), ", ", // + "len(label_comp_id)=", comp_id.size(), ", ", // + "len(atom_id)=", atom_id.size(), ", ", // + "len(label_alt_id)=", alt_id.size(), ", ", // + "len(occupancy)=", occupancy.size())); + } + } + StringArrayRef id; + StringArrayRef model_id; + StringArrayRef chain_id; + StringArrayRef seq_id; + StringArrayRef comp_id; + StringArrayRef atom_id; + StringArrayRef alt_id; + StringArrayRef occupancy; +}; + +// Stores references to columns in `_entity` ensuring they all exist and are the +// same size. +struct EntityLoop { + explicit EntityLoop(const CifDict& cif_dict) + : id(cif_dict["_entity.id"]), type(cif_dict["_entity.type"]) { + if (!AreSameSize(id, type)) { + throw py::value_error(absl::StrCat("Invalid '_entity.' loop. ", // + "len(id)=", id.size(), ", ", // + "len(type)=", type.size())); + } + } + StringArrayRef id; + StringArrayRef type; +}; + +// Stores references to columns in `_entity_poly` ensuring they all exist and +// are the same size. +struct EntityPolyLoop { + explicit EntityPolyLoop(const CifDict& cif_dict) + : entity_id(cif_dict["_entity_poly.entity_id"]), + type(cif_dict["_entity_poly.type"]) { + if (!AreSameSize(entity_id, type)) { + throw py::value_error(absl::StrCat("Invalid '_entity_poly.' loop. ", // + "len(entity_id)=", entity_id.size(), + ", ", // + "len(type)=", type.size())); + } + } + StringArrayRef entity_id; + StringArrayRef type; +}; + +// Returns a set of entity names removing ones not included by the flags +// specified. +absl::flat_hash_set SelectChains(const CifDict& mmcif, + bool include_nucleotides, + bool include_ligands, + bool include_water, + bool include_other) { + EntityLoop entity_loop(mmcif); + EntityPolyLoop entity_poly(mmcif); + absl::flat_hash_set permitted_polymers{"polypeptide(L)"}; + absl::flat_hash_set forbidden_polymers; + for (absl::string_view type : + {"polydeoxyribonucleotide", "polyribonucleotide", + "polydeoxyribonucleotide/polyribonucleotide hybrid"}) { + if (include_nucleotides) { + permitted_polymers.emplace(type); + } else { + forbidden_polymers.emplace(type); + } + } + + absl::flat_hash_set permitted_nonpoly_entity_types; + absl::flat_hash_set forbidden_nonpoly_entity_types; + for (absl::string_view type : {"non-polymer", "branched"}) { + if (include_ligands) { + permitted_nonpoly_entity_types.emplace(type); + } else { + forbidden_nonpoly_entity_types.emplace(type); + } + } + absl::string_view water_type = "water"; + if (include_water) { + permitted_nonpoly_entity_types.emplace(water_type); + } else { + forbidden_nonpoly_entity_types.emplace(water_type); + } + + StringArrayRef chain_ids = mmcif["_struct_asym.id"]; + StringArrayRef entity_ids = mmcif["_struct_asym.entity_id"]; + Indexer chain_index = MakeIndex(chain_ids); + Indexer entity_poly_index = MakeIndex(entity_poly.entity_id); + Indexer entity_id_to_index = MakeIndex(entity_loop.id); + + absl::flat_hash_set keep_chain_id; + for (std::size_t i = 0; i < chain_ids.size(); ++i) { + absl::string_view chain_id = chain_ids[i]; + absl::string_view entity_id = entity_ids[i]; + if (entity_id_to_index.empty() || + entity_loop.type[entity_id_to_index[entity_id]] == "polymer") { + if (auto it = entity_poly_index.find(entity_id); + it != entity_poly_index.end()) { + absl::string_view poly_type = entity_poly.type[it->second]; + if (include_other) { + if (!forbidden_polymers.contains(poly_type)) { + keep_chain_id.insert(chain_id); + } + } else { + if (permitted_polymers.contains(poly_type)) { + keep_chain_id.insert(chain_id); + } + } + } + } else { + absl::string_view entity_type = + entity_loop.type[entity_id_to_index[entity_id]]; + if (include_other) { + if (!forbidden_nonpoly_entity_types.contains(entity_type)) { + keep_chain_id.insert(chain_id); + continue; + } + } else { + if (permitted_nonpoly_entity_types.contains(entity_type)) { + keep_chain_id.insert(chain_id); + continue; + } + } + } + } + return keep_chain_id; +} + +class ProcessResidue { + public: + explicit ProcessResidue(const char* residue) + : residue_(PyUnicode_InternFromString(residue)) {} + bool IsResidue(PyObject* residue) { + return ArePyObjectsEqual(residue_.get(), residue); + } + + static bool ArePyObjectsEqual(PyObject* lhs, PyObject* rhs) { + switch (PyObject_RichCompareBool(lhs, rhs, Py_EQ)) { + case -1: + PyErr_Clear(); + return false; + case 0: + return false; + default: + return true; + } + } + + private: + ScopedPyObject residue_; +}; + +struct Position3 { + float x; + float y; + float z; +}; + +float DistanceSquared(Position3 v1, Position3 v2) { + float dx = v1.x - v2.x; + float dy = v1.y - v2.y; + float dz = v1.z - v2.z; + return dx * dx + dy * dy + dz * dz; +} + +class FixArginine : public ProcessResidue { + public: + FixArginine() + : ProcessResidue("ARG"), + cd_(PyUnicode_InternFromString("CD")), + nh1_(PyUnicode_InternFromString("NH1")), + nh2_(PyUnicode_InternFromString("NH2")), + hh11_(PyUnicode_InternFromString("HH11")), + hh21_(PyUnicode_InternFromString("HH21")), + hh12_(PyUnicode_InternFromString("HH12")), + hh22_(PyUnicode_InternFromString("HH22")) {} + void Fix(absl::Span atom_ids, absl::Span atom_x, + absl::Span atom_y, absl::Span atom_z) { + std::ptrdiff_t cd_index = -1; + std::ptrdiff_t nh1_index = -1; + std::ptrdiff_t nh2_index = -1; + std::ptrdiff_t hh11_index = -1; + std::ptrdiff_t hh21_index = -1; + std::ptrdiff_t hh12_index = -1; + std::ptrdiff_t hh22_index = -1; + for (std::ptrdiff_t index = 0; index < atom_ids.size(); ++index) { + PyObject* atom_id = atom_ids[index]; + if (cd_index == -1 && ArePyObjectsEqual(atom_id, cd_.get())) { + cd_index = index; + } else if (nh1_index == -1 && ArePyObjectsEqual(atom_id, nh1_.get())) { + nh1_index = index; + } else if (nh2_index == -1 && ArePyObjectsEqual(atom_id, nh2_.get())) { + nh2_index = index; + } else if (hh11_index == -1 && ArePyObjectsEqual(atom_id, hh11_.get())) { + hh11_index = index; + } else if (hh21_index == -1 && ArePyObjectsEqual(atom_id, hh21_.get())) { + hh21_index = index; + } else if (hh12_index == -1 && ArePyObjectsEqual(atom_id, hh12_.get())) { + hh12_index = index; + } else if (hh22_index == -1 && ArePyObjectsEqual(atom_id, hh22_.get())) { + hh22_index = index; + } + } + if (cd_index < 0 || nh1_index < 0 || nh2_index < 0) { + return; + } + Position3 cd_pos(atom_x[cd_index], atom_y[cd_index], atom_z[cd_index]); + Position3 nh1_pos(atom_x[nh1_index], atom_y[nh1_index], atom_z[nh1_index]); + Position3 nh2_pos(atom_x[nh2_index], atom_y[nh2_index], atom_z[nh2_index]); + if (DistanceSquared(nh1_pos, cd_pos) <= DistanceSquared(nh2_pos, cd_pos)) { + return; + } + std::swap(atom_ids[nh1_index], atom_ids[nh2_index]); + if (hh11_index >= 0 && hh21_index >= 0) { + std::swap(atom_ids[hh11_index], atom_ids[hh21_index]); + } else if (hh11_index >= 0) { + Py_DECREF(atom_ids[hh11_index]); + Py_INCREF(hh21_.get()); + atom_ids[hh11_index] = hh21_.get(); + } else if (hh21_index >= 0) { + Py_DECREF(atom_ids[hh21_index]); + Py_INCREF(hh11_.get()); + atom_ids[hh21_index] = hh11_.get(); + } + if (hh12_index >= 0 && hh22_index >= 0) { + std::swap(atom_ids[hh12_index], atom_ids[hh22_index]); + } else if (hh12_index >= 0) { + Py_DECREF(atom_ids[hh12_index]); + Py_INCREF(hh22_.get()); + atom_ids[hh12_index] = hh22_.get(); + } else if (hh22_index >= 0) { + Py_DECREF(atom_ids[hh22_index]); + Py_INCREF(hh21_.get()); + atom_ids[hh22_index] = hh21_.get(); + } + } + + private: + ScopedPyObject cd_; + ScopedPyObject nh1_; + ScopedPyObject nh2_; + ScopedPyObject hh11_; + ScopedPyObject hh21_; + ScopedPyObject hh12_; + ScopedPyObject hh22_; +}; + +// Returns the layout of the mmCIF `_atom_site` table. +inline MmcifLayout ReadMmcifLayout(const CifDict& mmcif, + absl::string_view model_id = "") { + py::gil_scoped_release release; + auto mmcif_layout = MmcifLayout::Create(mmcif, model_id); + if (mmcif_layout.ok()) { + return *mmcif_layout; + } + + throw py::value_error(std::string(mmcif_layout.status().message())); +} + +std::pair MmcifFilter( // + const CifDict& mmcif, // + bool include_nucleotides, // + bool include_ligands, // + bool include_water, // + bool include_other, // + absl::string_view model_id) { + if (_import_array() < 0) { + throw py::import_error("Failed to import NumPy."); + } + auto layout = ReadMmcifLayout(mmcif, model_id); + std::unique_ptr> keep_indices; + size_t new_num_atoms; + + { + py::gil_scoped_release release; + + AtomSiteLoop atom_site(mmcif); + + auto keep_chain_ids = + SelectChains(mmcif, include_nucleotides, include_ligands, include_water, + include_other); + + std::vector chain_indices; + chain_indices.reserve(keep_chain_ids.size()); + for (std::size_t i = 0; i < layout.num_chains(); ++i) { + if (keep_chain_ids.contains( + atom_site.chain_id[layout.atom_site_from_chain_index(i)])) { + chain_indices.push_back(i); + } + } + + keep_indices = + absl::WrapUnique(new std::vector(ResolveMmcifAltLocs( + layout, atom_site.comp_id, atom_site.atom_id, atom_site.alt_id, + atom_site.occupancy, chain_indices))); + new_num_atoms = keep_indices->size(); + + if (layout.num_models() > 1) { + keep_indices->reserve(layout.num_models() * new_num_atoms); + std::uint64_t* start = &(*keep_indices->begin()); + std::size_t num_atom = keep_indices->size(); + // Copy first model indices into all model indices offsetting each copy. + for (std::size_t i = 1; i < layout.num_models(); ++i) { + std::size_t offset = i * layout.num_atoms(); + std::transform(start, start + num_atom, + std::back_inserter(*keep_indices), + [offset](std::size_t v) { return v + offset; }); + } + } + } + + layout.Filter(*keep_indices); + + npy_intp shape[] = {static_cast(layout.num_models()), + static_cast(new_num_atoms)}; + PyObject* arr = + PyArray_SimpleNewFromData(2, shape, NPY_INT64, keep_indices->data()); + // Create a capsule to hold the memory of the buffer so NumPy knows how to + // delete it when done with it. + PyObject* capsule = PyCapsule_New( + keep_indices.release(), nullptr, +[](PyObject* capsule_cleanup) { + void* memory = PyCapsule_GetPointer(capsule_cleanup, nullptr); + delete static_cast*>(memory); + }); + PyArray_SetBaseObject(reinterpret_cast(arr), capsule); + + return std::make_pair(py::reinterpret_steal(arr), + std::move(layout)); +} + +void MmcifFixResidues( // + const MmcifLayout& layout, // + absl::Span comp_id, // + absl::Span atom_id, // + absl::Span atom_x, // + absl::Span atom_y, // + absl::Span atom_z, // + bool fix_arginine // +) { + std::optional arginine; + std::size_t num_atoms = layout.num_atoms(); + if (comp_id.size() != num_atoms || atom_id.size() != num_atoms || + atom_x.size() != num_atoms || atom_y.size() != num_atoms || + atom_z.size() != num_atoms) { + throw py::value_error( + absl::StrCat("Sizes must match. ", // + "num_atoms=", num_atoms, ", ", // + "len(comp_id)=", comp_id.size(), ", ", // + "len(atom_id)=", atom_id.size(), ", ", // + "len(atom_x)=", atom_x.size(), ", ", // + "len(atom_y)=", atom_y.size(), ", ", // + "len(atom_z)=", atom_z.size())); + } + + if (fix_arginine) { + arginine.emplace(); + } + if (!arginine.has_value()) { + return; + } + + for (std::size_t res_index = 0; res_index < layout.num_residues(); + ++res_index) { + auto [atom_start, atom_end] = layout.atom_range(res_index); + std::size_t atom_count = atom_end - atom_start; + PyObject* resname = comp_id[atom_start]; + if (arginine.has_value() && arginine->IsResidue(resname)) { + arginine->Fix(atom_id.subspan(atom_start, atom_count), + atom_x.subspan(atom_start, atom_count), + atom_y.subspan(atom_start, atom_count), + atom_z.subspan(atom_start, atom_count)); + } + } +} + +std::vector SelectedPolymerResidueMask( + const MmcifLayout& layout, + const std::vector& atom_site_label_asym_ids, // + const std::vector& atom_site_label_seq_ids, // + const std::vector& atom_site_label_comp_ids, // + const std::vector& poly_seq_asym_ids, // + const std::vector& poly_seq_seq_ids, // + const std::vector& poly_seq_mon_ids // +) { + absl::flat_hash_map, + absl::string_view> + selected; + selected.reserve(layout.num_residues()); + // layout.residues() is O(1) while layout.residue_starts() is O(num_res). + const std::vector& residue_starts = layout.residue_starts(); + for (int i = 0; i < layout.residues().size(); ++i) { + std::size_t res_start = residue_starts[i]; + std::size_t res_end = layout.residues()[i]; + if (res_start == res_end) { + continue; // Skip empty residues (containing no atoms). + } + + absl::string_view label_seq_id = atom_site_label_seq_ids[i]; + if (label_seq_id == ".") { + continue; // Skip non-polymers. + } + + absl::string_view label_asym_id = atom_site_label_asym_ids[i]; + absl::string_view label_comp_id = atom_site_label_comp_ids[i]; + selected[std::make_pair(label_asym_id, label_seq_id)] = label_comp_id; + } + + std::vector mask; + mask.reserve(poly_seq_mon_ids.size()); + for (int i = 0; i < poly_seq_mon_ids.size(); ++i) { + absl::string_view poly_seq_asym_id = poly_seq_asym_ids[i]; + absl::string_view poly_seq_seq_id = poly_seq_seq_ids[i]; + absl::string_view poly_seq_mon_id = poly_seq_mon_ids[i]; + + auto it = selected.find(std::make_pair(poly_seq_asym_id, poly_seq_seq_id)); + if (it != selected.end()) { + mask.push_back(it->second == poly_seq_mon_id); + } else { + mask.push_back(true); // Missing residues are not heterogeneous. + } + } + return mask; +} + +std::pair, std::vector> SelectedLigandResidueMask( + const MmcifLayout& layout, // + const std::vector& atom_site_label_asym_ids, // + const std::vector& atom_site_label_seq_ids, // + const std::vector& atom_site_auth_seq_ids, // + const std::vector& atom_site_label_comp_ids, // + const std::vector& atom_site_pdbx_pdb_ins_codes, // + const std::vector& nonpoly_asym_ids, // + const std::vector& nonpoly_auth_seq_ids, // + const std::vector& nonpoly_pdb_ins_codes, // + const std::vector& nonpoly_mon_ids, // + const std::vector& branch_asym_ids, // + const std::vector& branch_auth_seq_ids, // + const std::vector& branch_pdb_ins_codes, // + const std::vector& branch_mon_ids) { + absl::flat_hash_map< + std::tuple, + absl::string_view> + selected; + selected.reserve(layout.num_residues()); + // layout.residues() is O(1) while layout.residue_starts() is O(num_res). + const std::vector& residue_starts = layout.residue_starts(); + for (int i = 0; i < layout.residues().size(); ++i) { + std::size_t res_start = residue_starts[i]; + std::size_t res_end = layout.residues()[i]; + if (res_start == res_end) { + continue; // Skip empty residues (containing no atoms). + } + + absl::string_view label_seq_id = atom_site_label_seq_ids[i]; + if (label_seq_id != ".") { + continue; // Skip polymers. + } + + absl::string_view label_asym_id = atom_site_label_asym_ids[i]; + absl::string_view auth_seq_id = atom_site_auth_seq_ids[i]; + absl::string_view ins_code = atom_site_pdbx_pdb_ins_codes[i]; + ins_code = ins_code == "?" ? "." : ins_code; // Remap unknown to unset. + absl::string_view label_comp_id = atom_site_label_comp_ids[i]; + selected[std::make_tuple(label_asym_id, auth_seq_id, ins_code)] = + label_comp_id; + } + + std::vector nonpoly_mask; + nonpoly_mask.reserve(nonpoly_asym_ids.size()); + for (int i = 0; i < nonpoly_asym_ids.size(); ++i) { + absl::string_view nonpoly_asym_id = nonpoly_asym_ids[i]; + absl::string_view nonpoly_auth_seq_id = nonpoly_auth_seq_ids[i]; + absl::string_view nonpoly_ins_code = nonpoly_pdb_ins_codes[i]; + // Remap unknown to unset. + nonpoly_ins_code = nonpoly_ins_code == "?" ? "." : nonpoly_ins_code; + absl::string_view nonpoly_mon_id = nonpoly_mon_ids[i]; + + auto it = selected.find(std::make_tuple( + nonpoly_asym_id, nonpoly_auth_seq_id, nonpoly_ins_code)); + if (it != selected.end()) { + nonpoly_mask.push_back(it->second == nonpoly_mon_id); + } else { + nonpoly_mask.push_back(true); // Missing residues are not heterogeneous. + } + } + + std::vector branch_mask; + branch_mask.reserve(branch_asym_ids.size()); + for (int i = 0; i < branch_asym_ids.size(); ++i) { + absl::string_view branch_asym_id = branch_asym_ids[i]; + absl::string_view branch_auth_seq_id = branch_auth_seq_ids[i]; + + // Insertion codes in _pdbx_branch_scheme are not required and can be + // missing. Default to unset ('.') in such case. + absl::string_view branch_ins_code; + if (i < branch_pdb_ins_codes.size()) { + branch_ins_code = branch_pdb_ins_codes[i]; + // Remap unknown to unset. + branch_ins_code = branch_ins_code == "?" ? "." : branch_ins_code; + } else { + branch_ins_code = "."; + } + + absl::string_view branch_mon_id = branch_mon_ids[i]; + + auto it = selected.find( + std::make_tuple(branch_asym_id, branch_auth_seq_id, branch_ins_code)); + if (it != selected.end()) { + branch_mask.push_back(it->second == branch_mon_id); + } else { + branch_mask.push_back(true); // Missing residues are not heterogeneous. + } + } + + return std::make_pair(nonpoly_mask, branch_mask); +} + +constexpr char kReadMmcifLayout[] = R"( +Returns the layout of the cif_dict. + +Args: + mmcif: mmCIF to calculate the layout for. + model_id: If non-empty the layout of the given model is returned + otherwise the layout of all models are returned. +Raises: + ValueError: if the mmCIF is malformed or the number of atoms in each + model are inconsistent. +)"; + +constexpr char kMmcifFilter[] = R"( +Returns NumpyArray of selected rows in `_atom_site` and new layout. + +Args: + mmcif: mmCIF to filter. + include_nucleotides: Whether to include polymer entities of type: + "polypeptide(L)\", "polydeoxyribonucleotide", "polyribonucleotide". + Otherwise only "polypeptide(L)\". ("polypeptide(D)\" is never included.) + include_ligands: Whether to include non-polymer entities of type: + "non-polymer", "branched". + include_water: Whether to include entities of type water. + include_other: Whether to include other (non-standard) entity types + that are not covered by any of the above parameters. + model_id: If non-empty the model with given name is selected otherwise + all models are selected. + +Returns: + A tuple containing a numpy array with a shape (num_models, num_atoms) + with the atom_site indices selected and the new layout. + +Raises: + ValueError error if mmCIF dict does not have all required fields. +)"; + +constexpr char kMmcifFixResidues[] = R"( +Fixes residue columns in-place. + +Args: + layout: layout from filter command. + comp_id: '_atom_site.label_comp_id' of first model. + group: '_atom_site.group_PDB' of first model. + atom_id: '_atom_site.label_atom_id' of first model. + type_symbol: '_atom_site.type_symbol' of first model. + atom_x: '_atom_site.Cartn_x' of first model. + atom_y: '_atom_site.Cartn_y' of first model. + atom_z: '_atom_site.Cartn_z' of first model. + fix_mse: Whether to convert MSE residues into MET residues. + fix_arg: Whether to ensure the atoms in ARG are in the correct order. + fix_unknown_dna: Whether to convert DNA residues from N to DN. + dna_mask: Which atoms are from DNA chains. + +Raises: + ValueError: If shapes are invalid. +)"; + +constexpr char kSelectedPolymerResidueMask[] = R"( +Returns a _pdbx_poly_seq_scheme mask for selected hetero residues. + +Should be called after filtering the layout using mmcif_utils.filter. + +Args: + layout: Layout defining the _atom_site residue selection. + atom_site_label_asym_ids: Internal (label) chain ID, per selected residue. + atom_site_label_seq_ids: Internal (label) residue ID, per selected residue. + atom_site_label_comp_ids: Residue name, per selected residue. + poly_seq_asym_ids: Internal (label) chain ID, per residue. + poly_seq_seq_ids: Internal (label) residue ID, per residue. + poly_seq_mon_ids: Residue name, per residue. + +Returns: + A mask for the _pdbx_poly_seq_scheme table. If residues are selected + using this mask, they will have consistent heterogeneous residue + selection with the _atom_site table. +)"; + +constexpr char kSelectedLigandResidueMask[] = R"( +Returns masks for selected ligand hetero residues. + +Should be called after filtering the layout using mmcif_utils.filter. + +Args: + layout: Layout defining the _atom_site residue selection. + atom_site_label_asym_ids: Internal (label) chain ID, per selected residue. + atom_site_label_seq_ids: Internal (author) residue ID, per selected residue. + atom_site_auth_seq_ids: External (author) residue ID, per selected residue. + atom_site_label_comp_ids: Residue name, per selected residue. + atom_site_pdbx_pdb_ins_codes: Insertion code, per selected residue. + nonpoly_asym_ids: Internal (label) chain ID, per residue from + _pdbx_nonpoly_scheme. + nonpoly_auth_seq_ids: External (author) residue ID, per residue from + _pdbx_nonpoly_scheme. + nonpoly_pdb_ins_codes: Residue name, per residue from + _pdbx_nonpoly_scheme. + nonpoly_mon_ids: Insertion code, per residue from _pdbx_nonpoly_scheme. + branch_asym_ids: Internal (label) chain ID, per residue from + _pdbx_branch_scheme. + branch_auth_seq_ids: External (author) residue ID, per residue from + _pdbx_branch_scheme. + branch_pdb_ins_codes: Residue name, per residue from _pdbx_branch_scheme. + branch_mon_ids: Insertion code, per residue from _pdbx_branch_scheme. + +Returns: + A tuple with masks for _pdbx_nonpoly_scheme and _pdbx_branch_scheme. If + residues are selected using these masks, they will have consistent + heterogeneous residue selection with the _atom_site table. +)"; + +} // namespace + +void RegisterModuleMmcifUtils(pybind11::module m) { + m.def("read_layout", ReadMmcifLayout, + py::arg("mmcif"), // + py::arg("model_id") = "", // + py::doc(kReadMmcifLayout + 1) // + ); + + m.def("filter", MmcifFilter, // + py::arg("mmcif"), // + py::arg("include_nucleotides"), // + py::arg("include_ligands") = false, // + py::arg("include_water") = false, // + py::arg("include_other") = false, // + py::arg("model_id") = "", // + py::doc(kMmcifFilter + 1) // + ); + + m.def("fix_residues", MmcifFixResidues, + py::arg("layout"), // + py::arg("comp_id"), // + py::arg("atom_id"), // + py::arg("atom_x"), // + py::arg("atom_y"), // + py::arg("atom_z"), // + py::arg("fix_arg") = false, // + py::doc(kMmcifFixResidues + 1) // + ); + + m.def("selected_polymer_residue_mask", SelectedPolymerResidueMask, + py::arg("layout"), // + py::arg("atom_site_label_asym_ids"), // + py::arg("atom_site_label_seq_ids"), // + py::arg("atom_site_label_comp_ids"), // + py::arg("poly_seq_asym_ids"), // + py::arg("poly_seq_seq_ids"), // + py::arg("poly_seq_mon_ids"), // + py::call_guard(), // + py::doc(kSelectedPolymerResidueMask + 1) // + ); + + m.def("selected_ligand_residue_mask", SelectedLigandResidueMask, + py::arg("layout"), // + py::arg("atom_site_label_asym_ids"), // + py::arg("atom_site_label_seq_ids"), // + py::arg("atom_site_auth_seq_ids"), // + py::arg("atom_site_label_comp_ids"), // + py::arg("atom_site_pdbx_pdb_ins_codes"), // + py::arg("nonpoly_asym_ids"), // + py::arg("nonpoly_auth_seq_ids"), // + py::arg("nonpoly_pdb_ins_codes"), // + py::arg("nonpoly_mon_ids"), // + py::arg("branch_asym_ids"), // + py::arg("branch_auth_seq_ids"), // + py::arg("branch_pdb_ins_codes"), // + py::arg("branch_mon_ids"), // + py::call_guard(), // + py::doc(kSelectedLigandResidueMask + 1) // + ); +} + +} // namespace alphafold3 diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_utils_pybind.h b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_utils_pybind.h new file mode 100644 index 0000000000000000000000000000000000000000..7ba19420b228b6ad3fb81ce55d47608437ddc45e --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_utils_pybind.h @@ -0,0 +1,24 @@ +/* + * Copyright 2024 DeepMind Technologies Limited + * + * AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of + * this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ + * + * To request access to the AlphaFold 3 model parameters, follow the process set + * out at https://github.com/google-deepmind/alphafold3. You may only use these + * if received directly from Google. Use is subject to terms of use available at + * https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + */ + +#ifndef ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_MMCIF_UTILS_PYBIND_H_ +#define ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_MMCIF_UTILS_PYBIND_H_ + +#include "pybind11/pybind11.h" + +namespace alphafold3 { + +void RegisterModuleMmcifUtils(pybind11::module m); + +} + +#endif // ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_MMCIF_UTILS_PYBIND_H_ diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/string_array.pyi b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/string_array.pyi new file mode 100644 index 0000000000000000000000000000000000000000..b4b76c27f267ce43d2a126d9a53db580cf464772 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/string_array.pyi @@ -0,0 +1,50 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +from collections.abc import Sequence +from typing import Any, overload + +import numpy as np + + +def format_float_array( + values: Sequence[float], num_decimal_places: int +) -> list[str]: ... + + +def isin( + array: np.ndarray[object], + test_elements: set[str | bytes], + *, + invert: bool = ..., +) -> np.ndarray[bool]: ... + + +@overload +def remap( + array: np.ndarray[object], + mapping: dict[str, str], + default_value: str, + inplace: bool = ..., +) -> np.ndarray[object]: ... + + +@overload +def remap( + array: np.ndarray[object], + mapping: dict[str, str], + inplace: bool = ..., +) -> np.ndarray[object]: ... + + +def remap_multiple( + arrays: Sequence[np.ndarray[object]], + mapping: dict[tuple[Any], int], +) -> np.ndarray[int]: ... diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/string_array_pybind.cc b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/string_array_pybind.cc new file mode 100644 index 0000000000000000000000000000000000000000..29fac727a3e689aa116bc82aba1d16d7a190e391 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/string_array_pybind.cc @@ -0,0 +1,329 @@ +// Copyright 2024 DeepMind Technologies Limited +// +// AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +// this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +// +// To request access to the AlphaFold 3 model parameters, follow the process set +// out at https://github.com/google-deepmind/alphafold3. You may only use these +// if received directly from Google. Use is subject to terms of use available at +// https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "numpy/arrayobject.h" +#include "numpy/ndarrayobject.h" +#include "numpy/ndarraytypes.h" +#include "numpy/npy_common.h" +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "pybind11/cast.h" +#include "pybind11/numpy.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" +#include "pybind11_abseil/absl_casters.h" + +namespace { + +namespace py = pybind11; + +PyObject* RemapNumpyArrayObjects(PyObject* array, PyObject* mapping, + bool inplace, PyObject* default_value) { + import_array(); + if (!PyArray_Check(array)) { + PyErr_SetString(PyExc_TypeError, "'array' must be a np.ndarray."); + return nullptr; + } + if (!PyDict_Check(mapping)) { + PyErr_SetString(PyExc_TypeError, "'mapping' must be a Python dict."); + return nullptr; + } + + PyArrayObject* array_obj = reinterpret_cast(array); + if (PyArray_TYPE(array_obj) != NPY_OBJECT) { + PyErr_SetString(PyExc_TypeError, "`array` must be an array of objects."); + return nullptr; + } + + if (inplace) { + // We are returning original array so we need to increase the ref count. + Py_INCREF(array); + } else { + // We are returning a fresh copy. + array = PyArray_NewCopy(array_obj, NPY_CORDER); + if (array == nullptr) { + PyErr_SetString(PyExc_MemoryError, "Out of memory!"); + return nullptr; + } + array_obj = reinterpret_cast(array); + } + + if (PyArray_SIZE(array_obj) == 0) { + return array; + } + + if (default_value == nullptr && PyDict_Size(mapping) == 0) { + return array; + } + + NpyIter* iter = NpyIter_New( + array_obj, NPY_ITER_READWRITE | NPY_ITER_EXTERNAL_LOOP | NPY_ITER_REFS_OK, + NPY_KEEPORDER, NPY_NO_CASTING, nullptr); + if (iter == nullptr) { + PyErr_SetString(PyExc_MemoryError, "Out of memory!"); + Py_XDECREF(array); + return nullptr; + } + + NpyIter_IterNextFunc* iter_next = NpyIter_GetIterNext(iter, nullptr); + if (iter_next == nullptr) { + NpyIter_Deallocate(iter); + Py_XDECREF(array); + PyErr_SetString(PyExc_MemoryError, "Out of memory!"); + return nullptr; + } + + // Iterating arrays taken from: + // https://numpy.org/doc/stable/reference/c-api/iterator.html + char** data_pointer = NpyIter_GetDataPtrArray(iter); + npy_intp* stride_pointer = NpyIter_GetInnerStrideArray(iter); + npy_intp* inner_size_pointer = NpyIter_GetInnerLoopSizePtr(iter); + do { + char* data = *data_pointer; + npy_intp stride = *stride_pointer; + npy_intp count = *inner_size_pointer; + for (size_t i = 0; i < count; ++i) { + PyObject* entry; + std::memcpy(&entry, data, sizeof(PyObject*)); + PyObject* result = PyDict_GetItem(mapping, entry); + if (result != nullptr) { + // Replace entry. + Py_INCREF(result); + Py_XDECREF(entry); + std::memcpy(data, &result, sizeof(PyObject*)); + } else if (default_value != nullptr) { + // Replace entry with a default value. + Py_INCREF(default_value); + Py_XDECREF(entry); + std::memcpy(data, &default_value, sizeof(PyObject*)); + } + data += stride; + } + } while (iter_next(iter)); + + NpyIter_Deallocate(iter); + return array; +} + +// Convert 1D Numpy float array to a list of strings where each string has fixed +// number of decimal points. This is faster than Python list comprehension. +std::vector FormatFloatArray(absl::Span values, + int num_decimal_places) { + std::vector output; + output.reserve(values.size()); + + absl::c_transform(values, std::back_inserter(output), + [num_decimal_places](float value) { + return absl::StrFormat("%.*f", num_decimal_places, value); + }); + return output; +} + +py::array_t IsIn( + const py::array_t& array, + const absl::flat_hash_set& test_elements, bool invert) { + const size_t num_elements = array.size(); + py::array_t output(num_elements); + std::fill(output.mutable_data(), output.mutable_data() + output.size(), + invert); + + // Shortcut: The output will be trivially always false if test_elements empty. + if (test_elements.empty()) { + return output; + } + + for (size_t i = 0; i < num_elements; ++i) { + // Compare the string values instead of comparing just object pointers. + py::handle handle = array.data()[i]; + if (!PyUnicode_Check(handle.ptr()) && !PyBytes_Check(handle.ptr())) { + continue; + } + if (test_elements.contains(py::cast(handle))) { + output.mutable_data()[i] = !invert; + } + } + if (array.ndim() > 1) { + auto shape = + std::vector(array.shape(), array.shape() + array.ndim()); + return output.reshape(shape); + } + return output; +} + +py::array RemapMultipleArrays( + const std::vector>& arrays, + const py::dict& mapping) { + size_t array_size = arrays[0].size(); + for (const auto& array : arrays) { + if (array.size() != array_size) { + throw py::value_error("All arrays must have the same length."); + } + } + + // Create a result buffer. + auto result = py::array_t(array_size); + absl::Span result_buffer(result.mutable_data(), array_size); + PyObject* entry = PyTuple_New(arrays.size()); + if (entry == nullptr) { + throw py::error_already_set(); + } + std::vector> array_spans; + array_spans.reserve(arrays.size()); + for (const auto& array : arrays) { + array_spans.emplace_back(array.data(), array.size()); + } + + // Iterate over arrays and look up elements in the `py_dict`. + bool fail = false; + for (size_t i = 0; i < array_size; ++i) { + for (size_t j = 0; j < array_spans.size(); ++j) { + PyTuple_SET_ITEM(entry, j, array_spans[j][i]); + } + PyObject* result = PyDict_GetItem(mapping.ptr(), entry); + if (result != nullptr) { + int64_t result_value = PyLong_AsLongLong(result); + if (result_value == -1 && PyErr_Occurred()) { + fail = true; + break; + } + if (result_value > std::numeric_limits::max() || + result_value < std::numeric_limits::lowest()) { + PyErr_SetString(PyExc_OverflowError, "Result value too large."); + fail = true; + break; + } + result_buffer[i] = result_value; + } else { + PyErr_Format(PyExc_KeyError, "%R", entry); + fail = true; + break; + } + } + + for (size_t j = 0; j < array_spans.size(); ++j) { + PyTuple_SET_ITEM(entry, j, nullptr); + } + Py_XDECREF(entry); + if (fail) { + throw py::error_already_set(); + } + return result; +} + +constexpr char kRemapNumpyArrayObjects[] = R"( +Replace objects in NumPy array of objects using mapping. + +Args: + array: NumPy array with dtype=object. + mapping: Dict mapping old values to new values. + inplace: Bool (default False) whether to replace values inplace or to + create a new array. + default_value: If given, what value to map to if the mapping is missing + for that particular item. If not given, such items are left unchanged. + +Returns + NumPy array of dtype object with values replaced according to mapping. + If inplace is True the original array is modified inplace otherwise a + new array is returned. +)"; + +constexpr char kFormatFloatArrayDoc[] = R"( +Converts float -> string array with given number of decimal places. +)"; + +constexpr char kIsInDoc[] = R"( +Computes whether each element is in test_elements. + +Same use as np.isin, but much faster. If len(array) = n, len(test_elements) = m: +* This function has complexity O(n). +* np.isin with arrays of objects has complexity O(m*log(m) + n * log(m)). + +Args: + array: Input NumPy array with dtype=object. + test_elements: The values against which to test each value of array. + invert: If True, the values in the returned array are inverted, as if + calculating `element not in test_elements`. Default is False. + `isin(a, b, invert=True)` is equivalent to but faster than `~isin(a, b)`. + +Returns + A boolean array of the same shape as the input array. Each value `val` is: + * `val in test_elements` if `invert=False`, + * `val not in test_elements` if `invert=True`. +)"; + +constexpr char kRemapMultipleDoc[] = R"( +Maps keys from multiple aligned arrays to a single array. + +Args: + arrays: Numpy arrays of the same length. The tuple of aligned entries is used + as key for the mapping. + mapping: Dict mapping from tuples to integer values. + +Returns + NumPy array of dtype `int` with values looked up in mapping according to the + tuple of aligned array entries as keys. +)"; + +} // namespace + +namespace alphafold3 { + +void RegisterModuleStringArray(pybind11::module m) { + m.def( + "remap", + [](py::object array, py::object mapping, bool inplace, + py::object default_value) -> py::object { + PyObject* result = RemapNumpyArrayObjects(array.ptr(), mapping.ptr(), + inplace, default_value.ptr()); + if (result == nullptr) { + throw py::error_already_set(); + } + return py::reinterpret_steal(result); + }, + py::return_value_policy::take_ownership, py::arg("array"), + py::arg("mapping"), py::arg("inplace") = false, py::arg("default_value"), + py::doc(kRemapNumpyArrayObjects + 1)); + m.def( + "remap", + [](py::object array, py::object mapping, bool inplace) -> py::object { + PyObject* result = RemapNumpyArrayObjects(array.ptr(), mapping.ptr(), + inplace, nullptr); + if (result == nullptr) { + throw py::error_already_set(); + } + return py::reinterpret_steal(result); + }, + py::return_value_policy::take_ownership, py::arg("array"), + py::arg("mapping"), py::arg("inplace") = false, + py::doc(kRemapNumpyArrayObjects + 1)); + m.def("format_float_array", &FormatFloatArray, py::arg("values"), + py::arg("num_decimal_places"), py::doc(kFormatFloatArrayDoc + 1), + py::call_guard()); + m.def("isin", &IsIn, py::arg("array"), py::arg("test_elements"), + py::kw_only(), py::arg("invert") = false, py::doc(kIsInDoc + 1)); + m.def("remap_multiple", &RemapMultipleArrays, py::arg("arrays"), + py::arg("mapping"), py::doc(kRemapMultipleDoc + 1)); +} + +} // namespace alphafold3 diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/string_array_pybind.h b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/string_array_pybind.h new file mode 100644 index 0000000000000000000000000000000000000000..85790ddd831cd560b9340c2c8b6047e50250ec94 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/string_array_pybind.h @@ -0,0 +1,24 @@ +/* + * Copyright 2024 DeepMind Technologies Limited + * + * AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of + * this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ + * + * To request access to the AlphaFold 3 model parameters, follow the process set + * out at https://github.com/google-deepmind/alphafold3. You may only use these + * if received directly from Google. Use is subject to terms of use available at + * https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + */ + +#ifndef ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_STRING_ARRAY_PYBIND_H_ +#define ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_STRING_ARRAY_PYBIND_H_ + +#include "pybind11/pybind11.h" + +namespace alphafold3 { + +void RegisterModuleStringArray(pybind11::module m); + +} + +#endif // ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_STRING_ARRAY_PYBIND_H_ diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/mmcif.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/mmcif.py new file mode 100644 index 0000000000000000000000000000000000000000..d1b71c0281a3c4f252232183cf589cc72aebefa5 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/mmcif.py @@ -0,0 +1,333 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Low level mmCIF parsing operations and wrappers for nicer C++/Py errors. + +Note that the cif_dict.CifDict class has many useful methods to help with data +extraction which are not shown in this file. You can find them in cif_dict.clif +together with docstrings. The cif_dict.CifDict class behaves like an immutable +Python dictionary (some methods are not implemented though). +""" +from collections.abc import Callable, Mapping, Sequence +import functools +import itertools +import re +from typing import ParamSpec, TypeAlias, TypeVar + +from alphafold3.constants import chemical_components +from alphafold3.cpp import cif_dict +from alphafold3.cpp import mmcif_atom_site +from alphafold3.cpp import mmcif_struct_conn +from alphafold3.cpp import string_array +import numpy as np + +Mmcif = cif_dict.CifDict + + +_P = ParamSpec('_P') +_T = TypeVar('_T') +_WappedFn: TypeAlias = Callable[_P, _T] + + +@functools.lru_cache(maxsize=256) +def int_id_to_str_id(num: int) -> str: + """Encodes a number as a string, using reverse spreadsheet style naming. + + Args: + num: A positive integer. + + Returns: + A string that encodes the positive integer using reverse spreadsheet style, + naming e.g. 1 = A, 2 = B, ..., 27 = AA, 28 = BA, 29 = CA, ... This is the + usual way to encode chain IDs in mmCIF files. + """ + if num <= 0: + raise ValueError(f'Only positive integers allowed, got {num}.') + + num = num - 1 # 1-based indexing. + output = [] + while num >= 0: + output.append(chr(num % 26 + ord('A'))) + num = num // 26 - 1 + return ''.join(output) + + +@functools.lru_cache(maxsize=256) +def str_id_to_int_id(str_id: str) -> int: + """Encodes an mmCIF-style string chain ID as an integer. + + The integer IDs are one based so this function is the inverse of + int_id_to_str_id. + + Args: + str_id: A string chain ID consisting only of upper case letters A-Z. + + Returns: + An integer that can be used to order mmCIF chain IDs in the standard + (reverse spreadsheet style) ordering. + """ + if not re.match('^[A-Z]+$', str_id): + raise ValueError( + f'String ID must be upper case letters, got {str_id}.') + + offset = ord('A') - 1 + output = 0 + for i, c in enumerate(str_id): + output += (ord(c) - offset) * int(26**i) + return output + + +def from_string(mmcif_string: str | bytes) -> Mmcif: + return cif_dict.from_string(mmcif_string) + + +def parse_multi_data_cif(cif_string: str) -> dict[str, Mmcif]: + """Parses a CIF string with multiple data records. + + For instance, the CIF string: + + ``` + data_001 + _foo bar + # + data_002 + _foo baz + ``` + + is parsed as: + + ``` + {'001': Mmcif({'_foo': ['bar']}), '002': Mmcif({'_foo': ['baz']})} + ``` + + Args: + cif_string: The multi-data CIF string to be parsed. + + Returns: + A dictionary mapping record names to Mmcif objects with data. + """ + return cif_dict.parse_multi_data_cif(cif_string) + + +def tokenize(mmcif_string: str) -> list[str]: + return cif_dict.tokenize(mmcif_string) + + +def split_line(line: str) -> list[str]: + return cif_dict.split_line(line) + + +class BondParsingError(Exception): + """Exception raised by errors when getting bond atom indices.""" + + +def get_bond_atom_indices( + mmcif: Mmcif, + model_id: str = '1', +) -> tuple[Sequence[int], Sequence[int]]: + """Extracts the indices of the atoms that participate in bonds. + + Args: + mmcif: The mmCIF object to process. + model_id: The ID of the model that the returned atoms will belong to. This + should be a value in the mmCIF's _atom_site.pdbx_PDB_model_num column. + + Returns: + Two lists of atom indices, `from_atoms` and `to_atoms`, each one having + length num_bonds (as defined by _struct_conn, the bonds table). The bond + i, defined by the i'th row in _struct_conn, is a bond from atom at index + from_atoms[i], to the atom at index to_atoms[i]. The indices are simple + 0-based indexes into the columns of the _atom_site table in the input + mmCIF, and do not necessarily correspond to the values in _atom_site.id, + or any other column. + + Raises: + BondParsingError: If any of the required tables or columns are not present + in + the mmCIF, or if the _struct_conn table refers to atoms that cannot + be found in the _atom_site table. + """ + try: + return mmcif_struct_conn.get_bond_atom_indices(mmcif, model_id) + except ValueError as e: + raise BondParsingError(str(e)) from e + + +def get_or_infer_type_symbol( + mmcif: Mmcif, ccd: chemical_components.Ccd | None = None +) -> Sequence[str]: + """Returns the type symbol (element) for all of the atoms. + + Args: + mmcif: A parsed mmCIF file in the Mmcif format. + ccd: The chemical component dictionary. If not provided, defaults to the + cached CCD. + + If present, returns the _atom_site.type_symbol. If not, infers it using + _atom_site.label_comp_id (residue name), _atom_site.label_atom_id (atom name) + and the CCD. + """ + ccd = ccd or chemical_components.cached_ccd() + + def type_symbol_fn(res_name, atom_name): return chemical_components.type_symbol( + ccd, res_name, atom_name + ) + return mmcif_atom_site.get_or_infer_type_symbol(mmcif, type_symbol_fn) + + +def get_chain_type_by_entity_id(mmcif: Mmcif) -> Mapping[str, str]: + """Returns mapping from entity ID to its type or polymer type if available. + + If the entity is in the _entity_poly table, returns its polymer chain type. + If not, returns the type as specified in the _entity table. + + Args: + mmcif: CifDict holding the mmCIF. + """ + poly_entity_id = mmcif.get('_entity_poly.entity_id', []) + poly_type = mmcif.get('_entity_poly.type', []) + poly_type_by_entity_id = dict(zip(poly_entity_id, poly_type, strict=True)) + + chain_type_by_entity_id = {} + for entity_id, entity_type in zip( + mmcif.get('_entity.id', []), mmcif.get('_entity.type', []), strict=True + ): + chain_type = poly_type_by_entity_id.get(entity_id) or entity_type + chain_type_by_entity_id[entity_id] = chain_type + + return chain_type_by_entity_id + + +def get_internal_to_author_chain_id_map(mmcif: Mmcif) -> Mapping[str, str]: + """Returns a mapping from internal chain ID to the author chain ID. + + Note that this is not a bijection. One author chain ID can map to multiple + internal chain IDs. For example, a protein chain and a ligand bound to it will + share the same author chain ID, but they will each have a unique internal + chain ID). + + Args: + mmcif: CifDict holding the mmCIF. + """ + return mmcif_atom_site.get_internal_to_author_chain_id_map(mmcif) + + +def get_experimental_method(mmcif: Mmcif) -> str | None: + field = '_exptl.method' + return ','.join(mmcif[field]).lower() if field in mmcif else None + + +def get_release_date(mmcif: Mmcif) -> str | None: + """Returns the oldest revision date.""" + if '_pdbx_audit_revision_history.revision_date' not in mmcif: + return None + + # Release dates are ISO-8601, hence sort well. + return min(mmcif['_pdbx_audit_revision_history.revision_date']) + + +def get_resolution(mmcif: Mmcif) -> float | None: + """Returns the resolution of the structure. + + More than one resolution can be reported in an mmCIF. This function returns + the first one (in the order _refine.ls_d_res_high, + _em_3d_reconstruction.resolution, _reflns.d_resolution_high) that appears + in the mmCIF as is parseable as a float. + + Args: + mmcif: An `Mmcif` object. + + Returns: + The resolution as reported in the mmCIF. + """ + for res_key in ('_refine.ls_d_res_high', + '_em_3d_reconstruction.resolution', + '_reflns.d_resolution_high'): + if res_key in mmcif: + try: + raw_resolution = mmcif[res_key][0] + return float(raw_resolution) + except ValueError: + continue + return None + + +def parse_oper_expr(oper_expression: str) -> list[tuple[str, ...]]: + """Determines which transforms to apply based on an MMCIF oper_expression str. + + Args: + oper_expression: the field oper_expression from MMCIF format data. + Transform ids may be either numbers or single letters. Hyphens are used to + denote a numeric range of transforms to apply, and commas are used to + delimit a sequence of transforms. Where two sets of parentheses are + adjacent without a comma, the two sets of transforms should be combined as + a cartesian product, i.e. all possible pairs. + example 1,2,3 -> generate 3 copies of each chain by applying 1, 2 or 3. + example (1-3) -> generate 3 copies of each chain by applying 1, 2 or 3. + example (1-3)(4-6) -> generate 9 copies of each chain by applying one of + [(1,4), (1,5), (1,6), + (2,4), (2,5), (2,6), + (3,4), (3,5), (3,6)] + example (P) -> apply transform with id P. + + Raises: + ValueError: Failure to parse oper_expression. + + Returns: + A list with one element for each chain copy that should be generated. + Each element is a list of transform ids to apply. + """ + # Expand ranges, e.g. 1-4 -> 1,2,3,4. + def range_expander(match): + return ','.join( + [str(i) for i in range(int(match.group(1)), + int(match.group(2)) + 1)]) + + ranges_expanded = re.sub(r'\b(\d+)-(\d+)', range_expander, oper_expression) + + if re.fullmatch(r'(\w+,)*\w+', ranges_expanded): + # No brackets, just a single range, e.g. "1,2,3". + return [(t,) for t in ranges_expanded.split(',')] + elif re.fullmatch(r'\((\w+,)*\w+\)', ranges_expanded): + # Single range in brackets, e.g. "(1,2,3)". + return [(t,) for t in ranges_expanded[1:-1].split(',')] + elif re.fullmatch(r'\((\w+,)*\w+\)\((\w+,)*\w+\)', ranges_expanded): + # Cartesian product of two ranges, e.g. "(1,2,3)(4,5)". + part1, part2 = ranges_expanded[1:-1].split(')(') + return list(itertools.product(part1.split(','), part2.split(','))) + else: + raise ValueError( + f'Unsupported oper_expression format: {oper_expression}') + + +def format_float_array( + values: np.ndarray, num_decimal_places: int) -> Sequence[str]: + """Converts 1D array to a list of strings with the given number of decimals. + + This function is faster than converting via Python list comprehension, e.g.: + atoms_x = ['%.3f' % x for x in atoms_x] + + Args: + values: A numpy array with values to convert. This array is casted to + float32 before doing the conversion. + num_decimal_places: The number of decimal points to keep, including trailing + zeros. E.g. for 1.07 and num_decimal_places=1: 1.1, + num_decimal_places=2: 1.07, num_decimal_places=3: 1.070. + + Returns: + A list of formatted strings. + """ + if values.ndim != 1: + raise ValueError(f'The given array must be 1D, got {values.ndim}D') + + return string_array.format_float_array( + values=values.astype(np.float32), num_decimal_places=num_decimal_places + ) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/parsing.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/parsing.py new file mode 100644 index 0000000000000000000000000000000000000000..e449cbcf6a3bcf4ab1fd240efed079f4b347b8f3 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/parsing.py @@ -0,0 +1,1805 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Module for parsing various data sources and producing Structures.""" + +from collections.abc import Collection, Mapping, MutableMapping, Sequence +import dataclasses +import datetime +import enum +import functools +import itertools +from typing import TypeAlias +import numpy as np +from alphafold3.constants import chemical_components +from alphafold3.constants import mmcif_names +from alphafold3.constants import residue_names +from alphafold3.cpp import mmcif_utils +from alphafold3.cpp import string_array +from alphafold3.structure import bioassemblies +from alphafold3.structure import bonds +from alphafold3.structure import chemical_components as struc_chem_comps +from alphafold3.structure import mmcif +from alphafold3.structure import structure +from alphafold3.structure import structure_tables + + +ChainIndex: TypeAlias = int +ResIndex: TypeAlias = int +AtomName: TypeAlias = str +BondAtomId: TypeAlias = tuple[ChainIndex, ResIndex, AtomName] + +_INSERTION_CODE_REMAP: Mapping[str, str] = {'.': '?'} + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class BondIndices: + from_indices: list[int] + dest_indices: list[int] + + +@enum.unique +class ModelID(enum.Enum): + """Values for specifying model IDs when parsing.""" + + FIRST = 1 # The first model in the file. + ALL = 2 # All models in the file. + + +@enum.unique +class SequenceFormat(enum.Enum): + """The possible formats for an input sequence.""" + + FASTA = 'fasta' # One-letter code used in FASTA. + # Multiple-letter chemical components dictionary ids. + CCD_CODES = 'ccd_codes' + LIGAND_SMILES = 'ligand_smiles' # SMILES string defining a molecule. + + +def _create_bond_lookup( + bonded_atom_pairs: Sequence[tuple[BondAtomId, BondAtomId]], +) -> Mapping[tuple[ChainIndex, ResIndex], Mapping[AtomName, BondIndices]]: + """Creates maps to help find bonds during a loop over residues.""" + bond_lookup = {} + for bond_i, (from_atom_id, dest_atom_id) in enumerate(bonded_atom_pairs): + from_chain_i, from_res_i, from_atom_name = from_atom_id + dest_chain_i, dest_res_i, dest_atom_name = dest_atom_id + bonds_by_from_atom_name = bond_lookup.setdefault( + (from_chain_i, from_res_i), {} + ) + bonds_by_dest_atom_name = bond_lookup.setdefault( + (dest_chain_i, dest_res_i), {} + ) + bonds_by_from_atom_name.setdefault( + from_atom_name, BondIndices(from_indices=[], dest_indices=[]) + ).from_indices.append(bond_i) + bonds_by_dest_atom_name.setdefault( + dest_atom_name, BondIndices(from_indices=[], dest_indices=[]) + ).dest_indices.append(bond_i) + return bond_lookup + + +def _get_atom_element( + ccd: chemical_components.Ccd, res_name: str, atom_name: str +) -> str: + return ( + chemical_components.type_symbol( + ccd=ccd, res_name=res_name, atom_name=atom_name + ) + or '?' + ) + + +def _get_representative_atom( + ccd: chemical_components.Ccd, + res_name: str, + chain_type: str, + sequence_format: SequenceFormat, +) -> tuple[str, str]: + match sequence_format: + case SequenceFormat.CCD_CODES: + atom_name = _get_first_non_leaving_atom(ccd=ccd, res_name=res_name) + atom_element = _get_atom_element( + ccd=ccd, res_name=res_name, atom_name=atom_name + ) + return atom_name, atom_element + case SequenceFormat.LIGAND_SMILES: + return '', '?' + case SequenceFormat.FASTA: + if chain_type in mmcif_names.PEPTIDE_CHAIN_TYPES: + return 'CA', 'C' + if chain_type in mmcif_names.NUCLEIC_ACID_CHAIN_TYPES: + return "C1'", 'C' + else: + raise ValueError(chain_type) + case _: + raise ValueError(sequence_format) + + +@functools.lru_cache(maxsize=128) +def _get_first_non_leaving_atom( + ccd: chemical_components.Ccd, res_name: str +) -> str: + """Returns first definitely non-leaving atom if exists, as a stand-in.""" + all_atoms = struc_chem_comps.get_all_atoms_in_entry(ccd, res_name=res_name)[ + '_chem_comp_atom.atom_id' + ] + representative_atom = all_atoms[0] + if representative_atom == 'O1' and len(all_atoms) > 1: + representative_atom = all_atoms[1] + return representative_atom + + +def _add_ligand_to_chem_comp( + chem_comp: MutableMapping[str, struc_chem_comps.ChemCompEntry], + ligand_id: str, + ligand_smiles: str, +): + """Adds a ligand to chemical components. Raises ValueError on mismatch.""" + new_entry = struc_chem_comps.ChemCompEntry( + type='non-polymer', pdbx_smiles=ligand_smiles + ) + + existing_entry = chem_comp.get(ligand_id) + if existing_entry is None: + chem_comp[ligand_id] = new_entry + elif existing_entry != new_entry: + raise ValueError( + f'Mismatching data for ligand {ligand_id}: ' + f'{new_entry} != {existing_entry}' + ) + + +def _get_first_model_id(cif: mmcif.Mmcif) -> str: + """Returns cheaply the first model ID from the mmCIF.""" + return cif.get_array( + '_atom_site.pdbx_PDB_model_num', dtype=object, gather=slice(1) + )[0] + + +def _get_str_model_id( + cif: mmcif.Mmcif, + model_id: ModelID | int, +) -> str: + """Converts a user-specified model_id argument into a string.""" + match model_id: + case int(): + str_model_id = str(model_id) + case enum.Enum(): + # We compare the enum's value attribute since regular enum comparison + # breaks when adhoc importing. + match model_id.value: + case ModelID.FIRST.value: + str_model_id = _get_first_model_id(cif) + case ModelID.ALL.value: + str_model_id = '' + case _: + raise ValueError( + f'Model ID {model_id} with value {model_id.value} not recognized.' + ) + case _: + raise ValueError( + f'Model ID {model_id} with type {type(model_id)} not recognized.' + ) + return str_model_id + + +def _parse_bonds( + cif: mmcif.Mmcif, + atom_key: np.ndarray, + model_id: str, +) -> bonds.Bonds: + """Returns the bonds table extracted from the mmCIF. + + Args: + cif: The raw mmCIF to extract the bond information from. + atom_key: A numpy array defining atom key for each atom in _atom_site. Note + that the atom key must be computed before resolving alt-locs since this + function operates on the raw mmCIF! + model_id: The ID of the model to get bonds for. + """ + if '_struct_conn.id' not in cif: + # This is the category key item for the _struct_conn table, therefore + # we use it to determine whether to parse bond info. + return bonds.Bonds.make_empty() + from_atom, dest_atom = mmcif.get_bond_atom_indices(cif, model_id) + from_atom = np.array(from_atom, dtype=np.int64) + dest_atom = np.array(dest_atom, dtype=np.int64) + num_bonds = from_atom.shape[0] + bond_key = np.arange(num_bonds, dtype=np.int64) + bond_type = cif.get_array('_struct_conn.conn_type_id', dtype=object) + if '_struct_conn.pdbx_role' in cif: # This column isn't always present. + bond_role = cif.get_array('_struct_conn.pdbx_role', dtype=object) + else: + bond_role = np.full((num_bonds,), '?', dtype=object) + + bonds_mask = np.ones((num_bonds,), dtype=bool) + # Symmetries other than 1_555 imply the atom is not part of the asymmetric + # unit, and therefore this is a bond that only exists in the expanded + # bioassembly. + # We do not currently support parsing these types of bonds. + if '_struct_conn.ptnr1_symmetry' in cif: + ptnr1_symmetry = cif.get_array( + '_struct_conn.ptnr1_symmetry', dtype=object) + np.logical_and(bonds_mask, ptnr1_symmetry == '1_555', out=bonds_mask) + if '_struct_conn.ptnr2_symmetry' in cif: + ptnr2_symmetry = cif.get_array( + '_struct_conn.ptnr2_symmetry', dtype=object) + np.logical_and(bonds_mask, ptnr2_symmetry == '1_555', out=bonds_mask) + # Remove bonds that involve atoms that are not part of the structure, + # e.g. waters if include_water=False. In a rare case this also removes invalid + # bonds that are indicated by a key that is set to _atom_site size. + np.logical_and(bonds_mask, np.isin(from_atom, atom_key), out=bonds_mask) + np.logical_and(bonds_mask, np.isin(dest_atom, atom_key), out=bonds_mask) + return bonds.Bonds( + key=bond_key[bonds_mask], + type=bond_type[bonds_mask], + role=bond_role[bonds_mask], + from_atom_key=from_atom[bonds_mask], + dest_atom_key=dest_atom[bonds_mask], + ) + + +@dataclasses.dataclass(frozen=True, slots=True) +class _MmcifHeader: + name: str + resolution: float | None + release_date: datetime.date | None + structure_method: str | None + bioassembly_data: bioassemblies.BioassemblyData | None + chemical_components_data: struc_chem_comps.ChemicalComponentsData | None + + +def _get_mmcif_header( + cif: mmcif.Mmcif, + fix_mse: bool, + fix_unknown_dna: bool, +) -> _MmcifHeader: + """Extract header fields from an mmCIF object.""" + name = cif.get_data_name() + resolution = mmcif.get_resolution(cif) + + release_date = mmcif.get_release_date(cif) + if release_date is not None: + release_date = datetime.date.fromisoformat(release_date) + + experiments = cif.get('_exptl.method') + structure_method = ','.join(experiments) if experiments else None + + try: + bioassembly_data = bioassemblies.BioassemblyData.from_mmcif(cif) + except bioassemblies.MissingBioassemblyDataError: + bioassembly_data = None + + try: + chemical_components_data = ( + struc_chem_comps.ChemicalComponentsData.from_mmcif( + cif, fix_mse=fix_mse, fix_unknown_dna=fix_unknown_dna + ) + ) + except struc_chem_comps.MissingChemicalComponentsDataError: + chemical_components_data = None + + return _MmcifHeader( + name=name, + resolution=resolution, + release_date=release_date, + structure_method=structure_method, + bioassembly_data=bioassembly_data, + chemical_components_data=chemical_components_data, + ) + + +def from_parsed_mmcif( + mmcif_object: mmcif.Mmcif, + *, + name: str | None = None, + fix_mse_residues: bool = False, + fix_arginines: bool = False, + fix_unknown_dna: bool = False, + include_water: bool = False, + include_other: bool = False, + include_bonds: bool = False, + model_id: int | ModelID = ModelID.FIRST, +) -> structure.Structure: + """Construct a Structure from a parsed mmCIF object. + + This function is called by `from_mmcif` but can be useful when an mmCIF has + already been parsed e.g. to extract extra information from the header before + then converting to Structure for further manipulation. + + Args: + mmcif_object: A parsed mmcif.Mmcif object. + name: Optional name for the structure. If not provided, the name will be + taken from the mmCIF data_ field. + fix_mse_residues: If True, selenium atom sites (SE) in selenomethionine + (MSE) residues will be changed to sulphur atom sites (SD). This is because + methionine (MET) residues are often replaced with MSE to aid X-Ray + crystallography. If False, the SE MSE atom sites won't be modified. + fix_arginines: If True, NH1 and NH2 in arginine will be swapped if needed so + that NH1 is always closer to CD than NH2. If False, no atom sites in + arginine will be touched. Note that HH11, HH12, HH21, HH22 are fixed too. + fix_unknown_dna: If True, residues with name N in DNA chains will have their + res_name replaced with DN. Atoms are not changed. + include_water: If True, water (HOH) molecules will be parsed. Water + molecules may be grouped into chains, where number of residues > 1. Water + molecules are usually grouped into chains but do not necessarily all share + the same chain ID. + include_other: If True, all other atoms that are not included by any of the + above parameters will be included. This covers e.g. "polypeptide(D)" and + "macrolide" entities, as well as all other non-standard types. + include_bonds: If True, bond information will be parsed from the mmCIF and + stored in the Structure. + model_id: Either the integer model ID to parse, or one of ModelID.FIRST to + parse the first model, or ModelID.ALL to parse all models. + + Returns: + A Structure representation of the mmCIF object. + """ + str_model_id = _get_str_model_id(cif=mmcif_object, model_id=model_id) + header = _get_mmcif_header( + mmcif_object, fix_mse=fix_mse_residues, fix_unknown_dna=fix_unknown_dna + ) + + chains, residues, atoms = get_tables( + cif=mmcif_object, + fix_mse_residues=fix_mse_residues, + fix_arginines=fix_arginines, + fix_unknown_dna=fix_unknown_dna, + include_water=include_water, + include_other=include_other, + model_id=str_model_id, + ) + + if include_bonds: + # NB: parsing the atom table before the bonds table allows for a more + # informative error message when dealing with bad multi-model mmCIFs. + # We also ensure that we always use a specific model ID, even when parsing + # all models. + if str_model_id == '': # pylint: disable=g-explicit-bool-comparison + bonds_model_id = _get_first_model_id(mmcif_object) + else: + bonds_model_id = str_model_id + + bonds_table = _parse_bonds( + mmcif_object, + atom_key=atoms.key, + model_id=bonds_model_id, + ) + else: + bonds_table = bonds.Bonds.make_empty() + + return structure.Structure( + name=name if name is not None else header.name, + resolution=header.resolution, + release_date=header.release_date, + structure_method=header.structure_method, + bioassembly_data=header.bioassembly_data, + chemical_components_data=header.chemical_components_data, + bonds=bonds_table, + chains=chains, + residues=residues, + atoms=atoms, + ) + + +def from_mmcif( + mmcif_string: str | bytes, + *, + name: str | None = None, + fix_mse_residues: bool = False, + fix_arginines: bool = False, + fix_unknown_dna: bool = False, + include_water: bool = False, + include_other: bool = False, + include_bonds: bool = False, + model_id: int | ModelID = ModelID.FIRST, +) -> structure.Structure: + """Construct a Structure from a mmCIF string. + + Args: + mmcif_string: The string contents of an mmCIF file. + name: Optional name for the structure. If not provided, the name will be + taken from the mmCIF data_ field. + fix_mse_residues: If True, selenium atom sites (SE) in selenomethionine + (MSE) residues will be changed to sulphur atom sites (SD). This is because + methionine (MET) residues are often replaced with MSE to aid X-Ray + crystallography. If False, the SE MSE atom sites won't be modified. + fix_arginines: If True, NH1 and NH2 in arginine will be swapped if needed so + that NH1 is always closer to CD than NH2. If False, no atom sites in + arginine will be touched. Note that HH11, HH12, HH21, HH22 are fixed too. + fix_unknown_dna: If True, residues with name N in DNA chains will have their + res_name replaced with DN. Atoms are not changed. + include_water: If True, water (HOH) molecules will be parsed. Water + molecules may be grouped into chains, where number of residues > 1. Water + molecules are usually grouped into chains but do not necessarily all share + the same chain ID. + include_other: If True, all other atoms that are not included by any of the + above parameters will be included. This covers e.g. "polypeptide(D)" and + "macrolide" entities, as well as all other non-standard types. + include_bonds: If True, bond information will be parsed from the mmCIF and + stored in the Structure. + model_id: Either the integer model ID to parse, or one of ModelID.FIRST to + parse the first model, or ModelID.ALL to parse all models. + + Returns: + A Structure representation of the mmCIF string. + """ + mmcif_object = mmcif.from_string(mmcif_string) + + return from_parsed_mmcif( + mmcif_object, + name=name, + fix_mse_residues=fix_mse_residues, + fix_arginines=fix_arginines, + fix_unknown_dna=fix_unknown_dna, + include_water=include_water, + include_other=include_other, + include_bonds=include_bonds, + model_id=model_id, + ) + + +def from_res_arrays(atom_mask: np.ndarray, **kwargs) -> structure.Structure: + """Returns Structure created from from arrays with a residue dimension. + + All unset fields are filled with defaults (e.g. 1.0 for occupancy) or + unset/unknown values (e.g. UNK for residue type, or '.' for atom element). + + Args: + atom_mask: A array with shape (num_res, num_atom). This is used to decide + which atoms in the atom dimension are present in a given residue. Present + atoms should have a nonzero value, e.g. 1.0 or True. + **kwargs: A mapping from field name to values. For all array-valued fields + these arrays must have a dimension of length num_res. Chain and residue + fields should have this as their only dimension and atom fields should be + shaped (num_res, num_atom). Coordinate fields may also have arbitrary + leading dimensions (they must be the same across all coordinate fields). + See structure.{CHAIN,RESIDUE,ATOM}_FIELDS for a list of allowed fields. + """ + num_res, num_atom = atom_mask.shape + included_indices = np.flatnonzero(atom_mask) + + array_fields = ( + structure.CHAIN_FIELDS.keys() + | structure.RESIDUE_FIELDS.keys() + | structure.ATOM_FIELDS.keys() + ) + initializer_kwargs = {} + fields = {} + for k, val in kwargs.items(): + if k not in array_fields: + # The kwarg key isn't an array field name. Such kwargs are forwarded as-is + # to the constructor. They are expected to be global fields (e.g. name). + # Other values will raise an error when the constructor is called. + if k in structure.TABLE_FIELDS: + raise ValueError(f'Table fields must not be set. Got {k}.') + initializer_kwargs[k] = val + continue + elif val is None: + raise ValueError(f'{k} must be non-None.') + + if not isinstance(val, np.ndarray): + raise TypeError( + f'Value for {k} must be a NumPy array. Got {type(val)}.') + if k in structure.CHAIN_FIELDS or k in structure.RESIDUE_FIELDS: + if val.shape != (num_res,): + raise ValueError( + f'{k} must have shape ({num_res=},). Got {val.shape=}.' + ) + # Do not reshape the chain/residue arrays, they have the shape we need. + fields[k] = val + else: + assert k in structure.ATOM_FIELDS + if val.shape[-2:] != (num_res, num_atom): + raise ValueError( + f'{k} must have final two dimensions of length ' + f'{(num_res, num_atom)=}. Got {val.shape=}.' + ) + leading_dims = val.shape[:-2] + flat_val = val.reshape(leading_dims + (-1,), order='C') + masked_val = flat_val[..., included_indices] + fields[k] = masked_val + + # Get chain IDs or assume this is a single-chain structure. + chain_id = kwargs.get('chain_id', np.array(['A'] * num_res, dtype=object)) + # Find chain starts in res-sized arrays, use these to make chain-sized arrays. + chain_start = np.concatenate( + ([0], np.where(chain_id[1:] != chain_id[:-1])[0] + 1) + ) + if len(set(chain_id)) != len(chain_start): + raise ValueError(f'Chain IDs must be contiguous, but got {chain_id}') + + chain_lengths = np.diff(chain_start, append=len(chain_id)) + chain_key = np.repeat(np.arange(len(chain_start)), chain_lengths) + + chain_entity_id = fields.get('chain_entity_id') + if chain_entity_id is not None: + entity_id = chain_entity_id[chain_entity_id] + else: + entity_id = np.array( + [str(mmcif.str_id_to_int_id(cid)) + for cid in chain_id[chain_start]], + dtype=object, + ) + chain_str_empty = np.full((num_res,), '.', dtype=object) + chains_table = structure_tables.Chains( + key=chain_key[chain_start], + id=chain_id[chain_start], + type=fields.get('chain_type', chain_str_empty)[chain_start], + auth_asym_id=fields.get('chain_auth_asym_id', chain_id)[chain_start], + entity_id=entity_id, + entity_desc=fields.get('chain_entity_desc', chain_str_empty)[ + chain_start], + ) + + # Since all arrays are residue-shaped, we can use them directly. + res_key = np.arange(num_res, dtype=np.int64) + res_id = fields.get('res_id', res_key + 1).astype(np.int32) + residues_table = structure_tables.Residues( + key=res_key, + chain_key=chain_key, + id=res_id, + name=fields.get('res_name', np.full(num_res, 'UNK', dtype=object)), + auth_seq_id=fields.get( + 'res_auth_seq_id', np.char.mod('%d', res_id).astype(object) + ), + insertion_code=fields.get( + 'res_insertion_code', np.full(num_res, '?', dtype=object) + ), + ) + + # The atom-sized arrays have already been masked and reshaped. + num_atoms_per_res = np.sum(atom_mask, axis=1, dtype=np.int32) + num_atoms_total = np.sum(num_atoms_per_res, dtype=np.int32) + # Structure is immutable, so use the same array multiple times to save RAM. + atom_str_empty = np.full(num_atoms_total, '.', dtype=object) + atom_float32_zeros = np.zeros(num_atoms_total, dtype=np.float32) + atom_float32_ones = np.ones(num_atoms_total, dtype=np.float32) + atoms_table = structure_tables.Atoms( + key=np.arange(num_atoms_total, dtype=np.int64), + chain_key=np.repeat(chain_key, num_atoms_per_res), + res_key=np.repeat(res_key, num_atoms_per_res), + name=fields.get('atom_name', atom_str_empty), + element=fields.get('atom_element', atom_str_empty), + x=fields.get('atom_x', atom_float32_zeros), + y=fields.get('atom_y', atom_float32_zeros), + z=fields.get('atom_z', atom_float32_zeros), + b_factor=fields.get('atom_b_factor', atom_float32_zeros), + occupancy=fields.get('atom_occupancy', atom_float32_ones), + ) + + return structure.Structure( + chains=chains_table, + residues=residues_table, + atoms=atoms_table, + bonds=structure_tables.Bonds.make_empty(), # Currently not set. + **initializer_kwargs, + ) + + +def expand_sequence( + sequence: str, chain_type: str, sequence_format: SequenceFormat +) -> Sequence[str]: + """Returns full residue names based on a sequence string. + + Args: + sequence: A string representing the sequence. + chain_type: The chain type of the sequence. + sequence_format: The format of the sequence argument. + """ + match sequence_format: + case SequenceFormat.FASTA: + if not all(c.isalpha() for c in sequence): + raise ValueError( + f'Sequence "{sequence}" has non-alphabetic characters') + match chain_type: + case mmcif_names.PROTEIN_CHAIN: + res_name_map = residue_names.PROTEIN_COMMON_ONE_TO_THREE + default_res_name = residue_names.UNK + case mmcif_names.RNA_CHAIN: + res_name_map = {r: r for r in residue_names.RNA_TYPES} + default_res_name = residue_names.UNK_RNA + case mmcif_names.DNA_CHAIN: + res_name_map = residue_names.DNA_COMMON_ONE_TO_TWO + default_res_name = residue_names.UNK_DNA + case _: + raise ValueError( + f'{chain_type=} not supported for FASTA format.') + return [ + res_name_map.get(one_letter_res, default_res_name) + for one_letter_res in sequence + ] + case SequenceFormat.CCD_CODES: + return sequence.strip('()').split(')(') + case SequenceFormat.LIGAND_SMILES: + ligand_id, _ = sequence.split(':', maxsplit=1) + return [ligand_id] + + +def from_sequences_and_bonds( + sequences: Sequence[str], + chain_types: Sequence[str], + sequence_formats: Sequence[SequenceFormat], + bonded_atom_pairs: Sequence[tuple[BondAtomId, BondAtomId]] | None, + ccd: chemical_components.Ccd, + name: str = 'from_sequences_and_bonds', + bond_type: str | None = None, + **constructor_args, +) -> structure.Structure: + """Returns a minimal structure for the input sequences and bonds. + + The returned structure will have at least one atom per residue. If the + residue has any bonded atoms, according to `bonded_atom_pairs`, then + all (and only) those atoms will be present for that residue. If the residue + is not involved in any bond then an arbitrary atom will be created. + + Args: + sequences: A sequence of strings, each one representing a single chain. + chain_types: The types of each chain, e.g. polypeptide(L). The n-th element + describes the n-th sequence in `sequences`. + sequence_formats: The format of each sequence. The n-th element describes + the n-th sequence in `sequences`. + bonded_atom_pairs: A sequence of bonded atom pairs. Each atom is described + as a tuple of (chain_index, res_index, atom_name), where the first two + values are 0-based indices. The chain_index is the index of the chain in + the `sequences` argument, and the res_index is the index of the residue in + that sequence. The atom_name is the name of the atom in the residue, e.g. + CA. If the atom is not found in the standard atoms for that residue + (according to the CCD) then an error is raised. + ccd: The chemical components dictionary. + name: A name for the returned structure. + bond_type: This type will be used for all bonds in the structure, where type + follows PDB scheme, e.g. unknown (?), hydrog, metalc, covale, disulf. + **constructor_args: These arguments are passed directly to the + structure.Structure constructor. + """ + chain_id = [] + chain_type = [] + chain_res_count = [] + res_id = [] + res_name = [] + res_atom_count = [] + atom_name = [] + atom_element = [] + chem_comp = {} + + num_bonds = len(bonded_atom_pairs or ()) + from_atom_key = np.full((num_bonds,), -1, dtype=np.int64) + dest_atom_key = np.full((num_bonds,), -1, dtype=np.int64) + + # Create map (chain_i, res_i) -> {atom_name -> (from_idxs dest_idxs)}. + # This allows quick lookup of whether a residue has any bonded atoms, and + # which bonds those atoms participate in. + bond_lookup = _create_bond_lookup(bonded_atom_pairs or ()) + + current_atom_key = 0 + for chain_i, (sequence, curr_chain_type, sequence_format) in enumerate( + zip(sequences, chain_types, sequence_formats, strict=True) + ): + current_chain_id = mmcif.int_id_to_str_id(chain_i + 1) + num_chain_residues = 0 + for res_i, full_res_name in enumerate( + expand_sequence(sequence, curr_chain_type, sequence_format) + ): + current_res_id = res_i + 1 + num_res_atoms = 0 + + # Look for bonded atoms in the bond lookup and if any are found, add + # their atom keys to the bond atom_key columns. + if bond_indices_by_atom_name := bond_lookup.get((chain_i, res_i)): + for bond_atom_name, bond_indices in bond_indices_by_atom_name.items(): + atom_name.append(bond_atom_name) + atom_element.append( + _get_atom_element( + ccd=ccd, res_name=full_res_name, atom_name=bond_atom_name + ) + ) + for from_bond_i in bond_indices.from_indices: + from_atom_key[from_bond_i] = current_atom_key + for dest_bond_i in bond_indices.dest_indices: + dest_atom_key[dest_bond_i] = current_atom_key + current_atom_key += 1 + num_res_atoms += 1 + else: + # If this residue has no bonded atoms then we need to add one atom + # like in from_sequences. + assert num_res_atoms == 0 + rep_atom_name, rep_atom_element = _get_representative_atom( + ccd=ccd, + res_name=full_res_name, + chain_type=curr_chain_type, + sequence_format=sequence_format, + ) + atom_name.append(rep_atom_name) + atom_element.append(rep_atom_element) + num_res_atoms += 1 + current_atom_key += 1 + + if sequence_format == SequenceFormat.LIGAND_SMILES: + # Sequence expect to be in the format :, + # which always corresponds to a single-residue chain. + ligand_id, ligand_smiles = sequence.split(':', maxsplit=1) + if ccd.get(ligand_id) is not None: + raise ValueError( + f'Ligand name {ligand_id} is in CCD - it is not supported to give' + ' ligands created from SMILES the same name as CCD components.' + ) + # We need to provide additional chemical components metadata for + # ligands specified via SMILES strings since they might not be in CCD. + _add_ligand_to_chem_comp(chem_comp, ligand_id, ligand_smiles) + + assert num_res_atoms >= 1 + res_atom_count.append(num_res_atoms) + num_chain_residues += 1 + res_id.append(current_res_id) + res_name.append(full_res_name) + + chain_id.append(current_chain_id) + chain_type.append(curr_chain_type) + chain_res_count.append(num_chain_residues) + + chem_comp_data = struc_chem_comps.ChemicalComponentsData(chem_comp) + chem_comp_data = struc_chem_comps.populate_missing_ccd_data( + ccd=ccd, + chemical_components_data=chem_comp_data, + chemical_component_ids=set(res_name), + ) + + if bonded_atom_pairs is not None: + unknown_bond_col = np.full((num_bonds,), '?', dtype=object) + if bond_type is None: + bond_type_col = unknown_bond_col + else: + bond_type_col = np.full((num_bonds,), bond_type, dtype=object) + bonds_table = bonds.Bonds( + key=np.arange(num_bonds, dtype=np.int64), + type=bond_type_col, + role=unknown_bond_col, + from_atom_key=from_atom_key, + dest_atom_key=dest_atom_key, + ) + else: + bonds_table = structure_tables.Bonds.make_empty() + + # 1 chain per sequence. + chain_key = np.arange(len(sequences), dtype=np.int64) + chain_id = np.array(chain_id, dtype=object) + chains_table = structure_tables.Chains( + key=chain_key, + id=chain_id, + type=np.array(chain_type, dtype=object), + auth_asym_id=chain_id, + entity_id=np.char.mod('%d', chain_key + 1).astype(object), + entity_desc=np.array(['.'] * len(chain_key), dtype=object), + ) + + res_key = np.arange(len(res_name), dtype=np.int64) + res_chain_key = np.repeat(chain_key, chain_res_count) + residues_table = structure_tables.Residues( + key=res_key, + chain_key=res_chain_key, + id=np.array(res_id, dtype=np.int32), + name=np.array(res_name, dtype=object), + auth_seq_id=np.char.mod('%d', res_id).astype(object), + insertion_code=np.full(len(res_name), '?', dtype=object), + ) + + num_atoms = current_atom_key + atom_float32_zeros = np.zeros(num_atoms, dtype=np.float32) + atoms_table = structure_tables.Atoms( + key=np.arange(num_atoms, dtype=np.int64), + chain_key=np.repeat(res_chain_key, res_atom_count), + res_key=np.repeat(res_key, res_atom_count), + name=np.array(atom_name, dtype=object), + element=np.array(atom_element, dtype=object), + x=atom_float32_zeros, + y=atom_float32_zeros, + z=atom_float32_zeros, + b_factor=atom_float32_zeros, + occupancy=np.ones(num_atoms, np.float32), + ) + + return structure.Structure( + name=name, + atoms=atoms_table, + residues=residues_table, + chains=chains_table, + bonds=bonds_table, + chemical_components_data=chem_comp_data, + **constructor_args, + ) + + +class _ChainResBuilder: + """Class for incrementally building chain and residue tables.""" + + def __init__( + self, + *, + chain_key_by_chain_id: Mapping[str, int], + entity_id_by_chain_id: Mapping[str, str], + chain_type_by_entity_id: Mapping[str, str], + entity_desc_by_entity_id: Mapping[str, str], + fix_mse_residues: bool, + fix_unknown_dna: bool, + ): + # Len: num_chains. + self.chain_key = [] + self.chain_id = [] + self.chain_type = [] + self.chain_auth_asym_id = [] + self.chain_entity_id = [] + self.chain_entity_desc = [] + + # Len: num_residues. + self.res_key = [] + self.res_chain_key = [] + self.res_id = [] + self.res_name = [] + self.res_auth_seq_id = [] + self.res_insertion_code = [] + + self.chain_key_by_chain_id = chain_key_by_chain_id + self.entity_id_by_chain_id = entity_id_by_chain_id + self.chain_type_by_entity_id = chain_type_by_entity_id + self.entity_desc_by_entity_id = entity_desc_by_entity_id + self.key_for_res: dict[tuple[str, str, str, str], int] = {} + + self._fix_mse_residues = fix_mse_residues + self._fix_unknown_dna = fix_unknown_dna + + def add_residues( + self, + *, + chain_ids: np.ndarray, + chain_auth_asym_ids: np.ndarray, + res_ids: np.ndarray, + res_names: np.ndarray, + res_auth_seq_ids: np.ndarray, + res_ins_codes: np.ndarray, + ): + """Adds a residue (and its chain) to the tables.""" + # Create chain table data. + if chain_ids.size == 0: + return + + chain_ids_with_prev = np.concatenate( + (([self.chain_id[-1] if self.chain_id else None], chain_ids)) + ) + chain_change_mask = chain_ids_with_prev[:-1] != chain_ids_with_prev[1:] + chain_change_ids = chain_ids[chain_change_mask] + chain_keys = string_array.remap( + chain_change_ids, self.chain_key_by_chain_id, inplace=False + ) + self.chain_key.extend(chain_keys) + self.chain_id.extend(chain_change_ids) + self.chain_auth_asym_id.extend(chain_auth_asym_ids[chain_change_mask]) + chain_entity_id = string_array.remap( + chain_change_ids, self.entity_id_by_chain_id, inplace=False + ) + self.chain_entity_id.extend(chain_entity_id) + chain_type = string_array.remap( + chain_entity_id, self.chain_type_by_entity_id, inplace=False + ) + self.chain_type.extend(chain_type) + chain_entity_desc = string_array.remap( + chain_entity_id, self.entity_desc_by_entity_id, inplace=False + ) + self.chain_entity_desc.extend(chain_entity_desc) + + # Create residue table data. + num_prev_res = len(self.res_id) + res_keys = np.arange(num_prev_res, num_prev_res + len(res_ids)) + res_iter = zip( + chain_ids, + res_auth_seq_ids, + res_names, + res_ins_codes, + strict=True, + ) + key_for_res_update = { + res_unique_id: res_key + for res_key, res_unique_id in enumerate(res_iter, num_prev_res) + } + self.key_for_res.update(key_for_res_update) + self.res_key.extend(res_keys) + self.res_chain_key.extend( + string_array.remap( + chain_ids, self.chain_key_by_chain_id, inplace=False) + ) + self.res_id.extend(res_ids) + self.res_name.extend(res_names) + self.res_auth_seq_id.extend(res_auth_seq_ids) + self.res_insertion_code.extend(res_ins_codes) + + def make_chains_table(self) -> structure_tables.Chains: + """Returns the Structure chains table.""" + chain_key = np.array(self.chain_key, dtype=np.int64) + if not np.all(chain_key[:-1] <= chain_key[1:]): + # If the order is inconsistent with the atoms table, sort so that it is. + order = np.argsort(self.chain_key, kind='stable') + return structure_tables.Chains( + key=chain_key[order], + id=np.array(self.chain_id, dtype=object)[order], + type=np.array(self.chain_type, dtype=object)[order], + auth_asym_id=np.array( + self.chain_auth_asym_id, dtype=object)[order], + entity_id=np.array(self.chain_entity_id, dtype=object)[order], + entity_desc=np.array( + self.chain_entity_desc, dtype=object)[order], + ) + return structure_tables.Chains( + key=chain_key, + id=np.array(self.chain_id, dtype=object), + type=np.array(self.chain_type, dtype=object), + auth_asym_id=np.array(self.chain_auth_asym_id, dtype=object), + entity_id=np.array(self.chain_entity_id, dtype=object), + entity_desc=np.array(self.chain_entity_desc, dtype=object), + ) + + def make_residues_table(self) -> structure_tables.Residues: + """Returns the Structure residues table.""" + res_name = np.array(self.res_name, dtype=object) + res_chain_key = np.array(self.res_chain_key, dtype=np.int64) + + if self._fix_mse_residues: + string_array.remap(res_name, mapping={'MSE': 'MET'}, inplace=True) + + if self._fix_unknown_dna: + # Remap residues from N -> DN in DNA chains only. + dna_chain_mask = ( + np.array(self.chain_type, dtype=object) == mmcif_names.DNA_CHAIN + ) + dna_chain_key = np.array(self.chain_key, dtype=object)[ + dna_chain_mask] + res_name[(res_name == 'N') & np.isin( + res_chain_key, dna_chain_key)] = 'DN' + + if not np.all(res_chain_key[:-1] <= res_chain_key[1:]): + # If the order is inconsistent with the atoms table, sort so that it is. + order = np.argsort(res_chain_key, kind='stable') + return structure_tables.Residues( + key=np.array(self.res_key, dtype=np.int64)[order], + chain_key=res_chain_key[order], + id=np.array(self.res_id, dtype=np.int32)[order], + name=res_name[order], + auth_seq_id=np.array(self.res_auth_seq_id, + dtype=object)[order], + insertion_code=np.array( + self.res_insertion_code, dtype=object)[order], + ) + return structure_tables.Residues( + key=np.array(self.res_key, dtype=np.int64), + chain_key=res_chain_key, + id=np.array(self.res_id, dtype=np.int32), + name=res_name, + auth_seq_id=np.array(self.res_auth_seq_id, dtype=object), + insertion_code=np.array(self.res_insertion_code, dtype=object), + ) + + +def _get_string_array_default(cif: mmcif.Mmcif, key: str, default: list[str]): + try: + return cif.get_array(key, dtype=object) + except KeyError: + return default + + +def _generate_required_tables_if_missing( + cif: mmcif.Mmcif, +) -> Mapping[str, Sequence[str]]: + """Generates all required tables and columns if missing.""" + update = {} + + atom_site_entities = _get_string_array_default( + cif, '_atom_site.label_entity_id', [] + ) + + # OpenMM produces files that don't have any of the tables and also have + # _atom_site.label_entity_id set to '?' for all atoms. We infer the entities + # based on the _atom_site.label_asym_id column. We start with cheaper O(1) + # checks to prevent running the expensive O(n) check on most files. + if ( + len(atom_site_entities) > 0 # pylint: disable=g-explicit-length-test + and '_entity.id' not in cif # Ignore if the _entity table exists. + and atom_site_entities[0] == '?' # Cheap check. + and set(atom_site_entities) == {'?'} # Expensive check. + ): + label_asym_ids = cif.get_array( + '_atom_site.label_asym_id', dtype=object) + atom_site_entities = [ + str(mmcif.str_id_to_int_id(cid)) for cid in label_asym_ids + ] + # Update _atom_site.label_entity_id to be consistent with the new tables. + update['_atom_site.label_entity_id'] = atom_site_entities + + # Check table existence by checking the presence of its primary key. + if '_struct_asym.id' not in cif: + # Infer the _struct_asym table using the _atom_site table. + asym_ids = _get_string_array_default( + cif, '_atom_site.label_asym_id', []) + + if len(atom_site_entities) == 0 or len(asym_ids) == 0: # pylint: disable=g-explicit-length-test + raise ValueError( + 'Could not parse an mmCIF with no _struct_asym table and also no ' + '_atom_site.label_entity_id or _atom_site.label_asym_id columns.' + ) + + # Deduplicate, but keep the order intact - dict.fromkeys maintains order. + entity_id_chain_id_pairs = list( + dict.fromkeys(zip(atom_site_entities, asym_ids, strict=True)) + ) + update['_struct_asym.entity_id'] = [ + e for e, _ in entity_id_chain_id_pairs] + update['_struct_asym.id'] = [c for _, c in entity_id_chain_id_pairs] + + if '_entity.id' not in cif: + # Infer the _entity_poly and _entity tables using the _atom_site table. + residues = _get_string_array_default( + cif, '_atom_site.label_comp_id', []) + group_pdb = _get_string_array_default(cif, '_atom_site.group_PDB', []) + if '_atom_site.label_entity_id' in cif: + entities = atom_site_entities + else: + # If _atom_site.label_entity_id not set, use the asym_id -> entity_id map. + asym_to_entity = dict( + zip( + cif['_struct_asym.id'], cif['_struct_asym.entity_id'], strict=True + ) + ) + entities = string_array.remap( + cif.get_array('_atom_site.label_asym_id', dtype=object), + mapping=asym_to_entity, + ) + + entity_ids = [] + entity_types = [] + entity_poly_entity_ids = [] + entity_poly_types = [] + entity_poly_table_missing = '_entity_poly.entity_id' not in cif + for entity_id, group in itertools.groupby( + zip(entities, residues, group_pdb, strict=True), key=lambda e: e[0] + ): + _, entity_residues, entity_group_pdb = zip(*group, strict=True) + entity_type = _guess_entity_type( + chain_residues=entity_residues, atom_types=entity_group_pdb + ) + entity_ids.append(entity_id) + entity_types.append(entity_type) + + if entity_poly_table_missing and entity_type == mmcif_names.POLYMER_CHAIN: + polymer_type = mmcif_names.guess_polymer_type(entity_residues) + entity_poly_entity_ids.append(entity_id) + entity_poly_types.append(polymer_type) + + update['_entity.id'] = entity_ids + update['_entity.type'] = entity_types + if entity_poly_table_missing: + update['_entity_poly.entity_id'] = entity_poly_entity_ids + update['_entity_poly.type'] = entity_poly_types + + if '_atom_site.type_symbol' not in cif: + update['_atom_site.type_symbol'] = mmcif.get_or_infer_type_symbol(cif) + + return update + + +def _maybe_add_missing_scheme_tables( + cif: mmcif.Mmcif, + res_starts: Sequence[int], + label_asym_ids: np.ndarray, + label_seq_ids: np.ndarray, + label_comp_ids: np.ndarray, + auth_seq_ids: np.ndarray, + pdb_ins_codes: np.ndarray, +) -> Mapping[str, Sequence[str]]: + """If missing, infers the scheme tables from the _atom_site table.""" + update = {} + + required_poly_seq_scheme_cols = ( + '_pdbx_poly_seq_scheme.asym_id', + '_pdbx_poly_seq_scheme.pdb_seq_num', + '_pdbx_poly_seq_scheme.pdb_ins_code', + '_pdbx_poly_seq_scheme.seq_id', + '_pdbx_poly_seq_scheme.mon_id', + '_pdbx_poly_seq_scheme.pdb_strand_id', + ) + if not all(col in cif for col in required_poly_seq_scheme_cols): + # Create a mask for atoms where each polymer residue start. + entity_id_by_chain_id = dict( + zip(cif['_struct_asym.id'], + cif['_struct_asym.entity_id'], strict=True) + ) + chain_type_by_entity_id = dict( + zip(cif['_entity.id'], cif['_entity.type'], strict=True) + ) + # Remap asym ID -> entity ID. + chain_type = string_array.remap( + label_asym_ids, mapping=entity_id_by_chain_id, inplace=False + ) + # Remap entity ID -> chain type. + string_array.remap( + chain_type, mapping=chain_type_by_entity_id, inplace=True + ) + res_mask = np.zeros_like(label_seq_ids, dtype=bool) + res_mask[res_starts] = True + res_mask &= chain_type == mmcif_names.POLYMER_CHAIN + + entity_poly_seq_cols = ( + '_entity_poly_seq.entity_id', + '_entity_poly_seq.num', + '_entity_poly_seq.mon_id', + ) + if all(col in cif for col in entity_poly_seq_cols): + # Use _entity_poly_seq if available. + poly_seq_num = cif.get_array('_entity_poly_seq.num', dtype=object) + poly_seq_mon_id = cif.get_array( + '_entity_poly_seq.mon_id', dtype=object) + poly_seq_entity_id = cif.get_array( + '_entity_poly_seq.entity_id', dtype=object + ) + label_seq_id_to_auth_seq_id = dict( + zip(label_seq_ids[res_mask], + auth_seq_ids[res_mask], strict=True) + ) + scheme_pdb_seq_num = string_array.remap( + poly_seq_num, mapping=label_seq_id_to_auth_seq_id, default_value='.' + ) + label_seq_id_to_ins_code = dict( + zip(label_seq_ids[res_mask], + pdb_ins_codes[res_mask], strict=True) + ) + scheme_pdb_ins_code = string_array.remap( + poly_seq_num, mapping=label_seq_id_to_ins_code, default_value='.' + ) + + # The _entity_poly_seq table is entity-based, while _pdbx_poly_seq_scheme + # is chain-based. A single entity could mean multiple chains (asym_ids), + # we therefore need to replicate each entity for all of the chains. + scheme_asym_id = [] + select = [] + indices = np.arange(len(poly_seq_entity_id), dtype=np.int32) + for asym_id, entity_id in zip( + cif['_struct_asym.id'], cif['_struct_asym.entity_id'], strict=True + ): + entity_mask = poly_seq_entity_id == entity_id + select.extend(indices[entity_mask]) + scheme_asym_id.extend([asym_id] * sum(entity_mask)) + + scheme_pdb_strand_id = string_array.remap( + np.array(scheme_asym_id, dtype=object), + mapping=mmcif.get_internal_to_author_chain_id_map(cif), + inplace=False, + ) + + update['_pdbx_poly_seq_scheme.asym_id'] = scheme_asym_id + update['_pdbx_poly_seq_scheme.pdb_strand_id'] = scheme_pdb_strand_id + update['_pdbx_poly_seq_scheme.pdb_seq_num'] = scheme_pdb_seq_num[select] + update['_pdbx_poly_seq_scheme.pdb_ins_code'] = scheme_pdb_ins_code[select] + update['_pdbx_poly_seq_scheme.seq_id'] = poly_seq_num[select] + update['_pdbx_poly_seq_scheme.mon_id'] = poly_seq_mon_id[select] + else: + # _entity_poly_seq not available, fallback to _atom_site. + res_asym_ids = label_asym_ids[res_mask] + res_strand_ids = string_array.remap( + array=res_asym_ids, + mapping=mmcif.get_internal_to_author_chain_id_map(cif), + inplace=False, + ) + update['_pdbx_poly_seq_scheme.asym_id'] = res_asym_ids + update['_pdbx_poly_seq_scheme.pdb_seq_num'] = auth_seq_ids[res_mask] + update['_pdbx_poly_seq_scheme.pdb_ins_code'] = pdb_ins_codes[res_mask] + update['_pdbx_poly_seq_scheme.seq_id'] = label_seq_ids[res_mask] + update['_pdbx_poly_seq_scheme.mon_id'] = label_comp_ids[res_mask] + update['_pdbx_poly_seq_scheme.pdb_strand_id'] = res_strand_ids + + required_nonpoly_scheme_cols = ( + '_pdbx_nonpoly_scheme.mon_id', + '_pdbx_nonpoly_scheme.asym_id', + '_pdbx_nonpoly_scheme.pdb_seq_num', + '_pdbx_nonpoly_scheme.pdb_ins_code', + ) + required_branch_scheme_cols = ( + '_pdbx_branch_scheme.mon_id', + '_pdbx_branch_scheme.asym_id', + '_pdbx_branch_scheme.pdb_seq_num', + ) + + # Generate _pdbx_nonpoly_scheme only if both tables are missing. + if not ( + all(col in cif for col in required_nonpoly_scheme_cols) + or all(col in cif for col in required_branch_scheme_cols) + ): + # To be strictly semantically correct, multi-residue ligands should be + # written in _pdbx_branch_scheme. However, Structure parsing handles + # correctly multi-residue ligands in _pdbx_nonpoly_scheme and the tables + # constructed here live only while parsing, hence this is unnecessary. + entity_id_by_chain_id = dict( + zip(cif['_struct_asym.id'], + cif['_struct_asym.entity_id'], strict=True) + ) + chain_type_by_entity_id = dict( + zip(cif['_entity.id'], cif['_entity.type'], strict=True) + ) + # Remap asym ID -> entity ID. + chain_type = string_array.remap( + label_asym_ids, mapping=entity_id_by_chain_id, inplace=False + ) + # Remap entity ID -> chain type. + string_array.remap( + chain_type, mapping=chain_type_by_entity_id, inplace=True + ) + res_mask = np.zeros_like(label_seq_ids, dtype=bool) + res_mask[res_starts] = True + res_mask &= chain_type != mmcif_names.POLYMER_CHAIN + + if not np.any(res_mask): + return update # Shortcut: no non-polymer residues. + + ins_codes = string_array.remap( + pdb_ins_codes[res_mask], mapping={'?': '.'}, inplace=False + ) + + update['_pdbx_nonpoly_scheme.asym_id'] = label_asym_ids[res_mask] + update['_pdbx_nonpoly_scheme.pdb_seq_num'] = auth_seq_ids[res_mask] + update['_pdbx_nonpoly_scheme.pdb_ins_code'] = ins_codes + update['_pdbx_nonpoly_scheme.mon_id'] = label_comp_ids[res_mask] + + return update + + +def _get_chain_key_by_chain_id( + resolved_chain_ids: np.ndarray, struct_asym_chain_ids: np.ndarray +) -> Mapping[str, int]: + """Returns chain key for each chain ID respecting resolved chain ordering.""" + # Check that all chain IDs found in the (potentially filtered) _atom_site + # table are present in the _struct_asym table. + unique_resolved_chain_ids = set(resolved_chain_ids) + if not unique_resolved_chain_ids.issubset(set(struct_asym_chain_ids)): + unique_resolved_chain_ids = sorted(unique_resolved_chain_ids) + unique_struct_asym_chain_ids = sorted(set(struct_asym_chain_ids)) + raise ValueError( + 'Bad mmCIF: chain IDs in _atom_site.label_asym_id ' + f'{unique_resolved_chain_ids} is not a subset of chain IDs in ' + f'_struct_asym.id {unique_struct_asym_chain_ids}.' + ) + + resolved_mask = string_array.isin( + struct_asym_chain_ids, unique_resolved_chain_ids + ) + # For all resolved chains, use the _atom_site order they appear in. E.g. + # resolved_chain_ids = [B A E D F] + # struct_asym_chain_ids = [A B C D E F] + # consistent_chain_order = [B A C E D F] + # chain_keys = [0 1 2 3 4 5] + consistent_chain_order = struct_asym_chain_ids.copy() + consistent_chain_order[resolved_mask] = resolved_chain_ids + return dict(zip(consistent_chain_order, range(len(struct_asym_chain_ids)))) + + +def get_tables( + cif: mmcif.Mmcif, + fix_mse_residues: bool, + fix_arginines: bool, + fix_unknown_dna: bool, + include_water: bool, + include_other: bool, + model_id: str, +) -> tuple[ + structure_tables.Chains, structure_tables.Residues, structure_tables.Atoms +]: + """Returns chain, residue, and atom tables from a parsed mmcif. + + Args: + cif: A parsed mmcif.Mmcif. + fix_mse_residues: See from_mmcif. + fix_arginines: See from_mmcif. + fix_unknown_dna: See from_mmcif. + include_water: See from_mmcif. + include_other: See from_mmcif. + model_id: A string defining which model ID to use. If set, only coordinates, + b-factors and occupancies for the given model are returned. If empty, + coordinates, b-factors and occupanciesall for models are returned with a + leading dimension of num_models. Note that the model_id argument in + from_mmcif is an integer and has slightly different use (see from_mmcif). + """ + # Add any missing tables and columns we require for parsing. + if cif_update := _generate_required_tables_if_missing(cif): + cif = cif.copy_and_update(cif_update) + + # Resolve alt-locs, selecting only a single option for each residue. Also + # computes the layout, which defines where chain and residue boundaries are. + atom_site_all_models, layout = mmcif_utils.filter( + cif, + include_nucleotides=True, + include_ligands=True, + include_water=include_water, + include_other=include_other, + model_id=model_id, + ) + atom_site_first_model = atom_site_all_models[0] + + # Get atom information from the _atom_site table. + def _first_model_string_array(col: str) -> np.ndarray: + return cif.get_array(col, dtype=object, gather=atom_site_first_model) + + def _requested_models_float_array(col: str) -> np.ndarray: + if not model_id: + # Return data for all models with a leading dimension of num_models. + return cif.get_array(col, dtype=np.float32, gather=atom_site_all_models) + else: + # Return data only for the single requested model. + return cif.get_array(col, dtype=np.float32, gather=atom_site_first_model) + + # These columns are the same for all models, fetch them just for the 1st one. + label_comp_ids = _first_model_string_array('_atom_site.label_comp_id') + label_asym_ids = _first_model_string_array('_atom_site.label_asym_id') + label_seq_ids = _first_model_string_array('_atom_site.label_seq_id') + label_atom_ids = _first_model_string_array('_atom_site.label_atom_id') + if '_atom_site.auth_seq_id' in cif: + auth_seq_ids = _first_model_string_array('_atom_site.auth_seq_id') + else: + # auth_seq_id unset, fallback to label_seq_id. + auth_seq_ids = label_seq_ids + type_symbols = _first_model_string_array('_atom_site.type_symbol') + pdbx_pdb_ins_codes = _first_model_string_array( + '_atom_site.pdbx_PDB_ins_code') + + # These columns are different for all models, fetch them as requested. + atom_x = _requested_models_float_array('_atom_site.Cartn_x') + atom_y = _requested_models_float_array('_atom_site.Cartn_y') + atom_z = _requested_models_float_array('_atom_site.Cartn_z') + atom_b_factor = _requested_models_float_array('_atom_site.B_iso_or_equiv') + atom_occupancy = _requested_models_float_array('_atom_site.occupancy') + + # Make sure the scheme (residue) tables exist in case they are not present. + if cif_update := _maybe_add_missing_scheme_tables( + cif, + res_starts=layout.residue_starts(), + label_asym_ids=label_asym_ids, + label_seq_ids=label_seq_ids, + label_comp_ids=label_comp_ids, + auth_seq_ids=auth_seq_ids, + pdb_ins_codes=pdbx_pdb_ins_codes, + ): + cif = cif.copy_and_update(cif_update) + + # Fix common issues found in mmCIF files, like swapped arginine NH atoms. + mmcif_utils.fix_residues( + layout, + comp_id=label_comp_ids, + atom_id=label_atom_ids, + atom_x=atom_x[0] if not model_id else atom_x, + atom_y=atom_y[0] if not model_id else atom_y, + atom_z=atom_z[0] if not model_id else atom_z, + fix_arg=fix_arginines, + ) + + # Get keys for chains in the order they appear in _atom_site while also + # dealing with empty chains. + resolved_chain_ids = label_asym_ids[layout.chain_starts()] + struct_asym_chain_ids = cif.get_array('_struct_asym.id', dtype=object) + + chain_key_by_chain_id = _get_chain_key_by_chain_id( + resolved_chain_ids=resolved_chain_ids, + struct_asym_chain_ids=struct_asym_chain_ids, + ) + entity_id_by_chain_id = dict( + zip(struct_asym_chain_ids, cif['_struct_asym.entity_id'], strict=True) + ) + entity_description = cif.get( + '_entity.pdbx_description', ['?'] * len(cif['_entity.id']) + ) + entity_desc_by_entity_id = dict( + zip(cif['_entity.id'], entity_description, strict=True) + ) + chain_type_by_entity_id = mmcif.get_chain_type_by_entity_id(cif) + auth_asym_id_by_chain_id = mmcif.get_internal_to_author_chain_id_map(cif) + + chain_res_builder = _ChainResBuilder( + chain_key_by_chain_id=chain_key_by_chain_id, + entity_id_by_chain_id=entity_id_by_chain_id, + chain_type_by_entity_id=chain_type_by_entity_id, + entity_desc_by_entity_id=entity_desc_by_entity_id, + fix_mse_residues=fix_mse_residues, + fix_unknown_dna=fix_unknown_dna, + ) + + # Collect data for polymer chain and residue tables. _pdbx_poly_seq_scheme is + # guaranteed to be present thanks to _maybe_add_missing_scheme_tables. + def _get_poly_seq_scheme_col(col: str) -> np.ndarray: + return cif.get_array(key=f'_pdbx_poly_seq_scheme.{col}', dtype=object) + + poly_seq_asym_ids = _get_poly_seq_scheme_col('asym_id') + poly_seq_pdb_seq_nums = _get_poly_seq_scheme_col('pdb_seq_num') + poly_seq_seq_ids = _get_poly_seq_scheme_col('seq_id') + poly_seq_mon_ids = _get_poly_seq_scheme_col('mon_id') + poly_seq_pdb_strand_ids = _get_poly_seq_scheme_col('pdb_strand_id') + poly_seq_pdb_ins_codes = _get_poly_seq_scheme_col('pdb_ins_code') + string_array.remap( + poly_seq_pdb_ins_codes, mapping=_INSERTION_CODE_REMAP, inplace=True + ) + + # We resolved alt-locs earlier for the atoms table. In cases of heterogeneous + # residues (a residue with an alt-loc that is of different residue type), we + # need to also do the same resolution in the residues table. Compute a mask + # for the residues that were selected in the atoms table. + poly_seq_mask = mmcif_utils.selected_polymer_residue_mask( + layout=layout, + atom_site_label_asym_ids=label_asym_ids[layout.residue_starts()], + atom_site_label_seq_ids=label_seq_ids[layout.residue_starts()], + atom_site_label_comp_ids=label_comp_ids[layout.residue_starts()], + poly_seq_asym_ids=poly_seq_asym_ids, + poly_seq_seq_ids=poly_seq_seq_ids, + poly_seq_mon_ids=poly_seq_mon_ids, + ) + + if not include_other and poly_seq_mask: + # Mask filtered-out residues so that they are not treated as missing. + # Instead, we don't want them included in the chains/residues tables at all. + keep_mask = string_array.remap( + poly_seq_asym_ids, + mapping={cid: True for cid in resolved_chain_ids}, + default_value=False, + inplace=False, + ).astype(bool) + poly_seq_mask &= keep_mask + + chain_res_builder.add_residues( + chain_ids=poly_seq_asym_ids[poly_seq_mask], + chain_auth_asym_ids=poly_seq_pdb_strand_ids[poly_seq_mask], + res_ids=poly_seq_seq_ids[poly_seq_mask].astype(np.int32), + res_names=poly_seq_mon_ids[poly_seq_mask], + res_auth_seq_ids=poly_seq_pdb_seq_nums[poly_seq_mask], + res_ins_codes=poly_seq_pdb_ins_codes[poly_seq_mask], + ) + + # Collect data for ligand chain and residue tables. _pdbx_nonpoly_scheme + # could be empty/unset if there are only branched ligands. + def _get_nonpoly_scheme_col(col: str) -> np.ndarray: + key = f'_pdbx_nonpoly_scheme.{col}' + if f'_pdbx_nonpoly_scheme.{col}' in cif: + return cif.get_array(key=key, dtype=object) + else: + return np.array([], dtype=object) + + nonpoly_asym_ids = _get_nonpoly_scheme_col('asym_id') + nonpoly_auth_seq_ids = _get_nonpoly_scheme_col('pdb_seq_num') + nonpoly_pdb_ins_codes = _get_nonpoly_scheme_col('pdb_ins_code') + nonpoly_mon_ids = _get_nonpoly_scheme_col('mon_id') + nonpoly_auth_asym_id = string_array.remap( + nonpoly_asym_ids, mapping=auth_asym_id_by_chain_id, inplace=False + ) + + def _get_branch_scheme_col(col: str) -> np.ndarray: + key = f'_pdbx_branch_scheme.{col}' + if f'_pdbx_branch_scheme.{col}' in cif: + return cif.get_array(key=key, dtype=object) + else: + return np.array([], dtype=object) + + branch_asym_ids = _get_branch_scheme_col('asym_id') + branch_auth_seq_ids = _get_branch_scheme_col('pdb_seq_num') + branch_pdb_ins_codes = _get_branch_scheme_col('pdb_ins_code') + branch_mon_ids = _get_branch_scheme_col('mon_id') + branch_auth_asym_id = string_array.remap( + branch_asym_ids, mapping=auth_asym_id_by_chain_id, inplace=False + ) + + if branch_asym_ids.size > 0 and branch_pdb_ins_codes.size == 0: + branch_pdb_ins_codes = np.array( + ['.'] * branch_asym_ids.size, dtype=object) + + # Compute the heterogeneous residue masks as above, this time for ligands. + nonpoly_mask, branch_mask = mmcif_utils.selected_ligand_residue_mask( + layout=layout, + atom_site_label_asym_ids=label_asym_ids[layout.residue_starts()], + atom_site_label_seq_ids=label_seq_ids[layout.residue_starts()], + atom_site_auth_seq_ids=auth_seq_ids[layout.residue_starts()], + atom_site_label_comp_ids=label_comp_ids[layout.residue_starts()], + atom_site_pdbx_pdb_ins_codes=pdbx_pdb_ins_codes[layout.residue_starts( + )], + nonpoly_asym_ids=nonpoly_asym_ids, + nonpoly_auth_seq_ids=nonpoly_auth_seq_ids, + nonpoly_pdb_ins_codes=nonpoly_pdb_ins_codes, + nonpoly_mon_ids=nonpoly_mon_ids, + branch_asym_ids=branch_asym_ids, + branch_auth_seq_ids=branch_auth_seq_ids, + branch_pdb_ins_codes=branch_pdb_ins_codes, + branch_mon_ids=branch_mon_ids, + ) + + if not include_water: + if nonpoly_mask: + nonpoly_mask &= (nonpoly_mon_ids != 'HOH') & ( + nonpoly_mon_ids != 'DOD') + if branch_mask: + # Fix for bad mmCIFs that have water in the branch scheme table. + branch_mask &= (branch_mon_ids != 'HOH') & ( + branch_mon_ids != 'DOD') + + string_array.remap( + pdbx_pdb_ins_codes, mapping=_INSERTION_CODE_REMAP, inplace=True + ) + string_array.remap( + nonpoly_pdb_ins_codes, mapping=_INSERTION_CODE_REMAP, inplace=True + ) + string_array.remap( + branch_pdb_ins_codes, mapping=_INSERTION_CODE_REMAP, inplace=True + ) + + def _ligand_residue_ids(chain_ids: np.ndarray) -> np.ndarray: + """Computes internal residue ID for ligand residues that don't have it.""" + + # E.g. chain_ids=[A, A, A, B, C, C, D, D, D] -> [1, 2, 3, 1, 1, 2, 1, 2, 3]. + indices = np.arange(chain_ids.size, dtype=np.int32) + return (indices + 1) - np.maximum.accumulate( + indices * (chain_ids != np.roll(chain_ids, 1)) + ) + + branch_residue_ids = _ligand_residue_ids(branch_asym_ids[branch_mask]) + nonpoly_residue_ids = _ligand_residue_ids(nonpoly_asym_ids[nonpoly_mask]) + + chain_res_builder.add_residues( + chain_ids=branch_asym_ids[branch_mask], + chain_auth_asym_ids=branch_auth_asym_id[branch_mask], + res_ids=branch_residue_ids, + res_names=branch_mon_ids[branch_mask], + res_auth_seq_ids=branch_auth_seq_ids[branch_mask], + res_ins_codes=branch_pdb_ins_codes[branch_mask], + ) + + chain_res_builder.add_residues( + chain_ids=nonpoly_asym_ids[nonpoly_mask], + chain_auth_asym_ids=nonpoly_auth_asym_id[nonpoly_mask], + res_ids=nonpoly_residue_ids, + res_names=nonpoly_mon_ids[nonpoly_mask], + res_auth_seq_ids=nonpoly_auth_seq_ids[nonpoly_mask], + res_ins_codes=nonpoly_pdb_ins_codes[nonpoly_mask], + ) + + chains = chain_res_builder.make_chains_table() + residues = chain_res_builder.make_residues_table() + + # Construct foreign residue keys for the atoms table. + res_ends = np.array(layout.residues(), dtype=np.int32) + res_starts = np.array(layout.residue_starts(), dtype=np.int32) + res_lengths = res_ends - res_starts + + # Check just for HOH, DOD can be part e.g. of hydroxycysteine. + if include_water: + res_chain_types = chains.apply_array_to_column( + column_name='type', arr=residues.chain_key + ) + water_mask = res_chain_types != mmcif_names.WATER + if 'HOH' in set(residues.name[water_mask]): + raise ValueError( + 'Bad mmCIF file: non-water entity has water molecules.') + else: + # Include resolved and unresolved residues. + if 'HOH' in set(residues.name) | set(label_comp_ids[res_starts]): + raise ValueError( + 'Bad mmCIF file: non-water entity has water molecules.') + + atom_chain_key = string_array.remap( + label_asym_ids, mapping=chain_res_builder.chain_key_by_chain_id + ).astype(int) + + # If any of the residue lookups failed, the mmCIF is corrupted. + try: + atom_res_key_per_res = string_array.remap_multiple( + ( + label_asym_ids[res_starts], + auth_seq_ids[res_starts], + label_comp_ids[res_starts], + pdbx_pdb_ins_codes[res_starts], + ), + mapping=chain_res_builder.key_for_res, + ) + except KeyError as e: + raise ValueError( + 'Lookup for the following atom from the _atom_site table failed: ' + f'(atom_id, auth_seq_id, res_name, ins_code)={e}. This is ' + 'likely due to a known issue with some multi-model mmCIFs that only ' + 'match the first model in _atom_site table to the _pdbx_poly_scheme, ' + '_pdbx_nonpoly_scheme, or _pdbx_branch_scheme tables.' + ) from e + + # The residue ID will be shared for all atoms within that residue. + atom_res_key = np.repeat(atom_res_key_per_res, repeats=res_lengths) + + if fix_mse_residues: + met_residues_mask = (residues.name == 'MET')[atom_res_key] + unfixed_mse_selenium_mask = met_residues_mask & ( + label_atom_ids == 'SE') + label_atom_ids[unfixed_mse_selenium_mask] = 'SD' + type_symbols[unfixed_mse_selenium_mask] = 'S' + + atoms = structure_tables.Atoms( + key=atom_site_first_model, + chain_key=atom_chain_key, + res_key=atom_res_key, + name=label_atom_ids, + element=type_symbols, + x=atom_x, + y=atom_y, + z=atom_z, + b_factor=atom_b_factor, + occupancy=atom_occupancy, + ) + + return chains, residues, atoms + + +def from_atom_arrays( + *, + res_id: np.ndarray, + name: str = 'unset', + release_date: datetime.date | None = None, + resolution: float | None = None, + structure_method: str | None = None, + all_residues: Mapping[str, Sequence[tuple[str, int]]] | None = None, + bioassembly_data: bioassemblies.BioassemblyData | None = None, + chemical_components_data: ( + struc_chem_comps.ChemicalComponentsData | None + ) = None, + bond_table: structure_tables.Bonds | None = None, + chain_id: np.ndarray | None = None, + chain_type: np.ndarray | None = None, + res_name: np.ndarray | None = None, + atom_key: np.ndarray | None = None, + atom_name: np.ndarray | None = None, + atom_element: np.ndarray | None = None, + atom_x: np.ndarray | None = None, + atom_y: np.ndarray | None = None, + atom_z: np.ndarray | None = None, + atom_b_factor: np.ndarray | None = None, + atom_occupancy: np.ndarray | None = None, +) -> structure.Structure: + """Returns a Structure constructed from atom array level data. + + All fields except name and, res_id are optional, all array fields consist of a + value for each atom in the structure - so residue and chain values should hold + the same value for each atom in the chain or residue. Fields which are not + defined are filled with default values. + + Validation is performed by the Structure constructor where possible - but + author_naming scheme and all_residues must be checked in this function. + + It is not possible to construct structures with chains that do not contain + any resolved residues using this function. If this is necessary, use the + structure.Structure constructor directly. + + Args: + res_id: Integer array of shape [num_atom]. The unique residue identifier for + each residue. mmCIF field - _atom_site.label_seq_id. + name: The name of the structure. E.g. a PDB ID. + release_date: The release date of the structure as a `datetime.date`. + resolution: The resolution of the structure in Angstroms. + structure_method: The method used to solve this structure's coordinates. + all_residues: An optional mapping from each chain ID (i.e. label_asym_id) to + a sequence of (label_comp_id, label_seq_id) tuples, one per residue. This + can contain residues that aren't present in the atom arrays. This is + common in experimental data where some residues are not resolved but are + known to be present. + bioassembly_data: An optional instance of bioassembly.BioassemblyData. If + present then a new Structure representing a specific bioassembly can be + extracted using `Structure.generate_bioassembly(assembly_id)`. + chemical_components_data: An optional instance of ChemicalComponentsData. + Its content will be used for providing metadata about chemical components + in this Structure instance. If not specified information will be retrieved + from the standard chemical component dictionary (CCD, for more details see + https://www.wwpdb.org/data/ccd). + bond_table: A table representing manually-specified bonds. This corresponds + to the _struct_conn table in an mmCIF. Atoms are identified by their key, + as specified by the atom_key column. If this table is provided then the + atom_key column must also be defined. + chain_id: String array of shape [num_atom] of unique chain identifiers. + mmCIF field - _atom_site.label_asym_id. + chain_type: String array of shape [num_atom]. The molecular type of the + current chain (e.g. polyribonucleotide). mmCIF field - _entity_poly.type + OR _entity.type (for non-polymers). + res_name: String array of shape [num_atom].. The name of each residue, + typically a 3 letter string for polypeptides or 1-2 letter strings for + polynucleotides. mmCIF field - _atom_site.label_comp_id. + atom_key: A unique sorted integer array, used only by the bonds table to + identify the atoms participating in each bond. If the bonds table is + specified then this column must be non-None. + atom_name: String array of shape [num_atom]. The name of each atom (e.g CA, + O2', etc.). mmCIF field - _atom_site.label_atom_id. + atom_element: String array of shape [num_atom]. The element type of each + atom (e.g. C, O, N, etc.). mmCIF field - _atom_site.type_symbol. + atom_x: Float array of shape [..., num_atom] of atom x coordinates. May have + arbitrary leading dimensions, provided that these are consistent across + all coordinate fields. + atom_y: Float array of shape [..., num_atom] of atom y coordinates. May have + arbitrary leading dimensions, provided that these are consistent across + all coordinate fields. + atom_z: Float array of shape [..., num_atom] of atom z coordinates. May have + arbitrary leading dimensions, provided that these are consistent across + all coordinate fields. + atom_b_factor: Float array of shape [..., num_atom] or [num_atom] of atom + b-factors or equivalent. If there are no extra leading dimensions then + these values are assumed to apply to all coordinates for a given atom. If + there are leading dimensions then these must match those used by the + coordinate fields. + atom_occupancy: Float array of shape [..., num_atom] or [num_atom] of atom + occupancies or equivalent. If there are no extra leading dimensions then + these values are assumed to apply to all coordinates for a given atom. If + there are leading dimensions then these must match those used by the + coordinate fields. + """ + + atoms, residues, chains = structure_tables.tables_from_atom_arrays( + res_id=res_id, + all_residues=all_residues, + chain_id=chain_id, + chain_type=chain_type, + res_name=res_name, + atom_key=atom_key, + atom_name=atom_name, + atom_element=atom_element, + atom_x=atom_x, + atom_y=atom_y, + atom_z=atom_z, + atom_b_factor=atom_b_factor, + atom_occupancy=atom_occupancy, + ) + + return structure.Structure( + name=name, + release_date=release_date, + resolution=resolution, + structure_method=structure_method, + bioassembly_data=bioassembly_data, + chemical_components_data=chemical_components_data, + atoms=atoms, + chains=chains, + residues=residues, + bonds=bond_table or structure_tables.Bonds.make_empty(), + ) + + +def _guess_entity_type( + chain_residues: Collection[str], atom_types: Collection[str] +) -> str: + """Guess the entity type (polymer/non-polymer/water) based on residues/atoms. + + We treat both arguments as unordered collections since we care only whether + all elements satisfy come conditions. The chain_residues can be either + grouped by residue (length num_res), or it can be raw (length num_atoms). + Atom type is unique for each atom in a residue, so don't group atom_types. + + Args: + chain_residues: A sequence of full residue name (1-letter for DNA, 2-letters + for RNA, 3 for protein). The _atom_site.label_comp_id column in mmCIF. + atom_types: Atom type: ATOM or HETATM. The _atom_site.group_PDB column in + mmCIF. + + Returns: + One of polymer/non-polymer/water based on the following criteria: + * If all atoms are HETATMs and all residues are water -> water. + * If all atoms are HETATMs and not all residues are water -> non-polymer. + * Otherwise -> polymer. + """ + if not chain_residues or not atom_types: + raise ValueError( + f'chain_residues (len {len(chain_residues)}) and atom_types (len ' + f'{len(atom_types)}) must be both non-empty. Got: {chain_residues=} ' + f'and {atom_types=}' + ) + + if all(a == 'HETATM' for a in atom_types): + if all(c in residue_names.WATER_TYPES for c in chain_residues): + return mmcif_names.WATER + return mmcif_names.NON_POLYMER_CHAIN + return mmcif_names.POLYMER_CHAIN diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/sterics.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/sterics.py new file mode 100644 index 0000000000000000000000000000000000000000..55c7c1783a0360a8f19311b23b47f2ca8a275c41 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/sterics.py @@ -0,0 +1,142 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Functions relating to spatial locations of atoms within a structure.""" + +from collections.abc import Collection, Sequence + +from alphafold3 import structure +from alphafold3.structure import mmcif +import numpy as np +import scipy + + +def _make_atom_has_clash_mask( + kd_query_result: np.ndarray, + struct: structure.Structure, + ignore_chains: Collection[str], +) -> np.ndarray: + """Returns a boolean NumPy array representing whether each atom has a clash. + + Args: + kd_query_result: NumPy array containing N-atoms arrays, each array + containing indices to atoms that clash with the N'th atom. + struct: Structure over which clashes were detected. + ignore_chains: Collection of chains that should not be considered clashing. + A boolean NumPy array of length N atoms. + """ + atom_is_clashing = np.zeros((struct.num_atoms,), dtype=bool) + for atom_index, clashes in enumerate(kd_query_result): + chain_i = struct.chain_id[atom_index] + if chain_i in ignore_chains: + continue + islig_i = struct.is_ligand_mask[atom_index] + for clashing_atom_index in clashes: + chain_c = struct.chain_id[clashing_atom_index] + if chain_c in ignore_chains: + continue + islig_c = struct.is_ligand_mask[clashing_atom_index] + if ( + clashing_atom_index == atom_index + or chain_i == chain_c + or islig_i != islig_c + ): + # Ignore clashes within chain or between ligand and polymer. + continue + atom_is_clashing[atom_index] = True + return atom_is_clashing + + +def find_clashing_chains( + struct: structure.Structure, + clash_thresh_angstrom: float = 1.7, + clash_thresh_fraction: float = 0.3, +) -> Sequence[str]: + """Finds chains that clash with others. + + Clashes are defined by polymer backbone atoms and all ligand atoms. + Ligand-polymer clashes are not dropped. + + Will not find clashes if all coordinates are 0. Coordinates are all 0s if + the structure is generated from sequences only, as done for inference in + dendro for example. + + Args: + struct: The structure defining the chains and atom positions. + clash_thresh_angstrom: Below this distance, atoms are considered clashing. + clash_thresh_fraction: Chains with more than this fraction of their atoms + considered clashing will be dropped. This value should be in the range (0, + 1]. + + Returns: + A sequence of chain ids for chains that clash. + + Raises: + ValueError: If `clash_thresh_fraction` is not in range (0,1]. + """ + if not 0 < clash_thresh_fraction <= 1: + raise ValueError('clash_thresh_fraction must be in range (0,1]') + + struct_backbone = struct.filter_polymers_to_single_atom_per_res() + if struct_backbone.num_chains == 0: + return [] + + # If the coordinates are all 0, do not search for clashes. + if not np.any(struct_backbone.coords): + return [] + + coord_kdtree = scipy.spatial.cKDTree(struct_backbone.coords) + + # For each atom coordinate, find all atoms within the clash thresh radius. + clashing_per_atom = coord_kdtree.query_ball_point( + struct_backbone.coords, r=clash_thresh_angstrom + ) + chain_ids = struct_backbone.chains + if struct_backbone.atom_occupancy is not None: + chain_occupancy = np.array([ + np.mean(struct_backbone.atom_occupancy[start:end]) + for start, end in struct_backbone.iter_chain_ranges() + ]) + else: + chain_occupancy = None + + # Remove chains until no more significant clashing. + chains_to_remove = set() + for _ in range(len(chain_ids)): + # Calculate maximally clashing. + atom_has_clash = _make_atom_has_clash_mask( + clashing_per_atom, struct_backbone, chains_to_remove + ) + clashes_per_chain = np.array([ + atom_has_clash[start:end].mean() + for start, end in struct_backbone.iter_chain_ranges() + ]) + max_clash = np.max(clashes_per_chain) + if max_clash <= clash_thresh_fraction: + # None of the remaining chains exceed the clash fraction threshold, so + # we can exit. + break + + # Greedily remove worst with the lowest occupancy. + most_clashes = np.nonzero(clashes_per_chain == max_clash)[0] + if chain_occupancy is not None: + occupancy_clashing = chain_occupancy[most_clashes] + last_lowest_occupancy = ( + len(occupancy_clashing) - + np.argmin(occupancy_clashing[::-1]) - 1 + ) + worst_and_last = most_clashes[last_lowest_occupancy] + else: + worst_and_last = most_clashes[-1] + + chains_to_remove.add(chain_ids[worst_and_last]) + + return sorted(chains_to_remove, key=mmcif.str_id_to_int_id) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/structure.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/structure.py new file mode 100644 index 0000000000000000000000000000000000000000..8d66b52ec832ada58b885cca2c042ca8beafcafe --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/structure.py @@ -0,0 +1,3180 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Structure class for representing and processing molecular structures.""" + +import collections +from collections.abc import Callable, Collection, Iterable, Iterator, Mapping, Sequence, Set +import dataclasses +import datetime +import enum +import functools +import itertools +import typing +from typing_extensions import Any, ClassVar, Final, Literal, NamedTuple, Self, TypeAlias, TypeVar +import numpy as np +from alphafold3.constants import atom_types +from alphafold3.constants import chemical_components +from alphafold3.constants import mmcif_names +from alphafold3.constants import residue_names +from alphafold3.cpp import membership +from alphafold3.cpp import string_array +from alphafold3.structure import bioassemblies +from alphafold3.structure import chemical_components as struct_chem_comps +from alphafold3.structure import mmcif +from alphafold3.structure import structure_tables +from alphafold3.structure import table + +# Controls the default number of decimal places for coordinates when writing to +# mmCIF. +_COORDS_DECIMAL_PLACES: Final[int] = 3 + + +@enum.unique +class CascadeDelete(enum.Enum): + NONE = 0 + FULL = 1 + CHAINS = 2 + + +# See www.python.org/dev/peps/pep-0484/#support-for-singleton-types-in-unions +class _UnsetSentinel(enum.Enum): + UNSET = object() + + +_UNSET = _UnsetSentinel.UNSET + + +class Bond(NamedTuple): + """Describes a bond between two atoms.""" + + from_atom: Mapping[str, str | int | float | np.ndarray] + dest_atom: Mapping[str, str | int | float | np.ndarray] + bond_info: Mapping[str, str | int] + + +class MissingAtomError(Exception): + """Error raised when an atom is missing during alignment.""" + + +class MissingAuthorResidueIdError(Exception): + """Raised when author naming data is missing for a residue. + + This can occur in certain edge cases where missing residue data is provided + without also providing author IDs for those missing residues. + """ + + +# AllResidues is a mapping from label_asym_id to a sequence of (label_comp_id, +# label_seq_id) pairs. These represent the full sequence including residues +# that might be missing (e.g. unresolved residues in X-ray data). +AllResidues: TypeAlias = Mapping[str, Sequence[tuple[str, int]]] +AuthorNamingScheme: TypeAlias = structure_tables.AuthorNamingScheme + + +# External residue ID given to missing residues that don't have an ID +# already provided. In mmCIFs this data is found in _pdbx_poly_seq_scheme. +MISSING_AUTH_SEQ_ID: Final[str] = '.' + + +# Maps from structure fields to column names in the relevant table. +CHAIN_FIELDS: Final[Mapping[str, str]] = { + 'chain_id': 'id', + 'chain_type': 'type', + 'chain_auth_asym_id': 'auth_asym_id', + 'chain_entity_id': 'entity_id', + 'chain_entity_desc': 'entity_desc', +} + + +RESIDUE_FIELDS: Final[Mapping[str, str]] = { + 'res_id': 'id', + 'res_name': 'name', + 'res_auth_seq_id': 'auth_seq_id', + 'res_insertion_code': 'insertion_code', +} + +ATOM_FIELDS: Final[Mapping[str, str]] = { + 'atom_name': 'name', + 'atom_element': 'element', + 'atom_x': 'x', + 'atom_y': 'y', + 'atom_z': 'z', + 'atom_b_factor': 'b_factor', + 'atom_occupancy': 'occupancy', + 'atom_key': 'key', +} + +# Fields in structure. +ARRAY_FIELDS = frozenset({ + 'atom_b_factor', + 'atom_element', + 'atom_key', + 'atom_name', + 'atom_occupancy', + 'atom_x', + 'atom_y', + 'atom_z', + 'chain_id', + 'chain_type', + 'res_id', + 'res_name', +}) + +GLOBAL_FIELDS = frozenset({ + 'name', + 'release_date', + 'resolution', + 'structure_method', + 'bioassembly_data', + 'chemical_components_data', +}) + +# Fields which can be updated in copy_and_update. +_UPDATEABLE_FIELDS: Final[Set[str]] = frozenset({ + 'all_residues', + 'atom_b_factor', + 'atom_element', + 'atom_key', + 'atom_name', + 'atom_occupancy', + 'atom_x', + 'atom_y', + 'atom_z', + 'bioassembly_data', + 'bonds', + 'chain_id', + 'chain_type', + 'chemical_components_data', + 'name', + 'release_date', + 'res_id', + 'res_name', + 'resolution', + 'structure_method', +}) + + +def fix_non_standard_polymer_residues( + res_names: np.ndarray, chain_type: str +) -> np.ndarray: + """Remaps residue names to the closest standard protein/RNA/DNA residue. + + If residue name is already a standard type, it is not altered. + If a match cannot be found, returns 'UNK' for protein chainresidues and 'N' + for RNA/DNA chain residue. + + Args: + res_names: A numpy array of string residue names (CCD monomer codes). E.g. + 'ARG' (protein), 'DT' (DNA), 'N' (RNA). + chain_type: The type of the chain, must be PROTEIN_CHAIN, RNA_CHAIN or + DNA_CHAIN. + + Returns: + An array remapped so that its elements are all from + PROTEIN_TYPES_WITH_UNKNOWN | RNA_TYPES | DNA_TYPES | {'N'}. + + Raises: + ValueError: If chain_type not in PEPTIDE_CHAIN_TYPES or + {OTHER_CHAIN, RNA_CHAIN, DNA_CHAIN, DNA_RNA_HYBRID_CHAIN}. + """ + # Map to one letter code, then back to common res_names. + one_letter_codes = string_array.remap( + res_names, mapping=residue_names.CCD_NAME_TO_ONE_LETTER, default_value='X' + ) + + if ( + chain_type in mmcif_names.PEPTIDE_CHAIN_TYPES + or chain_type == mmcif_names.OTHER_CHAIN + ): + mapping = residue_names.PROTEIN_COMMON_ONE_TO_THREE + default_value = 'UNK' + elif chain_type == mmcif_names.RNA_CHAIN: + # RNA has single-letter CCD monomer codes. + mapping = {r: r for r in residue_names.RNA_TYPES} + default_value = 'N' + elif chain_type == mmcif_names.DNA_CHAIN: + mapping = residue_names.DNA_COMMON_ONE_TO_TWO + default_value = 'N' + elif chain_type == mmcif_names.DNA_RNA_HYBRID_CHAIN: + mapping = {r: r for r in residue_names.NUCLEIC_TYPES_WITH_UNKNOWN} + default_value = 'N' + else: + raise ValueError( + f'Expected a protein/DNA/RNA chain but got {chain_type}') + + return string_array.remap( + one_letter_codes, mapping=mapping, default_value=default_value + ) + + +def _get_change_indices(arr: np.ndarray) -> np.ndarray: + if arr.size == 0: + return np.array([], dtype=np.int32) + else: + changing_idxs = np.where(arr[1:] != arr[:-1])[0] + 1 + return np.concatenate(([0], changing_idxs), axis=0) + + +def _unpack_filter_predicates( + predicate_by_field_name: Mapping[str, table.FilterPredicate], +) -> tuple[ + Mapping[str, table.FilterPredicate], + Mapping[str, table.FilterPredicate], + Mapping[str, table.FilterPredicate], +]: + """Unpacks filter kwargs into predicates for each table.""" + chain_predicates = {} + res_predicates = {} + atom_predicates = {} + for k, pred in predicate_by_field_name.items(): + if col := CHAIN_FIELDS.get(k): + chain_predicates[col] = pred + elif col := RESIDUE_FIELDS.get(k): + res_predicates[col] = pred + elif col := ATOM_FIELDS.get(k): + atom_predicates[col] = pred + else: + raise ValueError(k) + return chain_predicates, res_predicates, atom_predicates + + +_T = TypeVar('_T') + + +SCALAR_FIELDS: Final[Collection[str]] = frozenset({ + 'name', + 'release_date', + 'resolution', + 'structure_method', + 'bioassembly_data', + 'chemical_components_data', +}) + + +TABLE_FIELDS: Final[Collection[str]] = frozenset( + {'chains', 'residues', 'atoms', 'bonds'} +) + + +V2_FIELDS: Final[Collection[str]] = frozenset({*SCALAR_FIELDS, *TABLE_FIELDS}) + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class StructureTables: + chains: structure_tables.Chains + residues: structure_tables.Residues + atoms: structure_tables.Atoms + bonds: structure_tables.Bonds + + +class Structure(table.Database): + """Structure class for representing and processing molecular structures.""" + + tables: ClassVar[Collection[str]] = TABLE_FIELDS + + foreign_keys: ClassVar[Mapping[str, Collection[tuple[str, str]]]] = { + 'residues': (('chain_key', 'chains'),), + 'atoms': (('chain_key', 'chains'), ('res_key', 'residues')), + 'bonds': (('from_atom_key', 'atoms'), ('dest_atom_key', 'atoms')), + } + + def __init__( + self, + *, + name: str = 'unset', + release_date: datetime.date | None = None, + resolution: float | None = None, + structure_method: str | None = None, + bioassembly_data: bioassemblies.BioassemblyData | None = None, + chemical_components_data: ( + struct_chem_comps.ChemicalComponentsData | None + ) = None, + chains: structure_tables.Chains, + residues: structure_tables.Residues, + atoms: structure_tables.Atoms, + bonds: structure_tables.Bonds, + skip_validation: bool = False, + ): + # Version number is written to mmCIF and should be incremented when changes + # are made to mmCIF writing or internals that affect this. + # b/345221494 Rename this variable when structure_v1 compatibility code + # is removed. + self._VERSION = '2.0.0' # pylint: disable=invalid-name + self._name = name + self._release_date = release_date + self._resolution = resolution + self._structure_method = structure_method + self._bioassembly_data = bioassembly_data + self._chemical_components_data = chemical_components_data + + self._chains = chains + self._residues = residues + self._atoms = atoms + self._bonds = bonds + + if not skip_validation: + self._validate_table_foreign_keys() + self._validate_consistent_table_ordering() + + def _validate_table_foreign_keys(self): + """Validates that all foreign keys are present in the referred tables.""" + residue_keys = set(self._residues.key) + chain_keys = set(self._chains.key) + if np.any(membership.isin(self._atoms.res_key, residue_keys, invert=True)): + raise ValueError( + 'Atom residue keys not in the residues table: ' + f'{set(self._atoms.res_key).difference(self._residues.key)}' + ) + if np.any(membership.isin(self._atoms.chain_key, chain_keys, invert=True)): + raise ValueError( + 'Atom chain keys not in the chains table: ' + f'{set(self._atoms.chain_key).difference(self._chains.key)}' + ) + if np.any( + membership.isin(self._residues.chain_key, chain_keys, invert=True) + ): + raise ValueError( + 'Residue chain keys not in the chains table: ' + f'{set(self._residues.chain_key).difference(self._chains.key)}' + ) + + def _validate_consistent_table_ordering(self): + """Validates that all tables have the same ordering.""" + atom_chain_keys = self._atoms.chain_key[self.chain_boundaries] + atom_res_keys = self._atoms.res_key[self.res_boundaries] + + if not np.array_equal(self.present_chains.key, atom_chain_keys): + raise ValueError( + f'Atom table chain order\n{atom_chain_keys}\ndoes not match the ' + f'chain table order\n{self._chains.key}' + ) + if not np.array_equal(self.present_residues.key, atom_res_keys): + raise ValueError( + f'Atom table residue order\n{atom_res_keys}\ndoes not match the ' + f'present residue table order\n{self.present_residues.key}' + ) + + def get_table(self, table_name: str) -> table.Table: + match table_name: + case 'chains': + return self.chains_table + case 'residues': + return self.residues_table + case 'atoms': + return self.atoms_table + case 'bonds': + return self.bonds_table + case _: + raise ValueError(table_name) + + @property + def chains_table(self) -> structure_tables.Chains: + """Chains table.""" + return self._chains + + @property + def residues_table(self) -> structure_tables.Residues: + """Residues table.""" + return self._residues + + @property + def atoms_table(self) -> structure_tables.Atoms: + """Atoms table.""" + return self._atoms + + @property + def bonds_table(self) -> structure_tables.Bonds: + """Bonds table.""" + return self._bonds + + @property + def name(self) -> str: + return self._name + + @property + def release_date(self) -> datetime.date | None: + return self._release_date + + @property + def resolution(self) -> float | None: + return self._resolution + + @property + def structure_method(self) -> str | None: + return self._structure_method + + @property + def bioassembly_data(self) -> bioassemblies.BioassemblyData | None: + return self._bioassembly_data + + @property + def chemical_components_data( + self, + ) -> struct_chem_comps.ChemicalComponentsData | None: + return self._chemical_components_data + + @property + def bonds(self) -> structure_tables.Bonds: + return self._bonds + + @functools.cached_property + def author_naming_scheme(self) -> AuthorNamingScheme: + auth_asym_id = {} + entity_id = {} + entity_desc = {} + auth_seq_id = collections.defaultdict(dict) + insertion_code = collections.defaultdict(dict) + + for chain_i in range(self._chains.size): + chain_id = self._chains.id[chain_i] + auth_asym_id[chain_id] = self._chains.auth_asym_id[chain_i] + chain_entity_id = self._chains.entity_id[chain_i] + entity_id[chain_id] = chain_entity_id + entity_desc[chain_entity_id] = self._chains.entity_desc[chain_i] + + chain_index_by_key = self._chains.index_by_key + for res_i in range(self._residues.size): + chain_key = self._residues.chain_key[res_i] + chain_id = self._chains.id[chain_index_by_key[chain_key]] + res_id = self._residues.id[res_i] + res_auth_seq_id = self._residues.auth_seq_id[res_i] + if res_auth_seq_id == MISSING_AUTH_SEQ_ID: + continue + auth_seq_id[chain_id][res_id] = res_auth_seq_id + ins_code = self._residues.insertion_code[res_i] + # Compatibility with Structure v1 which used None to represent . or ?. + insertion_code[chain_id][res_id] = ( + ins_code if ins_code not in {'.', '?'} else None + ) + + return AuthorNamingScheme( + auth_asym_id=auth_asym_id, + entity_id=entity_id, + entity_desc=entity_desc, + auth_seq_id=dict(auth_seq_id), + insertion_code=dict(insertion_code), + ) + + @functools.cached_property + def all_residues(self) -> AllResidues: + chain_id_by_key = dict(zip(self._chains.key, self._chains.id)) + residue_chain_boundaries = _get_change_indices( + self._residues.chain_key) + boundaries = self._iter_residue_ranges( + residue_chain_boundaries, count_unresolved=True + ) + return { + chain_id_by_key[self._residues.chain_key[start]]: list( + zip(self._residues.name[start:end], + self._residues.id[start:end]) + ) + for start, end in boundaries + } + + @functools.cached_property + def label_asym_id_to_entity_id(self) -> Mapping[str, str]: + return dict(zip(self._chains.id, self._chains.entity_id)) + + @functools.cached_property + def chain_entity_id(self) -> np.ndarray: + """Returns the entity ID for each atom in the structure.""" + return self.chains_table.apply_array_to_column( + 'entity_id', self._atoms.chain_key + ) + + @functools.cached_property + def chain_entity_desc(self) -> np.ndarray: + """Returns the entity description for each atom in the structure.""" + return self.chains_table.apply_array_to_column( + 'entity_desc', self._atoms.chain_key + ) + + @functools.cached_property + def chain_auth_asym_id(self) -> np.ndarray: + """Returns the chain auth asym ID for each atom in the structure.""" + return self.chains_table.apply_array_to_column( + 'auth_asym_id', self._atoms.chain_key + ) + + @functools.cached_property + def chain_id(self) -> np.ndarray: + chain_index_by_key = self._chains.index_by_key + return self._chains.id[chain_index_by_key[self._atoms.chain_key]] + + @functools.cached_property + def chain_type(self) -> np.ndarray: + chain_index_by_key = self._chains.index_by_key + return self._chains.type[chain_index_by_key[self._atoms.chain_key]] + + @functools.cached_property + def res_id(self) -> np.ndarray: + return self._residues['id', self._atoms.res_key] + + @functools.cached_property + def res_name(self) -> np.ndarray: + return self._residues['name', self._atoms.res_key] + + @functools.cached_property + def res_auth_seq_id(self) -> np.ndarray: + """Returns the residue auth seq ID for each atom in the structure.""" + return self.residues_table.apply_array_to_column( + 'auth_seq_id', self._atoms.res_key + ) + + @functools.cached_property + def res_insertion_code(self) -> np.ndarray: + """Returns the residue insertion code for each atom in the structure.""" + return self.residues_table.apply_array_to_column( + 'insertion_code', self._atoms.res_key + ) + + @property + def atom_key(self) -> np.ndarray: + return self._atoms.key + + @property + def atom_name(self) -> np.ndarray: + return self._atoms.name + + @property + def atom_element(self) -> np.ndarray: + return self._atoms.element + + @property + def atom_x(self) -> np.ndarray: + return self._atoms.x + + @property + def atom_y(self) -> np.ndarray: + return self._atoms.y + + @property + def atom_z(self) -> np.ndarray: + return self._atoms.z + + @property + def atom_b_factor(self) -> np.ndarray: + return self._atoms.b_factor + + @property + def atom_occupancy(self) -> np.ndarray: + return self._atoms.occupancy + + @functools.cached_property + def chain_boundaries(self) -> np.ndarray: + """The indices in the atom fields where each chain begins.""" + return _get_change_indices(self._atoms.chain_key) + + @functools.cached_property + def res_boundaries(self) -> np.ndarray: + """The indices in the atom fields where each residue begins.""" + return _get_change_indices(self._atoms.res_key) + + @functools.cached_property + def present_chains(self) -> structure_tables.Chains: + """Returns table of chains which have at least 1 resolved atom.""" + is_present_mask = np.isin(self._chains.key, self._atoms.chain_key) + return typing.cast(structure_tables.Chains, self._chains[is_present_mask]) + + @functools.cached_property + def present_residues(self) -> structure_tables.Residues: + """Returns table of residues which have at least 1 resolved atom.""" + is_present_mask = np.isin(self._residues.key, self._atoms.res_key) + return typing.cast( + structure_tables.Residues, self._residues[is_present_mask] + ) + + @functools.cached_property + def unresolved_residues(self) -> structure_tables.Residues: + """Returns table of residues which have at least 1 resolved atom.""" + is_unresolved_mask = np.isin( + self._residues.key, self._atoms.res_key, invert=True + ) + return typing.cast( + structure_tables.Residues, self._residues[is_unresolved_mask] + ) + + def __getitem__(self, field: str) -> Any: + """Gets raw field data using field name as a string.""" + if field in TABLE_FIELDS: + return self.get_table(field) + else: + return getattr(self, field) + + def __getstate__(self) -> dict[str, Any]: + """Pickle calls this on dump. + + Returns: + Members with cached properties removed. + """ + cached_props = { + k + for k, v in self.__class__.__dict__.items() + if isinstance(v, functools.cached_property) + } + return {k: v for k, v in self.__dict__.items() if k not in cached_props} + + def __repr__(self): + return ( + f'Structure({self._name}: {self.num_chains} chains, ' + f'{self.num_residues(count_unresolved=False)} residues, ' + f'{self.num_atoms} atoms)' + ) + + @property + def num_atoms(self) -> int: + return self._atoms.size + + def num_residues(self, *, count_unresolved: bool) -> int: + """Returns the number of residues in this Structure. + + Args: + count_unresolved: Whether to include unresolved (empty) residues. + + Returns: + Number of residues in the Structure. + """ + if count_unresolved: + return self._residues.size + else: + return self.present_residues.size + + @property + def num_chains(self) -> int: + return self._chains.size + + @property + def num_models(self) -> int: + """The number of models of this Structure.""" + return self._atoms.num_models + + def _atom_mask(self, entities: Set[str]) -> np.ndarray: + """Boolean label indicating if each atom is from entities or not.""" + mask = np.zeros(self.num_atoms, dtype=bool) + chain_index_by_key = self._chains.index_by_key + for start, end in self.iter_chain_ranges(): + chain_index = chain_index_by_key[self._atoms.chain_key[start]] + chain_type = self._chains.type[chain_index] + mask[start:end] = chain_type in entities + return mask + + @functools.cached_property + def is_protein_mask(self) -> np.ndarray: + """Boolean label indicating if each atom is from protein or not.""" + return self._atom_mask(entities={mmcif_names.PROTEIN_CHAIN}) + + @functools.cached_property + def is_dna_mask(self) -> np.ndarray: + """Boolean label indicating if each atom is from DNA or not.""" + return self._atom_mask(entities={mmcif_names.DNA_CHAIN}) + + @functools.cached_property + def is_rna_mask(self) -> np.ndarray: + """Boolean label indicating if each atom is from RNA or not.""" + return self._atom_mask(entities={mmcif_names.RNA_CHAIN}) + + @functools.cached_property + def is_nucleic_mask(self) -> np.ndarray: + """Boolean label indicating if each atom is a nucleic acid or not.""" + return self._atom_mask(entities=mmcif_names.NUCLEIC_ACID_CHAIN_TYPES) + + @functools.cached_property + def is_ligand_mask(self) -> np.ndarray: + """Boolean label indicating if each atom is a ligand or not.""" + return self._atom_mask(entities=mmcif_names.LIGAND_CHAIN_TYPES) + + @functools.cached_property + def is_water_mask(self) -> np.ndarray: + """Boolean label indicating if each atom is from water or not.""" + return self._atom_mask(entities={mmcif_names.WATER}) + + def iter_atoms(self) -> Iterator[Mapping[str, Any]]: + """Iterates over the atoms in the structure.""" + if self._atoms.size == 0: + return + + current_chain = self._chains.get_row_by_key( + column_name_map=CHAIN_FIELDS, key=self._atoms.chain_key[0] + ) + current_chain_key = self._atoms.chain_key[0] + current_res = self._residues.get_row_by_key( + column_name_map=RESIDUE_FIELDS, key=self._atoms.res_key[0] + ) + current_res_key = self._atoms.res_key[0] + for atom_i in range(self._atoms.size): + atom_chain_key = self._atoms.chain_key[atom_i] + atom_res_key = self._atoms.res_key[atom_i] + + if atom_chain_key != current_chain_key: + chain_index = self._chains.index_by_key[atom_chain_key] + current_chain = { + 'chain_id': self._chains.id[chain_index], + 'chain_type': self._chains.type[chain_index], + 'chain_auth_asym_id': self._chains.auth_asym_id[chain_index], + 'chain_entity_id': self._chains.entity_id[chain_index], + 'chain_entity_desc': self._chains.entity_desc[chain_index], + } + current_chain_key = atom_chain_key + if atom_res_key != current_res_key: + res_index = self._residues.index_by_key[atom_res_key] + current_res = { + 'res_id': self._residues.id[res_index], + 'res_name': self._residues.name[res_index], + 'res_auth_seq_id': self._residues.auth_seq_id[res_index], + 'res_insertion_code': self._residues.insertion_code[res_index], + } + current_res_key = atom_res_key + + yield { + 'atom_name': self._atoms.name[atom_i], + 'atom_element': self._atoms.element[atom_i], + 'atom_x': self._atoms.x[..., atom_i], + 'atom_y': self._atoms.y[..., atom_i], + 'atom_z': self._atoms.z[..., atom_i], + 'atom_b_factor': self._atoms.b_factor[..., atom_i], + 'atom_occupancy': self._atoms.occupancy[..., atom_i], + 'atom_key': self._atoms.key[atom_i], + **current_res, + **current_chain, + } + + def iter_residues( + self, + include_unresolved: bool = False, + ) -> Iterator[Mapping[str, Any]]: + """Iterates over the residues in the structure.""" + res_table = self._residues if include_unresolved else self.present_residues + if res_table.size == 0: + return + + current_chain = self._chains.get_row_by_key( + column_name_map=CHAIN_FIELDS, key=res_table.chain_key[0] + ) + current_chain_key = res_table.chain_key[0] + for res_i in range(res_table.size): + res_chain_key = res_table.chain_key[res_i] + + if res_chain_key != current_chain_key: + current_chain = self._chains.get_row_by_key( + column_name_map=CHAIN_FIELDS, key=res_table.chain_key[res_i] + ) + current_chain_key = res_chain_key + + row = { + 'res_id': res_table.id[res_i], + 'res_name': res_table.name[res_i], + 'res_auth_seq_id': res_table.auth_seq_id[res_i], + 'res_insertion_code': res_table.insertion_code[res_i], + } + yield row | current_chain + + def _iter_atom_ranges( + self, boundaries: Sequence[int] + ) -> Iterator[tuple[int, int]]: + """Iterator for (start, end) pairs from an array of start indices.""" + yield from itertools.pairwise(boundaries) + # Use explicit length test as boundaries can be a NumPy array. + if len(boundaries) > 0: # pylint: disable=g-explicit-length-test + yield boundaries[-1], self.num_atoms + + def _iter_residue_ranges( + self, + boundaries: Sequence[int], + *, + count_unresolved: bool, + ) -> Iterator[tuple[int, int]]: + """Iterator for (start, end) pairs from an array of start indices.""" + yield from itertools.pairwise(boundaries) + # Use explicit length test as boundaries can be a NumPy array. + if len(boundaries) > 0: # pylint: disable=g-explicit-length-test + yield boundaries[-1], self.num_residues(count_unresolved=count_unresolved) + + def iter_chain_ranges(self) -> Iterator[tuple[int, int]]: + """Iterates pairs of (chain_start, chain_end) indices. + + Yields: + Pairs of (start, end) indices for each chain, where end is not inclusive. + i.e. struct.chain_id[start:end] would be a constant array with length + equal to the number of atoms in the chain. + """ + yield from self._iter_atom_ranges(self.chain_boundaries) + + def iter_residue_ranges(self) -> Iterator[tuple[int, int]]: + """Iterates pairs of (residue_start, residue_end) indices. + + Yields: + Pairs of (start, end) indices for each residue, where end is not + inclusive. i.e. struct.res_id[start:end] would be a constant array with + length equal to the number of atoms in the residue. + """ + yield from self._iter_atom_ranges(self.res_boundaries) + + def iter_chains(self) -> Iterator[Mapping[str, Any]]: + """Iterates over the chains in the structure.""" + for chain_i in range(self.present_chains.size): + yield { + 'chain_id': self.present_chains.id[chain_i], + 'chain_type': self.present_chains.type[chain_i], + 'chain_auth_asym_id': self.present_chains.auth_asym_id[chain_i], + 'chain_entity_id': self.present_chains.entity_id[chain_i], + 'chain_entity_desc': self.present_chains.entity_desc[chain_i], + } + + def iter_bonds(self) -> Iterator[Bond]: + """Iterates over the atoms and bond information. + + Example usage: + + ``` + for from_atom, dest_atom, bond_info in struct.iter_bonds(): + print( + f'From atom: name={from_atom["atom_name"]}, ' + f'chain={from_atom["chain_id"]}, ...' + ) + # Same for dest_atom + print(f'Bond info: type={bond_info["type"]}, role={bond_info["role"]}') + ``` + + Yields: + A `Bond` NamedTuple for each bond in the bonds table. + These have fields `from_atom`, `dest_atom`, `bond_info` where each + is a dictionary. The first two have the same keys as the atom dicts + returned by self.iter_atoms() -- i.e. one key per non-None field. + The final dict has the same keys as self.bonds.iterrows() -- i.e. one + key per column in the bonds table. + """ + from_atom_iter = self._atoms.iterrows( + row_keys=self._bonds.from_atom_key, + column_name_map=ATOM_FIELDS, + chain_key=self._chains.with_column_names(CHAIN_FIELDS), + res_key=self._residues.with_column_names(RESIDUE_FIELDS), + ) + dest_atom_iter = self._atoms.iterrows( + row_keys=self._bonds.dest_atom_key, + column_name_map=ATOM_FIELDS, + chain_key=self._chains.with_column_names(CHAIN_FIELDS), + res_key=self._residues.with_column_names(RESIDUE_FIELDS), + ) + + for from_atom, dest_atom, bond_info in zip( + from_atom_iter, dest_atom_iter, self._bonds.iterrows(), strict=True + ): + yield Bond(from_atom=from_atom, dest_atom=dest_atom, bond_info=bond_info) + + def _apply_atom_index_array( + self, + index_arr: np.ndarray, + chain_boundaries: np.ndarray | None = None, + res_boundaries: np.ndarray | None = None, + skip_validation: bool = False, + ) -> Self: + """Applies index_arr to the atom table using NumPy-style array indexing. + + Args: + index_arr: A 1D NumPy array that will be used to index into the atoms + table. This can either be a boolean array to act as a mask, or an + integer array to perform a gather operation. + chain_boundaries: Unused in structure v2. + res_boundaries: Unused in structure v2. + skip_validation: Whether to skip the validation step that checks internal + consistency after applying atom index array. Do not set to True unless + you are certain the transform is safe, e.g. when the order of atoms is + guaranteed to not change. + + Returns: + A new Structure with an updated atoms table. + """ + del chain_boundaries, res_boundaries + + if index_arr.ndim != 1: + raise ValueError( + f'index_arr must be a 1D NumPy array, but has shape {index_arr.shape}' + ) + + if index_arr.dtype == bool and np.all(index_arr): + # Shortcut: The operation is a no-op, so just return itself. + return self + + atoms = structure_tables.Atoms( + **{col: self._atoms[col][..., index_arr] for col in self._atoms.columns} + ) + updated_tables = self._cascade_delete(atoms=atoms) + return self.copy_and_update( + atoms=updated_tables.atoms, + bonds=updated_tables.bonds, + skip_validation=skip_validation, + ) + + @property + def group_by_residue(self) -> Self: + """Returns a Structure with one atom per residue. + + e.g. restypes = struct.group_by_residue['res_id'] + + Returns: + A new Structure with one atom per residue such that per-atom arrays + such as res_name (i.e. Structure v1 fields) have one element per residue. + """ + # This use of _apply_atom_index_array is safe because the chain/residue/atom + # ordering won't change (essentially applying a residue start mask). + return self._apply_atom_index_array( + self.res_boundaries, skip_validation=True + ) + + @property + def group_by_chain(self) -> Self: + """Returns a Structure where all fields are per-chain. + + e.g. chains = struct.group_by_chain['chain_id'] + + Returns: + A new Structure with one atom per chain such that per-atom arrays + such as res_name (i.e. Structure v1 fields) have one element per chain. + """ + # This use of _apply_atom_index_array is safe because the chain/residue/atom + # ordering won't change (essentially applying a chain start mask). + return self._apply_atom_index_array( + self.chain_boundaries, skip_validation=True + ) + + @property + def with_sorted_chains(self) -> Self: + """Returns a new structure with the chains are in reverse spreadsheet style. + + This is the usual order to write chains in an mmCIF: + (A < B < ... < AA < BA < CA < ... < AB < BB < CB ...) + + NB: this method will fail if chains do not conform to this mmCIF naming + convention. + + Only to be used for third party metrics that rely on the chain order. + Elsewhere chains should be identified by name and code should be agnostic to + the order. + """ + sorted_chains = sorted(self.chains, key=mmcif.str_id_to_int_id) + return self.reorder_chains(new_order=sorted_chains) + + @functools.cached_property + def atom_ids(self) -> Sequence[tuple[str, str, None, str]]: + """Gets a list of atom ID tuples from Structure class arrays. + + Returns: + A list of tuples of (chain_id, res_id, insertion_code, atom_name) where + insertion code is always None. There is one element per atom, and the + list is ordered according to the order of atoms in the input arrays. + """ + # Convert to Numpy strings, then to Python strings (dtype=object). + res_ids = self.residues_table.id.astype(str).astype(object) + res_ids = res_ids[ + self.residues_table.index_by_key[self.atoms_table.res_key] + ] + ins_codes = [None] * self.num_atoms + return list( + zip(self.chain_id, res_ids, ins_codes, self.atom_name, strict=True) + ) + + def order_and_drop_atoms_to_match( + self, + other: 'Structure', + *, + allow_missing_atoms: bool = False, + ) -> Self: + """Returns a new structure with atoms ordered & dropped to match another's. + + This performs two operations simultaneously: + * Ordering the atoms in this structure to match the order in the other. + * Dropping atoms in this structure that do not appear in the other. + + Example: + Consider a prediction and ground truth with the following atoms, described + using tuples of `(chain_id, res_id, atom_name)`: + * `prediction: [(A, 1, CA), (A, 1, N), (A, 2, CA), (B, 1, CA)]` + * `ground_truth: [(B, 1, CA), (A, 1, N), (A, 1, CA)]` + Note how the ground truth is missing the `(A, 2, CA)` atom and also + has the atoms in a different order. This method returns a modified + prediction that has reordered atoms and without any atoms not in the ground + truth so that its atom list looks the same as the ground truth atom list. + This means `prediction.coords` and `ground_truth.coords` now have the + same shape and can be compared across the atom dimension. + + Note that matching residues with no atoms and matching chains with no + residues will also be kept. E.g. in the example above, if prediction and + ground truth both had an unresolved residue (A, 3), the output structure + will also have an unresolved residue (A, 3). + + Args: + other: Another `Structure`. This provides the reference ordering that is + used to sort this structure's atom arrays. + allow_missing_atoms: Whether to skip atoms present in `other` but not this + structure and return a structure containing a subset of the atoms in the + other structure. + + Returns: + A new `Structure`, based on this structure, which, if + `allow_missing_atoms` is False, contains exactly the same atoms as in + the `other` structure and which matches the `other` structure in terms + of the order of the atoms in the field arrays. Otherwise, if missing + atoms are allowed then the resulting structure contains a subset of + those atoms in the other structure. + + Raises: + MissingAtomError: If there are atoms present in the other structure that + cannot be found in this structure. + """ + atom_index_map = {atom_id: i for i, + atom_id in enumerate(self.atom_ids)} + try: + if allow_missing_atoms: + # Only include atoms that were found in the other structure. + atom_indices = [ + atom_index + for atom_id in other.atom_ids + if (atom_index := atom_index_map.get(atom_id)) is not None + ] + else: + atom_indices = [ + atom_index_map[atom_id] # Hard fail on missing. + for atom_id in other.atom_ids + ] + except KeyError as e: + if len(e.args[0]) == 4: + chain_id, res_id, ins_code, atom_name = e.args[0] + raise MissingAtomError( + f'No atom in this structure (name: {self._name}) matches atom in ' + f'other structure (name: {other.name}) with internal (label) chain ' + f'ID {chain_id}, residue ID {res_id}, insertion code {ins_code} ' + f'and atom name {atom_name}.' + ) from e + else: + raise + + def _iter_residues(struct: Self) -> Iterable[tuple[str, str]]: + yield from zip( + struct.chains_table['id', struct.residues_table.chain_key], + struct.residues_table.id, + strict=True, + ) + + chain_index_map = { + chain_id: i for i, chain_id in enumerate(self._chains.id) + } + chain_indices = [ + chain_index + for chain_id in other.chains_table.id + if (chain_index := chain_index_map.get(chain_id)) is not None + ] + residue_index_map = { + res_id: i for i, res_id in enumerate(_iter_residues(self)) + } + res_indices = [ + residue_index + for res_id in _iter_residues(other) + if (residue_index := residue_index_map.get(res_id)) is not None + ] + + # Reorder all tables. + chains = self._chains.apply_index( + np.array(chain_indices, dtype=np.int64)) + residues = self._residues.apply_index( + np.array(res_indices, dtype=np.int64)) + atoms = self._atoms.apply_index(np.array(atom_indices, dtype=np.int64)) + + # Get chain keys in the order they appear in the atoms table. + new_chain_boundaries = _get_change_indices(atoms.chain_key) + new_chain_key_order = atoms.chain_key[new_chain_boundaries] + if len(new_chain_key_order) != len(set(new_chain_key_order)): + raise ValueError( + f'Chain keys not contiguous after reordering: {new_chain_key_order}' + ) + + # Get residue keys in the order they appear in the atoms table. + new_res_boundaries = _get_change_indices(atoms.res_key) + new_res_key_order = atoms.res_key[new_res_boundaries] + if len(new_res_key_order) != len(set(new_res_key_order)): + raise ValueError( + f'Residue keys not contiguous after reordering: {new_res_key_order}' + ) + + # If any atoms were deleted, propagate that into the bonds table. + updated_tables = self._cascade_delete( + chains=chains, + residues=residues, + atoms=atoms, + ) + return self.copy_and_update( + chains=chains, + residues=residues, + atoms=updated_tables.atoms, + bonds=updated_tables.bonds, + ) + + def copy_and_update( + self, + *, + name: str | Literal[_UNSET] = _UNSET, + release_date: datetime.date | None | Literal[_UNSET] = _UNSET, + resolution: float | None | Literal[_UNSET] = _UNSET, + structure_method: str | None | Literal[_UNSET] = _UNSET, + bioassembly_data: ( + bioassemblies.BioassemblyData | None | Literal[_UNSET] + ) = _UNSET, + chemical_components_data: ( + struct_chem_comps.ChemicalComponentsData | None | Literal[_UNSET] + ) = _UNSET, + chains: structure_tables.Chains | None | Literal[_UNSET] = _UNSET, + residues: structure_tables.Residues | None | Literal[_UNSET] = _UNSET, + atoms: structure_tables.Atoms | None | Literal[_UNSET] = _UNSET, + bonds: structure_tables.Bonds | None | Literal[_UNSET] = _UNSET, + skip_validation: bool = False, + ) -> Self: + """Performs a shallow copy but with specified fields updated.""" + + def all_unset(fields): + return all(field == _UNSET for field in fields) + + if all_unset((chains, residues, atoms, bonds)): + if all_unset(( + name, + release_date, + resolution, + structure_method, + bioassembly_data, + chemical_components_data, + )): + raise ValueError( + 'Unnecessary call to copy_and_update with no changes. As Structure' + ' and its component tables are immutable, there is no need to copy' + ' it. Any subsequent operation that modifies structure will return' + ' a new object.' + ) + else: + raise ValueError( + 'When only changing global fields, prefer to use the specialised ' + 'copy_and_update_globals.' + ) + + def select(field, default): + return field if field != _UNSET else default + + return Structure( + name=select(name, self.name), + release_date=select(release_date, self.release_date), + resolution=select(resolution, self.resolution), + structure_method=select(structure_method, self.structure_method), + bioassembly_data=select(bioassembly_data, self.bioassembly_data), + chemical_components_data=select( + chemical_components_data, self.chemical_components_data + ), + chains=select(chains, self._chains), + residues=select(residues, self._residues), + atoms=select(atoms, self._atoms), + bonds=select(bonds, self._bonds), + skip_validation=skip_validation, + ) + + def _copy_and_update( + self, skip_validation: bool = False, **changes: Any + ) -> Self: + """Performs a shallow copy but with specified fields updated.""" + if not changes: + raise ValueError( + 'Unnecessary call to copy_and_update with no changes. As Structure ' + 'and its component tables are immutable, there is no need to copy ' + 'it. Any subsequent operation that modifies structure will return a ' + 'new object.' + ) + + if 'author_naming_scheme' in changes: + raise ValueError( + 'Updating using author_naming_scheme is not supported. Update ' + 'auth_asym_id, entity_id, entity_desc fields directly in the chains ' + 'table and auth_seq_id, insertion_code in the residues table.' + ) + + if all(k in GLOBAL_FIELDS for k in changes): + raise ValueError( + 'When only changing global fields, prefer to use the specialised ' + 'copy_and_update_globals.' + ) + + if all(k in V2_FIELDS for k in changes): + constructor_kwargs = {field: self[field] for field in V2_FIELDS} + constructor_kwargs.update(changes) + elif any(k in ('atoms', 'residues', 'chains') for k in changes): + raise ValueError( + 'Cannot specify atoms/chains/residues table changes with non-v2' + f' constructor params: {changes.keys()}' + ) + elif all(k in ATOM_FIELDS for k in changes): + if 'atom_key' not in changes: + raise ValueError( + 'When only changing atom fields, prefer to use the specialised ' + 'copy_and_update_atoms.' + ) + # Only atom fields are being updated, do that directly on the atoms table. + updated_atoms = self._atoms.copy_and_update( + **{ATOM_FIELDS[k]: v for k, v in changes.items()} + ) + constructor_kwargs = { + field: self[field] for field in V2_FIELDS if field != 'atoms' + } + constructor_kwargs['atoms'] = updated_atoms + else: + constructor_kwargs = {field: self[field] + for field in _UPDATEABLE_FIELDS} + constructor_kwargs.update(changes) + return Structure(skip_validation=skip_validation, **constructor_kwargs) + + def copy_and_update_coords(self, coords: np.ndarray) -> Self: + """Performs a shallow copy but with coordinates updated.""" + if coords.shape[-2:] != (self.num_atoms, 3): + raise ValueError( + f'{coords.shape=} does not have last dimensions ({self.num_atoms}, 3)' + ) + updated_atoms = self._atoms.copy_and_update_coords(coords) + return self.copy_and_update(atoms=updated_atoms, skip_validation=True) + + def copy_and_update_from_res_arrays(self, **changes: np.ndarray) -> Self: + """Like copy_and_update but changes are arrays of length num_residues. + + These changes are first scattered into arrays of length num_atoms such + that each value is repeated across the residue at that index, then they + are used as the new values of these fields. + + E.g. + * This structure's res_id: 1, 1, 1, 2, 3, 3 (3 res, 6 atoms) + * new atom_b_factor: 7, 8, 9 + * Returned structure's atom_b_factor: 7, 7, 7, 8, 9, 9 + + Args: + **changes: kwargs corresponding to atom array fields, e.g. atom_x or + atom_b_factor, but with length num_residues rather than num_atoms. Note + that changing atom_key this way is is not supported. + + Returns: + A new `Structure` with all fields other than those specified as kwargs + shallow copied from this structure. The values of the kwargs are + scattered across the atom arrays and then used to overwrite these + fields for the returned structure. + """ + # We create scatter indices by (1) starting from zeros, then (2) setting + # the position where each residue starts to 1 and then (3) doing a + # cumulative sum. Finally, since self.res_boundaries always starts with 0 + # the result of the cumulative sum will start from 1, so (4) we subtract + # 1 to get the final array of zero-based indices. + # Example, 6 atoms, 3 residues at indices 0, 2 and 5. + # (1) 0 0 0 0 0 0 + # (2) 1 0 1 0 0 1 + # (3) 1 1 2 2 2 3 + # (4) 0 0 1 1 1 2 + if not all(c in set(ATOM_FIELDS) - {'atom_key'} for c in changes): + raise ValueError( + 'Changes must only be to atom fields, got changes to' + f' {changes.keys()}' + ) + scatter_idxs = np.zeros((self.num_atoms,), dtype=int) + scatter_idxs[self.res_boundaries] = 1 + scatter_idxs = scatter_idxs.cumsum() - 1 + atom_array_changes = { + ATOM_FIELDS[field]: new_val[scatter_idxs] + for field, new_val in changes.items() + } + updated_atoms = self._atoms.copy_and_update(**atom_array_changes) + return self.copy_and_update(atoms=updated_atoms, skip_validation=True) + + def copy_and_update_globals( + self, + *, + name: str | Literal[_UNSET] = _UNSET, + release_date: datetime.date | Literal[_UNSET] | None = _UNSET, + resolution: float | Literal[_UNSET] | None = _UNSET, + structure_method: str | Literal[_UNSET] | None = _UNSET, + bioassembly_data: ( + bioassemblies.BioassemblyData | Literal[_UNSET] | None + ) = _UNSET, + chemical_components_data: ( + struct_chem_comps.ChemicalComponentsData | Literal[_UNSET] | None + ) = _UNSET, + ) -> Self: + """Returns a shallow copy with the global columns updated.""" + + def select(field, default): + return field if field != _UNSET else default + + name = select(name, self.name) + release_date = select(release_date, self.release_date) + resolution = select(resolution, self.resolution) + structure_method = select(structure_method, self.structure_method) + bioassembly_data = select(bioassembly_data, self.bioassembly_data) + chem_data = select(chemical_components_data, + self.chemical_components_data) + + return Structure( + name=name, + release_date=release_date, + resolution=resolution, + structure_method=structure_method, + bioassembly_data=bioassembly_data, + chemical_components_data=chem_data, + atoms=self._atoms, + residues=self._residues, + chains=self._chains, + bonds=self._bonds, + ) + + def copy_and_update_atoms( + self, + *, + atom_name: np.ndarray | None = None, + atom_element: np.ndarray | None = None, + atom_x: np.ndarray | None = None, + atom_y: np.ndarray | None = None, + atom_z: np.ndarray | None = None, + atom_b_factor: np.ndarray | None = None, + atom_occupancy: np.ndarray | None = None, + ) -> Self: + """Returns a shallow copy with the atoms table updated.""" + new_atoms = structure_tables.Atoms( + key=self._atoms.key, + res_key=self._atoms.res_key, + chain_key=self._atoms.chain_key, + name=atom_name if atom_name is not None else self.atom_name, + element=atom_element if atom_element is not None else self.atom_element, + x=atom_x if atom_x is not None else self.atom_x, + y=atom_y if atom_y is not None else self.atom_y, + z=atom_z if atom_z is not None else self.atom_z, + b_factor=( + atom_b_factor if atom_b_factor is not None else self.atom_b_factor + ), + occupancy=( + atom_occupancy + if atom_occupancy is not None + else self.atom_occupancy + ), + ) + return self.copy_and_update(atoms=new_atoms) + + def _cascade_delete( + self, + *, + chains: structure_tables.Chains | None = None, + residues: structure_tables.Residues | None = None, + atoms: structure_tables.Atoms | None = None, + bonds: structure_tables.Bonds | None = None, + ) -> StructureTables: + """Performs a cascade delete operation on the structure's tables. + + Cascade delete ensures all the tables are consistent after any table fields + are being updated by cascading any deletions down the hierarchy of tables: + chains > residues > atoms > bonds. + + E.g.: if a row from residues table is removed then all the atoms in that + residue will also be removed from the atoms table. In turn this cascades + also to the bond table, by removing any bond row which involves any of those + removed atoms. However the chains table will not be modified, even if + that was the only residue in its chain, because the chains table is above + the residues table in the hierarchy. + + Args: + chains: An optional new chains table. + residues: An optional new residues table. + atoms: An optional new atoms table. + bonds: An optional new bonds table. + + Returns: + A StructureTables object with the updated tables. + """ + if chains_unchanged := chains is None: + chains = self._chains + if residues_unchanged := residues is None: + residues = self._residues + if atoms_unchanged := atoms is None: + atoms = self._atoms + if bonds is None: + bonds = self._bonds + + if not chains_unchanged: + residues_mask = membership.isin(residues.chain_key, set( + chains.key)) # pylint:disable=attribute-error + if not np.all(residues_mask): # Only apply if this is not a no-op. + residues = residues[residues_mask] + residues_unchanged = False + if not residues_unchanged: + atoms_mask = membership.isin(atoms.res_key, set( + residues.key)) # pylint:disable=attribute-error + if not np.all(atoms_mask): # Only apply if this is not a no-op. + atoms = atoms[atoms_mask] + atoms_unchanged = False + if not atoms_unchanged: + bonds = bonds.restrict_to_atoms(atoms.key) + return StructureTables( + chains=chains, residues=residues, atoms=atoms, bonds=bonds + ) + + def filter( + self, + mask: np.ndarray | None = None, + *, + apply_per_element: bool = False, + invert: bool = False, + cascade_delete: CascadeDelete = CascadeDelete.CHAINS, + **predicate_by_field_name: table.FilterPredicate, + ) -> Self: + """Filters the structure by field values and returns a new structure. + + Predicates are specified as keyword arguments, with names following the + pattern: _, where table_name := (chain|res|atom). + For instance the auth_seq_id column in the residues table can be filtered + by passing `res_auth_seq_id=pred_value`. The full list of valid options + are defined in the `col_by_field_name` fields on the different Table + dataclasses. + + Predicate values can be either: + 1. A constant value, e.g. 'CA'. In this case then only rows that match + this value for the given field are retained. + 2. A (non-string) iterable e.g. ('A', 'B'). In this + case then rows are retained if they match any of the provided values for + the given field. + 3. A boolean function e.g. lambda b_fac: b_fac < 100.0. + In this case then only rows that evaluate to True are retained. By + default this function's parameter is expected to be an array, unless + apply_per_element=True. + + Example usage: + # Filter to backbone atoms in residues up to 100 in chain B. + filtered_struct = struct.filter( + chain_id='B', + atom_name=('N', 'CA', 'C'), + res_id=lambda res_id: res_id < 100) + + Example usage where predicate must be applied per-element: + # Filter to residues with IDs in either [1, 100) or [300, 400). + ranges = ((1, 100), (300, 400)) + filtered_struct = struct.filter( + res_id=lambda i: np.any([start <= i < end for start, end in ranges]), + apply_per_element=True) + + Example usage of providing a raw mask: + filtered_struct = struct.filter(struct.atom_b_factor < 10.0) + + Args: + mask: An optional boolean NumPy array with length equal to num_atoms. If + provided then this will be combined with the other predicates so that an + atom is included if it is masked-in *and* matches all the predicates. + apply_per_element: Whether apply predicates to each element individually, + or to pass the whole column array to the predicate. + invert: Whether to remove, rather than retain, the entities which match + the specified predicates. + cascade_delete: Whether to remove residues and chains which are left + unresolved in a cascade. filter operates on the atoms table, removing + atoms which match the predicate. If all atoms in a residue are removed, + the residue is "unresolved". The value of this argument then determines + whether such residues and their parent chains should be deleted. FULL + implies that all unresolved residues should be deleted, and any chains + which are left with no resolved residues should be deleted. CHAINS is + the default behaviour - only chains with no resolved residues, and their + child residues are deleted. Unresolved residues in partially resolved + chains remain. NONE implies that no unresolved residues or chains should + be deleted. + **predicate_by_field_name: A mapping from field name to a predicate. + Filtered columns must be 1D arrays. If multiple fields are provided as + keyword arguments then each predicate is applied and the results are + combined using a boolean AND operation, so an atom is only retained if + it passes all predicates. + + Returns: + A new structure representing a filtered version of the current structure. + + Raises: + ValueError: If mask is provided and is not a bool array with shape + (num_atoms,). + """ + chain_predicates, res_predicates, atom_predicates = ( + _unpack_filter_predicates(predicate_by_field_name) + ) + # Get boolean masks for each table. These are None if none of the filter + # parameters affect the table in question. + chain_mask = self._chains.make_filter_mask( + **chain_predicates, apply_per_element=apply_per_element + ) + res_mask = self._residues.make_filter_mask( + **res_predicates, apply_per_element=apply_per_element + ) + atom_mask = self._atoms.make_filter_mask( + mask, **atom_predicates, apply_per_element=apply_per_element + ) + if atom_mask is None: + atom_mask = np.ones((self._atoms.size,), dtype=bool) + + # Remove atoms that belong to filtered out chains. + if chain_mask is not None: + atom_chain_mask = membership.isin( + self._atoms.chain_key, set(self._chains.key[chain_mask]) + ) + np.logical_and(atom_mask, atom_chain_mask, out=atom_mask) + + # Remove atoms that belong to filtered out residues. + if res_mask is not None: + atom_res_mask = membership.isin( + self._atoms.res_key, set(self._residues.key[res_mask]) + ) + np.logical_and(atom_mask, atom_res_mask, out=atom_mask) + + final_atom_mask = ~atom_mask if invert else atom_mask + + if cascade_delete == CascadeDelete.NONE and np.all(final_atom_mask): + # Shortcut: The filter is a no-op, so just return itself. + return self + + filtered_atoms = typing.cast( + structure_tables.Atoms, self._atoms[final_atom_mask] + ) + + match cascade_delete: + case CascadeDelete.FULL: + nonempty_residues_mask = np.isin( + self._residues.key, filtered_atoms.res_key + ) + filtered_residues = self._residues[nonempty_residues_mask] + nonempty_chain_mask = np.isin( + self._chains.key, filtered_atoms.chain_key + ) + filtered_chains = self._chains[nonempty_chain_mask] + updated_tables = self._cascade_delete( + chains=filtered_chains, + residues=filtered_residues, + atoms=filtered_atoms, + ) + case CascadeDelete.CHAINS: + # To match v1 behavior we remove chains that have no atoms remaining, + # and we remove residues in those chains. + # NB we do not remove empty residues. + nonempty_chain_mask = membership.isin( + self._chains.key, set(filtered_atoms.chain_key) + ) + filtered_chains = self._chains[nonempty_chain_mask] + updated_tables = self._cascade_delete( + chains=filtered_chains, atoms=filtered_atoms + ) + case CascadeDelete.NONE: + updated_tables = self._cascade_delete(atoms=filtered_atoms) + case _: + raise ValueError( + f'Unknown cascade_delete behaviour: {cascade_delete}') + return self.copy_and_update( + chains=updated_tables.chains, + residues=updated_tables.residues, + atoms=updated_tables.atoms, + bonds=updated_tables.bonds, + skip_validation=True, + ) + + def filter_out(self, *args, **kwargs) -> Self: + """Returns a new structure with the specified elements removed.""" + return self.filter(*args, invert=True, **kwargs) + + def filter_to_entity_type( + self, + *, + protein: bool = False, + rna: bool = False, + dna: bool = False, + dna_rna_hybrid: bool = False, + ligand: bool = False, + water: bool = False, + ) -> Self: + """Filters the structure to only include the selected entity types. + + This convenience method abstracts away the specifics of mmCIF entity + type names which, especially for ligands, are non-trivial. + + Args: + protein: Whether to include protein (polypeptide(L)) chains. + rna: Whether to include RNA chains. + dna: Whether to include DNA chains. + dna_rna_hybrid: Whether to include DNA RNA hybrid chains. + ligand: Whether to include ligand (i.e. not polymer) chains. + water: Whether to include water chains. + + Returns: + The filtered structure. + """ + include_types = [] + if protein: + include_types.append(mmcif_names.PROTEIN_CHAIN) + if rna: + include_types.append(mmcif_names.RNA_CHAIN) + if dna: + include_types.append(mmcif_names.DNA_CHAIN) + if dna_rna_hybrid: + include_types.append(mmcif_names.DNA_RNA_HYBRID_CHAIN) + if ligand: + include_types.extend(mmcif_names.LIGAND_CHAIN_TYPES) + if water: + include_types.append(mmcif_names.WATER) + return self.filter(chain_type=include_types) + + def get_stoichiometry( + self, *, fix_non_standard_polymer_res: bool = False + ) -> Sequence[int]: + """Returns the structure's stoichiometry using chain_res_name_sequence. + + Note that everything is considered (protein, RNA, DNA, ligands) except for + water molecules. If you are interested only in a certain type of entities, + filter them out before calling this method. + + Args: + fix_non_standard_polymer_res: If True, maps non standard residues in + protein / RNA / DNA chains to standard residues (e.g. MSE -> MET) or UNK + / N if a match is not found. + + Returns: + A list of integers, one for each unique chain in the structure, + determining the number of that chain appearing in the structure. The + numbers are sorted highest to lowest. E.g. for an A3B2 protein this method + will return [3, 2]. + """ + filtered = self.filter_to_entity_type( + protein=True, + rna=True, + dna=True, + dna_rna_hybrid=True, + ligand=True, + water=False, + ) + seqs = filtered.chain_res_name_sequence( + include_missing_residues=True, + fix_non_standard_polymer_res=fix_non_standard_polymer_res, + ) + + unique_seq_counts = collections.Counter(seqs.values()) + return sorted(unique_seq_counts.values(), reverse=True) + + def without_hydrogen(self) -> Self: + """Returns the structure without hydrogen atoms.""" + return self.filter( + np.logical_and(self._atoms.element != 'H', + self._atoms.element != 'D') + ) + + def without_terminal_oxygens(self) -> Self: + """Returns the structure without terminal oxygen atoms.""" + terminal_oxygen_filter = np.zeros(self.num_atoms, dtype=bool) + for chain_type, atom_name in mmcif_names.TERMINAL_OXYGENS.items(): + chain_keys = self._chains.key[self._chains.type == chain_type] + chain_atom_filter = np.logical_and( + self._atoms.name == atom_name, + np.isin(self._atoms.chain_key, chain_keys), + ) + np.logical_or( + terminal_oxygen_filter, chain_atom_filter, out=terminal_oxygen_filter + ) + return self.filter_out(terminal_oxygen_filter) + + def reset_author_naming_scheme(self) -> Self: + """Remove author chain/residue ids, entity info and use internal ids.""" + new_chains = structure_tables.Chains( + key=self._chains.key, + id=self._chains.id, + type=self._chains.type, + auth_asym_id=self._chains.id, + entity_id=np.arange(1, self.num_chains + + 1).astype(str).astype(object), + entity_desc=np.full(self.num_chains, '.', dtype=object), + ) + new_residues = structure_tables.Residues( + key=self._residues.key, + chain_key=self._residues.chain_key, + id=self._residues.id, + name=self._residues.name, + auth_seq_id=self._residues.id.astype(str).astype(object), + insertion_code=np.full( + self.num_residues(count_unresolved=True), '?', dtype=object + ), + ) + return self.copy_and_update( + chains=new_chains, residues=new_residues, skip_validation=True + ) + + def filter_residues(self, res_mask: np.ndarray) -> Self: + """Filter resolved residues using a boolean mask.""" + required_shape = (self.num_residues(count_unresolved=False),) + if res_mask.shape != required_shape: + raise ValueError( + f'res_mask must have shape {required_shape}. Got: {res_mask.shape}.' + ) + if res_mask.dtype != bool: + raise ValueError( + f'res_mask must have dtype bool. Got: {res_mask.dtype}.') + + filtered_residues = self.present_residues.filter(res_mask) + atom_mask = np.isin(self._atoms.res_key, filtered_residues.key) + return self.filter(atom_mask) + + def filter_coords( + self, coord_predicate: Callable[[np.ndarray], bool] + ) -> Self: + """Filter a structure's atoms by a function of their coordinates. + + Args: + coord_predicate: A boolean function of coordinate vectors (shape (3,)). + + Returns: + A Structure filtered so that only atoms with coords passing the predicate + function are present. + + Raises: + ValueError: If the coords are not shaped (num_atom, 3). + """ + coords = self.coords + if coords.ndim != 2 or coords.shape[-1] != 3: + raise ValueError( + f'coords should have shape (num_atom, 3). Got {coords.shape}.' + ) + mask = np.vectorize(coord_predicate, signature='(n)->()')(coords) + # This use of _apply_atom_index_array is safe because a boolean mask is + # used, which means the chain/residue/atom ordering will stay unchanged. + return self._apply_atom_index_array(mask, skip_validation=True) + + def filter_polymers_to_single_atom_per_res( + self, + representative_atom_by_chain_type: Mapping[ + str, str + ] = mmcif_names.RESIDUE_REPRESENTATIVE_ATOMS, + ) -> Self: + """Filter to one representative atom per polymer residue, ligands unchanged. + + Args: + representative_atom_by_chain_type: Chain type str to atom name, only atoms + with this name will be kept for this chain type. Chains types from the + structure not found in this mapping will keep all their atoms. + + Returns: + A Structure filtered so that per chain types, only specified atoms are + present. + """ + polymer_chain_keys = self._chains.key[ + string_array.isin( + self._chains.type, set(representative_atom_by_chain_type) + ) + ] + polymer_atoms_mask = np.isin(self._atoms.chain_key, polymer_chain_keys) + + wanted_atom_by_chain_key = { + chain_key: representative_atom_by_chain_type.get(chain_type, None) + for chain_key, chain_type in zip(self._chains.key, self._chains.type) + } + wanted_atoms = string_array.remap( + self._atoms.chain_key.astype(object), mapping=wanted_atom_by_chain_key + ) + + representative_polymer_atoms_mask = polymer_atoms_mask & ( + wanted_atoms == self._atoms.name + ) + + return self.filter(representative_polymer_atoms_mask | ~polymer_atoms_mask) + + def drop_non_standard_protein_atoms(self, *, drop_oxt: bool = True) -> Self: + """Drops non-standard atom names from protein chains. + + Args: + drop_oxt: If True, also drop terminal oxygens (OXT). + + Returns: + A new Structure object where the protein chains have been filtered to + only contain atoms with names listed in `atom_types` + (including OXT unless `drop_oxt` is `True`). Non-protein chains are + unaltered. + """ + allowed_names = set(atom_types.ATOM37) + if drop_oxt: + allowed_names = {n for n in allowed_names if n != atom_types.OXT} + + return self.filter_out( + chain_type=mmcif_names.PROTEIN_CHAIN, + atom_name=lambda n: string_array.isin( + n, allowed_names, invert=True), + ) + + def drop_non_standard_atoms( + self, + *, + ccd: chemical_components.Ccd, + drop_unk: bool, + drop_non_ccd: bool, + drop_terminal_oxygens: bool = False, + ) -> Self: + """Drops atoms that are not in the CCD for the given residue type.""" + + # We don't remove any atoms in UNL, as it has no standard atoms. + def _keep(atom_index: int) -> bool: + atom_name = self._atoms.name[atom_index] + res_name = self._residues.name[ + self._residues.index_by_key[self._atoms.res_key[atom_index]] + ] + if drop_unk and res_name in residue_names.UNKNOWN_TYPES: + return False + else: + return ( + (not drop_non_ccd and not ccd.get(res_name)) + or atom_name in struct_chem_comps.get_res_atom_names(ccd, res_name) + or res_name == residue_names.UNL + ) + + standard_atom_mask = np.array( + [_keep(atom_i) for atom_i in range(self.num_atoms)], dtype=bool + ) + standard_atoms = self.filter(mask=standard_atom_mask) + if drop_terminal_oxygens: + standard_atoms = standard_atoms.without_terminal_oxygens() + return standard_atoms + + def find_chains_with_unknown_sequence(self) -> Sequence[str]: + """Returns a sequence of chain IDs that contain only unknown residues.""" + unknown_sequences = [] + for start, end in self.iter_chain_ranges(): + try: + unknown_id = residue_names.UNKNOWN_TYPES.index( + self.res_name[start]) + if start + 1 == end or np.all( + self.res_name[start + 1: end] + == residue_names.UNKNOWN_TYPES[unknown_id] + ): + unknown_sequences.append(self.chain_id[start]) + except ValueError: + pass + return unknown_sequences + + def add_bonds( + self, + bonded_atom_pairs: Sequence[ + tuple[tuple[str, int, str], tuple[str, int, str]], + ], + bond_type: str | None = None, + ) -> Self: + """Returns a structure with new bonds added. + + Args: + bonded_atom_pairs: A sequence of pairs of atoms, with one pair per bond. + Each element of the pair is a tuple of (chain_id, res_id, atom_name), + matching values from the respective fields of this structure. The first + element is the start atom, and the second atom is the end atom of the + bond. + bond_type: This type will be used for all bonds in the structure, where + type follows PDB scheme, e.g. unknown (?), hydrog, metalc, covale, + disulf. + + Returns: + A copy of this structure with the new bonds added. If this structure has + bonds already then the new bonds are concatenated onto the end of the + old bonds. NB: bonds are not deduplicated. + """ + atom_key_lookup: dict[tuple[str, str, None, str], int] = dict( + zip(self.atom_ids, self._atoms.key, strict=True) + ) + + # iter_atoms returns a 4-tuple (chain_id, res_id, ins_code, atom_name) but + # the insertion code is always None. It also uses string residue IDs. + def _to_internal_res_id( + bonded_atom_id: tuple[str, int, str], + ) -> tuple[str, str, None, str]: + return bonded_atom_id[0], str(bonded_atom_id[1]), None, bonded_atom_id[2] + + from_atom_key = [] + dest_atom_key = [] + for from_atom, dest_atom in bonded_atom_pairs: + from_atom_key.append( + atom_key_lookup[_to_internal_res_id(from_atom)]) + dest_atom_key.append( + atom_key_lookup[_to_internal_res_id(dest_atom)]) + num_bonds = len(bonded_atom_pairs) + bonds_key = np.arange(num_bonds, dtype=np.int64) + from_atom_key = np.array(from_atom_key, dtype=np.int64) + dest_atom_key = np.array(dest_atom_key, dtype=np.int64) + all_unk_col = np.array(['?'] * num_bonds, dtype=object) + if bond_type is None: + bond_type_col = all_unk_col + else: + bond_type_col = np.full((num_bonds,), bond_type, dtype=object) + + max_key = -1 if not self._bonds.size else np.max(self._bonds.key) + new_bonds = structure_tables.Bonds( + key=np.concatenate([self._bonds.key, bonds_key + max_key + 1]), + from_atom_key=np.concatenate( + [self._bonds.from_atom_key, from_atom_key] + ), + dest_atom_key=np.concatenate( + [self._bonds.dest_atom_key, dest_atom_key] + ), + type=np.concatenate([self._bonds.type, bond_type_col]), + role=np.concatenate([self._bonds.role, all_unk_col]), + ) + return self.copy_and_update(bonds=new_bonds) + + @property + def coords(self) -> np.ndarray: + """A [..., num_atom, 3] shaped array of atom coordinates.""" + return np.stack([self._atoms.x, self._atoms.y, self._atoms.z], axis=-1) + + def chain_single_letter_sequence( + self, include_missing_residues: bool = True + ) -> Mapping[str, str]: + """Returns a mapping from chain ID to a single letter residue sequence. + + Args: + include_missing_residues: Whether to include residues that have no atoms. + """ + res_table = ( + self._residues if include_missing_residues else self.present_residues + ) + residue_chain_boundaries = _get_change_indices(res_table.chain_key) + boundaries = self._iter_residue_ranges( + residue_chain_boundaries, + count_unresolved=include_missing_residues, + ) + chain_keys = res_table.chain_key[residue_chain_boundaries] + chain_ids = self._chains.apply_array_to_column('id', chain_keys) + chain_types = self._chains.apply_array_to_column('type', chain_keys) + chain_seqs = {} + for idx, (start, end) in enumerate(boundaries): + chain_id = chain_ids[idx] + chain_type = chain_types[idx] + chain_res = res_table.name[start:end] + if chain_type in mmcif_names.PEPTIDE_CHAIN_TYPES: + unknown_default = 'X' + elif chain_type in mmcif_names.NUCLEIC_ACID_CHAIN_TYPES: + unknown_default = 'N' + else: + chain_seqs[chain_id] = 'X' * chain_res.size + continue + + chain_res = string_array.remap( + chain_res, + mapping=residue_names.CCD_NAME_TO_ONE_LETTER, + inplace=False, + default_value=unknown_default, + ) + chain_seqs[chain_id] = ''.join(chain_res.tolist()) + + return chain_seqs + + def polymer_auth_asym_id_to_label_asym_id( + self, + *, + protein: bool = True, + rna: bool = True, + dna: bool = True, + other: bool = True, + ) -> Mapping[str, str]: + """Mapping from author chain ID to internal chain ID, polymers only. + + This mapping is well defined only for polymers (protein, DNA, RNA), but not + for ligands or water. + + E.g. if a structure had the following internal chain IDs (label_asym_id): + A (protein), B (DNA), C (ligand bound to A), D (ligand bound to A), + E (ligand bound to B). + + Such structure would have this internal chain ID (label_asym_id) -> author + chain ID (auth_asym_id) mapping: + A -> A, B -> B, C -> A, D -> A, E -> B + + This is a bijection only for polymers (A, B), but not for ligands. + + Args: + protein: Whether to include protein (polypeptide(L)) chains. + rna: Whether to include RNA chains. + dna: Whether to include DNA chains. + other: Whether to include other polymer chains, e.g. RNA/DNA hybrid or + polypeptide(D). Note that include_other=True must be set in from_mmcif. + + Returns: + A mapping from author chain ID to the internal (label) chain ID for the + given polymer types in the Structure, ligands/water are ignored. + + Raises: + ValueError: If the mapping from internal chain IDs to author chain IDs is + not a bijection for polymer chains. + """ + allowed_types = set() + if protein: + allowed_types.add(mmcif_names.PROTEIN_CHAIN) + if rna: + allowed_types.add(mmcif_names.RNA_CHAIN) + if dna: + allowed_types.add(mmcif_names.DNA_CHAIN) + if other: + non_standard_chain_types = ( + mmcif_names.POLYMER_CHAIN_TYPES + - mmcif_names.STANDARD_POLYMER_CHAIN_TYPES + ) + allowed_types |= non_standard_chain_types + + auth_asym_id_to_label_asym_id = {} + for chain in self.iter_chains(): + if chain['chain_type'] not in allowed_types: + continue + label_asym_id = chain['chain_id'] + auth_asym_id = chain['chain_auth_asym_id'] + # The mapping from author chain id to label chain id is only one-to-one if + # we restrict our attention to polymers. But check nevertheless. + if auth_asym_id in auth_asym_id_to_label_asym_id: + raise ValueError( + f'Author chain ID "{auth_asym_id}" does not have a unique mapping ' + f'to internal chain ID "{label_asym_id}", it is already mapped to ' + f'"{auth_asym_id_to_label_asym_id[auth_asym_id]}".' + ) + auth_asym_id_to_label_asym_id[auth_asym_id] = label_asym_id + + return auth_asym_id_to_label_asym_id + + def polymer_author_chain_single_letter_sequence( + self, + *, + include_missing_residues: bool = True, + protein: bool = True, + rna: bool = True, + dna: bool = True, + other: bool = True, + ) -> Mapping[str, str]: + """Mapping from author chain ID to single letter aa sequence, polymers only. + + This mapping is well defined only for polymers (protein, DNA, RNA), but not + for ligands or water. + + Args: + include_missing_residues: If True then all residues will be returned for + each polymer chain present in the structure. This uses the all_residues + field and will include residues missing due to filtering operations as + well as e.g. unresolved residues specified in an mmCIF header. + protein: Whether to include protein (polypeptide(L)) chains. + rna: Whether to include RNA chains. + dna: Whether to include DNA chains. + other: Whether to include other polymer chains, e.g. RNA/DNA hybrid or + polypeptide(D). Note that include_other=True must be set in from_mmcif. + + Returns: + A mapping from (author) chain IDs to their single-letter sequences for all + polymers in the Structure, ligands/water are ignored. + + Raises: + ValueError: If the mapping from internal chain IDs to author chain IDs is + not a bijection for polymer chains. + """ + label_chain_id_to_seq = self.chain_single_letter_sequence( + include_missing_residues=include_missing_residues + ) + auth_to_label = self.polymer_auth_asym_id_to_label_asym_id( + protein=protein, rna=rna, dna=dna, other=other + ) + return { + auth: label_chain_id_to_seq[label] + for auth, label in auth_to_label.items() + } + + def chain_res_name_sequence( + self, + *, + include_missing_residues: bool = True, + fix_non_standard_polymer_res: bool = False, + ) -> Mapping[str, Sequence[str]]: + """A mapping from internal chain ID to a sequence of residue names. + + The residue names are the full residue names rather than single letter + codes. For instance, for proteins these are the 3 letter CCD codes. + + Args: + include_missing_residues: Whether to include residues with no atoms in the + returned sequences. + fix_non_standard_polymer_res: Whether to map non standard residues in + protein / RNA / DNA chains to standard residues (e.g. MSE -> MET) or UNK + / N if a match is not found. + + Returns: + A mapping from (internal) chain IDs to a sequence of residue names. + """ + res_table = ( + self._residues if include_missing_residues else self.present_residues + ) + residue_chain_boundaries = _get_change_indices(res_table.chain_key) + boundaries = self._iter_residue_ranges( + residue_chain_boundaries, count_unresolved=include_missing_residues + ) + chain_keys = res_table.chain_key[residue_chain_boundaries] + chain_ids = self._chains.apply_array_to_column('id', chain_keys) + chain_types = self._chains.apply_array_to_column('type', chain_keys) + chain_seqs = {} + for idx, (start, end) in enumerate(boundaries): + chain_id = chain_ids[idx] + chain_type = chain_types[idx] + chain_res = res_table.name[start:end] + if ( + fix_non_standard_polymer_res + and chain_type in mmcif_names.POLYMER_CHAIN_TYPES + ): + chain_seqs[chain_id] = tuple( + fix_non_standard_polymer_residues( + res_names=chain_res, chain_type=chain_type + ) + ) + else: + chain_seqs[chain_id] = tuple(chain_res) + + return chain_seqs + + def fix_non_standard_polymer_res( + self, + res_mapper: Callable[ + [np.ndarray, str], np.ndarray + ] = fix_non_standard_polymer_residues, + ) -> Self: + """Replaces non-standard polymer residues with standard alternatives or UNK. + + e.g. maps 'ACE' -> 'UNK', 'MSE' -> 'MET'. + + NB: Only fixes the residue names, but does not fix the atom names. + E.g., 'MSE' will be renamed to 'MET' but its 'SE' atom will not be renamed + to 'S'. Fixing MSE should be done during conversion from mmcif with the + `fix_mse_residues` flag. + + Args: + res_mapper: An optional function that accepts a numpy array of residue + names and chain_type, and returns an array with fixed res_names. This + defaults to fix_non_standard_polymer_residues. + + Returns: + A Structure containing only standard residue types (or 'UNK') in its + polymer chains. + """ + fixed_res_name = self._residues.name.copy() + chain_change_indices = _get_change_indices(self._residues.chain_key) + for start, end in self._iter_atom_ranges(chain_change_indices): + chain_key = self._residues.chain_key[start] + chain_type = self._chains.type[self._chains.index_by_key[chain_key]] + if chain_type not in mmcif_names.POLYMER_CHAIN_TYPES: + continue # We don't need to change anything for non-polymers. + fixed_res_name[start:end] = res_mapper( + fixed_res_name[start:end], chain_type + ) + fixed_residues = self._residues.copy_and_update(name=fixed_res_name) + return self.copy_and_update(residues=fixed_residues, skip_validation=True) + + @property + def slice_leading_dims(self) -> '_LeadingDimSlice': + """Used to create a new Structure by slicing into the leading dimensions. + + Example usage 1: + + ``` + final_state = multi_state_struct.slice_leading_dims[-1] + ``` + + Example usage 2: + + ``` + # Structure has leading batch and time dimensions. + # Get final 3 time frames from first two batch elements. + sliced_strucs = batched_trajectories.slice_leading_dims[:2, -3:] + ``` + """ + return _LeadingDimSlice(self) + + def unstack(self, axis: int = 0) -> Sequence[Self]: + """Unstacks a multi-model structure into a list of Structures. + + This method is the inverse of `stack`. + + Example usage: + ``` + structs = multi_dim_struct.unstack(axis=0) + ``` + + Args: + axis: The axis to unstack over. The structures in the returned list won't + have this axis in their coordinate of b-factor fields. + + Returns: + A list of `Structure`s with length equal to the size of the specified + axis in the coordinate field arrays. + + Raises: + IndexError: If axis does not refer to one of the leading dimensions of + `self.atoms_table.size`. + """ + ndim = self._atoms.ndim + if not (-ndim <= axis < ndim): + raise IndexError( + f'{axis=} is out of range for atom coordinate fields with {ndim=}.' + ) + elif axis < 0: + axis += ndim + if axis == ndim - 1: + raise IndexError( + 'axis must refer to one of the leading dimensions, not the final ' + f'dimension. The atom fields have {ndim=} and {axis=} was specified.' + ) + unstacked = [] + leading_dim_slice = self.slice_leading_dims # Compute once here. + for i in range(self._atoms.shape[axis]): + slice_i = (slice(None),) * axis + (i,) + unstacked.append(leading_dim_slice[slice_i]) + return unstacked + + def split_by_chain(self) -> Sequence[Self]: + """Splits a Structure into single-chain Structures, one for each chain. + + The obtained structures can be merged back together into the original + structure using the `concat` function. + + Returns: + A list of `Structure`s, one for each chain. The order is the same as the + chain order in the original Structure. + """ + return [self.filter(chain_id=chain_id) for chain_id in self.chains] + + def transform_states_to_chains(self) -> Self: + """Transforms states to chains. + + A multi-state protein structure will be transformed to a multi-chain + single-state protein structure. Useful for visualising multiples states to + examine diversity. This structure's coordinate fields must have shape + `(num_states, num_atoms)`. + + Returns: + A new `Structure`, based on this structure, but with the multiple states + now represented as `num_states * num_chains` chains in a + single-state protein. + + Raises: + ValueError: If this structure's array fields don't have shape + `(num_states, num_atoms)`. + """ + if self._atoms.ndim != 2: + raise ValueError( + 'Coordinate field tensor must have 2 dimensions: ' + f'(num_states, num_atoms), got {self._atoms.ndim}.' + ) + return concat(self.unstack(axis=0)) + + def merge_chains( + self, + *, + chain_groups: Sequence[Sequence[str]], + chain_group_ids: Sequence[str] | None = None, + chain_group_types: Sequence[str] | None = None, + ) -> Self: + """Merges chains in each group into a single chain. + + If a Structure has chains A, B, C, D, E, and + `merge_chains([[A, C], [B, D], [E]])` is called, the new Structure will have + 3 chains A, B, C, the first being concatenation of A+C, the second B+D, the + third just the original chain E. + + Args: + chain_groups: Each group defines what chains should be merged into a + single chain. The output structure will therefore have len(chain_groups) + chains. Residue IDs are renumbered to preserve uniqueness within new + chains. Order of chain groups and within each group matters. + chain_group_ids: Optional sequence of new chain IDs for each group. If not + given, the new internal chain IDs (label_asym_id) are assigned in the + standard mmCIF order (i.e. A, B, ..., Z, AA, BA, CA, ...). Author chain + names (auth_asym_id) are set to be equal to the new internal chain IDs. + chain_group_types: Optional sequence of new chain types for each group. If + not given, only chains with the same type can be merged. + + Returns: + A new `Structure` with chains merged together into a single chain within + each chain group. + + Raises: + ValueError: If chain_group_ids or chain_group_types are given but don't + match the length of chain_groups. + ValueError: If the chain IDs in the flattened chain_groups don't match the + chain IDs in the Structure. + ValueError: If chains in any of the groups don't have the same chain type. + """ + if chain_group_ids and len(chain_group_ids) != len(chain_groups): + raise ValueError( + 'chain_group_ids must the same length as chain_groups: ' + f'{len(chain_group_ids)=} != {len(chain_groups)=}' + ) + if chain_group_types and len(chain_group_types) != len(chain_groups): + raise ValueError( + 'chain_group_types must the same length as chain_groups: ' + f'{len(chain_group_types)=} != {len(chain_groups)=}' + ) + flattened = sorted(itertools.chain.from_iterable(chain_groups)) + if flattened != sorted(self.chains): + raise ValueError( + 'IDs in chain groups do not match Structure chain IDs: ' + f'{chain_groups=}, chains={self.chains}' + ) + + new_chain_key_by_chain_id = {} + for new_chain_key, group_chain_ids in enumerate(chain_groups): + for chain_id in group_chain_ids: + new_chain_key_by_chain_id[chain_id] = new_chain_key + + chain_key_remap = {} + new_chain_type_by_chain_key = {} + for old_chain_key, old_chain_id, old_chain_type in zip( + self._chains.key, self._chains.id, self._chains.type + ): + new_chain_key = new_chain_key_by_chain_id[old_chain_id] + chain_key_remap[old_chain_key] = new_chain_key + + if new_chain_key not in new_chain_type_by_chain_key: + new_chain_type_by_chain_key[new_chain_key] = old_chain_type + elif not chain_group_types: + if new_chain_type_by_chain_key[new_chain_key] != old_chain_type: + bad_types = [ + f'{cid}: {self._chains.type[np.where(self._chains.id == cid)][0]}' + for cid in chain_groups[new_chain_key] + ] + raise ValueError( + 'Inconsistent chain types within group:\n' + + '\n'.join(bad_types) + ) + + new_chain_key = np.arange(len(chain_groups), dtype=np.int64) + if chain_group_ids: + new_chain_id = np.array(chain_group_ids, dtype=object) + else: + new_chain_id = np.array( + [mmcif.int_id_to_str_id(k) for k in new_chain_key + 1], dtype=object + ) + if chain_group_types: + new_chain_type = np.array(chain_group_types, dtype=object) + else: + new_chain_type = np.array( + [new_chain_type_by_chain_key[k] for k in new_chain_key], dtype=object + ) + new_chains = structure_tables.Chains( + key=new_chain_key, + id=new_chain_id, + type=new_chain_type, + auth_asym_id=new_chain_id, + entity_id=np.char.mod('%d', new_chain_key + 1).astype(object), + entity_desc=np.full(len(chain_groups), + fill_value='.', dtype=object), + ) + + # Remap chain keys and sort residues to match the chain table order. + new_residues = self._residues.copy_and_remap(chain_key=chain_key_remap) + new_residues = new_residues.apply_index( + np.argsort(new_residues.chain_key, kind='stable') + ) + # Renumber uniquely residues in each chain. + indices = np.arange(new_residues.chain_key.size, dtype=np.int32) + new_res_ids = (indices + 1) - np.maximum.accumulate( + indices * (new_residues.chain_key != + np.roll(new_residues.chain_key, 1)) + ) + new_residues = new_residues.copy_and_update(id=new_res_ids) + + # Remap chain keys and sort atoms to match the chain table order. + new_atoms = self._atoms.copy_and_remap(chain_key=chain_key_remap) + new_atoms = new_atoms.apply_index( + np.argsort(new_atoms.chain_key, kind='stable') + ) + + return self.copy_and_update( + chains=new_chains, + residues=new_residues, + atoms=new_atoms, + bonds=self._bonds, + ) + + def to_res_arrays( + self, + *, + include_missing_residues: bool, + atom_order: Mapping[str, int] = atom_types.ATOM37_ORDER, + ) -> tuple[np.ndarray, np.ndarray]: + """Returns an atom position and atom mask array with a num_res dimension. + + NB: All residues in the structure will appear in the residue + dimension but atoms will only have a True (1.0) mask value if + they are defined in `atom_order`. + + Args: + include_missing_residues: If True then the res arrays will include rows + for missing residues where all atoms will be masked out. Otherwise these + will simply be skipped. + atom_order: Atom order mapping atom names to their index in the atom + dimension of the returned arrays. Default is atom_order for proteins, + choose atom_types.ATOM29_ORDER for nucleics. + + Returns: + A pair of arrays: + * atom_positions: [num_res, atom_type_num, 3] float32 array of coords. + * atom_mask: [num_res, atom_type_num] float32 atom mask denoting + which atoms are present in this Structure. + """ + num_res = self.num_residues(count_unresolved=include_missing_residues) + atom_type_num = len(atom_order) + atom_positions = np.zeros( + (num_res, atom_type_num, 3), dtype=np.float32) + atom_mask = np.zeros((num_res, atom_type_num), dtype=np.float32) + + all_residues = None if not include_missing_residues else self.all_residues + for i, atom in enumerate_residues(self.iter_atoms(), all_residues): + atom_idx = atom_order.get(atom['atom_name']) + if atom_idx is not None: + atom_positions[i, atom_idx, 0] = atom['atom_x'] + atom_positions[i, atom_idx, 1] = atom['atom_y'] + atom_positions[i, atom_idx, 2] = atom['atom_z'] + atom_mask[i, atom_idx] = 1.0 + + return atom_positions, atom_mask + + def to_res_atom_lists( + self, *, include_missing_residues: bool + ) -> Sequence[Sequence[Mapping[str, Any]]]: + """Returns list of atom dictionaries grouped by residue. + + If this is a multi-model structure, each atom will store its fields + atom_x, atom_y, atom_z, and atom_b_factor as Numpy arrays of shape of the + leading dimension(s). If this is a single-mode structure, these fields will + just be scalars. + + Args: + include_missing_residues: If True, then the output list will contain an + empty list of atoms for missing residues. Otherwise missing residues + will simply be skipped. + + Returns: + A list of size `num_res`. Each element in the list represents atoms of one + residue. If a residue is present is present, the list will contain an atom + dictionary for every atom present in that residue. If a residue is missing + and `include_missing_residues=True`, the list for that missing residue + will be empty. + """ + num_res = self.num_residues(count_unresolved=include_missing_residues) + residue_atoms = [[] for _ in range(num_res)] + all_residues = None if not include_missing_residues else self.all_residues + + # We could yield directly in this loop but the code would be more complex. + # Let's optimise if memory usage is an issue. + for res_index, atom in enumerate_residues(self.iter_atoms(), all_residues): + residue_atoms[res_index].append(atom) + + return residue_atoms + + def reorder_chains(self, new_order: Sequence[str]) -> Self: + """Reorders tables so that the label_asym_ids are in the given order. + + This method changes the order of the chains, residues, and atoms tables so + that they are all consistent with each other. Moreover, it remaps chain keys + so that they stay monotonically increasing in chains/residues/atoms tables. + + Args: + new_order: The order in which the chain IDs (label_asym_id) should be. + This must be a permutation of the current chain IDs. + + Returns: + A structure with chains reordered. + """ + if len(new_order) != len(self.chains): + raise ValueError( + f'The new number of chains ({len(new_order)}) does not match the ' + f'current number of chains ({len(self.chains)}).' + ) + new_chain_set = set(new_order) + if len(new_chain_set) != len(new_order): + raise ValueError( + f'The new order {new_order} contains non-unique IDs.') + if new_chain_set.symmetric_difference(set(self.chains)): + raise ValueError( + f'New chain IDs {new_order} do not match the old {set(self.chains)}' + ) + + if self.chains == tuple(new_order): + # Shortcut: the new order is the same as the current one. + return self + + desired_chain_id_pos = {chain_id: i for i, + chain_id in enumerate(new_order)} + + current_chain_index_order = np.empty(self.num_chains, dtype=np.int64) + for index, old_chain_id in enumerate(self._chains.id): + current_chain_index_order[index] = desired_chain_id_pos[old_chain_id] + chain_reorder = np.argsort(current_chain_index_order, kind='stable') + chain_key_map = dict( + zip(self._chains.key[chain_reorder], range(self.num_chains)) + ) + chains = self._chains.apply_index(chain_reorder) + chains = chains.copy_and_remap(key=chain_key_map) + + # The stable sort keeps the original residue ordering within each chain. + residues = self._residues.copy_and_remap(chain_key=chain_key_map) + residue_reorder = np.argsort(residues.chain_key, kind='stable') + residues = residues.apply_index(residue_reorder) + + # The stable sort keeps the original atom ordering within each chain. + atoms = self._atoms.copy_and_remap(chain_key=chain_key_map) + atoms_reorder = np.argsort(atoms.chain_key, kind='stable') + atoms = atoms.apply_index(atoms_reorder) + + # Bonds unchanged - each references 2 atom keys, hence ordering not defined. + return self.copy_and_update(chains=chains, residues=residues, atoms=atoms) + + def rename_auth_asym_ids(self, new_id_by_old_id: Mapping[str, str]) -> Self: + """Returns a new structure with renamed auth_asym_ids. + + Args: + new_id_by_old_id: A mapping from original auth_asym_ids to their new + values. Any auth_asym_ids in this structure that are not in the mapping + will remain unchanged. + + Raises: + ValueError: If any two previously distinct polymer chains do not have + unique names anymore after the rename. + """ + mapped_chains = self._chains.copy_and_remap( + auth_asym_id=new_id_by_old_id) + mapped_polymer_ids = mapped_chains.filter( + type=mmcif_names.POLYMER_CHAIN_TYPES + ).auth_asym_id + if len(mapped_polymer_ids) != len(set(mapped_polymer_ids)): + raise ValueError( + 'The new polymer auth_asym_ids are not unique:' + f' {sorted(mapped_polymer_ids)}.' + ) + return self.copy_and_update(chains=mapped_chains, skip_validation=True) + + def rename_chain_ids(self, new_id_by_old_id: Mapping[str, str]) -> Self: + """Returns a new structure with renamed chain IDs (label_asym_ids). + + The chains' auth_asym_ids will be updated to be identical to the chain ID + since there isn't one unambiguous way to maintain the auth_asym_ids after + renaming the chain IDs (depending on whether you view the auth_asym_id as + more strongly associated with a given physical chain, or with a given + chain ID). + + The residues' auth_seq_id will be updated to be identical to the residue ID + since they are strongly tied to the original author chain naming and keeping + them would be misleading. + + Args: + new_id_by_old_id: A mapping from original chain ID to their new values. + Any chain IDs in this structure that are not in this mapping will remain + unchanged. + + Returns: + A new structure with renamed chains (and bioassembly data if it is + present). + + Raises: + ValueError: If any two previously distinct chains do not have unique names + anymore after the rename. + """ + new_chain_id = string_array.remap(self._chains.id, new_id_by_old_id) + if len(new_chain_id) != len(set(new_chain_id)): + raise ValueError( + f"New chain names aren't unique: {sorted(new_chain_id)}") + + # Map label_asym_ids in the bioassembly data. + if self._bioassembly_data is None: + new_bioassembly_data = None + else: + new_bioassembly_data = self._bioassembly_data.rename_label_asym_ids( + new_id_by_old_id, present_chains=set(self.present_chains.id) + ) + + # Set author residue IDs to be the string version of internal residue IDs. + new_residues = self._residues.copy_and_update( + auth_seq_id=self._residues.id.astype(str).astype(object) + ) + + new_chains = self._chains.copy_and_update( + id=new_chain_id, auth_asym_id=new_chain_id + ) + + return self.copy_and_update( + bioassembly_data=new_bioassembly_data, + chains=new_chains, + residues=new_residues, + skip_validation=True, + ) + + @functools.cached_property + def chains(self) -> tuple[str, ...]: + """Ordered internal chain IDs (label_asym_id) present in the Structure.""" + return tuple(self._chains.id) + + def rename_res_name( + self, + res_name_map: Mapping[str, str], + fail_if_not_found: bool = True, + ) -> Self: + """Returns a copy of this structure with residues renamed. + + Residue names in chemical components data will also be renamed. + + Args: + res_name_map: A mapping from old residue names to new residue names. Any + residues that are not in this mapping will be left unchanged. + fail_if_not_found: Whether to fail if keys in the res_name_map mapping are + not found in this structure's residues' `name` column. + + Raises: + ValueError: If `fail_if_not_found=True` and a residue name isn't found in + the residues table's `name` field. + """ + res_name_set = set(self._residues.name) + if fail_if_not_found: + for res_name in res_name_map: + if res_name not in res_name_set: + raise ValueError( + f'"{res_name}" not found in this structure.') + new_residues = self._residues.copy_and_remap(name=res_name_map) + + if self._chemical_components_data is not None: + chem_comp = { + res_name_map.get(res_name, res_name): data + for res_name, data in self._chemical_components_data.chem_comp.items() + } + new_chem_comp = struct_chem_comps.ChemicalComponentsData(chem_comp) + else: + new_chem_comp = None + + return self.copy_and_update( + residues=new_residues, + chemical_components_data=new_chem_comp, + skip_validation=True, + ) + + def rename_chains_to_match( + self, + other: 'Structure', + *, + fuzzy_match_non_standard_res: bool = True, + ) -> Self: + """Returns a new structure with renamed chains to match another's. + + Example: + This structure has chains: {'A': 'DEEP', 'B': 'MIND', 'C': 'MIND'} + Other structure has chains: {'X': 'DEEP', 'Z': 'MIND', 'Y': 'MIND'} + + After calling this method, you will get a structure that has chains named: + {'X': 'DEEP', 'Z': 'MIND', Y: 'MIND'} + + Args: + other: Another `Structure`. This provides the reference chain names that + is used to rename this structure's chains. + fuzzy_match_non_standard_res: If True, protein/RNA/DNA chains with the + same one letter sequence will be matched. e.g. "MET-MET-UNK1" will match + "MET-MSE-UNK2", since both will be mapped to "MMX". If False, we require + the full res_names to match. + + Returns: + A new `Structure`, based on this structure, which has chains renamed to + match the other structure. + """ + sequences = self.chain_res_name_sequence( + include_missing_residues=True, + fix_non_standard_polymer_res=fuzzy_match_non_standard_res, + ) + + other_sequences = other.chain_res_name_sequence( + include_missing_residues=True, + fix_non_standard_polymer_res=fuzzy_match_non_standard_res, + ) + + # Check that the sequences are the same. + sequence_counts = collections.Counter(sequences.values()) + other_sequence_counts = collections.Counter(other_sequences.values()) + if other_sequence_counts != sequence_counts: + raise ValueError( + 'The other structure does not have the same sequences\n' + f' other: {other_sequence_counts}\n self: {sequence_counts}' + ) + + new_decoy_id_by_old_id = {} + used_chain_ids = set() + # Sort self keys and take min over other to make matching deterministic. + # The matching is arbitrary but this helps debugging. + for self_chain_id, self_seq in sorted(sequences.items()): + # Find corresponding chains in the other structure. + other_chain_id = min( + k + for k, v in other_sequences.items() + if v == self_seq and k not in used_chain_ids + ) + + new_decoy_id_by_old_id[self_chain_id] = other_chain_id + used_chain_ids.add(other_chain_id) + + return self.rename_chain_ids(new_decoy_id_by_old_id) + + def _apply_bioassembly_transform( + self, transform: bioassemblies.Transform + ) -> Self: + """Applies a bioassembly transform to this structure.""" + base_struct = self.filter(chain_id=transform.chain_ids) + transformed_atoms = base_struct.atoms_table.copy_and_update_coords( + transform.apply_to_coords(base_struct.coords) + ) + transformed_chains = base_struct.chains_table.copy_and_remap( + id=transform.chain_id_rename_map + ) + # Set the transformed author chain ID to match the label chain ID. + transformed_chains = transformed_chains.copy_and_update( + auth_asym_id=transformed_chains.id + ) + return base_struct.copy_and_update( + chains=transformed_chains, + atoms=transformed_atoms, + skip_validation=True, + ) + + def generate_bioassembly(self, assembly_id: str | None = None) -> Self: + """Generates a biological assembly as a new `Structure`. + + When no assembly ID is provided this method produces a default assembly. + If this structure has no `bioassembly_data` then this returns itself + unchanged. Otherwise a default assembly ID is picked with + `BioassemblyData.get_default_assembly_id()`. + + Args: + assembly_id: The assembly ID to generate, or None to generate a default + bioassembly. + + Returns: + A new `Structure`, based on this one, representing the specified + bioassembly. Note that if the bioassembly contains copies of chains + in the original structure then they will be given new unique chain IDs. + + Raises: + ValueError: If this structure's `bioassembly_data` is `None` and + `assembly_id` is not `None`. + """ + if self._bioassembly_data is None: + if assembly_id is None: + return self + else: + raise ValueError( + f'Unset bioassembly_data, cannot generate assembly {assembly_id}' + ) + + if assembly_id is None: + assembly_id = self._bioassembly_data.get_default_assembly_id() + + transformed_structs = [ + self._apply_bioassembly_transform(transform) + for transform in self._bioassembly_data.get_transforms(assembly_id) + ] + + # We don't need to assign unique chain IDs because the bioassembly + # transform takes care of remapping chain IDs to be unique. + concatenated = concat(transformed_structs, + assign_unique_chain_ids=False) + + # Copy over all scalar fields (e.g. name, release date, etc.) other than + # bioassembly_data because it relates only to the pre-transformed structure. + return concatenated.copy_and_update_globals( + name=self.name, + release_date=self.release_date, + resolution=self.resolution, + structure_method=self.structure_method, + bioassembly_data=None, + chemical_components_data=self.chemical_components_data, + ) + + def _to_mmcif_header(self) -> Mapping[str, Sequence[str]]: + raw_mmcif = collections.defaultdict(list) + raw_mmcif['data_'] = [self._name] + raw_mmcif['_entry.id'] = [self._name] + + if self._release_date is not None: + date = [datetime.datetime.strftime(self._release_date, '%Y-%m-%d')] + raw_mmcif['_pdbx_audit_revision_history.revision_date'] = date + raw_mmcif['_pdbx_database_status.recvd_initial_deposition_date'] = date + + if self._resolution is not None: + raw_mmcif['_refine.ls_d_res_high'] = ['%.2f' % self._resolution] + + if self._structure_method is not None: + for method in self._structure_method.split(','): + raw_mmcif['_exptl.method'].append(method) + + if self._bioassembly_data is not None: + raw_mmcif.update(self._bioassembly_data.to_mmcif_dict()) + + # Populate chemical components data for all residues of this Structure. + if self._chemical_components_data: + raw_mmcif.update(self._chemical_components_data.to_mmcif_dict()) + + # Add _software table to store version number used to generate mmCIF. + # Only required data items are used (+ _software.version). + raw_mmcif['_software.pdbx_ordinal'] = ['1'] + raw_mmcif['_software.name'] = ['DeepMind Structure Class'] + raw_mmcif['_software.version'] = [self._VERSION] + raw_mmcif['_software.classification'] = ['other'] # Required. + + return raw_mmcif + + def to_mmcif_dict( + self, + *, + coords_decimal_places: int = _COORDS_DECIMAL_PLACES, + ) -> mmcif.Mmcif: + """Returns an Mmcif representing the structure.""" + header = self._to_mmcif_header() + sequence_tables = structure_tables.to_mmcif_sequence_and_entity_tables( + self._chains, self._residues, self._atoms.res_key + ) + atom_and_bond_tables = structure_tables.to_mmcif_atom_site_and_bonds_table( + chains=self._chains, + residues=self._residues, + atoms=self._atoms, + bonds=self._bonds, + coords_decimal_places=coords_decimal_places, + ) + return mmcif.Mmcif({**header, **sequence_tables, **atom_and_bond_tables}) + + def to_mmcif( + self, *, coords_decimal_places: int = _COORDS_DECIMAL_PLACES + ) -> str: + """Returns an mmCIF string representing the structure. + + Args: + coords_decimal_places: The number of decimal places to keep for atom + coordinates, including trailing zeros. + """ + return self.to_mmcif_dict( + coords_decimal_places=coords_decimal_places + ).to_string() + + +class _LeadingDimSlice: + """Helper class for slicing the leading dimensions of a `Structure`. + + Wraps a `Structure` instance and applies a slice operation to the coordinate + fields and other fields that may have leading dimensions (e.g. b_factor). + + Example usage: + t0_struct = multi_state_struct.slice_leading_dims[0] + """ + + def __init__(self, struct: Structure): + self._struct = struct + + def __getitem__(self, *args, **kwargs) -> Structure: + sliced_atom_cols = {} + for col_name in structure_tables.Atoms.multimodel_cols: + if (col := self._struct.atoms_table.get_column(col_name)).ndim > 1: + sliced_col = col.__getitem__(*args, **kwargs) + if ( + not sliced_col.shape + or sliced_col.shape[-1] != self._struct.num_atoms + ): + raise ValueError( + 'Coordinate slice cannot change final (atom) dimension.' + ) + sliced_atom_cols[col_name] = sliced_col + sliced_atoms = self._struct.atoms_table.copy_and_update( + **sliced_atom_cols) + return self._struct.copy_and_update(atoms=sliced_atoms, skip_validation=True) + + +def stack(structs: Sequence[Structure], axis: int = 0) -> Structure: + """Stacks multiple structures into a single multi-model Structure. + + This function is the inverse of `Structure.unstack()`. + + NB: this function assumes that every structure in `structs` is identical + other than the coordinates and b-factors. Under this assumption we can safely + copy all these identical fields from the first element of structs w.l.o.g. + However this is not checked in full detail as full comparison is expensive. + Instead this only checks that the `atom_name` field is identical, and that + the coordinates have the same shape. + + Usage example: + ``` + multi_model_struct = structure.stack(structs, axis=0) + ``` + + Args: + structs: A sequence of structures, each with the same atoms, but they may + have different coordinates and b-factors. If any b-factors are not None + then they must have the same shape as each of the coordinate fields. + axis: The axis in the returned structure that represents the different + structures in `structs` and will have size `len(structs)`. This cannot be + the final dimension as this is reserved for `num_atoms`. + + Returns: + A `Structure` with the same atoms as the structures in `structs` but with + all of their coordinates stacked into a new leading axis. + + Raises: + ValueError: If `structs` is empty. + ValueError: If `structs` do not all have the same `atom_name` field. + """ + if not structs: + raise ValueError('Need at least one Structure to stack.') + struct_0, *other_structs = structs + for i, struct in enumerate(other_structs, start=1): + # Check that every structure has the same atom name column. + # This check is intended to catch cases where the input structures might + # contain the same atoms, but in different orders. This won't catch every + # such case, e.g. if these are carbon-alpha-only structures, but should + # catch most cases. + if np.any(struct.atoms_table.name != struct_0.atoms_table.name): + raise ValueError( + f'structs[0] and structs[{i}] have mismatching atom name columns.' + ) + + stacked_atoms = struct_0.atoms_table.copy_and_update( + x=np.stack([s.atoms_table.x for s in structs], axis=axis), + y=np.stack([s.atoms_table.y for s in structs], axis=axis), + z=np.stack([s.atoms_table.z for s in structs], axis=axis), + b_factor=np.stack([s.atoms_table.b_factor for s in structs], axis=axis), + occupancy=np.stack( + [s.atoms_table.occupancy for s in structs], axis=axis), + ) + return struct_0.copy_and_update(atoms=stacked_atoms, skip_validation=True) + + +def _assign_unique_chain_ids( + structs: Iterable[Structure], +) -> Sequence[Structure]: + """Creates a sequence of `Structure` objects with unique chain IDs. + + Let e.g. [A, B] denote a structure of two chains A and B, then this function + performs the following kind of renaming operation: + + e.g.: [Z], [C], [B, C] -> [A], [B], [C, D] + + NB: This function uses Structure.rename_chain_ids which will define each + structure's chains.auth_asym_id to be identical to its chains.id columns. + + Args: + structs: Structures whose chains ids are to be uniquified. + + Returns: + A sequence with the same number of elements as `structs` but where each + element has had its chains renamed so that they aren't shared with any + other `Structure` in the sequence. + """ + # Start counting at 1 because mmcif.int_id_to_str_id expects integers >= 1. + chain_counter = 1 + structs_with_new_chain_ids = [] + for struct in structs: + rename_map = {} + for chain_id in struct.chains: + rename_map[chain_id] = mmcif.int_id_to_str_id(chain_counter) + chain_counter += 1 + renamed = struct.rename_chain_ids(rename_map) + structs_with_new_chain_ids.append(renamed) + return structs_with_new_chain_ids + + +def concat( + structs: Sequence[Structure], + *, + name: str | None = None, + assign_unique_chain_ids: bool = True, +) -> Structure: + """Concatenates structures along the atom dimension. + + NB: By default this function will first assign unique chain IDs to all chains + in `structs` so that the resulting structure does not contain duplicate chain + IDs. This will also fix entity IDs and author chain IDs. If this is disabled + via `assign_unique_chain_ids=False` the user must ensure that there are no + duplicate chains (label_asym_id). However, duplicate entity IDs and author + chain IDs are allowed as that might be the desired behavior. + + If `assign_unique_chain_ids=True`, note also that the chain_ids may be + overwritten even if they are already unique. + + Let e.g. [A, B] denote a structure of two chains A and B, then this function + performs the following kind of concatenation operation: + + assign_unique_chain_ids=True: + label chain IDS : [Z], [C], [B, C] -> [A, B, C, D] + author chain IDS: [U], [V], [V, C] -> [A, B, C, D] + entity IDs : [1], [1], [3, 3] -> [1, 2, 3, 4] + assign_unique_chain_ids=False: + label chain IDS : [D], [B], [C, A] -> [D, B, C, A] (inputs must be unique) + author chain IDS: [U], [V], [V, A] -> [U, V, V, A] + entity IDs : [1], [1], [3, 3] -> [1, 1, 3, 3] + + NB: This operation loses some information from the elements of `structs`, + namely the `name`, `resolution`, `release_date` and `bioassembly_data` fields. + + Args: + structs: The `Structure` instances to concatenate. These should all have the + same number and shape of leading dimensions (i.e. if any are multi-model + structures then they should all have the same number of models). + name: Optional name to give to the concatenated structure. If None, the name + will be concatenation of names of all concatenated structures. + assign_unique_chain_ids: Whether this function will first assign new unique + chain IDs, entity IDs and author chain IDs to every chain in `structs`. If + `False` then users must ensure chain IDs are already unique, otherwise an + exception is raised. See `_assign_unique_chain_ids` for more information + on how this is performed. + + Returns: + A new concatenated `Structure` with all of the chains in `structs` combined + into one new structure. The new structure will be named by joining the + names of `structs` with underscores. + + Raises: + ValueError: If `structs` is empty. + ValueError: If `assign_unique_chain_ids=False` and not all chains in + `structs` have unique chain IDs. + """ + if not structs: + raise ValueError('Need at least one Structure to concatenate.') + if assign_unique_chain_ids: + structs = _assign_unique_chain_ids(structs) + + chemical_components_data = {} + seen_label_chain_ids = set() + for i, struct in enumerate(structs): + if not assign_unique_chain_ids: + if seen_cid := seen_label_chain_ids.intersection(struct.chains): + raise ValueError( + f'Chain IDs {seen_cid} from structs[{i}] also exist in other' + ' members of structs. All given structures must have unique chain' + ' IDs. Consider setting assign_unique_chain_ids=True.' + ) + seen_label_chain_ids.update(struct.chains) + + if struct.chemical_components_data is not None: + # pytype: disable=attribute-error # always-use-property-annotation + chemical_components_data.update( + struct.chemical_components_data.chem_comp) + + concatted_struct = table.concat_databases(structs) + name = name if name is not None else '_'.join(s.name for s in structs) + # Chain IDs (label and author) are fixed at this point, fix also entity IDs. + if assign_unique_chain_ids: + entity_id = np.char.mod('%d', np.arange( + 1, concatted_struct.num_chains + 1)) + chains = concatted_struct.chains_table.copy_and_update( + entity_id=entity_id) + else: + chains = concatted_struct.chains_table + return concatted_struct.copy_and_update( + name=name, + release_date=None, + resolution=None, + structure_method=None, + bioassembly_data=None, + chemical_components_data=( + struct_chem_comps.ChemicalComponentsData(chemical_components_data) + if chemical_components_data + else None + ), + chains=chains, + skip_validation=True, # Already validated by table.concat_databases. + ) + + +def multichain_residue_index( + struct: Structure, chain_offset: int = 9000, between_chain_buffer: int = 1000 +) -> np.ndarray: + """Compute a residue index array that is monotonic across all chains. + + Lots of metrics (lddt, l1_long, etc) require computing a + distance-along-chain between two residues. For multimers we want to ensure + that any residues on different chains have a high along-chain distance + (i.e. they should always count as long-range contacts for example). To + do this we add 10000 to the residue indices of each chain, and enforce that + the residue index is monotonically increasing across the whole complex. + + Note: This returns the same as struct.res_id for monomers. + + Args: + struct: The structure to make a multichain residue index for. + chain_offset: The start of each chain is offset by at least this amount. + This must be larger than the absolute range of standard residue IDs. + between_chain_buffer: The final residue in one chain will have at least this + much of a buffer before the first residue in the next chain. + + Returns: + A monotonically increasing residue index, with at least + `between_chain_buffer` residues in between each chain. + """ + if struct.num_atoms: + res_id_range = np.max(struct.res_id) - np.min(struct.res_id) + assert res_id_range < chain_offset + chain_id_int = struct.chain_id + monotonic_chain_id_int = np.concatenate( + ([0], np.cumsum(chain_id_int[1:] != chain_id_int[:-1])) + ) + return struct.res_id + monotonic_chain_id_int * ( + chain_offset + between_chain_buffer + ) + + +def make_empty_structure() -> Structure: + """Returns a new structure consisting of empty array fields.""" + return Structure( + chains=structure_tables.Chains.make_empty(), + residues=structure_tables.Residues.make_empty(), + atoms=structure_tables.Atoms.make_empty(), + bonds=structure_tables.Bonds.make_empty(), + ) + + +def enumerate_residues( + atom_iter: Iterable[Mapping[str, Any]], + all_residues: AllResidues | None = None, +) -> Iterator[tuple[int, Mapping[str, Any]]]: + """Provides a zero-indexed enumeration of residues in an atom iterable. + + Args: + atom_iter: An iterable of atom dicts as returned by Structure.iter_atoms(). + all_residues: (Optional) A structure's all_residues field. If present then + this will be used to count missing residues by adding appropriate gaps in + the residue enumeration. + + Yields: + (res_i, atom) pairs where atom is the unmodified atom dict and res_i is a + zero-based index for the residue that the atom belongs to. + """ + if all_residues is None: + prev_res = None + res_i = -1 + for atom in atom_iter: + res = (atom['chain_id'], atom['res_id']) + if res != prev_res: + prev_res = res + res_i += 1 + yield res_i, atom + else: + all_res_seq = [] # Sequence of (chain_id, res_id) for all chains. + prev_chain = None + res_i = 0 + for atom in atom_iter: + chain_id = atom['chain_id'] + if chain_id not in all_residues: + raise ValueError( + f'Atom {atom} does not belong to any residue in all_residues.' + ) + if chain_id != prev_chain: + prev_chain = chain_id + all_res_seq.extend( + (chain_id, res_id) for (_, res_id) in all_residues[chain_id] + ) + res = (chain_id, atom['res_id']) + while res_i < len(all_res_seq) and res != all_res_seq[res_i]: + res_i += 1 + if res_i == len(all_res_seq): + raise ValueError( + f'Atom {atom} does not belong to a residue in all_residues.' + ) + yield res_i, atom diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/structure_tables.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/structure_tables.py new file mode 100644 index 0000000000000000000000000000000000000000..2867d8582e4ad7c8ec748be9b29f618517130fc8 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/structure_tables.py @@ -0,0 +1,842 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Table implementations for the Structure class.""" + +import collections +from collections.abc import Mapping, Sequence +import dataclasses +import functools +import itertools +import typing +from typing_extensions import Any, ClassVar, Self +import numpy as np +from alphafold3.constants import mmcif_names +from alphafold3.constants import residue_names +from alphafold3.cpp import aggregation +from alphafold3.cpp import string_array +from alphafold3.structure import bonds as bonds_module +from alphafold3.structure import mmcif +from alphafold3.structure import table + + +Bonds = bonds_module.Bonds + + +def _residue_name_to_record_name( + residue_name: np.ndarray, + polymer_mask: np.ndarray, +) -> np.ndarray: + """Returns record names (ATOM/HETATM) given residue names and polymer mask.""" + record_name = np.array(['HETATM'] * len(residue_name), dtype=object) + record_name[polymer_mask] = string_array.remap( + residue_name[polymer_mask], + mapping={r: 'ATOM' for r in residue_names.STANDARD_POLYMER_TYPES}, + default_value='HETATM', + ) + return record_name + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class AuthorNamingScheme: + """A mapping from internal values to author values in a mmCIF. + + Fields: + auth_asym_id: A mapping from label_asym_id to auth_asym_id. + auth_seq_id: A mapping from label_asym_id to a mapping from + label_seq_id to auth_seq_id. + insertion_code: A mapping from label_asym_id to a mapping from + label_seq_id to insertion codes. + entity_id: A mapping from label_asym_id to _entity.id. + entity_desc: A mapping from _entity.id to _entity.pdbx_description. + """ + + auth_asym_id: Mapping[str, str] + auth_seq_id: Mapping[str, Mapping[int, str]] + insertion_code: Mapping[str, Mapping[int, str | None]] + entity_id: Mapping[str, str] + entity_desc: Mapping[str, str] + + +def _default( + candidate_value: np.ndarray | None, default_value: Sequence[Any], dtype: Any +) -> np.ndarray: + if candidate_value is None: + return np.array(default_value, dtype=dtype) + return np.array(candidate_value, dtype=dtype) + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class Atoms(table.Table): + """Table of atoms in a Structure.""" + + chain_key: np.ndarray + res_key: np.ndarray + name: np.ndarray + element: np.ndarray + x: np.ndarray + y: np.ndarray + z: np.ndarray + b_factor: np.ndarray + occupancy: np.ndarray + multimodel_cols: ClassVar[tuple[str, ...]] = ( + 'x', + 'y', + 'z', + 'b_factor', + 'occupancy', + ) + + def __post_init__(self): + # Validates that the atom coordinates, b-factors and occupancies are finite. + for column_name in ('x', 'y', 'z', 'b_factor', 'occupancy'): + column = self.get_column(column_name) + if not np.isfinite(column).all(): + raise ValueError( + f'Column {column_name} must not contain NaN/inf values.' + ) + # super().__post_init__() can't be used as that causes the following error: + # TypeError: super(type, obj): obj must be an instance or subtype of type + super(Atoms, self).__post_init__() + + @classmethod + def make_empty(cls) -> Self: + return cls( + key=np.array([], dtype=np.int64), + chain_key=np.array([], dtype=np.int64), + res_key=np.array([], dtype=np.int64), + name=np.array([], dtype=object), + element=np.array([], dtype=object), + x=np.array([], dtype=np.float32), + y=np.array([], dtype=np.float32), + z=np.array([], dtype=np.float32), + b_factor=np.array([], dtype=np.float32), + occupancy=np.array([], dtype=np.float32), + ) + + @classmethod + def from_defaults( + cls, + *, + chain_key: np.ndarray, + res_key: np.ndarray, + key: np.ndarray | None = None, + name: np.ndarray | None = None, + element: np.ndarray | None = None, + x: np.ndarray | None = None, + y: np.ndarray | None = None, + z: np.ndarray | None = None, + b_factor: np.ndarray | None = None, + occupancy: np.ndarray | None = None, + ) -> Self: + """Create an Atoms table with minimal user inputs.""" + num_atoms = len(chain_key) + if not num_atoms: + return cls.make_empty() + return Atoms( + chain_key=chain_key, + res_key=res_key, + key=_default(key, np.arange(num_atoms), np.int64), + name=_default(name, ['?'] * num_atoms, object), + element=_default(element, ['?'] * num_atoms, object), + x=_default(x, [0.0] * num_atoms, np.float32), + y=_default(y, [0.0] * num_atoms, np.float32), + z=_default(z, [0.0] * num_atoms, np.float32), + b_factor=_default(b_factor, [0.0] * num_atoms, np.float32), + occupancy=_default(occupancy, [1.0] * num_atoms, np.float32), + ) + + def get_value_by_index( + self, column_name: str, index: int + ) -> table.TableEntry | np.ndarray: + if column_name in self.multimodel_cols: + return self.get_column(column_name)[..., index] + else: + return self.get_column(column_name)[index] + + def copy_and_update_coords(self, coords: np.ndarray) -> Self: + """Returns a copy with the x, y and z columns updated.""" + if coords.shape[-1] != 3: + raise ValueError( + f'Expecting 3-dimensional coordinates, got {coords.shape}' + ) + return typing.cast( + Atoms, + self.copy_and_update( + x=coords[..., 0], y=coords[..., 1], z=coords[..., 2] + ), + ) + + @property + def shape(self) -> tuple[int, ...]: + return self.x.shape + + @property + def ndim(self) -> int: + return len(self.shape) + + @functools.cached_property + def num_models(self) -> int: + """The number of models of this Structure.""" + leading_dims = self.shape[:-1] + match leading_dims: + case(): + return 1 + case(single_leading_dim_size,): + return single_leading_dim_size + case _: + raise ValueError( + 'num_models not defined for atom tables with more than one ' + 'leading dimension.' + ) + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class Residues(table.Table): + """Table of residues in a Structure.""" + + chain_key: np.ndarray + id: np.ndarray + name: np.ndarray + auth_seq_id: np.ndarray + insertion_code: np.ndarray + + @classmethod + def make_empty(cls) -> Self: + return cls( + key=np.array([], dtype=np.int64), + chain_key=np.array([], dtype=np.int64), + id=np.array([], dtype=np.int32), + name=np.array([], dtype=object), + auth_seq_id=np.array([], dtype=object), + insertion_code=np.array([], dtype=object), + ) + + @classmethod + def from_defaults( + cls, + *, + id: np.ndarray, # pylint:disable=redefined-builtin + chain_key: np.ndarray, + key: np.ndarray | None = None, + name: np.ndarray | None = None, + auth_seq_id: np.ndarray | None = None, + insertion_code: np.ndarray | None = None, + ) -> Self: + """Create a Residues table with minimal user inputs.""" + num_res = len(id) + if not num_res: + return cls.make_empty() + return Residues( + key=_default(key, np.arange(num_res), np.int64), + id=id, + chain_key=chain_key, + name=_default(name, ['UNK'] * num_res, object), + auth_seq_id=_default(auth_seq_id, id.astype(str), object), + insertion_code=_default(insertion_code, ['?'] * num_res, object), + ) + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class Chains(table.Table): + """Table of chains in a Structure.""" + + id: np.ndarray + type: np.ndarray + auth_asym_id: np.ndarray + entity_id: np.ndarray + entity_desc: np.ndarray + + @classmethod + def make_empty(cls) -> Self: + return cls( + key=np.array([], dtype=np.int64), + id=np.array([], dtype=object), + type=np.array([], dtype=object), + auth_asym_id=np.array([], dtype=object), + entity_id=np.array([], dtype=object), + entity_desc=np.array([], dtype=object), + ) + + @classmethod + def from_defaults( + cls, + *, + id: np.ndarray, # pylint:disable=redefined-builtin + key: np.ndarray | None = None, + type: np.ndarray | None = None, # pylint:disable=redefined-builtin + auth_asym_id: np.ndarray | None = None, + entity_id: np.ndarray | None = None, + entity_desc: np.ndarray | None = None, + ) -> Self: + """Create a Chains table with minimal user inputs.""" + num_chains = len(id) + if not num_chains: + return cls.make_empty() + + return Chains( + key=_default(key, np.arange(num_chains), np.int64), + id=id, + type=_default(type, [mmcif_names.PROTEIN_CHAIN] + * num_chains, object), + auth_asym_id=_default(auth_asym_id, id, object), + entity_id=_default( + entity_id, np.arange(1, num_chains + 1).astype(str), object + ), + entity_desc=_default(entity_desc, ['.'] * num_chains, object), + ) + + +def to_mmcif_sequence_and_entity_tables( + chains: Chains, + residues: Residues, + atom_res_key: np.ndarray, +) -> Mapping[str, Sequence[str]]: + """Returns raw sequence and entity mmCIF tables.""" + raw_mmcif = collections.defaultdict(list) + chains_by_entity_id = {} + written_entity_poly_seq_ids = set() + present_res_keys = set(atom_res_key) + + # Performance optimisation: Find residue indices for each chain in advance, so + # that we don't have to do redundant masking work for each chain. + res_indices_for_chain = aggregation.indices_grouped_by_value( + residues.chain_key + ) + + for chain in chains.iterrows(): + # Add all chain information to the _struct_asym table. + chain_id = chain['id'] # Saves multiple dict lookups. + auth_asym_id = chain['auth_asym_id'] + entity_id = chain['entity_id'] + chains_by_entity_id.setdefault(entity_id, []).append(chain) + raw_mmcif['_struct_asym.id'].append(chain_id) + raw_mmcif['_struct_asym.entity_id'].append(entity_id) + + res_chain_indices = res_indices_for_chain[chain['key']] + chain_type = chain['type'] + is_polymer = chain_type in mmcif_names.POLYMER_CHAIN_TYPES + is_water = chain_type == mmcif_names.WATER + is_branched = len( + res_chain_indices) > 1 and not is_polymer and not is_water + write_entity_poly_seq = entity_id not in written_entity_poly_seq_ids + + # Iterate over the individual masked residue table columns, as that doesn't + # create a copy (only a view), while residues[res_chain_indices] does. + for res_key, res_name, res_id, pdb_seq_num, res_ins_code in zip( + residues.key[res_chain_indices], + residues.name[res_chain_indices], + residues.id[res_chain_indices], + residues.auth_seq_id[res_chain_indices], + residues.insertion_code[res_chain_indices], + strict=True, + ): + is_missing = res_key not in present_res_keys + str_res_id = str(res_id) + # While atom_site uses "?" for insertion codes, scheme tables use ".". + ins_code = (res_ins_code or '.').replace('?', '.') + auth_seq_num = '?' if is_missing else pdb_seq_num + + if is_polymer: + raw_mmcif['_pdbx_poly_seq_scheme.asym_id'].append(chain_id) + raw_mmcif['_pdbx_poly_seq_scheme.entity_id'].append(entity_id) + raw_mmcif['_pdbx_poly_seq_scheme.seq_id'].append(str_res_id) + raw_mmcif['_pdbx_poly_seq_scheme.mon_id'].append(res_name) + raw_mmcif['_pdbx_poly_seq_scheme.pdb_seq_num'].append( + pdb_seq_num) + raw_mmcif['_pdbx_poly_seq_scheme.auth_seq_num'].append( + auth_seq_num) + raw_mmcif['_pdbx_poly_seq_scheme.pdb_strand_id'].append( + auth_asym_id) + raw_mmcif['_pdbx_poly_seq_scheme.pdb_ins_code'].append( + ins_code) + # Structure doesn't support heterogeneous sequences. + raw_mmcif['_pdbx_poly_seq_scheme.hetero'].append('n') + if write_entity_poly_seq: + raw_mmcif['_entity_poly_seq.entity_id'].append(entity_id) + raw_mmcif['_entity_poly_seq.num'].append(str_res_id) + raw_mmcif['_entity_poly_seq.mon_id'].append(res_name) + # Structure doesn't support heterogeneous sequences. + raw_mmcif['_entity_poly_seq.hetero'].append('n') + written_entity_poly_seq_ids.add(entity_id) + elif is_branched: + raw_mmcif['_pdbx_branch_scheme.asym_id'].append(chain_id) + raw_mmcif['_pdbx_branch_scheme.entity_id'].append(entity_id) + raw_mmcif['_pdbx_branch_scheme.mon_id'].append(res_name) + raw_mmcif['_pdbx_branch_scheme.num'].append(str_res_id) + raw_mmcif['_pdbx_branch_scheme.pdb_asym_id'].append( + auth_asym_id) + raw_mmcif['_pdbx_branch_scheme.pdb_seq_num'].append( + pdb_seq_num) + raw_mmcif['_pdbx_branch_scheme.auth_asym_id'].append( + auth_asym_id) + raw_mmcif['_pdbx_branch_scheme.auth_seq_num'].append( + auth_seq_num) + raw_mmcif['_pdbx_branch_scheme.pdb_ins_code'].append(ins_code) + # Structure doesn't support heterogeneous sequences. + raw_mmcif['_pdbx_branch_scheme.hetero'].append('n') + else: + raw_mmcif['_pdbx_nonpoly_scheme.asym_id'].append(chain_id) + raw_mmcif['_pdbx_nonpoly_scheme.entity_id'].append(entity_id) + raw_mmcif['_pdbx_nonpoly_scheme.mon_id'].append(res_name) + raw_mmcif['_pdbx_nonpoly_scheme.pdb_seq_num'].append( + pdb_seq_num) + raw_mmcif['_pdbx_nonpoly_scheme.auth_seq_num'].append( + auth_seq_num) + raw_mmcif['_pdbx_nonpoly_scheme.pdb_strand_id'].append( + auth_asym_id) + raw_mmcif['_pdbx_nonpoly_scheme.pdb_ins_code'].append(ins_code) + + # Add _entity and _entity_poly tables. + for entity_id, chains in chains_by_entity_id.items(): + # chains should always be a non-empty list because of how we constructed + # chains_by_entity_id. + assert chains + # All chains for a given entity should have the same type and sequence + # so we can pick the first one without losing information. + key_chain = chains[0] + raw_mmcif['_entity.id'].append(entity_id) + raw_mmcif['_entity.pdbx_description'].append(key_chain['entity_desc']) + entity_type = key_chain['type'] + if entity_type not in mmcif_names.POLYMER_CHAIN_TYPES: + raw_mmcif['_entity.type'].append(entity_type) + else: + raw_mmcif['_entity.type'].append('polymer') + raw_mmcif['_entity_poly.entity_id'].append(entity_id) + raw_mmcif['_entity_poly.type'].append(entity_type) + + # _entity_poly.pdbx_strand_id is a comma-separated list of + # auth_asym_ids that are part of the entity. + raw_mmcif['_entity_poly.pdbx_strand_id'].append( + ','.join(chain['auth_asym_id'] for chain in chains) + ) + return raw_mmcif + + +def to_mmcif_atom_site_and_bonds_table( + *, + chains: Chains, + residues: Residues, + atoms: Atoms, + bonds: Bonds, + coords_decimal_places: int, +) -> Mapping[str, Sequence[str]]: + """Returns raw _atom_site and _struct_conn mmCIF tables.""" + raw_mmcif = collections.defaultdict(list) + # Use [value] * num wherever possible since it is about 10x faster than list + # comprehension in such cases. Also use f-strings instead of str() - faster. + total_atoms = atoms.size * atoms.num_models + raw_mmcif['_atom_site.id'] = [f'{i}' for i in range(1, total_atoms + 1)] + raw_mmcif['_atom_site.label_alt_id'] = ['.'] * total_atoms + # Use format_float_array instead of list comprehension for performance. + raw_mmcif['_atom_site.Cartn_x'] = mmcif.format_float_array( + values=atoms.x.ravel(), num_decimal_places=coords_decimal_places + ) + raw_mmcif['_atom_site.Cartn_y'] = mmcif.format_float_array( + values=atoms.y.ravel(), num_decimal_places=coords_decimal_places + ) + raw_mmcif['_atom_site.Cartn_z'] = mmcif.format_float_array( + values=atoms.z.ravel(), num_decimal_places=coords_decimal_places + ) + + # atoms.b_factor or atoms.occupancy can be flat even when the coordinates have + # leading dimensions. In this case we tile it to match. + if atoms.b_factor.ndim == 1: + atom_b_factor = np.tile(atoms.b_factor, atoms.num_models) + else: + atom_b_factor = atoms.b_factor.ravel() + raw_mmcif['_atom_site.B_iso_or_equiv'] = mmcif.format_float_array( + values=atom_b_factor, num_decimal_places=2 + ) + + if atoms.occupancy.ndim == 1: + atom_occupancy = np.tile(atoms.occupancy, atoms.num_models) + else: + atom_occupancy = atoms.occupancy.ravel() + raw_mmcif['_atom_site.occupancy'] = mmcif.format_float_array( + values=atom_occupancy.ravel(), num_decimal_places=2 + ) + + label_atom_id = atoms.name + type_symbol = atoms.element + label_comp_id = residues.apply_array_to_column('name', atoms.res_key) + label_asym_id = chains.apply_array_to_column('id', atoms.chain_key) + label_entity_id = chains.apply_array_to_column( + 'entity_id', atoms.chain_key) + # Performance optimisation: Do the int->str conversion on num_residue-sized, + # array, then select instead of selecting and then converting. + label_seq_id = residues.id.astype('str').astype(object)[ + ..., residues.index_by_key[atoms.res_key] + ] + + # _atom_site.label_seq_id is '.' for non-polymers. + non_polymer_chain_mask = string_array.isin( + chains.type, mmcif_names.POLYMER_CHAIN_TYPES, invert=True + ) + non_polymer_chain_keys = chains.key[non_polymer_chain_mask] + non_polymer_atom_mask = np.isin(atoms.chain_key, non_polymer_chain_keys) + label_seq_id[non_polymer_atom_mask] = '.' + + auth_asym_id = chains.apply_array_to_column( + 'auth_asym_id', atoms.chain_key) + auth_seq_id = residues.apply_array_to_column('auth_seq_id', atoms.res_key) + pdbx_pdb_ins_code = residues.apply_array_to_column( + 'insertion_code', atoms.res_key + ) + string_array.remap(pdbx_pdb_ins_code, mapping={None: '?'}, inplace=True) + + group_pdb = _residue_name_to_record_name( + residue_name=label_comp_id, polymer_mask=~non_polymer_atom_mask + ) + + def tile_for_models(arr: np.ndarray) -> list[str]: + if atoms.num_models == 1: + # Memory optimisation: np.tile(arr, 1) does a copy. + return arr.tolist() + return np.tile(arr, atoms.num_models).tolist() + + raw_mmcif['_atom_site.group_PDB'] = tile_for_models(group_pdb) + raw_mmcif['_atom_site.label_atom_id'] = tile_for_models(label_atom_id) + raw_mmcif['_atom_site.type_symbol'] = tile_for_models(type_symbol) + raw_mmcif['_atom_site.label_comp_id'] = tile_for_models(label_comp_id) + raw_mmcif['_atom_site.label_asym_id'] = tile_for_models(label_asym_id) + raw_mmcif['_atom_site.label_entity_id'] = tile_for_models(label_entity_id) + raw_mmcif['_atom_site.label_seq_id'] = tile_for_models(label_seq_id) + raw_mmcif['_atom_site.auth_asym_id'] = tile_for_models(auth_asym_id) + raw_mmcif['_atom_site.auth_seq_id'] = tile_for_models(auth_seq_id) + raw_mmcif['_atom_site.pdbx_PDB_ins_code'] = tile_for_models( + pdbx_pdb_ins_code) + model_id = np.array( + [str(i + 1) for i in range(atoms.num_models)], dtype=object + ) + raw_mmcif['_atom_site.pdbx_PDB_model_num'] = np.repeat( + model_id, [atoms.size] * atoms.num_models + ).tolist() + + if bonds.key.size > 0: + raw_mmcif.update( + bonds.to_mmcif_dict_from_atom_arrays( + atom_key=atoms.key, + chain_id=label_asym_id, + res_id=label_seq_id, + res_name=label_comp_id, + atom_name=label_atom_id, + auth_asym_id=auth_asym_id, + auth_seq_id=auth_seq_id, + insertion_code=np.array(pdbx_pdb_ins_code), + ) + ) + return raw_mmcif + + +def _flatten_author_naming_scheme_table( + res_table: Mapping[str, Mapping[int, str]], + chain_ids: np.ndarray, + res_chain_ids: np.ndarray, + res_ids: np.ndarray, + default_if_missing: str, + table_name: str, +) -> np.ndarray: + """Flattens an author naming scheme table consistently with res_ids.""" + if not set(chain_ids).issubset(res_table): + raise ValueError( + f'Chain IDs in the chain_id array must be a subset of {table_name} in ' + 'author naming scheme:\n' + f'chain_ids: {sorted(chain_ids)}\n' + f'{table_name} keys: {sorted(res_table.keys())}' + ) + + chain_change_mask = res_chain_ids[1:] != res_chain_ids[:-1] + res_chain_boundaries = np.concatenate( + ([0], np.where(chain_change_mask)[0] + 1, [len(res_chain_ids)]) + ) + + flat_vals = np.empty(len(res_ids), dtype=object) + for chain_start, chain_end in itertools.pairwise(res_chain_boundaries): + chain_id = res_chain_ids[chain_start] + chain_res_ids = res_ids[chain_start:chain_end] + chain_mapping = res_table[chain_id] + flat_vals[chain_start:chain_end] = [ + chain_mapping.get(r, default_if_missing) for r in chain_res_ids + ] + + return flat_vals + + +def tables_from_atom_arrays( + *, + res_id: np.ndarray, + author_naming_scheme: AuthorNamingScheme | None = None, + all_residues: Mapping[str, Sequence[tuple[str, int]]] | None = None, + chain_id: np.ndarray | None = None, + chain_type: np.ndarray | None = None, + res_name: np.ndarray | None = None, + atom_key: np.ndarray | None = None, + atom_name: np.ndarray | None = None, + atom_element: np.ndarray | None = None, + atom_x: np.ndarray | None = None, + atom_y: np.ndarray | None = None, + atom_z: np.ndarray | None = None, + atom_b_factor: np.ndarray | None = None, + atom_occupancy: np.ndarray | None = None, +) -> tuple[Atoms, Residues, Chains]: + """Returns Structure tables constructed from atom array level data. + + All fields except name and, res_id are optional, all array fields consist of a + value for each atom in the structure - so residue and chain values should hold + the same value for each atom in the chain or residue. Fields which are not + defined are filled with default values. + + Validation is performed by the Structure constructor where possible - but + author_naming scheme and all_residues must be checked in this function. + + It is not possible to construct structures with chains that do not contain + any resolved residues using this function. If this is necessary, use the + structure.Structure constructor directly. + + Args: + res_id: Integer array of shape [num_atom]. The unique residue identifier for + each residue. mmCIF field - _atom_site.label_seq_id. + author_naming_scheme: An optional instance of AuthorNamingScheme to use when + converting this structure to mmCIF. + all_residues: An optional mapping from each chain ID (i.e. label_asym_id) to + a sequence of (label_comp_id, label_seq_id) tuples, one per residue. This + can contain residues that aren't present in the atom arrays. This is + common in experimental data where some residues are not resolved but are + known to be present. + chain_id: String array of shape [num_atom] of unique chain identifiers. + mmCIF field - _atom_site.label_asym_id. + chain_type: String array of shape [num_atom]. The molecular type of the + current chain (e.g. polyribonucleotide). mmCIF field - _entity_poly.type + OR _entity.type (for non-polymers). + res_name: String array of shape [num_atom].. The name of each residue, + typically a 3 letter string for polypeptides or 1-2 letter strings for + polynucleotides. mmCIF field - _atom_site.label_comp_id. + atom_key: A unique sorted integer array, used only by the bonds table to + identify the atoms participating in each bond. If the bonds table is + specified then this column must be non-None. + atom_name: String array of shape [num_atom]. The name of each atom (e.g CA, + O2', etc.). mmCIF field - _atom_site.label_atom_id. + atom_element: String array of shape [num_atom]. The element type of each + atom (e.g. C, O, N, etc.). mmCIF field - _atom_site.type_symbol. + atom_x: Float array of shape [..., num_atom] of atom x coordinates. May have + arbitrary leading dimensions, provided that these are consistent across + all coordinate fields. + atom_y: Float array of shape [..., num_atom] of atom y coordinates. May have + arbitrary leading dimensions, provided that these are consistent across + all coordinate fields. + atom_z: Float array of shape [..., num_atom] of atom z coordinates. May have + arbitrary leading dimensions, provided that these are consistent across + all coordinate fields. + atom_b_factor: Float array of shape [..., num_atom] or [num_atom] of atom + b-factors or equivalent. If there are no extra leading dimensions then + these values are assumed to apply to all coordinates for a given atom. If + there are leading dimensions then these must match those used by the + coordinate fields. + atom_occupancy: Float array of shape [..., num_atom] or [num_atom] of atom + occupancies or equivalent. If there are no extra leading dimensions then + these values are assumed to apply to all coordinates for a given atom. If + there are leading dimensions then these must match those used by the + coordinate fields. + """ + num_atoms = len(res_id) + + for arr_name, array, dtype in ( + ('chain_id', chain_id, object), + ('chain_type', chain_type, object), + ('res_id', res_id, np.int32), + ('res_name', res_name, object), + ('atom_key', atom_key, np.int64), + ('atom_name', atom_name, object), + ('atom_element', atom_element, object), + ): + if array is not None and array.shape != (num_atoms,): + raise ValueError( + f'{arr_name} shape {array.shape} != ({num_atoms},)') + if array is not None and array.dtype != dtype: + raise ValueError(f'{arr_name} dtype {array.dtype} != {dtype}') + + for arr_name, array in ( + ('atom_x', atom_x), + ('atom_y', atom_y), + ('atom_z', atom_z), + ('atom_b_factor', atom_b_factor), + ('atom_occupancy', atom_occupancy), + ): + if array is not None and array.shape[-1] != num_atoms: + raise ValueError( + f'{arr_name} last dim {array.shape[-1]} != {num_atoms=}') + if ( + array is not None + and array.dtype != np.float32 + and array.dtype != np.float64 + ): + raise ValueError( + f'{arr_name} must be np.float32 or np.float64, got {array.dtype=}' + ) + + if all_residues is not None and (res_name is None or res_id is None): + raise ValueError( + 'If all_residues != None, res_name and res_id must not be None either.' + ) + + if num_atoms == 0: + return Atoms.make_empty(), Residues.make_empty(), Chains.make_empty() + + if chain_id is None: + chain_id = np.full(shape=num_atoms, fill_value='A', dtype=object) + if res_name is None: + res_name = np.full(shape=num_atoms, fill_value='UNK', dtype=object) + + chain_change_mask = chain_id[1:] != chain_id[:-1] + chain_start = np.concatenate(([0], np.where(chain_change_mask)[0] + 1)) + res_start = np.concatenate( + ([0], np.where((res_id[1:] != res_id[:-1]) | chain_change_mask)[0] + 1) + ) + + if len(set(chain_id)) != len(chain_start): + raise ValueError(f'Chain IDs must be contiguous, but got {chain_id}') + + # We do not support chains with unresolved residues-only in this function. + chain_ids = chain_id[chain_start] + if all_residues and set(all_residues.keys()) != set(chain_ids): + raise ValueError( + 'all_residues must contain the same set of chain IDs as the chain_id ' + f'array:\nall_residues keys: {sorted(all_residues.keys())}\n' + f'chain_ids: {sorted(chain_ids)}.' + ) + # Make sure all_residue ordering is consistent with chain_id. + if all_residues and np.any(list(all_residues.keys()) != chain_ids): + all_residues = {cid: all_residues[cid] for cid in chain_ids} + + # Create the chains table. + num_chains = len(chain_ids) + chain_keys = np.arange(num_chains, dtype=np.int64) + chain_key_by_chain_id = dict(zip(chain_ids, chain_keys, strict=True)) + + if chain_type is not None: + chain_types = chain_type[chain_start] + else: + chain_types = np.full( + num_chains, mmcif_names.PROTEIN_CHAIN, dtype=object) + + if author_naming_scheme is not None: + auth_asym_id = string_array.remap( + chain_ids, author_naming_scheme.auth_asym_id + ) + entity_id = string_array.remap( + chain_ids, author_naming_scheme.entity_id, default_value='.' + ) + entity_desc = string_array.remap( + entity_id, author_naming_scheme.entity_desc, default_value='.' + ) + else: + auth_asym_id = chain_ids + entity_id = (chain_keys + 1).astype(str).astype(object) + entity_desc = np.full(num_chains, '.', dtype=object) + + chains = Chains( + key=chain_keys, + id=chain_ids, + type=chain_types, + auth_asym_id=auth_asym_id, + entity_id=entity_id, + entity_desc=entity_desc, + ) + + # Create the residues table. + if all_residues is not None: + residue_order = [] + for cid, residues in all_residues.items(): + residue_order.extend((cid, rname, int(rid)) + for (rname, rid) in residues) + res_chain_ids, res_names, res_ids = zip(*residue_order) + res_chain_ids = np.array(res_chain_ids, dtype=object) + res_ids = np.array(res_ids, dtype=np.int32) + res_names = np.array(res_names, dtype=object) + else: + res_chain_ids = chain_id[res_start] + res_ids = res_id[res_start] + res_names = res_name[res_start] + residue_order = list(zip(res_chain_ids, res_names, res_ids)) + + if author_naming_scheme is not None and author_naming_scheme.auth_seq_id: + auth_seq_id = _flatten_author_naming_scheme_table( + author_naming_scheme.auth_seq_id, + chain_ids=chain_ids, + res_chain_ids=res_chain_ids, + res_ids=res_ids, + default_if_missing='.', + table_name='auth_seq_id', + ) + else: + auth_seq_id = res_ids.astype(str).astype(object) + + if author_naming_scheme is not None and author_naming_scheme.insertion_code: + insertion_code = _flatten_author_naming_scheme_table( + author_naming_scheme.insertion_code, + chain_ids=chain_ids, + res_chain_ids=res_chain_ids, + res_ids=res_ids, + default_if_missing='?', + table_name='insertion_code', + ) + # Make sure insertion code of None is mapped to '.'. + insertion_code = string_array.remap(insertion_code, {None: '?'}) + else: + insertion_code = np.full( + shape=len(res_ids), fill_value='?', dtype=object) + + res_key_by_res = {res: i for i, res in enumerate(residue_order)} + res_keys = np.arange(len(residue_order), dtype=np.int64) + res_chain_keys = string_array.remap( + res_chain_ids, chain_key_by_chain_id + ).astype(np.int64) + residues = Residues( + chain_key=res_chain_keys, + key=res_keys, + id=res_ids, + name=res_names, + auth_seq_id=auth_seq_id, + insertion_code=insertion_code, + ) + + if atom_key is None: + atom_key = np.arange(num_atoms, dtype=np.int64) + + atom_chain_keys = string_array.remap(chain_id, chain_key_by_chain_id).astype( + np.int64 + ) + + try: + atom_res_keys = [res_key_by_res[r] + for r in zip(chain_id, res_name, res_id)] + except KeyError as e: + missing_chain_id, missing_res_name, missing_res_id = e.args[0] + raise ValueError( + 'Inconsistent res_name, res_id and all_residues. Could not find ' + f'residue with chain_id={missing_chain_id}, ' + f'res_name={missing_res_name}, res_id={missing_res_id} in all_residues.' + ) from e + + atoms = Atoms( + key=atom_key, + chain_key=atom_chain_keys, + res_key=np.array(atom_res_keys, dtype=np.int64), + name=_default(atom_name, ['?'] * num_atoms, object), + element=_default(atom_element, ['?'] * num_atoms, object), + x=_default(atom_x, [0.0] * num_atoms, np.float32), + y=_default(atom_y, [0.0] * num_atoms, np.float32), + z=_default(atom_z, [0.0] * num_atoms, np.float32), + b_factor=_default(atom_b_factor, [0.0] * num_atoms, np.float32), + occupancy=_default(atom_occupancy, [1.0] * num_atoms, np.float32), + ) + return atoms, residues, chains diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/table.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/table.py new file mode 100644 index 0000000000000000000000000000000000000000..7cad4a27d0be0732f7a71122594993d46148a2b0 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/table.py @@ -0,0 +1,565 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Table module for atom/residue/chain tables in Structure. + +Tables are intended to be lightweight collections of columns, loosely based +on a pandas dataframe, for use in the Structure class. +""" + +import abc +from collections.abc import Callable, Collection, Iterable, Iterator, Mapping, Sequence +import dataclasses +import functools +import graphlib +import typing +from typing_extensions import Any, Protocol, Self, TypeAlias, TypeVar, overload + +from alphafold3.cpp import string_array +import numpy as np + + +TableEntry: TypeAlias = str | int | float | None +FilterPredicate: TypeAlias = ( + TableEntry + | Iterable[Any] # Workaround for b/326384670. Tighten once fixed. + | Callable[[Any], bool] # Workaround for b/326384670. Tighten once fixed. + | Callable[[np.ndarray], bool] +) + + +class RowLookup(Protocol): + + def get_row_by_key( + self, + key: int, + column_name_map: Mapping[str, str] | None = None, + ) -> Mapping[str, Any]: + ... + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class Table: + """Parent class for structure tables. + + A table is a collection of columns of equal length, where one column is the + key. The key uniquely identifies each row in the table. + + A table can refer to other tables by including a foreign key column, whose + values are key values from the other table's key column. These column can have + arbitrary names and are treated like any other integer-valued column. + + See the `Database` class in this module for utilities for handing sets of + tables that are related via foreign keys. + + NB: This does not correspond to an mmCIF table. + """ + + key: np.ndarray + + def __post_init__(self): + for col_name in self.columns: + if (col_len := self.get_column(col_name).shape[-1]) != self.size: + raise ValueError( + f'All columns should have length {self.size} but got "{col_name}"' + f' with length {col_len}.' + ) + # Make col immutable. + self.get_column(col_name).flags.writeable = False + if self.key.size and self.key.min() < 0: + raise ValueError( + 'Key values must be non-negative. Got negative values:' + f' {set(self.key[self.key < 0])}' + ) + self.key.flags.writeable = False # Make key immutable. + + def __getstate__(self) -> dict[str, Any]: + """Returns members with cached properties removed for pickling.""" + cached_props = { + k + for k, v in self.__class__.__dict__.items() + if isinstance(v, functools.cached_property) + } + return {k: v for k, v in self.__dict__.items() if k not in cached_props} + + @functools.cached_property + def index_by_key(self) -> np.ndarray: + """Mapping from key values to their index in the column arrays. + + i.e.: self.key[index_by_key[k]] == k + """ + if not self.key.size: + return np.array([], dtype=np.int64) + else: + index_by_key = np.zeros(np.max(self.key) + 1, dtype=np.int64) + index_by_key[self.key] = np.arange(self.size) + return index_by_key + + @functools.cached_property + def columns(self) -> tuple[str, ...]: + """The names of the columns in the table, including the key column.""" + return tuple(field.name for field in dataclasses.fields(self)) + + @functools.cached_property + def items(self) -> Mapping[str, np.ndarray]: + """Returns the mapping from column names to column values.""" + return {col: getattr(self, col) for col in self.columns} + + @functools.cached_property + def size(self) -> int: + """The number of rows in the table.""" + return self.key.shape[-1] + + def __len__(self) -> int: + return self.size + + def get_column(self, column_name: str) -> np.ndarray: + """Gets a column by name.""" + # Performance optimisation: use the cached columns, instead of getattr. + return self.items[column_name] + + def apply_array(self, arr: np.ndarray) -> Self: + """Returns a sliced table using a key (!= index) array or a boolean mask.""" + if arr.dtype == bool and np.all(arr): + return self # Shortcut: No-op, so just return. + + return self.copy_and_update(**{ + column_name: self.apply_array_to_column(column_name, arr) + for column_name in self.columns + }) + + def apply_index(self, index_arr: np.ndarray) -> Self: + """Returns a sliced table using an index (!= key) array.""" + if index_arr.dtype == bool: + raise ValueError('The index array must not be a boolean mask.') + + return self.copy_and_update( + **{col: self.get_column(col)[..., index_arr] for col in self.columns} + ) + + def apply_array_to_column( + self, + column_name: str, + arr: np.ndarray, + ) -> np.ndarray: + """Returns a sliced column array using a key array or a boolean mask.""" + if arr.dtype == bool: + return self.get_column(column_name)[..., arr] + else: + return self.get_column(column_name)[..., self.index_by_key[arr]] + + def get_value_by_index(self, column_name: str, index: int) -> Any: + return self.get_column(column_name)[index] + + def get_value_by_key( + self, + column_name: str, + key: int | np.integer, + ) -> TableEntry: + """Gets the value of a column at the row with specified key value.""" + return self.get_value_by_index(column_name, self.index_by_key[key]) + + @overload + def __getitem__(self, key: str) -> np.ndarray: + ... + + @overload + def __getitem__(self, key: np.ndarray) -> 'Table': + ... + + @overload + def __getitem__(self, key: tuple[str, int | np.integer]) -> TableEntry: + ... + + @overload + def __getitem__(self, key: tuple[str, np.ndarray]) -> np.ndarray: + ... + + def __getitem__(self, key): + match key: + case str(): + return self.get_column(key) + case np.ndarray() as key_arr_or_mask: + return self.apply_array(key_arr_or_mask) + case str() as col, int() | np.integer() as key_val: + return self.get_value_by_key(col, key_val) + case str() as col, np.ndarray() as key_arr_or_mask: + return self.apply_array_to_column(col, key_arr_or_mask) + case _: + if isinstance(key, tuple): + err_msg = f'{key}, type: tuple({[type(v) for v in key]})' + else: + err_msg = f'{key}, type: {type(key)}' + raise KeyError(err_msg) + + def get_row_by_key( + self, + key: int, + column_name_map: Mapping[str, str] | None = None, + ) -> dict[str, Any]: + """Gets the row with specified key value.""" + return self.get_row_by_index( + self.index_by_key[key], column_name_map=column_name_map + ) + + def get_row_by_index( + self, + index: int, + column_name_map: Mapping[str, str] | None = None, + ) -> dict[str, Any]: + """Gets the row at the specified index.""" + if column_name_map is not None: + return { + renamed_col: self.get_value_by_index(col, index) + for renamed_col, col in column_name_map.items() + } + else: + return {col: self.get_value_by_index(col, index) for col in self.columns} + + def iterrows( + self, + *, + row_keys: np.ndarray | None = None, + column_name_map: Mapping[str, str] | None = None, + **table_by_foreign_key_col: RowLookup, + ) -> Iterator[Mapping[str, Any]]: + """Yields rows from the table. + + Args: + row_keys: An optional array of keys of rows to yield. If None, all rows + will be yielded. + column_name_map: An optional mapping from desired keys in the row dicts to + the names of the columns they correspond to. + **table_by_foreign_key_col: An optional mapping from column names in this + table, which are expected to be columns of foreign keys, to the table + that the foreign keys point into. If provided, then the yielded rows + will include data from the foreign tables at the appropriate key. + """ + if row_keys is not None: + row_indices = self.index_by_key[row_keys] + else: + row_indices = range(self.size) + for i in row_indices: + row = self.get_row_by_index(i, column_name_map=column_name_map) + for key_col, table in table_by_foreign_key_col.items(): + foreign_key = self[key_col][i] + foreign_row = table.get_row_by_key(foreign_key) + row.update(foreign_row) + yield row + + def with_column_names( + self, column_name_map: Mapping[str, str] + ) -> 'RenamedTableView': + """Returns a view of this table with mapped column names.""" + return RenamedTableView(self, column_name_map=column_name_map) + + def make_filter_mask( + self, + mask: np.ndarray | None = None, + *, + apply_per_element: bool = False, + **predicate_by_col: FilterPredicate, + ) -> np.ndarray | None: + """Returns a boolean array of rows to keep, or None if all can be kept. + + Args: + mask: See `Table.filter`. + apply_per_element: See `Table.filter`. + **predicate_by_col: See `Table.filter`. + + Returns: + Either a boolean NumPy array of length `(self.size,)` denoting which rows + should be kept according to the input mask and predicates, or None. None + implies there is no filtering required, and is used where possible + instead of an all-True array to save time and space. + """ + if mask is None: + if not predicate_by_col: + return None + else: + mask = np.ones((self.size,), dtype=bool) + else: + if mask.shape != (self.size,): + raise ValueError( + f'mask must have shape ({self.size},). Got: {mask.shape}.' + ) + if mask.dtype != bool: + raise ValueError( + f'mask must have dtype bool. Got: {mask.dtype}.') + + for col, predicate in predicate_by_col.items(): + if self[col].ndim > 1: + raise ValueError( + f'Cannot filter by column {col} with more than 1 dimension.' + ) + + callable_predicates = [] + if not callable(predicate): + if isinstance(predicate, Iterable) and not isinstance(predicate, str): + target_vals = predicate + else: + target_vals = [predicate] + for target_val in target_vals: + callable_predicates.append( + lambda x, target=target_val: x == target) + else: + callable_predicates.append(predicate) + + field_mask = np.zeros_like(mask) + for callable_predicate in callable_predicates: + if not apply_per_element: + callable_predicate = typing.cast( + Callable[[np.ndarray], bool], callable_predicate + ) + predicate_result = callable_predicate(self.get_column(col)) + else: + predicate_result = np.array( + [callable_predicate(elem) + for elem in self.get_column(col)] + ) + np.logical_or(field_mask, predicate_result, out=field_mask) + np.logical_and(mask, field_mask, out=mask) # Update in-place. + return mask + + def filter( + self, + mask: np.ndarray | None = None, + *, + apply_per_element: bool = False, + invert: bool = False, + **predicate_by_col: FilterPredicate, + ) -> Self: + """Filters the table using mask and/or predicates and returns a new table. + + Predicates can be either: + 1. A constant value, e.g. `'CA'`. In this case then only rows that match + this value for the given column are retained. + 2. A (non-string) iterable e.g. `('A', 'B')`. In this + case then rows are retained if they match any of the provided values for + the given column. + 3. A boolean function e.g. `lambda b_fac: b_fac < 100.0`. + In this case then only rows that evaluate to `True` are retained. By + default this function's parameter is expected to be an array, unless + `apply_per_element=True`. + + Args: + mask: An optional boolean NumPy array with length equal to the table size. + If provided then this will be combined with the other predicates so that + a row is included if it is masked-in *and* matches all the predicates. + apply_per_element: Whether apply predicates to each element in the column + individually, or to pass the whole column array to the predicate. + invert: If True then the returned table will contain exactly those rows + that would be removed if this was `False`. + **predicate_by_col: A mapping from column name to a predicate. Filtered + columns must be 1D arrays. If multiple columns are provided as keyword + arguments then each predicate is applied and the results are combined + using a boolean AND operation, so an atom is only retained if it passes + all predicates. + + Returns: + A new table with the desired rows retained (or filtered out if + `invert=True`). + + Raises: + ValueError: If mask is provided and is not a bool array with shape + `(num_atoms,)`. + """ + filter_mask = self.make_filter_mask( + mask, apply_per_element=apply_per_element, **predicate_by_col + ) + if filter_mask is None: + # No mask or predicate was specified, so we can return early. + if not invert: + return self + else: + return self[np.array((), dtype=np.int64)] + else: + return self[~filter_mask if invert else filter_mask] + + def _validate_keys_are_column_names(self, keys: Collection[str]) -> None: + """Raises an error if any of the keys are not column names.""" + if mismatches := set(keys) - set(self.columns): + raise ValueError(f'Invalid column names: {sorted(mismatches)}.') + + def copy_and_update(self, **new_column_by_column_name: np.ndarray) -> Self: + """Returns a copy of this table with the specified changes applied. + + Args: + **new_column_by_column_name: New values for the specified columns. + + Raises: + ValueError: If a specified column name is not a column in this table. + """ + self._validate_keys_are_column_names(new_column_by_column_name) + return dataclasses.replace(self, **new_column_by_column_name) + + def copy_and_remap( + self, **mapping_by_col: Mapping[TableEntry, TableEntry] + ) -> Self: + """Returns a copy of the table with the specified columns remapped. + + Args: + **mapping_by_col: Each kwarg key should be the name of one of this table's + columns, and each value should be a mapping. The values in the column + will be looked up in the mapping and replaced with the result if one is + found. + + Raises: + ValueError: If a specified column name is not a column in this table. + """ + self._validate_keys_are_column_names(mapping_by_col) + if not self.size: + return self + remapped_cols = {} + for column_name, mapping in mapping_by_col.items(): + col_arr = self.get_column(column_name) + if col_arr.dtype == object: + remapped = string_array.remap(col_arr, mapping) + else: + remapped = np.vectorize(lambda x: mapping.get(x, x))( + col_arr) # pylint: disable=cell-var-from-loop + remapped_cols[column_name] = remapped + return self.copy_and_update(**remapped_cols) + + +class RenamedTableView: + """View of a table with renamed column names.""" + + def __init__(self, table: Table, column_name_map: Mapping[str, str]): + self._table = table + self._column_name_map = column_name_map + + def get_row_by_key( + self, + key: int, + column_name_map: Mapping[str, str] | None = None, + ) -> Mapping[str, Any]: + del column_name_map + return self._table.get_row_by_key( + key, column_name_map=self._column_name_map + ) + + +_DatabaseT = TypeVar('_DatabaseT', bound='Database') + + +class Database(abc.ABC): + """Relational database base class.""" + + @property + @abc.abstractmethod + def tables(self) -> Collection[str]: + """The names of the tables in this database.""" + + @abc.abstractmethod + def get_table(self, table_name: str) -> Table: + """Gets the table with the given name.""" + + @property + @abc.abstractmethod + def foreign_keys(self) -> Mapping[str, Collection[tuple[str, str]]]: + """Describes the relationship between keys in the database. + + Returns: + A map from table names to pairs of `(column_name, foreign_table_name)` + where `column_name` is a column containing foreign keys in the table named + by the key, and the `foreign_table_name` is the name of the table that + those foreign keys refer to. + """ + + @abc.abstractmethod + def copy_and_update( + self: _DatabaseT, + **new_field_by_field_name: ..., + ) -> _DatabaseT: + """Returns a copy of this database with the specified changes applied.""" + + +def table_dependency_order(db: Database) -> Iterable[str]: + """Yields the names of the tables in the database in dependency order. + + This order guarantees that a table appears after all other tables that + it refers to using foreign keys. Specifically A < B implies that A contains + no column that refers to B.key as a foreign key. + + Args: + db: The database that defines the table names and foreign keys. + """ + connections: dict[str, set[str]] = {} + for table_name in db.tables: + connection_set = set() + for _, foreign_table in db.foreign_keys.get(table_name, ()): + connection_set.add(foreign_table) + connections[table_name] = connection_set + yield from graphlib.TopologicalSorter(connections).static_order() + + +def concat_databases(dbs: Sequence[_DatabaseT]) -> _DatabaseT: + """Concatenates the tables across a sequence of databases. + + Args: + dbs: A non-empty sequence of database instances of the same type. + + Returns: + A new database containing the concatenated tables from the input databases. + + Raises: + ValueError: If `dbs` is empty or `dbs` contains different Database + types. + """ + if not dbs: + raise ValueError('Need at least one value to concatenate.') + distinct_db_types = {type(db) for db in dbs} + if len(distinct_db_types) > 1: + raise ValueError( + f'All `dbs` must be of the same type, got: {distinct_db_types}' + ) + + first_db, *other_dbs = dbs + concatted_tables: dict[str, Table] = {} + key_offsets: dict[str, list[int]] = {} + for table_name in table_dependency_order(first_db): + first_table = first_db.get_table(table_name) + columns: dict[str, list[np.ndarray]] = { + column_name: [first_table.get_column(column_name)] + for column_name in first_table.columns + } + key_offsets[table_name] = [ + first_table.key.max() + 1 if first_table.size else 0 + ] + + for prev_index, db in enumerate(other_dbs): + table = db.get_table(table_name) + for col_name in table.columns: + columns[col_name].append(table.get_column(col_name)) + key_offset = key_offsets[table_name][prev_index] + offset_key = table.key + key_offset + columns['key'][-1] = offset_key + if table.size: + key_offsets[table_name].append(offset_key.max() + 1) + else: + key_offsets[table_name].append( + key_offsets[table_name][prev_index]) + for fkey_col_name, foreign_table_name in first_db.foreign_keys.get( + table_name, [] + ): + fkey_columns = columns[fkey_col_name] + fkey_columns[-1] = ( + fkey_columns[-1] + + key_offsets[foreign_table_name][prev_index] + ) + + concatted_columns = { + column_name: np.concatenate(values, axis=-1) + for column_name, values in columns.items() + } + concatted_tables[table_name] = (type(first_table))(**concatted_columns) + return first_db.copy_and_update(**concatted_tables) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/test_utils.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8cc6ec49853e60de4f807818c918a52451304f11 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/test_utils.py @@ -0,0 +1,358 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Utilities for structure module testing.""" + +import dataclasses + +from absl.testing import parameterized +from alphafold3 import structure +from alphafold3.common.testing import data +import numpy as np + +import os +import contextlib +import datetime +import difflib +import functools +import hashlib +import shutil +import pathlib +from typing import Any +from absl.testing import absltest +import mindspore as ms +from alphafold3.common.testing import data as testing_data +from alphafold3.common import resources +from alphafold3.data import pipeline +from alphafold3.model.atom_layout import atom_layout + +_JACKHMMER_BINARY_PATH = shutil.which('jackhmmer') +_NHMMER_BINARY_PATH = shutil.which('nhmmer') +_HMMALIGN_BINARY_PATH = shutil.which('hmmalign') +_HMMSEARCH_BINARY_PATH = shutil.which('hmmsearch') +_HMMBUILD_BINARY_PATH = shutil.which('hmmbuild') + +@contextlib.contextmanager +def _output(name: str): + with open(result_path := f'{absltest.TEST_TMPDIR.value}/{name}', "wb") as f: + yield result_path, f + + +@functools.singledispatch +def _hash_data(x: Any, /) -> str: + if x is None: + return '<>' + return _hash_data(json.dumps(x).encode('utf-8')) + + +@_hash_data.register +def _(x: bytes, /) -> str: + return hashlib.sha256(x).hexdigest() + + +@_hash_data.register +def _(x: ms.Tensor) -> str: + return _hash_data(x.asnumpy()) + + +@_hash_data.register +def _(x: np.ndarray) -> str: + if x.dtype == object: + return ';'.join(map(_hash_data, x.ravel().tolist())) + return _hash_data(x.tobytes()) + + +@_hash_data.register +def _(_: structure.Structure) -> str: + return '<>' + + +@_hash_data.register +def _(_: atom_layout.AtomLayout) -> str: + return '<>' + + +def _generate_diff(actual: str, expected: str) -> str: + return '\n'.join( + difflib.unified_diff( + expected.split('\n'), + actual.split('\n'), + fromfile='expected', + tofile='actual', + lineterm='', + ) + ) + + +def tree_map(func, dict_tree): + if isinstance(dict_tree, dict): + return {k: tree_map(func, v) for k, v in dict_tree.items()} + else: + if func == "asnumpy": + return dict_tree.asnumpy() + elif func == "float32": + return dict_tree.astype(ms.float32) + elif func == "bfloat16": + return dict_tree.astype(ms.bfloat16) + else: + return func(dict_tree) + +class StructureTestCase(parameterized.TestCase): + """Testing utilities for working with structure.Structure.""" + + def set_path(self, use_full_database=False): + if use_full_database: + small_bfd_database_path = testing_data.Data( + '/data/zmmVol2/AF3/public_databases/bfd-first_non_consensus_sequences.fasta' + ).path() + mgnify_database_path = testing_data.Data( + '/data/zmmVol2/AF3/public_databases/mgy_clusters_2022_05.fa' + ).path() + uniprot_cluster_annot_database_path = testing_data.Data( + '/data/zmmVol2/AF3/public_databases/uniprot_all_2021_04.fa' + ).path() + uniref90_database_path = testing_data.Data( + '/data/zmmVol2/AF3/public_databases/uniref90_2022_05.fa' + ).path() + ntrna_database_path = testing_data.Data( + '/data/zmmVol2/AF3/public_databases/nt_rna_2023_02_23_clust_seq_id_90_cov_80_rep_seq.fasta' + ).path() + rfam_database_path = testing_data.Data( + '/data/zmmVol2/AF3/public_databases/rfam_14_9_clust_seq_id_90_cov_80_rep_seq.fasta' + ).path() + rna_central_database_path = testing_data.Data( + '/data/zmmVol2/AF3/public_databases/rnacentral_active_seq_id_90_cov_80_linclust.fasta' + ).path() + pdb_database_path = testing_data.Data( + '/data/zmmVol2/AF3/public_databases/mmcif_files' + ).path() + seqres_database_path = testing_data.Data( + '/data/zmmVol2/AF3/public_databases/pdb_seqres_2022_09_28.fasta' + ).path() + else: + small_bfd_database_path = testing_data.Data( + resources.ROOT + / 'test_data/miniature_databases/bfd-first_non_consensus_sequences__subsampled_1000.fasta' + ).path() + mgnify_database_path = testing_data.Data( + resources.ROOT + / 'test_data/miniature_databases/mgy_clusters__subsampled_1000.fa' + ).path() + uniprot_cluster_annot_database_path = testing_data.Data( + resources.ROOT + / 'test_data/miniature_databases/uniprot_all__subsampled_1000.fasta' + ).path() + uniref90_database_path = testing_data.Data( + resources.ROOT + / 'test_data/miniature_databases/uniref90__subsampled_1000.fasta' + ).path() + ntrna_database_path = testing_data.Data( + resources.ROOT + / ('test_data/miniature_databases/' + 'nt_rna_2023_02_23_clust_seq_id_90_cov_80_rep_seq__subsampled_1000.fasta') + ).path() + rfam_database_path = testing_data.Data( + resources.ROOT + / 'test_data/miniature_databases/rfam_14_4_clustered_rep_seq__subsampled_1000.fasta' + ).path() + rna_central_database_path = testing_data.Data( + resources.ROOT + / 'test_data/miniature_databases/rnacentral_active_seq_id_90_cov_80_linclust__subsampled_1000.fasta' + ).path() + pdb_database_path = testing_data.Data( + resources.ROOT / 'test_data/miniature_databases/pdb_mmcif' + ).path() + seqres_database_path = testing_data.Data( + resources.ROOT + / 'test_data/miniature_databases/pdb_seqres_2022_09_28__subsampled_1000.fasta' + ).path() + + self._data_pipeline_config = pipeline.DataPipelineConfig( + jackhmmer_binary_path=_JACKHMMER_BINARY_PATH, + nhmmer_binary_path=_NHMMER_BINARY_PATH, + hmmalign_binary_path=_HMMALIGN_BINARY_PATH, + hmmsearch_binary_path=_HMMSEARCH_BINARY_PATH, + hmmbuild_binary_path=_HMMBUILD_BINARY_PATH, + small_bfd_database_path=small_bfd_database_path, + mgnify_database_path=mgnify_database_path, + uniprot_cluster_annot_database_path=uniprot_cluster_annot_database_path, + uniref90_database_path=uniref90_database_path, + ntrna_database_path=ntrna_database_path, + rfam_database_path=rfam_database_path, + rna_central_database_path=rna_central_database_path, + pdb_database_path=pdb_database_path, + seqres_database_path=seqres_database_path, + max_template_date=datetime.date(2021, 9, 30), + ) + self.data_path = "/data/zmmVol2/AF3/run_test/src/alphafold3/test_data" + + def compare_golden(self, result_path: str, golden_path) -> None: + filename = os.path.split(result_path)[1] + golden_path = pathlib.Path(golden_path) + with open(golden_path, 'r') as golden_file: + golden_text = golden_file.read() + with open(result_path, 'r') as result_file: + result_text = result_file.read() + + diff = _generate_diff(result_text, golden_text) + + self.assertEqual(diff, "", f"Result differs from golden:\n{diff}") + + def assertAuthorNamingSchemeEqual(self, ans1, ans2): # pylint: disable=invalid-name + """Walks naming scheme, making sure all elements are equal.""" + if ans1 is None or ans2 is None: + self.assertIsNone(ans1) + self.assertIsNone(ans2) + return + flat_ans1 = dict(tree.flatten_with_path(dataclasses.asdict(ans1))) + flat_ans2 = dict(tree.flatten_with_path(dataclasses.asdict(ans2))) + for k, v in flat_ans1.items(): + self.assertEqual(v, flat_ans2[k], msg=str(k)) + for k, v in flat_ans2.items(): + self.assertEqual(v, flat_ans1[k], msg=str(k)) + + def assertAllResiduesEqual(self, all_res1, all_res2): # pylint: disable=invalid-name + """Walks all residues, making sure alll elements are equal.""" + if all_res1 is None or all_res2 is None: + self.assertIsNone(all_res1) + self.assertIsNone(all_res2) + return + self.assertSameElements(all_res1.keys(), all_res2.keys()) + for chain_id, chain_res in all_res1.items(): + self.assertSequenceEqual( + chain_res, all_res2[chain_id], msg=chain_id) + + def assertBioassemblyDataEqual(self, data1, data2): # pylint: disable=invalid-name + if data1 is None or data2 is None: + self.assertIsNone(data1) + self.assertIsNone(data2) + return + self.assertDictEqual(data1.to_mmcif_dict(), data2.to_mmcif_dict()) + + def assertChemicalComponentsDataEqual( # pylint: disable=invalid-name + self, + data1, + data2, + allow_chem_comp_data_extension, + ): + """Checks whether two ChemicalComponentData objects are considered equal.""" + if data1 is None or data2 is None: + self.assertIsNone(data1) + self.assertIsNone(data2) + return + if (not allow_chem_comp_data_extension) or ( + data1.chem_comp.keys() ^ data2.chem_comp.keys() + ): + self.assertDictEqual(data1.chem_comp, data2.chem_comp) + else: + mismatching_values = [] + for component_id in data1.chem_comp: + found = data1.chem_comp[component_id] + expected = data2.chem_comp[component_id] + if not found.extends(expected): + mismatching_values.append((component_id, expected, found)) + + if mismatching_values: + mismatch_err_msgs = '\n'.join( + f'{component_id}: {expected} or its extension expected,' + f' but {found} found.' + for component_id, expected, found in mismatching_values + ) + self.fail( + f'Mismatching values for `_chem_comp` table: {mismatch_err_msgs}', + ) + + def assertBondsEqual(self, bonds1, bonds2, atom_key1, atom_key2): # pylint: disable=invalid-name + """Checks whether two Bonds objects are considered equal.""" + # An empty bonds table is functionally equivalent to an empty bonds table. + # NB: this can only ever be None in structure v1. + if bonds1 is None or not bonds1.size or bonds2 is None or not bonds2.size: + self.assertTrue(bonds1 is None or not bonds1.size, + msg=f'{bonds1=}') + self.assertTrue(bonds2 is None or not bonds2.size, + msg=f'{bonds2=}') + return + + ptnr1_indices1, ptnr2_indices1 = bonds1.get_atom_indices(atom_key1) + ptnr1_indices2, ptnr2_indices2 = bonds2.get_atom_indices(atom_key2) + np.testing.assert_array_equal(ptnr1_indices1, ptnr1_indices2) + np.testing.assert_array_equal(ptnr2_indices1, ptnr2_indices2) + np.testing.assert_array_equal(bonds1.type, bonds2.type) + np.testing.assert_array_equal(bonds1.role, bonds2.role) + + def assertStructuresEqual( # pylint: disable=invalid-name + self, + struc1, + struc2, + *, + ignore_fields=None, + allow_chem_comp_data_extension=False, + atol=0, + ): + """Checks whether two Structure objects could be considered equal. + + Args: + struc1: First Structure object. + struc2: Second Structure object. + ignore_fields: Fields not taken into account during comparison. + allow_chem_comp_data_extension: Whether to allow data of `_chem_comp` + table to differ if `struc2` is missing some fields, but `struc1` has + specific values for them. + atol: Absolute tolerance for floating point comparisons (in + np.testing.assert_allclose). + """ + for field in sorted(structure.GLOBAL_FIELDS): + if ignore_fields and field in ignore_fields: + continue + if field == 'author_naming_scheme': + self.assertAuthorNamingSchemeEqual( + struc1[field], struc2[field]) + elif field == 'all_residues': + self.assertAllResiduesEqual(struc1[field], struc2[field]) + elif field == 'bioassembly_data': + self.assertBioassemblyDataEqual(struc1[field], struc2[field]) + elif field == 'chemical_components_data': + self.assertChemicalComponentsDataEqual( + struc1[field], struc2[field], allow_chem_comp_data_extension + ) + elif field == 'bonds': + self.assertBondsEqual( + struc1.bonds, struc2.bonds, struc1.atom_key, struc2.atom_key + ) + else: + self.assertEqual(struc1[field], struc2[field], msg=field) + + # The chain order within a structure is arbitrary so in order to + # directly compare arrays we first align struc1 to struc2 and check that + # the number of atoms doesn't change. + num_atoms = struc1.num_atoms + self.assertEqual(struc2.num_atoms, num_atoms) + struc1 = struc1.order_and_drop_atoms_to_match(struc2) + self.assertEqual(struc1.num_atoms, num_atoms) + + for field in sorted(structure.ARRAY_FIELDS): + if field == 'atom_key': + # atom_key has no external meaning, so it doesn't matter whether it + # differs between two structures. + continue + if ignore_fields and field in ignore_fields: + continue + self.assertEqual(struc1[field] is None, + struc2[field] is None, msg=field) + + if np.issubdtype(struc1[field].dtype, np.inexact): + np.testing.assert_allclose( + struc1[field], struc2[field], err_msg=field, atol=atol + ) + else: + np.testing.assert_array_equal( + struc1[field], struc2[field], err_msg=field + ) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/attention/attention.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/attention/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..4d397750856bb4dbe83823cc2199b61f6f9fdfb4 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/attention/attention.py @@ -0,0 +1,77 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +from typing import Literal, TypeAlias +import typing +import alphafold3.utils.attention.attention_base as base +import alphafold3.utils.attention.ms_attention as ms_attention + +Implementation: TypeAlias = Literal["ms"] + + +def dot_product_attention(query, key, value, *, bias, mask, implementation, + logits_dtype=None, precision=None): + """Performs scaled dot-product attention. + + Scaled dot-product attention from "Attention is all you need" + https://arxiv.org/abs/1706.03762. + + Computes self- or cross-attention. The following is computed: + softmax(qk_scale * query @ key^T + bias) @ value. + + Supports both multi-head and multi-query attention + (https://arxiv.org/abs/1911.02150). + + Arguments: + query: Query array of shape `[batch, seq_len_q, num_heads, head_dim]`. + key: Key array of shape `[batch, seq_len_kv, num_heads, head_dim]`. + `num_heads` can be 1 for multi-query attention. + value: Value array of shape `[batch, seq_len_kv, num_heads, head_dim]`. + `num_heads` can be 1 for multi-query attention. + bias: Optional bias array, broadcastable to shape `[batch, num_heads, + seq_len_q, seq_len_kv]`. + mask: Optional boolean mask, broadcastable to `[batch, num_heads, seq_len_q, + seq_len_kv]`. Attention weights are masked out if the corresponding mask + value is `False`. + implementation: if `None` (default), an implementation is automatically + chosen. 'ms' will use standard MS and work on any platform. + logits_dtype: Data type for attention logits (`query @ key^T`). If `None` is + passed (the default), the accumulator type from the `query @ key^T` dot + product will be used, which is FP32 for BF16/FP16/FP32 inputs. Note that + this default increases the memory usage for BF16/FP16 inputs when using + `implementation='ms'`. + precision: The precision for the dot products. Either a single or a tuple + of `DEFAULT` precision. + + Returns: + An array with the same shape as `query`. + """ + + if implementation is not None: + named_args = typing.get_args(Implementation) + if implementation not in named_args: + raise ValueError( + f"Unsupported named implementation. Must be one of {named_args}." + ) + + logits_dtype = base.AUTO if logits_dtype is None else logits_dtype + precision = "DEFAULT" if precision is None else precision + + args = (query, key, value) + kwargs = dict( + precision=precision, + logits_dtype=logits_dtype, + bias=bias, + mask=mask, + ) + + return ms_attention.MsDotProductAttention()(*args, **kwargs) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/attention/attention_base.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/attention/attention_base.py new file mode 100644 index 0000000000000000000000000000000000000000..2bd41f5409c85acb34ecaf5cf5b701a3f0c44854 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/attention/attention_base.py @@ -0,0 +1,269 @@ +# Copyright 2024 DeepMind Technologies Limited +# Copyright (C) 2025 Huawei Technologies Co., Ltd +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# +# Modifications by Huawei Technologies Co., Ltd: Adapt to run by MindSpore on Ascend + +import abc +import enum +import math +import dataclasses +import functools +from dataclasses import dataclass, KW_ONLY +from typing import Any +import numpy as np +import mindspore as ms +from mindspore import ops, Tensor +from alphafold3.utils.common import precision as precision_lib + + +@dataclasses.dataclass(frozen=True) +class Mask: + """An attention mask. + + `k_start` (inclusive) and `k_end` (exclusive) define range of enabled + k-sequence values for each row of logits. + + For example, a local attention mask could be defined as follows: + ``` + seq_len_q = seq_len_k = 4 + window_size = 2 + k_start = Tensor(np.maximum(0, np.arange(seq_len_q) + 1 - window_size)) + mask = Mask(k_start=k_start, is_causal=True) + assert mask.as_array(seq_len_q, seq_len_k) == Tensor(np.array( + [[1, 0, 0, 0], + [1, 1, 0, 0], + [0, 1, 1, 0], + [0, 0, 1, 1]], dtype=bool)) + ``` + """ + bool_mask: ms.Tensor | None = None + _: dataclasses.KW_ONLY + q_start: ms.Tensor | None = None + q_end: ms.Tensor | None = None + k_start: ms.Tensor | None = None + k_end: ms.Tensor | None = None + is_causal: bool = False + + def tree_flatten(self): + return ( + self.bool_mask, + self.q_start, + self.q_end, + self.k_start, + self.k_end, + ), (self.is_causal,) + + @classmethod + def tree_unflatten(cls, aux, children): + (is_causal,) = aux + bool_mask, q_start, q_end, k_start, k_end = children + return cls( + bool_mask, + q_start=q_start, + q_end=q_end, + k_start=k_start, + k_end=k_end, + is_causal=is_causal, + ) + + def as_array(self, q_len_or_indices, k_len_or_indices): + """Returns the mask as a boolean array.""" + q_indices = ops.arange(q_len_or_indices) if isinstance( + q_len_or_indices, int) else q_len_or_indices + q_indices = q_indices[..., None] + + k_indices = ops.arange(k_len_or_indices) if isinstance( + k_len_or_indices, int) else k_len_or_indices + k_indices = k_indices[..., None, :] + + mask = [] + if self.bool_mask is not None: + mask.append(self.bool_mask) + + if self.q_start is not None: + mask.append(q_indices >= self.q_start[..., None, :]) + + if self.q_end is not None: + mask.append(q_indices < self.q_end[..., None, :]) + + if self.k_start is not None: + mask.append(k_indices >= self.k_start[..., None]) + + if self.k_end is not None: + mask.append(k_indices < self.k_end[..., None]) + + if self.is_causal: + mask.append(q_indices >= k_indices) + + logical_and = functools.partial(functools.reduce, ops.logical_and) + + if mask: + return logical_and(mask) + else: + return None + + def take(self, *attrs): + """Returns a mask with attrs removed and the removed attrs.""" + default_mask = type(self)() + replacements = {attr: getattr(default_mask, attr) for attr in attrs} + values = (getattr(self, attr) for attr in attrs) + return dataclasses.replace(self, **replacements), *values + + def __and__(self, other): + """Returns the intersection of two masks.""" + if not isinstance(other, Mask): + other = Mask(other) + + def combine(op): + return lambda a, b: b if a is None else a if b is None else op(a, b) + + return Mask( + bool_mask=combine(ops.logical_and)( + self.bool_mask, other.bool_mask), + q_end=combine(ops.minimum)(self.q_end, other.q_end), + k_start=combine(ops.maximum)(self.k_start, other.k_start), + k_end=combine(ops.minimum)(self.k_end, other.k_end), + is_causal=self.is_causal or other.is_causal, + ) + + +CAUSAL_MASK = Mask(is_causal=True) + + +@enum.unique +class SoftmaxResidualMode(enum.Enum): + """The mode of storing softmax residuals for the backwards pass. + + The stable softmax calculation performs two reductions calculating: + - the maximum input value (`x_max`), + - the sum of exponentiated values (`denom`). + + We can store these values as residuals to avoid the need to recompute them + in the backwards pass. + + It is also possible to combine the two residuals into a single residual, + `res = x_max + log(denom)`, as `exp(x - res) === exp(x - x_max - log(denom)) + === exp(x - x_max) / denom`. Combining the residuals reduces the memory usage + of the residuals, but will reduce the accuracy of the backwards pass if + `abs(x_max) >> log(denom)`. + """ + + SEPARATE = "separate" + COMBINED = "combined" + + def conform(self, aux): + match self, aux: + case None, _: + return None + case SoftmaxResidualMode.SEPARATE, (_, _): + return aux + case SoftmaxResidualMode.SEPARATE, _: # pytype: disable=redundant-match # b/300135240 + raise ValueError("`aux` has been combined.") + case SoftmaxResidualMode.COMBINED, (x_max, denom): + return x_max + ops.log(denom) + case SoftmaxResidualMode.COMBINED, _: # pytype: disable=redundant-match # b/300135240 + return aux + + +class DotProductAttention(abc.ABC): + """Dot product attention function.""" + + def __call__(self, query, key, value, *, precision, logits_dtype, bias, mask, q_indices=None, k_indices=None): + """Performs scaled dot-product attention. + + Scaled dot-product attention from "Attention is all you need" + https://arxiv.org/abs/1706.03762. + + Computes self- or cross-attention. The following is computed: + softmax(qk_scale * query @ key^T + bias) @ value. + + Supports both multi-head and multi-query attention + (https://arxiv.org/abs/1911.02150). + + Arguments: + query: Query array of shape `[batch, seq_len_q, num_heads_q, head_dim]`. + It must be a multiple of num_heads_kv. + Here's an example of how q/kv heads are interleaved: + For 8 key/value heads and 4 query heads: + - key/value heads [0, 1] see query head 0 + - key/value heads [2, 3] see query head 1 + - key/value heads [4, 5] see query head 2 + key: Key array of shape `[batch, seq_len_kv, num_heads_kv, head_dim]`. It + must be divisible by num_heads_q. + value: Value array of shape `[batch, seq_len_kv, num_heads_kv, head_dim]`. + precision: The precision for the dot products. Either a tuple `( + query_key_dot_precision, weights_value_dot_precision)` or a single + precision applied to both dot products. + logits_dtype: Data type for attention logits (`query @ key^T`). If `AUTO` + is passed (the default), the accumulator type from the `query @ key^T` + dot product will be used. + bias: Optional bias array, broadcastable to shape `[batch, num_heads, + seq_len_q, seq_len_kv]`. + mask: Optional boolean mask, broadcastable to `[batch, num_heads, + seq_len_q, seq_len_kv]`. Attention weights are masked out if the + corresponding mask value is `False`. + q_indices: Optional indices for each token in query sequence. + k_indices: Optional indices for each token in key/value sequence. + + Returns: + An array with the same shape as `query`. + """ + return self.fwd( + query, + key, + value, + precision=precision, + logits_dtype=logits_dtype, + bias=bias, + mask=mask, + q_indices=q_indices, + k_indices=k_indices, + ) + + def fwd(self, query, key, value, *, precision, logits_dtype, bias, mask, q_indices, k_indices): + """Performs attention.""" + if not isinstance(precision, tuple): + precision = (precision, precision) + + q_k_dot_precision, weights_v_dot_precision = precision + + if not isinstance(q_k_dot_precision, precision_lib.DotPrecision): + q_k_dot_precision = precision_lib.get_equivalent_dot_precision( + query.dtype, key.dtype, q_k_dot_precision + ) + + if not isinstance(weights_v_dot_precision, precision_lib.DotPrecision): + weights_v_dot_precision = precision_lib.get_equivalent_dot_precision( + value.dtype, value.dtype, weights_v_dot_precision + ) + + if not isinstance(mask, Mask): + mask = Mask(mask) + + return self._fwd( + Tensor(query), + Tensor(key), + Tensor(value), + q_k_dot_precision=q_k_dot_precision, + logits_dtype=logits_dtype, + logits_scale=1 / math.sqrt(query.shape[-1]), + bias=bias, + mask=mask, + weights_v_dot_precision=weights_v_dot_precision, + q_indices=q_indices, + k_indices=k_indices, + ) + + @abc.abstractmethod + def _fwd(self, q, k, v, *, q_k_dot_precision, logits_dtype, logits_scale, bias, mask, + weights_v_dot_precision, q_indices, k_indices): + """Performs attention.""" + ... diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/attention/attention_call_arg_specs.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/attention/attention_call_arg_specs.py new file mode 100644 index 0000000000000000000000000000000000000000..6e8db1bdea05cfc6c5351732d5a59bb48e5668c7 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/attention/attention_call_arg_specs.py @@ -0,0 +1,61 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Attention call argument specifications. + +Attention argument specifications used by users of the library. +They are the most important test cases, and also cases for optimize +performance of via autotuning. +""" + +from typing import Any + + +def _make_argspec( + *, + q_shape, + dtype, + k_shape=None, + v_shape=None, + bias_shape=None, + mask_shape=None, + **kwargs, +) -> dict[str, Any]: + """Make argspec from shapes and kwargs.""" + if k_shape is None: + k_shape = q_shape + if v_shape is None: + v_shape = k_shape + + return dict( + query=q_shape, + key=k_shape, + value=v_shape, + bias=bias_shape, + mask=mask_shape, + dtype=dtype, + **kwargs, + ) + + +# A subset of the full set of argument specifications. Useful for tap-tests and +# microbenchmarks. +CALL_ARG_SPECS = dict( + vanilla_f32=_make_argspec(q_shape=(8, 1024, 4, 128), dtype='float32'), + vanilla_bf16=_make_argspec(q_shape=(8, 1024, 4, 128), dtype='bfloat16'), + alphafold=_make_argspec( + q_shape=(384, 384, 4, 32), + bias_shape=(1, 4, 384, 384), + mask_shape=(384, 1, 1, 384), + dtype='bfloat16', + ), +) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/attention/ms_attention.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/attention/ms_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..835d0864489f71dcc19bd4413b07f07e00f8cca4 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/attention/ms_attention.py @@ -0,0 +1,96 @@ +# Copyright 2024 DeepMind Technologies Limited +# Copyright (C) 2025 Huawei Technologies Co., Ltd +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# +# Modifications by Huawei Technologies Co., Ltd: Adapt to run by MindSpore on Ascend + +import dataclasses +import mindspore as ms +from mindspore import ops +import alphafold3.utils.attention.attention_base as base + + +def _softmax(x): + """Computes softmax.""" + dtype = ms.float32 + x_max, _ = ops.max(x.astype(dtype), axis=-1, keepdims=True) + unnormalized = ops.exp(x - x_max) + denom = ops.sum(unnormalized, dim=-1, keepdim=True) + return (unnormalized / denom).astype(x.dtype) + + +def cal_logits(q, k, use_bf16=False): + # ...qhd,...khd->...hqk + dtype = q.dtype + if use_bf16: + q = q.astype(ms.bfloat16) + k = k.astype(ms.bfloat16) + q_trans = ops.transpose(q, (0, 2, 1, 3)) # ...qhd -> ...hqd + k_trans = ops.transpose(k, (0, 2, 3, 1)) # ...khd -> ...hdk + logits = ops.matmul(q_trans, k_trans) + if use_bf16: + logits = logits.astype(dtype) + return logits + + +def cal_out(weights, v, use_bf16=False): + # ...hqk,...khd->...qhd + if use_bf16: + weights = weights.astype(ms.bfloat16) + v = v.astype(ms.bfloat16) + v_trans = ops.transpose(v, (0, 2, 1, 3)) # ...khd -> ...hkd + out_temp = ops.matmul(weights, v_trans) # ...hqk,...hkd->...hqd + out = ops.transpose(out_temp, (0, 2, 1, 3)) + return out + + +def _attend( + q, k, v, *, q_k_dot_precision, logits_dtype, logits_scale, + bias, mask, weights_v_dot_precision, q_indices, k_indices, +): + logits = cal_logits(q, k) + + logits *= logits_scale + + if bias is not None: + logits += bias + + if mask is not None: + q_len_or_indices = q.shape[-3] if q_indices is None else q_indices + k_len_or_indices = k.shape[-3] if k_indices is None else k_indices + mask = mask.as_array(q_len_or_indices, k_len_or_indices) + + if mask is not None: # TBD in ms + mask_value = -3.4028235e+37 # a small value close to min of bfloat16 + logits = ops.where(mask.bool(), logits, mask_value) + + weights = _softmax(logits) + + out = cal_out(weights, v) + + return out + + +@dataclasses.dataclass(frozen=True) +class MsDotProductAttention(base.DotProductAttention): + """MS dot product attention function.""" + + _: dataclasses.KW_ONLY + + def _fwd( + self, q, k, v, *, q_k_dot_precision, logits_dtype, logits_scale, + bias, mask, weights_v_dot_precision, q_indices, k_indices, + ): + + return _attend( + q, k, v, bias=bias, mask=mask, q_indices=q_indices, k_indices=k_indices, + q_k_dot_precision=q_k_dot_precision, logits_dtype=logits_dtype, logits_scale=logits_scale, + weights_v_dot_precision=weights_v_dot_precision, + ) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/common/precision.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/common/precision.py new file mode 100644 index 0000000000000000000000000000000000000000..b4b299dcd856732147c6d22cc962bdff6330e1e9 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/common/precision.py @@ -0,0 +1,91 @@ +# Copyright 2024 DeepMind Technologies Limited +# Copyright (C) 2025 Huawei Technologies Co., Ltd +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# +# Modifications by Huawei Technologies Co., Ltd: Adapt to run by MindSpore on Ascend + +"""Precision classes and utilities.""" + +import enum +import mindspore as ms + + +@enum.unique +class DotPrecision(enum.Enum): + """Precision for `dot` operation. + + Naming scheme: {OPERAND_DTYPE}_{ACCUMULATOR_DTYPE}[_{NUM_PASSES}x] + """ + + BF16_F32 = "bf16_f32" + + # NPU only precisions. + F32_F32 = "f32_f32" # Full f32 precision (doesn't use TensorCores). + F16_F16 = "f16_f16" + F16_F32 = "f16_f32" + + @property + def operand_dtype(self) -> ms.dtype: + match self: + case DotPrecision.BF16_F32: + return ms.bfloat16 + case DotPrecision.F16_F16 | DotPrecision.F16_F32: + return ms.float16 + case _: + return ms.float32 + + @property + def accumulator_dtype(self) -> ms.dtype: + return ms.float16 if (self == DotPrecision.F16_F16) else ms.float32 + + +_MS_NPU_PRECISION_MAP = { + (ms.float16, "DEFAULT"): DotPrecision.F16_F32, + (ms.bfloat16, "DEFAULT"): DotPrecision.BF16_F32, + (ms.float32, "DEFAULT"): DotPrecision.F32_F32, + (ms.float32, "HIGH"): DotPrecision.F32_F32, + (ms.float32, "HIGHEST"): DotPrecision.F32_F32, +} + +_MS_CPU_PRECISION_MAP = { + (ms.float16, "DEFAULT"): DotPrecision.F16_F32, + (ms.bfloat16, "DEFAULT"): DotPrecision.F32_F32, + (ms.float32, "DEFAULT"): DotPrecision.F32_F32, + (ms.float32, "HIGH"): DotPrecision.F32_F32, + (ms.float32, "HIGHEST"): DotPrecision.F32_F32, +} + + +def _create_ms_precision_map(): + precision_map = {} + for (dtype, ms_precision), dot_precision in _MS_NPU_PRECISION_MAP.items(): + precision_map[("ascend", dtype, ms_precision)] = dot_precision + for (dtype, ms_precision), dot_precision in _MS_CPU_PRECISION_MAP.items(): + precision_map[("cpu", dtype, ms_precision)] = dot_precision + return precision_map + + +_MS_PRECISION_MAP = _create_ms_precision_map() + + +def get_equivalent_dot_precision( + a_dtype: ms.dtype, b_dtype: ms.dtype, ms_precision: str +) -> DotPrecision: + """Returns `DotPrecision` replicating default behaviour.""" + if a_dtype != b_dtype: + raise ValueError("Cannot infer precision if operand types differ.") + + backend = ms.context.get_context("device_target").lower() + if (ms_precision != "DEFAULT") and (a_dtype != ms.float32): + raise ValueError( + "`Precision` values other than `DEFAULT` only have an effect if" + " the operand type is `float32`." + ) + return _MS_PRECISION_MAP[(backend, a_dtype, ms_precision)] diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/gated_linear_unit/gated_linear_unit.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/gated_linear_unit/gated_linear_unit.py new file mode 100644 index 0000000000000000000000000000000000000000..5e7f3718ead378760ecee4b8f7e2e29313eeec5e --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/gated_linear_unit/gated_linear_unit.py @@ -0,0 +1,66 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Public API for gated linear unit functions.""" + +import typing +from typing import Literal, TypeAlias +from alphafold3.utils.gated_linear_unit import gated_linear_unit_base + +Implementation: TypeAlias = Literal['ms'] + + +def gated_linear_unit(x, weight, *, activation, precision, implementation=None): + """Applies a gated linear unit (https://arxiv.org/abs/1612.08083). + + Computes `activation(x @ weight[:, 0]) * x @ weight[:, 1]`. + + This is SwiGLU when `activation=swish`, GEGLU when + `activation=gelu`, REGLU when `activation=relu`, and GLU when + `activation=sigmoid` (https://arxiv.org/abs/2002.05202). + + Args: + x: the input array. + weight: the combined weight array. + activation: optional activation function. + precision: specifies the matrix multiplication precision. Either `None` + (default), which means the default precision for the backend, or an + enum of "DEFAULT/HIGH/...". + implementation: if `None` (default), an implementation is automatically + chosen. 'ms' will use standard MS and work on any platform. + + Raises: + ValueError: if the arguments are invalid. + + Returns: + The output array. + """ + + if x.dtype != weight.dtype: + raise ValueError( + f'Input and weight must have the same dtype. {x.dtype} !=' + f' {weight.dtype}' + ) + + if implementation is not None: + named_args = typing.get_args(Implementation) + if implementation not in named_args: + raise ValueError( + f'Unsupported named implementation. Must be one of {named_args}.' + ) + + return gated_linear_unit_base.gated_linear_unit_ms( + x=x, + weight=weight, + activation=activation, + precision=precision, + ) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/gated_linear_unit/gated_linear_unit_base.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/gated_linear_unit/gated_linear_unit_base.py new file mode 100644 index 0000000000000000000000000000000000000000..afa6406f903bfe05efed789a6e61d9414ff82b73 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/gated_linear_unit/gated_linear_unit_base.py @@ -0,0 +1,84 @@ +# Copyright 2024 DeepMind Technologies Limited +# Copyright (C) 2025 Huawei Technologies Co., Ltd +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# +# Modifications by Huawei Technologies Co., Ltd: Adapt to run by MindSpore on Ascend + +"""Common types for gated linear unit kernels.""" +import abc +import mindspore as ms +from mindspore import mint + + +class GatedLinearUnit(abc.ABC): + """Gated linear unit.""" + + def __call__(self, x, weight, *, activation, precision, **kwargs): + """Applies a gated linear unit (https://arxiv.org/abs/1612.08083). + + Computes `activation(x @ weight[:, 0]) * x @ weight[:, 1]`. + + This is SwiGLU when `activation=swish`, GEGLU when + `activation=gelu`, REGLU when `activation=relu`, and GLU when + `activation=sigmoid` (https://arxiv.org/abs/2002.05202). + + Args: + x: the input array. + weight: the combined weight array. + activation: optional activation function. + precision: specifies the matrix multiplication precision. Either `None` + (default), which means the default precision for the backend, or an + enum of "DEFAULT/HIGH/...". + + Returns: + The output array. + """ + + return self._fwd( + x, weight, activation=activation, precision=precision, **kwargs + ) + + @abc.abstractmethod + def _fwd(self, x, weight, *, activation, precision): + """Gated linear unit.""" + ... + + +def gated_linear_unit_ms(x, weight, *, activation, precision=None): + """Applies a gated linear unit (https://arxiv.org/abs/1612.08083). + + Computes `activation(x @ weight[:, 0]) * x @ weight[:, 1]`. + + This is SwiGLU when `activation=swish`, GEGLU when + `activation=gelu`, REGLU when `activation=relu`, and GLU when + `activation=sigmoid` (https://arxiv.org/abs/2002.05202). + + Args: + x: the input array. + weight: the combined weight array. + activation: optional activation function. + precision: specifies the matrix multiplication precision. Either `None` + (default), which means the default precision for the backend, or an + enum of "DEFAULT/HIGH/...". + + Returns: + The output array. + """ + + weight_reshaped = mint.reshape( + weight, (-1, weight.shape[-2] * weight.shape[-1])) + # y = ops.dot(x.astype('float32'), weight_reshaped.astype('float32')) + y1 = mint.matmul(x, weight_reshaped) + y = y1.astype(ms.float32) + a, b = y.split(y.shape[-1] // 2, axis=-1) + out = mint.mul(a, b) if activation is None else mint.mul(activation(a), b) + out = out.astype(x.dtype) + + return out diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/__init__.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..910ccfe9f425dd70390b93bbe2950e9b6a593ae8 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +from alphafold3.utils.geometry import rigid_matrix_vector +from alphafold3.utils.geometry import rotation_matrix +from alphafold3.utils.geometry import struct_of_array +from alphafold3.utils.geometry import vector + +Rot3Array = rotation_matrix.Rot3Array +Rigid3Array = rigid_matrix_vector.Rigid3Array + +StructOfArray = struct_of_array.StructOfArray + +Vec3Array = vector.Vec3Array +square_euclidean_distance = vector.square_euclidean_distance +euclidean_distance = vector.euclidean_distance +dihedral_angle = vector.dihedral_angle +dot = vector.dot +cross = vector.cross diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/rigid_matrix_vector.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/rigid_matrix_vector.py new file mode 100644 index 0000000000000000000000000000000000000000..6faf4e062bf293b864ce96660ca1b624b6bbab32 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/rigid_matrix_vector.py @@ -0,0 +1,194 @@ +# Copyright 2024 DeepMind Technologies Limited +# Copyright (C) 2025 Huawei Technologies Co., Ltd +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# +# Modifications by Huawei Technologies Co., Ltd: Adapt to run by MindSpore on Ascend + +"""Rigid3Array Transformations represented by a Matrix and a Vector.""" + +from typing import Any, Final, TypeAlias +import mindspore as ms +import mindspore.numpy as mnp +from mindspore import Tensor +from mindspore.ops import operations as P + +from alphafold3.utils.geometry import rotation_matrix, struct_of_array, utils, vector + +Float: TypeAlias = float | Tensor + +VERSION: Final[str] = '0.1' + + +def _compute_covariance_matrix( + row_values: vector.Vec3Array, + col_values: vector.Vec3Array, + weights: Tensor, + epsilon=1e-6, +) -> Tensor: + """Compute covariance matrix.""" + weights = mnp.asarray(weights) + + weights = mnp.broadcast_to(weights, row_values.shape) + + normalized_weights = weights / \ + (mnp.sum(weights, axis=-1, keepdims=True) + epsilon) + + def weighted_average(x): + return mnp.sum(normalized_weights * x, axis=-1) + + out = [ + mnp.stack( + ( + weighted_average(row_values.x * col_values.x), + weighted_average(row_values.x * col_values.y), + weighted_average(row_values.x * col_values.z), + ), + axis=-1, + ) + ] + + out.append( + mnp.stack( + ( + weighted_average(row_values.y * col_values.x), + weighted_average(row_values.y * col_values.y), + weighted_average(row_values.y * col_values.z), + ), + axis=-1, + ) + ) + + out.append( + mnp.stack( + ( + weighted_average(row_values.z * col_values.x), + weighted_average(row_values.z * col_values.y), + weighted_average(row_values.z * col_values.z), + ), + axis=-1, + ) + ) + + return mnp.stack(out, axis=-2) + + +@struct_of_array.StructOfArray(same_dtype=True) +class Rigid3Array: + """Rigid Transformation, i.e. element of special euclidean group.""" + + rotation: rotation_matrix.Rot3Array + translation: vector.Vec3Array + + def __matmul__(self, other: 'Rigid3Array') -> 'Rigid3Array': + new_rotation = self.rotation @ other.rotation + new_translation = self.apply_to_point(other.translation) + return Rigid3Array(new_rotation, new_translation) + + def inverse(self) -> 'Rigid3Array': + """Return Rigid3Array corresponding to inverse transform.""" + inv_rotation = self.rotation.inverse() + inv_translation = inv_rotation.apply_to_point(-self.translation) + return Rigid3Array(inv_rotation, inv_translation) + + def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: + """Apply Rigid3Array transform to point.""" + return self.rotation.apply_to_point(point) + self.translation + + def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: + """Apply inverse Rigid3Array transform to point.""" + new_point = point - self.translation + return self.rotation.apply_inverse_to_point(new_point) + + def compose_rotation(self, other_rotation: rotation_matrix.Rot3Array) -> 'Rigid3Array': + rot = self.rotation @ other_rotation + trans = P.BroadcastTo(rot.shape)(self.translation) + return Rigid3Array(rot, trans) + + @classmethod + def identity(cls, shape: Any, dtype: ms.dtype = ms.float32) -> 'Rigid3Array': + """Return identity Rigid3Array of given shape.""" + + return cls( + rotation_matrix.Rot3Array.identity(shape, dtype=dtype), + vector.Vec3Array.zeros(shape, dtype=dtype), + ) + + def scale_translation(self, factor: Float) -> 'Rigid3Array': + """Scale translation in Rigid3Array by 'factor'.""" + return Rigid3Array(self.rotation, self.translation * factor) + + def to_array(self): + rot_array = self.rotation.to_array() + vec_array = self.translation.to_array() + return mnp.concatenate([rot_array, vec_array[..., None]], axis=-1) + + @classmethod + def from_array(cls, array): + rot = rotation_matrix.Rot3Array.from_array(array[..., :3]) + vec = vector.Vec3Array.from_array(array[..., -1]) + return cls(rot, vec) + + @classmethod + def from_array4x4(cls, array: Tensor) -> 'Rigid3Array': + """Construct Rigid3Array from homogeneous 4x4 array.""" + if array.shape[-2:] != (4, 4): + raise ValueError(f'array.shape({array.shape}) must be [..., 4, 4]') + rotation = rotation_matrix.Rot3Array( + *(array[..., 0, 0], array[..., 0, 1], array[..., 0, 2]), + *(array[..., 1, 0], array[..., 1, 1], array[..., 1, 2]), + *(array[..., 2, 0], array[..., 2, 1], array[..., 2, 2]), + ) + translation = vector.Vec3Array( + array[..., 0, 3], array[..., 1, 3], array[..., 2, 3] + ) + return cls(rotation, translation) + + @classmethod + def from_point_alignment( + cls, + points_to: vector.Vec3Array, + points_from: vector.Vec3Array, + weights: Float | None = None, + epsilon: float = 1e-6, + ) -> 'Rigid3Array': + """Constructs Rigid3Array by finding transform aligning points.""" + if weights is None: + weights = 1.0 + + def compute_center(value): + return utils.weighted_mean(value=value, weights=weights, axis=-1) + + points_to_center = P.Map()(compute_center, points_to) + points_from_center = P.Map()(compute_center, points_from) + centered_points_to = points_to - points_to_center[..., None] + centered_points_from = points_from - points_from_center[..., None] + cov_mat = _compute_covariance_matrix( + centered_points_to, + centered_points_from, + weights=weights, + epsilon=epsilon, + ) + rots = rotation_matrix.Rot3Array.from_svd( + mnp.reshape(cov_mat, cov_mat.shape[:-2] + (9,)) + ) + + translations = points_to_center - \ + rots.apply_to_point(points_from_center) + + return cls(rots, translations) + + def __getstate__(self): + return (VERSION, (self.rotation, self.translation)) + + def __setstate__(self, state): + version, (rot, trans) = state + del version + object.__setattr__(self, 'rotation', rot) + object.__setattr__(self, 'translation', trans) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/rotation_matrix.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/rotation_matrix.py new file mode 100644 index 0000000000000000000000000000000000000000..b9c91a59811315c52abed4f55b3d04158c806420 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/rotation_matrix.py @@ -0,0 +1,255 @@ +# Copyright 2024 DeepMind Technologies Limited +# Copyright (C) 2025 Huawei Technologies Co., Ltd +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# +# Modifications by Huawei Technologies Co., Ltd: Adapt to run by MindSpore on Ascend + +"""Rot3Array Matrix Class.""" + +import dataclasses +from typing import Any, Final +import numpy as np +import mindspore as ms +import mindspore.numpy as mnp +from mindspore import ops, mint +from mindspore import Tensor +from alphafold3.utils.geometry import struct_of_array, utils, vector + +COMPONENTS: Final[tuple[str, ...]] = ( + *('xx', 'xy', 'xz'), + *('yx', 'yy', 'yz'), + *('zx', 'zy', 'zz'), +) +VERSION: Final[str] = '0.1' + + +def make_matrix_svd_factors() -> Tensor: + """Generates factors for converting 3x3 matrix to symmetric 4x4 matrix.""" + factors = mnp.zeros((16, 9), dtype=ms.float32) + + indices = [(0, [0, 4, 8]), ([1, 4], 5), ([1, 4], 7), ([2, 8], 6), ([2, 8], 2), + ([3, 12], 1), ([3, 12], 3), (5, 0), (5, [4, 8]), + ([6, 9], 1), ([6, 9], 3), ([7, 13], 2), ([7, 13], 6), + (10, 4), (10, [0, 8]), ([11, 14], 5), ([11, 14], 7), (15, 8), (15, [0, 4])] + + values = [[1.0], [1.0, -1.0], [1.0, -1.0], [1.0, -1.0], [1.0, -1.0], + [1.0, -1.0], [1.0, 1.0], [1.0, -1.0], [-1.0, -1.0], + [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], + [1.0, -1.0], [1.0, -1.0], [1.0, 1.0], [1.0, 1.0]] + + for idx, val in zip(indices, values): + if isinstance(idx[1], list): + for i in idx[1]: + factors[idx[0], i] = val[i % len(val)] + else: + factors[idx[0], idx[1]] = val[0] + + return factors + + +def largest_evec(m): + _, eigvecs = np.linalg.eigh(m.asnumpy()) + + return Tensor(eigvecs[..., -1]) + + +MATRIX_SVD_QUAT_FACTORS = make_matrix_svd_factors() + + +@struct_of_array.StructOfArray(same_dtype=True) +class Rot3Array: + """Rot3Array Matrix in 3 dimensional Space implemented as struct of arrays.""" + + xx: Tensor = dataclasses.field(metadata={'dtype': ms.float32}) + xy: Tensor + xz: Tensor + yx: Tensor + yy: Tensor + yz: Tensor + zx: Tensor + zy: Tensor + zz: Tensor + + __array_ufunc__ = None + + def inverse(self): + """Returns inverse of Rot3Array.""" + return Rot3Array( + *(self.xx, self.yx, self.zx), + *(self.xy, self.yy, self.zy), + *(self.xz, self.yz, self.zz), + ) + + def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: + """Applies Rot3Array to point.""" + x = self.xx * point.x + self.xy * point.y + self.xz * point.z + y = self.yx * point.x + self.yy * point.y + self.yz * point.z + z = self.zx * point.x + self.zy * point.y + self.zz * point.z + return vector.Vec3Array(x, y, z) + + def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: + """Applies inverse Rot3Array to point.""" + return self.inverse().apply_to_point(point) + + def __matmul__(self, other): + """Composes two Rot3Arrays.""" + c0 = self.apply_to_point( + vector.Vec3Array(other.xx, other.yx, other.zx)) + c1 = self.apply_to_point( + vector.Vec3Array(other.xy, other.yy, other.zy)) + c2 = self.apply_to_point( + vector.Vec3Array(other.xz, other.yz, other.zz)) + return Rot3Array(c0.x, c1.x, c2.x, c0.y, c1.y, c2.y, c0.z, c1.z, c2.z) + + @classmethod + def identity(cls, shape: Any, dtype: ms.dtype = ms.float32): + """Returns identity of given shape.""" + ones = mint.ones(shape, dtype=dtype) + zeros = mint.zeros(shape, dtype=dtype) + + temp = cls(ones, zeros, zeros, zeros, ones, zeros, zeros, zeros, ones) + return temp + + @classmethod + def from_two_vectors(cls, e0: vector.Vec3Array, e1: vector.Vec3Array): + """Construct Rot3Array from two Vectors. + + Rot3Array is constructed such that in the corresponding frame 'e0' lies on + the positive x-Axis and 'e1' lies in the xy plane with positive sign of y. + + Args: + e0: Vector + e1: Vector + + Returns: + Rot3Array + """ + # Normalize the unit vector for the x-axis, e0. + e0 = e0.normalized() + # Make e1 perpendicular to e0. + c = e1.dot(e0) + e1 = (e1 - e0 * c).normalized() + # Compute e2 as cross product of e0 and e1. + e2 = e0.cross(e1) + return cls(e0.x, e1.x, e2.x, e0.y, e1.y, e2.y, e0.z, e1.z, e2.z) + + @classmethod + def from_array(cls, array: Tensor): + """Construct Rot3Array Matrix from array of shape [..., 3, 3].""" + unstacked = utils.unstack(array, axis=-2) + unstacked = sum([utils.unstack(x, axis=-1) for x in unstacked], []) + return cls(*unstacked) + + def to_array(self) -> Tensor: + """Convert Rot3Array to array of shape [..., 3, 3].""" + return ops.stack( + [ + ops.stack([self.xx, self.xy, self.xz], axis=-1), + ops.stack([self.yx, self.yy, self.yz], axis=-1), + ops.stack([self.zx, self.zy, self.zz], axis=-1), + ], + axis=-2, + ) + + @classmethod + def from_quaternion( + cls, + w: Tensor, + x: Tensor, + y: Tensor, + z: Tensor, + normalize: bool = True, + epsilon: float = 1e-6, + ): + """Construct Rot3Array from components of quaternion.""" + if normalize: + inv_norm = ops.rsqrt(ops.maximum( + epsilon, w**2 + x**2 + y**2 + z**2)) + w *= inv_norm + x *= inv_norm + y *= inv_norm + z *= inv_norm + xx = 1 - 2 * (y**2 + z**2) + xy = 2 * (x * y - w * z) + xz = 2 * (x * z + w * y) + yx = 2 * (x * y + w * z) + yy = 1 - 2 * (x**2 + z**2) + yz = 2 * (y * z - w * x) + zx = 2 * (x * z - w * y) + zy = 2 * (y * z + w * x) + zz = 1 - 2 * (x**2 + y**2) + return cls(xx, xy, xz, yx, yy, yz, zx, zy, zz) + + @classmethod + def from_svd(cls, mat: Tensor, use_quat_formula: bool = True): + """Constructs Rot3Array from arbitrary array of shape [3 * 3] using SVD. + + The case when 'use_quat_formula' is False rephrases the problem of + projecting the matrix to a rotation matrix as a problem of finding the + largest eigenvector of a certain 4x4 matrix. This has the advantage of + having fewer numerical issues. + This approach follows: + https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.65.971&rep=rep1&type=pdf + In the other case we construct it via svd following + https://arxiv.org/pdf/2006.14616.pdf + In that case [∂L/∂M] is large if the two smallest singular values are close + to each other, or if they are close to 0. + + Args: + mat: Array of shape [..., 3 * 3] + use_quat_formula: Whether to construct matrix via 4x4 eigenvalue problem. + + Returns: + Rot3Array of shape [...] + """ + assert mat.shape[-1] == 9 + if use_quat_formula: + symmetric_4by4 = ops.einsum( + 'ji, ...i -> ...j', + MATRIX_SVD_QUAT_FACTORS, + mat, + ) + symmetric_4by4 = ops.reshape( + symmetric_4by4, mat.shape[:-1] + (4, 4)) + largest_eigvec = largest_evec(symmetric_4by4) + return cls.from_quaternion( + *utils.unstack(largest_eigvec, axis=-1) + ).inverse() + + mat = ops.reshape(mat, mat.shape[:-1] + (3, 3)) + u, _, v_t = np.linalg.svd(mat.asnumpy(), full_matrices=False) + u = Tensor(u) + v_t = Tensor(v_t) + det_uv_t = ops.det(ops.matmul(u, v_t)) + ones = ops.ones_like(det_uv_t) + diag_array = ops.stack([ones, ones, det_uv_t], axis=-1) + # This is equivalent to making diag_array into a diagonal array and matrix + # multiplying + diag_times_v_t = diag_array[..., None] * v_t + out = ops.matmul(u, diag_times_v_t) + return cls.from_array(out) + + @classmethod + def random_uniform(cls, key, shape, dtype=ms.float32): + """Samples uniform random Rot3Array according to Haar Measure.""" + stdnormal = ops.StandardNormal(seed=key) + quat_array = stdnormal(shape + (4,)).astype(dtype) + # quat_array = ops.StandardNormal()(shape=(tuple(shape) + (4,)), seed=key) + quats = utils.unstack(quat_array) + return cls.from_quaternion(*quats) + + def __getstate__(self): + return (VERSION, [getattr(self, field) for field in COMPONENTS]) + + def __setstate__(self, state): + version, state = state + del version + for i, field in enumerate(COMPONENTS): + object.__setattr__(self, field, state[i]) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/struct_of_array.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/struct_of_array.py new file mode 100644 index 0000000000000000000000000000000000000000..80e872c172feb3d5e47cbeef06fd0856611cec96 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/struct_of_array.py @@ -0,0 +1,280 @@ +# Copyright 2024 DeepMind Technologies Limited +# Copyright (C) 2025 Huawei Technologies Co., Ltd +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# +# Modifications by Huawei Technologies Co., Ltd: Adapt to run by MindSpore on Ascend + +"""Class decorator to represent (nested) struct of arrays.""" + +import dataclasses +import mindspore as ms + +def get_item(instance, key): + sliced = {} + for field in get_array_fields(instance): + num_trailing_dims = field.metadata.get('num_trailing_dims', 0) + this_key = key + if isinstance(key, tuple) and Ellipsis in this_key: + this_key += (slice(None),) * num_trailing_dims + + def apply_slice(x): + if isinstance(x, ms.Tensor): + return x[this_key] + elif isinstance(x, dict): + return {k: apply_slice(v) for k, v in x.items()} + elif isinstance(x, list): + return [apply_slice(item) for item in x] + else: + return x + + sliced[field.name] = apply_slice(getattr(instance, field.name)) + + return dataclasses.replace(instance, **sliced) + + +@property +def get_shape(instance): + """Returns Shape for given instance of dataclass.""" + first_field = dataclasses.fields(instance)[0] + num_trailing_dims = first_field.metadata.get('num_trailing_dims', None) + value = getattr(instance, first_field.name) + if num_trailing_dims: + return value.shape[:-num_trailing_dims] + else: + return value.shape + + +def get_len(instance): + """Returns length for given instance of dataclass.""" + shape = instance.shape + if shape: + return shape[0] + else: + # Match utils.numpy behavior. + raise TypeError('len() of unsized object') + + +@property +def get_dtype(instance): + """Returns Dtype for given instance of dataclass.""" + fields = dataclasses.fields(instance) + sets_dtype = [ + field.name for field in fields if field.metadata.get('sets_dtype', False) + ] + if sets_dtype: + assert len(sets_dtype) == 1, 'at most one field can set dtype' + field_value = getattr(instance, sets_dtype[0]) + elif instance.same_dtype: + field_value = getattr(instance, fields[0].name) + else: + raise AttributeError( + 'Trying to access Dtype on Struct of Array without' + 'either "same_dtype" or field setting dtype' + ) + + if hasattr(field_value, 'dtype'): + return field_value.dtype + else: + raise AttributeError(f'field_value {field_value} does not have dtype') + + +def replace(instance, **kwargs): + return dataclasses.replace(instance, **kwargs) + + +def post_init(instance): + """Validate instance has same shapes & dtypes.""" + array_fields = get_array_fields(instance) + arrays = list(get_array_fields(instance, return_values=True).values()) + first_field = array_fields[0] + try: + dtype = instance.dtype + except AttributeError: + dtype = None + if dtype is not None: + first_shape = instance.shape + for array, field in zip(arrays, array_fields, strict=True): + num_trailing_dims = field.metadata.get('num_trailing_dims', None) + if num_trailing_dims: + array_shape = array.shape + field_shape = array_shape[:-num_trailing_dims] + msg = ( + f'field {field} should have number of trailing dims' + ' {num_trailing_dims}' + ) + assert len(array_shape) == len( + first_shape) + num_trailing_dims, msg + else: + + field_shape = array.shape + + shape_msg = ( + f"Stripped Shape {field_shape} of field {field} doesn't " + f'match shape {first_shape} of field {first_field}' + ) + + assert field_shape == first_shape, shape_msg + + field_dtype = array.dtype + + allowed_metadata_dtypes = field.metadata.get('allowed_dtypes', []) + if allowed_metadata_dtypes: + msg = f'Dtype is {field_dtype} but must be in {allowed_metadata_dtypes}' + assert field_dtype in allowed_metadata_dtypes, msg + + if 'dtype' in field.metadata: + target_dtype = field.metadata['dtype'] + else: + target_dtype = dtype + + msg = f'Dtype is {field_dtype} but must be {target_dtype}' + assert field_dtype == target_dtype, msg + + +def flatten(instance): + """Flatten Struct Of Array instance.""" + array_likes = get_array_fields(instance, return_values=True).values() + flat_array_likes = [] + inner_treedefs = [] + num_arrays = [] + for array_like in array_likes: + flat_array_like, inner_treedef = tree_flatten(array_like) + inner_treedefs.append(inner_treedef) + flat_array_likes += flat_array_like + num_arrays.append(len(flat_array_like)) + metadata = get_metadata_fields(instance, return_values=True) + metadata = type(instance).metadata_cls(**metadata) + return flat_array_likes, (inner_treedefs, metadata, num_arrays) + + +def make_metadata_class(cls): + metadata_fields = get_fields( + cls, lambda x: x.metadata.get('is_metadata', False) + ) + metadata_cls = dataclasses.make_dataclass( + cls_name='Meta' + cls.__name__, + fields=[(field.name, field.type, field) for field in metadata_fields], + frozen=True, + eq=True, + ) + return metadata_cls + + +def get_fields(cls_or_instance, filterfn, return_values=False): + fields = dataclasses.fields(cls_or_instance) + fields = [field for field in fields if filterfn(field)] + if return_values: + return { + field.name: getattr(cls_or_instance, field.name) for field in fields + } + else: + return fields + + +def get_array_fields(cls, return_values=False): + return get_fields( + cls, + lambda x: not x.metadata.get('is_metadata', False), + return_values=return_values, + ) + + +def get_metadata_fields(cls, return_values=False): + return get_fields( + cls, + lambda x: x.metadata.get('is_metadata', False), + return_values=return_values, + ) + + +def tree_flatten(pytree): + """Custom tree flattening function for MindSpore tensors.""" + if isinstance(pytree, ms.Tensor): + return [pytree], None + elif isinstance(pytree, dict): + keys, values = zip(*pytree.items()) + flat_values, treedefs = zip(*(tree_flatten(v) for v in values)) + return sum(flat_values, []), {'keys': keys, 'treedefs': treedefs} + elif isinstance(pytree, list): + flat_items, treedefs = zip(*(tree_flatten(item) for item in pytree)) + return sum(flat_items, []), {'treedefs': treedefs} + else: + return [], None + + +def tree_unflatten(treedef, leaves): + """Custom tree unflattening function for MindSpore tensors.""" + if treedef is None: + return leaves[0] + elif isinstance(treedef, dict): + if 'keys' in treedef: + keys = treedef['keys'] + treedefs = treedef['treedefs'] + items = [tree_unflatten(td, leaves[i:i+1]) + for i, td in enumerate(treedefs)] + return dict(zip(keys, items)) + else: + treedefs = treedef['treedefs'] + start = 0 + items = [] + for td in treedefs: + size = len(tree_flatten(tree_unflatten( + td, leaves[start:start+1]))[0]) + items.append(tree_unflatten(td, leaves[start:start+size])) + start += size + return items + else: + return [] + + +class StructOfArray: + """Class Decorator for Struct Of Arrays.""" + + def __init__(self, same_dtype=True): + self.same_dtype = same_dtype + + def __call__(self, cls): + cls.__array_ufunc__ = None + cls.replace = replace + cls.same_dtype = self.same_dtype + cls.dtype = get_dtype + cls.shape = get_shape + cls.__len__ = get_len + cls.__getitem__ = get_item + cls.__post_init__ = post_init + new_cls = dataclasses.dataclass(cls, frozen=True, eq=False) + # pytree claims to require metadata to be hashable, not sure why, + # But making derived dataclass that can just hold metadata + new_cls.metadata_cls = make_metadata_class(new_cls) + + def unflatten(cls, params): + aux, data = params + inner_treedefs, metadata, num_arrays = aux + array_fields = [field.name for field in get_array_fields(new_cls)] + value_dict = {} + array_start = 0 + for num_array, inner_treedef, array_field in zip( + num_arrays, inner_treedefs, array_fields, strict=True + ): + value_dict[array_field] = tree_unflatten( + inner_treedef, data[array_start: array_start + num_array] + ) + array_start += num_array + metadata_fields = get_metadata_fields(new_cls) + for field in metadata_fields: + value_dict[field.name] = getattr(metadata, field.name) + + return new_cls(**value_dict) + + # Override __flatten__ and __unflatten__ methods + new_cls.__flatten__ = flatten + new_cls.__unflatten__ = unflatten + + return new_cls diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/utils.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b332b5dea8827f3c01173150ce3c194e8fe3a289 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/utils.py @@ -0,0 +1,149 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Utils for geometry library.""" + +from collections.abc import Iterable +import numbers + +import mindspore as ms +import mindspore.ops as ops +import mindspore.numpy as mnp + + +def safe_select(condition, true_fn, false_fn): + """Safe version of selection (i.e. `where`). + + This applies the double-where trick. + Like jnp.where, this function will still execute both branches and is + expected to be more lightweight than lax.cond. Other than NaN-semantics, + safe_select(condition, true_fn, false_fn) is equivalent to + + utils.tree.map(lambda x, y: jnp.where(condition, x, y), + true_fn(), + false_fn()), + + Compared to the naive implementation above, safe_select provides the + following guarantee: in either the forward or backward pass, a NaN produced + *during the execution of true_fn()* will not propagate to the rest of the + computation and similarly for false_fn. It is very important to note that + while true_fn and false_fn will typically close over other tensors (i.e. they + use values computed prior to the safe_select function), there is no NaN-safety + for the backward pass of closed over values. It is important than any NaN's + are produced within the branch functions and not before them. For example, + + safe_select(x < eps, lambda: 0., lambda: jnp.sqrt(x)) + + will not produce NaN on the backward pass even if x == 0. since sqrt happens + within the false_fn, but the very similar + + y = jnp.sqrt(x) + safe_select(x < eps, lambda: 0., lambda: y) + + will produce a NaN on the backward pass if x == 0 because the sqrt happens + prior to the false_fn. + + Args: + condition: Boolean array to use in where + true_fn: Zero-argument function to construct the values used in the True + condition. Tensors that this function closes over will be extracted + automatically to implement the double-where trick to suppress spurious NaN + propagation. + false_fn: False branch equivalent of true_fn + + Returns: + Resulting PyTree equivalent to tree_map line above. + """ + true_result = true_fn() + false_result = false_fn() + + # Apply the double-where trick + true_part = ops.select(condition, true_result, + ops.stop_gradient(true_result)) + false_part = ops.select( + condition, ops.stop_gradient(false_result), false_result) + + return ops.select(condition, true_part, false_part) + + +def unstack(value: ms.Tensor, axis: int = -1) -> list[ms.Tensor]: + if len(value.shape) == 3: + if axis == -1: + split_tensors = [value[:, :, i] for i in range(value.shape[axis])] + elif axis == -2: + split_tensors = [value[:, i, :] for i in range(value.shape[axis])] + else: + split_tensors = [value[i, :, :] for i in range(value.shape[axis])] + elif len(value.shape) == 2: + if axis == -1: + split_tensors = [value[:, i] for i in range(value.shape[axis])] + else: + split_tensors = [value[i, :] for i in range(value.shape[axis])] + return split_tensors + + +def angdiff(alpha: ms.Tensor, beta: ms.Tensor) -> ms.Tensor: + """Compute absolute difference between two angles.""" + d = alpha - beta + d = (d + mnp.pi) % (2 * mnp.pi) - mnp.pi + return d + + +def safe_arctan2( + x1: ms.Tensor, x2: ms.Tensor, eps: float = 1e-8 +) -> ms.Tensor: + """Safe version of arctan2 that avoids NaN gradients when x1=x2=0.""" + + return safe_select( + ops.abs(x1) + ops.abs(x2) < eps, + lambda: ops.zeros_like(ops.atan2(x1, x2)), + lambda: ops.atan2(x1, x2), + ) + + +def weighted_mean( + *, + weights: ms.Tensor, + value: ms.Tensor, + axis: int | Iterable[int] | None=None, + eps: float = 1e-10, +) -> ms.Tensor: + """Computes weighted mean in a safe way that avoids NaNs. + + This is equivalent to jnp.average for the case eps=0.0, but adds a small + constant to the denominator of the weighted average to avoid NaNs. + 'weights' should be broadcastable to the shape of value. + + Args: + weights: Weights to weight value by. + value: Values to average + axis: Axes to average over. + eps: Epsilon to add to the denominator. + + Returns: + Weighted average. + """ + + weights = ops.cast(weights, value.dtype) + weights = ops.broadcast_to(weights, value.shape) + + weights_shape = weights.shape + + if isinstance(axis, numbers.Integral): + axis = [axis] + elif axis is None: + axis = list(range(len(weights_shape))) + + numerator = ops.reduce_sum(weights * value, axis=tuple(axis)) + denominator = ops.reduce_sum(weights, axis=tuple(axis)) + eps + + return numerator / denominator diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/vector.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/vector.py new file mode 100644 index 0000000000000000000000000000000000000000..76ec2bd73b654d85f7d064910d36076c9750e696 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/vector.py @@ -0,0 +1,255 @@ +# Copyright 2024 DeepMind Technologies Limited +# Copyright (C) 2025 Huawei Technologies Co., Ltd +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# +# Modifications by Huawei Technologies Co., Ltd: Adapt to run by MindSpore on Ascend + +"""Vec3Array Class.""" + +import dataclasses +from typing import Final, TypeVar, TypeAlias + +import mindspore as ms +from mindspore import ops, mint +from alphafold3.utils.geometry import struct_of_array + +Self = TypeVar('Self', bound='Vec3Array') +Float = TypeAlias = float | ms.Tensor +VERSION: Final[str] = '0.1' + + +def tree_map(func, *trees): + """ + Recursively applies a function to each leaf of the input trees. + + Args: + func: A function to apply to each leaf. + *trees: One or more tree structures (nested lists/tuples/dicts). + + Returns: + A new tree with the same structure where `func` has been applied to each leaf. + """ + if isinstance(trees[0], Vec3Array): + return Vec3Array( + x=tree_map(func, *(t.x for t in trees)), + y=tree_map(func, *(t.y for t in trees)), + z=tree_map(func, *(t.z for t in trees)) + ) + if isinstance(trees[0], dict): + return {key: tree_map(func, *(t[key] for t in trees)) for key in trees[0]} + if isinstance(trees[0], (list, tuple)): + return type(trees[0])(tree_map(func, *args) for args in zip(*trees)) + return func(*trees) + + +@struct_of_array.StructOfArray(same_dtype=True) +class Vec3Array: + """Vec3Array in 3 dimensional Space implemented as struct of arrays. + This is done in order to improve performance and precision. + """ + + x: ms.Tensor = dataclasses.field(metadata={'dtype': ms.float32}) + y: ms.Tensor + z: ms.Tensor + + def __post_init__(self): + if hasattr(self.x, 'dtype'): + if not self.x.dtype == self.y.dtype == self.z.dtype: + raise ValueError( + f'Type mismatch: {self.x.dtype}, {self.y.dtype}, {self.z.dtype}' + ) + if not self.x.shape == self.y.shape == self.z.shape: + raise ValueError( + f'Shape mismatch: {self.x.shape}, {self.y.shape}, {self.z.shape}' + ) + + @property + def shape(self): + """Return the shape of the Vec3Array.""" + return self.x.shape + + def __add__(self, other: Self) -> Self: + return tree_map(ops.add, self, other) + + def __sub__(self, other: Self) -> Self: + return tree_map(ops.sub, self, other) + + def __mul__(self, other: Float | ms.Tensor) -> Self: + if isinstance(other, float): + return tree_map(lambda x: ops.mul(x, other), self) + x = ops.mul(self.x, other) + y = ops.mul(self.y, other) + z = ops.mul(self.z, other) + return Vec3Array(x, y, z) + + def __rmul__(self, other: Float | ms.Tensor) -> Self: + if isinstance(other, float): + return self * other + x = ops.mul(self.x, other) + y = ops.mul(self.y, other) + z = ops.mul(self.z, other) + return Vec3Array(x, y, z) + + def __truediv__(self, other: Float) -> Self: + return tree_map(lambda x: ops.div(x, other), self) + + def __neg__(self) -> Self: + return tree_map(lambda x: -x, self) + + def __pos__(self) -> Self: + return tree_map(lambda x: x, self) + + def cross(self, other: Self) -> Self: + """Compute cross product between 'self' and 'other'.""" + new_x = ops.sub(ops.mul(self.y, other.z), ops.mul(self.z, other.y)) + new_y = ops.sub(ops.mul(self.z, other.x), ops.mul(self.x, other.z)) + new_z = ops.sub(ops.mul(self.x, other.y), ops.mul(self.y, other.x)) + return Vec3Array(new_x, new_y, new_z) + + def dot(self, other: Self) -> ms.Tensor: + """Compute dot product between 'self' and 'other'.""" + return ops.add(ops.add(ops.mul(self.x, other.x), ops.mul(self.y, other.y)), ops.mul(self.z, other.z)) + + def norm(self, epsilon: float = 1e-6) -> ms.Tensor: + """Compute Norm of Vec3Array, clipped to epsilon.""" + # To avoid NaN on the backward pass, we must use maximum before the sqrt + norm2 = self.dot(self) + if epsilon: + norm2 = ops.maximum(norm2, epsilon**2) + return ops.sqrt(norm2) + + def norm2(self) -> ms.Tensor: + return self.dot(self) + + def normalized(self, epsilon: float = 1e-6) -> Self: + """Return unit vector with optional clipping.""" + return self / self.norm(epsilon) + + @classmethod + def zeros(cls, shape, dtype=ms.float32): + """Return Vec3Array corresponding to zeros of given shape.""" + return cls( + mint.zeros(shape, dtype=dtype), + mint.zeros(shape, dtype=dtype), + mint.zeros(shape, dtype=dtype), + ) + + def to_array(self) -> ms.Tensor: + return ops.stack([self.x, self.y, self.z], axis=-1) + + @classmethod + def from_array(cls, array): + unstacked = ops.unstack(array, axis=-1) + return cls(unstacked[0], unstacked[1], unstacked[2]) + + def __getstate__(self): + return ( + VERSION, + [self.x.asnumpy(), self.y.asnumpy(), self.z.asnumpy()], + ) + + def __setstate__(self, state): + version, state = state + del version + for i, letter in enumerate('xyz'): + object.__setattr__(self, letter, ms.Tensor(state[i])) + + +def square_euclidean_distance( + vec1: Vec3Array, vec2: Vec3Array, epsilon: float = 1e-6 +) -> Float: + """Computes square of euclidean distance between 'vec1' and 'vec2'. + + Args: + vec1: Vec3Array to compute distance to + vec2: Vec3Array to compute distance from, should be broadcast compatible + with 'vec1' + epsilon: distance is clipped from below to be at least epsilon + + Returns: + Array of square euclidean distances; + shape will be result of broadcasting 'vec1' and 'vec2' + """ + difference = vec1 - vec2 + distance = difference.dot(difference) + if epsilon: + distance = ops.maximum(distance, epsilon) + return distance + + +def dot(vector1: Vec3Array, vector2: Vec3Array) -> Float: + return vector1.dot(vector2) + + +def cross(vector1: Vec3Array, vector2: Vec3Array) -> Float: + return vector1.cross(vector2) + + +def norm(vector: Vec3Array, epsilon: float = 1e-6) -> Float: + return vector.norm(epsilon) + + +def normalized(vector: Vec3Array, epsilon: float = 1e-6) -> Vec3Array: + return vector.normalized(epsilon) + + +def euclidean_distance( + vec1: Vec3Array, vec2: Vec3Array, epsilon: float = 1e-6 +) -> Float: + """Computes euclidean distance between 'vec1' and 'vec2'. + + Args: + vec1: Vec3Array to compute euclidean distance to + vec2: Vec3Array to compute euclidean distance from, should be broadcast + compatible with 'vec1' + epsilon: distance is clipped from below to be at least epsilon + + Returns: + Array of euclidean distances; + shape will be result of broadcasting 'vec1' and 'vec2' + """ + distance_sq = square_euclidean_distance(vec1, vec2, epsilon**2) + distance = ops.sqrt(distance_sq) + return distance + + +def dihedral_angle( + a: Vec3Array, b: Vec3Array, c: Vec3Array, d: Vec3Array +) -> Float: + """Computes torsion angle for a quadruple of points. + + For points (a, b, c, d), this is the angle between the planes defined by + points (a, b, c) and (b, c, d). It is also known as the dihedral angle. + + Arguments: + a: A Vec3Array of coordinates. + b: A Vec3Array of coordinates. + c: A Vec3Array of coordinates. + d: A Vec3Array of coordinates. + + Returns: + A tensor of angles in radians: [-pi, pi]. + """ + v1 = a - b + v2 = b - c + v3 = d - c + + c1 = v1.cross(v2) + c2 = v3.cross(v2) + c3 = c2.cross(c1) + + v2_mag = v2.norm() + return ops.atan2(c3.dot(v2), v2_mag * c1.dot(c2)) + + +def random_gaussian_vector(shape, key=None, dtype=ms.float32) -> Vec3Array: + stdnormal = ops.StandardNormal(seed=key) + vec_array = stdnormal(shape + (3,)).astype(dtype) + return Vec3Array.from_array(vec_array) diff --git a/MindSPONGE/applications/research/medformer/README.md b/MindSPONGE/applications/research/medformer/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ae6dee737d8ccd9073a6714fe1beff79735cfeea --- /dev/null +++ b/MindSPONGE/applications/research/medformer/README.md @@ -0,0 +1,42 @@ +# MedFormer: Transformer-based Drug Perturbation Prediction + +MedFormer is a drug perturbation prediction framework based on the Transformer architecture, designed to predict the transcriptional responses of small molecule drugs under different cellular states. By integrating drug molecular fingerprints, baseline transcriptional states, and gene embeddings, it achieves high-precision predictions for unseen drugs and cell types, and is scalable to single-cell data. + +This project is based on [MindSPONGE](https://gitee.com/mindspore/mindscience/tree/master/MindSPONGE) and implemented in Python. + +--- + +## 🔧 requirement + +- Python 3.8+ + +- mindspore >= 3.9.0 + +- numpy + +- pandas + +- scikit-learn + +- rdkit + +- tqdm + +--- + +## Quick start + +Raw data link: +https://zenodo.org/records/14230870 + +Essential data link: +https://pan.baidu.com/s/1AKJT6gvSf05PgYit6SPbYQ?pwd=f5iy + +Run: +`python train.py --split_key drug_split_0 --ablation False --device_id 0` + +`split_key` indicates which fold of the k-fold cross-validation should be used as the training set. + +`ablation` indicates whether an ablation experiment is to be conducted. + +`device_id` represents the ID of the computing card being used. It can be filled in according to the actual situation. By default, the idle computing card among all available ones will be selected automatically. \ No newline at end of file diff --git a/MindSPONGE/applications/research/medformer/__init__.py b/MindSPONGE/applications/research/medformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d00e2d8a6d1bdc5fe082ba097bbc44179b1908ce --- /dev/null +++ b/MindSPONGE/applications/research/medformer/__init__.py @@ -0,0 +1,12 @@ +# Copyright 2025 Yuanhanyu Luo & Linchang Zhu +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/MindSPONGE/applications/research/medformer/module/MedFormer.py b/MindSPONGE/applications/research/medformer/module/MedFormer.py new file mode 100644 index 0000000000000000000000000000000000000000..5dc8026ce4a34e57355c6eb6c20fa514e00ad542 --- /dev/null +++ b/MindSPONGE/applications/research/medformer/module/MedFormer.py @@ -0,0 +1,110 @@ +# Copyright 2025 Yuanhanyu Luo & Linchang Zhu + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module defines the MedFormer model for gene expression prediction. +""" + +import mindspore as ms +from mindspore import nn +from mindspore import Parameter +from mindspore import ops + +class GenePertFormer(ms.nn.Cell): + """ + GenePertFormer model for gene expression prediction, combining gene, drug, and cell features. + """ + def __init__(self, gene_vocab_size=23185, drug_dim=1024, cell_dim=82, + hidden_dim=256, n_layers=4, n_heads=1, dropout=0.1, + cell_input_dim=978, use_cell_expr=False): + super().__init__() + self.hidden_dim = hidden_dim + + # Embeddings + self.gene_embedding = nn.Embedding(gene_vocab_size, hidden_dim) + self.expr_embedding = nn.Dense(1, hidden_dim) + self.drug_embedding = nn.Dense(drug_dim, hidden_dim) + + self.use_cell_expr = use_cell_expr + if use_cell_expr: + self.cell_embedding = nn.SequentialCell([ + nn.Dense(cell_input_dim, 512), + nn.ReLU(), + nn.Dropout(0.1), + nn.Dense(512, hidden_dim) + ]) + else: + self.cell_embedding = nn.Dense(cell_dim, hidden_dim) + + # CLS Token & positional embedding + self.cls_token = Parameter(ops.StandardNormal()((1, 1, hidden_dim)), name='cls_token') + # Line too long fixed by splitting the long line + self.pos_embedding = \ + Parameter(ops.StandardNormal()((1, gene_vocab_size + 3, hidden_dim)), name='pos_embedding') + + # Transformer Encoder + encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=n_heads, + dim_feedforward=4 * hidden_dim, dropout=dropout, + batch_first=True) + self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers) + + # Prediction Heads + self.to_gene_pred = nn.Dense(hidden_dim, 1) + self.cls_head = nn.Dense(hidden_dim, hidden_dim) + self.recon_head = nn.SequentialCell([ + nn.Dense(hidden_dim, hidden_dim), + nn.ReLU(), + nn.Dense(hidden_dim, cell_input_dim) + ]) + + def construct(self, gene_ids, gene_expr, drug_fp, cell_feat, mask=None): + """ + Forward pass for the GenePertFormer model. + + Args: + gene_ids (Tensor): Tensor of gene IDs. + gene_expr (Tensor): Tensor of gene expressions. + drug_fp (Tensor): Tensor of drug fingerprints. + cell_feat (Tensor): Tensor of cell features. + mask (Tensor, optional): Mask for the transformer encoder. Defaults to None. + + Returns: + Tuple[Tensor, Tensor, Tensor]: Predicted gene expression, CLS token output, and reconstructed cell features. + """ + batch_size, _ = gene_ids.shape # Renamed B to batch_size, G to _ (unused) + + id_embed = self.gene_embedding(gene_ids) # [batch_size, G, H] + expr_embed = self.expr_embedding(gene_expr) # [batch_size, G, H] + gene_embed = id_embed + expr_embed # [batch_size, G, H] + + drug_token = self.drug_embedding(drug_fp).expand_dims(1) # [batch_size, 1, H] + + if self.use_cell_expr: + cell_raw = ops.Squeeze(-1)(cell_feat) # [batch_size, G] + cell_embed = self.cell_embedding(cell_raw) # [batch_size, H] + else: + cell_embed = self.cell_embedding(cell_feat) # [batch_size, H] + cell_token = cell_embed.expand_dims(1) # [batch_size, 1, H] + + cls = ops.BroadcastTo((batch_size, 1, self.hidden_dim))(self.cls_token) + tokens = ops.Concat(axis=1)((cls, drug_token, cell_token, gene_embed)) + tokens = tokens + self.pos_embedding[:, :tokens.shape[1], :] + + x = self.encoder(tokens, src_key_padding_mask=mask) + + pred_gene = self.to_gene_pred(x[:, 3:, :]).squeeze(-1) # [batch_size, G] + cls_out = self.cls_head(x[:, 0, :]) + recon = self.recon_head(cell_embed) + + return pred_gene, cls_out, recon diff --git a/MindSPONGE/applications/research/medformer/requirements.txt b/MindSPONGE/applications/research/medformer/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..9f9d297eb3c1f24cc0cbe20af25da1dea6187c82 --- /dev/null +++ b/MindSPONGE/applications/research/medformer/requirements.txt @@ -0,0 +1,7 @@ +python +mindspore==3.9.0 +numpy +pandas +scikit-learn +rdkit +tqdm diff --git a/MindSPONGE/applications/research/medformer/train.py b/MindSPONGE/applications/research/medformer/train.py new file mode 100644 index 0000000000000000000000000000000000000000..4198ed922b0a2b93d05433185ced02e8c45d5421 --- /dev/null +++ b/MindSPONGE/applications/research/medformer/train.py @@ -0,0 +1,324 @@ +# Copyright 2025 Yuanhanyu Luo & Linchang Zhu + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module implements the training and evaluation pipeline for the GenePertFormer model. +It handles data loading, preprocessing, model definition, training loop, and result visualization. +""" +import argparse +import os +import json +import logging +from datetime import datetime + +import numpy as np +import scanpy as sc +import matplotlib.pyplot as plt +import seaborn as sns + +import mindspore as ms +from mindspore import nn, context +import mindspore.dataset as ds + +# Grouped scipy imports +from scipy import sparse +from scipy.stats import pearsonr +from sklearn.metrics import r2_score + +import wandb + +from rdkit import Chem +from rdkit.Chem import AllChem + +from module.MedFormer import GenePertFormer + + +## Execution Mode +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + +# Configure logging +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +handler = logging.StreamHandler() +formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') +handler.setFormatter(formatter) +logger.addHandler(handler) + + +# utils +def parse_args(): + """ + Parses command line arguments for the perturbation model. + + Returns: + argparse.Namespace: An object containing the parsed arguments. + """ + parser = argparse.ArgumentParser(description="MindSpore version of perturbation model") + parser.add_argument("--split_key", default="drug_split_0", type=str) + parser.add_argument("--ablation", default=None, type=str) + return parser.parse_args() + + +def shuffle_adata(adata_obj): + """ + Shuffles the AnnData object in place. + + Args: + adata_obj (anndata.AnnData): The AnnData object to be shuffled. + + Returns: + anndata.AnnData: The shuffled AnnData object. + """ + if sparse.issparse(adata_obj.X): + adata_obj.X = adata_obj.X.A + perm = np.random.permutation(adata_obj.shape[0]) + return adata_obj[perm, :] + + +def train_valid_test_split(adata_obj, split_key): + """ + Splits the AnnData object into training, validation, and test sets, + including control samples in all sets. + + Args: + adata_obj (anndata.AnnData): The AnnData object containing the data. + split_key (str): The observation key used for splitting (e.g., "drug_split_0"). + + Returns: + Tuple[anndata.AnnData, anndata.AnnData, anndata.AnnData]: + Train, validation, and test AnnData objects. + """ + shuffled = shuffle_adata(adata_obj) + adata_ctrl0 = adata_obj[adata_obj.obs["control"] == 0] + train_idx = adata_ctrl0.obs[adata_ctrl0.obs[split_key] == "train"].index.tolist() + valid_idx = adata_ctrl0.obs[adata_ctrl0.obs[split_key] == "valid"].index.tolist() + test_idx = adata_ctrl0.obs[adata_ctrl0.obs[split_key] == "test"].index.tolist() + ctrl_idx = adata_obj.obs[adata_obj.obs["control"] == 1].index.tolist() + + def subset(idx_list): + return shuffled[idx_list + ctrl_idx] + return subset(train_idx), subset(valid_idx), subset(test_idx) + + +# ------------ Dataset & DataLoader ------------ +def drug_smiles_encode(drug_smiles_list: list, num_bits=1024): + """ + Encodes a list of drug SMILES strings into Morgan fingerprints. + + Args: + drug_smiles_list (list): A list of SMILES strings. + num_bits (int): The number of bits for the Morgan fingerprint. + + Returns: + numpy.ndarray: A NumPy array of drug fingerprints. + + Raises: + ValueError: If an invalid SMILES string is encountered. + """ + arr = np.zeros((len(drug_smiles_list), num_bits), dtype=np.float32) + for i, smiles in enumerate(drug_smiles_list): + mol = Chem.MolFromSmiles(smiles) + if mol is None: + raise ValueError("Invalid SMILES") # Changed to lazy formatting + bits = AllChem.GetMorganFingerprintAsBitVect(mol, 2, useFeatures=True, nBits=num_bits).ToBitString() + arr[i] = np.array(list(bits), dtype=np.float32) + return arr + + +class GenePertAnnDatasetMS: + """ + A MindSpore Dataset for GenePert data, wrapping an AnnData object. + """ + def __init__(self, adata_obj, gene2id_path, control_key="condition", smiles_key="SMILES", cell_key="cell_id"): + """ + Initializes the GenePertAnnDatasetMS. + + Args: + adata_obj (anndata.AnnData): The AnnData object containing the data. + gene2id_path (str): Path to the JSON file mapping gene names to IDs. + control_key (str): Observation key for control samples. + smiles_key (str): Observation key for drug SMILES strings. + cell_key (str): Observation key for cell IDs. + """ + self.adata = adata_obj + self.control_key = control_key + self.smiles_key = smiles_key + self.cell_key = cell_key + self.drug_dim = 1024 + + with open(gene2id_path, "r", encoding='utf-8') as f: + self.gene2id = json.load(f) + self.idx = np.where(self.adata.obs[self.control_key] != "control")[0] + self.gene_order = self.adata.var_names.tolist() + cells = self.adata.obs[self.cell_key].unique().tolist() + self.cell2id = {cid: i for i, cid in enumerate(sorted(cells))} + + def __len__(self): + """Returns the number of samples in the dataset.""" + return len(self.idx) + + def __getitem__(self, idx): + """ + Retrieves a single sample from the dataset. + + Args: + idx (int): Index of the sample to retrieve. + + Returns: + Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray, numpy.ndarray, numpy.ndarray]: + gene_ids, treated gene expression, drug fingerprint, cell one-hot encoding, control gene expression. + """ + si = self.idx[idx] + sample = self.adata[si] + treat = sample.X.toarray().flatten() + ctrl_id = sample.obs["paired_control_index"].values[0] + ctrl_i = self.adata.obs_names.get_loc(ctrl_id) + ctrl = self.adata[ctrl_i].X.toarray().flatten() + gene_ids = np.array([self.gene2id[g] for g in self.gene_order], dtype=np.int32) + treat = treat.astype(np.float32).reshape(-1, 1) + ctrl = ctrl.astype(np.float32) + smiles = sample.obs[self.smiles_key].values[0] + drug_fp = drug_smiles_encode([smiles], self.drug_dim)[0] + cid = sample.obs[self.cell_key].values[0] + cell_onehot = np.eye(len(self.cell2id), dtype=np.float32)[self.cell2id[cid]] + return gene_ids, treat, drug_fp, cell_onehot, ctrl + + +def build_ms_dataset(adata_obj, gene2id_path, batch_size=64, shuffle_data=True): + """ + Builds a MindSpore GeneratorDataset from an AnnData object. + + Args: + adata_obj (anndata.AnnData): The AnnData object to build the dataset from. + gene2id_path (str): Path to the JSON file mapping gene names to IDs. + batch_size (int): Batch size for the dataset. + shuffle_data (bool): Whether to shuffle the dataset. + + Returns: + mindspore.dataset.engine.datasets.GeneratorDataset: The built MindSpore dataset. + """ + ds_src = GenePertAnnDatasetMS(adata_obj, gene2id_path) + ms_ds = ds.GeneratorDataset(ds_src, + ["gene_ids", "gene_expr", "drug_fp", "cell_feat", "control_expr"], + shuffle=shuffle_data) + ms_ds = ms_ds.batch(batch_size, drop_remainder=True) + return ms_ds + + +# ------------ training process ------------ +args = parse_args() +original_adata = sc.read_h5ad("./Lincs_L1000.h5ad") +train_data, valid_data, test_data = train_valid_test_split(original_adata, args.split_key) + +timestamp = datetime.now().strftime("%Y%m%d_%H%M") +save_dir = f"./MSmodel_{args.split_key}_{timestamp}" +os.makedirs(save_dir, exist_ok=True) + +train_ds = build_ms_dataset(train_data, "./data/gene2id.json", batch_size=64, shuffle_data=True) +valid_ds = build_ms_dataset(valid_data, "./data/gene2id.json", batch_size=64, shuffle_data=False) +test_ds = build_ms_dataset(test_data, "./data/gene2id.json", batch_size=64, shuffle_data=False) + +# define model +model = GenePertFormer(drug_dim=1024, cell_dim=82, hidden_dim=256, use_cell_expr=True, cell_input_dim=978) + +# loss +loss_fn = nn.MSELoss() +optimizer = nn.Adam(model.trainable_params(), learning_rate=0.0005) + +net = nn.WithLossCell(model, loss_fn) +train_net = nn.TrainOneStepCell(net, optimizer) +train_net.set_train() + +# wandb init +wandb.init(project="GenePertFormerMS", name=f"ms_{args.split_key}_{datetime.now().strftime('%Y%m%d_%H%M')}") + +best_val = 1e9 +patience, counter = 5, 0 +train_losses, val_losses = [], [] + +for ep in range(100): + total, count = 0.0, 0 + for b in train_ds.create_tuple_iterator(): + _, loss = train_net(*b) + total += loss.asnumpy() + count += 1 + avg_train = total / count + train_losses.append(avg_train) + + # validation + model.set_train(False) + total, count = 0.0, 0 + for b in valid_ds.create_tuple_iterator(): + out = model(*b[:-1]) + v_loss = loss_fn(out[0], b[1]) + total += v_loss.asnumpy() + count += 1 + avg_val = total / count + val_losses.append(avg_val) + model.set_train(True) + + wandb.log({"epoch": ep, "train_loss": avg_train, "val_loss": avg_val}) + if avg_val < best_val: + best_val = avg_val + counter = 0 + ms.save_checkpoint(model, os.path.join(save_dir, "best.ckpt")) + else: + counter += 1 + if counter >= patience: + break + +# === results === +pred_list, true_list = [], [] + +for batch in test_ds.create_tuple_iterator(): + pred_batch, _, _ = model(*batch[:-1]) + true_batch = batch[1].asnumpy() + pred_list.append(pred_batch.asnumpy()) + true_list.append(true_batch) + +pred_array = np.vstack(pred_list) +true_array = np.vstack(true_list) + +r2 = np.nanmean([r2_score(t, p) for t, p in zip(true_array, pred_array)]) +pcc = np.nanmean([pearsonr(t, p)[0] for t, p in zip(true_array, pred_array)]) +logger.info("Test R²: %.4f, Pearson: %.4f", r2, pcc) # Changed to lazy formatting + +np.savez(os.path.join(save_dir, f"{args.split_key}_test_result_ms.npz"), + pred=pred_array, + true=true_array, + r2=r2, pcc=pcc) + +wandb.log({"test_r2": r2, "test_pearson": pcc}) + +# === visualization === +flat_true = true_array.flatten() +flat_pred = pred_array.flatten() +mask = ~np.isnan(flat_true) & ~np.isnan(flat_pred) +flat_true, flat_pred = flat_true[mask], flat_pred[mask] + +sns.set_theme(style="ticks") +fig, ax = plt.subplots(figsize=(6, 6)) +ax.scatter(flat_true, flat_pred, alpha=0.4, s=3, color='steelblue') +ax.plot([flat_true.min(), flat_true.max()], [flat_true.min(), flat_true.max()], + 'r--', linewidth=1.2) +ax.set_xlabel("True Expression", fontsize=12) +ax.set_ylabel("Predicted Expression", fontsize=12) +ax.set_title(f"True vs Predicted (MS)\nR² = {r2:.3f}, PCC = {pcc:.3f}", fontsize=13) +sns.despine() +fig.tight_layout() + +fig_path = os.path.join(save_dir, f"{args.split_key}_true_vs_pred_ms.png") +fig.savefig(fig_path, dpi=300) +plt.close(fig) +wandb.log({"true_vs_pred_scatter_ms": wandb.Image(fig_path)}) diff --git a/MindSPONGE/src/mindsponge/pipeline/models/esm_if1/module/features.py b/MindSPONGE/src/mindsponge/pipeline/models/esm_if1/module/features.py new file mode 100644 index 0000000000000000000000000000000000000000..9fb79ddd1c40ed54369d1f56aa1a47bb9db10b19 --- /dev/null +++ b/MindSPONGE/src/mindsponge/pipeline/models/esm_if1/module/features.py @@ -0,0 +1,373 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Feature extraction""" + +import math +import numpy as np +import mindspore as ms +import mindspore.ops as ops +import mindspore.nn as nn +from mindspore import context +# pylint: disable=relative-beyond-top-level +from .basic_modules import GVP, LayerNorm, Dense +from .util import normalize, norm, nan_to_num, rbf, flatten_graph, ms_transpose, ms_padding_without_val + + +class GVPInputFeaturizer(nn.Cell): + """Input feature extraction for GVP""" + + @staticmethod + def get_node_features(coords, coord_mask, with_coord_mask=True): + """Get node features""" + node_scalar_features = GVPInputFeaturizer._dihedrals(coords) + if with_coord_mask: + coord_mask = ops.ExpandDims()(ops.Cast()(coord_mask, ms.float32), -1) + node_scalar_features = ops.Concat(axis=-1)([node_scalar_features, coord_mask]) + x_ca = coords[:, :, 1] + orientations = GVPInputFeaturizer._orientations(x_ca) + sidechains = GVPInputFeaturizer._sidechains(coords) + node_vector_features = ops.Concat(axis=-2)([orientations, ops.ExpandDims()(sidechains, -2)]) + return node_scalar_features, node_vector_features + + @staticmethod + def _orientations(x): + + forward = normalize(x[:, 1:] - x[:, :-1]) + backward = normalize(x[:, :-1] - x[:, 1:]) + forward = ops.concat((forward, ops.Zeros()((forward.shape[0], 1, forward.shape[2]), ms.float32)), 1) + backward = ops.concat((ops.Zeros()((backward.shape[0], 1, backward.shape[2]), ms.float32), backward), 1) + + output = ops.Concat(axis=-2)([ops.ExpandDims()(forward, -2), ops.ExpandDims()(backward, -2)]) + return output + + @staticmethod + def _sidechains(x): + n, origin, c = x[:, :, 0], x[:, :, 1], x[:, :, 2] + c, n = normalize(c - origin), normalize(n - origin) + bisector = normalize(c + n) + perp = normalize(ms.numpy.cross(c, n)) + vec = -bisector * math.sqrt(1 / 3) - perp * math.sqrt(2 / 3) + return vec + + @staticmethod + def _dihedrals(x, eps=1e-7): + """Dihedron""" + + y = x[:, :, :3].reshape((x.shape[0], (x.shape[1] * x.shape[2]), x.shape[3])) + bsz = x.shape[0] + dx = y[:, 1:] - y[:, :-1] + u = normalize(dx, dim=-1) + u_2 = u[:, :-2] + u_1 = u[:, 1:-1] + u_0 = u[:, 2:] + + # Backbone normals + n_2 = normalize(ms.numpy.cross(u_2, u_1), dim=-1) + n_1 = normalize(ms.numpy.cross(u_1, u_0), dim=-1) + + # Angle between normals + cosd = ops.ReduceSum()(n_2 * n_1, -1) + + min_value = ms.Tensor((-1 + eps), ms.float32) + max_value = ms.Tensor((1 - eps), ms.float32) + cosd = ops.clip_by_value(cosd, clip_value_min=min_value, clip_value_max=max_value) + d = ops.Sign()((u_2 * n_1).sum(-1)) * ops.ACos()(cosd) + + # This scheme will remove phi[0], psi[-1], omega[-1] + d = ms_padding_without_val(d, [1, 2]) + d = ops.Reshape()(d, (bsz, -1, 3)) + # Lift angle representations to the circle + d_features = ops.Concat(axis=-1)([ops.Cos()(d), ops.Sin()(d)]) + return d_features + + @staticmethod + def _positional_embeddings(edge_index, + num_embeddings=None, + num_positional_embeddings=16): + """Positional embeddings""" + + num_embeddings = num_embeddings or num_positional_embeddings or [] + d = edge_index[0] - edge_index[1] + + frequency = ops.Exp()( + ms.numpy.arange(0, num_embeddings, 2, dtype=ms.float32) + * -(np.log(10000.0) / num_embeddings) + ) + angles = ops.ExpandDims()(d, -1) * frequency + e = ops.Concat(-1)((ops.Cos()(angles), ops.Sin()(angles))) + return e + + @staticmethod + def _dist(x, coord_mask, padding_mask, top_k_neighbors): + """ Pairwise euclidean distances """ + bsz, maxlen = x.shape[0], x.shape[1] + coord_mask = ops.Cast()(coord_mask, ms.float32) + coord_mask_2d = ops.ExpandDims()(coord_mask, 1) * ops.ExpandDims()(coord_mask, 2) + residue_mask = ~padding_mask + residue_mask = ops.Cast()(residue_mask, ms.float32) + residue_mask_2d = ops.ExpandDims()(residue_mask, 1) * ops.ExpandDims()(residue_mask, 2) + dx = ops.ExpandDims()(x, 1) - ops.ExpandDims()(x, 2) + d = coord_mask_2d * norm(dx, dim=-1) + + # sorting preference: first those with coords, then among the residues that + # exist but are masked use distance in sequence as tie breaker, and then the + # residues that came from padding are last + seqpos = ms.numpy.arange(maxlen) + seqpos_1 = ops.ExpandDims()(seqpos, 1) + seqpos_0 = ops.ExpandDims()(seqpos, 0) + d_seq = ops.Abs()(seqpos_1 - seqpos_0) + if bsz != 1: + d_seq = ms.numpy.tile(d_seq, (bsz, 1, 1)) + coord_mask_2d = ops.Cast()(coord_mask_2d, ms.bool_) + residue_mask_2d = ops.Cast()(residue_mask_2d, ms.bool_) + verse_coord_mask_2d = ops.Cast()(~coord_mask_2d, ms.float32) + verse_residue_mask_2d = ops.Cast()(~residue_mask_2d, ms.float32) + d_adjust = nan_to_num(d) + (verse_coord_mask_2d) * (1e8 + d_seq * 1e6) + ( + verse_residue_mask_2d) * (1e10) + + if top_k_neighbors == -1: + d_neighbors = d_adjust / 1e4 + e_idx = seqpos.repeat( + *d_neighbors.shape[:-1], 1) + else: + d_adjust = d_adjust / 1e4 + if context.get_context("device_target") == "GPU": + d_neighbors, e_idx = ops.Sort(axis=-1, descending=True)(d_adjust) + else: + d_neighbors, e_idx = ops.TopK(sorted=True)(d_adjust, d_adjust.shape[-1]) + d_neighbors, e_idx = ms.mint.flip(d_neighbors, [-1]), ms.mint.flip(e_idx, [-1]) + d_neighbors, e_idx = d_neighbors[:, :, 0:int(min(top_k_neighbors, x.shape[1]))], \ + e_idx[:, :, 0:int(min(top_k_neighbors, x.shape[1]))] + d_neighbors = ms.Tensor(d_neighbors, ms.float32)*1e4 + coord_mask_neighbors = (d_neighbors < 5e7) + residue_mask_neighbors = (d_neighbors < 5e9) + output = [d_neighbors, e_idx, coord_mask_neighbors, residue_mask_neighbors] + return output + + +class Normalize(nn.Cell): + """Normalization""" + + def __init__(self, features, epsilon=1e-6): + super(Normalize, self).__init__() + self.gain = ms.Parameter(ops.Ones()(features, ms.float32)) + self.bias = ms.Parameter(ops.Zeros()(features, ms.float32)) + self.epsilon = epsilon + + def construct(self, x, dim=-1): + """Normalization construction""" + + mu = x.mean(dim, keep_dims=True) + sigma = ops.Sqrt()(x.var(dim, keepdims=True) + self.epsilon) + gain = self.gain + bias = self.bias + # Reshape + if dim != -1: + shape = [1] * len(mu.size()) + shape[dim] = self.gain.size()[0] + gain = gain.view(shape) + bias = bias.view(shape) + return gain * (x - mu) / (sigma + self.epsilon) + bias + + +class DihedralFeatures(nn.Cell): + """Dihedral features""" + + def __init__(self, node_embed_dim): + """ Embed dihedral angle features. """ + super(DihedralFeatures, self).__init__() + # 3 dihedral angles; sin and cos of each angle + node_in = 6 + # Normalization and embedding + self.node_embedding = Dense(node_in, node_embed_dim, has_bias=True) + self.norm_nodes = Normalize(node_embed_dim) + + @staticmethod + def _dihedrals(x, eps=1e-7, return_angles=False): + """Dihedron in DihedralFeatures""" + + # First 3 coordinates are N, CA, C + x = x[:, :, :3, :].reshape(x.shape[0], 3 * x.shape[1], 3) + + # Shifted slices of unit vectors + dx = x[:, 1:, :] - x[:, :-1, :] + u = ops.L2Normalize(axis=-1)(dx) + u_2 = u[:, :-2, :] + u_1 = u[:, 1:-1, :] + u_0 = u[:, 2:, :] + # Backbone normals + n_2 = ops.L2Normalize(axis=-1)(ms.numpy.cross(u_2, u_1)) + n_1 = ops.L2Normalize(axis=-1)(ms.numpy.cross(u_1, u_0)) + + # Angle between normals + cosd = (n_2 * n_1).sum(-1) + min_value = ms.Tensor((-1 + eps), ms.float32) + max_value = ms.Tensor((1 - eps), ms.float32) + cosd = ops.clip_by_value(cosd, clip_value_min=min_value, clip_value_max=max_value) + d = ops.Sign()((u_2 * n_1).sum(-1)) * ops.ACos()(cosd) + + # This scheme will remove phi[0], psi[-1], omega[-1] + d = ms_padding_without_val(d, [1, 2]) + d = d.view((d.shape[0], int(d.shape[1] / 3), 3)) + phi, psi, omega = ops.Unstack(axis=-1)(d) + + if return_angles: + return phi, psi, omega + + # Lift angle representations to the circle + d_features = ops.Concat(axis=2)((ops.Cos()(d), ops.Sin()(d))) + return d_features + + def construct(self, x): + """ Featurize coordinates as an attributed graph """ + v = self._dihedrals(x) + v = self.node_embedding(v) + v = self.norm_nodes(v) + return v + + +class GVPGraphEmbedding(GVPInputFeaturizer): + """GVP graph embedding""" + + def __init__(self, args): + super().__init__() + self.top_k_neighbors = args.top_k_neighbors + self.num_positional_embeddings = 16 + self.remove_edges_without_coords = True + node_input_dim = (7, 3) + edge_input_dim = (34, 1) + node_hidden_dim = (args.node_hidden_dim_scalar, + args.node_hidden_dim_vector) + edge_hidden_dim = (args.edge_hidden_dim_scalar, + args.edge_hidden_dim_vector) + self.embed_node = nn.SequentialCell( + [GVP(node_input_dim, node_hidden_dim, activations=(None, None)), + LayerNorm(node_hidden_dim, eps=1e-4)] + ) + self.embed_edge = nn.SequentialCell( + [GVP(edge_input_dim, edge_hidden_dim, activations=(None, None)), + LayerNorm(edge_hidden_dim, eps=1e-4)] + ) + self.embed_confidence = Dense(16, args.node_hidden_dim_scalar) + + def construct(self, coords, coord_mask, padding_mask, confidence): + """GVP graph embedding construction""" + + node_features = self.get_node_features(coords, coord_mask) + + edge_features, edge_index = self.get_edge_features( + coords, coord_mask, padding_mask) + node_embeddings_scalar, node_embeddings_vector = self.embed_node(node_features) + edge_embeddings = self.embed_edge(edge_features) + + rbf_rep = rbf(confidence, 0., 1.) + + node_embeddings = ( + node_embeddings_scalar + self.embed_confidence(rbf_rep), + node_embeddings_vector + ) + + + node_embeddings, edge_embeddings, edge_index = flatten_graph( + node_embeddings, edge_embeddings, edge_index) + return node_embeddings, edge_embeddings, edge_index + + def get_edge_features(self, coords, coord_mask, padding_mask): + """Get edge features""" + + x_ca = coords[:, :, 1] + + # Get distances to the top k neighbors + e_dist, e_idx, e_coord_mask, e_residue_mask = GVPInputFeaturizer._dist( + x_ca, coord_mask, padding_mask, self.top_k_neighbors) + # Flatten the graph to be batch size 1 for torch_geometric package + dest = e_idx + e_idx_b, e_idx_l, k = e_idx.shape[:3] + + src = ms.numpy.arange(e_idx_l).view((1, e_idx_l, 1)) + src = ops.BroadcastTo((e_idx_b, e_idx_l, k))(src) + + + edge_index = ops.Stack(axis=0)([src, dest]) + + edge_index = edge_index.reshape((edge_index.shape[0], edge_index.shape[1], + (edge_index.shape[2] * edge_index.shape[3]))) + + # After flattening, [B, E] + e_dist = e_dist.reshape((e_dist.shape[0], (e_dist.shape[1] * e_dist.shape[2]))) + + e_coord_mask = e_coord_mask.reshape((e_coord_mask.shape[0], (e_coord_mask.shape[1] * e_coord_mask.shape[2]))) + e_coord_mask = ops.ExpandDims()(e_coord_mask, -1) + e_residue_mask = e_residue_mask.reshape((e_residue_mask.shape[0], + (e_residue_mask.shape[1] * e_residue_mask.shape[2]))) + + # Calculate relative positional embeddings and distance RBF + pos_embeddings = GVPInputFeaturizer._positional_embeddings( + edge_index, + num_positional_embeddings=self.num_positional_embeddings, + ) + d_rbf = rbf(e_dist, 0., 20.) + + # Calculate relative orientation + x_src = ops.ExpandDims()(x_ca, 2) + x_src = ops.BroadcastTo((-1, -1, k, -1))(x_src) + x_src = x_src.reshape((x_src.shape[0], (x_src.shape[1] * x_src.shape[2]), x_src.shape[3])) + + a = ops.ExpandDims()(edge_index[1, :, :], -1) + a = ops.BroadcastTo((e_idx_b, e_idx_l * k, 3))(a) + x_dest = ops.GatherD()( + x_ca, + 1, + a + ) + coord_mask_src = ops.ExpandDims()(coord_mask, 2) + coord_mask_src = ops.BroadcastTo((-1, -1, k))(coord_mask_src) + coord_mask_src = coord_mask_src.reshape((coord_mask_src.shape[0], + (coord_mask_src.shape[1] * coord_mask_src.shape[2]))) + + b = ops.BroadcastTo((e_idx_b, e_idx_l * k))(edge_index[1, :, :]) + + coord_mask_dest = ops.GatherD()( + coord_mask, + 1, + b + ) + e_vectors = x_src - x_dest + # For the ones without coordinates, substitute in the average vector + e_coord_mask_fl = ops.Cast()(e_coord_mask, ms.float32) + e_vector_mean_top = ops.ReduceSum(keep_dims=True)(e_vectors * e_coord_mask_fl, axis=1) + e_vector_mean_bottom = ops.ReduceSum(keep_dims=True)(e_coord_mask_fl, axis=1) + e_vector_mean = e_vector_mean_top / e_vector_mean_bottom + e_vectors_factor1 = e_vectors * e_coord_mask_fl + e_vectors_factor2 = e_vector_mean * ~(e_coord_mask) + e_vectors = e_vectors_factor1 + e_vectors_factor2 + # Normalize and remove nans + edge_s = ops.Concat(axis=-1)([d_rbf, pos_embeddings]) + edge_v = ops.ExpandDims()(normalize(e_vectors), -2) + edge_s, edge_v = map(nan_to_num, (edge_s, edge_v)) + # Also add indications of whether the coordinates are present + + edge_s = ops.Concat(axis=-1)([ + edge_s, + ops.ExpandDims()((~coord_mask_src).astype(np.float32), -1), + ops.ExpandDims()((~coord_mask_dest).astype(np.float32), -1)]) + e_residue_mask = ops.Cast()(e_residue_mask, ms.bool_) + fill_value = ms.Tensor(-1, dtype=edge_index.dtype) + edge_index = edge_index.masked_fill(~e_residue_mask, fill_value) + + if self.remove_edges_without_coords: + edge_index = ops.masked_fill(edge_index, ~e_coord_mask.squeeze(-1), fill_value) + + return (edge_s, edge_v), ms_transpose(edge_index, 0, 1) diff --git a/MindSPONGE/src/mindsponge/pipeline/models/ufold/ufold.py b/MindSPONGE/src/mindsponge/pipeline/models/ufold/ufold.py new file mode 100644 index 0000000000000000000000000000000000000000..742e61dbf385fc00a39191aecf49b07ea3f56738 --- /dev/null +++ b/MindSPONGE/src/mindsponge/pipeline/models/ufold/ufold.py @@ -0,0 +1,162 @@ +# Copyright 2023 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""ufold""" +import mindspore as ms +from mindspore import jit, nn +from mindspore import ops +from mindspore import Tensor, context +from mindspore.nn import TrainOneStepCell +from mindspore.common import dtype as mstype +from mindspore.common import mutable + +from ..model import Model +from .nn_arch import Unet as FCNNet + +sign = ops.Sign() + + +def evaluate_exact_new(pred_a, true_a, eps=1e-11): + """get pred, recall and f1_score""" + tp_map = sign(ms.Tensor(pred_a) * ms.Tensor(true_a)) + tp = tp_map.sum() + pred_p = sign(ms.Tensor(pred_a)).sum() + true_p = true_a.sum() + fp = pred_p - tp + fn = true_p - tp + recall = (tp + eps)/(tp + fn + eps) + precision = (tp + eps)/(tp + fp + eps) + f1_score_ms = (2 * tp + eps)/(2 * tp + fp + fn + eps) + return precision, recall, f1_score_ms + + +class MyWithLossCell(nn.Cell): + def __init__(self, network, loss_fn): + super(MyWithLossCell, self).__init__(auto_prefix=False) + self.network = network + self.loss_fn = loss_fn + + def construct(self, x, y, label): + out = self.network(x) + return self.loss_fn(out*y, label) + + +class UFold(Model): + """UFold""" + def __init__(self, config): + self.config = config + self.use_jit = self.config.use_jit + if context.get_context("device_target") == "GPU": + self.mixed_precision = False + else: + self.mixed_precision = True + self.dataset_ckpt_name = { + 'ArchiveII': 'ufold_train', + 'bpnew': 'ufold_train', + 'TS0': 'ufold_train', + 'TS1': 'ufold_train_pdbfinetune', + 'TS2': 'ufold_train_pdbfinetune', + 'TS3': 'ufold_train_pdbfinetune', + 'All': 'ufold_train_99', + } + + self.checkpoint_urls = { + 'ufold_train': 'https://download.mindspore.cn/mindscience/mindsponge/ufold/checkpoint/ufold_train.ckpt', + 'ufold_train_pdbfinetune': + 'https://download.mindspore.cn/mindscience/mindsponge/ufold/checkpoint/ufold_train_pdbfinetune.ckpt', + 'ufold_train_99': + 'https://download.mindspore.cn/mindscience/mindsponge/ufold/checkpoint/ufold_train_99.ckpt' + } + + self.ckpt_name = self.dataset_ckpt_name.get(self.config.test_ckpt) + self.checkpoint_url = self.checkpoint_urls.get(self.ckpt_name) + self.checkpoint_path = "./" + self.ckpt_name + ".ckpt" + self.result_no_train = [] + self.cast = ops.Cast() + self.zeroslike = ops.ZerosLike() + self.network = FCNNet(img_ch=17) + self.pos_weight = ms.Tensor([300], mstype.float32) + self.criterion_bce_weighted = nn.BCEWithLogitsLoss(pos_weight=self.pos_weight) + self.u_optimizer = nn.Adam(params=self.network.trainable_params(), learning_rate=1e-4) + self.loss_net = MyWithLossCell(self.network, self.criterion_bce_weighted) + self.train_net = TrainOneStepCell(self.loss_net, self.u_optimizer) + if self.config.is_training: + self.train_net.set_train() + super().__init__(self.checkpoint_url, self.checkpoint_path, self.network, + mixed_precision=self.mixed_precision) + + + def forward(self, data): + if self.use_jit: + mse = self._jit_forward(data) + else: + mse = self._pynative_forward(data) + return mse + + # pylint: disable=arguments-differ + def predict(self, data): + pred_contacts = [] + for d in data: + _, seq_embeddings, _, _, _, _, _, _ = d + seq_embeddings = ms.Tensor(seq_embeddings).unsqueeze(0) + seq_embedding_batch = ms.Tensor(ops.Cast()(seq_embeddings, mstype.float32)) + pred_contacts.append(self.forward(seq_embedding_batch)) + return pred_contacts + + + def loss(self, data): + pass + + + def grad_operations(self, gradient): + pass + + + @jit + def backward(self, data): + loss = self.train_net(*data) + return loss + + + def train_step(self, data): + contacts, seq_embeddings, _, seq_lens, _, _ = data.values() + contacts_batch = Tensor(ops.Cast()(contacts, mstype.float32)) + seq_embedding_batch = Tensor(ops.Cast()(seq_embeddings, mstype.float32)) + pred_contacts = self.network(seq_embedding_batch) + contact_masks = ops.ZerosLike()(pred_contacts) + contact_masks[:, :seq_lens[0].item(), :seq_lens[0].item()] = 1 + contact_masks = contact_masks.astype(ms.float32) + feat = [seq_embedding_batch, contact_masks, contacts_batch] + feat = mutable(feat) + loss = self.backward(feat) + return loss + + + @jit + def _jit_forward(self, data): + mse = self.network(data) + return mse + + + def _pynative_forward(self, data): + mse = self.network(data) + return mse diff --git a/docs/api_python/mindchemistry/cell/mindchemistry.cell.orb.AttentionInteractionNetwork.rst b/docs/api_python/mindchemistry/cell/mindchemistry.cell.orb.AttentionInteractionNetwork.rst new file mode 100644 index 0000000000000000000000000000000000000000..7778f3a537968f0a3a629102da109eae1e335619 --- /dev/null +++ b/docs/api_python/mindchemistry/cell/mindchemistry.cell.orb.AttentionInteractionNetwork.rst @@ -0,0 +1,32 @@ +mindchemistry.cell.orb.AttentionInteractionNetwork +================================================== + +.. py:class:: mindchemistry.cell.orb.AttentionInteractionNetwork(num_node_in: int, num_node_out: int, num_edge_in: int, num_edge_out: int, num_mlp_layers: int, mlp_hidden_dim: int, attention_gate: str = "sigmoid", distance_cutoff: bool = True, polynomial_order: int = 4, cutoff_rmax: float = 6.0) + + 注意力交互网络。实现基于注意力机制的消息传递神经网络层,用于分子图的边更新。 + + 参数: + - **num_node_in** (int) - 节点输入特征数量。 + - **num_node_out** (int) - 节点输出特征数量。 + - **num_edge_in** (int) - 边输入特征数量。 + - **num_edge_out** (int) - 边输出特征数量。 + - **num_mlp_layers** (int) - 节点和边更新MLP的隐藏层数量。 + - **mlp_hidden_dim** (int) - MLP的隐藏维度大小。 + - **attention_gate** (str,可选) - 注意力门类型, ``"sigmoid"`` 或 ``"softmax"``。默认值: ``"sigmoid"``。 + - **distance_cutoff** (bool,可选) - 是否使用基于距离的边截断。默认值: ``True``。 + - **polynomial_order** (int,可选) - 多项式截断函数的阶数。默认值: ``4``。 + - **cutoff_rmax** (float,可选) - 截断的最大距离。默认值: ``6.0``。 + + 输入: + - **graph_edges** (dict) - 边特征字典,必须包含键"feat",形状为 :math:`(n_{edges}, num\_edge\_in)`。 + - **graph_nodes** (dict) - 节点特征字典,必须包含键"feat",形状为 :math:`(n_{nodes}, num\_node\_in)`。 + - **senders** (Tensor) - 每条边的发送节点索引,形状为 :math:`(n_{edges},)`。 + - **receivers** (Tensor) - 每条边的接收节点索引,形状为 :math:`(n_{edges},)`。 + + 输出: + - **edges** (dict) - 更新的边特征字典,键"feat"的形状为 :math:`(n_{edges}, num\_edge\_out)`。 + - **nodes** (dict) - 更新的节点特征字典,键"feat"的形状为 :math:`(n_{nodes}, num\_node\_out)`。 + + 异常: + - **ValueError** - 如果 `attention_gate` 不是"sigmoid"或"softmax"。 + - **ValueError** - 如果边或节点特征不包含必需的"feat"键。 \ No newline at end of file diff --git a/docs/api_python/mindchemistry/cell/mindchemistry.cell.orb.EnergyHead.rst b/docs/api_python/mindchemistry/cell/mindchemistry.cell.orb.EnergyHead.rst new file mode 100644 index 0000000000000000000000000000000000000000..fb549db328a7b26008669b6bbe43d7dc1a925bd9 --- /dev/null +++ b/docs/api_python/mindchemistry/cell/mindchemistry.cell.orb.EnergyHead.rst @@ -0,0 +1,28 @@ +mindchemistry.cell.orb.EnergyHead +================================== + +.. py:class:: mindchemistry.cell.orb.EnergyHead(latent_dim: int, num_mlp_layers: int, mlp_hidden_dim: int, target_property_dim: int, predict_atom_avg: bool = True, reference_energy_name: str = "mp-traj-d3", train_reference: bool = False, dropout: Optional[float] = None, node_aggregation: Optional[str] = None) + + 图级能量预测头。实现用于预测分子图总能量或原子平均能量的神经网络头。支持节点级聚合、参考能量偏移和灵活的输出模式。 + + 参数: + - **latent_dim** (int) - 每个节点的输入特征维度。 + - **num_mlp_layers** (int) - MLP中的隐藏层数量。 + - **mlp_hidden_dim** (int) - MLP的隐藏维度大小。 + - **target_property_dim** (int) - 能量属性的输出维度(通常为1)。 + - **predict_atom_avg** (bool,可选) - 是否预测每原子平均能量而不是总能量。默认值: ``True``。 + - **reference_energy_name** (str,可选) - 用于偏移的参考能量名称,例如 ``"vasp-shifted"``。默认值: ``"mp-traj-d3"``。 + - **train_reference** (bool,可选) - 是否将参考能量训练为可学习参数。默认值: ``False``。 + - **dropout** (Optional[float],可选) - MLP的dropout率。默认值: ``None``。 + - **node_aggregation** (str,可选) - 节点预测的聚合方法,例如 ``"mean"`` 或 ``"sum"``。默认值: ``None``。 + + 输入: + - **node_features** (dict) - 节点特征字典,必须包含键"feat",形状为 :math:`(n_{nodes}, latent\_dim)`。 + - **n_node** (Tensor) - 图中节点数量,形状为 :math:`(1,)`。 + + 输出: + - **output** (dict) - 包含键"graph_pred"的字典,值的形状为 :math:`(1, target\_property\_dim)`。 + + 异常: + - **ValueError** - 如果 `node_features` 中缺少必需的特征键。 + - **ValueError** - 如果 `node_aggregation` 不是支持的类型。 \ No newline at end of file diff --git a/docs/api_python/mindchemistry/cell/mindchemistry.cell.orb.GraphHead.rst b/docs/api_python/mindchemistry/cell/mindchemistry.cell.orb.GraphHead.rst new file mode 100644 index 0000000000000000000000000000000000000000..75ae5ad7c52da0a39f5b91142a8458c68fd51323 --- /dev/null +++ b/docs/api_python/mindchemistry/cell/mindchemistry.cell.orb.GraphHead.rst @@ -0,0 +1,25 @@ +mindchemistry.cell.orb.GraphHead +================================= + +.. py:class:: mindchemistry.cell.orb.GraphHead(latent_dim: int, num_mlp_layers: int, mlp_hidden_dim: int, target_property_dim: int, node_aggregation: str = "mean", dropout: Optional[float] = None, compute_stress: Optional[bool] = False) + + 图级预测头。实现可以附加到基础模型的图级预测头,用于从节点特征预测图级属性(例如应力张量),使用聚合和MLP。 + + 参数: + - **latent_dim** (int) - 每个节点的输入特征维度。 + - **num_mlp_layers** (int) - MLP中的隐藏层数量。 + - **mlp_hidden_dim** (int) - MLP的隐藏维度大小。 + - **target_property_dim** (int) - 图级属性的输出维度。 + - **node_aggregation** (str,可选) - 节点预测的聚合方法,例如 ``"mean"`` 或 ``"sum"``。默认值: ``"mean"``。 + - **dropout** (Optional[float],可选) - MLP的dropout率。默认值: ``None``。 + - **compute_stress** (bool,可选) - 是否计算和输出应力张量。默认值: ``False``。 + + 输入: + - **node_features** (dict) - 节点特征字典,必须包含键"feat",形状为 :math:`(n_{nodes}, latent\_dim)`。 + - **n_node** (Tensor) - 图中节点数量,形状为 :math:`(1,)`。 + + 输出: + - **output** (dict) - 包含键"stress_pred"的字典,值的形状为 :math:`(1, target\_property\_dim)`。 + + 异常: + - **ValueError** - 如果 `node_features` 中缺少必需的特征键。 \ No newline at end of file diff --git a/docs/api_python/mindchemistry/cell/mindchemistry.cell.orb.MoleculeGNS.rst b/docs/api_python/mindchemistry/cell/mindchemistry.cell.orb.MoleculeGNS.rst new file mode 100644 index 0000000000000000000000000000000000000000..c44551f329f4d6bd28dcacf08210fc60cc7a662d --- /dev/null +++ b/docs/api_python/mindchemistry/cell/mindchemistry.cell.orb.MoleculeGNS.rst @@ -0,0 +1,34 @@ +mindchemistry.cell.orb.MoleculeGNS +=================================== + +.. py:class:: mindchemistry.cell.orb.MoleculeGNS(num_node_in_features: int, num_node_out_features: int, num_edge_in_features: int, latent_dim: int, num_message_passing_steps: int, num_mlp_layers: int, mlp_hidden_dim: int, node_feature_names: List[str], edge_feature_names: List[str], use_embedding: bool = True, interactions: str = "simple_attention", interaction_params: Optional[Dict[str, Any]] = None) + + 分子图神经网络。实现用于分子性质预测的灵活模块化图神经网络,基于注意力或其他交互机制的消息传递。支持节点和边嵌入、多个消息传递步骤,以及用于复杂分子图的可定制交互层。 + + 参数: + - **num_node_in_features** (int) - 每个节点的输入特征数量。 + - **num_node_out_features** (int) - 每个节点的输出特征数量。 + - **num_edge_in_features** (int) - 每条边的输入特征数量。 + - **latent_dim** (int) - 节点和边表示的潜在维度。 + - **num_message_passing_steps** (int) - 消息传递层的数量。 + - **num_mlp_layers** (int) - 节点和边更新MLP的隐藏层数量。 + - **mlp_hidden_dim** (int) - MLP的隐藏维度大小。 + - **node_feature_names** (List[str]) - 从输入字典中使用的节点特征键列表。 + - **edge_feature_names** (List[str]) - 从输入字典中使用的边特征键列表。 + - **use_embedding** (bool,可选) - 是否对节点使用原子序数嵌入。默认值: ``True``。 + - **interactions** (str,可选) - 要使用的交互层类型(例如, ``"simple_attention"``)。默认值: ``"simple_attention"``。 + - **interaction_params** (Optional[Dict[str, Any]],可选) - 交互层的参数,例如截断、多项式阶数、门类型。默认值: ``None``。 + + 输入: + - **edge_features** (dict) - 边特征字典,必须包含 `edge_feature_names` 中指定的键。 + - **node_features** (dict) - 节点特征字典,必须包含 `node_feature_names` 中指定的键。 + - **senders** (Tensor) - 每条边的发送节点索引,形状为 :math:`(n_{edges},)`。 + - **receivers** (Tensor) - 每条边的接收节点索引,形状为 :math:`(n_{edges},)`。 + + 输出: + - **edges** (dict) - 更新的边特征字典,键"feat"的形状为 :math:`(n_{edges}, latent\_dim)`。 + - **nodes** (dict) - 更新的节点特征字典,键"feat"的形状为 :math:`(n_{nodes}, latent\_dim)`。 + + 异常: + - **ValueError** - 如果 `edge_features` 或 `node_features` 中缺少必需的特征键。 + - **ValueError** - 如果 `interactions` 不是支持的类型。 \ No newline at end of file diff --git a/docs/api_python/mindchemistry/cell/mindchemistry.cell.orb.NodeHead.rst b/docs/api_python/mindchemistry/cell/mindchemistry.cell.orb.NodeHead.rst new file mode 100644 index 0000000000000000000000000000000000000000..2e422d861892a40688faea4b8b38ca0a5848bb63 --- /dev/null +++ b/docs/api_python/mindchemistry/cell/mindchemistry.cell.orb.NodeHead.rst @@ -0,0 +1,26 @@ +mindchemistry.cell.orb.NodeHead +=============================== + +.. py:class:: mindchemistry.cell.orb.NodeHead(latent_dim: int, num_mlp_layers: int, mlp_hidden_dim: int, target_property_dim: int, dropout: Optional[float] = None, remove_mean: bool = True) + + 节点级预测头。 + + 实现用于从节点特征预测节点级属性的神经网络头。该头可以添加到基础模型中以在预训练期间启用辅助任务,或在微调步骤中添加。 + + 参数: + - **latent_dim** (int) - 每个节点的输入特征维度。 + - **num_mlp_layers** (int) - MLP中的隐藏层数量。 + - **mlp_hidden_dim** (int) - MLP的隐藏维度大小。 + - **target_property_dim** (int) - 节点级目标属性的输出维度。 + - **dropout** (Optional[float],可选) - MLP的dropout率。默认值: ``None``。 + - **remove_mean** (bool,可选) - 如果为True,从输出中移除均值,通常用于力预测。默认值: ``True``。 + + 输入: + - **node_features** (dict) - 节点特征字典,必须包含键 "feat",形状为 :math:`(n_{nodes}, latent\_dim)`。 + - **n_node** (Tensor) - 图中节点数量,形状为 :math:`(1,)`。 + + 输出: + - **output** (dict) - 包含键 "node_pred" 的字典,值的形状为 :math:`(n_{nodes}, target\_property\_dim)`。 + + 异常: + - **ValueError** - 如果 `node_features` 中缺少必需的特征键。 \ No newline at end of file diff --git a/docs/api_python/mindchemistry/cell/mindchemistry.cell.orb.Orb.rst b/docs/api_python/mindchemistry/cell/mindchemistry.cell.orb.Orb.rst new file mode 100644 index 0000000000000000000000000000000000000000..fe1d53a26ce54523eb47205d9a5fa5729892f8c2 --- /dev/null +++ b/docs/api_python/mindchemistry/cell/mindchemistry.cell.orb.Orb.rst @@ -0,0 +1,35 @@ +mindchemistry.cell.orb.Orb +=========================== + +.. py:class:: mindchemistry.cell.orb.Orb(model: MoleculeGNS, node_head: Optional[NodeHead] = None, graph_head: Optional[GraphHead] = None, stress_head: Optional[GraphHead] = None, model_requires_grad: bool = True, cutoff_layers: Optional[int] = None) + + Orb图回归器。将预训练的基础模型(如MoleculeGNS)与可选的节点、图和应力回归头结合,支持微调或特征提取工作流程。 + + 参数: + - **model** (MoleculeGNS) - 用于消息传递和特征提取的预训练或随机初始化基础模型。 + - **node_head** (NodeHead,可选) - 节点级属性预测的回归头。默认值: ``None``。 + - **graph_head** (GraphHead,可选) - 图级属性预测(例如能量)的回归头。默认值: ``None``。 + - **stress_head** (GraphHead,可选) - 应力预测的回归头。默认值: ``None``。 + - **model_requires_grad** (bool,可选) - 是否微调基础模型(True)或冻结其参数(False)。默认值: ``True``。 + - **cutoff_layers** (int,可选) - 如果提供,仅使用基础模型的前 ``"cutoff_layers"`` 个消息传递层。默认值: ``None``。 + + 输入: + - **edge_features** (dict) - 边特征字典(例如,`{"vectors": Tensor, "r": Tensor}`)。 + - **node_features** (dict) - 节点特征字典(例如,`{"atomic_numbers": Tensor, ...}`)。 + - **senders** (Tensor) - 每条边的发送节点索引。形状::math:`(n_{edges},)`。 + - **receivers** (Tensor) - 每条边的接收节点索引。形状::math:`(n_{edges},)`。 + - **n_node** (Tensor) - 批次中每个图的节点数量。形状::math:`(n_{graphs},)`。 + + 输出: + - **output** (dict) - 包含以下内容的字典: + + - **edges** (dict) - 消息传递后的边特征,例如 `{..., "feat": Tensor}`。 + - **nodes** (dict) - 消息传递后的节点特征,例如 `{..., "feat": Tensor}`。 + - **graph_pred** (Tensor) - 图级预测,例如能量。形状::math:`(n_{graphs}, target\_property\_dim)`。 + - **node_pred** (Tensor) - 节点级预测。形状::math:`(n_{nodes}, target\_property\_dim)`。 + - **stress_pred** (Tensor) - 应力预测(如果提供stress_head)。形状::math:`(n_{graphs}, 6)`。 + + 异常: + - **ValueError** - 如果既未提供node_head也未提供graph_head。 + - **ValueError** - 如果cutoff_layers超过基础模型中的消息传递步骤数。 + - **ValueError** - 如果graph_head需要时未提供atomic_numbers。 \ No newline at end of file diff --git a/docs/api_python/mindchemistry/mindchemistry.cell.rst b/docs/api_python/mindchemistry/mindchemistry.cell.rst new file mode 100644 index 0000000000000000000000000000000000000000..d346532e43dc2c3414d4e35edcb73aa369809ab4 --- /dev/null +++ b/docs/api_python/mindchemistry/mindchemistry.cell.rst @@ -0,0 +1,18 @@ +mindchemistry.cell +================== + +.. mscnplatformautosummary:: + :toctree: cell + :nosignatures: + :template: classtemplate.rst + + mindchemistry.cell.Allegro + mindchemistry.cell.AutoEncoder + mindchemistry.cell.FCNet + mindchemistry.cell.MLPNet + mindchemistry.cell.orb.AttentionInteractionNetwork + mindchemistry.cell.orb.EnergyHead + mindchemistry.cell.orb.GraphHead + mindchemistry.cell.orb.MoleculeGNS + mindchemistry.cell.orb.NodeHead + mindchemistry.cell.orb.Orb \ No newline at end of file diff --git a/docs/api_python/mindflow/cell/mindflow.cell.MultiHeadAttention.rst b/docs/api_python/mindflow/cell/mindflow.cell.MultiHeadAttention.rst new file mode 100644 index 0000000000000000000000000000000000000000..5d245889df1929ca2b91d1fb11f5459be19d130b --- /dev/null +++ b/docs/api_python/mindflow/cell/mindflow.cell.MultiHeadAttention.rst @@ -0,0 +1,25 @@ +mindflow.cell.MultiHeadAttention +================================= + +.. py:class:: mindflow.cell.MultiHeadAttention(in_channels, num_heads, enable_flash_attn=False, fa_dtype=mstype.bfloat16, drop_mode='dropout', dropout_rate=0.0, compute_dtype=mstype.float32) + + 多头注意力机制,具体细节可以参见 `Attention Is All You Need `_ 。 + + 参数: + - **in_channels** (int) - 输入的输入特征维度。 + - **num_heads** (int) - 输出的输出特征维度。 + - **enable_flash_attn** (bool) - 是否使能FlashAttention。FlashAttention只支持 `Ascend` 后端。具体细节参见 `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness `_ 。 + 默认值: ``False`` 。 + - **fa_dtype** (mindspore.dtype) - FlashAttention计算类型。支持以下类型: `mstype.bfloat16`、 `mstype.float16`。默认值: ``mstype.bfloat16`` ,表示 ``mindspore.bfloat16`` 。 + - **drop_mode** (str) - dropout方式。默认值: ``dropout`` 。支持以下类型: ``dropout`` 和 ``droppath`` 。 + - **dropout_rate** (float) - dropout层丢弃的比率。取值在 `[0, 1]` 。默认值: ``0.0`` 。 + - **compute_dtype** (mindspore.dtype) - 网络层的数据类型。默认值: ``mstype.float32`` ,表示 ``mindspore.float32`` 。 + + 输入: + - **x** (Tensor) - shape为 :math:`(batch\_size, sequence\_len, in\_channels)` 的Tensor。 + - **attn_mask** (Tensor,可选) - shape为 :math:`(sequence\_len, sequence\_len)` 或 + :math:`(batch\_size, 1, sequence\_len, sequence\_len)` 的Tensor。默认值: ``None`` 。 + - **key_padding_mask** (Tensor,可选) - shape为 :math:`(batch\_size, sequence\_len)` 的Tensor。默认值: ``None`` 。 + + 输出: + - **output** (Tensor) - shape为 :math:`(batch\_size, sequence\_len, in\_channels)` 的Tensor。 diff --git a/docs/api_python/mindflow/cell/mindflow.cell.TransformerBlock.rst b/docs/api_python/mindflow/cell/mindflow.cell.TransformerBlock.rst new file mode 100644 index 0000000000000000000000000000000000000000..574b01e5d13fe7e8a0161fec8bd09fbd59ca6abd --- /dev/null +++ b/docs/api_python/mindflow/cell/mindflow.cell.TransformerBlock.rst @@ -0,0 +1,23 @@ +mindflow.cell.TransformerBlock +====================================== + +.. py:class:: mindflow.cell.TransformerBlock(in_channels, num_heads, enable_flash_attn=False, fa_dtype=mstype.bfloat16, drop_mode='dropout', dropout_rate=0.0, compute_dtype=mstype.float32) + + `TransformerBlock` 包含 `MultiHeadAttention` 和 `FeedForward` 网络堆叠而成。 + + 参数: + - **in_channels** (int) - 输入的输入特征维度。 + - **num_heads** (int) - 输出的输出特征维度。 + - **enable_flash_attn** (bool) - 是否使能FlashAttention。FlashAttention只支持 `Ascend` 后端。具体细节参见 `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness `_ 。 + 默认值: ``False`` 。 + - **fa_dtype** (mindspore.dtype): FlashAttention计算类型。支持以下类型: `mstype.bfloat16`、 `mstype.float16`。默认值: ``mstype.bfloat16`` ,表示 ``mindspore.bfloat16`` 。 + - **drop_mode** (str) - dropout方式。默认值: ``dropout`` 。支持以下类型: ``dropout`` 和 ``droppath`` 。 + - **dropout_rate** (float) - dropout层丢弃的比率,在 ``[0, 1]`` 范围。默认值: ``0.0`` 。 + - **compute_dtype** (mindspore.dtype) - 网络层的数据类型。默认值: ``mstype.float32`` ,表示 ``mindspore.float32`` 。 + + 输入: + - **x** (Tensor) - shape为 :math:`(batch\_size, sequence\_len, in\_channels)` 的Tensor。 + - **mask** (Tensor) - shape为 :math:`(sequence\_len, sequence\_len)` 或 :math:`(batch\_size, 1, sequence\_len, sequence\_len)` 的Tensor. + + 输出: + - **output** (Tensor) - shape为 :math:`(batch\_size, sequence\_len, in\_channels)` 的Tensor。 diff --git a/docs/api_python/mindflow/mindflow.cell.rst b/docs/api_python/mindflow/mindflow.cell.rst new file mode 100644 index 0000000000000000000000000000000000000000..c52ce4cebd44996933d8aaba45dc88aea6953006 --- /dev/null +++ b/docs/api_python/mindflow/mindflow.cell.rst @@ -0,0 +1,34 @@ +mindflow.cell +================== + +.. mscnplatformautosummary:: + :toctree: cell + :nosignatures: + :template: classtemplate.rst + + mindflow.cell.ConditionDiffusionTransformer + mindflow.cell.DiffusionTrainer + mindflow.cell.DiffusionTransformer + mindflow.cell.DDIMPipeline + mindflow.cell.DDIMScheduler + mindflow.cell.DDPMPipeline + mindflow.cell.DDPMScheduler + mindflow.cell.FCSequential + mindflow.cell.FNO1D + mindflow.cell.FNO2D + mindflow.cell.FNO3D + mindflow.cell.InputScale + mindflow.cell.LinearBlock + mindflow.cell.MultiHeadAttention + mindflow.cell.MultiScaleFCSequential + mindflow.cell.PDENet + mindflow.cell.PeRCNN + mindflow.cell.ResBlock + mindflow.cell.SNO + mindflow.cell.SNO1D + mindflow.cell.SNO2D + mindflow.cell.SNO3D + mindflow.cell.TransformerBlock + mindflow.cell.UNet2D + mindflow.cell.ViT + mindflow.cell.get_activation diff --git a/docs/api_python_en/mindchemistry/mindchemistry.cell.rst b/docs/api_python_en/mindchemistry/mindchemistry.cell.rst new file mode 100644 index 0000000000000000000000000000000000000000..78cd3bfffeded29ea9cbc13642292928e4f6d464 --- /dev/null +++ b/docs/api_python_en/mindchemistry/mindchemistry.cell.rst @@ -0,0 +1,18 @@ +mindchemistry.cell +================== + +.. msplatformautosummary:: + :toctree: cell + :nosignatures: + :template: classtemplate.rst + + mindchemistry.cell.Allegro + mindchemistry.cell.AutoEncoder + mindchemistry.cell.FCNet + mindchemistry.cell.MLPNet + mindchemistry.cell.orb.AttentionInteractionNetwork + mindchemistry.cell.orb.EnergyHead + mindchemistry.cell.orb.GraphHead + mindchemistry.cell.orb.MoleculeGNS + mindchemistry.cell.orb.NodeHead + mindchemistry.cell.orb.Orb \ No newline at end of file diff --git a/docs/api_python_en/mindflow/mindflow.cell.rst b/docs/api_python_en/mindflow/mindflow.cell.rst new file mode 100644 index 0000000000000000000000000000000000000000..76791f5cfea2685f8339c2c149a51698ef2aa34b --- /dev/null +++ b/docs/api_python_en/mindflow/mindflow.cell.rst @@ -0,0 +1,34 @@ +mindflow.cell +================== + +.. msplatformautosummary:: + :toctree: cell + :nosignatures: + :template: classtemplate.rst + + mindflow.cell.ConditionDiffusionTransformer + mindflow.cell.DiffusionTrainer + mindflow.cell.DiffusionTransformer + mindflow.cell.DDIMPipeline + mindflow.cell.DDIMScheduler + mindflow.cell.DDPMPipeline + mindflow.cell.DDPMScheduler + mindflow.cell.FCSequential + mindflow.cell.FNO1D + mindflow.cell.FNO2D + mindflow.cell.FNO3D + mindflow.cell.InputScale + mindflow.cell.LinearBlock + mindflow.cell.MultiHeadAttention + mindflow.cell.MultiScaleFCSequential + mindflow.cell.PDENet + mindflow.cell.PeRCNN + mindflow.cell.ResBlock + mindflow.cell.SNO + mindflow.cell.SNO1D + mindflow.cell.SNO2D + mindflow.cell.SNO3D + mindflow.cell.TransformerBlock + mindflow.cell.UNet2D + mindflow.cell.ViT + mindflow.cell.get_activation diff --git a/mindscience/common/fourier.py b/mindscience/common/fourier.py new file mode 100644 index 0000000000000000000000000000000000000000..64f980668abda0e4b045f4645c35202cb0b24445 --- /dev/null +++ b/mindscience/common/fourier.py @@ -0,0 +1,656 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +''' provide complex dft based on the real dft API in mindflow.dft ''' +import numpy as np +import scipy +import mindspore as ms +import mindspore.common.dtype as mstype +from mindspore import nn, ops, Tensor, mint +from mindspore.common.initializer import Zero +from mindspore.ops import operations as P + +from ..utils.check_func import check_param_no_greater, check_param_value + + +class MyRoll(nn.Cell): + ''' Custom defined roll operator to avoid bug in MindSpore ''' + def __init__(self): + super().__init__() + + if ms.get_context('device_target') == 'Ascend' and ms.get_context('mode') == ms.GRAPH_MODE: + self.roller = mint.roll + else: + self.roller = None + + def construct(self, x, shifts, dims): + ''' Same as mint.roll ''' + shifts = np.atleast_1d(shifts).astype(int).tolist() + dims = np.atleast_1d(dims).astype(int).tolist() + + if self.roller: + return self.roller(x, shifts, dims) + + for i, j in zip(shifts, dims): + n = x.shape[j] + x = ops.swapaxes(x, j, 0) + x = ops.cat([x[n - i % n:], x[:n - i % n]], axis=0) + x = ops.swapaxes(x, j, 0) + return x + +class MyFlip(nn.Cell): + ''' Custom defined flip operator to avoid bug in MindSpore ''' + def __init__(self): + super().__init__() + msver = tuple([int(s) for s in ms.__version__.split('.')]) + + if msver <= (2, 4, 0) and \ + ms.get_context('device_target') == 'Ascend' and \ + ms.get_context('mode') == ms.PYNATIVE_MODE: + self.fliper = None + else: + self.fliper = mint.flip + + def construct(self, x, dims): + ''' same as mint.flip ''' + dims = np.atleast_1d(dims).astype(int).tolist() + + if self.fliper: + return self.fliper(x, dims) + + for j in dims: + x = ops.swapaxes(x, j, 0) + x = x[::-1] + x = ops.swapaxes(x, j, 0) + return x + + +def convert_shape(shape): + ''' convert shape to suitable format ''' + if isinstance(shape, int): + n = shape + elif len(shape) == 1: + n, = shape + else: + raise TypeError("Only support 1D dct/dst, but got shape {}".format(shape)) + return n + + +def convert_params(shape, modes, dim): + ''' convert input arguments to suitable format ''' + shape = tuple(np.atleast_1d(shape).astype(int).tolist()) + ndim = len(shape) + + if dim is None: + dim = tuple([n - ndim for n in range(ndim)]) + else: + dim = tuple(np.atleast_1d(dim).astype(int).tolist()) + + if modes is None or isinstance(modes, int): + modes = tuple([modes] * ndim) + else: + modes = tuple(np.atleast_1d(modes).astype(int).tolist()) + + return shape, modes, dim + + +def check_params(shape, modes, dim): + ''' check lawfulness of input arguments ''' + check_param_no_greater(len(dim), "dim length", 3) + check_param_value(len(shape), "shape length", len(dim)) + check_param_value(len(modes), "modes length", len(dim)) + if np.any(modes): + for i, (m, n) in enumerate(zip(modes, shape)): + # if for last axis mode need to be n//2+1, mode should be set to None + check_param_no_greater(m, f'mode{i+1}', n // 2) + + +class _DFT1d(nn.Cell): + '''One dimensional Discrete Fourier Transformation''' + + def __init__(self, n, mode, last_index, idx=0, scale='sqrtn', inv=False, compute_dtype=mstype.float32): + super().__init__() + + self.n = n + self.dft_mat = scipy.linalg.dft(n, scale=scale) + self.last_index = last_index + self.inv = inv + self.odd = bool(n % 2) + self.idx = idx + self.mode_upper = mode if mode else n // 2 + (self.last_index or self.odd) + self.mode_lower = mode if mode else n - self.mode_upper + self.compute_dtype = compute_dtype + + # generate DFT matrix for positive and negative frequencies + dft_mat_mode = self.dft_mat[:, :self.mode_upper] + self.a_re_upper = Tensor(dft_mat_mode.real, dtype=compute_dtype) + self.a_im_upper = Tensor(dft_mat_mode.imag, dtype=compute_dtype) + + dft_mat_mode = self.dft_mat[:, -self.mode_lower:] + self.a_re_lower = Tensor(dft_mat_mode.real, dtype=compute_dtype) + self.a_im_lower = Tensor(dft_mat_mode.imag, dtype=compute_dtype) + + # the zero matrix to fill the un-transformed modes + m = self.n - (self.mode_upper + self.mode_lower) + if m > 0: + self.mat = Tensor(shape=m, dtype=compute_dtype, init=Zero()) + + self.concat = ops.Concat(axis=-1) + self.cast = P.Cast() + + if self.inv: + self.a_re_upper = self.a_re_upper.T + self.a_im_upper = -self.a_im_upper.T + self.a_re_lower = self.a_re_lower.T + self.a_im_lower = -self.a_im_lower.T + + # last axis is real-transformed, so the inverse is conjugate of the positive frequencies + if last_index: + mode_res = min(self.mode_lower, self.mode_upper - 1) + dft_mat_res = self.dft_mat[:, -mode_res:] + a_re_res = MyFlip()(Tensor(dft_mat_res.real, dtype=compute_dtype), dims=-1) + a_im_res = MyFlip()(Tensor(dft_mat_res.imag, dtype=compute_dtype), dims=-1) + + a_re_res = ops.pad(a_re_res, (1, self.mode_upper - mode_res - 1)) + a_im_res = ops.pad(a_im_res, (1, self.mode_upper - mode_res - 1)) + + self.a_re_upper += a_re_res.T + self.a_im_upper += a_im_res.T + + def swap_axes(self, x_re, x_im): + return x_re.swapaxes(-1, self.idx), x_im.swapaxes(-1, self.idx) + + def complex_matmul(self, x_re, x_im, a_re, a_im): + y_re = ops.matmul(x_re, a_re) - ops.matmul(x_im, a_im) + y_im = ops.matmul(x_im, a_re) + ops.matmul(x_re, a_im) + return y_re, y_im + + def zero_mat(self, dims): + mat = self.mat + for n in dims[::-1]: + mat = mint.repeat_interleave(mat.expand_dims(0), n, 0) + return mat + + def compute_forward(self, x_re, x_im): + ''' Forward transform for rdft ''' + y_re, y_im = self.complex_matmul( + x_re=x_re, x_im=x_im, a_re=self.a_re_upper, a_im=self.a_im_upper) + + if self.last_index: + return y_re, y_im + + y_re2, y_im2 = self.complex_matmul( + x_re=x_re, x_im=x_im, a_re=self.a_re_lower, a_im=self.a_im_lower) + + if self.n == self.mode_upper + self.mode_lower: + y_re = self.concat((y_re, y_re2)) + y_im = self.concat((y_im, y_im2)) + else: + mat = self.zero_mat(x_re.shape[:-1]) + y_re = self.concat((y_re, mat, y_re2)) + y_im = self.concat((y_im, mat, y_im2)) + + return y_re, y_im + + def compute_inverse(self, x_re, x_im): + ''' Inverse transform for irdft ''' + y_re, y_im = self.complex_matmul(x_re=x_re[..., :self.mode_upper], + x_im=x_im[..., :self.mode_upper], + a_re=self.a_re_upper, + a_im=self.a_im_upper) + if self.last_index: + return y_re, y_im + + y_re_res, y_im_res = self.complex_matmul(x_re=x_re[..., -self.mode_lower:], + x_im=x_im[..., -self.mode_lower:], + a_re=self.a_re_lower, + a_im=self.a_im_lower) + return y_re + y_re_res, y_im + y_im_res + + def construct(self, x): + ''' perform 1d rdft/irdft with matmul operations ''' + x_re, x_im = x + x_re, x_im = self.cast(x_re, self.compute_dtype), self.cast(x_im, self.compute_dtype) + x_re, x_im = self.swap_axes(x_re, x_im) + if self.inv: + y_re, y_im = self.compute_inverse(x_re, x_im) + else: + y_re, y_im = self.compute_forward(x_re, x_im) + y_re, y_im = self.swap_axes(y_re, y_im) + return y_re, y_im + + +class _DFTn(nn.Cell): + ''' Base class for n-D DFT transform ''' + def __init__(self, shape, dim=None, norm='backward', modes=None, compute_dtype=mstype.float32): + super().__init__() + + shape, modes, dim = convert_params(shape, modes, dim) + check_params(shape, modes, dim) + + ndim = len(shape) + inv, scale, r2c_flags = self.set_options(ndim, norm) + self.dft1_seq = nn.SequentialCell() + for n, m, r, d in zip(shape, modes, r2c_flags, dim): + self.dft1_seq.append(_DFT1d( + n=n, mode=m, last_index=r, idx=d, scale=scale, inv=inv, compute_dtype=compute_dtype)) + + def set_options(self, ndim, norm): + ''' + Choose the dimensions, normalization, and transformation mode (forward/backward). + Derivative APIs overwrite the options to achieve their specific goals. + ''' + inv = False + scale = { + 'backward': None, + 'forward': 'n', + 'ortho': 'sqrtn', + }[norm] + r2c_flags = np.zeros(ndim, dtype=bool).tolist() + r2c_flags[-1] = True + return inv, scale, r2c_flags + + def construct(self, *args, **kwargs): + raise NotImplementedError + + +class RDFTn(_DFTn): + r""" + 1/2/3D discrete real Fourier transformation on real number. The results should be same as + `scipy.fft.rfftn() `_ . + + Args: + shape (tuple): The shape of the dimensions to be transformed, other dimensions need not be included. + dim (tuple): Dimensions to be transformed. Default: None, the leading dimensions will be transformed. + norm (str): Normalization mode, should be one of 'forward', 'backward', 'ortho'. Default: 'backward', + same as torch.fft.rfftn + modes (tuple, int, None): The length of the output transform axis. The `modes` must be no greater than half of the + dimension of input 'x'. + compute_dtype (mindspore.dtype): The type of input tensor. Default: mindspore.float32. + + Inputs: + - **ar** (Tensor) - The real tensor to be transformed, with trailing dimensions aligned with `shape`. + + Outputs: + - **br** (Tensor) - Real part of the output tensor, with trailing dimensions aligned with `shape`, + except for the last dimension, which should be shape[-1] / 2 + 1. + - **bi** (Tensor) - Imag part of the output tensor, with trailing dimensions aligned with `shape`, + except for the last dimension, which should be shape[-1] / 2 + 1. + + Supported Platforms: + ``Ascend`` ``CPU`` + + Examples: + >>> from mindspore import ops + >>> from mindflow.core import RDFTn + >>> ar = ops.rand((2, 32, 512)) + >>> dft_cell = RDFTn(x.shape[-2:]) + >>> br, bi = dft_cell(ar) + >>> print(br.shape) + (2, 32, 257) + """ + def construct(self, ar): + ''' perform n-dimensional rDFT on real tensor ''' + # n-D Fourier transform with last axis being real-transformed, output dimension (..., m, n//2+1) + # the last ndim dimensions of ar must accord with shape + return self.dft1_seq((ar, ar * 0)) + + +class IRDFTn(_DFTn): + r""" + 1/2/3D discrete inverse real Fourier transformation on complex number. The results should be same as + `scipy.fft.irfftn() `_ . + + Args: + shape (tuple): The shape of the dimensions to be transformed, other dimensions need not be included. + dim (tuple): Dimensions to be transformed. Default: None, the leading dimensions will be transformed. + norm (str): Normalization mode, should be one of 'forward', 'backward', 'ortho'. Default: 'backward', + same as torch.fft.irfftn + modes (tuple, int, None): The length of the output transform axis. The `modes` must be no greater than half of the + dimension of input 'x'. + compute_dtype (mindspore.dtype): The type of input tensor. Default: mindspore.float32. + + Inputs: + - **ar** (Tensor) - Real part of the tensor to be transformed, with trailing dimensions aligned with `shape`, + except for the last dimension, which should be shape[-1] / 2 + 1. + - **ai** (Tensor) - Imag part of the tensor to be transformed, with trailing dimensions aligned with `shape`, + except for the last dimension, which should be shape[-1] / 2 + 1. + + Outputs: + - **br** (Tensor) - The output real tensor, with trailing dimensions aligned with `shape`. + + Supported Platforms: + ``Ascend`` ``CPU`` + + Examples: + >>> from mindspore import ops + >>> from mindflow.core import IRDFTn + >>> ar = ops.rand((2, 32, 257)) + >>> ai = ops.rand((2, 32, 257)) + >>> dft_cell = IRDFTn(x.shape[-2:]) + >>> br = dft_cell(ar) + >>> print(br.shape) + (2, 32, 512) + """ + def set_options(self, ndim, norm): + inv = True + scale = { + 'forward': None, + 'backward': 'n', + 'ortho': 'sqrtn', + }[norm] + r2c_flags = np.zeros(ndim, dtype=bool).tolist() + r2c_flags[-1] = True + return inv, scale, r2c_flags + + def construct(self, ar, ai): + ''' perform n-dimensional irDFT on complex tensor and output real tensor ''' + return self.dft1_seq((ar, ai))[0] + + +class DFTn(_DFTn): + r""" + 1/2/3D discrete Fourier transformation on complex number. The results should be same as + `scipy.fft.fftn() `_ . + + Args: + shape (tuple): The shape of the dimensions to be transformed, other dimensions need not be included. + dim (tuple): Dimensions to be transformed. Default: None, the leading dimensions will be transformed. + norm (str): Normalization mode, should be one of 'forward', 'backward', 'ortho'. Default: 'backward', + same as torch.fft.irfftn + modes (tuple, int, None): The length of the output transform axis. The `modes` must be no greater than half of the + dimension of input 'x'. + compute_dtype (mindspore.dtype): The type of input tensor. Default: mindspore.float32. + + Inputs: + - **ar** (Tensor) - Real part of the tensor to be transformed, with trailing dimensions aligned with `shape`. + - **ai** (Tensor) - Imag part of the tensor to be transformed, with trailing dimensions aligned with `shape`. + + Outputs: + - **br** (Tensor) - Real part of the output tensor, with trailing dimensions aligned with `shape`. + - **bi** (Tensor) - Imag part of the output tensor, with trailing dimensions aligned with `shape`. + + Supported Platforms: + ``Ascend`` ``CPU`` + + Examples: + >>> from mindspore import ops + >>> from mindflow.cell import DFTn + >>> ar = ops.rand((2, 32, 512)) + >>> ai = ops.rand((2, 32, 512)) + >>> dft_cell = DFTn(x.shape[-2:]) + >>> br, bi = dft_cell(ar, ai) + >>> print(br.shape) + (2, 32, 512) + """ + def set_options(self, ndim, norm): + inv = False + scale = { + 'forward': 'n', + 'backward': None, + 'ortho': 'sqrtn', + }[norm] + r2c_flags = np.zeros(ndim, dtype=bool).tolist() + return inv, scale, r2c_flags + + def construct(self, ar, ai): + ''' perform n-dimensional DFT on complex tensor ''' + # n-D complex Fourier transform, output dimension (..., m, n) + return self.dft1_seq((ar, ai)) + + +class IDFTn(DFTn): + r""" + 1/2/3D discrete inverse Fourier transformation on complex number. The results should be same as + `scipy.fft.ifftn() `_ . + + Args: + shape (tuple): The shape of the dimensions to be transformed, other dimensions need not be included. + dim (tuple): Dimensions to be transformed. Default: None, the leading dimensions will be transformed. + norm (str): Normalization mode, should be one of 'forward', 'backward', 'ortho'. Default: 'backward', + same as torch.fft.irfftn + modes (tuple, int, None): The length of the output transform axis. The `modes` must be no greater than half of the + dimension of input 'x'. + compute_dtype (mindspore.dtype): The type of input tensor. Default: mindspore.float32. + + Inputs: + - **ar** (Tensor) - Real part of the tensor to be transformed, with trailing dimensions aligned with `shape`. + - **ai** (Tensor) - Imag part of the tensor to be transformed, with trailing dimensions aligned with `shape`. + + Outputs: + - **br** (Tensor) - Real part of the output tensor, with trailing dimensions aligned with `shape`. + - **bi** (Tensor) - Imag part of the output tensor, with trailing dimensions aligned with `shape`. + + Supported Platforms: + ``Ascend`` ``CPU`` + + Examples: + >>> from mindspore import ops + >>> from mindflow.cell import DFTn + >>> ar = ops.rand((2, 32, 512)) + >>> ai = ops.rand((2, 32, 512)) + >>> dft_cell = DFTn(x.shape[-2:]) + >>> br, bi = dft_cell(ar, ai) + >>> print(br.shape) + (2, 32, 512) + """ + def set_options(self, ndim, norm): + inv = True + scale = { + 'forward': None, + 'backward': 'n', + 'ortho': 'sqrtn', + }[norm] + r2c_flags = np.zeros(ndim, dtype=bool).tolist() + return inv, scale, r2c_flags + + +class DCT(nn.Cell): + r""" + 1D discrete cosine transformation on real number on the last axis. The results should be same as + `scipy.fft.dct() `_ . + Reference: `Type 2 DCT using N FFT (Makhoul) `_ . + + Args: + shape (tuple): The shape of the dimensions to be transformed, other dimensions need not be included. + Must be a length-1 tuple. + compute_dtype (mindspore.dtype): The type of input tensor. Default: mindspore.float32. + + Inputs: + - **a** (Tensor) - The real tensor to be transformed, with trailing dimensions aligned with `shape`. + + Outputs: + - **b** (Tensor) - The output real tensor, with trailing dimensions aligned with `shape`. + + Supported Platforms: + ``Ascend`` ``CPU`` + + Examples: + >>> from mindspore import ops + >>> from mindflow.cell import DCT + >>> a = ops.rand((2, 32, 512)) + >>> dft_cell = DCT(x.shape[-1:]) + >>> b = dft_cell(a) + >>> print(b.shape) + (2, 32, 512) + """ + def __init__(self, shape, compute_dtype=mstype.float32): + super().__init__() + + n = convert_shape(shape) + + self.dft_cell = DFTn(n, compute_dtype=compute_dtype) + + w = Tensor(np.arange(n) * np.pi / (2 * n), dtype=compute_dtype) + self.cosw = ops.cos(w) + self.sinw = ops.sin(w) + + self.fliper = MyFlip() + + def construct(self, a): + ''' perform 1-dimensional DCT on real tensor ''' + b_half1 = a[..., ::2] + b_half2 = self.fliper(a[..., 1::2], dims=-1) + b = ops.cat([b_half1, b_half2], axis=-1) + cr, ci = self.dft_cell(b, b * 0) + return 2 * (cr * self.cosw + ci * self.sinw) + + +class IDCT(nn.Cell): + r""" + 1D inverse discrete cosine transformation on real number on the last axis. The results should be same as + `scipy.fft.dct() `_ . + Reference: `A fast cosine transform in one and two dimensions + `_ . + + Args: + shape (tuple): The shape of the dimensions to be transformed, other dimensions need not be included. + Must be a length-1 tuple. + compute_dtype (mindspore.dtype): The type of input tensor. Default: mindspore.float32. + + Inputs: + - **a** (Tensor) - The real tensor to be transformed, with trailing dimensions aligned with `shape`. + + Outputs: + - **b** (Tensor) - The output real tensor, with trailing dimensions aligned with `shape`. + + Supported Platforms: + ``Ascend`` ``CPU`` + + Examples: + >>> from mindspore import ops + >>> from mindflow.cell import IDCT + >>> a = ops.rand((2, 32, 512)) + >>> dft_cell = IDCT(x.shape[-1:]) + >>> b = dft_cell(a) + >>> print(b.shape) + (2, 32, 512) + """ + def __init__(self, shape, compute_dtype=mstype.float32): + super().__init__() + + n = convert_shape(shape) + + # assert n % 2 == 0, 'only support even length' # n has to be even, or IRDFTn would fail + + self.dft_cell = IRDFTn(n, compute_dtype=compute_dtype) + + w = Tensor(np.arange(n // 2 + 1) * np.pi / (2 * n), dtype=compute_dtype) + self.cosw = ops.cos(w) + self.sinw = ops.sin(w) + + self.fliper = MyFlip() + + def construct(self, a): + ''' perform 1-dimensional iDCT on real tensor ''' + n = a.shape[-1] + + br = a[..., :n // 2 + 1] + bi = ops.pad(self.fliper(- a[..., -(n // 2):], dims=-1), (1, 0)) + vr = (br * self.cosw - bi * self.sinw) / 2 + vi = (bi * self.cosw + br * self.sinw) / 2 + + c = self.dft_cell(vr, vi) # (..., n) + c1 = c[..., :(n + 1) // 2] + c2 = self.fliper(c[..., (n + 1) // 2:], dims=-1) + d1 = ops.pad(c1.reshape(-1)[..., None], (0, 1)).reshape(*c1.shape[:-1], -1) + d2 = ops.pad(c2.reshape(-1)[..., None], (1, 0)).reshape(*c2.shape[:-1], -1) + # in case n is odd, d1 and d2 need to be aligned + d1 = d1[..., :n] + d2 = ops.pad(d2, (0, n % 2)) + return d1 + d2 + + +class DST(nn.Cell): + r""" + 1D discrete sine transformation on real number on the last axis. The results should be same as + `scipy.fft.dct() `_ . + Reference: `Wikipedia `_ . + + Args: + shape (tuple): The shape of the dimensions to be transformed, other dimensions need not be included. + Must be a length-1 tuple. + compute_dtype (mindspore.dtype): The type of input tensor. Default: mindspore.float32. + + Inputs: + - **a** (Tensor) - The real tensor to be transformed, with trailing dimensions aligned with `shape`. + + Outputs: + - **b** (Tensor) - The output real tensor, with trailing dimensions aligned with `shape`. + + Supported Platforms: + ``Ascend`` ``CPU`` + + Examples: + >>> from mindspore import ops + >>> from mindflow.cell import DST + >>> a = ops.rand((2, 32, 512)) + >>> dft_cell = DST(x.shape[-1:]) + >>> b = dft_cell(a) + >>> print(b.shape) + (2, 32, 512) + """ + def __init__(self, shape, compute_dtype=mstype.float32): + super().__init__() + n = convert_shape(shape) + self.dft_cell = DCT(n, compute_dtype=compute_dtype) + multiplier = np.ones(n) + multiplier[..., 1::2] *= -1 + self.multiplier = Tensor(multiplier, dtype=compute_dtype) + + def construct(self, a): + ''' perform 1-dimensional DST on real tensor ''' + return self.dft_cell.fliper(self.dft_cell(a * self.multiplier), dims=-1) + + +class IDST(nn.Cell): + r""" + 1D inverse discrete sine transformation on real number on the last axis. The results should be same as + `scipy.fft.dct() `_ . + Reference: `Wikipedia `_ . + + Args: + shape (tuple): The shape of the dimensions to be transformed, other dimensions need not be included. + Must be a length-1 tuple. + compute_dtype (mindspore.dtype): The type of input tensor. Default: mindspore.float32. + + Inputs: + - **a** (Tensor) - The real tensor to be transformed, with trailing dimensions aligned with `shape`. + + Outputs: + - **b** (Tensor) - The output real tensor, with trailing dimensions aligned with `shape`. + + Supported Platforms: + ``Ascend`` ``CPU`` + + Examples: + >>> from mindspore import ops + >>> from mindflow.cell import IDST + >>> a = ops.rand((2, 32, 512)) + >>> dft_cell = IDST(x.shape[-1:]) + >>> b = dft_cell(a) + >>> print(b.shape) + (2, 32, 512) + """ + def __init__(self, shape, compute_dtype=mstype.float32): + super().__init__() + n = convert_shape(shape) + self.dft_cell = IDCT(n, compute_dtype=compute_dtype) + multiplier = np.ones(n) + multiplier[..., 1::2] *= -1 + self.multiplier = Tensor(multiplier, dtype=compute_dtype) + + def construct(self, a): + ''' perform 1-dimensional iDST on real tensor ''' + return self.dft_cell(self.dft_cell.fliper(a, dims=-1)) * self.multiplier diff --git a/mindscience/models/neural_operator/ffno.py b/mindscience/models/neural_operator/ffno.py index d4ae17e31525da95478ebe6788a3fafbc1bbc9a9..be22763c34b691723a5e3af21d48447f94efc047 100644 --- a/mindscience/models/neural_operator/ffno.py +++ b/mindscience/models/neural_operator/ffno.py @@ -21,7 +21,7 @@ from mindspore.common.initializer import XavierNormal, initializer import mindspore.common.dtype as mstype from .ffno_sp import SpectralConv1d, SpectralConv2d, SpectralConv3d -from ...common.math import get_grid_1d, get_grid_2d, get_grid_3d +from ...core.math import get_grid_1d, get_grid_2d, get_grid_3d from ...utils.check_func import check_param_type @@ -316,8 +316,10 @@ class FFNO(nn.Cell): check_param_type(positional_embedding, "positional_embedding", data_type=bool, exclude_type=str) if ff_weight_norm: - raise ValueError(f"The weight normalization is not supported in feedforward\ - but got value of ff_weight_norm {ff_weight_norm}") + raise ValueError( + f"The weight normalization is not supported in feedforward\ + but got value of ff_weight_norm {ff_weight_norm}") + if r_padding < 0: raise ValueError(f"The right padding value cannot be negative but got value of r_padding {r_padding}") @@ -326,7 +328,8 @@ class FFNO(nn.Cell): self.hidden_channels = hidden_channels self.lifting_channels = lifting_channels self.projection_channels = projection_channels - self.n_modes, self.resolutions = validate_and_expand_dimensions(1, n_modes, resolutions, False) + self.n_modes, self.resolutions = validate_and_expand_dimensions( + 1, n_modes, resolutions, False) self.n_layers = n_layers self.r_padding = r_padding self.data_format = data_format @@ -338,18 +341,26 @@ class FFNO(nn.Cell): self._concat = ops.Concat(axis=-1) self._positional_embedding = self._transpose(len(self.resolutions)) self._padding = self._pad(len(self.resolutions)) - self._lifting = self.lift_channels( - self.in_channels, self.hidden_channels, self.lifting_channels, self.ffno_compute_dtype) + if self.lifting_channels: + self._lifting = nn.SequentialCell([ + nn.Dense(self.in_channels, self.lifting_channels, has_bias=True).to_float(self.ffno_compute_dtype), + nn.Dense(self.lifting_channels, self.hidden_channels, has_bias=True).to_float(self.ffno_compute_dtype)]) + else: + self._lifting = nn.SequentialCell( + nn.Dense(self.in_channels, self.hidden_channels, has_bias=True).to_float(self.ffno_compute_dtype) + ) self.fourier_weight = None if share_weight: param_list = [] for i, n_mode in enumerate(self.n_modes): weight_shape = [hidden_channels, hidden_channels, n_mode] + w_re = Parameter(initializer(XavierNormal(), weight_shape, mstype.float32), name=f'base_w_re_{i}', requires_grad=True) w_im = Parameter(initializer(XavierNormal(), weight_shape, mstype.float32), name=f'base_w_im_{i}', requires_grad=True) + param_list.append(w_re) param_list.append(w_im) @@ -374,16 +385,17 @@ class FFNO(nn.Cell): dft_compute_dtype=self.dft_compute_dtype ) for _ in range(self.n_layers)]) - self._projection = self.lift_channels( - self.hidden_channels, self.out_channels, self.projection_channels, self.ffno_compute_dtype) - - def lift_channels(self, in_c, out_c, mid_c=0, compute_dtype=mstype.float32): - if mid_c: - return nn.SequentialCell([ - nn.Dense(in_c, mid_c, has_bias=True).to_float(compute_dtype), - nn.Dense(mid_c, out_c, has_bias=True).to_float(compute_dtype) + if self.projection_channels: + self._projection = nn.SequentialCell([ + nn.Dense(self.hidden_channels, self.projection_channels, has_bias=True).to_float( + self.ffno_compute_dtype), + nn.Dense(self.projection_channels, self.out_channels, has_bias=True).to_float( + self.ffno_compute_dtype) ]) - return nn.SequentialCell(nn.Dense(in_c, out_c, has_bias=True).to_float(compute_dtype)) + else: + self._projection = nn.SequentialCell( + nn.Dense(self.hidden_channels, self.out_channels, has_bias=True).to_float( + self.ffno_compute_dtype)) def construct(self, x: Tensor): """construct""" diff --git a/mindscience/models/neural_operator/ffno_sp.py b/mindscience/models/neural_operator/ffno_sp.py index 78a13a1a1b55ef032f1753b9ad2508995fca1986..b1fa1382fd96061f05209937c8c79976c17a1448 100644 --- a/mindscience/models/neural_operator/ffno_sp.py +++ b/mindscience/models/neural_operator/ffno_sp.py @@ -18,8 +18,8 @@ import mindspore as ms import mindspore.common.dtype as mstype from mindspore import nn, ops, Tensor, Parameter, ParameterTuple, mint from mindspore.common.initializer import XavierNormal, initializer -from ...common.math import get_grid_1d, get_grid_2d, get_grid_3d -from ...sciops import RDFTn, IRDFTn +from ...core.math import get_grid_1d, get_grid_2d, get_grid_3d +from ...core.fourier import RDFTn, IRDFTn class FeedForward(nn.Cell): diff --git a/mindscience/models/neural_operator/fno_sp.py b/mindscience/models/neural_operator/fno_sp.py index 5c287c7179cf9946f930f284d857892b4594ca4f..bb02333507dc94592727b632597b22d661b6b7d5 100644 --- a/mindscience/models/neural_operator/fno_sp.py +++ b/mindscience/models/neural_operator/fno_sp.py @@ -21,7 +21,7 @@ from mindspore import nn, ops, Tensor, Parameter, mint from mindspore.common.initializer import Zero from mindspore.ops import operations as P -from ...sciops import RDFTn, IRDFTn +from ...core.fourier import RDFTn, IRDFTn class SpectralConvDft(nn.Cell): diff --git a/mindscience/pde/README.md b/mindscience/pde/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5d0c2f43bf7337fd6c1f1a8dcd86b6fd27b0a7d7 --- /dev/null +++ b/mindscience/pde/README.md @@ -0,0 +1,422 @@ +## mindflow.pde + + + +### pde 介绍 + +- 文件结构 + +```bash +. +├── flow_with_loss.py +├── __init__.py +├── pde_with_loss.py +└── sympy2mindspore + ├── __init__.py + ├── parse_sympy.py + ├── pde_node.py + └── sympy_translation.py +``` + +- pde 模块是 MindSpore Science 框架中用于求解流体力学、静力学等领域中的偏微分方程的科学计算算子库,给出了自定义的数学运算实现(`mindspore function`),并可将 `sympy` 库中的符号计算转换成相应的 `mindspore function`。此外,pde 模块目前也支持在不同的算子神经网络框架下(如 FNO,FFNO,SNO,PDENet 等)给出一些**流体力学和静力学方程的损失函数计算**。pde 模块将形如加法、幂运算、微分等数学运算定义为相应的 Node 类,通过 `sympy_to_mindspore()` 方法为用户提供简洁的**形式化泛函计算接口**。结合 `mindflow` 中的其他模块,例如 `mindflow.cell`,用户能够更加高效的进行微分方程的神经网络求解和处理科学计算任务。 + + + +### sympy2mindspore 介绍 + +- 此子模块包含三个主要文件,其调用关系为: + +``` +parse_sympy.py --> sympy_translation.py --> pde_node.py +``` + + + +- `pde_node.py` 给出了部分数学运算的 mindspore 实现: + + - 定义字典 `MINDSPORE_SYMPY_TRANSLATIONS`,为基础数学符号转变成 mindspore 实现建立桥梁; + + - 对于某些更高级的数学运算,定义对应的 Node 类,类中必须包含 `compute()` 方法对输入进行相应的运算; + + - 共包含以下类 + + | 名称 | 作用 | 初始化输入 | compute() 输入 | compute() 输出 | + | ---------------- | -------------------------------------------------------- | ------------------------------------------------------------ | --------------------------- | ----------------------------------------- | + | `NumberNode` | 可作为这些 Node 类的起始输入,是一种数据结构 | `Tensor` | 空 `dict` | Tensor 的值 | + | `AddNode` | 对所有输入进行求和 | 某些 Node 类实例所构成的 `list` | `dict` 形式的数据 | 在这些实例计算结果的基础上再求总和 | + | `PowNode` | 对所有输入进行幂运算 | 某些 Node 类实例和幂值成对构成的 `list` | `dict` 形式的数据 | 以 list 第 0 维为底, 第 1 维为幂进行计算 | + | `MulNode` | 对所有输入进行求积 | 某些 Node 类实例所构成的 `list` | `dict` 形式的数据 | 在这些实例计算结果的基础上再求乘积 | + | `SymbolNode` | 对输入的数据按照变量的维度进行列分割 | 变量构成的 `list`,及某一变量的下标 | `dict` 形式的数据(Tensor) | Tensor 在该下标下的列 slice | + | `ParamNode` | 对输入的数据按照参数的维度进行列分割 | 参数构成的 list,及某一参数的下标 | `dict` 形式的数据(Tensor) | Tensor 在该下标下的列 slice | + | `NetOutputNode` | 针对多元输出的场景:对输入的数据按照指定的维度进行列分割 | 输出变量构成的 `list`,及某一变量的下标 | `dict` 形式的数据(Tensor) | Tensor 在该下标下的列 slice | + | `MSFunctionNode` | 将数学符号转化为自定义 Node 类结构的算子 | 某些 Node 类实例所构成的 `list` | `dict` 形式的数据 | 该数学符号对应的运算结果 | + | `DerivativeNode` | 对已计算得到的一阶和二阶微分的值进行总结输出 | 微分算子的阶,自变量下标,微分下标(如梯度和 Hessian 的分量) | `dict` 形式的数据 | 对应下标下的微分值 | + + + + - 示例一:`SymbolNode` + + ```python + import numpy as np + from mindspore import Tensor + from mindflow.pde.sympy2mindspore.pde_node import SymbolNode + + # consider three-dimensional variables, the interested index is 1, i.e., 'y' + node = SymbolNode(in_vars=['x', 'y', 'z'], in_var_idx=1) + + # number of variables = 4 + inputs = Tensor(np.random.rand(4, 3).astype(np.float32)) + + # data structured as dict + data = {"inputs": inputs} + result = node.compute(data) + print(result) + + # predicted outputs: + # [[0.71518934] + # [0.4236548 ] + # [0.891773 ] + # [0.79172504]] + ``` + + + + - 示例二:`NumberNode`, `AddNode`, `SymbolNode` + + 测试函数 + $$ + f(x,y) = x+y +3 + $$ + 对不同的输入得到的值。 + + ```python + import numpy as np + from mindspore import Tensor + from mindspore import dtype as mstype + from mindflow.pde.sympy2mindspore.pde_node import NumberNode, AddNode, SymbolNode + + # first input column + x_node = SymbolNode(in_vars=['x', 'y'], in_var_idx=0) + + # second input column + y_node = SymbolNode(in_vars=['x', 'y'], in_var_idx=1) + + # the constant number 3 + const_node = NumberNode(nodes=[Tensor(np.float32(3.0), mstype.float32)]) + + # initialized using a list of Node-instances + add_node = AddNode([x_node, y_node, const_node]) + + # data structured as dict + input_data = Tensor(np.array([[1.0, 2.0], [2.0, 3.0]]).astype(np.float32)) + data = {'inputs': input_data} + result = add_node.compute(data) + print(result) + + # predicted outputs: + #[[ 6.] + # [10.]] + ``` + + + + - 示例三:DerivativeNode + + 考虑二元函数 f 的一阶梯度和二阶 Hessian。假设二者已通过其他途径计算获得,现利用 DerivativeNode 对其进行总结和拆分。目标:输出在 norm 方向下的方向导数,以及 + $$ + \frac{\partial^2f}{\partial x^2}, \frac{\partial^2f}{\partial x \partial y}, \frac{\partial^2f}{\partial y \partial x}, \frac{\partial^2f}{\partial y^2} + $$ + 具体的值。 + + ```python + import numpy as np + import mindspore.numpy as mnp + from mindspore import Tensor + from mindspore import dtype as mstype + from mindflow.pde.sympy2mindspore.pde_node import DerivativeNode + + + in_vars = ["x", "y"] + + # data structured as dict, where "norm" is the normal vector defined on the boundary. + data = { + "jacobian": [Tensor(np.array([1.0, 2.0]), dtype=mstype.float32)], # Example Jacobian + "hessian": [ + [Tensor(np.array([1.0, 2.0]), dtype=mstype.float32), # dxdx and dxdy + Tensor(np.array([2.0, 3.0]), dtype=mstype.float32)], # dydx and dydy + ], # Example Hessian + "norm": Tensor(np.array([0.5, 0.5]), dtype=mstype.float32), # Example norm + } + + # First-order derivative example + first_order_node = DerivativeNode(in_vars, order=1, in_var_idx=0, out_var_idx=0, is_norm=True) + result_first_order = first_order_node.compute(data) + print("First-order derivative result:", result_first_order) + + # Second-order derivative example + dxdx_node = DerivativeNode(in_vars, order=2, in_var_idx=(0, 0), out_var_idx=0) + dxdy_node = DerivativeNode(in_vars, order=2, in_var_idx=(0, 1), out_var_idx=0) + dydx_node = DerivativeNode(in_vars, order=2, in_var_idx=(1, 0), out_var_idx=0) + dydy_node = DerivativeNode(in_vars, order=2, in_var_idx=(1, 1), out_var_idx=0) + result_dxdx = dxdx_node.compute(data) + result_dxdy = dxdy_node.compute(data) + result_dydx = dydx_node.compute(data) + result_dydy = dydy_node.compute(data) + print("dxdx derivative result:", result_dxdx) + print("dxdy derivative result:", result_dxdy) + print("dydx derivative result:", result_dydx) + print("dydy derivative result:", result_dydy) + + # predicted outputs: + # First-order derivative result: 1.5 + # dxdx derivative result: 1.0 + # dxdy derivative result: 2.0 + # dydx derivative result: 2.0 + # dydy derivative result: 3.0 + ``` + + + + +- `sympy_translation.py` 集成了 `pde_node.py` 中规定的基础数学运算和更高级的运算,对于输入的 `sympy` 中的对象(一些运算)手动翻译成 mindspore 函数类(也即 Node 类,可以实现具体的计算);最后对这些算子再进行一层封装、分类和组装,从而实现对给出的一串数学符号进行具体的计算。 + + + +- `parse_sympy.py` 中的核心函数 `sympy_to_mindspore` 通过调用`sympy_translation.py` 以及 `_make_nodes()` 函数,将来自 `sympy` 的数学运算串翻译成 mindspore function,并生成计算图。 + + - 示例:对于以下两种运算生成对应的 mindspore function: + $$ + \begin{align} + & f(x,y) = x+y \\ + & \nabla [u](x,y) \cdot \mathbf{1} = \frac{\partial u}{\partial x} + \frac{\partial u}{\partial y} \\ + \end{align} + $$ + +```python +from mindflow.pde import sympy_to_mindspore +from sympy import symbols, Function, diff + + +x, y = symbols('x, y') +u = Function('u')(x, y) +in_vars = [x, y] +out_vars = [u] +eq1 = x + y +eq2 = diff(u, (x, 1)) + diff(u, (y, 1)) +equations = {"eq1": eq1, "eq2": eq2} +res = sympy_to_mindspore(equations, in_vars, out_vars) + +# predicted outputs: +# eq1: x + y +# Item numbers of current derivative formula nodes: 2 +# eq2: Derivative(u(x, y), x) + Derivative(u(x, y), y) +# Item numbers of current derivative formula nodes: 2 +``` + + + + + +### pde_with_loss.py 介绍 + +- 此模块被应用于神经网络求解单一方程的方法(如 PINNs)中; + +- 此模块定义基类 `PDEWithLoss`,该类规定了一个**完整、可计算且可被训练求解**的偏微分方程应有的函数和数据结构: + + - `pde(self)`返回一个 sympy 数学符号串,该符号串表示了当前 PDE 问题在求解区域内部的方程左端,默认右端项为零; + + - `self.pde_nodes`:对应于 PDE 的 mindspore function; + + - `bc(self)`:返回一个 `sympy` 数学符号串,该符号串表示了当前 PDE 问题的边界条件表达式左端,默认右端项为零;注意此函数并非 `PDEWithLoss` 强制规定; + + - `self.bc_nodes`:对应于 bc 的 mindspore function;注意此成员并非 `PDEWithLoss` 强制规定; + + - `get_loss(self)`:返回一个 mindspore function,此函数为用户自定义函数,应当在子类中被实现;函数内容即为采用训练法求解当前 PDE 问题时定义的神经网络损失函数; + + - `parse_node(self, formula_nodes, inputs, norm)`:输入 `formula_nodes` 为 mindspore function, 可以是自己类中的 `pde` 或者 `bc`;此函数是 mindspore function 实际发生计算的位置。 + + - 示例:定义一个 `PDEWithLoss` 子类,其中 + + - 方程:二维区域上的二阶椭圆型方程; + $$ + \begin{align} + -\Delta u + u &= f = 4,~x \in \Omega \subset \mathbb{R}^2,\\ + \nabla u \cdot \mathbf{1} &= g = 2,~x\in \partial \Omega. + \end{align} + $$ + + - 损失函数:基于两隐藏层全连接神经网络 $u_{\theta}$ 的 PINNs 损失函数; + $$ + L(\theta) = \int_{\Omega} \left( -\Delta u_{\theta}(x) + u_{\theta}(x) -f(x) \right)^2 \mathrm{d}x + \int_{\partial \Omega} \left( \nabla u_{\theta}(x) \cdot \mathbf{1} - g(x) \right)^2 \mathrm{d} S. + $$ + 注:积分形式的损失函数经过 Monte-Carlo 离散后(数值积分权值均为 1),等价于 MSE 形式。 + + ```python + import numpy as np + from sympy import symbols, Function, diff + from mindspore import nn, ops, Tensor + from mindspore import dtype as mstype + from mindflow.pde import PDEWithLoss, sympy_to_mindspore + + # define a fully-connected neural network with tanh activation + class Net(nn.Cell): + def __init__(self, cin=2, cout=1, hidden=10): + super().__init__() + self.fc1 = nn.Dense(cin, hidden) + self.fc2 = nn.Dense(hidden, hidden) + self.fcout = nn.Dense(hidden, cout) + self.act = ops.Tanh() + + def construct(self, x): + x = self.act(self.fc1(x)) + x = self.act(self.fc2(x)) + x = self.fcout(x) + return x + model = Net() + + # user-defined class to describe Poisson's equation with pure Neumann's boundary condition. + class MyProblem(PDEWithLoss): + def __init__(self, model, loss_fn=nn.MSELoss()): # Take the MSE loss function + self.x, self.y = symbols('x y') + self.u = Function('u')(self.x, self.y) + self.in_vars = [self.x, self.y] + self.out_vars = [self.u] + super(MyProblem, self).__init__(model, in_vars=self.in_vars, out_vars=self.out_vars) + self.loss_fn = loss_fn + self.bc_nodes = sympy_to_mindspore(self.bc(), self.in_vars, self.out_vars) + + # pde's info inside the domain + def pde(self): + my_eq = diff(self.u, (self.x, 2)) + diff(self.u, (self.y, 2)) - self.u + 4.0 + equations = {"my_eq": my_eq} + return equations + + # pde's info on the boundary + def bc(self): + bc_eq = diff(self.u, (self.x, 1)) + diff(self.u, (self.y, 1)) - 2.0 + equations = {"bc_eq": bc_eq} + return equations + + # PINN's loss function + def get_loss(self, pde_data, bc_data): + pde_res = self.parse_node(self.pde_nodes, inputs=pde_data) + pde_loss = self.loss_fn(pde_res[0], Tensor(np.array([0.0]), mstype.float32)) + bc_res = self.parse_node(self.bc_nodes, inputs=bc_data) + bc_loss = self.loss_fn(bc_res[0], Tensor(np.array([0.0]), mstype.float32)) + return pde_loss + bc_loss + + problem = MyProblem(model) + print(problem.pde()) + print(problem.bc()) + + # predicted outputs: + # my_eq: -u(x, y) + Derivative(u(x, y), (x, 2)) + Derivative(u(x, y), (y, 2)) + 4.0 + # Item numbers of current derivative formula nodes: 4 + # bc_eq: Derivative(u(x, y), x) + Derivative(u(x, y), y) - 2.0 + # Item numbers of current derivative formula nodes: 3 + # {'my_eq': -u(x, y) + Derivative(u(x, y), (x, 2)) + Derivative(u(x, y), (y, 2)) + 4.0} + # {'bc_eq': Derivative(u(x, y), x) + Derivative(u(x, y), y) - 2.0} + ``` + + + +- 此模块目前支持的方程如下: + + - 一维有黏性的 Burgers' equation (目前初始条件和边界条件待补充): + $$ + \frac{\partial u}{\partial t} + u \frac{\partial u}{\partial x} - \epsilon \frac{\partial^2 u}{\partial x^2} = 0. + $$ + + ```python + def pde(self): + """ + Define Burgers 1-D governing equations based on sympy, abstract method. + + Returns: + dict, user defined sympy symbolic equations. + """ + burgers_eq = diff(self.u, (self.t, 1)) + self.u * diff(self.u, (self.x, 1)) - \ + self.mu * diff(self.u, (self.x, 2)) + + equations = {"burgers": burgers_eq} + return equations + ``` + + + + - 二维不可压 Navier-Stokes equation(目前初始条件和边界条件待补充): + $$ + \text{连续性方程:}\quad\quad \frac{\partial u}{\partial x} + \frac{\partial u}{\partial y} = 0,\\ + x~\text{方向动量守恒:} \frac{\partial u}{\partial t} + u\frac{\partial u}{\partial x} + v\frac{\partial u}{\partial y} = -\frac{1}{\rho} \frac{\partial p}{\partial x} + \nu \left( \frac{\partial^2 u}{\partial x^2}+\frac{\partial^2 u}{\partial y^2} \right), \\ + y~\text{方向动量守恒:} \frac{\partial v}{\partial t} + u\frac{\partial v}{\partial x} + v\frac{\partial v}{\partial y} = -\frac{1}{\rho} \frac{\partial p}{\partial y} + \nu \left( \frac{\partial^2 v}{\partial x^2}+\frac{\partial^2 v}{\partial y^2} \right). + $$ + + + ```python + def pde(self): + """ + Define governing equations based on sympy, abstract method. + + Returns: + dict, user defined sympy symbolic equations. + """ + + # momentum convervation along x + momentum_x = self.u.diff(self.t) + self.u * self.u.diff(self.x) + self.v * self.u.diff(self.y) + \ + self.p.diff(self.x) - self.number * (diff(self.u, (self.x, 2)) + diff(self.u, (self.y, 2))) + + # momentum conservation along y + momentum_y = self.v.diff(self.t) + self.u * self.v.diff(self.x) + self.v * self.v.diff(self.y) + \ + self.p.diff(self.y) - self.number * (diff(self.v, (self.x, 2)) + diff(self.v, (self.y, 2))) + + # continuity equation + continuty = self.u.diff(self.x) + self.v.diff(self.y) + + equations = {"momentum_x": momentum_x, "momentum_y": momentum_y, "continuty": continuty} + return equations + ``` + + + + - 二维 Poisson's equation (目前边界条件待补充): + $$ + -\Delta u = f = 1. + $$ + + ```python + def pde(self): + """ + Define Poisson 2-D governing equations based on sympy, abstract method. + + Returns: + dict, user defined sympy symbolic equations. + """ + poisson = diff(self.u, (self.x, 2)) + diff(self.u, (self.y, 2)) + 1.0 + + equations = {"poisson": poisson} + return equations + ``` + + + + + +### flow_with_loss.py 介绍 + +- 此模块定义基类 `FlowWithLoss`,该类规定了一个**完整、可计算且可被训练求解**的偏微分方程应有的函数和数据结构: + + - `step(self, inputs)`:返回一个 Tensor 表示所采用模型的预测; + - `get_loss(self, inputs, labels)`:根据输入(可以是区域的 sample 点)和标签(可以是方程的源项信息)来构建可被优化的损失函数。 + +- 此模块目前支持的流体力学场景包括: + + - 稳态流(steady flow),即流场中任意一点的流体属性与时间无关; + - 非稳态流(unsteady flow),即流场中至少存在一点的流体属性与时间相关。 + +- 可接入多种类型的神经网络以及自定义网络,目前该模块主要**被应用于算子学习**方法中,因此没有包含具体的方程定义。 + + - 示例:FNO 求解二维不可压 Navier-Stokes 方程 https://gitee.com/mindspore/mindscience/blob/master/MindFlow/applications/data_driven/navier_stokes/fno2d/train.py + + + + diff --git a/mindscience/pde/README_EN.md b/mindscience/pde/README_EN.md new file mode 100644 index 0000000000000000000000000000000000000000..5c71dd6641695df9da2f14930b6026dcfc82db8f --- /dev/null +++ b/mindscience/pde/README_EN.md @@ -0,0 +1,423 @@ +## mindflow.pde + + + +### Introduction to the pde module + +- Framework + +```bash +. +├── flow_with_loss.py +├── __init__.py +├── pde_with_loss.py +└── sympy2mindspore + ├── __init__.py + ├── parse_sympy.py + ├── pde_node.py + └── sympy_translation.py +``` + +- The **PDE module** is a scientific computing operator library within the **MindSpore Science** framework, designed for solving partial differential equations (PDEs) in fields such as fluid dynamics and statics. It provides customized implementations of mathematical operations (`mindspore function`) and can convert symbolic computations from the `sympy` library into corresponding `mindspore function` code. + + Furthermore, the PDE module currently supports the computation of loss functions for various fluid dynamics and statics equations under different neural operator architectures (e.g., FNO, FFNO, SNO, PDENet). + + The module defines fundamental mathematical operations—such as addition, exponentiation, and differentiation—as corresponding **Node** classes. It offers users a concise, formal functional computation interface through the `sympy_to_mindspore()` method. + + Integrated with other modules in `mindflow`, such as `mindflow.cell`, the PDE module enables users to solve differential equations using neural networks and handle scientific computing tasks with greater efficiency. + + + +### sympy2mindspore + +- The calling relationships between the three main files in this submodule are as follows: + +``` +parse_sympy.py --> sympy_translation.py --> pde_node.py +``` + + + +- `pde_node.py` provides the MindSpore implementation for several mathematical operations. + + - The dictionary `MINDSPORE_SYMPY_TRANSLATIONS` is defined with the purpose of bridging basic mathematical symbols to their corresponding MindSpore implementations. + + - For more advanced mathematical operations, corresponding Node classes must be defined. These classes are required to include a `compute()` method that performs the respective operation on the input data. + + - The following classes are included: + + | Name | Function | Initialization Input | Input of compute() | Output of compute() | + | ---------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ----------------------------- | ------------------------------------------------------------ | + | `NumberNode` | Initial input for Node classes, serving as a data structure | Tensor | Empty `dict` | Tensor values | + | `ddNode` | Summation of all inputs | `list` of certain Node class instances | Data in `dict` format | Summation over computed instance results | + | `PowNode` | Computation of the power of all inputs | `list` of Node instance-power value pairs | Data in `dict` format | Computation using `list` dimension 0 as base, dimension 1 as exponent | + | `MulNode` | Product of all inputs | `list` of certain Node class instances | Data in `dict` format | Product of computed instance results | + | `SymbolNode` | Column-wise split of input data by variable dimension | `list` of specific Node class instances | Data(Tensor) in `dict` format | Column slice of tensor at specified index | + | `ParamNode` | Column-wise split of input data by parameter dimension | Parameter `list` with specified index | Data(Tensor) in `dict` format | Column slice of tensor at specified index | + | `NetOutputNode` | Column-wise split of input data along specified dimensions for multi-output scenarios | Output variable `list` with specified index | Data(Tensor) in `dict` format | Column slice of tensor at specified index | + | `MSFunctionNode` | Operator transformation from mathematical symbols to custom Node structures | `list` of selected Node instances | Data in `dict` format | Computation result of the mathematical symbol | + | `DerivativeNode` | Summary output of computed first and second-order differential values | Order of differential operator, independent variable index, differentiation index | Data in `dict` format | Derivative value at corresponding index | + + + + - Example I: SymbolNode + + ```python + import numpy as np + from mindspore import Tensor + from mindflow.pde.sympy2mindspore.pde_node import SymbolNode + + # consider three-dimensional variables, the interested index is 1, i.e., 'y' + node = SymbolNode(in_vars=['x', 'y', 'z'], in_var_idx=1) + + # number of variables = 4 + inputs = Tensor(np.random.rand(4, 3).astype(np.float32)) + + # data structured as dict + data = {"inputs": inputs} + result = node.compute(data) + print(result) + + # predicted outputs: + # [[0.71518934] + # [0.4236548 ] + # [0.891773 ] + # [0.79172504]] + ``` + + + + - Example II: NumberNode, AddNode, SymbolNode + + Test the outputs of different inputs for the following function: + $$ + f(x,y) = x+y +3 + $$ + ```python + import numpy as np + from mindspore import Tensor + from mindspore import dtype as mstype + from mindflow.pde.sympy2mindspore.pde_node import NumberNode, AddNode, SymbolNode + + # first input column + x_node = SymbolNode(in_vars=['x', 'y'], in_var_idx=0) + + # second input column + y_node = SymbolNode(in_vars=['x', 'y'], in_var_idx=1) + + # the constant number 3 + const_node = NumberNode(nodes=[Tensor(np.float32(3.0), mstype.float32)]) + + # initialized using a list of Node-instances + add_node = AddNode([x_node, y_node, const_node]) + + # data structured as dict + input_data = Tensor(np.array([[1.0, 2.0], [2.0, 3.0]]).astype(np.float32)) + data = {'inputs': input_data} + result = add_node.compute(data) + print(result) + + # predicted outputs: + #[[ 6.] + # [10.]] + ``` + + + + - Example III: `DerivativeNode` + + Let us consider the gradient and Hessian operator of the function f. Assume that the derivatives have been obtained via a certain process. Now use `DerivativeNode` to compute and split these values. Our goal is to compute/output the directional derivative along the *norm* direction, as well as the hessian + $$ + \frac{\partial^2f}{\partial x^2}, \frac{\partial^2f}{\partial x \partial y}, \frac{\partial^2f}{\partial y \partial x}, \frac{\partial^2f}{\partial y^2}. + $$ + ```python + import numpy as np + import mindspore.numpy as mnp + from mindspore import Tensor + from mindspore import dtype as mstype + from mindflow.pde.sympy2mindspore.pde_node import DerivativeNode + + + in_vars = ["x", "y"] + + # data structured as dict, where "norm" is the normal vector defined on the boundary. + data = { + "jacobian": [Tensor(np.array([1.0, 2.0]), dtype=mstype.float32)], # Example Jacobian + "hessian": [ + [Tensor(np.array([1.0, 2.0]), dtype=mstype.float32), # dxdx and dxdy + Tensor(np.array([2.0, 3.0]), dtype=mstype.float32)], # dydx and dydy + ], # Example Hessian + "norm": Tensor(np.array([0.5, 0.5]), dtype=mstype.float32), # Example norm + } + + # First-order derivative example + first_order_node = DerivativeNode(in_vars, order=1, in_var_idx=0, out_var_idx=0, is_norm=True) + result_first_order = first_order_node.compute(data) + print("First-order derivative result:", result_first_order) + + # Second-order derivative example + dxdx_node = DerivativeNode(in_vars, order=2, in_var_idx=(0, 0), out_var_idx=0) + dxdy_node = DerivativeNode(in_vars, order=2, in_var_idx=(0, 1), out_var_idx=0) + dydx_node = DerivativeNode(in_vars, order=2, in_var_idx=(1, 0), out_var_idx=0) + dydy_node = DerivativeNode(in_vars, order=2, in_var_idx=(1, 1), out_var_idx=0) + result_dxdx = dxdx_node.compute(data) + result_dxdy = dxdy_node.compute(data) + result_dydx = dydx_node.compute(data) + result_dydy = dydy_node.compute(data) + print("dxdx derivative result:", result_dxdx) + print("dxdy derivative result:", result_dxdy) + print("dydx derivative result:", result_dydx) + print("dydy derivative result:", result_dydy) + + # predicted outputs: + # First-order derivative result: 1.5 + # dxdx derivative result: 1.0 + # dxdy derivative result: 2.0 + # dydx derivative result: 2.0 + # dydy derivative result: 3.0 + ``` + + + + +- `sympy_translation.py` integrates the fundamental and advanced mathematical operations defined in `pde_node.py`. It manually translates input objects from SymPy (representing certain operations) into MindSpore function classes (i.e., Node classes capable of performing concrete computations). Finally, these operators are further encapsulated, categorized, and assembled to enable the concrete evaluation of a given sequence of mathematical symbols. + + + +- The core function `sympy_to_mindspore` in `parse_sympy.py` translates sequences of mathematical operations from SymPy into MindSpore functions and generates a computational graph by invoking `sympy_translation.py` and the `_make_nodes()` function. + + - Example: Generate corresponding MindSpore functions for the following two operations: + $$ + \begin{align} + & f(x,y) = x+y \\ + & \nabla [u](x,y) \cdot \mathbf{1} = \frac{\partial u}{\partial x} + \frac{\partial u}{\partial y} \\ + \end{align} + $$ + +```python +from mindflow.pde import sympy_to_mindspore +from sympy import symbols, Function, diff + + +x, y = symbols('x, y') +u = Function('u')(x, y) +in_vars = [x, y] +out_vars = [u] +eq1 = x + y +eq2 = diff(u, (x, 1)) + diff(u, (y, 1)) +equations = {"eq1": eq1, "eq2": eq2} +res = sympy_to_mindspore(equations, in_vars, out_vars) + +# predicted outputs: +# eq1: x + y +# Item numbers of current derivative formula nodes: 2 +# eq2: Derivative(u(x, y), x) + Derivative(u(x, y), y) +# Item numbers of current derivative formula nodes: 2 +``` + + + + + +### pde_with_loss.py + +- This module is applied in neural network methods for solving single equations (such as PINNs method). + +- This module defines the base class `PDEWithLoss`, which specifies the functions and data structures that a **complete, computable, and trainable** partial differential equation should possess. + + - `pde(self)`: Returns a string of SymPy mathematical symbols representing the left-hand side of the current PDE problem within the solution domain, with the right-hand side defaulting to zero. + + - `self.pde_nodes`: Corresponds to the MindSpore function for the PDE. + + - `bc(self)`: Returns a string of` sympy` mathematical symbols representing the left-hand side of the boundary condition expression for the current PDE problem, with the right-hand side defaulting to zero. Note that this function is not mandatory for `PDEWithLoss`. + + - `self.bc_nodes`: Corresponds to the MindSpore function for the boundary conditions (bc). Note that this member is not mandatory for `PDEWithLoss`. + + - `get_loss(self)`: Returns a MindSpore function defined by the user, which should be implemented in a subclass. This function defines the neural network loss function used when solving the current PDE problem with a training-based method. + + - `parse_node(self, formula_nodes, inputs, norm)`: The input `formula_nodes` is a MindSpore function, which can be the `pde` or `bc` from its own class. This function is where the actual computation of the MindSpore function takes place. + + - Example: Define a subclass of `PDEWithLoss`, where + + - The equation is defined as a second-order elliptic equation in a two-dimensional domain. + $$ + \begin{align} + -\Delta u + u &= f = 4,~x \in \Omega \subset \mathbb{R}^2,\\ + \nabla u \cdot \mathbf{1} &= g = 2,~x\in \partial \Omega. + \end{align} + $$ + + - The loss function is defined using a two hidden-layer fully-connected network $u_{\theta}$ and the formulation from PINNs' method. + $$ + L(\theta) = \int_{\Omega} \left( -\Delta u_{\theta}(x) + u_{\theta}(x) -f(x) \right)^2 \mathrm{d}x + \int_{\partial \Omega} \left( \nabla u_{\theta}(x) \cdot \mathbf{1} - g(x) \right)^2 \mathrm{d} S. + $$ + Note that the integral-type loss can be equivalent to the MSE loss when discretized using the Monte-Carlo quadrature rule, where all the weights for numerical integration is 1. + + ```python + import numpy as np + from sympy import symbols, Function, diff + from mindspore import nn, ops, Tensor + from mindspore import dtype as mstype + from mindflow.pde import PDEWithLoss, sympy_to_mindspore + + # define a fully-connected neural network with tanh activation + class Net(nn.Cell): + def __init__(self, cin=2, cout=1, hidden=10): + super().__init__() + self.fc1 = nn.Dense(cin, hidden) + self.fc2 = nn.Dense(hidden, hidden) + self.fcout = nn.Dense(hidden, cout) + self.act = ops.Tanh() + + def construct(self, x): + x = self.act(self.fc1(x)) + x = self.act(self.fc2(x)) + x = self.fcout(x) + return x + model = Net() + + # user-defined class to describe Poisson's equation with pure Neumann's boundary condition. + class MyProblem(PDEWithLoss): + def __init__(self, model, loss_fn=nn.MSELoss()): # Take the MSE loss function + self.x, self.y = symbols('x y') + self.u = Function('u')(self.x, self.y) + self.in_vars = [self.x, self.y] + self.out_vars = [self.u] + super(MyProblem, self).__init__(model, in_vars=self.in_vars, out_vars=self.out_vars) + self.loss_fn = loss_fn + self.bc_nodes = sympy_to_mindspore(self.bc(), self.in_vars, self.out_vars) + + # pde's info inside the domain + def pde(self): + my_eq = diff(self.u, (self.x, 2)) + diff(self.u, (self.y, 2)) - self.u + 4.0 + equations = {"my_eq": my_eq} + return equations + + # pde's info on the boundary + def bc(self): + bc_eq = diff(self.u, (self.x, 1)) + diff(self.u, (self.y, 1)) - 2.0 + equations = {"bc_eq": bc_eq} + return equations + + # PINN's loss function + def get_loss(self, pde_data, bc_data): + pde_res = self.parse_node(self.pde_nodes, inputs=pde_data) + pde_loss = self.loss_fn(pde_res[0], Tensor(np.array([0.0]), mstype.float32)) + bc_res = self.parse_node(self.bc_nodes, inputs=bc_data) + bc_loss = self.loss_fn(bc_res[0], Tensor(np.array([0.0]), mstype.float32)) + return pde_loss + bc_loss + + problem = MyProblem(model) + print(problem.pde()) + print(problem.bc()) + + # predicted outputs: + # my_eq: -u(x, y) + Derivative(u(x, y), (x, 2)) + Derivative(u(x, y), (y, 2)) + 4.0 + # Item numbers of current derivative formula nodes: 4 + # bc_eq: Derivative(u(x, y), x) + Derivative(u(x, y), y) - 2.0 + # Item numbers of current derivative formula nodes: 3 + # {'my_eq': -u(x, y) + Derivative(u(x, y), (x, 2)) + Derivative(u(x, y), (y, 2)) + 4.0} + # {'bc_eq': Derivative(u(x, y), x) + Derivative(u(x, y), y) - 2.0} + ``` + + + +- The equations currently supported by this module are as follows: + + - One-dimensional Burgers' equation with artificial viscosity (initial & boundary values remained to be updated): + $$ + \frac{\partial u}{\partial t} + u \frac{\partial u}{\partial x} - \epsilon \frac{\partial^2 u}{\partial x^2} = 0. + $$ + + ```python + def pde(self): + """ + Define Burgers 1-D governing equations based on sympy, abstract method. + + Returns: + dict, user defined sympy symbolic equations. + """ + burgers_eq = diff(self.u, (self.t, 1)) + self.u * diff(self.u, (self.x, 1)) - \ + self.mu * diff(self.u, (self.x, 2)) + + equations = {"burgers": burgers_eq} + return equations + ``` + + + + - Two-dimensional incompressible Navier-Stokes equation (initial & boundary values remained to be updated): + $$ + \text{continuity}\quad\quad \frac{\partial u}{\partial x} + \frac{\partial u}{\partial y} = 0,\\ + \text{momentum}~x \quad\quad \frac{\partial u}{\partial t} + u\frac{\partial u}{\partial x} + v\frac{\partial u}{\partial y} = -\frac{1}{\rho} \frac{\partial p}{\partial x} + \nu \left( \frac{\partial^2 u}{\partial x^2}+\frac{\partial^2 u}{\partial y^2} \right), \\ + \text{momentum}~y \quad\quad \frac{\partial v}{\partial t} + u\frac{\partial v}{\partial x} + v\frac{\partial v}{\partial y} = -\frac{1}{\rho} \frac{\partial p}{\partial y} + \nu \left( \frac{\partial^2 v}{\partial x^2}+\frac{\partial^2 v}{\partial y^2} \right). + $$ + + + ```python + def pde(self): + """ + Define governing equations based on sympy, abstract method. + + Returns: + dict, user defined sympy symbolic equations. + """ + + # momentum convervation along x + momentum_x = self.u.diff(self.t) + self.u * self.u.diff(self.x) + self.v * self.u.diff(self.y) + \ + self.p.diff(self.x) - self.number * (diff(self.u, (self.x, 2)) + diff(self.u, (self.y, 2))) + + # momentum conservation along y + momentum_y = self.v.diff(self.t) + self.u * self.v.diff(self.x) + self.v * self.v.diff(self.y) + \ + self.p.diff(self.y) - self.number * (diff(self.v, (self.x, 2)) + diff(self.v, (self.y, 2))) + + # continuity equation + continuty = self.u.diff(self.x) + self.v.diff(self.y) + + equations = {"momentum_x": momentum_x, "momentum_y": momentum_y, "continuty": continuty} + return equations + ``` + + + + - Two-dimensional Poisson's equation (boundary values remained to be updated): + $$ + -\Delta u = f = 1. + $$ + + ```python + def pde(self): + """ + Define Poisson 2-D governing equations based on sympy, abstract method. + + Returns: + dict, user defined sympy symbolic equations. + """ + poisson = diff(self.u, (self.x, 2)) + diff(self.u, (self.y, 2)) + 1.0 + + equations = {"poisson": poisson} + return equations + ``` + + + + + +### flow_with_loss.py + +- This module defines the base class `FlowWithLoss`, which specifies the functions and data structures that a **complete, computable, and trainable** partial differential equation should possess: + + - `step(self, inputs)`: Returns a Tensor representing the prediction of the adopted model. + + - `get_loss(self, inputs, labels)`: Constructs an optimizable loss function based on the inputs (which can be sample points in the domain) and labels (which can be source term of the equation). + +- The fluid dynamics scenarios currently supported by this module include: + + - **Steady flow**: Where fluid properties at any point in the flow field are independent of time. + + - **Unsteady flow**: Where at least one point in the flow field has fluid properties that vary with time. + +- The module is compatible with various types of neural networks and custom networks. Currently, it is **primarily applied in operator learning** **methods** and does not include specific equation definitions. + - Example: Using FNO to solve the 2D incompressible Navier-Stokes equations: + https://gitee.com/mindspore/mindscience/blob/master/MindFlow/applications/data_driven/navier_stokes/fno2d/train.py + diff --git a/tests/st/mindchemistry/cell/test_orb/base.py b/tests/st/mindchemistry/cell/test_orb/base.py new file mode 100644 index 0000000000000000000000000000000000000000..12ebf5a21d0b2a32f60d24e838446b03af9276ca --- /dev/null +++ b/tests/st/mindchemistry/cell/test_orb/base.py @@ -0,0 +1,119 @@ +# ============================================================================ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Base data class.""" + +from typing import Dict, Mapping, NamedTuple, Optional, Union + +from mindspore import Tensor + + +Metric = Union[Tensor, int, float] +TensorDict = Mapping[str, Optional[Tensor]] + + +class ModelOutput(NamedTuple): + """A model's output.""" + + loss: Tensor + log: Mapping[str, Metric] + + +class AtomGraphs(NamedTuple): + """A class representing the input to a model for a graph. + + Args: + senders (ms.Tensor): The integer source nodes for each edge. + receivers (ms.Tensor): The integer destination nodes for each edge. + n_node (ms.Tensor): A (batch_size, ) shaped tensor containing the number of nodes per graph. + n_edge (ms.Tensor): A (batch_size, ) shaped tensor containing the number of edges per graph. + node_features (Dict[str, ms.Tensor]): A dictionary containing node feature tensors. + It will always contain "atomic_numbers" and "positions" keys, representing the + atomic numbers of each node, and the 3d cartesian positions of them respectively. + edge_features (Dict[str, ms.Tensor]): A dictionary containing edge feature tensors. + system_features (Optional[TensorDict]): An optional dictionary containing system-level features. + node_targets (Optional[Dict[ms.Tensor]]): An optional dict of tensors containing targets + for individual nodes. This tensor is commonly expected to have shape (num_nodes, *). + edge_target (Optional[ms.Tensor]): An optional tensor containing targets for individual edges. + This tensor is commonly expected to have (num_edges, *). + system_targets (Optional[Dict[ms.Tensor]]): An optional dict of tensors containing targets for the + entire system. system_id (Optional[ms.Tensor]): An optional tensor containing the ID of the system. + fix_atoms (Optional[ms.Tensor]): An optional tensor containing information on fixed atoms in the system. + """ + + senders: Tensor + receivers: Tensor + n_node: Tensor + n_edge: Tensor + node_features: Dict[str, Tensor] + edge_features: Dict[str, Tensor] + system_features: Dict[str, Tensor] + node_targets: Optional[Dict[str, Tensor]] = None + edge_targets: Optional[Dict[str, Tensor]] = None + system_targets: Optional[Dict[str, Tensor]] = None + system_id: Optional[Tensor] = None + fix_atoms: Optional[Tensor] = None + tags: Optional[Tensor] = None + radius: Optional[float] = None + max_num_neighbors: Optional[int] = None + + @property + def positions(self): + """Get positions of atoms.""" + return self.node_features["positions"] + + @positions.setter + def positions(self, val: Tensor): + self.node_features["positions"] = val + + @property + def atomic_numbers(self): + """Get integer atomic numbers.""" + return self.node_features["atomic_numbers"] + + @atomic_numbers.setter + def atomic_numbers(self, val: Tensor): + self.node_features["atomic_numbers"] = val + + @property + def cell(self): + """Get unit cells.""" + assert self.system_features + return self.system_features.get("cell") + + @cell.setter + def cell(self, val: Tensor): + assert self.system_features + self.system_features["cell"] = val + + def clone(self): + """Clone the AtomGraphs object.""" + return AtomGraphs( + senders=self.senders.clone(), + receivers=self.receivers.clone(), + n_node=self.n_node.clone(), + n_edge=self.n_edge.clone(), + node_features={k: v.clone() for k, v in self.node_features.items()}, + edge_features={k: v.clone() for k, v in self.edge_features.items()}, + system_features={k: v.clone() for k, v in self.system_features.items()}, + node_targets={k: v.clone() for k, v in (self.node_targets or {}).items()}, + edge_targets=self.edge_targets.clone() if self.edge_targets is not None else None, + system_targets={k: v.clone() for k, v in (self.system_targets or {}).items()}, + system_id=self.system_id.clone() if self.system_id is not None else None, + fix_atoms=self.fix_atoms.clone() if self.fix_atoms is not None else None, + tags=self.tags.clone() if self.tags is not None else None, + radius=self.radius, + max_num_neighbors=self.max_num_neighbors + ) diff --git a/tests/st/mindchemistry/cell/test_orb/test_orb.py b/tests/st/mindchemistry/cell/test_orb/test_orb.py new file mode 100644 index 0000000000000000000000000000000000000000..cbb6534088433e267d87585eda9781da29755aac --- /dev/null +++ b/tests/st/mindchemistry/cell/test_orb/test_orb.py @@ -0,0 +1,451 @@ +# ============================================================================ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test mindchemistry ORB""" + +import os +import sys +from typing import Optional +import pickle + +import requests +import pytest +import numpy as np +import mindspore +from mindspore import nn, Tensor, load_checkpoint, load_param_into_net + +from mindchemistry.cell import ( + AttentionInteractionNetwork, + MoleculeGNS, + NodeHead, + GraphHead, + EnergyHead, + Orb, +) +import base +from utils import numpy_to_tensor, tensor_to_numpy, is_equal + +# pylint: disable=C0413 +PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../")) +sys.path.append(PROJECT_ROOT) +from common.cell import compare_output + + +def load_graph_data(pkl_path: str): + """Load graph data from pickle file. + Args: + pkl_path: Path to the pickle file + Returns: + tuple: (input_graph_ms, output_graph_np) + """ + with open(pkl_path, "rb") as f: + loaded = pickle.load(f) + + input_graph_np = loaded["input_graph"] + output_graph_np = loaded["output_graph"] + + input_graph_ms = base.AtomGraphs( + *[numpy_to_tensor(getattr(input_graph_np, field)) + for field in input_graph_np._fields] + ) + + return input_graph_ms, output_graph_np + + +def get_gns( + latent_dim: int = 256, + mlp_hidden_dim: int = 512, + num_message_passing_steps: int = 15, + num_edge_in_features: int = 23, + distance_cutoff: bool = True, + attention_gate: str = "sigmoid", +) -> MoleculeGNS: + """Define the base pretrained model architecture. + Args: + latent_dim: The latent dimension of the model. + mlp_hidden_dim: The hidden dimension of the MLP layers. + num_message_passing_steps: The number of message passing steps. + num_edge_in_features: The number of edge input features. + distance_cutoff: Whether to use distance cutoff in the interaction. + attention_gate: The type of attention gate to use. + Returns: + MoleculeGNS: The MoleculeGNS model instance. + """ + return MoleculeGNS( + num_node_in_features=256, + num_node_out_features=3, + num_edge_in_features=num_edge_in_features, + latent_dim=latent_dim, + interactions="simple_attention", + interaction_params={ + "distance_cutoff": distance_cutoff, + "polynomial_order": 4, + "cutoff_rmax": 6, + "attention_gate": attention_gate, + }, + num_message_passing_steps=num_message_passing_steps, + num_mlp_layers=2, + mlp_hidden_dim=mlp_hidden_dim, + use_embedding=True, + node_feature_names=["feat"], + edge_feature_names=["feat"], + ) + + +def load_model_for_inference(model: nn.Cell, weights_path: str) -> nn.Cell: + """Load a pretrained model in inference mode. + Args: + model: The model to load the weights into. + weights_path: Path to the checkpoint file. + Returns: + nn.Cell: The model with loaded weights. + Raises: + FileNotFoundError: If the checkpoint file does not exist. + ValueError: If the checkpoint file has more parameters than the model. + """ + if not os.path.exists(weights_path): + raise FileNotFoundError(f"Checkpoint file {weights_path} not found.") + param_dict = load_checkpoint(weights_path) + + try: + load_param_into_net(model, param_dict) + except ValueError: + print("Warning: The checkpoint file has more parameters than the model. \ + This may be due to a mismatch in the model architecture or version.") + params = [] + for key in param_dict: + params.append(param_dict[key]) + for parameters in model.trainable_params(): + param_ckpt = params.pop(0) + assert parameters.shape == param_ckpt.shape, f"Shape mismatch: {parameters.name}" + param_ckpt = param_ckpt.reshape(parameters.shape) + parameters.set_data(param_ckpt) + + model.set_train(False) + return model + +def orb_v2(weights_path: Optional[str]) -> nn.Cell: + """Load ORB v2. + Args: + weights_path: Path to the checkpoint file. + Returns: + Orb GraphRegressor: The ORB v2 model instance. + """ + gns = get_gns() + + model = Orb( + graph_head=EnergyHead( + latent_dim=256, + num_mlp_layers=1, + mlp_hidden_dim=256, + target_property_dim=1, + node_aggregation="mean", + reference_energy_name="vasp-shifted", + train_reference=True, + predict_atom_avg=True, + ), + node_head=NodeHead( + latent_dim=256, + num_mlp_layers=1, + mlp_hidden_dim=256, + target_property_dim=3, + remove_mean=True, + ), + stress_head=GraphHead( + latent_dim=256, + num_mlp_layers=1, + mlp_hidden_dim=256, + target_property_dim=6, + compute_stress=True, + ), + model=gns, + ) + model = load_model_for_inference(model, weights_path) + return model + + +def orb_mptraj_only_v2( + weights_path: Optional[str] = None, +): + """Load ORB MPTraj Only v2.""" + + return orb_v2(weights_path,) + + +def download_file(url, local_filename): + """Download a file from a URL to a local path.""" + response = requests.get(url, timeout=30) + if response.status_code == 200: + with open(local_filename, 'wb') as f: + f.write(response.content) + else: + print(f"Failed to download file. HTTP Status Code: {response.status_code}") + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +def test_attn(): + """ + Feature: Test AttentionInteractionNetwork in platform ascend. + Description: The forward output should has expected shape and accuracy. + Expectation: Success or throw AssertionError. + """ + mindspore.set_context(mode=mindspore.PYNATIVE_MODE) + # prepare data + download_file( + 'https://download-mindspore.osinfra.cn/mindscience/mindchemistry/orb/test/attn_input_output.pkl', + 'attn_input_output.pkl' + ) + download_file( + 'https://download-mindspore.osinfra.cn/mindscience/mindchemistry/orb/orb_ckpts/attn_net.ckpt', + 'attn_net.ckpt' + ) + input_graph_ms, output_graph_np = load_graph_data('attn_input_output.pkl') + + attn_net = AttentionInteractionNetwork( + num_node_in=256, + num_node_out=256, + num_edge_in=256, + num_edge_out=256, + num_mlp_layers=2, + mlp_hidden_dim=512, + ) + + # load checkpoint + param_dict = load_checkpoint('attn_net.ckpt') + load_param_into_net(attn_net, param_dict) + + # inference + edges, nodes = attn_net( + input_graph_ms.edge_features, + input_graph_ms.node_features, + input_graph_ms.senders, + input_graph_ms.receivers, + ) + + # Validate results + out_node_feats = tensor_to_numpy(nodes["feat"]) + out_edge_feats = tensor_to_numpy(edges["feat"]) + out_node_feats_np = output_graph_np.node_features["feat"] + out_edge_feats_np = output_graph_np.edge_features["feat"] + + flag_node = is_equal(out_node_feats, out_node_feats_np) + flag_edge = is_equal(out_edge_feats, out_edge_feats_np) + assert flag_node, "Failed! Node features mismatch in attention network" + assert flag_edge, "Failed! Edge features mismatch in attention network" + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +def test_gns(): + """ + Feature: Test MoleculeGNS network in platform ascend. + Description: The forward output should has expected shape and accuracy. + Expectation: Success or throw AssertionError. + """ + mindspore.set_context(mode=mindspore.PYNATIVE_MODE) + download_file( + 'https://download-mindspore.osinfra.cn/mindscience/mindchemistry/orb/test/gns_input_output.pkl', + 'gns_input_output.pkl' + ) + download_file( + 'https://download-mindspore.osinfra.cn/mindscience/mindchemistry/orb/orb_ckpts/gns_net.ckpt', + 'gns_net.ckpt' + ) + input_graph_ms, output_graph_np = load_graph_data('gns_input_output.pkl') + + # load gns model and checkpoint + gns_model = get_gns() + + # load checkpoint + param_dict = load_checkpoint('gns_net.ckpt') + load_param_into_net(gns_model, param_dict) + + edges, nodes = gns_model( + input_graph_ms.edge_features, + input_graph_ms.node_features, + input_graph_ms.senders, + input_graph_ms.receivers, + ) + + out_node_feats = tensor_to_numpy(nodes["feat"]) + out_edge_feats = tensor_to_numpy(edges["feat"]) + out_node_feats_np = output_graph_np.node_features["feat"] + out_edge_feats_np = output_graph_np.edge_features["feat"] + + flag_node = is_equal(out_node_feats, out_node_feats_np) + flag_edge = is_equal(out_edge_feats, out_edge_feats_np) + assert flag_node, "Failed! Node features mismatch in MoleculeGNS network" + assert flag_edge, "Failed! Edge features mismatch in MoleculeGNS network" + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +def test_node_head(): + """ + Feature: Test NodeHead in platform ascend. + Description: The forward output should has expected shape and accuracy. + Expectation: Success or throw AssertionError. + """ + mindspore.set_context(mode=mindspore.PYNATIVE_MODE) + node_head = NodeHead( + latent_dim=256, + num_mlp_layers=1, + mlp_hidden_dim=256, + target_property_dim=3, + remove_mean=True, + ) + + n_atoms = 4 + n_node = Tensor([n_atoms], mindspore.int32) + atomic_numbers = Tensor(np.random.randint(1, 119, size=(n_atoms,), dtype=np.int32)) + atomic_numbers_embedding_np = np.zeros((n_atoms, 118), dtype=np.float32) + for i, num in enumerate(atomic_numbers.asnumpy()): + atomic_numbers_embedding_np[i, num - 1] = 1.0 + + node_features = { + "atomic_numbers": atomic_numbers, + "atomic_numbers_embedding": Tensor(atomic_numbers_embedding_np), + "positions": Tensor(np.random.randn(n_atoms, 3).astype(np.float32)), + "feat": Tensor(np.random.randn(n_atoms, 256).astype(np.float32)) + } + + output = node_head(node_features, n_node) + assert output['node_pred'].shape == (4, 3), \ + f"Expected node_pred shape (4, 3), but got {output['node_pred'].shape}" + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +def test_graph_head(): + """ + Feature: Test GraphHead in platform ascend. + Description: The forward output should has expected shape and accuracy. + Expectation: Success or throw AssertionError. + """ + mindspore.set_context(mode=mindspore.PYNATIVE_MODE) + graph_head = GraphHead( + latent_dim=256, + num_mlp_layers=1, + mlp_hidden_dim=256, + target_property_dim=6, + compute_stress=True, + ) + + n_atoms = 4 + n_node = Tensor([n_atoms], mindspore.int32) + atomic_numbers = Tensor(np.random.randint(1, 119, size=(n_atoms,), dtype=np.int32)) + atomic_numbers_embedding_np = np.zeros((n_atoms, 118), dtype=np.float32) + for i, num in enumerate(atomic_numbers.asnumpy()): + atomic_numbers_embedding_np[i, num - 1] = 1.0 + + node_features = { + "atomic_numbers": atomic_numbers, + "atomic_numbers_embedding": Tensor(atomic_numbers_embedding_np), + "positions": Tensor(np.random.randn(n_atoms, 3).astype(np.float32)), + "feat": Tensor(np.random.randn(n_atoms, 256).astype(np.float32)) + } + + output = graph_head(node_features, n_node) + assert output['stress_pred'].shape == (1, 6), \ + f"Expected stress_pred shape (1, 6), but got {output['stress_pred'].shape}" + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +def test_energy_head(): + """ + Feature: Test EnergyHead in platform ascend. + Description: The forward output should has expected shape and accuracy. + Expectation: Success or throw AssertionError. + """ + mindspore.set_context(mode=mindspore.PYNATIVE_MODE) + energy_head = EnergyHead( + latent_dim=256, + num_mlp_layers=1, + mlp_hidden_dim=256, + target_property_dim=1, + node_aggregation="mean", + reference_energy_name="vasp-shifted", + train_reference=True, + predict_atom_avg=True, + ) + + n_atoms = 4 + n_node = Tensor([n_atoms], mindspore.int32) + atomic_numbers = Tensor(np.random.randint(1, 119, size=(n_atoms,), dtype=np.int32)) + atomic_numbers_embedding_np = np.zeros((n_atoms, 118), dtype=np.float32) + for i, num in enumerate(atomic_numbers.asnumpy()): + atomic_numbers_embedding_np[i, num - 1] = 1.0 + + node_features = { + "atomic_numbers": atomic_numbers, + "atomic_numbers_embedding": Tensor(atomic_numbers_embedding_np), + "positions": Tensor(np.random.randn(n_atoms, 3).astype(np.float32)), + "feat": Tensor(np.random.randn(n_atoms, 256).astype(np.float32)) + } + + output = energy_head(node_features, n_node) + assert output['graph_pred'].shape == (1, 1), \ + f"Expected graph_pred shape {(1, 1)}, but got {output['graph_pred'].shape}" + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +def test_inference(): + """ + Feature: Test Orb network in platform ascend. + Description: The forward output should has expected shape and accuracy. + Expectation: Success or throw AssertionError. + """ + mindspore.set_context(mode=mindspore.PYNATIVE_MODE) + download_file( + 'https://download-mindspore.osinfra.cn/mindscience/mindchemistry/orb/test/orb_input_output.pkl', + 'orb_input_output.pkl' + ) + download_file( + 'https://download-mindspore.osinfra.cn/mindscience/mindchemistry/orb/orb_ckpts/orb-mptraj-only-v2.ckpt', + 'orb-mptraj-only-v2.ckpt' + ) + reference_path = 'orb_input_output.pkl' + with open(reference_path, "rb") as f: + loaded = pickle.load(f) + + atom_graph_ms = loaded["input_graph"] + output_pt = loaded["output"] + + regressor = orb_mptraj_only_v2(weights_path='orb-mptraj-only-v2.ckpt') + regressor.set_train(False) + + out_ms = regressor.predict( + atom_graph_ms.edge_features, + atom_graph_ms.node_features, + atom_graph_ms.senders, + atom_graph_ms.receivers, + atom_graph_ms.n_node, + atom_graph_ms.atomic_numbers, + ) + + out_ms = {k: tensor_to_numpy(v) for k, v in out_ms.items()} + + for k in out_ms: + flag = compare_output(out_ms[k], output_pt[k]) + assert flag, f"Failed! Orb network inference output {k} mismatch" diff --git a/tests/st/mindchemistry/cell/test_orb/utils.py b/tests/st/mindchemistry/cell/test_orb/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ad100964ae2e4addcd18be6c699012889445031a --- /dev/null +++ b/tests/st/mindchemistry/cell/test_orb/utils.py @@ -0,0 +1,105 @@ +# ============================================================================ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Utils""" +import os +import sys +from typing import Any + +import numpy as np +from mindspore import Tensor + +# pylint: disable=C0413 +PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../")) +sys.path.append(PROJECT_ROOT) +from common.cell import compare_output, FP32_ATOL, FP32_RTOL + + +def tensor_to_numpy(data: Any) -> Any: + """Convert MindSpore Tensors to NumPy arrays recursively. + This function traverses the input data structure and converts all MindSpore Tensors + to NumPy arrays, while leaving other data types unchanged. + Args: + data (Any): Input data which can be a MindSpore Tensor, dict, list, tuple, or other types. + Returns: + Any: Data structure with MindSpore Tensors converted to NumPy arrays. + """ + if isinstance(data, Tensor): + return data.numpy() + if isinstance(data, dict): + return {k: tensor_to_numpy(v) for k, v in data.items()} + if isinstance(data, (list, tuple)): + return type(data)(tensor_to_numpy(v) for v in data) + return data + + +def numpy_to_tensor(data: Any) -> Any: + """Convert NumPy arrays to MindSpore Tensors recursively. + This function traverses the input data structure and converts all NumPy arrays + to MindSpore Tensors, while leaving other data types unchanged. + Args: + data (Any): Input data which can be a NumPy array, dict, list, tuple, or other types. + Returns: + Any: Data structure with NumPy arrays converted to MindSpore Tensors. + """ + if isinstance(data, np.ndarray): + return Tensor(data) + if isinstance(data, dict): + return {k: numpy_to_tensor(v) for k, v in data.items()} + if isinstance(data, (list, tuple)): + return type(data)(numpy_to_tensor(v) for v in data) + return data + + +def is_equal(a: Any, b: Any) -> bool: + """Compare two objects for equality with special handling for different types. + + This function performs a deep comparison between two objects, supporting: + - NumPy arrays (using tolerance-based comparison) + - Dictionaries (recursive comparison of values) + - Lists and tuples (element-wise comparison) + - NamedTuples (field-wise comparison) + - Other types (using standard equality comparison) + + Args: + a (Any): First object to compare + b (Any): Second object to compare + + Returns: + bool: True if objects are considered equal, False otherwise + + Examples: + >>> is_equal(np.array([1.0]), np.array([1.0])) + True + >>> is_equal({'a': 1, 'b': 2}, {'a': 1, 'b': 2}) + True + >>> is_equal([1, 2, 3], [1, 2, 3]) + True + """ + if isinstance(a, np.ndarray) and isinstance(b, np.ndarray): + return compare_output(a, b, FP32_ATOL, FP32_RTOL) + if isinstance(a, dict) and isinstance(b, dict): + if a.keys() != b.keys(): + return False + return all(is_equal(a[k], b[k]) for k in a) + if isinstance(a, (list, tuple)) and isinstance(b, (list, tuple)): + if len(a) != len(b): + return False + return all(is_equal(x, y) for x, y in zip(a, b)) + if hasattr(a, "_fields") and hasattr(b, "_fields"): + if a._fields != b._fields: + return False + return all(is_equal(getattr(a, f), getattr(b, f)) for f in a._fields) + return a == b diff --git a/tests/st/mindflow/cell/attention/test_attention.py b/tests/st/mindflow/cell/attention/test_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..7db0b58bfd6d8dac21e6b07b5d681f494c698eca --- /dev/null +++ b/tests/st/mindflow/cell/attention/test_attention.py @@ -0,0 +1,332 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""attention testcase""" +import os +import sys +import pytest + +import numpy as np +from mindspore import Tensor, ops, load_checkpoint, load_param_into_net, jit_class, context +from mindspore import dtype as mstype + +from mindflow.cell import Attention, MultiHeadAttention, TransformerBlock, DropPath, ViT +from mindflow.core import RelativeRMSELoss + +PROJECT_ROOT = os.path.abspath(os.path.join( + os.path.dirname(__file__), "../../../")) +sys.path.append(PROJECT_ROOT) + +from common.cell import compare_output, validate_checkpoint, validate_model_infer, validate_output_dtype +from common.cell import FP32_RTOL, FP32_ATOL, FP16_RTOL, FP16_ATOL + +BATCH_SIZE, NUM_HEADS, SEQ_LEN, IN_CHANNELS = 2, 4, 15, 64 + + +def load_inputs(): + x = Tensor(np.load('input.npy').astype(np.float32)) + mask = Tensor(np.load('mask.npy').astype(np.int32)) + return x, mask + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('compute_dtype', [mstype.float16, mstype.float32]) +def test_attention_qkv(mode, compute_dtype): + """ + Feature: attention + Description: test qkv dtype and shape + Expectation: success + """ + context.set_context(mode=mode) + net = Attention(IN_CHANNELS, NUM_HEADS, compute_dtype=compute_dtype) + x = ops.randn((BATCH_SIZE, SEQ_LEN, IN_CHANNELS)) + qkv = net.get_qkv(x) + for tensor in qkv: + assert tensor.dtype == compute_dtype + assert tensor.shape == (BATCH_SIZE, NUM_HEADS, SEQ_LEN, IN_CHANNELS//NUM_HEADS) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('fa_dtype', [mstype.float16, mstype.bfloat16]) +def test_flash_attn(mode, fa_dtype): + """ + Feature: FlashAttn + Description: test forward result + Expectation: success + """ + context.set_context(mode=mode) + in_shape = (BATCH_SIZE, NUM_HEADS, SEQ_LEN, IN_CHANNELS//NUM_HEADS) + query, key, value = ops.randn(in_shape), ops.randn(in_shape), ops.randn(in_shape) + mask = ops.randint(0, 2, (SEQ_LEN, SEQ_LEN)) + net = MultiHeadAttention(IN_CHANNELS, NUM_HEADS, enable_flash_attn=True, fa_dtype=fa_dtype) + output = net.attn(query, key, value, mask) + assert output.dtype == fa_dtype + assert output.shape == in_shape + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('fa_dtype', [mstype.float16, mstype.bfloat16]) +def test_multihead_fa(mode, fa_dtype): + """ + Feature: FlashAttention + Description: test forward result + Expectation: success + """ + context.set_context(mode=mode) + net = MultiHeadAttention(IN_CHANNELS, NUM_HEADS, enable_flash_attn=True, fa_dtype=fa_dtype) + in_shape = (BATCH_SIZE, SEQ_LEN, IN_CHANNELS) + x = ops.randn(in_shape) + mask = ops.randint(0, 2, (BATCH_SIZE, 1, SEQ_LEN, SEQ_LEN)) + output = net(x, mask) + assert output.dtype == mstype.float32 + assert output.shape == in_shape + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('fa_dtype', [mstype.float16, mstype.bfloat16]) +def test_fa_forward(mode, fa_dtype): + """ + Feature: FlashAttention + Description: test FlashAttention forward result + Expectation: success + """ + context.set_context(mode=mode) + net = MultiHeadAttention(IN_CHANNELS, NUM_HEADS, enable_flash_attn=False) + fa_net = MultiHeadAttention(IN_CHANNELS, NUM_HEADS, enable_flash_attn=True, fa_dtype=fa_dtype) + batch_size, seq_len = 256, 512 + in_shape = (batch_size, seq_len, IN_CHANNELS) + x = ops.randn(in_shape) + mask = ops.randint(0, 2, (batch_size, 1, seq_len, seq_len)) + validate_checkpoint(net, fa_net, (x, mask), FP32_RTOL, FP32_ATOL) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_attention_mask1(mode): + """ + Feature: attention + Description: test attention mask function + Expectation: success + """ + context.set_context(mode=mode) + net = Attention(IN_CHANNELS, NUM_HEADS, compute_dtype=mstype.float16) + attn_mask = ops.randint(0, 2, (SEQ_LEN, SEQ_LEN)) + key_padding_mask = ops.randint(0, 2, (BATCH_SIZE, SEQ_LEN)) + mask = net.merge_mask(attn_mask, key_padding_mask) + assert mask.shape == (BATCH_SIZE, 1, SEQ_LEN, SEQ_LEN) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_attention_mask2(mode): + """ + Feature: attention + Description: test attention mask function + Expectation: success + """ + context.set_context(mode=mode) + net = Attention(IN_CHANNELS, NUM_HEADS) + attn_mask = ops.randint(0, 2, (BATCH_SIZE, 1, SEQ_LEN, SEQ_LEN)) + key_padding_mask = ops.randint(0, 2, (BATCH_SIZE, SEQ_LEN)) + mask = net.merge_mask(attn_mask, key_padding_mask) + assert mask.shape == (BATCH_SIZE, 1, SEQ_LEN, SEQ_LEN) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_multihead_attention_forward(mode): + """ + Feature: MultiHeadAttention + Description: test result dtype + Expectation: success + """ + context.set_context(mode=mode) + net_32 = MultiHeadAttention( + IN_CHANNELS, NUM_HEADS, compute_dtype=mstype.float32) + net_16 = MultiHeadAttention( + IN_CHANNELS, NUM_HEADS, compute_dtype=mstype.float16) + x, mask = load_inputs() + validate_checkpoint(net_32, net_16, (x, mask), FP16_RTOL, FP16_ATOL) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_multihead_attention(mode): + """ + Feature: MultiHeadAttention + Description: test forward result shape + Expectation: success + """ + context.set_context(mode=mode) + net = MultiHeadAttention(in_channels=IN_CHANNELS, num_heads=NUM_HEADS) + x, mask = load_inputs() + validate_model_infer(net, (x, mask), './multihead.ckpt', + './multihead_output.npy', FP32_RTOL, FP32_ATOL) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('compute_dtype', [mstype.float16, mstype.bfloat16]) +def test_multihead_attention_dtype(mode, compute_dtype): + """ + Feature: MultiHeadAttention + Description: test forward result dtype + Expectation: success + """ + context.set_context(mode=mode) + net = MultiHeadAttention( + in_channels=IN_CHANNELS, num_heads=NUM_HEADS, compute_dtype=compute_dtype) + x, mask = load_inputs() + validate_output_dtype(net, (x, mask), compute_dtype) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_attn_block(mode): + """ + Feature: TransformerBlock + Description: test forward result + Expectation: success + """ + context.set_context(mode=mode) + net = TransformerBlock(in_channels=IN_CHANNELS, num_heads=NUM_HEADS) + x, mask = load_inputs() + validate_model_infer(net, (x, mask), './attention_block.ckpt', + './attention_block_output.npy', FP32_RTOL, FP32_ATOL) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_vit_forward(mode): + """ + Feature: ViT + Description: test forward result dtype + Expectation: success + """ + context.set_context(mode=mode) + x = ops.rand(32, 3, 192, 384) + model = ViT(in_channels=3, + out_channels=3, + encoder_depths=6, + encoder_embed_dim=768, + encoder_num_heads=12, + decoder_depths=6, + decoder_embed_dim=512, + decoder_num_heads=16, + compute_dtype=mstype.float32 + ) + output = model(x) + assert output.dtype == mstype.float32 + assert output.shape == (32, 288, 768) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_droppath(mode): + """ + Feature: DropPath train eval mode + Description: test forward result shape + Expectation: success + """ + context.set_context(mode=mode) + net = DropPath() + x = np.random.rand(BATCH_SIZE, SEQ_LEN, IN_CHANNELS) + net.set_train(True) + out = net(Tensor(x)).numpy() + assert out.shape == x.shape + net.set_train(False) + out = net(Tensor(x)).numpy() + assert np.array_equal(out, x) + + +@jit_class +class Trainer: + """Trainer""" + + def __init__(self, net, loss_fn): + self.net = net + self.loss_fn = loss_fn + + def get_loss(self, data, label): + "get loss" + pred = self.net(data) + return self.loss_fn(label, pred) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_multihead_attention_grad(mode): + """ + Feature: MultiHeadAttention + Description: test backward result + Expectation: success + """ + context.set_context(mode=mode) + ckpt_path = './multihead.ckpt' + model = MultiHeadAttention( + IN_CHANNELS, NUM_HEADS, compute_dtype=mstype.float32) + params = load_checkpoint(ckpt_path) + load_param_into_net(model, params) + + input_data = Tensor(np.load('./input.npy')) + input_label = Tensor(np.load('./label.npy')) + + trainer = Trainer(model, RelativeRMSELoss()) + + def forward_fn(data, label): + loss = trainer.get_loss(data, label) + return loss + + grad_fn = ops.value_and_grad( + forward_fn, None, model.trainable_params(), has_aux=False) + + _, grads = grad_fn(input_data, input_label) + + convert_grads = tuple(grad.asnumpy() for grad in grads) + with np.load('./grads.npz') as data: + output_target = tuple(data[key] for key in data.files) + + validate_ans = compare_output( + convert_grads, output_target, rtol=1e-6, atol=1e-6) + assert validate_ans, "The verification of scaleddot grad case failed." diff --git a/tests/st/mindflow/cell/test_optimizers.py b/tests/st/mindflow/cell/test_optimizers.py new file mode 100644 index 0000000000000000000000000000000000000000..4085507fd42dcb9b85facc4ea593c5fb03b249fe --- /dev/null +++ b/tests/st/mindflow/cell/test_optimizers.py @@ -0,0 +1,275 @@ +# ============================================================================ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Optimizers Test Case""" +import os +import random +import sys + +import pytest +import numpy as np + +import mindspore as ms +from mindspore import ops, set_seed, nn, mint +from mindspore import dtype as mstype +from mindflow import UNet2D, TransformerBlock, MultiHeadAttention, AdaHessian +from mindflow.cell.attention import FeedForward +from mindflow.cell.unet2d import Down + +PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")) +sys.path.append(PROJECT_ROOT) + +# pylint: disable=wrong-import-position + +from common.cell import FP32_RTOL + +# pylint: enable=wrong-import-position + +set_seed(0) +np.random.seed(0) +random.seed(0) + + +class TestAdaHessianAccuracy(AdaHessian): + ''' Child class for testing the accuracy of AdaHessian optimizer ''' + + def gen_rand_vecs(self, grads): + ''' generate certain vector for accuracy test ''' + return [ms.Tensor(np.arange(p.size).reshape(p.shape) - p.size // 2, dtype=ms.float32) for p in grads] + + +class TestUNet2D(UNet2D): + ''' Child class for testing optimizing UNet with AdaHessian ''' + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + class TestDown(Down): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + in_channels = args[0] + kernel_size = kwargs['kernel_size'] + stride = kwargs['stride'] + # replace the `maxpool` layer in the original UNet with `conv` to avoid `vjp` problem + self.maxpool = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride) + + self.layers_down = nn.CellList() + for i in range(self.n_layers): + self.layers_down.append(TestDown(self.base_channels * 2**i, self.base_channels * 2 ** (i+1), + kernel_size=self.kernel_size, stride=self.stride, + activation=self.activation, enable_bn=self.enable_bn)) + + +class TestAttentionBlock(TransformerBlock): + ''' Child class for testing optimizing Attention with AdaHessian ''' + + def __init__(self, + in_channels: int, + num_heads: int, + enable_flash_attn: bool = False, + fa_dtype: mstype = mstype.bfloat16, + drop_mode: str = "dropout", + dropout_rate: float = 0.0, + compute_dtype: mstype = mstype.float32, + ): + super().__init__(in_channels=in_channels, + num_heads=num_heads, + enable_flash_attn=enable_flash_attn, + fa_dtype=fa_dtype, + drop_mode=drop_mode, + dropout_rate=dropout_rate, + compute_dtype=compute_dtype, + ) + + class TestMlp(FeedForward): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.act_fn = nn.ReLU() # replace `gelu` with `relu` to avoid `vjp` problem + + class TestMultiHeadAttention(MultiHeadAttention): + ''' MultiHeadAttention modified to support vjp ''' + def get_qkv(self, x: ms.Tensor) -> tuple[ms.Tensor]: + ''' use masks to select out q, k, v, instead of tensor reshaping & indexing ''' + b, n, c_full = x.shape + c = c_full // self.num_heads + + # use matmul with masks to select out q, k, v to avoid vjp problem + q_mask = ms.Tensor(np.vstack([np.eye(c), np.zeros([2 * c, c])]), dtype=self.compute_dtype) + k_mask = ms.Tensor(np.vstack([np.zeros([c, c]), np.eye(c), np.zeros([c, c])]), dtype=self.compute_dtype) + v_mask = ms.Tensor(np.vstack([np.zeros([2 * c, c]), np.eye(c)]), dtype=self.compute_dtype) + + qkv = self.qkv(x) + qkv = qkv.reshape(b, n, self.num_heads, -1).swapaxes(1, 2) + + q = mint.matmul(qkv, q_mask) + k = mint.matmul(qkv, k_mask) + v = mint.matmul(qkv, v_mask) + + return q, k, v + + self.ffn = TestMlp( + in_channels=in_channels, + dropout_rate=dropout_rate, + compute_dtype=compute_dtype, + ) + self.attention = TestMultiHeadAttention( + in_channels=in_channels, + num_heads=num_heads, + enable_flash_attn=enable_flash_attn, + fa_dtype=fa_dtype, + drop_mode=drop_mode, + dropout_rate=dropout_rate, + compute_dtype=compute_dtype, + ) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_adahessian_accuracy(mode): + """ + Feature: AdaHessian forward accuracy test + Description: Test the accuracy of the AdaHessian optimizer in both GRAPH_MODE and PYNATIVE_MODE + with input data specified in the code below. + The expected output is compared to a reference output stored in + './mindflow/core/optimizers/data/adahessian_output.npy'. + Expectation: The output should match the target data within the defined relative tolerance, + ensuring the AdaHessian computation is accurate. + """ + ms.set_context(mode=mode) + + weight_init = ms.Tensor(np.reshape(range(72), [4, 2, 3, 3]), dtype=ms.float32) + bias_init = ms.Tensor(np.arange(4), dtype=ms.float32) + + net = nn.Conv2d( + in_channels=2, out_channels=4, kernel_size=3, has_bias=True, weight_init=weight_init, bias_init=bias_init) + + def forward(a): + return ops.sqrt(ops.mean(ops.square(net(a)))) + + grad_fn = ms.grad(forward, grad_position=None, weights=net.trainable_params()) + + optimizer = TestAdaHessianAccuracy( + net.trainable_params(), + learning_rate=0.1, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.) + + inputs = ms.Tensor(np.reshape(range(100), [2, 2, 5, 5]), dtype=ms.float32) + + for _ in range(4): + optimizer(grad_fn, inputs) + + outputs = net(inputs).numpy() + outputs_ref = np.load('/home/workspace/mindspore_dataset/mindscience/mindflow/optimizers/adahessian_output.npy') + relative_error = np.max(np.abs(outputs - outputs_ref)) / np.max(np.abs(outputs_ref)) + assert relative_error < FP32_RTOL, "The verification of adahessian accuracy is not successful." + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +@pytest.mark.parametrize('model_option', ['unet', 'attention']) +def test_adahessian_st(mode, model_option): + """ + Feature: AdaHessian ST test + Description: Test the function of the AdaHessian optimizer in both GRAPH_MODE and PYNATIVE_MODE + on the complex network such as UNet. The input is a Tensor specified in the code + and the output is the loss after 4 rounds of optimization. + Expectation: The output should be finite, ensuring the AdaHessian runs successfully on UNet. + """ + ms.set_context(mode=mode) + + # default test with Attention network + net = TestAttentionBlock(in_channels=256, num_heads=4) + inputs = ms.Tensor(np.sin(np.arange(102400)).reshape(4, 100, 256), dtype=ms.float32) + + # test with UNet network + if model_option.lower() == 'unet': + net = TestUNet2D( + in_channels=2, + out_channels=4, + base_channels=8, + n_layers=4, + kernel_size=2, + stride=2, + activation='relu', + data_format="NCHW", + enable_bn=False, # bn leads to bug in PYNATIVE_MODE for MS2.5.0 + ) + inputs = ms.Tensor(np.random.rand(2, 2, 64, 64), dtype=ms.float32) + + def forward(a): + return ops.sqrt(ops.mean(ops.square(net(a)))) + + grad_fn = ms.grad(forward, grad_position=None, weights=net.trainable_params()) + + optimizer = AdaHessian( + net.trainable_params(), + learning_rate=0.1, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.) + + for _ in range(4): + optimizer(grad_fn, inputs) + + loss = forward(inputs) + assert ops.isfinite(loss) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.PYNATIVE_MODE]) +def test_adahessian_compare(mode): + """ + Feature: AdaHessian compare with Adam + Description: Compare the algorithm results of the AdaHessian optimizer with Adam. + The code runs in PYNATIVE_MODE and the network under comparison is TransformerBlock. + The optimization runs 100 rounds to demonstrate an essential loss decrease. + Expectation: The loss of AdaHessian outperforms Adam by 20% under the same configuration on an Attention network. + """ + ms.set_context(mode=mode) + + def get_loss(optimizer_option): + ''' compare Adam and AdaHessian ''' + net = TestAttentionBlock(in_channels=256, num_heads=4) + inputs = ms.Tensor(np.sin(np.arange(102400)).reshape(4, 100, 256), dtype=ms.float32) + + def forward(a): + return ops.sqrt(ops.mean(ops.square(net(a)))) + + grad_fn = ms.grad(forward, grad_position=None, weights=net.trainable_params()) + + if optimizer_option.lower() == 'adam': + optimizer = nn.Adam( + net.trainable_params(), + learning_rate=0.01, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.) + else: + optimizer = AdaHessian( + net.trainable_params(), + learning_rate=0.01, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.) + + for _ in range(20): + if optimizer_option.lower() == 'adam': + optimizer(grad_fn(inputs)) + else: + optimizer(grad_fn, inputs) + + loss = forward(inputs) + return loss + + loss_adam = get_loss('adam') + loss_adahessian = get_loss('adahessian') + + assert loss_adam * 0.8 > loss_adahessian, (loss_adam, loss_adahessian) diff --git a/tests/st/mindflow/cfd/couette/test_couette.py b/tests/st/mindflow/cfd/couette/test_couette.py new file mode 100644 index 0000000000000000000000000000000000000000..f5f7c218ba55044d2479f5eae899458cac7eb945 --- /dev/null +++ b/tests/st/mindflow/cfd/couette/test_couette.py @@ -0,0 +1,92 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""solve couette flow""" +import time +import os + +import numpy as np +import pytest + +from mindspore import numpy as mnp +from mindflow import load_yaml_config +from mindflow import cfd +from mindflow.cfd.runtime import RunTime +from mindflow.cfd.simulator import Simulator + + +def couette_ic_2d(mesh_x, mesh_y): + rho = mnp.ones_like(mesh_x) + u = mnp.zeros_like(mesh_y) + v = mnp.zeros_like(mesh_x) + w = mnp.zeros_like(mesh_x) + p = mnp.ones_like(mesh_x) + return mnp.stack([rho, u, v, w, p], axis=0) + + +def label_fun(y, t): + nu = 0.1 + h = 1.0 + u_max = 0.1 + coe = 0.0 + for i in range(1, 100): + coe += np.sin(i * np.pi * (1 - y / h)) * np.exp(-(i ** 2) * (np.pi ** 2) * nu * t / (h ** 2)) / i + return u_max * y / h - (2 * u_max / np.pi) * coe + + +def train(): + '''train and evaluate the network''' + config = load_yaml_config('{}/couette.yaml'.format(os.path.split(os.path.realpath(__file__))[0])) + + simulator = Simulator(config) + runtime = RunTime(config['runtime'], simulator.mesh_info, simulator.material) + + mesh_x, mesh_y, _ = simulator.mesh_info.mesh_xyz() + pri_var = couette_ic_2d(mesh_x, mesh_y) + con_var = cfd.cal_con_var(pri_var, simulator.material) + + dy = 1 / config['mesh']['ny'] + cell_centers = np.linspace(dy / 2, 1 - dy / 2, config['mesh']['ny']) + + start = time.time() + + while runtime.time_loop(pri_var): + runtime.compute_timestep(pri_var) + con_var = simulator.integration_step(con_var, runtime.timestep) + pri_var = cfd.cal_pri_var(con_var, simulator.material) + runtime.advance() + + label_u = label_fun(cell_centers, runtime.current_time.asnumpy()) + sim_u = pri_var.asnumpy()[1, 0, :, 0] + + err = np.abs(label_u - sim_u).sum() / np.abs(label_u).sum() + per_epoch_time = (time.time() - start) / 500 + + print(f'l1 error: {err:.10f}') + print(f'per epoch time: {per_epoch_time:.10f}') + + return err + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +def test_couette_gpu(): + """ + Feature: cfd couette test in the gpu + Description: None. + Expectation: Success or throw error when error is larger than 1 + """ + err = train() + assert err < 1.0 diff --git a/tests/st/mindflow/networks/burgers/test_burgers.py b/tests/st/mindflow/networks/burgers/test_burgers.py new file mode 100644 index 0000000000000000000000000000000000000000..407a01ad97e6aa95b48e0d19cd7334a40145f76c --- /dev/null +++ b/tests/st/mindflow/networks/burgers/test_burgers.py @@ -0,0 +1,166 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""burgers pinns st case""" +import time +import pytest + +import numpy as np +from mindspore import context, nn, ops, jit, set_seed, Tensor +import mindspore.common.dtype as mstype + +from model import Burgers1D + + +set_seed(123456) +np.random.seed(123456) + + +def _calculate_error(label, prediction): + """calculate l2 error""" + error = label - prediction + l2_error = np.sqrt( + np.sum(np.square(error[..., 0]))) / np.sqrt(np.sum(np.square(label[..., 0]))) + + return l2_error + + +def _get_prediction(model, inputs, label_shape, batch_size): + """get prediction""" + prediction = np.zeros(label_shape) + prediction = prediction.reshape((-1, label_shape[1])) + inputs = inputs.reshape((-1, inputs.shape[1])) + + index = 0 + while index < inputs.shape[0]: + index_end = min(index + batch_size, inputs.shape[0]) + test_batch = Tensor(inputs[index: index_end, :], mstype.float32) + prediction[index: index_end, :] = model(test_batch).asnumpy() + index = index_end + + prediction = prediction.reshape(label_shape) + prediction = prediction.reshape((-1, label_shape[1])) + return prediction + + +def calculate_l2_error(model, inputs, label, batch_size): + """calculate evaluation error""" + label_shape = label.shape + prediction = _get_prediction(model, inputs, label_shape, batch_size) + label = label.reshape((-1, label_shape[1])) + l2_error = _calculate_error(label, prediction) + return l2_error + + +class Net(nn.Cell): + """MLP""" + def __init__(self, in_channels=2, hidden_channels=128, out_channels=1): + super().__init__() + self.act = nn.Tanh() + self.layers = nn.SequentialCell( + nn.Dense(in_channels, hidden_channels, activation=self.act), + nn.Dense(hidden_channels, hidden_channels, activation=self.act), + nn.Dense(hidden_channels, hidden_channels, activation=self.act), + nn.Dense(hidden_channels, hidden_channels, activation=self.act), + nn.Dense(hidden_channels, out_channels) + ) + + def construct(self, x): + return self.layers(x) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +def test_mindflow_burgers_pinns(): + """ + Feature: burgers pinns + Description: test train and eval + Expectation: success + """ + context.set_context(mode=context.GRAPH_MODE, jit_config={"jit_level": "O2"}) + model = Net() + optimizer = nn.Adam(model.trainable_params(), 0.0001) + problem = Burgers1D(model) + use_ascend = context.get_context(attr_key='device_target') == "Ascend" + if use_ascend: + from mindspore.amp import DynamicLossScaler, auto_mixed_precision, all_finite + loss_scaler = DynamicLossScaler(1024, 2, 100) + auto_mixed_precision(model, 'O3') + else: + loss_scaler = None + + pde_data = Tensor([[-0.02541629, 0.12696983], + [0.30243418, 0.96671784], + [0.6249878, 0.260476], + [-0.61371905, 0.8972365], + [0.70554, 0.37674972]], mstype.float32) + + ic_data = Tensor([[0.1678119, 0.], + [-0.45064327, 0.], + [0.01379196, 0.], + [0.40799928, 0.], + [0.13942307, 0.]], mstype.float32) + bc_data = Tensor([[1., 0.1909238], + [1., 0.70078486], + [-1., 0.70864534], + [1., 0.7291773], + [1., 0.30929238]], mstype.float32) + inputs = Tensor([[-1., 0.], + [-0.99215686, 0.], + [-0.9843137, 0.], + [-0.9764706, 0.], + [-0.96862745, 0.]], mstype.float32) + + label = np.array([[1.22464680e-16], + [2.46374492e-02], + [4.92599411e-02], + [7.38525275e-02], + [9.84002783e-02]], np.float32) + + def forward_fn(pde_data, ic_data, bc_data): + loss = problem.get_loss(pde_data, ic_data, bc_data) + if use_ascend: + loss = loss_scaler.scale(loss) + + return loss + + grad_fn = ops.value_and_grad( + forward_fn, None, optimizer.parameters, has_aux=False) + + @jit + def train_step(pde_data, ic_data, bc_data): + loss, grads = grad_fn(pde_data, ic_data, bc_data) + if use_ascend: + loss = loss_scaler.unscale(loss) + if all_finite(grads): + grads = loss_scaler.unscale(grads) + + loss = ops.depend(loss, optimizer(grads)) + return loss + + for epoch in range(1, 1 + 10): + model.set_train(True) + time_beg = time.time() + train_loss = train_step(pde_data, ic_data, bc_data) + epoch_time = time.time() - time_beg + print(f"epoch: {epoch} train loss: {train_loss} epoch time: {epoch_time}s") + + model.set_train(False) + eval_error = calculate_l2_error(model, inputs, label, 5) + print("eval_error:", eval_error) + + assert epoch_time < 0.05 + assert train_loss < 0.6 + assert eval_error < 0.8 diff --git a/tests/st/mindflow/networks/ffno/test_ffno.py b/tests/st/mindflow/networks/ffno/test_ffno.py new file mode 100644 index 0000000000000000000000000000000000000000..8cda98808dc60934f3e44d050638a093de4340d3 --- /dev/null +++ b/tests/st/mindflow/networks/ffno/test_ffno.py @@ -0,0 +1,381 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""mindflow st testcase""" + +import os +import sys +import time + +import pytest +import numpy as np + +import mindspore as ms +from mindspore import nn, Tensor, set_seed, load_param_into_net, load_checkpoint +from mindspore import dtype as mstype + +from mindflow.cell import FFNO1D, FFNO2D, FFNO3D + +PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../")) +sys.path.append(PROJECT_ROOT) + +# pylint: disable=wrong-import-position + +from common.cell.utils import compare_output +from common.cell import FP32_RTOL + +# pylint: enable=wrong-import-position + +set_seed(123456) +folder_path = "/home/workspace/mindspore_dataset/mindscience/ffno" + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_ffno1d_output(mode): + """ + Feature: Test FFNO1D network in platform ascend. + Description: None. + Expectation: Success or throw AssertionError. + """ + ms.set_context(mode=mode) + model1d = FFNO1D(in_channels=2, + out_channels=2, + n_modes=[2], + resolutions=[6], + hidden_channels=2, + n_layers=2, + share_weight=True, + r_padding=8, + ffno_compute_dtype=mstype.float32) + + data1d = Tensor(np.load(os.path.join(folder_path, "ffno_data1d.npy")), dtype=mstype.float32) + param1d = load_checkpoint(os.path.join(folder_path, "ffno1d.ckpt")) + load_param_into_net(model1d, param1d) + output1d = model1d(data1d) + target1d = np.load(os.path.join(folder_path, "ffno_target1d.npy")) + + assert output1d.shape == (2, 6, 2) + assert output1d.dtype == mstype.float32 + assert compare_output(output1d.asnumpy(), target1d, rtol=FP32_RTOL, atol=FP32_RTOL) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_ffno1d_mse_loss_output(mode): + """ + Feature: Test FFNO1D MSE Loss in platform ascend. + Description: None. + Expectation: Success or throw AssertionError. + """ + ms.set_context(mode=mode) + model1d = FFNO1D(in_channels=2, + out_channels=2, + n_modes=[2], + resolutions=[6], + hidden_channels=2, + n_layers=2, + share_weight=True, + r_padding=8, + ffno_compute_dtype=mstype.float32) + + data1d = Tensor(np.ones((2, 6, 2)), dtype=mstype.float32) + label_1d = Tensor(np.ones((2, 6, 2)), dtype=mstype.float32) + param1d = load_checkpoint(os.path.join(folder_path, "ffno1d.ckpt")) + load_param_into_net(model1d, param1d) + + loss_fn = nn.MSELoss() + optimizer_1d = nn.SGD(model1d.trainable_params(), learning_rate=0.01) + net_with_loss_1d = nn.WithLossCell(model1d, loss_fn) + train_step_1d = nn.TrainOneStepCell(net_with_loss_1d, optimizer_1d) + + # calculate two steps of loss + loss_1d = train_step_1d(data1d, label_1d) + target_loss_1_1d = 0.63846040 + assert compare_output(loss_1d.asnumpy(), target_loss_1_1d, rtol=FP32_RTOL, atol=FP32_RTOL) + + loss_1d = train_step_1d(data1d, label_1d) + target_loss_2_1d = 0.04462930 + assert compare_output(loss_1d.asnumpy(), target_loss_2_1d, rtol=FP32_RTOL, atol=FP32_RTOL) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_ffno2d_output(mode): + """ + Feature: Test FFNO2D network in platform ascend. + Description: None. + Expectation: Success or throw AssertionError. + """ + ms.set_context(mode=mode) + model2d = FFNO2D(in_channels=2, + out_channels=2, + n_modes=[2, 2], + resolutions=[6, 6], + hidden_channels=2, + n_layers=2, + share_weight=True, + r_padding=8, + ffno_compute_dtype=mstype.float32) + + data2d = Tensor(np.load(os.path.join(folder_path, "ffno_data2d.npy")), dtype=mstype.float32) + param2d = load_checkpoint(os.path.join(folder_path, "ffno2d.ckpt")) + load_param_into_net(model2d, param2d) + output2d = model2d(data2d) + target2d = np.load(os.path.join(folder_path, "ffno_target2d.npy")) + + assert output2d.shape == (2, 6, 6, 2) + assert output2d.dtype == mstype.float32 + assert compare_output(output2d.asnumpy(), target2d, rtol=FP32_RTOL, atol=FP32_RTOL) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_ffno2d_mse_loss_output(mode): + """ + Feature: Test FFNO2D MSE Loss in platform ascend. + Description: None. + Expectation: Success or throw AssertionError. + """ + ms.set_context(mode=mode) + model2d = FFNO2D(in_channels=2, + out_channels=2, + n_modes=[2, 2], + resolutions=[6, 6], + hidden_channels=2, + n_layers=2, + share_weight=True, + r_padding=8, + ffno_compute_dtype=mstype.float32) + + data2d = Tensor(np.ones((2, 6, 6, 2)), dtype=mstype.float32) + label_2d = Tensor(np.ones((2, 6, 6, 2)), dtype=mstype.float32) + param2d = load_checkpoint(os.path.join(folder_path, "ffno2d.ckpt")) + load_param_into_net(model2d, param2d) + + loss_fn = nn.MSELoss() + optimizer_2d = nn.SGD(model2d.trainable_params(), learning_rate=0.01) + net_with_loss_2d = nn.WithLossCell(model2d, loss_fn) + train_step_2d = nn.TrainOneStepCell(net_with_loss_2d, optimizer_2d) + + # calculate two steps of loss + loss_2d = train_step_2d(data2d, label_2d) + target_loss_1_2d = 1.70347130 + assert compare_output(loss_2d.asnumpy(), target_loss_1_2d, rtol=FP32_RTOL, atol=FP32_RTOL) + + loss_2d = train_step_2d(data2d, label_2d) + target_loss_2_2d = 0.28143430 + assert compare_output(loss_2d.asnumpy(), target_loss_2_2d, rtol=FP32_RTOL, atol=FP32_RTOL) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_ffno3d_output(mode): + """ + Feature: Test FFNO3D network in platform ascend. + Description: None. + Expectation: Success or throw AssertionError. + """ + ms.set_context(mode=mode) + model3d = FFNO3D(in_channels=2, + out_channels=2, + n_modes=[2, 2, 2], + resolutions=[6, 6, 6], + hidden_channels=2, + n_layers=2, + share_weight=True, + r_padding=8, + ffno_compute_dtype=mstype.float32) + + data3d = Tensor(np.load(os.path.join(folder_path, "ffno_data3d.npy")), dtype=mstype.float32) + param3d = load_checkpoint(os.path.join(folder_path, "ffno3d.ckpt")) + load_param_into_net(model3d, param3d) + output3d = model3d(data3d) + target3d = np.load(os.path.join(folder_path, "ffno_target3d.npy")) + + assert output3d.shape == (2, 6, 6, 6, 2) + assert output3d.dtype == mstype.float32 + assert compare_output(output3d.asnumpy(), target3d, rtol=FP32_RTOL, atol=FP32_RTOL) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_ffno3d_mse_loss_output(mode): + """ + Feature: Test FFNO3D MSE Loss in platform ascend. + Description: None. + Expectation: Success or throw AssertionError. + """ + ms.set_context(mode=mode) + model3d = FFNO3D(in_channels=2, + out_channels=2, + n_modes=[2, 2, 2], + resolutions=[6, 6, 6], + hidden_channels=2, + n_layers=2, + share_weight=True, + r_padding=8, + ffno_compute_dtype=mstype.float32) + + data3d = Tensor(np.ones((2, 6, 6, 6, 2)), dtype=mstype.float32) + label_3d = Tensor(np.ones((2, 6, 6, 6, 2)), dtype=mstype.float32) + param3d = load_checkpoint(os.path.join(folder_path, "ffno3d.ckpt")) + load_param_into_net(model3d, param3d) + + loss_fn = nn.MSELoss() + optimizer_3d = nn.SGD(model3d.trainable_params(), learning_rate=0.01) + net_with_loss_3d = nn.WithLossCell(model3d, loss_fn) + train_step_3d = nn.TrainOneStepCell(net_with_loss_3d, optimizer_3d) + + # calculate two steps of loss + loss_3d = train_step_3d(data3d, label_3d) + target_loss_1_3d = 1.94374371 + assert compare_output(loss_3d.asnumpy(), target_loss_1_3d, rtol=FP32_RTOL, atol=FP32_RTOL) + + loss_3d = train_step_3d(data3d, label_3d) + target_loss_2_3d = 0.24034855 + assert compare_output(loss_3d.asnumpy(), target_loss_2_3d, rtol=FP32_RTOL, atol=FP32_RTOL) + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_ffno1d_speed(mode): + """ + Feature: Test FFNO1D training speed in platform ascend. + Description: The speed of each training step. + Expectation: Success or throw AssertionError. + """ + ms.set_context(mode=mode) + model1d = FFNO1D(in_channels=32, + out_channels=32, + n_modes=[16], + resolutions=[128], + hidden_channels=2, + n_layers=2, + share_weight=True, + r_padding=8, + ffno_compute_dtype=mstype.float32) + + data1d = Tensor(np.ones((32, 128, 32)), dtype=mstype.float32) + label_1d = Tensor(np.ones((32, 128, 32)), dtype=mstype.float32) + + loss_fn = nn.MSELoss() + optimizer_1d = nn.SGD(model1d.trainable_params(), learning_rate=0.01) + net_with_loss_1d = nn.WithLossCell(model1d, loss_fn) + train_step_1d = nn.TrainOneStepCell(net_with_loss_1d, optimizer_1d) + + steps = 10 + for _ in range(10): + train_step_1d(data1d, label_1d) + + start_time = time.time() + for _ in range(10): + train_step_1d(data1d, label_1d) + end_time = time.time() + + assert (end_time - start_time) / steps < 0.5 + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_ffno2d_speed(mode): + """ + Feature: Test FFNO2D training speed in platform ascend. + Description: The speed of each training step. + Expectation: Success or throw AssertionError. + """ + ms.set_context(mode=mode) + model2d = FFNO2D(in_channels=32, + out_channels=32, + n_modes=[16, 16], + resolutions=[64, 64], + hidden_channels=2, + n_layers=2, + share_weight=True, + r_padding=8, + ffno_compute_dtype=mstype.float32) + + data2d = Tensor(np.ones((32, 64, 64, 32)), dtype=mstype.float32) + label_2d = Tensor(np.ones((32, 64, 64, 32)), dtype=mstype.float32) + + loss_fn = nn.MSELoss() + optimizer_2d = nn.SGD(model2d.trainable_params(), learning_rate=0.01) + net_with_loss_2d = nn.WithLossCell(model2d, loss_fn) + train_step_2d = nn.TrainOneStepCell(net_with_loss_2d, optimizer_2d) + + steps = 10 + for _ in range(steps): + train_step_2d(data2d, label_2d) + + start_time = time.time() + for _ in range(steps): + train_step_2d(data2d, label_2d) + end_time = time.time() + + assert (end_time - start_time) / steps < 1 + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +def test_ffno3d_speed(mode): + """ + Feature: Test FFNO3D training speed in platform ascend. + Description: The speed of each training step. + Expectation: Success or throw AssertionError. + """ + ms.set_context(mode=mode) + model3d = FFNO3D(in_channels=2, + out_channels=2, + n_modes=[16, 16, 16], + resolutions=[32, 32, 32], + hidden_channels=2, + n_layers=2, + share_weight=True, + r_padding=8, + ffno_compute_dtype=mstype.float32) + + data3d = Tensor(np.ones((2, 32, 32, 32, 2)), dtype=mstype.float32) + label_3d = Tensor(np.ones((2, 32, 32, 32, 2)), dtype=mstype.float32) + + loss_fn = nn.MSELoss() + optimizer_3d = nn.SGD(model3d.trainable_params(), learning_rate=0.01) + net_with_loss_3d = nn.WithLossCell(model3d, loss_fn) + train_step_3d = nn.TrainOneStepCell(net_with_loss_3d, optimizer_3d) + + steps = 10 + for _ in range(steps): + train_step_3d(data3d, label_3d) + + start_time = time.time() + for _ in range(steps): + train_step_3d(data3d, label_3d) + end_time = time.time() + + assert (end_time - start_time) / steps < 3 diff --git a/tests/st/mindflow/networks/fno/test_fno.py b/tests/st/mindflow/networks/fno/test_fno.py new file mode 100644 index 0000000000000000000000000000000000000000..7fbc0c437d57b553a2a592c494dadcc3ceccb3f8 --- /dev/null +++ b/tests/st/mindflow/networks/fno/test_fno.py @@ -0,0 +1,95 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""mindflow st testcase""" + +import pytest +import numpy as np + +from mindspore import Tensor, context, set_seed, load_param_into_net, load_checkpoint +from mindspore import dtype as mstype +from mindflow.cell import FNO1D, FNO2D, FNO3D +from mindflow.cell.neural_operators.fno_sp import SpectralConv1dDft, SpectralConv2dDft, SpectralConv3dDft + +RTOL = 0.001 +set_seed(123456) + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +def test_fno_output(): + """ + Feature: Test FNO1D, FNO2D and FNO3D network in platform gpu and ascend. + Description: None. + Expectation: Success or throw AssertionError. + Need to adaptive 910B + """ + context.set_context(mode=context.GRAPH_MODE) + model1d = FNO1D( + in_channels=2, out_channels=2, n_modes=[2], resolutions=[6], fno_compute_dtype=mstype.float32) + model2d = FNO2D( + in_channels=2, out_channels=2, n_modes=[2, 2], resolutions=[6, 6], fno_compute_dtype=mstype.float32) + model3d = FNO3D( + in_channels=2, out_channels=2, n_modes=[2, 2, 2], resolutions=[6, 6, 6], fno_compute_dtype=mstype.float32) + data1d = Tensor(np.ones((2, 6, 2)), dtype=mstype.float32) + data2d = Tensor(np.ones((2, 6, 6, 2)), dtype=mstype.float32) + data3d = Tensor(np.ones((2, 6, 6, 6, 2)), dtype=mstype.float32) + output1d = model1d(data1d) + output2d = model2d(data2d) + output3d = model3d(data3d) + assert output1d.shape == (2, 6, 2) + assert output1d.dtype == mstype.float32 + assert output2d.shape == (2, 6, 6, 2) + assert output2d.dtype == mstype.float32 + assert output3d.shape == (2, 6, 6, 6, 2) + assert output3d.dtype == mstype.float32 + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +def test_spectralconvdft_output(): + """ + Feature: Test SpectralConv1dDft, SpectralConv2dDft and SpectralConv3dDft network in platform gpu and ascend. + Description: None. + Expectation: Success or throw AssertionError. + """ + context.set_context(mode=context.GRAPH_MODE) + model1d = SpectralConv1dDft(in_channels=2, out_channels=2, n_modes=[2], resolutions=[6]) + model2d = SpectralConv2dDft(in_channels=2, out_channels=2, n_modes=[2, 2], resolutions=[6, 6]) + model3d = SpectralConv3dDft(in_channels=2, out_channels=2, n_modes=[2, 2, 2], resolutions=[6, 6, 6]) + data1d = Tensor(np.ones((2, 2, 6)), dtype=mstype.float32) + data2d = Tensor(np.ones((2, 2, 6, 6)), dtype=mstype.float32) + data3d = Tensor(np.ones((2, 2, 6, 6, 6)), dtype=mstype.float32) + target1d = 3.64671636 + target2d = 35.93239212 + target3d = 149.64256287 + param1 = load_checkpoint("./spectralconv1d.ckpt") + param2 = load_checkpoint("./spectralconv2d.ckpt") + param3 = load_checkpoint("./spectralconv3d.ckpt") + load_param_into_net(model1d, param1) + load_param_into_net(model2d, param2) + load_param_into_net(model3d, param3) + output1d = model1d(data1d) + output2d = model2d(data2d) + output3d = model3d(data3d) + assert output1d.shape == (2, 2, 6) + assert output1d.dtype == mstype.float32 + assert output1d.sum() - target1d < RTOL + assert output2d.shape == (2, 2, 6, 6) + assert output2d.dtype == mstype.float32 + assert output2d.sum() - target2d < RTOL + assert output3d.shape == (2, 2, 6, 6, 6) + assert output3d.dtype == mstype.float32 + assert output3d.sum() - target3d < RTOL diff --git a/tests/st/mindflow/networks/navier_stokes/test_mindflow_navier_stokes.py b/tests/st/mindflow/networks/navier_stokes/test_mindflow_navier_stokes.py new file mode 100644 index 0000000000000000000000000000000000000000..1a8c827c9cedc8f48d8b063442e1c3236807b7e4 --- /dev/null +++ b/tests/st/mindflow/networks/navier_stokes/test_mindflow_navier_stokes.py @@ -0,0 +1,124 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""navier stokes pinns st case""" +import time +import pytest + +import numpy as np +from mindspore import context, nn, ops, jit, set_seed, Tensor +import mindspore.common.dtype as mstype + +from model import NavierStokes2D + + +set_seed(123456) +np.random.seed(123456) + + +class Net(nn.Cell): + """MLP""" + + def __init__(self, in_channels=2, hidden_channels=128, out_channels=1, act=nn.Tanh()): + super().__init__() + self.act = act + self.layers = nn.SequentialCell( + nn.Dense(in_channels, hidden_channels, activation=self.act), + nn.Dense(hidden_channels, hidden_channels, activation=self.act), + nn.Dense(hidden_channels, hidden_channels, activation=self.act), + nn.Dense(hidden_channels, out_channels) + ) + + def construct(self, x): + return self.layers(x) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +def test_mindflow_navier_stokes(): + """ + Feature: navier_stokes pinns + Description: test train + Expectation: success + """ + context.set_context(mode=context.GRAPH_MODE, jit_config={"jit_level": "O2"}) + model = Net(in_channels=3, out_channels=3) + optimizer = nn.Adam(model.trainable_params(), 0.0001) + problem = NavierStokes2D(model) + use_ascend = context.get_context(attr_key='device_target') == "Ascend" + if use_ascend: + from mindspore.amp import DynamicLossScaler, auto_mixed_precision, all_finite + loss_scaler = DynamicLossScaler(1024, 2, 100) + auto_mixed_precision(model, 'O3') + else: + loss_scaler = None + + pde_data = Tensor([[4.4814544, -1.6147294, 3.8416946], + [5.8124804, -0.49786586, 5.0063257], + [1.1191559, 0.9042227, 4.2193437], + [1.9051491, 1.1916666, 3.8141823], + [2.8169591, 0.3456305, 2.9655836]], mstype.float32) + bc_data = Tensor([[3.8282828, -2., 1.4], + [6.3030305, -2., 1.4], + [1., 1.1020408, 3.6], + [5.8080807, -2., 6.7], + [3.5454545, -2., 4.]], mstype.float32) + ic_data = Tensor([[6.6565657, -0.2857143, 0.], + [7.7171717, 1.1836735, 0.], + [2.6262627, -0.4489796, 0.], + [2.909091, 1.3469387, 0.], + [1.2121212, -0.04081633, 0.]], mstype.float32) + bc_label = Tensor([[0.9934991, -0.06462386], + [1.0738759, 0.14259282], + [1.2844703, -0.06666262], + [1.0665872, 0.14132853], + [1.0729637, 0.1342909]], mstype.float32) + ic_label = Tensor([[0.6541988, 0.26443157, -0.0937218], + [1.0388365, -0.32561874, -0.04165602], + [0.13536283, 0.00210919, -0.06288609], + [1.0713681, -0.24921523, -0.08316418], + [-0.20542848, 0.19257492, -0.39120102]], mstype.float32) + + def forward_fn(pde_data, bc_data, bc_label, ic_data, ic_label): + loss = problem.get_loss(pde_data, bc_data, bc_label, ic_data, ic_label) + if use_ascend: + loss = loss_scaler.scale(loss) + return loss + + grad_fn = ops.value_and_grad( + forward_fn, None, optimizer.parameters, has_aux=False) + + @jit + def train_step(pde_data, bc_data, bc_label, ic_data, ic_label): + loss, grads = grad_fn(pde_data, bc_data, bc_label, ic_data, ic_label) + if use_ascend: + loss = loss_scaler.unscale(loss) + if all_finite(grads): + grads = loss_scaler.unscale(grads) + + loss = ops.depend(loss, optimizer(grads)) + return loss + epochs = 10 + for epoch in range(1, 1 + epochs): + model.set_train(True) + time_beg = time.time() + train_loss = train_step(pde_data, bc_data, bc_label, ic_data, ic_label) + epoch_time = time.time() - time_beg + print( + f"epoch: {epoch} train loss: {train_loss} epoch time: {epoch_time}s") + model.set_train(False) + + assert epoch_time < 0.05 + assert train_loss < 0.8 diff --git a/tests/st/mindflow/networks/test_vit.py b/tests/st/mindflow/networks/test_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..262050bc1e6fad393f28f58604f1e27aaaef4c7f --- /dev/null +++ b/tests/st/mindflow/networks/test_vit.py @@ -0,0 +1,54 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""mindflow ut testcase""" + +import pytest +import numpy as np + +from mindspore import Tensor, context +from mindspore import dtype as mstype +from mindflow.cell import ViT + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +def test_vit_output(): + """ + Feature: Test ViT network in platform gpu and ascend. + Description: None. + Expectation: Success or throw AssertionError. + Need to adaptive 910B + """ + context.set_context(mode=context.GRAPH_MODE) + input_tensor = Tensor(np.ones((32, 3, 192, 384)), mstype.float32) + print('input_tensor.shape: ', input_tensor.shape) + print('input_tensor.dtype: ', input_tensor.dtype) + + model = ViT(in_channels=3, + out_channels=3, + encoder_depths=6, + encoder_embed_dim=768, + encoder_num_heads=12, + decoder_depths=6, + decoder_embed_dim=512, + decoder_num_heads=16, + ) + + output_tensor = model(input_tensor) + print('output_tensor.shape: ', output_tensor.shape) + print('output_tensor.dtype: ', output_tensor.dtype) + assert output_tensor.shape == (32, 288, 768) + assert output_tensor.dtype == mstype.float32 diff --git a/tests/st/mindflow/operators/test_dft.py b/tests/st/mindflow/operators/test_dft.py new file mode 100644 index 0000000000000000000000000000000000000000..e001cee122bc3a033bccb9e55309d7006dbb4a4e --- /dev/null +++ b/tests/st/mindflow/operators/test_dft.py @@ -0,0 +1,127 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test mindflow dft""" + +import torch +import numpy as np +import pytest + +import mindspore as ms +from mindspore import ops +from mindflow.cell.neural_operators.dft import dft1, dft2, idft1, idft2 + + +def dft_1d_torch(x, dim=-1): + x = torch.Tensor(x) + + x_re_im = torch.fft.fft(x, dim=dim, norm="ortho") + x_re, x_im = x_re_im.real, x_re_im.imag + return x_re.numpy(), x_im.numpy() + + +def dft_2d_torch(x, dim=-1): + x = torch.Tensor(x) + + x_re_im = torch.fft.rfft2(x, dim=dim, norm="ortho") + x_re, x_im = x_re_im.real, x_re_im.imag + return x_re.numpy(), x_im.numpy() + + +def idft_1d_torch(x_re, x_im, dim=-1): + x = torch.stack([torch.Tensor(x_re), torch.Tensor(x_im)], dim=-1) + x = torch.view_as_complex(x) + x = torch.fft.ifft(x, norm="ortho", dim=dim) + return x.numpy() + + +def idft_2d_torch(x_re, x_im, dim=-1): + x = torch.stack([torch.Tensor(x_re), torch.Tensor(x_im)], dim=-1) + x = torch.view_as_complex(x) + x = torch.fft.irfft2(x, norm="ortho", dim=dim) + return x.numpy() + + +def dft_1d_ms(x, shape, mode, dim=(-1,)): + x = ms.Tensor(x) + x_re = x + x_im = ops.zeros_like(x_re) + dft1_cell = dft1(shape=shape, modes=mode, dim=dim) + x_ft_re, x_ft_im = dft1_cell((x_re, x_im)) + return x_ft_re.asnumpy(), x_ft_im.asnumpy() + + +def dft_2d_ms(x, shape, mode, dim=(-1,)): + x = ms.Tensor(x) + x_re = x + x_im = ops.zeros_like(x_re) + dft2_cell = dft2(shape=shape, modes=mode, dim=dim) + x_ft_re, x_ft_im = dft2_cell((x_re, x_im)) + return x_ft_re.asnumpy(), x_ft_im.asnumpy() + + +def idft_1d_ms(x_re, x_im, shape, mode, dim=(-1)): + x_re = ms.Tensor(x_re) + x_im = ms.Tensor(x_im) + idft1_cell = idft1(shape=shape, modes=mode, dim=dim) + x_ms, _ = idft1_cell((x_re, x_im)) + return x_ms.asnumpy() + + +def idft_2d_ms(x_re, x_im, shape, mode, dim=(-1)): + x_re = ms.Tensor(x_re) + x_im = ms.Tensor(x_im) + idft2_cell = idft2(shape=shape, modes=mode, dim=dim) + x_ms, _ = idft2_cell((x_re, x_im)) + return x_ms.asnumpy() + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +def test_dft1d(): + """ + Feature: Test dft1d in platform gpu and ascend. + Description: None. + Expectation: Success or throw AssertionError. + Torch problem, need to adaptive 910B + """ + x = np.random.randn(1, 6, 8, 2) + x_re_torch1d, x_im_torch1d = dft_1d_torch(x, dim=-2) + x_re_ms1d, x_im_ms1d = dft_1d_ms(x, shape=(8,), mode=5, dim=(-2,)) + + x_torch1d = idft_1d_torch(x_re_torch1d, x_im_torch1d, dim=-2) + x_ms1d = idft_1d_ms(x_re_ms1d, x_im_ms1d, shape=(8,), mode=5, dim=(-2,)) + + assert np.sum(x_torch1d - x_ms1d) < 0.001 + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +def test_dft2d(): + """ + Feature: Test dft2d in platform gpu and ascend. + Description: None. + Expectation: Success or throw AssertionError. + Torch problem, need to adaptive 910B + """ + x = np.random.randn(1, 6, 8, 2) + x_re_torch2d, x_im_torch2d = dft_2d_torch(x, dim=(-3, -2)) + x_re_ms2d, x_im_ms2d = dft_2d_ms(x, shape=(6, 8), mode=(3, 5), dim=(-3, -2)) + + x_torch2d = idft_2d_torch(x_re_torch2d, x_im_torch2d, dim=(-3, -2)) + x_ms2d = idft_2d_ms(x_re_ms2d, x_im_ms2d, shape=(6, 8), mode=(3, 5), dim=(-3, -2)) + + assert np.sum(x_torch2d - x_ms2d) < 0.001 diff --git a/tests/st/mindflow/operators/test_fourier.py b/tests/st/mindflow/operators/test_fourier.py new file mode 100644 index 0000000000000000000000000000000000000000..188e247206e784ac3910b38d4c19b5f27ac98513 --- /dev/null +++ b/tests/st/mindflow/operators/test_fourier.py @@ -0,0 +1,248 @@ +# ============================================================================ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Optimizers Test Case""" +import os +import random +import sys +from time import time as toc +import pytest +import numpy as np +from scipy.fft import dct, dst +import mindspore as ms +from mindspore import set_seed, ops +from mindflow import DFTn, IDFTn, RDFTn, IRDFTn, DCT, IDCT, DST, IDST + +PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")) +sys.path.append(PROJECT_ROOT) + +# pylint: disable=wrong-import-position + +from common.cell import FP32_RTOL, FP16_RTOL, FP32_ATOL, FP16_ATOL +from common.cell.utils import compare_output + +# pylint: enable=wrong-import-position + +set_seed(0) +np.random.seed(0) +random.seed(0) + + +def gen_input(shape=(5, 6, 4, 8), rand_test=True): + ''' Generate random or deterministic tensor for input of the tests + ''' + a = np.random.rand(*shape) + 1j * np.random.rand(*shape) + if not rand_test: + a = sum([np.arange(n).reshape([n] + [1] * j) for j, n in enumerate(shape[::-1])]) + 1j * \ + sum([np.arange(n).reshape([n] + [1] * j) for j, n in enumerate(shape[::-1])]) + ar, ai = (ms.Tensor(a.real, dtype=ms.float32), ms.Tensor(a.imag, dtype=ms.float32)) + return a, ar, ai + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('device_target', ['CPU', 'Ascend']) +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +@pytest.mark.parametrize('ndim', [1, 2, 3]) +@pytest.mark.parametrize('compute_dtype', [ms.float32, ms.float16]) +def test_rdft_accuracy(device_target, mode, ndim, compute_dtype): + """ + Feature: Test RDFTn & IRDFTn accuracy + Description: Input random tensor, compare the results of RDFTn and IRDFTn with numpy results + Expectation: The output tensors should be equal within tolerance + """ + ms.set_context(device_target=device_target, mode=mode) + a, ar, _ = gen_input() + shape = a.shape + + b = np.fft.rfftn(a.real, s=a.shape[-ndim:], axes=range(-ndim, 0)) + br, bi = RDFTn(shape[-ndim:], compute_dtype=compute_dtype)(ar) + cr = IRDFTn(shape[-ndim:], compute_dtype=compute_dtype)(br, bi) + + rtol = FP32_RTOL if compute_dtype == ms.float32 else FP16_RTOL * 10 + atol = FP32_ATOL if compute_dtype == ms.float32 else FP16_ATOL * 20 + + assert compare_output(br.numpy(), b.real, rtol, atol) + assert compare_output(bi.numpy(), b.imag, rtol, atol) + assert compare_output(cr.numpy(), a.real, rtol, atol) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('device_target', ['CPU', 'Ascend']) +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +@pytest.mark.parametrize('ndim', [1, 2, 3]) +@pytest.mark.parametrize('compute_dtype', [ms.float32, ms.float16]) +def test_dft_accuracy(device_target, mode, ndim, compute_dtype): + """ + Feature: Test DFTn & IDFTn accuracy + Description: Input random tensor, compare the results of DFTn and IDFTn with numpy results + Expectation: The output tensors should be equal within tolerance + """ + ms.set_context(device_target=device_target, mode=mode) + a, ar, ai = gen_input() + shape = a.shape + + b = np.fft.fftn(a, s=a.shape[-ndim:], axes=range(-ndim, 0)) + br, bi = DFTn(shape[-ndim:], compute_dtype=compute_dtype)(ar, ai) + cr, ci = IDFTn(shape[-ndim:], compute_dtype=compute_dtype)(br, bi) + + rtol = FP32_RTOL if compute_dtype == ms.float32 else FP16_RTOL * 10 + atol = FP32_ATOL if compute_dtype == ms.float32 else FP16_ATOL * 20 + + assert compare_output(br.numpy(), b.real, rtol, atol) + assert compare_output(bi.numpy(), b.imag, rtol, atol) + assert compare_output(cr.numpy(), a.real, rtol, atol) + assert compare_output(ci.numpy(), a.imag, rtol, atol) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('device_target', ['CPU', 'Ascend']) +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +@pytest.mark.parametrize('compute_dtype', [ms.float32, ms.float16]) +def test_dct_accuracy(device_target, mode, compute_dtype): + """ + Feature: Test DCT & IDCT accuracy + Description: Input random tensor, compare the results of DCT and IDCT with numpy results + Expectation: The output tensors should be equal within tolerance + """ + ms.set_context(device_target=device_target, mode=mode) + a, ar, _ = gen_input() + shape = a.shape + + b = dct(a.real) + br = DCT(shape[-1:], compute_dtype=compute_dtype)(ar) + cr = IDCT(shape[-1:], compute_dtype=compute_dtype)(br) + + rtol = FP32_RTOL if compute_dtype == ms.float32 else FP16_RTOL * 10 + atol = FP32_ATOL if compute_dtype == ms.float32 else FP16_ATOL * 20 + + assert compare_output(br.numpy(), b.real, rtol, atol) + assert compare_output(cr.numpy(), a.real, rtol, atol) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('device_target', ['CPU', 'Ascend']) +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +@pytest.mark.parametrize('compute_dtype', [ms.float32, ms.float16]) +def test_dst_accuracy(device_target, mode, compute_dtype): + """ + Feature: Test DST & IDST accuracy + Description: Input random tensor, compare the results of DST and IDST with numpy results + Expectation: The output tensors should be equal within tolerance + """ + ms.set_context(device_target=device_target, mode=mode) + a, ar, _ = gen_input() + shape = a.shape + + b = dst(a.real) + br = DST(shape[-1:], compute_dtype=compute_dtype)(ar) + cr = IDST(shape[-1:], compute_dtype=compute_dtype)(br) + + rtol = FP32_RTOL if compute_dtype == ms.float32 else FP16_RTOL * 10 + atol = FP32_ATOL if compute_dtype == ms.float32 else FP16_ATOL * 20 + + assert compare_output(br.numpy(), b.real, rtol, atol) + assert compare_output(cr.numpy(), a.real, rtol, atol) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('device_target', ['Ascend']) +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +@pytest.mark.parametrize('ndim', [1, 2, 3]) +def test_dft_speed(device_target, mode, ndim): + """ + Feature: Test DFTn & IDFTn speed + Description: Input random tensor, clock the time of 10 runs of the + gradient function containing DFT & iDFT operators + Expectation: The average time of each run should be within 0.5s + """ + # test dftn & idftn speed + ms.set_context(device_target=device_target, mode=mode) + a, ar, ai = gen_input(shape=(64, 128, 256)) + shape = a.shape + + warmup_steps = 10 + timed_steps = 10 + + dft_cell = DFTn(shape[-ndim:]) + idft_cell = IDFTn(shape[-ndim:]) + + def forward_fn(xr, xi): + br, bi = dft_cell(xr, xi) + cr, ci = idft_cell(br, bi) + return ops.sum(cr * cr + ci * ci) + + grad_fn = ms.value_and_grad(forward_fn, grad_position=(0, 1)) + + # warmup run + for _ in range(warmup_steps): + _, (g1, g2) = grad_fn(ar, ai) + ar = ar - .1 * g1 + ai = ai - .1 * g2 + + # timed run + tic = toc() + for _ in range(timed_steps): + _, (g1, g2) = grad_fn(ar, ai) + ar = ar - .1 * g1 + ai = ai - .1 * g2 + + assert (toc() - tic) / timed_steps < 0.5 + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('device_target', ['CPU', 'Ascend']) +@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE]) +@pytest.mark.parametrize('ndim', [1, 2, 3]) +@pytest.mark.parametrize('compute_dtype', [ms.float32, ms.float16]) +def test_dft_grad(device_target, mode, ndim, compute_dtype): + """ + Feature: Test the correctness of DFTn & IDFTn grad calculation + Description: Input random tensor, compare the autograd results with theoretic solutions + Expectation: The autograd results should be equal to theoretic solutions + """ + ms.set_context(device_target=device_target, mode=mode) + a, ar, ai = gen_input() + shape = a.shape + + dft_cell = DFTn(shape[-ndim:], compute_dtype=compute_dtype) + + def forward_fn(xr, xi): + yr, yi = dft_cell(xr, xi) + return ops.sum(yr * yr + yi * yi) + + grad_fn = ms.value_and_grad(forward_fn, grad_position=(0, 1)) + _, (g1, g2) = grad_fn(ar, ai) + + # analytic solution of the gradient + b = np.fft.fftn(a, s=a.shape[-ndim:], axes=range(-ndim, 0)) + g = np.fft.ifftn(b, s=a.shape[-ndim:], axes=range(-ndim, 0)) * 2 * np.prod(a.shape[-ndim:]) + + rtol = FP32_RTOL if compute_dtype == ms.float32 else FP16_RTOL * 10 + atol = FP32_ATOL if compute_dtype == ms.float32 else FP16_ATOL * 500 # grad func leads to larger error + + assert compare_output(g1.numpy(), g.real, rtol, atol) + assert compare_output(g2.numpy(), g.imag, rtol, atol) diff --git a/tests/st/mindsponge/test_megaprotein/test_megaprotein.py b/tests/st/mindsponge/test_megaprotein/test_megaprotein.py new file mode 100644 index 0000000000000000000000000000000000000000..b8bce8d5f4a1e3567212a43b5aa5ddbe89bc35fa --- /dev/null +++ b/tests/st/mindsponge/test_megaprotein/test_megaprotein.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test MEGAFoldProtein examples.""" +import numpy as np +import mindspore as ms +from mindsponge import PipeLine + +ms.set_context(mode=ms.GRAPH_MODE) + +# MEGA-EvoGen推理获取蛋白质生成MSA后的特征 +fasta = "GYDKDLCEWSMTADQTEVETQIEADIMNIVKRDRPEMKAEVQKQLKSGGVMQYNYVLYCDKNFNNKNIIAEVVGE" +msa_generator = PipeLine(name="MEGAEvoGen") +msa_generator.set_device_id(0) +msa_generator.initialize(key="evogen_predict_256") +msa_generator.model.from_pretrained() +msa_feature = msa_generator.predict(fasta) + +# MEGA-Fold推理获取蛋白质结构信息 +fold_prediction = PipeLine(name="MEGAFold") +fold_prediction.set_device_id(0) +fold_prediction.initialize(key="predict_256") +fold_prediction.model.from_pretrained() +final_atom_positions, final_atom_mask, aatype, _, _ = fold_prediction.model.predict(msa_feature) + +# MEGA-Assessment对蛋白质结构进行评价 +protein_assessment = PipeLine(name="MEGAAssessment") +protein_assessment.set_device_id(0) +protein_assessment.initialize("predict_256") +protein_assessment.model.from_pretrained() +msa_feature['decoy_aatype'] = np.pad(aatype, (0, 256 - aatype.shape[0])) +msa_feature['decoy_atom_positions'] = np.pad(final_atom_positions, + ((0, 256 - final_atom_positions.shape[0]), (0, 0), (0, 0))) +msa_feature['decoy_atom_mask'] = np.pad(final_atom_mask, ((0, 256 - final_atom_mask.shape[0]), (0, 0))) + +res = protein_assessment.model.predict(msa_feature) +print("score is:", np.mean(res[:msa_feature['num_residues']]))