Newer
Older
HCAL_project / architectures / pix2pix_cycleDisc.py
@Davide Lancierini Davide Lancierini on 2 Dec 2018 13 KB First commit
import numpy as np
import os 
import math

import tensorflow as tf
import matplotlib.pyplot as plt
from datetime import datetime

from architectures.utils.NN_building_blocks import *
from architectures.utils.NN_gen_building_blocks import *
from architectures.utils.toolbox import *


# some dummy constants
LEARNING_RATE = None
BETA1 = None
COST_TYPE=None
BATCH_SIZE = None
EPOCHS = None
SAVE_SAMPLE_PERIOD = None
PATH = None
SEED = None
rnd_seed=1
preprocess=None
LAMBDA=.01
EPS=1e-10
CYCL_WEIGHT=None
GAN_WEIGHT=None
DISCR_STEPS=None
GEN_STEPS=None
max_true=None
max_reco=None

n_H_A=None
n_W_A=None
n_W_B=None
n_H_B=None
n_C=None

d_sizes=None
g_sizes_enc=None
g_sizes_dec=None


class pix2pix_cycleDisc(object):

    def __init__(

        self,
        n_H_A=n_H_A, n_W_A=n_W_A,
        n_H_B=n_H_B, n_W_B=n_W_B, n_C=n_C,
        max_true=max_true, max_reco=max_reco,
        d_sizes=d_sizes, g_sizes_enc=g_sizes_enc, g_sizes_dec=g_sizes_dec,
        lr=LEARNING_RATE, beta1=BETA1, preprocess=preprocess,
        cost_type=COST_TYPE, gan_weight=GAN_WEIGHT, cycl_weight=CYCL_WEIGHT,
        discr_steps=DISCR_STEPS, gen_steps=GEN_STEPS,
        batch_size=BATCH_SIZE, epochs=EPOCHS,
        save_sample=SAVE_SAMPLE_PERIOD, path=PATH, seed=SEED,

        ):

        """

        Positional arguments:

        - width of (square) image
        - number of channels of input image
        - discriminator sizes

        a python dict of the kind
            d_sizes = { 'convblocklayer_n':[(n_c+1, kernel, stride, apply_batch_norm, weight initializer),
                                           (,,,,),
                                           (,,,,),
                                           ],
                        'convblock_shortcut_layer_n':[(,,,)],
                        'dense_layers':[(n_o, apply_bn, weight_init)]
                        }
        - generator sizes

        a python dictionary of the kind

            g_sizes = { 
                        'z':latent_space_dim,
                        'projection': int,
                        'bn_after_project':bool

                        'deconvblocklayer_n':[(n_c+1, kernel, stride, apply_batch_norm, weight initializer),
                                           (,,,,),
                                           (,,,,),
                                           ],
                        'deconvblock_shortcut_layer_n':[(,,,)],
                        'dense_layers':[(n_o, apply_bn, weight_init)]
                        'activation':function
                        }

        Keyword arguments:

        - lr = LEARNING_RATE (float32)
        - beta1 = ema parameter for adam opt (float32)
        - batch_size (int)
        - save_sample = after how many batches iterations to save a sample (int)
        - path = relative path for saving samples

        """

        self.max_true = max_true
        self.max_reco = max_reco

        self.seed=seed

        self.n_W_A = n_W_A
        self.n_H_A = n_H_A

        self.n_W_B = n_W_B
        self.n_H_B = n_H_B
        self.n_C = n_C

        self.input_A = tf.placeholder(
        tf.float32,
        shape=(None, 
           n_H_A, n_W_A, n_C),
        name='X_A',
        )

        self.input_B = tf.placeholder(
        tf.float32,
        shape=(None, 
           n_H_B, n_W_B, n_C),
        name='X_B',
        )

        self.batch_sz = tf.placeholder(
        tf.int32, 
        shape=(), 
        name='batch_sz'
        )
        self.lr = tf.placeholder(
        tf.float32, 
        shape=(), 
        name='lr'
        )

        D = Discriminator(self.input_A, d_sizes, 'B')
        G = pix2pixGenerator(self.input_A, self.n_H_B, self.n_W_B, g_sizes_enc, g_sizes_dec, 'A_to_B')

        with tf.variable_scope('discriminator') as scope:

            logits = D.d_forward(self.input_B)

        with tf.variable_scope('generator') as scope:

            sample_images = G.g_forward(self.input_A)

        with tf.variable_scope('discriminator') as scope:
            scope.reuse_variables()
            sample_logits = D.d_forward(sample_images, reuse=True)

        self.input_test_A = tf.placeholder(
            tf.float32,
            shape=(None,
                   n_H_A, n_W_A, n_C),
            name='X_test_A'
        )
        #get sample images at test time
        with tf.variable_scope('generator') as scope:
            scope.reuse_variables()
            self.sample_images_test_A_to_B = G.g_forward(
                self.input_test_A, reuse=True, is_training=False
            )

        self.d_params = [t for t in tf.trainable_variables() if 'discriminator' in t.name]
        self.g_params = [t for t in tf.trainable_variables() if 'generator' in t.name]

        if cost_type == 'GAN':

            #Discriminator cost
            #d_cost_real is low if real images are predicted as real
            d_cost_real = tf.nn.sigmoid_cross_entropy_with_logits(
                logits = logits,
                labels = tf.ones_like(logits)
            )
            #d_cost_fake is low if fake images are predicted as real
            d_cost_fake = tf.nn.sigmoid_cross_entropy_with_logits(
                logits = sample_logits,
                labels = tf.zeros_like(logits)
                )

            self.d_cost = tf.reduce_mean(d_cost_real)+ tf.reduce_mean(d_cost_fake)

            #Generator cost 
            #g_cost is low if logits from discriminator on samples generated by generator 
            #are predicted as true (1)
            self.g_cost_GAN = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    logits=sample_logits,
                    labels=tf.ones_like(sample_logits)
                )
            )
            self.g_cost_l1 = tf.reduce_mean(tf.abs(self.input_B - sample_images))

            self.g_cost=gan_weight*self.g_cost_GAN+cycl_weight*self.g_cost_l1

        if cost_type == 'WGAN-gp':
            self.d_cost = tf.reduce_mean(sample_logits) - tf.reduce_mean(logits)
            g_cost_GAN = -tf.reduce_mean(sample_logits)

            g_cost_l1 = tf.reduce_mean(tf.abs(self.input_B - sample_images))

            self.g_cost=gan_weight*g_cost_GAN+cycl_weight*g_cost_l1

            alpha = tf.random_uniform(
                shape=[self.batch_sz,self.n_H_A,self.n_W_A,self.n_C],
                minval=0.,
                maxval=1.
            )

            interpolated = alpha*self.input_B+(1-alpha)*sample_images

            with tf.variable_scope('discriminator') as scope:
                scope.reuse_variables()
                disc_interpolates = D.d_forward(interpolated,reuse = True)

            gradients = tf.gradients(disc_interpolates,[interpolated])[0]
            slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=[1, 2, 3]))
            gradient_penalty = tf.reduce_mean(tf.square(slopes-1))
            self.d_cost+=LAMBDA*gradient_penalty

        self.d_train_op = tf.train.AdamOptimizer(
            learning_rate=lr,
            beta1=beta1,
            beta2=0.9
            ).minimize(
            self.d_cost,
            var_list=self.d_params
        )

        self.g_train_op = tf.train.AdamOptimizer(
            learning_rate=lr,
            beta1=beta1,
            beta2=0.9
            ).minimize(
            self.g_cost,
            var_list=self.g_params
        )

        #saving for later
        self.batch_size=batch_size
        self.epochs=epochs
        self.save_sample=save_sample
        self.path=path
        self.lr = lr

        self.D=D
        self.G=G

        self.sample_images=sample_images
        self.preprocess=preprocess
        self.cost_type=cost_type
        self.cycl_weight=cycl_weight

        self.gen_steps=gen_steps
        self.discr_steps=discr_steps

    def set_session(self,session):

        self.session = session

        for layer in self.D.d_conv_layers:
            layer.set_session(session)

        for layer in self.G.g_enc_conv_layers:
            layer.set_session(session)

        for layer in self.G.g_dec_conv_layers:
            layer.set_session(session)

    def fit(self, X_A, X_B, validating_size):

        all_A = X_A
        all_B = X_B
        gen_steps = self.gen_steps
        discr_steps = self.discr_steps

        m = X_A.shape[0]
        train_A = all_A[0:m-validating_size]
        train_B = all_B[0:m-validating_size]

        validating_A = all_A[m-validating_size:m]
        validating_B = all_B[m-validating_size:m]

        seed=self.seed

        d_costs=[]
        d_gps=[]
        g_costs=[]
        g_GANs=[]
        g_l1s=[]
        N=len(train_A)
        n_batches = N // self.batch_size

        total_iters=0

        print('\n ****** \n')
        print('Training pix2pix (from 1611.07004) GAN with a total of ' +str(N)+' samples distributed in '+ str((N)//self.batch_size) +' batches of size '+str(self.batch_size)+'\n')
        print('The validation set consists of {0} images'.format(validating_A.shape[0]))
        print('The learning rate is '+str(self.lr)+', and every ' +str(self.save_sample)+ ' batches a generated sample will be saved to '+ self.path)
        print('\n ****** \n')

        for epoch in range(self.epochs):

            seed+=1
            print('Epoch:', epoch)

            batches_A = unsupervised_random_mini_batches(train_A, self.batch_size, seed)
            batches_B = unsupervised_random_mini_batches(train_B, self.batch_size, seed)

            for X_batch_A, X_batch_B in zip(batches_A, batches_B):

                bs=X_batch_A.shape[0]

                t0 = datetime.now()

                g_cost=0
                g_GAN=0
                g_l1=0
                for i in range(gen_steps):

                    _, g_cost, g_GAN, g_l1 = self.session.run(
                    (self.g_train_op, self.g_cost, self.g_cost_GAN, self.g_cost_l1),
                    feed_dict={self.input_A:X_batch_A, self.input_B:X_batch_B, self.batch_sz:bs},
                    )
                    g_cost+=g_cost
                    g_GAN+=g_GAN
                    g_l1+=g_l1

                g_costs.append(g_cost/gen_steps)
                g_GANs.append(g_GAN/gen_steps)
                g_l1s.append(self.cycl_weight*g_l1/gen_steps)

                d_cost=0
                d_gp=0
                for i in range(discr_steps):

                    if self.cost_type=='WGAN-gp':
                        _, d_cost, d_gp = self.session.run(
                        (self.d_train_op, self.d_cost, self.gradient_penalty),
                        feed_dict={self.input_A:X_batch_A, self.input_B:X_batch_B, self.batch_sz:bs},
                        )

                        d_gp+=d_gp


                    else:
                        _, d_cost = self.session.run(
                        (self.d_train_op, self.d_cost),
                        feed_dict={self.input_A:X_batch_A, self.input_B:X_batch_B, self.batch_sz:bs},
                        )
                    
                    d_cost+=d_cost
                    

                    
                d_costs.append(d_cost/discr_steps)
                if self.cost_type=='WGAN-gp':
                    d_gps.append(LAMBDA*d_gp/discr_steps)

                total_iters+=1
                if total_iters % self.save_sample ==0:

                    print("At iter: %d  -  dt: %s" % (total_iters, datetime.now() - t0))
                    print("Discriminator cost {0:.4g}, Generator cost {1:.4g}".format(d_costs[-1], g_costs[-1]))
                    print('Saving a sample...')

                    if self.preprocess!=False:
                        draw_nn_sample(validating_A, validating_B, 1, self.preprocess,
                                        self.max_true, self.max_reco,
                                        f=self.get_sample_A_to_B, is_training=True,
                                        total_iters=total_iters, PATH=self.path)
                    else:
                        draw_nn_sample(validating_A, validating_B, 1, self.preprocess,
                                        f=self.get_sample_A_to_B, is_training=True,
                                        total_iters=total_iters, PATH=self.path)
            plt.clf()
            plt.subplot(1,2,1)
            plt.plot(d_costs, label='Discriminator GAN cost')
            plt.plot(g_GANs, label='Generator GAN cost')
            plt.xlabel('Epoch')
            plt.ylabel('Cost')
            plt.legend()
            
            plt.subplot(1,2,2)
            plt.plot(g_costs, label='Generator total cost')
            plt.plot(g_GANs, label='Generator GAN cost')
            plt.plot(g_l1s, label='Generator l1 cycle cost')
            plt.xlabel('Epoch')
            plt.ylabel('Cost')
            plt.legend()

            fig = plt.gcf()
            fig.set_size_inches(15,5)
            plt.savefig(self.path+'/cost_iteration_gen_disc_B_to_A.png',dpi=150)


    def get_sample_A_to_B(self, Z):
        
        one_sample = self.session.run(
            self.sample_images_test_A_to_B, 
            feed_dict={self.input_test_A:Z, self.batch_sz: 1})

        return one_sample 

    def get_samples_A_to_B(self, Z):
        
        many_samples = self.session.run(
            self.sample_images_test_A_to_B, 
            feed_dict={self.input_test_A:Z, self.batch_sz: Z.shape[0]})

        return many_samples