1 Star 0 Fork 0

dengxuezheng/VIDEVAL_release

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
demo_pred_MOS_pretrained_VIDEVAL.py 2.00 KB
一键复制 编辑 原始数据 按行查看 历史
vztu 提交于 2020-05-19 00:17 . initial commit
# -*- coding: utf-8 -*-
"""
This script predicts a quality score in [1,5] given a VIDEVAL feature
vector by a pretrained VIDEVAL model
Input:
- feature matrix:
eg, features/TEST_VIDEOS_VIDEVAL_feats.mat
Output:
- predicted scores:
eg, results/TEST_VIDEOS_VIDEVAL_pred.csv
"""
# Load libraries
from sklearn import model_selection
import os
import warnings
import time
import scipy.io
from sklearn.svm import SVR
import numpy as np
from sklearn.svm import SVC
from sklearn.preprocessing import MinMaxScaler
from sklearn.externals import joblib
# ignore all warnings
warnings.filterwarnings("ignore")
# ===========================================================================
# Here starts the main part of the script
#
'''======================== parameters ================================'''
model_name = 'SVR'
data_name = 'TEST_VIDEOS'
algo_name = 'VIDEVAL'
mat_file = os.path.join('features', data_name+'_'+algo_name+'_feats.mat')
model_file = os.path.join('model', algo_name+'_trained_svr.pkl')
scaler_file = os.path.join('model', algo_name+'_trained_scaler.pkl')
pars_file = os.path.join('model', algo_name+'_logistic_pars.mat')
result_file = os.path.join('results', data_name+'_'+algo_name+'_pred.csv')
print("Predict quality scores using pretrained {} with {} on dataset {} ...".format(
algo_name, model_name, data_name))
'''======================== read files =============================== '''
X_mat = scipy.io.loadmat(mat_file)
X = np.asarray(X_mat['feats_mat'], dtype=np.float)
X[np.isnan(X)] = 0
X[np.isinf(X)] = 0
model = joblib.load(model_file)
scaler = joblib.load(scaler_file)
popt = np.asarray(scipy.io.loadmat(pars_file)['popt'][0], dtype=np.float)
X = scaler.transform(X)
y_pred = model.predict(X)
def logistic_func(X, bayta1, bayta2, bayta3, bayta4):
logisticPart = 1 + np.exp(np.negative(np.divide(X - bayta3, np.abs(bayta4))))
yhat = bayta2 + np.divide(bayta1 - bayta2, logisticPart)
return yhat
y = logistic_func(y_pred, *popt)
np.savetxt(result_file, y, delimiter=",")
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/dengxuezheng/VIDEVAL_release.git
git@gitee.com:dengxuezheng/VIDEVAL_release.git
dengxuezheng
VIDEVAL_release
VIDEVAL_release
master

搜索帮助