载入数据迭代器中一批量图片,并以 SVG 格式显示图片:
from pylab import plt, mpl
from IPython import display
class Loader:
"""
方法
========
L 为该类的实例
len(L)::返回 batch 的批数
iter(L)::即为数据迭代器
Return
========
可迭代对象(numpy 对象)
"""
def __init__(self, batch_size, X, Y=None, shuffle=True, name=None):
'''
X, Y 均为类 numpy, 可以是 HDF5
'''
if name is not None:
self.name = name
self.X = X[:]
if Y is None:
# print('不存在标签!')
self.Y = None
else:
self.Y = Y[:]
self.batch_size = batch_size
self.shuffle = shuffle
def __iter__(self):
n = len(self.X)
idx = np.arange(n)
if self.shuffle:
np.random.shuffle(idx)
for k in range(0, n, self.batch_size):
K = idx[k:min(k + self.batch_size, n)].tolist()
if self.Y is None:
yield np.take(self.X[:], K, 0)
else:
yield np.take(self.X[:], K, 0), np.take(self.Y[:], K, 0)
def __len__(self):
return round(len(self.X) / self.batch_size)
def use_svg_display(self):
# 用矢量图显示。
display.set_matplotlib_formats('svg')
def show_imgs(self, label_names, imgs, labels, figsize=(7, 7)):
'''
展示 多张图片
'''
n = imgs.shape[0]
h, w = 4, int(n / 4)
self.use_svg_display()
_, ax = plt.subplots(h, w, figsize=figsize) # 设置图的尺寸
K = np.arange(n).reshape((h, w))
names = np.asanyarray(
[label_names[label] for label in labels], dtype='U')
names = names.reshape((h, w))
for i in range(h):
for j in range(w):
img = imgs[K[i, j]]
ax[i][j].imshow(img)
ax[i][j].axes.get_yaxis().set_visible(False)
ax[i][j].axes.set_xlabel(names[i][j])
ax[i][j].set_xticks([])
plt.show()
数据集的使用见:使用 迭代器 获取 Cifar 等常用数据集
import tables as tb
h5 = tb.open_file('E:/xdata/X.h5')
data = h5.root.cifar10
batch_size = 32
trainset = Loader(batch_size, data.trainX, data.trainY, shuffle=True, name='train')
for imgs, labels in iter(trainset):
trainset.show_imgs(data.label_names, imgs, labels)
break