diff --git a/pix2pix.py b/pix2pix.py index e774f0ef37099036ecdf68180c0fc90bba3a73d2..bb79786ebdfb55b2c17d85ed8c76fde834d5657b 100644 --- a/pix2pix.py +++ b/pix2pix.py @@ -58,7 +58,7 @@ EPS = 1e-12 CROP_SIZE = 256 Examples = collections.namedtuple("Examples", "paths, inputs, targets, count, steps_per_epoch") -Model = collections.namedtuple("Model", "outputs, predict_real, predict_fake, discrim_loss, discrim_grads_and_vars, gen_loss_GAN, gen_loss_L1, gen_grads_and_vars, train") +Model = collections.namedtuple("Model", "outputs, predict_real, predict_fake, discrim_loss, discrim_grads_and_vars, gen_loss_GAN, gen_loss_L1, gen_grads_and_vars, train, clip_D") def preprocess(image): @@ -478,8 +478,12 @@ def create_model(inputs, targets): discrim_optim = tf.train.AdamOptimizer(a.lr, a.beta1) discrim_grads_and_vars = discrim_optim.compute_gradients(discrim_loss, var_list=discrim_tvars) discrim_train = discrim_optim.apply_gradients(discrim_grads_and_vars) - if a.loss == "wasserstein": - clip_D = [p.assign(tf.clip_by_value(p, -0.01, 0.01)) for p in discrim_tvars] + + if a.loss == "wasserstein": + discrim_tvars = [var for var in tf.trainable_variables() if var.name.startswith("discriminator")] + clip_D = [p.assign(tf.clip_by_value(p, -0.01, 0.01)) for p in discrim_tvars] + else: + clip_D = [] with tf.name_scope("generator_train"): with tf.control_dependencies([discrim_train]): @@ -506,6 +510,7 @@ def create_model(inputs, targets): gen_grads_and_vars=gen_grads_and_vars, outputs=outputs, train=tf.group(update_losses, incr_global_step, gen_train), + clip_D=clip_D ) @@ -803,8 +808,12 @@ def main(): if should(a.display_freq): fetches["display"] = display_fetches + fetches["clip_D"] = model.clip_D - results = sess.run(fetches, options=options, run_metadata=run_metadata) + if a.loss == "wasserstein": + results = sess.run(fetches, options=options, run_metadata=run_metadata) + else: + results = sess.run(fetches, options=options, run_metadata=run_metadata) discrim_loss.append(results["discrim_loss"]) gen_loss_GAN.append(results["gen_loss_GAN"]) gen_loss_L1.append(results["gen_loss_L1"])