Sequence分类问题中处理不定长数据

之前写过一篇关于stanford alpaca的代码的分析,最近在kaggle上看到一个检测某段长文本是否是AI生成的任务LLM - Detect AI Generated Text,自己也在尝试做这个任务的时候,发现斯坦福的这份代码真是常看常新。对于数据的准备部分,有很多选择,比如在创建Dataset的时候就把所有的字符串数据tokenize好,在get_item()的函数返回时就返回input_ids,也可以是像斯坦福的这份代码一样,先把数据读取进来然后再用DataCollator处理(padding)。

我之前没有发现斯坦福这份代码这么写的真正原因,直到我自己来处理这种不定长的序列输入时才发现这样写的绝妙,因为我们都知道矩阵是每一行都需要是同样的size,所以斯坦福的写法在数据处理前期一直在用list,而不是batch。

先来说说transformer的对于序列分类的官方教程的写法传送门

transformer的这个教程直接使用的是自己的数据集,已经规整为datasets了,首先它对数据集使用map函数做了截断的处理:

1
2
3
4
5
6
7
8
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

def preprocess_function(examples):
return tokenizer(examples["text"], truncation=True)

tokenized_imdb = imdb.map(preprocess_function, batched=True)

然后重点来了,注意在上面的preprocess_function中并没有对序列进行padding,只是对过长的序列做了截断。接着作者使用了datacollatorwithpadding,给出的理由是:

It's more efficient to dynamically pad the sentences to the longest length in a batch during collation, instead of padding the whole dataset to the maximum length.

我们可以用官方的文档中看到对于datacollator的定义

Data collators are objects that will form a batch by using a list of dataset elements as input. These elements are of the same type as the elements of train_dataset or eval_dataset

也就是data collators的输入是一个list,list里的每一个元素跟train_dataset中的元素是一样的。在data collator中你可以做一些processing,如padding,random masking。我们接下来可以看到斯坦福的羊驼代码就是将padding的步骤放到了data collator内。

1
2
3
from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

在这里的教程里,作者使用了DataCollatorWithPadding,它会动态地pad inputs。我看到文档里还有class transformers.DataCollatorForTokenClassification这个类,maybe可以处理不定长输入,留作后续探索。

transformer的这个教程还是过于简单了,在实际的case中情况会复杂一点。接下来我们看斯坦福的羊驼咋处理不定长sequence的。

首先它先把Dataset定义好:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def preprocess(
sources: Sequence[str],
targets: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
"""Preprocess the data by tokenizing."""
examples = [s + t for s, t in zip(sources, targets)]
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
input_ids = examples_tokenized["input_ids"]
labels = copy.deepcopy(input_ids)
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
label[:source_len] = IGNORE_INDEX
return dict(input_ids=input_ids, labels=labels)

class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""

def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer):
super(SupervisedDataset, self).__init__()
logging.warning("Loading data...")
list_data_dict = utils.jload(data_path)

logging.warning("Formatting inputs...")
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
sources = [
prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
for example in list_data_dict
]
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict] # 将output之后加上[EOS]

logging.warning("Tokenizing inputs... This may take some time...")
data_dict = preprocess(sources, targets, tokenizer)

self.input_ids = data_dict["input_ids"]
self.labels = data_dict["labels"]

def __len__(self):
return len(self.input_ids)

def __getitem__(self, i) -> Dict[str, torch.Tensor]:
return dict(input_ids=self.input_ids[i], labels=self.labels[i])

玄机在:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
"""Tokenize a list of strings."""
tokenized_list = [
tokenizer(
text,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
) # 这里做了循环,也就是对strings这个list里的每一个sequence单独做的tokenize,然后把这些不等长的input_ids一同放到一个list里。之所以用list,是因为list里可以存储不等长的list。一直到这一步都没有做padding
for text in strings
]
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
input_ids_lens = labels_lens = [
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
]
return dict(
input_ids=input_ids,
labels=labels,
input_ids_lens=input_ids_lens,
labels_lens=labels_lens,
)

当我们调用SupervisedDataset实例化数据后我们来看看数据长什么样子

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
train_dataset中有两个key,一个Input_ids,一个labels
input_ids中的值长这样:
tensor([ 2, 45943, 16, 41, 15741, 14, 7448, 10, 3685, 4,
21062, 10, 1263, 14, 16574, 25830, 5, 2069, 4, 50118,
50118, 48134, 41241, 35, 50118, 31033, 130, 4965, 13, 4959,
2245, 4, 50118, 50118, 48134, 19121, 35, 134, 4, 43800,
10, 9320, 5626, 8, 146, 686, 7, 680, 2710, 9,
12849, 8, 8942, 4, 1437, 50118, 176, 4, 30450, 4595,
7, 489, 110, 809, 2171, 8, 670, 4, 1437, 50118,
246, 4, 2315, 615, 3581, 8, 3014, 10, 4292, 3581,
3078, 4, 2])
labels中的值长这样:
tensor([ -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, 134, 4, 43800,
10, 9320, 5626, 8, 146, 686, 7, 680, 2710, 9,
12849, 8, 8942, 4, 1437, 50118, 176, 4, 30450, 4595,
7, 489, 110, 809, 2171, 8, 670, 4, 1437, 50118,
246, 4, 2315, 615, 3581, 8, 3014, 10, 4292, 3581,
3078, 4, 2])

而且input_ids中每一个值的长度都是不同的,这是因为没有做padding的结果,仅仅是将所有的过长的sequence截断了。

羊驼的代码将所有的padding细节都放到了collator里:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
@dataclass
class DataCollatorForSupervisedDataset(object):
"""Collate examples for supervised fine-tuning."""

tokenizer: transformers.PreTrainedTokenizer

def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
)
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
return dict(
input_ids=input_ids,
labels=labels,
attention_mask=input_ids.ne(self.tokenizer.pad_token_id), # see https://pytorch.org/docs/stable/generated/torch.ne.html#torch.ne
)

这里作者写了一个自己的类,继承自object。这里没有继承transformer的DefaultDataCollator,暂时不清楚用意,但我觉得应该也可以。这个类实现了一个__call__方法,接受的是一个Sequence(可迭代对象),对象中是字典(input_ids, labels),我们上面在创建数据集的时候getitem每次返回一个dict,这个dict里有input_id和label。现在的collator接受的是这个字典的list,也就是有很多个数据(batch_size),我们对这个batch里的数据统一进行padding,这样就实现了在batch内部去pad,避免将所有的字符串都pad成最长的字符长度。