代码拉取完成,页面将自动刷新
import logging
logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
from transformers import XLMRobertaForSequenceClassification,XLMRobertaTokenizer
import sys, os
sys.path.insert(0, os.path.abspath('..'))
from classification_utils.my_dataset import MultilingualNLIDataset
from classification_utils.predict_function import predict
model_path = sys.argv[1]
taskname = 'pawsx'
data_dir = '../datasets/pawsx'
split = 'test'
max_seq_length=128
eval_langs = ['en']
batch_size=32
device = 'cuda'
# Re-initialze the tokenizer
model = XLMRobertaForSequenceClassification.from_pretrained(model_path).to(device)
tokenizer = XLMRobertaTokenizer.from_pretrained(model_path)
eval_dataset = MultilingualNLIDataset(
task=taskname, data_dir=data_dir, split=split, prefix='xlmr',
max_seq_length=max_seq_length, langs=eval_langs, tokenizer=tokenizer)
eval_datasets = [eval_dataset.lang_datasets[lang] for lang in eval_langs]
predict(model, eval_datasets, eval_langs, device, batch_size)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。