我正在尝试py_func通过使用Dataset.map()来创建我的输入管道来将 .h5 解析器函数与包装器进行映射。我想传递两个参数:filename和window_size在地图函数中。以下代码有调用顺序:Dataset.map--> _pyfn_wrapper-->parse_h5
缺点是使用 map() 函数时 _pyfn_wrapper 只能接受一个参数,因为不能压缩from_tensor_slices2 种类型的数据:字符串然后是 int
def helper(window_size, batch_size, ncores=mp.cpu_count()):
flist = []
for dirpath, _, fnames in os.walk('./'):
for fname in fnames:
flist.append(os.path.abspath(os.path.join(dirpath, fname)))
f_len = len(flist)
# init list of files
batch = tf.data.Dataset.from_tensor_slices((tf.constant(flist))) #fixme: how to zip one list of string and a list of int
batch = batch.map_fn(_pyfn_wrapper, num_parallel_calls=ncores) #fixme: how to map two args
batch = batch.shuffle(batch_size).batch(batch_size, drop_remainder=True).prefetch(ncores + 6)
# construct iterator
it = batch.make_initializable_iterator()
iter_init_op = it.initializer
# get next img and label
X_it, y_it = it.get_next()
inputs = {'img': X_it, 'label': y_it, 'iterator_init_op': iter_init_op}
return inputs, f_len
def _pyfn_wrapper(filename): #fixme: args
# filename, window_size = args #fixme: try to separate args
window_size = 100
return tf.py_func(parse_h5, #wrapped pythonic function
[filename, window_size],
[tf.float32, tf.float32] #[input, output] dtype
)
def parse_h5(name, window_size):
with h5py.File(name.decode('utf-8'), 'r') as f:
X = f['X'][:].reshape(window_size, window_size, 1)
y = f['y'][:].reshape(window_size, window_size, 1)
return X, y
一只名叫tom的猫
相关分类