代码拉取完成,页面将自动刷新
同步操作将从 Tanhaiyan/keyPointsDetectionMethodWithTorch 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
from keypoints_Net import CoordRegression
from data_process import *
import torch.optim as optim
import dsntnn
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def calculate_loss(epochs=100):
from data_process_fashion_hw import blouseDataset
dataloader = DataLoader(dataset=blouseDataset, batch_size=8, shuffle=True, drop_last=True)
# dataloader_val = DataLoader(dataset=blouseDataset, batch_size=2, shuffle=True)
model = CoordRegression(n_locations=13)
optimizer = optim.RMSprop(model.parameters(), lr=2e-4, alpha=0.85)
# optimizer = optim.RMSprop(model.parameters(), lr=2.5e-4)
model = torch.nn.DataParallel(model).cuda()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# start training
for epoch in range(epochs):
model.train()
print("Epoch: {}/{}".format(epoch + 1, epochs))
optimizer.zero_grad()
for i_batch, data in enumerate(dataloader):
img, landmarks = data
landmarks = landmarks[:, :, 0:2]
landmarks = (landmarks*2+1)/64.0 - 1
img = img.to(device)
landmarks = torch.tensor(landmarks, dtype=torch.float32)
landmarks = landmarks.to(device)
# 每张图像训练连续两次和三次分别保存基数和偶数模型
# forward pass
coords, heatmaps = model(img)
# per-location euclidean losses
euc_losses = dsntnn.euclidean_losses(coords, landmarks)
# print("predict", heatmaps.shape)
# per-location regulation losses
reg_losses = dsntnn.js_reg_losses(heatmaps, landmarks, sigma_t=1.0)
# combine losses into an overall loss
loss = dsntnn.average_loss(euc_losses + reg_losses)
# Calculate gradients
optimizer.zero_grad()
loss.backward()
train_loss = loss.data
# Update model parameters with RMSprop
optimizer.step()
print(str(i_batch),
# ',euc_losses,', euc_losses.data,
# ',reg_losses,', reg_losses.data,
',current_loss,{:.3f}'.format(loss.data))
torch.save(model, r'D:/TanHaiyan/Models/KPDEM/trouser_kp/' + 'blouse_kp' + str(epoch) + '.pth')
if __name__ == "__main__":
calculate_loss()
print("The end!")
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。