Newer
Older
HCAL_project / architectures / bicycle_GAN.py
import numpy as np
import os 
import math

import tensorflow as tf
import matplotlib.pyplot as plt
from datetime import datetime

from architectures.utils.NN_building_blocks import *
from architectures.utils.NN_gen_building_blocks import *
from architectures.utils.arbitrary_act import *

def lrelu(x, alpha=0.2):
    return tf.maximum(alpha*x,x)

# def tf_HCAL_act(x, mask):

#     output = tf.where(mask,
#                       tf.multiply(6.120,tf.nn.sigmoid(tf.subtract(x,5.))), 
#                       tf.multiply(1.530,tf.nn.sigmoid(tf.subtract(x,5.))) 
#                         )
    
#     return output

def tf_HCAL_act(x, mask):
    
    output = tf.where(mask,
                      tf.maximum(0.,tf.minimum(6120.,x)) , 
                      tf.maximum(0.,tf.minimum(1530.,x)) 
                        )

    return output

pretrain=None
# some dummy constants
LEARNING_RATE = None
BETA1 = None
COST_TYPE=None
BATCH_SIZE = None
EPOCHS = None
SAVE_SAMPLE_PERIOD = None
PATH = None
SEED = None
rnd_seed=1
PREPROCESS=None
LAMBDA=.01
EPS=1e-6
CYCL_WEIGHT=None
LATENT_WEIGHT=None
KL_WEIGHT=None
DISCR_STEPS=None
GEN_STEPS=None
VAE_STEPS=None

min_true=None
max_true=None

min_reco=None
max_reco=None

n_H_A=None
n_W_A=None
n_W_B=None
n_H_B=None
n_C=None

d_sizes=None
g_sizes_enc=None
g_sizes_dec=None
e_sizes=None


