梦里花落0921
class TextLoader():def __init__(self, data_dir, batch_size, seq_length, encoding='utf-8'):self.data_dir = data_dirself.batch_size = batch_sizeself.seq_length = seq_lengthself.encoding = encoding#第一次运行程序时只有input.txt一个文件,剩下两个文件是运行之后产生的input_file = os.path.join(data_dir, "input.txt")vocab_file = os.path.join(data_dir, "vocab.pkl")tensor_file = os.path.join(data_dir, "data.npy")#如果是第一次执行则调用preprocess函数,否则调用load_preprocessed函数。if not (os.path.exists(vocab_file) and os.path.exists(tensor_file)):print("reading text file")self.preprocess(input_file, vocab_file, tensor_file)else:print("loading preprocessed files")self.load_preprocessed(vocab_file, tensor_file)self.create_batches()self.reset_batch_pointer()def preprocess(self, input_file, vocab_file, tensor_file):with codecs.open(input_file, "r", encoding=self.encoding) as f:data = f.read()#使用Counter函数对输入数据进行统计。counter保存data中每个字符出现的次数counter = collections.Counter(data)#对counter进行排序,出现次数最多的排在前面count_pairs = sorted(counter.items(), key=lambda x: -x[1])#将data中出现的所有字符保存,这里有65个,所以voacb_size=65self.chars, _ = zip(*count_pairs)self.vocab_size = len(self.chars)#按照字符出现次数多少顺序将chars保存,vocab中存储的是char和顺序,这样方便将data转化为索引self.vocab = dict(zip(self.chars, range(len(self.chars))))with open(vocab_file, 'wb') as f:#保存charscPickle.dump(self.chars, f)#将data中每个字符转化为索引下标。self.tensor = np.array(list(map(self.vocab.get, data)))np.save(tensor_file, self.tensor)def load_preprocessed(self, vocab_file, tensor_file):#如果是第二次运行,则可以直接读取之前保存的chars和tensorwith open(vocab_file, 'rb') as f:self.chars = cPickle.load(f)self.vocab_size = len(self.chars)self.vocab = dict(zip(self.chars, range(len(self.chars))))self.tensor = np.load(tensor_file)self.num_batches = int(self.tensor.size / (self.batch_size *self.seq_length))def create_batches(self):#首先将数据按batch_size切割,然后每个batch_size在按照seq_length进行切割self.num_batches = int(self.tensor.size / (self.batch_size *self.seq_length))if self.num_batches == 0:assert False, "Not enough data. Make seq_length and batch_size small."self.tensor = self.tensor[:self.num_batches * self.batch_size * self.seq_length]xdata = self.tensor#构造target,这里使用上一个词预测下一个词,所以直接将x向后一个字符即可ydata = np.copy(self.tensor)ydata[:-1] = xdata[1:]ydata[-1] = xdata[0]#将数据进行切分,这里我们假设数据总长度为10000,batch_size为100, seq_length为10.# 所以num_batches=10,所以,xdata在reshape之后变成[100, 100],然后在第二个维度上切成10份,# 所以最终得到[100, 10, 10]的数据self.x_batches = np.split(xdata.reshape(self.batch_size, -1),self.num_batches, 1)self.y_batches = np.split(ydata.reshape(self.batch_size, -1),self.num_batches, 1)def next_batch(self):x, y = self.x_batches[self.pointer], self.y_batches[self.pointer]self.pointer += 1return x, ydef reset_batch_pointer(self):self.pointer = 0