当运行network.set_train()
时,虽然network.training
这个flag已经recursively修改了,但是network.phase
没有正确修改,子网络的network.xx.phase
还是最初的值。
Hardware Environment(Ascend
/GPU
/CPU
) / 硬件环境: GPU
Software Environment / 软件环境 (Mandatory / 必填):
-- MindSpore version (e.g., 1.7.0.Bxxx) : 2.0.0rc1
-- Python version (e.g., Python 3.7.5) : 3.9.16
-- OS platform and distribution (e.g., Linux Ubuntu 16.04):Linux
Excute Mode / 执行模式 (Mandatory / 必填)(PyNative
/Graph
): PyNative
#test the set_train function
import mindspore.nn as nn
class Backbone(nn.Cell):
def __init__(self):
super().__init__()
self.fc = nn.Dense(100, 100)
def construct(self, x):
x = self.fc(x)
return x
class Head(nn.Cell):
def __init__(self):
super().__init__()
self.fc = nn.Dense(100, 100)
def construct(self, x):
x = self.fc(x)
return x
class Network(nn.Cell):
def __init__(self):
super().__init__()
self.backbone = Backbone()
self.head = Head()
def construct(self, x):
x = self.backbone(x)
x = self.head(x)
return x
network = Network()
network.set_train(False)
print(f"network.phase: {network.phase}\tnetwork.backbone.phase: {network.backbone.phase}\t network.head.phase: {network.head.phase}")
print(f"network.training: {network.training}\tnetwork.backbone.training: {network.backbone.training}\tnetwork.head.training: {network.head.training}")
运行上述代码
子网络的phase都改成了predict
network.phase: predict network.backbone.phase: train network.head.phase: train
network.training: False network.backbone.training: False network.head.training: False
Please assign maintainer to check this issue.
请为此issue分配处理人。
@fangwenyi @chengxiaoli @wuweikang
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。
Please add labels (comp or sig), also you can visit https://gitee.com/mindspore/community/blob/master/sigs/dx/docs/labels.md to find more.
为了让代码尽快被审核,请您为Pull Request打上 组件(comp)或兴趣组(sig) 标签,打上标签的PR可直接推送给责任人进行审核。
更多的标签可以查看https://gitee.com/mindspore/community/blob/master/sigs/dx/docs/labels.md
以组件相关代码提交为例,如果你提交的是data组件代码,你可以这样评论:
//comp/data
当然你也可以邀请data SIG组来审核代码,可以这样写:
//sig/data
另外你还可以给这个PR标记类型,例如是bugfix或者是特性需求:
//kind/bug or //kind/feature
恭喜你,你已经学会了使用命令来打标签,接下来就在下面的评论里打上标签吧!
您好,问题我们已复现,正在分析中
您好,network.phase是一个内部参数,不会影响训练和推理的状态
是的,这不算是一个bug。只是希望training和flag这两个flag是同步变化的,以免造成用户的误解。
登录 后才可以发表评论