Sunday, March 3, 2024

Generative Adversarial Network (GAN) , DCGAN, Tutorial and Keras Implementation


GANs, called Generative Adversarial Networks, are special types of deep learning models that have two main parts: a generator and a discriminator. The generator creates fake data, and the discriminator checks if this data looks real or not compared to real data. By training against each other, GANs get better at making data that looks real, changing how we make new images, expand datasets, and learn without supervision. The following points show the significance of GAN in the area of AI.
  • Creative Applications: GANs create realistic images, music, text, and videos, enabling creativity in art, content, and virtual environments.
  • Data Augmentation: GANs generate synthetic data to enhance small datasets, improving the performance of machine learning models.
  • Defense Against Deepfakes: GANs are used to develop defenses and detect manipulated media content amidst the growth of deepfake technology.
  • Drug Discovery and Molecular Design: GANs play an increasing role in drug discovery, producing novel molecular structures with desired properties, potentially transforming the pharmaceutical industry.


This article comprises interactive video tutorials and code demonstrations to elucidate the GAN architecture. The discussion covers various topics, culminating in the presentation of straightforwardly designed code.

GAN Part-1

GAN Part-2

GAN Part-3.

Keras implementation of GAN 

The following contains the Kera implementation of Deep Convolution GAN (DCGAN). Please go through the above video tutorials, to properly understand and use the code. 
The system is built using Python 3.10 and relies on several essential library dependencies:
  • Tensorflow (version 2.15)
  • tqdm (version 4.66.2)
  • h5py (version 3.10)
  • Keras (version 2.115)

Train DCGAN.

# example of training a gan on mnist
from numpy import expand_dims
from tqdm import tqdm
import keras
import tensorflow as tf
import numpy as np
from keras import Model
from keras.optimizers import Adam
from keras.layers import Input, Reshape, Flatten
from keras.layers import Dense, BatchNormalization, Conv2D, Conv2DTranspose, LeakyReLU, Dropout
batch_size = 32
input_shape = (28, 28, 1)
latent_dim = 100
img_shape = (28, 28, 1)
class GAN_1:
def __init__(self):
print("welcome to GAN coding")
# This code prepares a TensorFlow dataset for training by shuffling the data, batching it into
# consistent batch sizes, and prefetching batches to optimize data loading during training.
def preprocess_real_part_training_dataset(self, batch_size):
# load mnist dataset
(dataX, dataY), (testDX, testDY) = keras.datasets.fashion_mnist.load_data()
# Add an additional dimension for the grayscale channel by using expand_dims() from NumPy
dataX = expand_dims(dataX, axis=-1)
# convert from unsigned ints to floats and scale from [0,255] to [0,1]
dataX = dataX.astype(np.float32) / 255.0
# testDX = testDX.astype(np.float32) / 255.0
trainX =
# Combines consecutive elements of this dataset into batches.
trainX = trainX.batch(batch_size, drop_remainder=True).prefetch(1)
return trainX

# latent_dim = 100
# img_shape = (28, 28, 1)
def define_generator(self, latent_dim, img_shape):
inputs = Input(shape=latent_dim)
# Project and reshape the input
proj = Dense(128 * 7 * 7)(inputs)
proj = Reshape((7, 7, 128))(proj)
# Upsample to 14x14
upsample_1 = Conv2DTranspose(filters=128, kernel_size=4, strides=2, padding='same', activation=LeakyReLU(alpha=0.2),)(proj)
upsample_1 = BatchNormalization()(upsample_1)
# Upsample to 28x28
upsample_2 = Conv2DTranspose(filters=128, kernel_size=4, strides=2, padding='same', activation=LeakyReLU(alpha=0.2),)(upsample_1)
upsample_2 = BatchNormalization()(upsample_2)
# Generate output image (28x28x1)
gen_output = Conv2D(filters=img_shape[2], kernel_size=7, activation='sigmoid', padding='same')(upsample_2)
g_model = Model(inputs, gen_output)
# compile model
g_model.compile(loss='binary_crossentropy', optimizer=Adam(lr=0.0002, beta_1=0.5), metrics=['accuracy'])
return g_model

