本文结构:
什么是 GAN?
优点?
keras 例子?
什么是 GAN?
GAN,全称为 Generative Adversarial Nets,直译为生成式对抗网络,是一种非监督式模型。
一种应用是生成在原始数据集中不存在的但是却比较合理的数据,还可以拓展一张图片,生成下一帧影像,由简单几笔生成一幅画:
模型:
主要有两部分:
The Generative Model:通过输入任意随机数据,尝试生成一些真实的东西(曲线,图像,声音,文本,...)
The Discriminative Model:试图判定哪些是虚假的数据,来减小对真实数据的误报。
优点:
Markov chains are never needed
避免了计算复杂度特别高的过程,直接进行采样和推断,应用效率相应提高。
a wide variety of functions can be incorporated into the model
针对不同的任务就可以设计不同类型的损失函数。
can represent very sharp, even degenerate distributions
Keras 例子:
任务:生成 sin 曲线。
%matplotlib inlineimport osimport randomimport numpy as npimport pandas as pdimport matplotlib.pyplot as pltimport seaborn as snsfrom tqdm import tqdm_notebook as tqdmfrom keras.models import Modelfrom keras.layers import Input, Reshapefrom keras.layers.core import Dense, Activation, Dropout, Flattenfrom keras.layers.normalization import BatchNormalizationfrom keras.layers.convolutional import UpSampling1D, Conv1Dfrom keras.layers.advanced_activations import LeakyReLUfrom keras.optimizers import Adam, SGDfrom keras.callbacks import TensorBoard
1. Generative model:
输入:noise data
输出:尝试生成真实的 sin 数据
def get_generative(G_in, dense_dim=200, out_dim=50, lr=1e-3): x = Dense(dense_dim)(G_in) x = Activation('tanh')(x) G_out = Dense(out_dim, activation='tanh')(x) G = Model(G_in, G_out) opt = SGD(lr=lr) G.compile(loss='binary_crossentropy', optimizer=opt) return G, G_out
2. Discriminative model:
输出:识别此数据是真实的,还是由 Generative model 生成的
def get_discriminative(D_in, lr=1e-3, drate=.25, n_channels=50, conv_sz=5, leak=.2): x = Reshape((-1, 1))(D_in) x = Conv1D(n_channels, conv_sz, activation='relu')(x) x = Dropout(drate)(x) x = Flatten()(x) x = Dense(n_channels)(x) D_out = Dense(2, activation='sigmoid')(x) D = Model(D_in, D_out) dopt = Adam(lr=lr) D.compile(loss='binary_crossentropy', optimizer=dopt) return D, D_out
3. chain the two models into a GAN:
set_trainability 的作用是每次训练 generator 时要冻住 discriminator。
def set_trainability(model, trainable=False): model.trainable = trainable for layer in model.layers: layer.trainable = trainable def make_gan(GAN_in, G, D): set_trainability(D, False) x = G(GAN_in) GAN_out = D(x) GAN = Model(GAN_in, GAN_out) GAN.compile(loss='binary_crossentropy', optimizer=G.optimizer) return GAN, GAN_out
4. Training:
交替训练 discriminator 和 chained GAN,在训练 chained GAN 时要冻住 discriminator 的参数:
def sample_noise(G, noise_dim=10, n_samples=10000): X = np.random.uniform(0, 1, size=[n_samples, noise_dim]) y = np.zeros((n_samples, 2)) y[:, 1] = 1 return X, ydef train(GAN, G, D, epochs=500, n_samples=10000, noise_dim=10, batch_size=32, verbose=False, v_freq=50): d_loss = [] g_loss = [] e_range = range(epochs) if verbose: e_range = tqdm(e_range) for epoch in e_range: X, y = sample_data_and_gen(G, n_samples=n_samples, noise_dim=noise_dim) set_trainability(D, True) d_loss.append(D.train_on_batch(X, y)) X, y = sample_noise(G, n_samples=n_samples, noise_dim=noise_dim) set_trainability(D, False) g_loss.append(GAN.train_on_batch(X, y)) if verbose and (epoch + 1) % v_freq == 0: print("Epoch #{}: Generative Loss: {}, Discriminative Loss: {}".format(epoch + 1, g_loss[-1], d_loss[-1])) return d_loss, g_loss d_loss, g_loss = train(GAN, G, D, verbose=True)
5. Results:
N_VIEWED_SAMPLES = 2 data_and_gen, _ = sample_data_and_gen(G, n_samples=N_VIEWED_SAMPLES) pd.DataFrame(np.transpose(data_and_gen[N_VIEWED_SAMPLES:])).rolling(5).mean()[5:].plot()