其他
Simple Transformer:用BERT、RoBERTa、XLNet、XLM和DistilBERT进行多类文本分类
【导读】本文将介绍一个简单易操作的Transformers库——Simple Transformers库。它是AI创业公司Hugging Face在Transformers库的基础上构建的。Hugging Face Transformers是供研究与其他需要全面控制操作方式的人员使用的库,简单易操作。
conda create -n transformers python pandas tqdm
conda activate transformers
conda install pytorch cudatoolkit=10.0 -c pytorch
conda install pytorch cpuonly -c pytorch
conda install -c anaconda scipy
conda install -c anaconda scikit-learn
pip install transformers
pip install tensorboardx
pip install simpletransformers
import pandas as pd
train_df = pd.read_csv('data/train.csv', header=None)
train_df['text'] = train_df.iloc[:, 1] + " " + train_df.iloc[:, 2]
train_df = train_df.drop(train_df.columns[[1, 2]], axis=1)
train_df.columns = ['label', 'text']
train_df = train_df[['text', 'label']]
train_df['text'] = train_df['text'].apply(lambda x: x.replace('\\', ' '))
eval_df = pd.read_csv('data/test.csv', header=None)
eval_df['text'] = eval_df.iloc[:, 1] + " " + eval_df.iloc[:, 2]
eval_df = eval_df.drop(eval_df.columns[[1, 2]], axis=1)
eval_df.columns = ['label', 'text']
eval_df = eval_df[['text', 'label']]
eval_df['text'] = eval_df['text'].apply(lambda x: x.replace('\\', ' '))
eval_df['label'] = eval_df['label'].apply(lambda x:x-1
from simpletransformers.model import TransformerModel
# Create a TransformerModel
model = TransformerModel('roberta', 'roberta-base', num_labels=4)
model = TransformerModel('xlnet', 'path_to_model/', num_labels=4)
self.args = {
'output_dir': 'outputs/',
'cache_dir': 'cache_dir',
'fp16': True,
'fp16_opt_level': 'O1',
'max_seq_length': 128,
'train_batch_size': 8,
'gradient_accumulation_steps': 1,
'eval_batch_size': 8,
'num_train_epochs': 1,
'weight_decay': 0,
'learning_rate': 4e-5,
'adam_epsilon': 1e-8,
'warmup_ratio': 0.06,
'warmup_steps': 0,
'max_grad_norm': 1.0,
'logging_steps': 50,
'save_steps': 2000,
'overwrite_output_dir': False,
'reprocess_input_data': False,
'process_count': cpu_count() - 2 if cpu_count() > 2 else 1,
}
# Create a TransformerModel with modified attributes
model = TransformerModel('roberta', 'roberta-base', num_labels=4,
args={'learning_rate':1e-5, 'num_train_epochs': 2,
'reprocess_input_data': True, 'overwrite_output_dir': True})
# Train the model
model.train_model(train_df)
result, model_outputs, wrong_predictions = model.eval_model(eval_df)
from sklearn.metrics import f1_score, accuracy_score
def f1_multiclass(labels, preds):
return f1_score(labels, preds, average='micro')
result, model_outputs, wrong_predictions = model.eval_model(eval_df, f1=f1_multiclass, acc=accuracy_score
{'mcc': 0.937104098029913, 'f1': 0.9527631578947369, 'acc': 0.9527631578947369}
predictions, raw_outputs = model.predict(['Some arbitary sentence'])
(*本文为 AI科技大本营翻译文章,转载请微信联系 1092722531)
◆
精彩推荐
◆
推荐阅读
深度学习可解释性问题如何解决?图灵奖得主Bengio有一个解
亚马逊马超:如何使用DGL进行大规模图神经网络训练?
20 行 Python 代码说清量子霸权!
@程序员 扒一扒编程语言排行榜
为什么我害怕数据结构学得好的程序员?
迷思:Python学到什么程度可以面试工作?
移动互联网这十年,跨平台技术的演进及 Flutter 的未来
刨根问底 | 红遍全网的SD-WAN,到底是什么?
百行代码解读阿里 AloT 芯片平台无剑 100!
你点的每个“在看”,我都认真当成了AI