220 Star 946 Fork 694

GVPMindSpore/mindscience

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
mindflow.cell.DiffusionTransformer.rst 1.20 KB
一键复制 编辑 原始数据 按行查看 历史
俞涵 提交于 2025-04-23 15:33 +08:00 . modify api format error

mindflow.cell.DiffusionTransformer

.. py:class:: mindflow.cell.DiffusionTransformer(in_channels, out_channels, hidden_channels, layers, heads, time_token_cond=True, compute_dtype=mstype.float32)

    以Transformer作为骨干网络的扩散模型。

    参数:
        - **in_channels** (int) - 输入特征维度。
        - **out_channels** (int) - 输出特征维度。
        - **hidden_channels** (int) - 隐藏层特征维度。
        - **layers** (int) - `Transformer` 层数。
        - **heads** (int) - 注意力头数。
        - **time_token_cond** (bool) - 是否将时间作为作为条件token。 Default: ``True`` 。
        - **compute_dtype** (mindspore.dtype) - 计算数据类型。支持 ``mstype.float32`` or ``mstype.float16`` 。 默认值: ``mstype.float32`` ,表示 ``mindspore.float32`` 。

    输入:
        - **x** (Tensor) - 网络输入。shape为 :math:`(batch\_size, sequence\_len, in\_channels)` 的Tensor。
        - **timestep** (Tensor) - 时间步。shape为 :math:`(batch\_size,)` 的Tensor。

    输出:
        - **output** (Tensor) - shape为 :math:`(batch\_size, sequence\_len, out\_channels)` 的Tensor。
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/mindspore/mindscience.git
git@gitee.com:mindspore/mindscience.git
mindspore
mindscience
mindscience
r0.7

搜索帮助