Ai
107 Star 891 Fork 1.4K

MindSpore/models
暂停

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
eval.py 16.13 KB
一键复制 编辑 原始数据 按行查看 历史
zhaoting 提交于 2023-06-13 15:33 +08:00 . update README
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Object Recognition eval."""
import os
import time
import math
from pprint import pformat
import numpy as np
import cv2
from mindspore.common import dtype as mstype
import mindspore.dataset.transforms as transforms
import mindspore.dataset.vision as vision
import mindspore.dataset as de
from mindspore import Tensor, context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.backbone.resnet import get_backbone
from src.my_logging import get_logger
from model_utils.config import config
from model_utils.moxing_adapter import moxing_wrapper
from model_utils.device_adapter import get_device_id, get_device_num, get_rank_id
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, device_id=get_device_id())
class TxtDataset:
"""TxtDataset"""
def __init__(self, root_all, filenames):
super(TxtDataset, self).__init__()
self.imgs = []
self.labels = []
for root, filename in zip(root_all, filenames):
fin = open(filename, "r")
for line in fin:
self.imgs.append(os.path.join(root, line.strip().split(" ")[0]))
self.labels.append(line.strip())
fin.close()
def __getitem__(self, index):
try:
img = cv2.cvtColor(cv2.imread(self.imgs[index]), cv2.COLOR_BGR2RGB)
except:
print(self.imgs[index])
raise
return img, index
def __len__(self):
return len(self.imgs)
def get_all_labels(self):
return self.labels
class DistributedSampler:
"""DistributedSampler"""
def __init__(self, dataset):
self.dataset = dataset
self.num_replicas = 1
self.rank = 0
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
def __iter__(self):
indices = list(range(len(self.dataset)))
indices = indices[self.rank :: self.num_replicas]
return iter(indices)
def __len__(self):
return self.num_samples
def get_dataloader(img_predix_all, img_list_all, batch_size, img_transforms):
dataset = TxtDataset(img_predix_all, img_list_all)
sampler = DistributedSampler(dataset)
dataset_column_names = ["image", "index"]
ds = de.GeneratorDataset(dataset, column_names=dataset_column_names, sampler=sampler)
ds = ds.map(input_columns=["image"], operations=img_transforms)
ds = ds.batch(batch_size, num_parallel_workers=8, drop_remainder=False)
return ds, len(dataset), dataset.get_all_labels()
def generate_test_pair(jk_list, zj_list):
"""generate_test_pair"""
file_paths = [jk_list, zj_list]
jk_dict = {}
zj_dict = {}
jk_zj_dict_list = [jk_dict, zj_dict]
for path, x_dict in zip(file_paths, jk_zj_dict_list):
with open(path, "r") as fr:
for line in fr:
label = line.strip().split(" ")[1]
tmp = x_dict.get(label, [])
tmp.append(line.strip())
x_dict[label] = tmp
zj2jk_pairs = []
for key in jk_dict:
jk_file_list = jk_dict[key]
zj_file_list = zj_dict[key]
for zj_file in zj_file_list:
zj2jk_pairs.append([zj_file, jk_file_list])
return zj2jk_pairs
def check_minmax(args, data, min_value=0.99, max_value=1.01):
min_data = data.min()
max_data = data.max()
if np.isnan(min_data) or np.isnan(max_data):
args.logger.info("ERROR, nan happened, please check if used fp16 or other error")
raise Exception
if min_data < min_value or max_data > max_value:
args.logger.info(
"ERROR, min or max is out if range, range=[{}, {}], minmax=[{}, {}]".format(
min_value, max_value, min_data, max_data
)
)
raise Exception
def get_model(args):
"""get_model"""
net = get_backbone(args)
net.add_flags_recursive(fp16=True)
if args.weight.endswith(".ckpt"):
param_dict = load_checkpoint(args.weight)
param_dict_new = {}
for key, value in param_dict.items():
if key.startswith("moments."):
continue
elif key.startswith("network."):
param_dict_new[key[8:]] = value
else:
param_dict_new[key] = value
load_param_into_net(net, param_dict_new)
args.logger.info("INFO, ------------- load model success--------------")
else:
args.logger.info("ERROR, not support file:{}, please check weight in config.py".format(args.weight))
return 0
if args.device_target == "GPU":
net.to_float(mstype.float32)
net.set_train(False)
return net
def topk(matrix, k, axis=1):
"""topk"""
if axis == 0:
row_index = np.arange(matrix.shape[1 - axis])
topk_index = np.argpartition(-matrix, k, axis=axis)[0:k, :]
topk_data = matrix[topk_index, row_index]
topk_index_sort = np.argsort(-topk_data, axis=axis)
topk_data_sort = topk_data[topk_index_sort, row_index]
topk_index_sort = topk_index[0:k, :][topk_index_sort, row_index]
else:
column_index = np.arange(matrix.shape[1 - axis])[:, None]
topk_index = np.argpartition(-matrix, k, axis=axis)[:, 0:k]
topk_data = matrix[column_index, topk_index]
topk_index_sort = np.argsort(-topk_data, axis=axis)
topk_data_sort = topk_data[column_index, topk_index_sort]
topk_index_sort = topk_index[:, 0:k][column_index, topk_index_sort]
return topk_data_sort, topk_index_sort
def cal_topk(args, idx, zj2jk_pairs, test_embedding_tot, dis_embedding_tot):
"""cal_topk"""
args.logger.info("start idx:{} subprocess...".format(idx))
correct = np.array([0] * 2)
tot = np.array([0])
zj, jk_all = zj2jk_pairs[idx]
zj_embedding = test_embedding_tot[zj]
jk_all_embedding = np.concatenate([np.expand_dims(test_embedding_tot[jk], axis=0) for jk in jk_all], axis=0)
args.logger.info("INFO, calculate top1 acc index:{}, zj_embedding shape:{}".format(idx, zj_embedding.shape))
args.logger.info("INFO, calculate top1 acc index:{}, jk_all_embedding shape:{}".format(idx, jk_all_embedding.shape))
test_time = time.time()
mm = np.matmul(np.expand_dims(zj_embedding, axis=0), dis_embedding_tot)
top100_jk2zj = np.squeeze(topk(mm, 100)[0], axis=0)
top100_zj2jk = topk(np.matmul(jk_all_embedding, dis_embedding_tot), 100)[0]
test_time_used = time.time() - test_time
args.logger.info(
"INFO, calculate top1 acc index:{}, np.matmul().top(100) time used:{:.2f}s".format(idx, test_time_used)
)
tot[0] = len(jk_all)
for i, jk in enumerate(jk_all):
jk_embedding = test_embedding_tot[jk]
similarity = np.dot(jk_embedding, zj_embedding)
if similarity > top100_jk2zj[0]:
correct[0] += 1
if similarity > top100_zj2jk[i, 0]:
correct[1] += 1
return correct, tot
def l2normalize(features):
epsilon = 1e-12
l2norm = np.sum(np.abs(features) ** 2, axis=1, keepdims=True) ** (1.0 / 2)
l2norm[np.logical_and(l2norm < 0, l2norm > -epsilon)] = -epsilon
l2norm[np.logical_and(l2norm >= 0, l2norm < epsilon)] = epsilon
return features / l2norm
def modelarts_pre_process():
"""modelarts pre process function."""
def unzip(zip_file, save_dir):
import zipfile
s_time = time.time()
if not os.path.exists(os.path.join(save_dir, "face_recognition_dataset")):
zip_isexist = zipfile.is_zipfile(zip_file)
if zip_isexist:
fz = zipfile.ZipFile(zip_file, "r")
data_num = len(fz.namelist())
print("Extract Start...")
print("unzip file num: {}".format(data_num))
i = 0
for file in fz.namelist():
if i % int(data_num / 100) == 0:
print("unzip percent: {}%".format(i / int(data_num / 100)), flush=True)
i += 1
fz.extract(file, save_dir)
print(
"cost time: {}min:{}s.".format(
int((time.time() - s_time) / 60), int(int(time.time() - s_time) % 60)
)
)
print("Extract Done.")
else:
print("This is not zip.")
else:
print("Zip has been extracted.")
if config.need_modelarts_dataset_unzip:
zip_file_1 = os.path.join(config.data_path, "face_recognition_dataset.zip")
save_dir_1 = os.path.join(config.data_path)
sync_lock = "/tmp/unzip_sync.lock"
# Each server contains 8 devices as most.
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print("Zip file path: ", zip_file_1)
print("Unzip file save dir: ", save_dir_1)
unzip(zip_file_1, save_dir_1)
print("===Finish extract data synchronization===")
try:
os.mknod(sync_lock)
except IOError:
pass
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
config.ckpt_path = os.path.join(config.output_path, str(get_rank_id()), config.ckpt_path)
@moxing_wrapper(pre_process=modelarts_pre_process)
def run_eval(args):
"""run eval function."""
if not os.path.exists(args.test_dir):
args.logger.info("ERROR, test_dir is not exists, please set test_dir in config.py.")
return 0
all_start_time = time.time()
net = get_model(args)
compile_time_used = time.time() - all_start_time
args.logger.info(
"INFO, graph compile finished, time used:{:.2f}s, start calculate img embedding".format(compile_time_used)
)
img_transforms = transforms.Compose(
[vision.ToTensor(), vision.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), is_hwc=False)]
)
# for test images
args.logger.info("INFO, start step1, calculate test img embedding, weight file = {}".format(args.weight))
step1_start_time = time.time()
ds, img_tot, all_labels = get_dataloader(
args.test_img_predix, args.test_img_list, args.test_batch_size, img_transforms
)
args.logger.info("INFO, dataset total test img:{}, total test batch:{}".format(img_tot, ds.get_dataset_size()))
test_embedding_tot_np = np.zeros((img_tot, args.emb_size))
test_img_labels = all_labels
data_loader = ds.create_dict_iterator(output_numpy=True, num_epochs=1)
for i, data in enumerate(data_loader):
img, idxs = data["image"], data["index"]
out = net(Tensor(img)).asnumpy().astype(np.float32)
embeddings = l2normalize(out)
for batch in range(embeddings.shape[0]):
test_embedding_tot_np[idxs[batch]] = embeddings[batch]
try:
check_minmax(args, np.linalg.norm(test_embedding_tot_np, ord=2, axis=1))
except ValueError:
return 0
test_embedding_tot = {}
for idx, label in enumerate(test_img_labels):
test_embedding_tot[label] = test_embedding_tot_np[idx]
step2_start_time = time.time()
step1_time_used = step2_start_time - step1_start_time
args.logger.info(
"INFO, step1 finished, time used:{:.2f}s, start step2, calculate dis img embedding".format(step1_time_used)
)
# for dis images
ds_dis, img_tot, _ = get_dataloader(args.dis_img_predix, args.dis_img_list, args.dis_batch_size, img_transforms)
dis_embedding_tot_np = np.zeros((img_tot, args.emb_size))
total_batch = ds_dis.get_dataset_size()
args.logger.info("INFO, dataloader total dis img:{}, total dis batch:{}".format(img_tot, total_batch))
start_time = time.time()
img_per_gpu = int(math.ceil(1.0 * img_tot / args.world_size))
delta_num = img_per_gpu * args.world_size - img_tot
start_idx = img_per_gpu * args.local_rank - max(0, args.local_rank - (args.world_size - delta_num))
data_loader = ds_dis.create_dict_iterator(output_numpy=True, num_epochs=1)
for idx, data in enumerate(data_loader):
img = data["image"]
out = net(Tensor(img)).asnumpy().astype(np.float32)
embeddings = l2normalize(out)
dis_embedding_tot_np[start_idx : (start_idx + embeddings.shape[0])] = embeddings
start_idx += embeddings.shape[0]
if args.local_rank % 8 == 0 and idx % args.log_interval == 0 and idx > 0:
speed = 1.0 * (args.dis_batch_size * args.log_interval * args.world_size) / (time.time() - start_time)
time_left = (total_batch - idx - 1) * args.dis_batch_size * args.world_size / speed
args.logger.info(
"INFO, processed [{}/{}], speed: {:.2f} img/s, left:{:.2f}s".format(idx, total_batch, speed, time_left)
)
start_time = time.time()
try:
check_minmax(args, np.linalg.norm(dis_embedding_tot_np, ord=2, axis=1))
except ValueError:
return 0
step3_start_time = time.time()
step2_time_used = step3_start_time - step2_start_time
args.logger.info("INFO, step2 finished, time used:{:.2f}s, start step3, calculate top1 acc".format(step2_time_used))
# clear npu memory
img = None
net = None
dis_embedding_tot_np = np.transpose(dis_embedding_tot_np, (1, 0))
args.logger.info("INFO, calculate top1 acc dis_embedding_tot_np shape:{}".format(dis_embedding_tot_np.shape))
# find best match
assert len(args.test_img_list) % 2 == 0
task_num = int(len(args.test_img_list) / 2)
correct = np.array([0] * (2 * task_num))
tot = np.array([0] * task_num)
for i in range(int(len(args.test_img_list) / 2)):
jk_list = args.test_img_list[2 * i]
zj_list = args.test_img_list[2 * i + 1]
zj2jk_pairs = sorted(generate_test_pair(jk_list, zj_list))
sampler = DistributedSampler(zj2jk_pairs)
args.logger.info("INFO, calculate top1 acc sampler len:{}".format(len(sampler)))
for idx in sampler:
out1, out2 = cal_topk(args, idx, zj2jk_pairs, test_embedding_tot, dis_embedding_tot_np)
correct[2 * i] += out1[0]
correct[2 * i + 1] += out1[1]
tot[i] += out2[0]
args.logger.info("local_rank={},tot={},correct={}".format(args.local_rank, tot, correct))
step3_time_used = time.time() - step3_start_time
args.logger.info("INFO, step3 finished, time used:{:.2f}s".format(step3_time_used))
args.logger.info("weight:{}".format(args.weight))
for i in range(int(len(args.test_img_list) / 2)):
test_set_name = "test_dataset"
zj2jk_acc = correct[2 * i] / tot[i]
jk2zj_acc = correct[2 * i + 1] / tot[i]
avg_acc = (zj2jk_acc + jk2zj_acc) / 2
results = "[{}]: zj2jk={:.4f}, jk2zj={:.4f}, avg={:.4f}".format(test_set_name, zj2jk_acc, jk2zj_acc, avg_acc)
args.logger.info(results)
args.logger.info("INFO, tot time used: {:.2f}s".format(time.time() - all_start_time))
return 0
if __name__ == "__main__":
config.test_img_predix = [
os.path.join(config.test_dir, "test_dataset/"),
os.path.join(config.test_dir, "test_dataset/"),
]
config.test_img_list = [
os.path.join(config.test_dir, "lists/jk_list.txt"),
os.path.join(config.test_dir, "lists/zj_list.txt"),
]
config.dis_img_predix = [
os.path.join(config.test_dir, "dis_dataset/"),
]
config.dis_img_list = [
os.path.join(config.test_dir, "lists/dis_list.txt"),
]
log_path = os.path.join(config.ckpt_path, "logs")
config.logger = get_logger(log_path, config.local_rank)
config.logger.info("Config %s", pformat(config))
run_eval(config)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/mindspore/models.git
git@gitee.com:mindspore/models.git
mindspore
models
models
master

搜索帮助