GPT2 on Hugging Face(pytorch 变压器)运行时错误:

我正在尝试使用我的自定义数据集微调 gpt2。我使用拥抱面变压器的文档创建了一个基本示例。我收到上述错误。我知道这意味着什么:(基本上它是在非标量张量上向后调用)但由于我几乎只使用 API 调用,所以我不知道如何解决这个问题。有什么建议么?


from pathlib import Path

from absl import flags, app

import IPython

import torch

from transformers import GPT2LMHeadModel, Trainer,  TrainingArguments

from data_reader import GetDataAsPython


# this is my custom data, but i get the same error for the basic case below

# data = GetDataAsPython('data.json')

# data = [data_point.GetText2Text() for data_point in data]

# print("Number of data samples is", len(data))


data = ["this is a trial text", "this is another trial text"]


train_texts = data


from transformers import GPT2Tokenizer

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')


special_tokens_dict = {'pad_token': '<PAD>'}

num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)

train_encodigs = tokenizer(train_texts, truncation=True, padding=True)



class BugFixDataset(torch.utils.data.Dataset):

    def __init__(self, encodings):

        self.encodings = encodings

    

    def __getitem__(self, index):

        item = {key: torch.tensor(val[index]) for key, val in self.encodings.items()}

        return item


    def __len__(self):

        return len(self.encodings['input_ids'])


train_dataset = BugFixDataset(train_encodigs)


training_args = TrainingArguments(

    output_dir='./results',          

    num_train_epochs=3,              

    per_device_train_batch_size=1,  

    per_device_eval_batch_size=1,   

    warmup_steps=500,                

    weight_decay=0.01,               

    logging_dir='./logs',

    logging_steps=10,

)


model = GPT2LMHeadModel.from_pretrained('gpt2', return_dict=True)

model.resize_token_embeddings(len(tokenizer))


trainer = Trainer(

    model=model,

    args=training_args,

    train_dataset=train_dataset,

)


trainer.train()


慕哥9229398
浏览 145回答 1
1回答

海绵宝宝撒

我终于弄明白了。问题在于数据样本不包含目标输出。即使很难的 gpt 也是自我监督的,这必须明确地告诉模型。你必须添加以下行:item['labels']&nbsp;=&nbsp;torch.tensor(self.encodings['input_ids'][index])到Dataset类的getitem函数,然后就可以正常运行了!
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python