220 Star 945 Fork 694

GVPMindSpore/mindscience

Create your Gitee Account
Explore and code with more than 13.5 million developers,Free private repositories !:)
Sign up
文件
Clone or Download
mindflow.cell.DiffusionTrainer.rst 1.65 KB
Copy Edit Raw Blame History
xf233 authored 2025-04-11 17:08 +08:00 . update flow doc

mindflow.cell.DiffusionTrainer

.. py:class:: mindflow.cell.DiffusionTrainer(model, scheduler, objective='pred_noise', p2_loss_weight_gamma=0., p2_loss_weight_k=1, loss_type='l1')

    扩散模型训练控制实现。

    参数:
        - **model** (nn.Cell) - 用于扩散模型训练的骨干网络。
        - **scheduler** (DiffusionScheduler) - 用于训练的噪声控制器。
        - **objective** (str) - 扩散模型预测结果的类型。默认值: ``pred_noise`` 。支持以下类型: ``pred_noise`` , ``pred_v`` 和 ``pred_x0`` 。
        - **p2_loss_weight_gamma** (float) - p2 loss权重 `gamma` ,具体信息查看 `Perception Prioritized Training of Diffusion Models <https://arxiv.org/abs/2204.00227>`_ 。默认值: ``0.0`` 。
        - **p2_loss_weight_k** (float) - p2 loss权重 `k` ,具体信息查看 `Perception Prioritized Training of Diffusion Models <https://arxiv.org/abs/2204.00227>`_ 。默认值: ``1`` 。
        - **loss_type** (str) - loss函数类型。默认值: ``l1`` 。支持以下类型: ``l1`` 和 ``l2`` 。

    异常:
        - **TypeError** - 如果 `scheduler` 不是 `DiffusionScheduler` 类型。

    .. py:method:: get_loss(original_samples: Tensor, noise: Tensor, timesteps: Tensor, condition: Tensor = None)

        计算扩散过程的前向loss。

        参数:
            - **original_samples** (Tensor) - 原始样本。
            - **noise** (Tensor) - 随机噪声。
            - **timesteps** (Tensor) - 时间步。
            - **condition** (Tensor) - 控制条件。默认值: ``None`` 。

        返回:
            - Tensor - 前向loss。
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/mindspore/mindscience.git
git@gitee.com:mindspore/mindscience.git
mindspore
mindscience
mindscience
r0.7

Search