diff --git a/pix2pix.py b/pix2pix.py index bb79786ebdfb55b2c17d85ed8c76fde834d5657b..16f7105696bf15443f2aeac9307ea69471dd83e7 100644 --- a/pix2pix.py +++ b/pix2pix.py @@ -47,8 +47,9 @@ parser.add_argument("--lr", type=float, default=0.0002, help="initial learning r parser.add_argument("--beta1", type=float, default=0.5, help="momentum term of adam") parser.add_argument("--l1_weight", type=float, default=100.0, help="weight on L1 term for generator gradient") parser.add_argument("--gan_weight", type=float, default=1.0, help="weight on GAN term for generator gradient") -parser.add_argument("--loss", type=str, default="binary_crossentropy", choices=["binary_crossentropy", "wasserstein", "wasserstein-gp"]) +parser.add_argument("--loss", type=str, default="binary_crossentropy", choices=["binary_crossentropy", "wasserstein"]) parser.add_argument('--lambda_', type=float, default=10., help='gradient penalty lambda hyperparameter, default: 10.') +parser.add_argument('--clip', type=float, default=0.01, help='weight clipping for wasserstein loss') # export options parser.add_argument("--output_filetype", default="png", choices=["png", "jpeg"]) @@ -481,7 +482,7 @@ def create_model(inputs, targets): if a.loss == "wasserstein": discrim_tvars = [var for var in tf.trainable_variables() if var.name.startswith("discriminator")] - clip_D = [p.assign(tf.clip_by_value(p, -0.01, 0.01)) for p in discrim_tvars] + clip_D = [p.assign(tf.clip_by_value(p, -c, c)) for p in discrim_tvars] else: clip_D = []