继续浏览精彩内容
慕课网APP
程序员的梦工厂
打开
继续
感谢您的支持,我会继续努力的
赞赏金额会直接到老师账户
将二维码发送给自己后长按识别
微信支付
支付宝支付

使用 matplotlib 显示 SVG 图片(批量)

心之宙
关注TA
已关注
手记 71
粉丝 37
获赞 167

载入数据迭代器中一批量图片,并以 SVG 格式显示图片:

from pylab import plt, mpl
from IPython import display


class Loader:
    """
    方法
    ========
    L 为该类的实例
    len(L)::返回 batch 的批数
    iter(L)::即为数据迭代器

    Return
    ========
    可迭代对象(numpy 对象)
    """

    def __init__(self, batch_size, X, Y=None, shuffle=True, name=None):
        '''
        X, Y 均为类 numpy, 可以是 HDF5 
        '''
        if name is not None:
            self.name = name
        self.X = X[:]
        if Y is None:
            # print('不存在标签!')
            self.Y = None
        else:
            self.Y = Y[:]
        self.batch_size = batch_size
        self.shuffle = shuffle

    def __iter__(self):
        n = len(self.X)
        idx = np.arange(n)

        if self.shuffle:
            np.random.shuffle(idx)

        for k in range(0, n, self.batch_size):
            K = idx[k:min(k + self.batch_size, n)].tolist()
            if self.Y is None:
                yield np.take(self.X[:], K, 0)
            else:
                yield np.take(self.X[:], K, 0), np.take(self.Y[:], K, 0)

    def __len__(self):
        return round(len(self.X) / self.batch_size)

    def use_svg_display(self):
        # 用矢量图显示。
        display.set_matplotlib_formats('svg')

    def show_imgs(self, label_names, imgs, labels, figsize=(7, 7)):
        '''
        展示 多张图片
        '''
        n = imgs.shape[0]
        h, w = 4, int(n / 4)
        self.use_svg_display()
        _, ax = plt.subplots(h, w, figsize=figsize)  # 设置图的尺寸
        K = np.arange(n).reshape((h, w))
        names = np.asanyarray(
            [label_names[label] for label in labels], dtype='U')
        names = names.reshape((h, w))
        for i in range(h):
            for j in range(w):
                img = imgs[K[i, j]]
                ax[i][j].imshow(img)
                ax[i][j].axes.get_yaxis().set_visible(False)
                ax[i][j].axes.set_xlabel(names[i][j])
                ax[i][j].set_xticks([])
        plt.show()
import tables as tb
h5 = tb.open_file('E:/xdata/X.h5')

data = h5.root.cifar10

batch_size = 32
trainset = Loader(batch_size, data.trainX, data.trainY, shuffle=True, name='train')

for imgs, labels in iter(trainset):
    trainset.show_imgs(data.label_names, imgs, labels)
    break
打开App,阅读手记
0人推荐
发表评论
随时随地看视频慕课网APP