# web2020-lab2 **Repository Path**: leoncoci/web2020-lab2 ## Basic Information - **Project Name**: web2020-lab2 - **Description**: USTC webinfo 2020 lab2 - **Primary Language**: Unknown - **License**: Not specified - **Default Branch**: master - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 0 - **Forks**: 0 - **Created**: 2021-01-03 - **Last Updated**: 2021-04-07 ## Categories & Tags **Categories**: Uncategorized **Tags**: None ## README # web2020-lab2 #### 介绍 USTC webinfo 2020 lab2 #### 使用环境 - python 3.7+ - windows 10 #### 安装教程 1. pip install torch 2. pip install transformers 3. pip install numpy 4. pip install transformers 5. pip install sklearn 6. pip install sentencepiece 7. ~(其他依赖库) #### 使用说明 先使用perprocess.py对训练集和测试集进行预处理,然后修改文件中包含的一些路径和设置参数 ##### xlnet 1. 在https://huggingface.co/models 上找到xlnet-base-cased模型,下载config.json、pytorch_model.bin和spiece.model文件到../xlnet-base-cased下 2. 运行 xlnet-main.py ##### CNN-Chainer 1. 运行 CNN-main.py #### 关键函数说明 ##### xlnet `process_data`对经过perprocess.py处理后的训练集做二次处理,返回值中的sentence_list 是个一维的 list,里面存了每一行文本。label_list 是个一维的 list,里面的值是 int 类型的,就是将原本 str 类型的 label 标签转为对应的 index。classes_list 就是去重后的 label。num_classes 就是 len(classes_list) `convert`将提取出来的sentence_list经过XLNetTokenizer,以每句话为单位,获取一句话中所有词的索引,attention mask等相关内容 `train_eval`主要分类函数,将训练集中划分一部分做为验证集,同时调用Huggingface中的XLNetForSequenceClassification,一个现成的API完成简单的句子分类 `eval`评估函数,每经过特定数量的epoch(这里取得是60)就在验证集上测试一次,保存到目前为止在验证集上准确率最高的模型参数 `save`保存函数,将训练好的模型保存到models文件夹下 `pred`预测函数,将测试集的分类结果保存为../data/result.txt文件 ##### CNN-Chainer `label_aasign`将9中类型的关系按先后各分成俩类(共18类),加上Other类共19类,分类对应的label `max_sen_length`返回所以语句的最大长度 `process`主要的功能函数 #### 文件说明及其他 1. xlnet-main使用的是fine-tune Huggingface 的预训练模型,它主要用到了XLNet模型,关于该模型的内容在https://huggingface.co/transformers/model_doc/xlnet.html 和https://github.com/zihangdai/xlnet 和https://github.com/huggingface/transformers 有详细说明 2. ../xlnet-base-cased文件下的文件是用于tokenizer和model初始化的,因为句子分类模型用的是XLNet的API,他需要用到SentencePiece以及XLNet基本模型,官方文档中给出的方法是 xxx.per_trained('xlnet-base-cased'), 但在国内容易出现RuntimeError,因此需要手动导入模型 3. ../models下是训练得到的model文件和config.json文件 4. 文件夹下data保存用到的文件和输出log(CNN-Chainer)