Sequence分类问题中处理不定长数据
之前写过一篇关于stanford
alpaca的代码的分析,最近在kaggle上看到一个检测某段长文本是否是AI生成的任务LLM -
Detect AI Generated
Text,自己也在尝试做这个任务的时候,发现斯坦福的这份代码真是常看常新。对于数据的准备部分,有很多选择,比如在创建Dataset的时候就把所有的字符串数据tokenize好,在get_item()的函数返回时就返回input_ids,也可以是像斯坦福的这份代码一样,先把数据读取进来然后再用DataCollator
处理(padding)。
我之前没有发现斯坦福这份代码这么写的真正原因,直到我自己来处理这种不定长的序列输入时才发现这样写的绝妙,因为我们都知道矩阵是每一行都需要是同样的size,所以斯坦福的写法在数据处理前期一直在用list,而不是batch。
先来说说transformer的对于序列分类的官方教程的写法传送门
transformer的这个教程直接使用的是自己的数据集,已经规整为datasets了,首先它对数据集使用map函数做了截断的处理:
1 | from transformers import AutoTokenizer |
然后重点来了,注意在上面的preprocess_function
中并没有对序列进行padding,只是对过长的序列做了截断。接着作者使用了datacollatorwithpadding
,给出的理由是:
It's more efficient to dynamically pad the sentences to the longest length in a batch during collation, instead of padding the whole dataset to the maximum length.
我们可以用官方的文档中看到对于datacollator的定义:
Data collators are objects that will form a batch by using a list of dataset elements as input. These elements are of the same type as the elements of
train_dataset
oreval_dataset
也就是data collators的输入是一个list,list里的每一个元素跟train_dataset中的元素是一样的。在data collator中你可以做一些processing,如padding,random masking。我们接下来可以看到斯坦福的羊驼代码就是将padding的步骤放到了data collator内。
1 | from transformers import DataCollatorWithPadding |
在这里的教程里,作者使用了DataCollatorWithPadding
,它会动态地pad
inputs。我看到文档里还有class transformers.DataCollatorForTokenClassification
这个类,maybe可以处理不定长输入,留作后续探索。
transformer的这个教程还是过于简单了,在实际的case中情况会复杂一点。接下来我们看斯坦福的羊驼咋处理不定长sequence的。
首先它先把Dataset定义好:
1 | def preprocess( |
玄机在:
1 | def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict: |
当我们调用SupervisedDataset
实例化数据后我们来看看数据长什么样子
1 | train_dataset中有两个key,一个Input_ids,一个labels |
而且input_ids中每一个值的长度都是不同的,这是因为没有做padding的结果,仅仅是将所有的过长的sequence截断了。
羊驼的代码将所有的padding细节都放到了collator里:
1 |
|
这里作者写了一个自己的类,继承自object。这里没有继承transformer的DefaultDataCollator,暂时不清楚用意,但我觉得应该也可以。这个类实现了一个__call__
方法,接受的是一个Sequence(可迭代对象),对象中是字典(input_ids,
labels),我们上面在创建数据集的时候getitem每次返回一个dict,这个dict里有input_id和label。现在的collator接受的是这个字典的list,也就是有很多个数据(batch_size),我们对这个batch里的数据统一进行padding,这样就实现了在batch内部去pad,避免将所有的字符串都pad成最长的字符长度。