1 Star 0 Fork 5

hjy / keyPointsDetectionMethodWithTorch

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
train_KPDEM_model.py 2.33 KB
一键复制 编辑 原始数据 按行查看 历史
unknown 提交于 2021-06-18 15:23 . KPDM_v2.1
from keypoints_Net import CoordRegression
from data_process import *
import torch.optim as optim
import dsntnn
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def calculate_loss(epochs=100):
from data_process_fashion_hw import blouseDataset
dataloader = DataLoader(dataset=blouseDataset, batch_size=8, shuffle=True, drop_last=True)
# dataloader_val = DataLoader(dataset=blouseDataset, batch_size=2, shuffle=True)
model = CoordRegression(n_locations=13)
optimizer = optim.RMSprop(model.parameters(), lr=2e-4, alpha=0.85)
# optimizer = optim.RMSprop(model.parameters(), lr=2.5e-4)
model = torch.nn.DataParallel(model).cuda()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# start training
for epoch in range(epochs):
model.train()
print("Epoch: {}/{}".format(epoch + 1, epochs))
optimizer.zero_grad()
for i_batch, data in enumerate(dataloader):
img, landmarks = data
landmarks = landmarks[:, :, 0:2]
landmarks = (landmarks*2+1)/64.0 - 1
img = img.to(device)
landmarks = torch.tensor(landmarks, dtype=torch.float32)
landmarks = landmarks.to(device)
# 每张图像训练连续两次和三次分别保存基数和偶数模型
# forward pass
coords, heatmaps = model(img)
# per-location euclidean losses
euc_losses = dsntnn.euclidean_losses(coords, landmarks)
# print("predict", heatmaps.shape)
# per-location regulation losses
reg_losses = dsntnn.js_reg_losses(heatmaps, landmarks, sigma_t=1.0)
# combine losses into an overall loss
loss = dsntnn.average_loss(euc_losses + reg_losses)
# Calculate gradients
optimizer.zero_grad()
loss.backward()
train_loss = loss.data
# Update model parameters with RMSprop
optimizer.step()
print(str(i_batch),
# ',euc_losses,', euc_losses.data,
# ',reg_losses,', reg_losses.data,
',current_loss,{:.3f}'.format(loss.data))
torch.save(model, r'D:/TanHaiyan/Models/KPDEM/trouser_kp/' + 'blouse_kp' + str(epoch) + '.pth')
if __name__ == "__main__":
calculate_loss()
print("The end!")
1
https://gitee.com/hjycodehjy/key-points-detection-method-with-torch.git
git@gitee.com:hjycodehjy/key-points-detection-method-with-torch.git
hjycodehjy
key-points-detection-method-with-torch
keyPointsDetectionMethodWithTorch
master

搜索帮助