class bicycle_GAN(object):
    #fix args
    def __init__(

        self, 
        n_H_A=n_H_A, n_W_A=n_W_A,
        n_H_B=n_H_B, n_W_B=n_W_B, n_C=n_C,
        min_true=min_true, max_true=max_true, 
        min_reco=min_reco, max_reco=max_reco,
        d_sizes=d_sizes, g_sizes_enc=g_sizes_enc, g_sizes_dec=g_sizes_dec, e_sizes=e_sizes, 
        pretrain=pretrain, lr=LEARNING_RATE, beta1=BETA1, preprocess=PREPROCESS,
        cost_type=COST_TYPE, cycl_weight=CYCL_WEIGHT, latent_weight=LATENT_WEIGHT, kl_weight=KL_WEIGHT,
        discr_steps=DISCR_STEPS, gen_steps=GEN_STEPS, vae_steps=VAE_STEPS,
        batch_size=BATCH_SIZE, epochs=EPOCHS,
        save_sample=SAVE_SAMPLE_PERIOD, path=PATH, seed=SEED,
        ):

        """

        Positional arguments:

            - width of (square) image
            - number of channels of input image
            - discriminator sizes

                a python dict of the kind
                    d_sizes = { 'convblocklayer_n':[(n_c+1, kernel, stride, apply_batch_norm, weight initializer),
                                                   (,,,,),
                                                   (,,,,),
                                                   ],
                                'convblock_shortcut_layer_n':[(,,,)],
                                'dense_layers':[(n_o, apply_bn, weight_init)]
                                }
            - generator sizes

                a python dictionary of the kind

                    g_sizes = { 
                                'z':latent_space_dim,
                                'projection': int,
                                'bn_after_project':bool

                                'deconvblocklayer_n':[(n_c+1, kernel, stride, apply_batch_norm, weight initializer),
                                                   (,,,,),
                                                   (,,,,),
                                                   ],
                                'deconvblock_shortcut_layer_n':[(,,,)],
                                'dense_layers':[(n_o, apply_bn, weight_init)]
                                'activation':function
                                }

        Keyword arguments:

            - lr = LEARNING_RATE (float32)
            - beta1 = ema parameter for adam opt (float32)
            - batch_size (int)
            - save_sample = after how many batches iterations to save a sample (int)
            - path = relative path for saving samples

        """

        latent_dims=e_sizes['latent_dims']

        self.min_true=min_true
        self.max_true=max_true

        self.min_reco=min_reco 
        self.max_reco=max_reco
        self.seed=seed

        self.n_W_A = n_W_A
        self.n_H_A = n_H_A

        self.n_W_B = n_W_B
        self.n_H_B = n_H_B
        self.n_C = n_C 

        #input data
        
        self.input_A = tf.placeholder(
            tf.float32,
            shape=(None, 
                   n_H_A, n_W_A, n_C),
            name='X_A',
        )

        self.input_B = tf.placeholder(
            tf.float32,
            shape=(None, 
                   n_H_B, n_W_B, n_C),
            name='X_B',
        )
        
        self.batch_sz = tf.placeholder(
            tf.int32, 
            shape=(), 
            name='batch_sz'
        )
        self.lr = tf.placeholder(
            tf.float32, 
            shape=(), 
            name='lr'
        )

        self.z = tf.placeholder(
            tf.float32,
            shape=(None,
                    latent_dims)
        )

        self.input_test_A = tf.placeholder(
                    tf.float32,
                    shape=(None, 
                           n_H_A, n_W_A, n_C),
                    name='X_test_A',
                )

        self.mask = tf.placeholder(
                    tf.bool,
                    shape=(None,
                         n_H_A, n_W_A, n_C),
                    name='inner_outer',
                )

        G = bicycleGenerator(self.input_A, self.n_H_B, self.n_W_B, g_sizes_enc, g_sizes_dec, 'A_to_B')

        D = Discriminator_minibatch(self.input_B, d_sizes, 'B')
        #D = Discriminator(self.input_B, d_sizes, 'B')

        E = convEncoder(self.input_B, e_sizes, 'B')

        with tf.variable_scope('encoder_B') as scope:
            z_encoded, z_encoded_mu, z_encoded_log_sigma = E.e_forward(self.input_B)
        
        with tf.variable_scope('generator_A_to_B') as scope:
            sample_A_to_B_encoded = tf_HCAL_act(G.g_forward(self.input_A, z_encoded), self.mask)

        with tf.variable_scope('generator_A_to_B') as scope:
            scope.reuse_variables()
            sample_A_to_B = self.sample_A_to_B = tf_HCAL_act(G.g_forward(self.input_A, self.z, reuse=True), self.mask)
        
        with tf.variable_scope('encoder_B') as scope:
            scope.reuse_variables()
            z_recon, z_recon_mu, z_recon_log_sigma = E.e_forward(sample_A_to_B, reuse=True)

        with tf.variable_scope('discriminator_B') as scope:

            logits_real, feature_output_real = D.d_forward(self.input_B)

            
        with tf.variable_scope('discriminator_B') as scope:
            scope.reuse_variables()
            logits_fake, feature_output_fake = D.d_forward(sample_A_to_B, reuse=True)
            logits_fake_encoded, feature_output_fake_encoded = D.d_forward(sample_A_to_B_encoded, reuse=True)

        with tf.variable_scope('generator_A_to_B') as scope:
            scope.reuse_variables()
            self.test_images_A_to_B = tf_HCAL_act(G.g_forward(
                self.input_test_A, self.z, reuse=True, is_training=False
                ),self.mask)

        #parameters lists
        self.d_params =[t for t in tf.trainable_variables() if 'discriminator' in t.name]
        self.e_params =[t for t in tf.trainable_variables() if 'encoder' in t.name]
        self.g_params =[t for t in tf.trainable_variables() if 'generator' in t.name]
        
        predicted_real= tf.nn.sigmoid(logits_real)
        #predicted_real=tf.maximum(tf.minimum(predicted_real, 0.99), 0.00)

        predicted_fake=tf.nn.sigmoid(logits_fake)
        #predicted_fake=tf.maximum(tf.minimum(predicted_fake, 0.99), 0.00)

        predicted_fake_encoded = tf.nn.sigmoid(logits_fake_encoded)
        #predicted_fake_encoded =tf.maximum(tf.minimum(predicted_fake_encoded, 0.99), 0.00)

        epsilon=1e-2
        #GAN LOSS
        if cost_type=='GAN': 

            #DISCRIMINATOR LOSSES
            self.d_cost_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
                logits=logits_real,
                labels=(1-epsilon)*tf.ones_like(logits_real)
            )
            )
            self.d_cost_fake_lr_GAN = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
                logits=logits_fake,
                labels=epsilon+tf.zeros_like(logits_fake)
            )
            )
            self.d_cost_fake_vae_GAN = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
                logits=logits_fake_encoded,
                labels=epsilon+tf.zeros_like(logits_fake_encoded)
            )
            )
            #GENERATOR LOSSES
            self.g_cost_lr_GAN = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    logits=logits_fake,
                    labels=(1-epsilon)*tf.ones_like(logits_fake)
                )
            )

            self.g_cost_vae_GAN = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(
                    logits=logits_fake_encoded,
                    labels=(1-epsilon)*tf.ones_like(logits_fake_encoded)
                )
            )

        if cost_type=='WGAN':
            #DISCRIMINATOR
            self.d_cost_real= -tf.reduce_mean(logits_real)
            self.d_cost_fake_vae_GAN = tf.reduce_mean(logits_fake_encoded)
            self.d_cost_fake_lr_GAN =  tf.reduce_mean(logits_fake)

            #GP
            epsilon= tf.random_uniform(
                [self.batch_sz, 1, 1, 1], 
                minval=0.,
                maxval=1.,
                )
            interpolated = epsilon*self.input_A + (1-epsilon/2)*sample_A_to_B + (1-epsilon/2)*sample_A_to_B_encoded 
            with tf.variable_scope('discriminator_B') as scope:
                scope.reuse_variables()
                logits_interpolated= D.d_forward(self.input_A, interpolated, reuse=True)

            gradients = tf.gradients(logits_interpolated, [interpolated], name='D_logits_intp')[0]
            grad_l2= tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=[1,2,3]))
            self.grad_penalty=tf.reduce_mean(tf.square(grad_l2-1.0))

            #GENERATOR
            self.g_cost_vae_GAN= - tf.reduce_mean(logits_fake_encoded)
            self.g_cost_lr_GAN= - tf.reduce_mean(logits_fake)
            
        if cost_type=='FEATURE': 
            #DISCRIMINATOR LOSSES
            self.d_cost_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
                logits=logits_real,
                labels=(1-epsilon)*tf.ones_like(logits_real)
            )
            )
            self.d_cost_fake_lr_GAN = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
                logits=logits_fake,
                labels=epsilon+tf.zeros_like(logits_fake)
            )
            )
            self.d_cost_fake_vae_GAN = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
                logits=logits_fake_encoded,
                labels=epsilon+tf.zeros_like(logits_fake_encoded)
            )
            )

            #GENERATOR LOSSES

            g_cost_lr_GAN = tf.reduce_sum(tf.squared_difference(feature_output_real,feature_output_fake), axis=1)
            self.g_cost_lr_GAN = tf.reduce_mean(g_cost_lr_GAN)

            g_cost_vae_GAN = tf.reduce_sum(tf.squared_difference(feature_output_real,feature_output_fake_encoded), axis=1)
            self.g_cost_vae_GAN=tf.reduce_mean(g_cost_vae_GAN)

        #CYCLIC WEIGHT

        g_cost_cycl = tf.reduce_mean(tf.abs(self.input_B - sample_A_to_B_encoded), axis=[1,2,3])
        self.g_cost_cycl = tf.reduce_mean(g_cost_cycl)


        # self.g_4_cells_cycl_encoded=tf.reduce_mean(
        #         tf.abs(
        #         tf.cast(
        #         tf.convert_to_tensor(
        #         [
        #         tf.nn.top_k(
        #             tf.reshape(
        #                     self.input_B[i],
        #             [-1]), 
        #         k=4)[0] -
        #         tf.nn.top_k(
        #             tf.reshape(
        #                     sample_A_to_B_encoded[i],
        #             [-1]), 
        #         k=4)[0] 
        #         for i in range(16)]
        # )
        # ,dtype=tf.float32)
        # )
        # )
        # self.g_4_cells_pos_encoded=tf.reduce_mean(tf.abs(tf.cast(tf.convert_to_tensor(
        #         [
        #         tf.nn.top_k(
        #             tf.reshape(
        #                     self.input_B[i],
        #             [-1]), 
        #         k=4)[1] -
        #         tf.nn.top_k(
        #             tf.reshape(
        #                     sample_A_to_B_encoded[i],
        #             [-1]), 
        #         k=4)[1] 
        #         for i in range(16)]
        # ), dtype=tf.float32)
        # )
        # )
        

        
        #ENCODER COSTS

        e_cost_latent_cycle = tf.reduce_mean(tf.abs(self.z - z_recon), axis=1)
        self.e_cost_latent_cycle=tf.reduce_mean(e_cost_latent_cycle)
        
        self.e_cost_kl = -0.5 * tf.reduce_mean(1 + 2*tf.log(z_encoded_log_sigma) - tf.square(z_encoded_mu) - tf.square(z_encoded_log_sigma))
        
        #TOTAL COSTS
        
        self.d_cost = self.d_cost_fake_vae_GAN + self.d_cost_fake_lr_GAN + 2*self.d_cost_real 
        self.g_cost = self.g_cost_vae_GAN + self.g_cost_lr_GAN + cycl_weight*self.g_cost_cycl + self.e_cost_latent_cycle
        self.e_cost = latent_weight*self.e_cost_latent_cycle + kl_weight*self.e_cost_kl + self.g_cost_vae_GAN + self.g_cost_cycl

        self.d_train_op = tf.train.AdamOptimizer(
                learning_rate=lr, 
                beta1=beta1,
            ).minimize(
                self.d_cost, 
                var_list=self.d_params
            )

        self.g_train_op = tf.train.AdamOptimizer(
                learning_rate=lr, 
                beta1=beta1,
            ).minimize(
                self.g_cost, 
                var_list=self.g_params
            )

        self.e_train_op = tf.train.AdamOptimizer(
                learning_rate=lr, 
                beta1=beta1,
            ).minimize(
                self.e_cost, 
                var_list=self.e_params
            )

        real_predictions = tf.cast(logits_real>0.5,tf.float32)
        fake_predictions = tf.cast(logits_fake<0.5,tf.float32)
        fake_enc_predictions = tf.cast(logits_fake_encoded<0.5,tf.float32)
        
        num_predictions=2.0*batch_size

        num_correct = tf.reduce_sum(real_predictions)+tf.reduce_sum(fake_predictions)
        num_correct_enc = tf.reduce_sum(real_predictions)+tf.reduce_sum(fake_enc_predictions)

        self.d_accuracy= num_correct/num_predictions
        self.d_accuracy_enc= num_correct_enc/num_predictions

        self.D=D
        self.G=G
        self.E=E

        self.latent_weight=latent_weight
        self.cycl_weight=cycl_weight
        self.kl_weight=kl_weight
        self.latent_dims=latent_dims

        self.batch_size=batch_size
        self.epochs=epochs
        self.save_sample=save_sample
        self.path=path
        self.lr = lr

        self.preprocess=preprocess
        self.cost_type=cost_type
        self.gen_steps=gen_steps
        self.vae_steps=vae_steps
        self.discr_steps=discr_steps

    def set_session(self, session):

        self.session = session

        for layer in self.D.d_conv_layers:
            layer.set_session(session)

        for layer in self.G.g_enc_conv_layers:
            layer.set_session(session)

        for layer in self.G.g_dec_conv_layers:
            layer.set_session(session)

        for layer in self.E.e_conv_layers:
            layer.set_session(session)

        for layer in self.E.e_dense_layers:
            layer.set_session(session)

    def fit(self, X_A, X_B, validating_size):

        all_A = X_A
        all_B = X_B

        gen_steps=self.gen_steps
        discr_steps=self.discr_steps
        vae_steps = self.vae_steps
        m = X_A.shape[0]
        train_A = all_A[0:m-validating_size]
        train_B = all_B[0:m-validating_size]

        validating_A = all_A[m-validating_size:m]
        validating_B = all_B[m-validating_size:m]

        seed=self.seed

        d_costs=[]
        d_costs_vae_GAN=[]
        d_costs_lr_GAN=[]
        d_costs_GAN=[ ]

        g_costs=[]
        g_costs_lr_GAN=[]
        g_costs_vae_GAN=[]
        g_costs_cycl=[]

        e_costs=[]
        e_costs_kl=[]
        e_costs_latent_cycle=[]

        N=len(train_A)
        n_batches = N // self.batch_size

        total_iters=0

        print('\n ****** \n')
        print('Training bicycleGAN with a total of ' +str(N)+' samples distributed in '+ str((N)//self.batch_size) +' batches of size '+str(self.batch_size)+'\n')
        print('The validation set consists of {0} images'.format(validating_A.shape[0]))
        print('The learning rate is '+str(self.lr)+', and every ' +str(self.save_sample)+ ' batches a generated sample will be saved to '+ self.path)
        print('\n ****** \n')
        
        for epoch in range(self.epochs):

            seed+=1
            print('Epoch:', epoch)

            batches_A = unsupervised_random_mini_batches(train_A, self.batch_size, seed)
            batches_B = unsupervised_random_mini_batches(train_B, self.batch_size, seed)

            for X_batch_A, X_batch_B in zip(batches_A[:-1], batches_B[:-1]):

                bs=X_batch_A.shape[0]

                in_out_mask_zeros = np.zeros([bs, self.n_H_A, self.n_W_A, self.n_C])
                in_out_mask_zeros[:, 12:40, 16:48, :]=1
                in_out_mask = in_out_mask_zeros.astype(dtype=bool)

                t0 = datetime.now()


                e_cost=0
                e_cost_latent_cycle=0
                e_cost_kl=0

                g_cost=0
                g_cost_cycl=0
                g_cost_lr_GAN=0
                g_cost_vae_GAN=0
                #cluster_diff=0

                d_cost=0
                d_cost_vae_GAN=0
                d_cost_lr_GAN=0

                sample_z = np.random.normal(size=(bs, self.latent_dims))
                for i in range(discr_steps):

                    #sample_z = np.random.normal(size=(bs, self.latent_dims))

                    _, d_acc, d_acc_enc, d_cost, d_cost_vae_GAN, d_cost_lr_GAN = self.session.run(

                        (self.d_train_op, self.d_accuracy, self.d_accuracy_enc, self.d_cost, self.d_cost_fake_vae_GAN, self.d_cost_fake_lr_GAN),

                        feed_dict={self.input_A:X_batch_A, self.input_B:X_batch_B, 
                                self.z:sample_z, self.mask:in_out_mask, self.batch_sz:bs
                                },
                    )
                    d_cost+=d_cost
                    d_cost_vae_GAN+=d_cost_vae_GAN
                    d_cost_lr_GAN+=d_cost_lr_GAN

                d_costs.append(d_cost/discr_steps)
                d_costs_vae_GAN.append(d_cost_vae_GAN/discr_steps)
                d_costs_lr_GAN.append(d_cost_lr_GAN/discr_steps)

                for i in range(gen_steps):

                    #sample_z = np.random.normal(size=(bs, self.latent_dims))

                    _, g_cost, g_cost_cycl, g_cost_lr_GAN, g_cost_vae_GAN = self.session.run(
                        
                        (self.g_train_op, self.g_cost, 
                        self.g_cost_cycl, self.g_cost_lr_GAN, self.g_cost_vae_GAN, 
                            ),
                        
                        feed_dict={self.input_A:X_batch_A, self.input_B:X_batch_B, 
                                    self.z:sample_z, self.mask:in_out_mask, self.batch_sz:bs
                                    },
                    )

                    
                    g_cost+=g_cost
                    g_cost_cycl+=g_cost_cycl
                    g_cost_lr_GAN+=g_cost_lr_GAN
                    g_cost_vae_GAN+=g_cost_vae_GAN
                    #cluster_diff+=cluster_diff

                g_costs.append(g_cost/gen_steps)
                g_costs_vae_GAN.append(g_cost_vae_GAN/gen_steps)
                g_costs_lr_GAN.append(g_cost_lr_GAN/gen_steps)
                g_costs_cycl.append(self.cycl_weight*g_cost_cycl/gen_steps)
                #cluster_diffs.append(self.cycl_weight*cluster_diff/gen_steps)

                for i in range(vae_steps):
                    
                    #sample_z = np.random.normal(size=(bs, self.latent_dims))
                    _, e_cost, e_cost_latent_cycle, e_cost_kl  = self.session.run(
                        
                        (self.e_train_op, self.e_cost, 
                        self.e_cost_latent_cycle, self.e_cost_kl
                            ),
                        
                        feed_dict={self.input_A:X_batch_A, self.input_B:X_batch_B, 
                                    self.z:sample_z, self.mask:in_out_mask, self.batch_sz:bs
                                    },
                    )

                    e_cost+=e_cost
                    e_cost_latent_cycle+=e_cost_latent_cycle
                    e_cost_kl+=e_cost_kl


                    #cluster_diff+=cluster_diff
                e_costs.append(e_cost/vae_steps)
                e_costs_latent_cycle.append(self.latent_weight*e_cost_latent_cycle/vae_steps)
                e_costs_kl.append(self.kl_weight*e_cost_kl/vae_steps)

                total_iters+=1
                if total_iters % self.save_sample==0:
                    plt.clf()
                    print("At iter: %d  -  dt: %s - d_acc: %.2f, - d_acc_enc: %.2f" % (total_iters, datetime.now() - t0, d_acc, d_acc_enc))
                    print("Discriminator cost {0:.4g}, Generator cost {1:.4g}, VAE Cost {2:.4g}, KL divergence cost {3:.4g}".format(d_cost, g_cost, e_cost, e_cost_kl))
                    print('Saving a sample...')


                    if self.preprocess!=False:
                        draw_nn_sample(validating_A, validating_B, 1, self.preprocess,
                                        self.min_true, self.max_true, 
                                        self.min_reco, self.max_reco,
                                        f=self.get_sample_A_to_B, is_training=True,
                                        total_iters=total_iters, PATH=self.path)
                    else:
                        draw_nn_sample(validating_A, validating_B, 1, self.preprocess,
                                        f=self.get_sample_A_to_B, is_training=True,
                                        total_iters=total_iters, PATH=self.path)

                    plt.clf()
                    plt.subplot(2,4,1)
                    plt.plot(d_costs, label='Discriminator total cost')
                    plt.plot(d_costs_lr_GAN, label='Discriminator of image with encoded noise cost')
                    plt.plot(d_costs_vae_GAN, label='Discriminator of image with input noise cost')
                    plt.xlabel('Batch')
                    plt.ylabel('Cost')
                    plt.legend()

                    plt.subplot(2,4,2)
                    plt.plot(g_costs, label='Generator total cost')
                    plt.plot(g_costs_cycl, label='Generator cyclic cost')
                    #plt.plot(g_costs_GAN, label='GAN cost (encoded noise image)')
                    #plt.plot(g_costs_vae_GAN, label='GAN cost (input noise image)')
                    plt.xlabel('Batch')
                    plt.ylabel('Cost')
                    plt.legend()

                    plt.subplot(2,4,3)
                    plt.plot(e_costs, label='VAE cost')
                    #plt.plot(e_costs_kl, label='KL cost')
                    #plt.plot(e_costs_latent_cycle, label='Latent space cyclic cost')
                    plt.xlabel('Batch')
                    plt.ylabel('Cost')
                    plt.legend()


                    plt.subplot(2,4,6)
                    #plt.plot(g_costs, label='Generator cost')
                    #plt.plot(g_costs_cycl, label='Generator cyclic cost')
                    plt.plot(g_costs_lr_GAN, label='GAN cost (encoded noise image)')
                    plt.plot(g_costs_vae_GAN, label='GAN cost (input noise image)')
                    plt.xlabel('Batch')
                    plt.ylabel('Cost')
                    plt.legend()

                    plt.subplot(2,4,7)
                    plt.plot(e_costs_latent_cycle, label='Latent space cyclic cost')
                    plt.xlabel('Batch')
                    plt.ylabel('Cost')
                    plt.legend()

                    plt.subplot(2,4,8) 
                    plt.plot(e_costs_kl, label='KL cost')
                    #plt.plot(e_costs_latent_cycle, label='Latent space cyclic cost')
                    plt.xlabel('Batch')
                    plt.ylabel('Cost')
                    plt.legend()


                    fig = plt.gcf()
                    fig.set_size_inches(15,7)
                    plt.savefig(self.path+'/cost_iteration.png',dpi=80)

            print('Printing validation set histograms at epoch {0}'.format(epoch))
            
            if not os.path.exists(self.path+'/epoch{0}/'.format(epoch)):
                os.mkdir(self.path+'/epoch{0}/'.format(epoch))

            validating_NN=np.zeros_like(validating_B)
            for i in range(validating_size):
                validating_NN[i]=self.get_sample_A_to_B(validating_A[i].reshape(1,self.n_H_A,self.n_W_A,self.n_C))
            
            print('ET Distribution plots are being printed...')
            if self.preprocess != False:
                validation_MC_hist= denormalise(validating_B, self.min_reco, self.max_reco).reshape(validating_size, self.n_H_B*self.n_W_B)
                validation_MC_hist = np.sum(validation_MC_hist,axis=1)
                max_MC_hist = np.max(validation_MC_hist)

                validation_NN_hist= denormalise(validating_NN, self.min_reco, self.max_reco).reshape(validating_size, self.n_H_B*self.n_W_B)
                validation_NN_hist = np.sum(validation_NN_hist,axis=1)
                max_NN_hist = np.max(validation_NN_hist)

                validation_true_hist= denormalise(validating_A, self.min_true, self.max_true).reshape(validating_size, self.n_H_A*self.n_W_A)
                validation_true_hist = np.sum(validation_true_hist,axis=1)
                max_true_hist = np.max(validation_true_hist)

                validation_NN_hist_rescaled=(validation_NN_hist/max_NN_hist)*max_MC_hist

            else:
                validation_MC_hist= validating_B.reshape(validating_size, self.n_H_B*self.n_W_B)
                validation_MC_hist = np.sum(validation_MC_hist,axis=1)
                max_MC_hist = np.max(validation_MC_hist)

                validation_NN_hist= validating_NN.reshape(validating_size, self.n_H_B*self.n_W_B)
                validation_NN_hist = np.sum(validation_NN_hist,axis=1)
                max_NN_hist = np.max(validation_NN_hist)

                validation_true_hist= validating_A.reshape(validating_size, self.n_H_A*self.n_W_A)
                validation_true_hist = np.sum(validation_true_hist,axis=1)
                max_true_hist = np.max(validation_true_hist)
                
                validation_NN_hist_rescaled=(validation_NN_hist/max_NN_hist)*max_MC_hist


            plt.clf()
            plt.subplot(1,3,1)
            h_reco = plt.hist(validation_true_hist,bins=30, edgecolor='black');
            plt.xlabel('E (MeV)')
            plt.ylabel('dN/dE')
            plt.title('True pion E_T distribution,\n max true hist: {0} '.format(max_true_hist))
            plt.subplot(1,3,2)
            h_reco = plt.hist(validation_MC_hist,bins=30, edgecolor='black');
            plt.xlabel('E (MeV)')
            plt.ylabel('dN/dE')
            plt.title('Reco pion E_T distribution,\n max MC hist: {0} '.format(max_MC_hist))
            plt.subplot(1,3,3)
            h_nn = plt.hist(validation_NN_hist_rescaled,bins=30, edgecolor='black');
            plt.xlabel('E (MeV)')
            plt.ylabel('dN/dE')
            plt.title('Reco pion E_T distribution from bicycleGAN, \n max NN hist: {0} '.format(max_NN_hist))
            fig = plt.gcf()
            fig.set_size_inches(16,4)

            plt.savefig(self.path+'/epoch{0}/distribution_at_epoch_{1}.png'.format(epoch, epoch), dpi=80)
            
            plt.clf()
            diff=plt.bar(np.arange(0, max_MC_hist, step=max_MC_hist/30), 
             height=(h_nn[0]-h_reco[0]), edgecolor='black', 
             linewidth=1, color='lightblue',width = 1, align = 'edge') 
            plt.xlabel('E (GeV)')
            plt.ylabel('dN/dE')
            plt.title("ET distribution difference NN output - MC output")
            fig = plt.gcf()
            fig.set_size_inches(12,4)
            plt.savefig(self.path+'/epoch{0}/difference_at_epoch_{1}.png'.format(epoch, epoch), dpi=80)
            print('Done')

            print('Resolution plots are being printed...')

            diffNN = validation_NN_hist_rescaled-validation_true_hist
            diffMC = validation_MC_hist-validation_true_hist

            plt.clf()
            plt.subplot(1,2,1)
            h_reco = plt.hist(diffMC/1000,bins=30, range=(-80, 40), edgecolor='black');
            plt.xlabel('ET recoMC - ET true')
            plt.ylabel('dN/dETdiff')
            plt.title('Resolution as simulated by MC')
            plt.subplot(1,2,2)
            h_nn = plt.hist(diffNN/1000,bins=30, range=(-80, 40), edgecolor='black');
            plt.xlabel('ET recoNN - ET true')
            plt.ylabel('dN/dETdiff')
            plt.title('Resolution as simulated by NN')
            fig = plt.gcf()
            fig.set_size_inches(12,4)
            plt.savefig(self.path+'/epoch{0}/resolution_at_epoch_{1}.png'.format(epoch, epoch), dpi=80)
            print('Done')

    def get_sample_A_to_B(self, X):

        z = np.random.normal(size=(1, self.latent_dims))

        in_out_mask_zeros = np.zeros([1, self.n_H_A, self.n_W_A, self.n_C])
        in_out_mask_zeros[:, 12:40, 16:48, :]=1
        in_out_mask = in_out_mask_zeros.astype(dtype=bool)

        one_sample = self.session.run(
            self.test_images_A_to_B, 
            feed_dict={self.input_test_A:X, self.z:z, self.mask:in_out_mask, self.batch_sz: 1})

        return one_sample 

    def get_samples_A_to_B(self, X):

        bs=X.shape[0]
        z = np.random.normal(size=(bs, self.latent_dims))

        in_out_mask_zeros = np.zeros([bs, self.n_H_A, self.n_W_A, self.n_C])
        in_out_mask_zeros[:, 12:40, 16:48, :]=1
        in_out_mask = in_out_mask_zeros.astype(dtype=bool)

        many_samples = self.session.run(
            self.test_images_A_to_B, 
            feed_dict={self.input_test_A:X, self.z:z, self.mask:in_out_mask, self.batch_sz: X.shape[0]})

        return many_samples