当前位置: 智能网 > 人工智能 > NLP ——从0开始快速上手百度 ERNIE

NLP ——从0开始快速上手百度 ERNIE

放大字体 缩小字体 发布日期:2020-12-17 12:02:29   浏览次数:209


三、具体实现过程

开始写代码!

ChnSentiCorp任务运行的shell脚本是 ERNIE/ernie/run_classifier.py,该文件定义了分类任务Fine-tuning 的详细过程,下面我们将通过如下几个步骤进行详细剖析:

环境准备。导入相关的依赖,解析命令行参数;

实例化ERNIE 模型,优化器以及Tokenizer, 并设置超参数

定义辅助函数

运行训练循环

1. 环境准备

import相关的依赖,解析命令行参数。

import syssys.path.append('./ERNIE')import numpy as npfrom sklearn.metrics import f1_scoreimport paddle as Pimport paddle.fluid as Fimport paddle.fluid.layers as Limport paddle.fluid.dygraph as D
from ernie.tokenizing_ernie import ErnieTokenizerfrom ernie.modeling_ernie import ErnieModelForSequenceClassification2. 实例化ERNIE 模型,优化器以及Tokenizer, 并设置超参数

设置好所有的超参数,对于ERNIE任务学习率推荐取 1e-5/2e-5/5e-5, 根据显存大小调节BATCH大小, 最大句子长度不超过512.

BATCH=32MAX_SEQLEN=300LR=5e-5EPOCH=10
D.guard().__enter__() # 为了让Paddle进入动态图模式,需要添加这一行在最前面
ernie = ErnieModelForSequenceClassification.from_pretrained('ernie-1.0', num_labels=3)optimizer = F.optimizer.Adam(LR, parameter_list=ernie.parameters())tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0')3. 定义辅助函数

(1)定义函数 make_data,将文本数据读入内存并转换为numpy List存储。

def make_data(path):    data = []    for i, l in enumerate(open(path)):        if i == 0:            continue        l = l.strip().split('')        text, label = l[0], int(l[1])        text_id, _ = tokenizer.encode(text) # ErnieTokenizer 会自动添加ERNIE所需要的特殊token,如[CLS], [SEP]        text_id = text_id[:MAX_SEQLEN]        text_id = np.pad(text_id, [0, MAX_SEQLEN-len(text_id)], mode='constant') # 对所有句子都补长至300,这样会比较费显存;        label_id = np.array(label+1)        data.append((text_id, label_id))    return data
train_data = make_data('./chnsenticorp/train/part.0')test_data = make_data('./chnsenticorp/dev/part.0')

(2)定义函数get_batch_data,用于获取BATCH条样本并按照批处理维度stack到一起。

def get_batch_data(data, i):    d = data[i*BATCH: (i + 1) * BATCH]    feature, label = zip(*d)    feature = np.stack(feature)  # 将BATCH行样本整合在一个numpy.array中    label = np.stack(list(label))    feature = D.to_variable(feature) # 使用to_variable将numpy.array转换为paddle tensor    label = D.to_variable(label)    return feature, label4. 运行训练循环

队训练数据重复EPOCH遍训练循环;每次循环开头都会重新shuffle数据。在训练过程中每间隔100步在验证数据集上进行测试并汇报结果(acc)。

for i in range(EPOCH):    np.random.shuffle(train_data) # 每个epoch都shuffle数据以获得最佳训练效果;    #train    for j in range(len(train_data) // BATCH):        feature, label = get_batch_data(train_data, j)        loss, _ = ernie(feature, labels=label) # ernie模型的返回值包含(loss, logits);其中logits目前暂时不需要使用        loss.backward()        optimizer.minimize(loss)        ernie.clear_gradients()        if j % 10 == 0:            print('train %d: loss %.5f' % (j, loss.numpy()))        # evaluate        if j % 100 == 0:            all_pred, all_label = [], []            with D.base._switch_tracer_mode_guard_(is_train=False): # 在这个with域内ernie不会进行梯度计算;                ernie.eval() # 控制模型进入eval模式,这将会关闭所有的dropout;                for j in range(len(test_data) // BATCH):                    feature, label = get_batch_data(test_data, j)                    loss, logits = ernie(feature, labels=label)                     all_pred.extend(L.argmax(logits, -1).numpy())                    all_label.extend(label.numpy())                ernie.train()            f1 = f1_score(all_label, all_pred, average='macro')            acc = (np.array(all_label) == np.array(all_pred)).astype(np.float32).mean()            print('acc %.5f' % acc)

训练过程中单次迭代输出的日志如下所示:

train 0: loss 0.05833acc 0.91723train 10: loss 0.03602train 20: loss 0.00047train 30: loss 0.02403train 40: loss 0.01642train 50: loss 0.12958train 60: loss 0.04629train 70: loss 0.00942train 80: loss 0.00068train 90: loss 0.05485train 100: loss 0.01527acc 0.92821train 110: loss 0.00927train 120: loss 0.07236train 130: loss 0.01391train 140: loss 0.01612

包含了当前 batch 的训练得到的Loss(ave loss)和每个Epochde 精度(acc)信息。训练完成后用户可以参考快速运行中的方法使用模型体验推理功能。

其它特性

ERNIE 还提供了混合精度训练、模型蒸馏等高级功能,可以在 README 中获得这些功能的使用方法。

图片标题


<上一页  3  
 
关键词: 数据 模型 推理

[ 智能网搜索 ]  [ 打印本文 ]  [ 违规举报

猜你喜欢

 
推荐图文
ITECH直流电源在人工智能领域的应用 基于朴素贝叶斯自动过滤垃圾广告
2020年是人工智能相关业务发展的重要一年 我国人工智能市场规模、行业短板、发展前景一览
推荐智能网
点击排行

 
 
新能源网 | 锂电网 | 智能网 | 环保设备网 | 联系方式