Newer
Older
HCAL_project / architectures / pix2pix.py
@Davide Lancierini Davide Lancierini on 2 Dec 2018 14 KB First commit
  1. import numpy as np
  2. import os
  3. import math
  4.  
  5. import tensorflow as tf
  6. import matplotlib.pyplot as plt
  7. from datetime import datetime
  8.  
  9. from architectures.utils.NN_building_blocks import *
  10. from architectures.utils.NN_gen_building_blocks import *
  11. from architectures.utils.toolbox import *
  12.  
  13.  
  14. #some hyperparameters of the network
  15. LEARNING_RATE = None
  16. BETA1 = None
  17. COST_TYPE=None
  18. BATCH_SIZE = None
  19. EPOCHS = None
  20. SAVE_SAMPLE_PERIOD = None
  21. PATH = None
  22. SEED = None
  23. rnd_seed=1
  24. preprocess=None
  25. LAMBDA=.01
  26. EPS=1e-10
  27. CYCL_WEIGHT=None
  28. GAN_WEIGHT=None
  29. DISCR_STEPS=None
  30. GEN_STEPS=None
  31.  
  32. min_true=None
  33. max_true=None
  34.  
  35. min_reco=None
  36. max_reco=None
  37.  
  38. n_H_A=None
  39. n_W_A=None
  40. n_W_B=None
  41. n_H_B=None
  42. n_C=None
  43.  
  44. d_sizes=None
  45. g_sizes_enc=None
  46. g_sizes_dec=None
  47.  
  48.  
  49.  
  50. class pix2pix(object):
  51.  
  52. def __init__(
  53.  
  54. self,
  55. n_H_A=n_H_A, n_W_A=n_W_A,
  56. n_H_B=n_H_B, n_W_B=n_W_B, n_C=n_C,
  57. min_true=min_true, max_true=max_true,
  58. min_reco=min_reco, max_reco=max_reco,
  59. d_sizes=d_sizes, g_sizes_enc=g_sizes_enc, g_sizes_dec=g_sizes_dec,
  60. lr=LEARNING_RATE, beta1=BETA1, preprocess=preprocess,
  61. cost_type=COST_TYPE, gan_weight=GAN_WEIGHT, cycl_weight=CYCL_WEIGHT,
  62. discr_steps=DISCR_STEPS, gen_steps=GEN_STEPS,
  63. batch_size=BATCH_SIZE, epochs=EPOCHS,
  64. save_sample=SAVE_SAMPLE_PERIOD, path=PATH, seed=SEED,
  65.  
  66. ):
  67.  
  68. """
  69.  
  70. Positional arguments:
  71.  
  72. - width of (square) image
  73. - number of channels of input image
  74. - discriminator sizes
  75.  
  76. a python dict of the kind
  77. d_sizes = { 'convblocklayer_n':[(n_c+1, kernel, stride, apply_batch_norm, weight initializer),
  78. (,,,,),
  79. (,,,,),
  80. ],
  81. 'convblock_shortcut_layer_n':[(,,,)],
  82. 'dense_layers':[(n_o, apply_bn, weight_init)]
  83. }
  84. - generator sizes
  85.  
  86. a python dictionary of the kind
  87.  
  88. g_sizes = {
  89. 'z':latent_space_dim,
  90. 'projection': int,
  91. 'bn_after_project':bool
  92.  
  93. 'deconvblocklayer_n':[(n_c+1, kernel, stride, apply_batch_norm, weight initializer),
  94. (,,,,),
  95. (,,,,),
  96. ],
  97. 'deconvblock_shortcut_layer_n':[(,,,)],
  98. 'dense_layers':[(n_o, apply_bn, weight_init)]
  99. 'activation':function
  100. }
  101.  
  102. Keyword arguments:
  103.  
  104. - lr = LEARNING_RATE (float32)
  105. - beta1 = ema parameter for adam opt (float32)
  106. - batch_size (int)
  107. - save_sample = after how many batches iterations to save a sample (int)
  108. - path = relative path for saving samples
  109.  
  110. """
  111.  
  112.  
  113. self.max_reco=max_reco
  114. self.min_reco = min_reco
  115.  
  116. self.max_true=max_true
  117. self.min_true=min_true
  118.  
  119. self.seed=seed
  120.  
  121. self.n_W_A = n_W_A
  122. self.n_H_A = n_H_A
  123.  
  124. self.n_W_B = n_W_B
  125. self.n_H_B = n_H_B
  126.  
  127. self.n_C = n_C
  128.  
  129. self.batch_sz = tf.placeholder(
  130. tf.int32,
  131. shape=(),
  132. name='batch_sz'
  133. )
  134.  
  135. self.input_A = tf.placeholder(
  136. tf.float32,
  137. shape=(None,
  138. n_H_A, n_W_A, n_C),
  139. name='X_A',
  140. )
  141.  
  142. self.input_B = tf.placeholder(
  143. tf.float32,
  144. shape=(None,
  145. n_H_B, n_W_B, n_C),
  146. name='X_B',
  147. )
  148.  
  149. self.input_test_A = tf.placeholder(
  150. tf.float32,
  151. shape=(None,
  152. n_H_A, n_W_A, n_C),
  153. name='X_test_A'
  154. )
  155.  
  156. D = pix2pixDiscriminator(self.input_A, d_sizes, 'B')
  157. G = pix2pixGenerator(self.input_A, n_H_B, n_W_B, g_sizes_enc, g_sizes_dec, 'A_to_B')
  158.  
  159. with tf.variable_scope('generator_A_to_B') as scope:
  160.  
  161. sample_images = G.g_forward(self.input_A)
  162.  
  163. with tf.variable_scope('discriminator_B') as scope:
  164.  
  165. predicted_real = D.d_forward(self.input_A, self.input_B)
  166. with tf.variable_scope('discriminator_B') as scope:
  167. scope.reuse_variables()
  168. predicted_fake = D.d_forward(self.input_A, sample_images, reuse=True)
  169.  
  170. #get sample images at test time
  171. with tf.variable_scope('generator_A_to_B') as scope:
  172. scope.reuse_variables()
  173. self.sample_images_test_A_to_B = G.g_forward(
  174. self.input_test_A, reuse=True, is_training=False
  175. )
  176.  
  177. self.d_params = [t for t in tf.trainable_variables() if 'discriminator' in t.name]
  178. self.g_params = [t for t in tf.trainable_variables() if 'generator' in t.name]
  179.  
  180.  
  181. if cost_type == 'GAN':
  182.  
  183. #Discriminator cost
  184. predicted_real= tf.nn.sigmoid(predicted_real)
  185. predicted_fake=tf.nn.sigmoid(predicted_fake)
  186.  
  187. d_cost_real = tf.log(predicted_real + EPS)
  188.  
  189. #d_cost_fake is low if fake images are predicted as real
  190. d_cost_fake = tf.log(1 - predicted_fake +EPS)
  191.  
  192. self.d_cost = tf.reduce_mean(-(d_cost_real + d_cost_fake))
  193.  
  194.  
  195. # #Discriminator cost
  196. # #d_cost_real is low if real images are predicted as real
  197. # d_cost_real = tf.nn.sigmoid_cross_entropy_with_logits(
  198. # logits = predicted_real,
  199. # labels = tf.ones_like(predicted_real)-0.01
  200. # )
  201. # #d_cost_fake is low if fake images are predicted as real
  202. # d_cost_fake = tf.nn.sigmoid_cross_entropy_with_logits(
  203. # logits = predicted_fake,
  204. # labels = tf.zeros_like(predicted_fake)+0.01
  205. # )
  206.  
  207. # self.d_cost = tf.reduce_mean(d_cost_real)+ tf.reduce_mean(d_cost_fake)
  208.  
  209. #Generator cost
  210. #g_cost is low if logits from discriminator on samples generated by generator are predicted as true (1)
  211. self.g_cost_GAN = tf.reduce_mean(-tf.log(predicted_fake + EPS))
  212.  
  213. # self.g_cost_GAN = tf.reduce_mean(
  214. # tf.nn.sigmoid_cross_entropy_with_logits(
  215. # logits=predicted_fake,
  216. # labels=tf.ones_like(predicted_fake)-0.01
  217. # )
  218. # )
  219. self.g_cost_l1 = tf.reduce_mean(tf.square(self.input_B - sample_images))
  220. #self.g_cost_sum = tf.abs(tf.reduce_sum(self.input_B)-tf.reduce_sum(sample_images))
  221. self.g_cost=gan_weight*self.g_cost_GAN + cycl_weight*self.g_cost_l1
  222.  
  223. if cost_type == 'WGAN-gp':
  224.  
  225.  
  226. self.g_cost_GAN = -tf.reduce_mean(predicted_fake)
  227.  
  228. self.g_cost_l1 = tf.reduce_mean(tf.abs(self.input_B - sample_images))
  229. self.g_cost=gan_weight*self.g_cost_GAN+cycl_weight*self.g_cost_l1
  230.  
  231.  
  232. self.d_cost = tf.reduce_mean(predicted_fake) - tf.reduce_mean(predicted_real)
  233. alpha = tf.random_uniform(
  234. shape=[self.batch_sz,self.n_H_A,self.n_W_A,self.n_C],
  235. minval=0.,
  236. maxval=1.
  237. )
  238.  
  239. # interpolates_1 = alpha*self.input_A+(1-alpha)*sample_images
  240. interpolates = alpha*self.input_B+(1-alpha)*sample_images
  241.  
  242. with tf.variable_scope('discriminator_B') as scope:
  243. scope.reuse_variables()
  244. disc_interpolates = D.d_forward(self.input_A, interpolates,reuse = True)
  245.  
  246. gradients = tf.gradients(disc_interpolates,[interpolates])[0]
  247. slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
  248. self.gradient_penalty = tf.reduce_mean((slopes-1)**2)
  249. self.d_cost+=LAMBDA*self.gradient_penalty
  250.  
  251. self.d_train_op = tf.train.AdamOptimizer(
  252. learning_rate=lr,
  253. beta1=beta1,
  254. beta2=0.9
  255. ).minimize(
  256. self.d_cost,
  257. var_list=self.d_params
  258. )
  259.  
  260. self.g_train_op = tf.train.AdamOptimizer(
  261. learning_rate=lr,
  262. beta1=beta1,
  263. beta2=0.9
  264. ).minimize(
  265. self.g_cost,
  266. var_list=self.g_params
  267. )
  268.  
  269. #saving for later
  270. self.batch_size=batch_size
  271. self.epochs=epochs
  272. self.save_sample=save_sample
  273. self.path=path
  274. self.lr = lr
  275.  
  276. self.D=D
  277. self.G=G
  278.  
  279. self.sample_images=sample_images
  280. self.preprocess=preprocess
  281. self.cost_type=cost_type
  282. self.cycl_weight=cycl_weight
  283.  
  284. self.gen_steps=gen_steps
  285. self.discr_steps=discr_steps
  286.  
  287. def set_session(self,session):
  288.  
  289. self.session = session
  290.  
  291. for layer in self.D.d_conv_layers:
  292. layer.set_session(session)
  293.  
  294. for layer in self.G.g_enc_conv_layers:
  295. layer.set_session(session)
  296.  
  297. for layer in self.G.g_dec_conv_layers:
  298. layer.set_session(session)
  299.  
  300. def fit(self, X_A, X_B, validating_size):
  301.  
  302. all_A = X_A
  303. all_B = X_B
  304. gen_steps = self.gen_steps
  305. discr_steps = self.discr_steps
  306.  
  307. m = X_A.shape[0]
  308. train_A = all_A[0:m-validating_size]
  309. train_B = all_B[0:m-validating_size]
  310.  
  311. validating_A = all_A[m-validating_size:m]
  312. validating_B = all_B[m-validating_size:m]
  313.  
  314. seed=self.seed
  315.  
  316. d_costs=[]
  317. d_gps=[]
  318. g_costs=[]
  319. g_GANs=[]
  320. g_l1s=[]
  321. N=len(train_A)
  322. n_batches = N // self.batch_size
  323.  
  324. total_iters=0
  325.  
  326. print('\n ****** \n')
  327. print('Training pix2pix (from 1611.07004) GAN with a total of ' +str(N)+' samples distributed in '+ str((N)//self.batch_size) +' batches of size '+str(self.batch_size)+'\n')
  328. print('The validation set consists of {0} images'.format(validating_A.shape[0]))
  329. print('The learning rate is '+str(self.lr)+', and every ' +str(self.save_sample)+ ' batches a generated sample will be saved to '+ self.path)
  330. print('\n ****** \n')
  331.  
  332. for epoch in range(self.epochs):
  333.  
  334. seed+=1
  335. print('Epoch:', epoch)
  336.  
  337. batches_A = unsupervised_random_mini_batches(train_A, self.batch_size, seed)
  338. batches_B = unsupervised_random_mini_batches(train_B, self.batch_size, seed)
  339.  
  340. for X_batch_A, X_batch_B in zip(batches_A, batches_B):
  341.  
  342. bs=X_batch_A.shape[0]
  343.  
  344. t0 = datetime.now()
  345.  
  346. g_cost=0
  347. g_GAN=0
  348. g_l1=0
  349. for i in range(gen_steps):
  350.  
  351. _, g_cost, g_GAN, g_l1 = self.session.run(
  352. (self.g_train_op, self.g_cost, self.g_cost_GAN, self.g_cost_l1),
  353. feed_dict={self.input_A:X_batch_A, self.input_B:X_batch_B, self.batch_sz:bs},
  354. )
  355. g_cost+=g_cost
  356. g_GAN+=g_GAN
  357. g_l1+=g_l1
  358.  
  359. g_costs.append(g_cost/gen_steps)
  360. g_GANs.append(g_GAN/gen_steps)
  361. g_l1s.append(self.cycl_weight*g_l1/gen_steps)
  362.  
  363. d_cost=0
  364. d_gp=0
  365. for i in range(discr_steps):
  366.  
  367. if self.cost_type=='WGAN-gp':
  368. _, d_cost, d_gp = self.session.run(
  369. (self.d_train_op, self.d_cost, self.gradient_penalty),
  370. feed_dict={self.input_A:X_batch_A, self.input_B:X_batch_B, self.batch_sz:bs},
  371. )
  372.  
  373. d_gp+=d_gp
  374.  
  375.  
  376. else:
  377. _, d_cost = self.session.run(
  378. (self.d_train_op, self.d_cost),
  379. feed_dict={self.input_A:X_batch_A, self.input_B:X_batch_B, self.batch_sz:bs},
  380. )
  381. d_cost+=d_cost
  382.  
  383. d_costs.append(d_cost/discr_steps)
  384. if self.cost_type=='WGAN-gp':
  385. d_gps.append(LAMBDA*d_gp/discr_steps)
  386.  
  387. total_iters+=1
  388. if total_iters % self.save_sample ==0:
  389.  
  390. plt.clf()
  391. print("At iter: %d - dt: %s" % (total_iters, datetime.now() - t0))
  392. print("Discriminator cost {0:.4g}, Generator cost {1:.4g}".format(d_costs[-1], g_costs[-1]))
  393. print('Saving a sample...')
  394.  
  395. if self.preprocess!=False:
  396. draw_nn_sample(validating_A, validating_B, 1, self.preprocess,
  397. self.min_true, self.max_true,
  398. self.min_reco, self.max_reco,
  399. f=self.get_sample_A_to_B, is_training=True,
  400. total_iters=total_iters, PATH=self.path)
  401. else:
  402. draw_nn_sample(validating_A, validating_B, 1, self.preprocess,
  403. f=self.get_sample_A_to_B, is_training=True,
  404. total_iters=total_iters, PATH=self.path)
  405. plt.clf()
  406. plt.subplot(1,2,1)
  407. plt.plot(d_costs, label='Discriminator GAN cost')
  408. plt.plot(g_GANs, label='Generator GAN cost')
  409. plt.xlabel('Epoch')
  410. plt.ylabel('Cost')
  411. plt.legend()
  412. plt.subplot(1,2,2)
  413. plt.plot(g_costs, label='Generator total cost')
  414. plt.plot(g_GANs, label='Generator GAN cost')
  415. plt.plot(g_l1s, label='Generator l1 cycle cost')
  416. plt.xlabel('Epoch')
  417. plt.ylabel('Cost')
  418. plt.legend()
  419.  
  420. fig = plt.gcf()
  421. fig.set_size_inches(15,5)
  422. plt.savefig(self.path+'/cost_iteration_gen_disc_B_to_A.png',dpi=150)
  423.  
  424.  
  425. def get_sample_A_to_B(self, Z):
  426. one_sample = self.session.run(
  427. self.sample_images_test_A_to_B,
  428. feed_dict={self.input_test_A:Z, self.batch_sz: 1})
  429.  
  430. return one_sample
  431.  
  432. def get_samples_A_to_B(self, Z):
  433. many_samples = self.session.run(
  434. self.sample_images_test_A_to_B,
  435. feed_dict={self.input_test_A:Z, self.batch_sz: Z.shape[0]})
  436.  
  437. return many_samples