diff --git a/pix2pix.py b/pix2pix.py
index e774f0ef37099036ecdf68180c0fc90bba3a73d2..bb79786ebdfb55b2c17d85ed8c76fde834d5657b 100644
--- a/pix2pix.py
+++ b/pix2pix.py
@@ -58,7 +58,7 @@ EPS = 1e-12
 CROP_SIZE = 256
 
 Examples = collections.namedtuple("Examples", "paths, inputs, targets, count, steps_per_epoch")
-Model = collections.namedtuple("Model", "outputs, predict_real, predict_fake, discrim_loss, discrim_grads_and_vars, gen_loss_GAN, gen_loss_L1, gen_grads_and_vars, train")
+Model = collections.namedtuple("Model", "outputs, predict_real, predict_fake, discrim_loss, discrim_grads_and_vars, gen_loss_GAN, gen_loss_L1, gen_grads_and_vars, train, clip_D")
 
 
 def preprocess(image):
@@ -478,8 +478,12 @@ def create_model(inputs, targets):
         discrim_optim = tf.train.AdamOptimizer(a.lr, a.beta1)
         discrim_grads_and_vars = discrim_optim.compute_gradients(discrim_loss, var_list=discrim_tvars)
         discrim_train = discrim_optim.apply_gradients(discrim_grads_and_vars)
-        if a.loss == "wasserstein":
-            clip_D = [p.assign(tf.clip_by_value(p, -0.01, 0.01)) for p in discrim_tvars]
+
+    if a.loss == "wasserstein":
+        discrim_tvars = [var for var in tf.trainable_variables() if var.name.startswith("discriminator")]
+        clip_D = [p.assign(tf.clip_by_value(p, -0.01, 0.01)) for p in discrim_tvars]
+    else:
+        clip_D = []
 
     with tf.name_scope("generator_train"):
         with tf.control_dependencies([discrim_train]):
@@ -506,6 +510,7 @@ def create_model(inputs, targets):
         gen_grads_and_vars=gen_grads_and_vars,
         outputs=outputs,
         train=tf.group(update_losses, incr_global_step, gen_train),
+        clip_D=clip_D
     )
 
 
@@ -803,8 +808,12 @@ def main():
 
                 if should(a.display_freq):
                     fetches["display"] = display_fetches
+                fetches["clip_D"] = model.clip_D
 
-                results = sess.run(fetches, options=options, run_metadata=run_metadata)
+                if a.loss == "wasserstein":
+                    results = sess.run(fetches, options=options, run_metadata=run_metadata)
+                else:
+                    results = sess.run(fetches, options=options, run_metadata=run_metadata)
                 discrim_loss.append(results["discrim_loss"])
                 gen_loss_GAN.append(results["gen_loss_GAN"])
                 gen_loss_L1.append(results["gen_loss_L1"])