# input_shape = (28, 28, 1)
def define_descriminator(self, input_shape):
inputs = Input(shape=input_shape)
# convolution layers
conv1 = Conv2D(filters=64, kernel_size=3, strides=2, activation=LeakyReLU(alpha=0.2), padding='same')(inputs)
conv1 = Dropout(0.4)(conv1)
conv1 = Conv2D(filters=128, kernel_size=3, strides=2, activation=LeakyReLU(alpha=0.2), padding='same')(conv1)
conv1 = Dropout(0.4)(conv1)
conv1 = Conv2D(filters=256, kernel_size=3, strides=2, activation=LeakyReLU(alpha=0.2), padding='same')(conv1)
conv1 = Dropout(0.4)(conv1)
# Flatten Layer
flatten_layer = Flatten()(conv1)
discriminator_decision_layer = Dense(1, activation='sigmoid')(flatten_layer)
d_model = Model(inputs, discriminator_decision_layer)
# compile model
d_model.compile(loss='binary_crossentropy', optimizer=Adam(lr=0.0002, beta_1=0.5), metrics=['accuracy'])
return d_model

def define_gan(self,latent_dim0, img_shape0):
# Define the input for the generator
latent_input = Input(shape=(latent_dim0,))
# Build the generator
generator_output = self.define_generator(latent_dim=latent_dim0,img_shape=img_shape0)(latent_input)
# Build the discriminator
discriminator_input = Input(shape=img_shape0)
discriminator_output = self.define_descriminator(input_shape=img_shape0)(discriminator_input)
# Compile the discriminator
discriminator = Model(discriminator_input, discriminator_output)
discriminator.compile(loss="binary_crossentropy", optimizer=Adam(lr=0.0002, beta_1=0.5))
# Make the discriminator not trainable
discriminator.trainable = False
# Combine the generator and discriminator
gan_output = discriminator(generator_output)
gan_model = Model(latent_input, gan_output)
# Compile the GAN
gan_model.compile(loss="binary_crossentropy", optimizer="adam")
return gan_model

def train_save_models(self, input_shape, latent_dim, img_shape, n_epochs=2, n_batch=256):
# manually enumerate epochs
g_model = self.define_generator(latent_dim=latent_dim,img_shape=img_shape)
d_model = self.define_descriminator(input_shape)
gan_main = self.define_gan(latent_dim0=latent_dim,img_shape0=img_shape)
for i in tqdm(range(n_epochs)):
print("Epoch {}/{}".format(i + 1, n_epochs))
# enumerate batches over the training set
for X_batch in trainX:
# generate random noise as an input to initialize the generator
noise = tf.random.normal(shape=[batch_size, latent_dim])
generated_images = g_model(noise)
# print("shape of noise => ",np.shape(noise))
X_fake_and_real = tf.concat([generated_images, X_batch], axis=0)
y1 = tf.constant([[0.]] * batch_size + [[1.]] * batch_size)
d_loss = d_model.train_on_batch(x=X_fake_and_real,y=y1)
noise1 = tf.random.normal(shape=[batch_size, latent_dim])
# print("shape of noise1 => ", np.shape(noise1))
y2 = tf.constant([[1.]] * batch_size)
gan_loss = gan_main.train_on_batch(noise1, y2)
print("discriminator loss =>",d_loss, " Gan-Loss => ",gan_loss)"g_model.h5")"d_model.h5")"gan_model.h5")

if __name__ == "__main__":
print ("Executed when invoked directly")
input_shape1 = (28, 28, 1)
img_shape1 = (28, 28, 1)
latent_dim1 = 100
# Create some dog objects
gan1 = GAN_1()
trainX = gan1.preprocess_real_part_training_dataset(batch_size=32)
d_model = gan1.define_descriminator(input_shape=img_shape1)
# visualkeras.layered_view(d_model)
# visualkeras.layered_view(d_model, legend=True)
g_model = gan1.define_generator(latent_dim=latent_dim1, img_shape=img_shape1)
gan_model = gan1.define_gan(latent_dim0=latent_dim1,img_shape0=img_shape1)
gan1.train_save_models(input_shape=input_shape1,latent_dim=latent_dim1, img_shape=img_shape1,n_epochs=50,n_batch=32)

Test the trained GAN model.

# example of loading the generator model and generating images
import numpy as np
from keras.models import load_model
from numpy.random import randn
from keras.models import load_model
from matplotlib import pyplot
import matplotlib.pyplot as plt
# load model
model = load_model('g_model.h5')
# Generate synthetic images
num_images = 10
latent_dim = 100
noise = np.random.normal(0, 1, (num_images, latent_dim))
generated_images = model.predict(noise)

# Plot the generated images
plt.figure(figsize=(10, 10))
for i in range(num_images):
plt.subplot(1, num_images, i+1)
plt.imshow(generated_images[i, :, :, 0], cmap='gray')

NOTE: This code is not intended for any commercial use. It is created solely for simple educational purposes. 

Niraj Kumar