data,count,dictionary,reverse_dictionary = build_dataset(arg_words=words)
# 删除 words 节省内存
del words
data_index = 0
# 3. 为 skip_gram 模型生成训练参数
def generate_batch(arg_batch_size,arg_num_skips,arg_ski_windows):
global data_index
l_batch = np.ndarray(shape=arg_batch_size,dtype=np.int32) # (1,arg_batch_size)
l_labels = np.ndarray(shape=(arg_batch_size,1),dtype=np.int32) #(arg_batch_size,1)
span = 2 * arg_ski_windows + 1 # [我 爱 祖 国]
buffer = collections.deque(maxlen=span)
for _ in range(span):
buffer.append(data[data_index])
data_index = (data_index + 1) % len(data)
for i in range(arg_batch_size // arg_num_skips):
target = arg_ski_windows
targets_to_avoid = [arg_ski_windows]
for j in range(arg_num_skips):
while target in targets_to_avoid:
target = random.randint(0,span - 1)
targets_to_avoid.append(target)
l_batch[i * arg_num_skips + j] = buffer[arg_ski_windows]
l_labels[i * arg_ski_windows + j, 0] = buffer[target]
buffer.append(data[data_index])
data_index = (data_index + 1) % len(data)
return l_batch, l_labels
# 显示示例
batch,lables = generate_batch(arg_batch_size = 8, arg_num_skips = 2, arg_ski_windows = 1)
for i in range(8):
print(batch[i],reverse_dictionary[batch[i]], "->", lables[i,0], reverse_dictionary[lables[i,0]])