猿问

在转换过程中从 tensorflow 对象中提取 numpy 值

我正在尝试使用 tensorflow 获取词嵌入,并且我已经使用我的语料库创建了相邻的工作列表。

我的词汇表中唯一单词的数量为 8000,相邻单词列表的数量约为 160 万

单词列表示例照片

由于数据非常大,我试图将单词列表分批写入 TFRecords 文件。

def save_tfrecords_wordlist(toprocess_word_lists, path ):    

    writer = tf.io.TFRecordWriter(path)


    for word_list in toprocess_word_lists:

        features=tf.train.Features(

            feature={

        'word_list_X': tf.train.Feature( bytes_list=tf.train.BytesList(value=[word_list[0].encode('utf-8')] )),

        'word_list_Y': tf.train.Feature( bytes_list=tf.train.BytesList(value=[word_list[1].encode('utf-8') ]))

                }

            )

        example = tf.train.Example(features = features)

        writer.write(example.SerializeToString())

    writer.close()

定义批次

batches = [0,250000,500000,750000,1000000,1250000,1500000,1641790]


for i in range(len(batches) - 1 ):


    batches_start = batches[i]

    batches_end = batches[i + 1]

    print( str(batches_start) + " -- " + str(batches_end ))


    toprocess_word_lists = word_lists[batches_start:batches_end]

    save_tfrecords_wordlist( toprocess_word_lists, path +"/TFRecords/data_" + str(i) +".tfrecords")

##############################


def _parse_function(example_proto):


  features = {"word_list_X": tf.io.FixedLenFeature((), tf.string),

          "word_list_Y": tf.io.FixedLenFeature((), tf.string)}

  parsed_features = tf.io.parse_single_example(example_proto, features)




海绵宝宝撒
浏览 111回答 1
1回答

沧海一幻觉

似乎您无法从映射函数(1、2)内部调用 .numpy() 函数,尽管我能够使用来自(doc)的 py_function 进行管理。在下面的示例中,我已将我解析的数据集映射到一个函数,该函数将我的图像转换为np.uint8以便使用 matplotlib绘制它们。records_path = data_directory+'TFRecords'+'/data_0.tfrecord'# Create a datasetdataset = tf.data.TFRecordDataset(filenames=records_path)# Map our dataset to the parsing function parsed_dataset = dataset.map(parsing_fn)converted_dataset = parsed_dataset.map(lambda image,label:                                       tf.py_function(func=converting_function,                                                      inp=[image,label],                                                      Tout=[np.uint8,tf.int64]))# Gets the iteratoriterator = tf.compat.v1.data.make_one_shot_iterator(converted_dataset) for i in range(5):    image,label = iterator.get_next()    plt.imshow(image)    plt.show()    print('label: ', label)输出:解析函数:def parsing_fn(serialized):    # Define a dict with the data-names and types we expect to    # find in the TFRecords file.    features = \        {            'image': tf.io.FixedLenFeature([], tf.string),            'label': tf.io.FixedLenFeature([], tf.int64)                    }    # Parse the serialized data so we get a dict with our data.    parsed_example = tf.io.parse_single_example(serialized=serialized,                                             features=features)    # Get the image as raw bytes.    image_raw = parsed_example['image']    # Decode the raw bytes so it becomes a tensor with type.    image = tf.io.decode_jpeg(image_raw)        # Get the label associated with the image.    label = parsed_example['label']        # The image and label are now correct TensorFlow types.    return image, label更新:实际上并没有签出,但 tf.shape() 似乎也是一个有前途的选择。
随时随地看视频慕课网APP

相关分类

Python
我要回答