- import numpy as np
- import os
- import math
-
- import tensorflow as tf
- import matplotlib.pyplot as plt
- from datetime import datetime
-
- from architectures.utils.NN_building_blocks import *
- from architectures.utils.NN_gen_building_blocks import *
- from architectures.utils.toolbox import *
-
-
- #some hyperparameters of the network
- LEARNING_RATE = None
- BETA1 = None
- COST_TYPE=None
- BATCH_SIZE = None
- EPOCHS = None
- SAVE_SAMPLE_PERIOD = None
- PATH = None
- SEED = None
- rnd_seed=1
- preprocess=None
- LAMBDA=.01
- EPS=1e-10
- CYCL_WEIGHT=None
- GAN_WEIGHT=None
- DISCR_STEPS=None
- GEN_STEPS=None
-
- min_true=None
- max_true=None
-
- min_reco=None
- max_reco=None
-
- n_H_A=None
- n_W_A=None
- n_W_B=None
- n_H_B=None
- n_C=None
-
- d_sizes=None
- g_sizes_enc=None
- g_sizes_dec=None
-
-
-
- class pix2pix(object):
-
- def __init__(
-
- self,
- n_H_A=n_H_A, n_W_A=n_W_A,
- n_H_B=n_H_B, n_W_B=n_W_B, n_C=n_C,
- min_true=min_true, max_true=max_true,
- min_reco=min_reco, max_reco=max_reco,
- d_sizes=d_sizes, g_sizes_enc=g_sizes_enc, g_sizes_dec=g_sizes_dec,
- lr=LEARNING_RATE, beta1=BETA1, preprocess=preprocess,
- cost_type=COST_TYPE, gan_weight=GAN_WEIGHT, cycl_weight=CYCL_WEIGHT,
- discr_steps=DISCR_STEPS, gen_steps=GEN_STEPS,
- batch_size=BATCH_SIZE, epochs=EPOCHS,
- save_sample=SAVE_SAMPLE_PERIOD, path=PATH, seed=SEED,
-
- ):
-
- """
-
- Positional arguments:
-
- - width of (square) image
- - number of channels of input image
- - discriminator sizes
-
- a python dict of the kind
- d_sizes = { 'convblocklayer_n':[(n_c+1, kernel, stride, apply_batch_norm, weight initializer),
- (,,,,),
- (,,,,),
- ],
- 'convblock_shortcut_layer_n':[(,,,)],
- 'dense_layers':[(n_o, apply_bn, weight_init)]
- }
- - generator sizes
-
- a python dictionary of the kind
-
- g_sizes = {
- 'z':latent_space_dim,
- 'projection': int,
- 'bn_after_project':bool
-
- 'deconvblocklayer_n':[(n_c+1, kernel, stride, apply_batch_norm, weight initializer),
- (,,,,),
- (,,,,),
- ],
- 'deconvblock_shortcut_layer_n':[(,,,)],
- 'dense_layers':[(n_o, apply_bn, weight_init)]
- 'activation':function
- }
-
- Keyword arguments:
-
- - lr = LEARNING_RATE (float32)
- - beta1 = ema parameter for adam opt (float32)
- - batch_size (int)
- - save_sample = after how many batches iterations to save a sample (int)
- - path = relative path for saving samples
-
- """
-
-
- self.max_reco=max_reco
- self.min_reco = min_reco
-
- self.max_true=max_true
- self.min_true=min_true
-
- self.seed=seed
-
- self.n_W_A = n_W_A
- self.n_H_A = n_H_A
-
- self.n_W_B = n_W_B
- self.n_H_B = n_H_B
-
- self.n_C = n_C
-
- self.batch_sz = tf.placeholder(
- tf.int32,
- shape=(),
- name='batch_sz'
- )
-
- self.input_A = tf.placeholder(
- tf.float32,
- shape=(None,
- n_H_A, n_W_A, n_C),
- name='X_A',
- )
-
- self.input_B = tf.placeholder(
- tf.float32,
- shape=(None,
- n_H_B, n_W_B, n_C),
- name='X_B',
- )
-
- self.input_test_A = tf.placeholder(
- tf.float32,
- shape=(None,
- n_H_A, n_W_A, n_C),
- name='X_test_A'
- )
-
- D = pix2pixDiscriminator(self.input_A, d_sizes, 'B')
- G = pix2pixGenerator(self.input_A, n_H_B, n_W_B, g_sizes_enc, g_sizes_dec, 'A_to_B')
-
- with tf.variable_scope('generator_A_to_B') as scope:
-
- sample_images = G.g_forward(self.input_A)
-
- with tf.variable_scope('discriminator_B') as scope:
-
- predicted_real = D.d_forward(self.input_A, self.input_B)
-
- with tf.variable_scope('discriminator_B') as scope:
- scope.reuse_variables()
- predicted_fake = D.d_forward(self.input_A, sample_images, reuse=True)
-
- #get sample images at test time
- with tf.variable_scope('generator_A_to_B') as scope:
- scope.reuse_variables()
- self.sample_images_test_A_to_B = G.g_forward(
- self.input_test_A, reuse=True, is_training=False
- )
-
- self.d_params = [t for t in tf.trainable_variables() if 'discriminator' in t.name]
- self.g_params = [t for t in tf.trainable_variables() if 'generator' in t.name]
-
-
- if cost_type == 'GAN':
-
- #Discriminator cost
-
- predicted_real= tf.nn.sigmoid(predicted_real)
- predicted_fake=tf.nn.sigmoid(predicted_fake)
-
- d_cost_real = tf.log(predicted_real + EPS)
-
- #d_cost_fake is low if fake images are predicted as real
- d_cost_fake = tf.log(1 - predicted_fake +EPS)
-
- self.d_cost = tf.reduce_mean(-(d_cost_real + d_cost_fake))
-
-
- # #Discriminator cost
- # #d_cost_real is low if real images are predicted as real
- # d_cost_real = tf.nn.sigmoid_cross_entropy_with_logits(
- # logits = predicted_real,
- # labels = tf.ones_like(predicted_real)-0.01
- # )
- # #d_cost_fake is low if fake images are predicted as real
- # d_cost_fake = tf.nn.sigmoid_cross_entropy_with_logits(
- # logits = predicted_fake,
- # labels = tf.zeros_like(predicted_fake)+0.01
- # )
-
- # self.d_cost = tf.reduce_mean(d_cost_real)+ tf.reduce_mean(d_cost_fake)
-
- #Generator cost
- #g_cost is low if logits from discriminator on samples generated by generator are predicted as true (1)
- self.g_cost_GAN = tf.reduce_mean(-tf.log(predicted_fake + EPS))
-
- # self.g_cost_GAN = tf.reduce_mean(
- # tf.nn.sigmoid_cross_entropy_with_logits(
- # logits=predicted_fake,
- # labels=tf.ones_like(predicted_fake)-0.01
- # )
- # )
- self.g_cost_l1 = tf.reduce_mean(tf.square(self.input_B - sample_images))
- #self.g_cost_sum = tf.abs(tf.reduce_sum(self.input_B)-tf.reduce_sum(sample_images))
- self.g_cost=gan_weight*self.g_cost_GAN + cycl_weight*self.g_cost_l1
-
- if cost_type == 'WGAN-gp':
-
-
- self.g_cost_GAN = -tf.reduce_mean(predicted_fake)
-
- self.g_cost_l1 = tf.reduce_mean(tf.abs(self.input_B - sample_images))
- self.g_cost=gan_weight*self.g_cost_GAN+cycl_weight*self.g_cost_l1
-
-
- self.d_cost = tf.reduce_mean(predicted_fake) - tf.reduce_mean(predicted_real)
-
- alpha = tf.random_uniform(
- shape=[self.batch_sz,self.n_H_A,self.n_W_A,self.n_C],
- minval=0.,
- maxval=1.
- )
-
- # interpolates_1 = alpha*self.input_A+(1-alpha)*sample_images
- interpolates = alpha*self.input_B+(1-alpha)*sample_images
-
- with tf.variable_scope('discriminator_B') as scope:
- scope.reuse_variables()
- disc_interpolates = D.d_forward(self.input_A, interpolates,reuse = True)
-
- gradients = tf.gradients(disc_interpolates,[interpolates])[0]
- slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
- self.gradient_penalty = tf.reduce_mean((slopes-1)**2)
- self.d_cost+=LAMBDA*self.gradient_penalty
-
- self.d_train_op = tf.train.AdamOptimizer(
- learning_rate=lr,
- beta1=beta1,
- beta2=0.9
- ).minimize(
- self.d_cost,
- var_list=self.d_params
- )
-
- self.g_train_op = tf.train.AdamOptimizer(
- learning_rate=lr,
- beta1=beta1,
- beta2=0.9
- ).minimize(
- self.g_cost,
- var_list=self.g_params
- )
-
- #saving for later
- self.batch_size=batch_size
- self.epochs=epochs
- self.save_sample=save_sample
- self.path=path
- self.lr = lr
-
- self.D=D
- self.G=G
-
- self.sample_images=sample_images
- self.preprocess=preprocess
- self.cost_type=cost_type
- self.cycl_weight=cycl_weight
-
- self.gen_steps=gen_steps
- self.discr_steps=discr_steps
-
- def set_session(self,session):
-
- self.session = session
-
- for layer in self.D.d_conv_layers:
- layer.set_session(session)
-
- for layer in self.G.g_enc_conv_layers:
- layer.set_session(session)
-
- for layer in self.G.g_dec_conv_layers:
- layer.set_session(session)
-
- def fit(self, X_A, X_B, validating_size):
-
- all_A = X_A
- all_B = X_B
- gen_steps = self.gen_steps
- discr_steps = self.discr_steps
-
- m = X_A.shape[0]
- train_A = all_A[0:m-validating_size]
- train_B = all_B[0:m-validating_size]
-
- validating_A = all_A[m-validating_size:m]
- validating_B = all_B[m-validating_size:m]
-
- seed=self.seed
-
- d_costs=[]
- d_gps=[]
- g_costs=[]
- g_GANs=[]
- g_l1s=[]
- N=len(train_A)
- n_batches = N // self.batch_size
-
- total_iters=0
-
- print('\n ****** \n')
- print('Training pix2pix (from 1611.07004) GAN with a total of ' +str(N)+' samples distributed in '+ str((N)//self.batch_size) +' batches of size '+str(self.batch_size)+'\n')
- print('The validation set consists of {0} images'.format(validating_A.shape[0]))
- print('The learning rate is '+str(self.lr)+', and every ' +str(self.save_sample)+ ' batches a generated sample will be saved to '+ self.path)
- print('\n ****** \n')
-
- for epoch in range(self.epochs):
-
- seed+=1
- print('Epoch:', epoch)
-
- batches_A = unsupervised_random_mini_batches(train_A, self.batch_size, seed)
- batches_B = unsupervised_random_mini_batches(train_B, self.batch_size, seed)
-
- for X_batch_A, X_batch_B in zip(batches_A, batches_B):
-
- bs=X_batch_A.shape[0]
-
- t0 = datetime.now()
-
- g_cost=0
- g_GAN=0
- g_l1=0
- for i in range(gen_steps):
-
- _, g_cost, g_GAN, g_l1 = self.session.run(
- (self.g_train_op, self.g_cost, self.g_cost_GAN, self.g_cost_l1),
- feed_dict={self.input_A:X_batch_A, self.input_B:X_batch_B, self.batch_sz:bs},
- )
- g_cost+=g_cost
- g_GAN+=g_GAN
- g_l1+=g_l1
-
- g_costs.append(g_cost/gen_steps)
- g_GANs.append(g_GAN/gen_steps)
- g_l1s.append(self.cycl_weight*g_l1/gen_steps)
-
- d_cost=0
- d_gp=0
- for i in range(discr_steps):
-
- if self.cost_type=='WGAN-gp':
- _, d_cost, d_gp = self.session.run(
- (self.d_train_op, self.d_cost, self.gradient_penalty),
- feed_dict={self.input_A:X_batch_A, self.input_B:X_batch_B, self.batch_sz:bs},
- )
-
- d_gp+=d_gp
-
-
- else:
- _, d_cost = self.session.run(
- (self.d_train_op, self.d_cost),
- feed_dict={self.input_A:X_batch_A, self.input_B:X_batch_B, self.batch_sz:bs},
- )
-
- d_cost+=d_cost
-
-
-
- d_costs.append(d_cost/discr_steps)
- if self.cost_type=='WGAN-gp':
- d_gps.append(LAMBDA*d_gp/discr_steps)
-
- total_iters+=1
- if total_iters % self.save_sample ==0:
-
- plt.clf()
- print("At iter: %d - dt: %s" % (total_iters, datetime.now() - t0))
- print("Discriminator cost {0:.4g}, Generator cost {1:.4g}".format(d_costs[-1], g_costs[-1]))
- print('Saving a sample...')
-
- if self.preprocess!=False:
- draw_nn_sample(validating_A, validating_B, 1, self.preprocess,
- self.min_true, self.max_true,
- self.min_reco, self.max_reco,
- f=self.get_sample_A_to_B, is_training=True,
- total_iters=total_iters, PATH=self.path)
- else:
- draw_nn_sample(validating_A, validating_B, 1, self.preprocess,
- f=self.get_sample_A_to_B, is_training=True,
- total_iters=total_iters, PATH=self.path)
- plt.clf()
- plt.subplot(1,2,1)
- plt.plot(d_costs, label='Discriminator GAN cost')
- plt.plot(g_GANs, label='Generator GAN cost')
- plt.xlabel('Epoch')
- plt.ylabel('Cost')
- plt.legend()
-
- plt.subplot(1,2,2)
- plt.plot(g_costs, label='Generator total cost')
- plt.plot(g_GANs, label='Generator GAN cost')
- plt.plot(g_l1s, label='Generator l1 cycle cost')
- plt.xlabel('Epoch')
- plt.ylabel('Cost')
- plt.legend()
-
- fig = plt.gcf()
- fig.set_size_inches(15,5)
- plt.savefig(self.path+'/cost_iteration_gen_disc_B_to_A.png',dpi=150)
-
-
- def get_sample_A_to_B(self, Z):
-
- one_sample = self.session.run(
- self.sample_images_test_A_to_B,
- feed_dict={self.input_test_A:Z, self.batch_sz: 1})
-
- return one_sample
-
- def get_samples_A_to_B(self, Z):
-
- many_samples = self.session.run(
- self.sample_images_test_A_to_B,
- feed_dict={self.input_test_A:Z, self.batch_sz: Z.shape[0]})
-
- return many_samples