2.3K Star 8.1K Fork 4.3K

GVPMindSpore / mindspore

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
mindspore.train.RunContext.rst 6.22 KB
一键复制 编辑 原始数据 按行查看 历史
宦晓玲 提交于 2024-04-18 09:38 . modify the error links

mindspore.train.RunContext

.. py:class:: mindspore.train.RunContext(original_args)

    保存和管理模型的相关信息。

    `RunContext` 主要用于收集训练或推理过程中模型的上下文相关信息并作为入参传入callback对象中来实现信息的共享。

    Callback的类方法中,调用 `RunContext.original_args()` 可以获取模型当前的上下文信息,用户也可以为此信息添加额外的自定义属性,同时 `request_stop()` 方法可以控制训练过程的停止。具体用法请查看 `回调机制Callback <https://www.mindspore.cn/tutorials/zh-CN/master/advanced/model/callback.html>`_。

    `RunContext.original_args()` 存储的模型信息为一个字典型变量,在训练和推理过程会存储不同的属性。详情如下:

    +--------------------------+------------------------+---------------------------------------+
    |   训练过程支持的属性     |   推理过程支持的属性   |               说明                    |
    +==========================+========================+=======================================+
    |   train_network          |                        |    包含了优化器和损失的训练网络       |
    +--------------------------+------------------------+---------------------------------------+
    |   epoch_num              |                        |      训练的epoch数                    |
    +--------------------------+------------------------+---------------------------------------+
    |  train_dataset           |                        |         训练集                        |
    +--------------------------+------------------------+---------------------------------------+
    |   loss_fn                |                        |         损失函数                      |
    +--------------------------+------------------------+---------------------------------------+
    |   optimizer              |                        |         优化器                        |
    +--------------------------+------------------------+---------------------------------------+
    |  parallel_mode           |                        |         并行模式                      |
    +--------------------------+------------------------+---------------------------------------+
    |   device_number          |                        |         设备编号                      |
    +--------------------------+------------------------+---------------------------------------+
    |   train_dataset_element  |                        |         当前step的训练数据            |
    +--------------------------+------------------------+---------------------------------------+
    |  last_save_ckpt_step     |                        |      最后一次存储ckpt的step           |
    +--------------------------+------------------------+---------------------------------------+
    |  latest_ckpt_file        |                        |            ckpt文件名                 |
    +--------------------------+------------------------+---------------------------------------+
    |   cur_epoch_num          |                        |          当前的epoch                  |
    +--------------------------+------------------------+---------------------------------------+
    |                          |  eval_network          |          评估网络                     |
    +--------------------------+------------------------+---------------------------------------+
    |                          |  valid_dataset         |          验证集                       |
    +--------------------------+------------------------+---------------------------------------+
    |                          |   metrics              |          评估指标                     |
    +--------------------------+------------------------+---------------------------------------+
    |   mode                   |   mode                 |        "train"或"eval"模式            |
    +--------------------------+------------------------+---------------------------------------+
    |  batch_num               |   batch_num            |        训练或推理的batch数            |
    +--------------------------+------------------------+---------------------------------------+
    |   list_callback          |   list_callback        |        回调列表                       |
    +--------------------------+------------------------+---------------------------------------+
    |   network                |    network             |       基础的网络结构                  |
    +--------------------------+------------------------+---------------------------------------+
    |  cur_step_num            |    cur_step_num        |       当前的训练或推理的step          |
    +--------------------------+------------------------+---------------------------------------+
    |   dataset_sink_mode      |    dataset_sink_mode   |       训练或推理的数据是否下沉        |
    +--------------------------+------------------------+---------------------------------------+
    |   net_outputs            |      net_outputs       |       训练或推理的网络输出            |
    +--------------------------+------------------------+---------------------------------------+

    参数:
        - **original_args** (dict) - 模型的相关信息。

    .. py:method:: get_stop_requested()

        获取是否停止训练的标志。

        返回:
            bool,如果为True,则 `Model.train()` 停止迭代。

    .. py:method:: original_args()

        获取模型相关信息的对象。

        返回:
            dict,含有模型的相关信息的对象。

        教程样例:
            - `回调机制 Callback - 自定义回调机制
              <https://mindspore.cn/tutorials/zh-CN/master/advanced/model/callback.html#自定义回调机制>`_

    .. py:method:: request_stop()

        在训练期间设置停止请求。

        可以使用此函数请求停止训练。 `Model.train()` 会检查是否调用此函数。

        教程样例:
            - `回调机制 Callback - 自定义终止训练
              <https://mindspore.cn/tutorials/zh-CN/master/advanced/model/callback.html#自定义终止训练>`_
Python
1
https://gitee.com/mindspore/mindspore.git
git@gitee.com:mindspore/mindspore.git
mindspore
mindspore
mindspore
master

搜索帮助