from __future__ import print_function, division import tensorflow from keras.datasets import mnist from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply, GaussianNoise, Lambda, Concatenate from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D from keras.layers import MaxPooling2D, merge from keras.layers.advanced_activations import LeakyReLU from keras.layers.convolutional import UpSampling2D, Conv2D from keras.models import Sequential, Model from keras.optimizers import Adam from keras import losses from keras.utils import to_categorical import keras.backend as K import matplotlib.pyplot as plt import numpy as np class AdversarialAutoencoder(): def __init__(self): self.img_rows = 28 self.img_cols = 28 self.channels = 1 self.img_shape = (self.img_rows, self.img_cols, self.channels) self.latent_dim = 10 self.n_labels = 10 optimizer = Adam(0.0002, 0.5) # Build and compile the discriminator self.discriminator = self.build_discriminator() self.discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy']) # Build the encoder / decoder self.encoder = self.build_encoder() self.decoder = self.build_decoder() img = Input(shape=self.img_shape) # The generator takes the image, encodes it and reconstructs it # from the encoding encoded_repr = self.encoder(img) labels = Input(shape=(self.n_labels,)) # .encoder(img) reconstructed_img = self.decoder([encoded_repr, labels]) # For the adversarial_autoencoder model we will only train the generator self.discriminator.trainable = False # The discriminator determines validity of the encoding validity = self.discriminator([encoded_repr, labels]) # The adversarial_autoencoder model (stacked generator and discriminator) self.adversarial_autoencoder = Model([img, labels], [reconstructed_img, validity]) self.adversarial_autoencoder.compile(loss=['binary_crossentropy', 'binary_crossentropy'], loss_weights=[0.5, 0.5], optimizer=optimizer) def build_encoder(self): # Encoder img = Input(shape=self.img_shape) h = Flatten()(img) h = Dense(512)(h) h = LeakyReLU(alpha=0.2)(h) h = Dense(512)(h) h = LeakyReLU(alpha=0.2)(h) # mu = Dense(self.latent_dim)(h) latent_repr = Dense(self.latent_dim)(h) # log_var = Dense(self.latent_dim)(h) # latent_repr = merge([mu, log_var], # mode=lambda p: p[0] + K.random_normal(K.shape(p[0])) * K.exp(p[1] / 2), # output_shape=lambda p: p[0]) return Model(img, latent_repr) def build_decoder(self): # model = Sequential() z = Input(shape=(self.latent_dim,)) l = Input(shape=(self.n_labels,)) x = Concatenate()([z, l]) x = Dense(512, input_dim=self.latent_dim)(x) x = LeakyReLU(alpha=0.2)(x) x = Dense(512)(x) x = LeakyReLU(alpha=0.2)(x) x = Dense(np.prod(self.img_shape), activation='sigmoid')(x) img = Reshape(self.img_shape)(x) # model = Model(z, img) # model.summary() # img = model(z) return Model([z, l], img) def build_discriminator(self): encoded_repr = Input(shape=(self.latent_dim, )) l = Input(shape=(self.n_labels,)) x = Concatenate()([encoded_repr, l]) x = Dense(512, input_dim=self.latent_dim)(x) x = LeakyReLU(alpha=0.2)(x) x = Dense(256)(x) x = LeakyReLU(alpha=0.2)(x) x = Dense(1, activation="sigmoid")(x) # model.summary() model = Model([encoded_repr, l], x) validity = model([encoded_repr, l]) return Model([encoded_repr, l], validity) def train(self, epochs, batch_size=128, sample_interval=50): # Load the dataset (X_train, y_train), (_, _) = mnist.load_data() # Rescale 0 to 1 X_train = (X_train.astype(np.float32)) / 255. X_train = np.expand_dims(X_train, axis=3) y_train = to_categorical(y_train) # Adversarial ground truths valid = np.ones((batch_size, 1)) fake = np.zeros((batch_size, 1)) for epoch in range(epochs): # --------------------- # Train Discriminator # --------------------- # Select a random batch of images idx = np.random.randint(0, X_train.shape[0], batch_size) imgs = X_train[idx] lbls = y_train[idx] latent_fake = self.encoder.predict(imgs) latent_real = np.random.normal(size=(batch_size, self.latent_dim)) # Train the discriminator d_loss_real = self.discriminator.train_on_batch([latent_real, lbls], valid) d_loss_fake = self.discriminator.train_on_batch([latent_fake, lbls], fake) d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) # --------------------- # Train Generator # --------------------- # Train the generator g_loss = self.adversarial_autoencoder.train_on_batch([imgs, lbls], [imgs, valid]) # Plot the progress print ("%d [D loss: %f, acc: %.2f%%] [G loss: %f, mse: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss[0], g_loss[1])) # If at save interval => save generated image samples if epoch % sample_interval == 0: self.sample_images(epoch) def sample_images(self, epoch): r, c = 5, 5 z = np.random.normal(size=(r*c, self.latent_dim)) gen_imgs = self.decoder.predict([z, to_categorical( np.ones(len(z)), # np.random.randint(0, self.n_labels, len(z)), num_classes=self.n_labels, )]) gen_imgs = 0.5 * gen_imgs + 0.5 fig, axs = plt.subplots(r, c) cnt = 0 for i in range(r): for j in range(c): axs[i,j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray') axs[i,j].axis('off') cnt += 1 fig.savefig("images/mnist_%d.png" % epoch) plt.show() plt.close() def save_model(self): def save(model, model_name): model_path = "saved_model/%s.json" % model_name weights_path = "saved_model/%s_weights.hdf5" % model_name options = {"file_arch": model_path, "file_weight": weights_path} json_string = model.to_json() open(options['file_arch'], 'w').write(json_string) model.save_weights(options['file_weight']) save(self.generator, "aae_generator") save(self.discriminator, "aae_discriminator") if __name__ == '__main__': aae = AdversarialAutoencoder() aae.train(epochs=20000, batch_size=32, sample_interval=200)