diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..1e38629ce5049cbde4d10de60755088bc0143cea
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+tags
+test
diff --git a/README.md b/README.md
index 0b0b1cee5e9e7b3717b56b93c8fc5fccd6e52490..15f57eabe0036a177fdd6679f12664c21b22107e 100644
--- a/README.md
+++ b/README.md
@@ -27,14 +27,51 @@ cd pix2pix-tensorflow
 # download the CMP Facades dataset http://cmp.felk.cvut.cz/~tylecr1/facade/
 python tools/download-dataset.py facades
 # train the model (this may take 1-8 hours depending on GPU, on CPU you will be waiting for a bit)
-python pix2pix.py --mode train --output_dir facades_train --max_epochs 200 --input_dir facades/train --which_direction BtoA
+python pix2pix.py \
+  --mode train \
+  --output_dir facades_train \
+  --max_epochs 200 \
+  --input_dir facades/train \
+  --which_direction BtoA
 # test the model
-python pix2pix.py --mode test --output_dir facades_test --input_dir facades/val --checkpoint facades_train
+python pix2pix.py \
+  --mode test \
+  --output_dir facades_test \
+  --input_dir facades/val \
+  --checkpoint facades_train
 ```
 
 The test run will output an HTML file at `facades_test/index.html` that shows input/output/target image sets.
 
-## Datasets
+If you have Docker installed, you can use the provided Docker image to run pix2pix without installing the correct version of Tensorflow:
+
+```sh
+# train the model
+sudo nvidia-docker run \
+  --volume $PWD:/prj \
+  --workdir /prj \
+  --env PYTHONUNBUFFERED=x \
+  affinelayer/pix2pix-tensorflow \
+    python pix2pix.py \
+      --mode train \
+      --output_dir facades_train \
+      --max_epochs 200 \
+      --input_dir facades/train \
+      --which_direction BtoA
+# test the model
+sudo nvidia-docker run \
+  --volume $PWD:/prj \
+  --workdir /prj \
+  --env PYTHONUNBUFFERED=x \
+  affinelayer/pix2pix-tensorflow \
+    python pix2pix.py \
+      --mode test \
+      --output_dir facades_test \
+      --input_dir facades/val \
+      --checkpoint facades_train
+```
+
+## Datasets and Trained Models
 
 The data format used by this program is the same as the original pix2pix format, which consists of images of input and desired output side by side like:
 
@@ -44,15 +81,15 @@ For example:
 
 <img src="docs/418.png" width="256px"/>
 
-Some datasets have been made available by the authors of the pix2pix paper.  To download those datasets, use the included script `tools/download-dataset.py`.
+Some datasets have been made available by the authors of the pix2pix paper.  To download those datasets, use the included script `tools/download-dataset.py`.  There are also links to pre-trained models alongside each dataset:
 
-| dataset | image |
+| dataset | example |
 | --- | --- |
-| `python tools/download-dataset.py facades` <br> 400 images from [CMP Facades dataset](http://cmp.felk.cvut.cz/~tylecr1/facade/). (31MB)  | <img src="docs/facades.jpg" width="256px"/> |
-| `python tools/download-dataset.py cityscapes` <br> 2975 images from the [Cityscapes training set](https://www.cityscapes-dataset.com/). (113M) | <img src="docs/cityscapes.jpg" width="256px"/> |
-| `python tools/download-dataset.py maps` <br> 1096 training images scraped from Google Maps (246M) | <img src="docs/maps.jpg" width="256px"/> |
-| `python tools/download-dataset.py edges2shoes` <br> 50k training images from [UT Zappos50K dataset](http://vision.cs.utexas.edu/projects/finegrained/utzap50k/). Edges are computed by [HED](https://github.com/s9xie/hed) edge detector + post-processing. (2.2GB) | <img src="docs/edges2shoes.jpg" width="256px"/>  |
-| `python tools/download-dataset.py edges2handbags` <br> 137K Amazon Handbag images from [iGAN project](https://github.com/junyanz/iGAN). Edges are computed by [HED](https://github.com/s9xie/hed) edge detector + post-processing. (8.6GB) | <img src="docs/edges2handbags.jpg" width="256px"/> |
+| `python tools/download-dataset.py facades` <br> 400 images from [CMP Facades dataset](http://cmp.felk.cvut.cz/~tylecr1/facade/). (31MB) <br> Pre-trained: [BtoA](https://mega.nz/#!2xpyQBoK!GVtkZN7lqY4aaZltMFdZsPNVE6bUsWyiVUN6RwJtIxQ)  | <img src="docs/facades.jpg" width="256px"/> |
+| `python tools/download-dataset.py cityscapes` <br> 2975 images from the [Cityscapes training set](https://www.cityscapes-dataset.com/). (113M) <br> Pre-trained: [AtoB](https://mega.nz/#!rxByxK6S!W9ZBUqgdGTFDWVlOE_ljVt1G3bU89bdu_nS9Bi1ujiA) [BtoA](https://mega.nz/#!b1olDbhL!mxsYC5AF_WH64CXoukN0KB-nw15kLQ0Etii-F-HDTps) | <img src="docs/cityscapes.jpg" width="256px"/> |
+| `python tools/download-dataset.py maps` <br> 1096 training images scraped from Google Maps (246M) <br> Pre-trained: [AtoB](https://mega.nz/#!i8pkkBJT!3NKLar9sUr-Vh_vNVQF-xwK9-D9iCqaCmj1T27xRf4w) [BtoA](https://mega.nz/#!r8xwCBCD!lNBrY_2QO6pyUJziGj7ikPheUL_yXA8xGXFlM3GPL3c) | <img src="docs/maps.jpg" width="256px"/> |
+| `python tools/download-dataset.py edges2shoes` <br> 50k training images from [UT Zappos50K dataset](http://vision.cs.utexas.edu/projects/finegrained/utzap50k/). Edges are computed by [HED](https://github.com/s9xie/hed) edge detector + post-processing. (2.2GB) <br> Pre-trained: [AtoB](https://mega.nz/#!OoYT3QiQ!8y3zLESvhOyeA6UsjEbcJphi3_uEt534waSL5_f_D4Y) | <img src="docs/edges2shoes.jpg" width="256px"/>  |
+| `python tools/download-dataset.py edges2handbags` <br> 137K Amazon Handbag images from [iGAN project](https://github.com/junyanz/iGAN). Edges are computed by [HED](https://github.com/s9xie/hed) edge detector + post-processing. (8.6GB) <br> Pre-trained: [AtoB](https://mega.nz/#!KlpBHKrZ!iJ3x6xzgk0wnJkPiAf0UxPzhYSmpC3kKH1DY5n_dd0M) | <img src="docs/edges2handbags.jpg" width="256px"/> |
 
 The `facades` dataset is the smallest and easiest to get started with.
 
@@ -64,13 +101,24 @@ The `facades` dataset is the smallest and easiest to get started with.
 
 ```sh
 # Resize source images
-python tools/process.py --input_dir photos/original --operation resize --output_dir photos/resized
+python tools/process.py \
+  --input_dir photos/original \
+  --operation resize \
+  --output_dir photos/resized
 # Create images with blank centers
-python tools/process.py --input_dir photos/resized --operation blank --output_dir photos/blank
+python tools/process.py \
+  --input_dir photos/resized \
+  --operation blank \
+  --output_dir photos/blank
 # Combine resized images with blanked images
-python tools/process.py --input_dir photos/resized --b_dir photos/blank --operation combine --output_dir photos/combined
+python tools/process.py \
+  --input_dir photos/resized \
+  --b_dir photos/blank \
+  --operation combine \
+  --output_dir photos/combined
 # Split into train/val set
-python tools/split.py --dir photos/combined
+python tools/split.py \
+  --dir photos/combined
 ```
 
 The folder `photos/combined` will now have `train` and `val` subfolders that you can use for training and testing.
@@ -80,7 +128,11 @@ The folder `photos/combined` will now have `train` and `val` subfolders that you
 If you have two directories `a` and `b`, with corresponding images (same name, same dimensions, different data) you can combine them with `process.py`:
 
 ```sh
-python tools/process.py --input_dir a --b_dir b --operation combine --output_dir c
+python tools/process.py \
+  --input_dir a \
+  --b_dir b \
+  --operation combine \
+  --output_dir c
 ```
 
 This puts the images in a side-by-side combined image that `pix2pix.py` expects.
@@ -89,7 +141,10 @@ This puts the images in a side-by-side combined image that `pix2pix.py` expects.
 
 For colorization, your images should ideally all be the same aspect ratio.  You can resize and crop them with the resize command:
 ```sh
-python tools/process.py --input_dir photos/original --operation resize --output_dir photos/resized
+python tools/process.py \
+  --input_dir photos/original \
+  --operation resize \
+  --output_dir photos/resized
 ```
 
 No other processing is required, the colorzation mode (see Training section below) uses single images instead of image pairs.
@@ -100,7 +155,12 @@ No other processing is required, the colorzation mode (see Training section belo
 
 For normal training with image pairs, you need to specify which directory contains the training images, and which direction to train on.  The direction options are `AtoB` or `BtoA`
 ```sh
-python pix2pix.py --mode train --output_dir facades_train --max_epochs 200 --input_dir facades/train --which_direction BtoA
+python pix2pix.py \
+  --mode train \
+  --output_dir facades_train \
+  --max_epochs 200 \
+  --input_dir facades/train \
+  --which_direction BtoA
 ```
 
 ### Colorization
@@ -108,7 +168,12 @@ python pix2pix.py --mode train --output_dir facades_train --max_epochs 200 --inp
 `pix2pix.py` includes special code to handle colorization with single images instead of pairs, using that looks like this:
 
 ```sh
-python pix2pix.py --mode train --output_dir photos_train --max_epochs 200 --input_dir photos/train --lab_colorization
+python pix2pix.py \
+  --mode train \
+  --output_dir photos_train \
+  --max_epochs 200 \
+  --input_dir photos/train \
+  --lab_colorization
 ```
 
 In this mode, image A is the black and white image (lightness only), and image B contains the color channels of that image (no lightness information).
@@ -129,7 +194,11 @@ If you wish to write in-progress pictures as the network is training, use `--dis
 Testing is done with `--mode test`.  You should specify the checkpoint to use with `--checkpoint`, this should point to the `output_dir` that you created previously with `--mode train`:
 
 ```sh
-python pix2pix.py --mode test --output_dir facades_test --input_dir facades/val --checkpoint facades_train
+python pix2pix.py \
+  --mode test \
+  --output_dir facades_test \
+  --input_dir facades/val \
+  --checkpoint facades_train
 ```
 
 The testing mode will load some of the configuration options from the checkpoint provided so you do not need to specify `which_direction` for instance.
@@ -138,16 +207,35 @@ The test run will output an HTML file at `facades_test/index.html` that shows in
 
 <img src="docs/test-html.png" width="300px"/>
 
-## Implementation Validation
+## Code Validation
 
-Validation of the code was performed on a Linux machine with a ~1.3 TFLOPS Nvidia GTX 750 Ti GPU.  Due to a lack of compute power, validation is not extensive and only the `facades` dataset at 200 epochs was tested.
+Validation of the code was performed on a Linux machine with a ~1.3 TFLOPS Nvidia GTX 750 Ti GPU and an Azure NC6 instance with a K80 GPU.
 
 ```sh
 git clone https://github.com/affinelayer/pix2pix-tensorflow.git
 cd pix2pix-tensorflow
 python tools/download-dataset.py facades
-time nvidia-docker run --volume $PWD:/prj --workdir /prj --env PYTHONUNBUFFERED=x affinelayer/tensorflow:pix2pix python pix2pix.py --mode train --output_dir facades_train --max_epochs 200 --input_dir facades/train --which_direction BtoA
-nvidia-docker run --volume $PWD:/prj --workdir /prj --env PYTHONUNBUFFERED=x affinelayer/tensorflow:pix2pix python pix2pix.py --mode test --output_dir facades_test --input_dir facades/val --checkpoint facades_train
+sudo nvidia-docker run \
+  --volume $PWD:/prj \
+  --workdir /prj \
+  --env PYTHONUNBUFFERED=x \
+  affinelayer/pix2pix-tensorflow \
+    python pix2pix.py \
+      --mode train \
+      --output_dir facades_train \
+      --max_epochs 200 \
+      --input_dir facades/train \
+      --which_direction BtoA
+sudo nvidia-docker run \
+  --volume $PWD:/prj \
+  --workdir /prj \
+  --env PYTHONUNBUFFERED=x \
+  affinelayer/pix2pix-tensorflow \
+    python pix2pix.py \
+      --mode test \
+      --output_dir facades_test \
+      --input_dir facades/val \
+      --checkpoint facades_train
 ```
 
 Comparison on facades dataset:
diff --git a/pix2pix.py b/pix2pix.py
index a556bbdcf0e67b42c5ede6a21e9b6c9553b49fff..aa8b74e7490b2abd3e61acb777475d88de8a059a 100644
--- a/pix2pix.py
+++ b/pix2pix.py
@@ -14,15 +14,15 @@ import math
 import time
 
 parser = argparse.ArgumentParser()
-parser.add_argument("--input_dir", required=True, help="path to folder containing images")
-parser.add_argument("--mode", required=True, choices=["train", "test"])
+parser.add_argument("--input_dir", help="path to folder containing images")
+parser.add_argument("--mode", required=True, choices=["train", "test", "export"])
 parser.add_argument("--output_dir", required=True, help="where to put output files")
 parser.add_argument("--seed", type=int)
 parser.add_argument("--checkpoint", default=None, help="directory with checkpoint to resume training from or use for testing")
 
 parser.add_argument("--max_steps", type=int, help="number of training steps (0 to disable)")
 parser.add_argument("--max_epochs", type=int, help="number of training epochs")
-parser.add_argument("--summary_freq", type=int, default=10, help="update summaries every summary_freq steps")
+parser.add_argument("--summary_freq", type=int, default=100, help="update summaries every summary_freq steps")
 parser.add_argument("--progress_freq", type=int, default=50, help="display progress every progress_freq steps")
 # to get tracing working on GPU, LD_LIBRARY_PATH may need to be modified:
 # LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64:/usr/local/cuda/extras/CUPTI/lib64
@@ -31,7 +31,7 @@ parser.add_argument("--display_freq", type=int, default=0, help="write current t
 parser.add_argument("--save_freq", type=int, default=5000, help="save model every save_freq steps, 0 to disable")
 
 parser.add_argument("--aspect_ratio", type=float, default=1.0, help="aspect ratio of output images (width/height)")
-parser.add_argument("--lab_colorization", action="store_true", help="split A image into brightness (A) and color (B), ignore B image")
+parser.add_argument("--lab_colorization", action="store_true", help="image into brightness (A) and color (B)")
 parser.add_argument("--batch_size", type=int, default=1, help="number of images in batch")
 parser.add_argument("--which_direction", type=str, default="AtoB", choices=["AtoB", "BtoA"])
 parser.add_argument("--ngf", type=int, default=64, help="number of generator filters in first conv layer")
@@ -50,7 +50,43 @@ EPS = 1e-12
 CROP_SIZE = 256
 
 Examples = collections.namedtuple("Examples", "paths, inputs, targets, count, steps_per_epoch")
-Model = collections.namedtuple("Model", "outputs, predict_real, predict_fake, discrim_loss, gen_loss_GAN, gen_loss_L1, train")
+Model = collections.namedtuple("Model", "outputs, predict_real, predict_fake, discrim_loss, discrim_grads_and_vars, gen_loss_GAN, gen_loss_L1, gen_grads_and_vars, train")
+
+
+def preprocess(image):
+    with tf.name_scope("preprocess"):
+        # [0, 1] => [-1, 1]
+        return image * 2 - 1
+
+
+def deprocess(image):
+    with tf.name_scope("deprocess"):
+        # [-1, 1] => [0, 1]
+        return (image + 1) / 2
+
+
+def preprocess_lab(lab):
+    with tf.name_scope("preprocess_lab"):
+        L_chan, a_chan, b_chan = tf.unstack(lab, axis=2)
+        # L_chan: black and white with input range [0, 100]
+        # a_chan/b_chan: color channels with input range ~[-110, 110], not exact
+        # [0, 100] => [-1, 1],  ~[-110, 110] => [-1, 1]
+        return [L_chan / 50 - 1, a_chan / 110, b_chan / 110]
+
+
+def deprocess_lab(L_chan, a_chan, b_chan):
+    with tf.name_scope("deprocess_lab"):
+        # this is axis=3 instead of axis=2 because we process individual images but deprocess batches
+        return tf.stack([(L_chan + 1) / 2 * 100, a_chan * 110, b_chan * 110], axis=3)
+
+
+def augment(image, brightness):
+    # (a, b) color channels, combine with L channel and convert to rgb
+    a_chan, b_chan = tf.unstack(image, axis=3)
+    L_chan = tf.squeeze(brightness, axis=3)
+    lab = deprocess_lab(L_chan, a_chan, b_chan)
+    rgb = lab_to_rgb(lab)
+    return rgb
 
 
 def conv(batch_input, out_channels, stride):
@@ -199,7 +235,7 @@ def lab_to_rgb(lab):
 
 
 def load_examples():
-    if not os.path.exists(a.input_dir):
+    if a.input_dir is None or not os.path.exists(a.input_dir):
         raise Exception("input_dir does not exist")
 
     input_paths = glob.glob(os.path.join(a.input_dir, "*.jpg"))
@@ -238,14 +274,14 @@ def load_examples():
         if a.lab_colorization:
             # load color and brightness from image, no B image exists here
             lab = rgb_to_lab(raw_input)
-            L_chan, a_chan, b_chan = tf.unstack(lab, axis=2)
-            a_images = tf.expand_dims(L_chan, axis=2) / 50 - 1 # black and white with input range [0, 100]
-            b_images = tf.stack([a_chan, b_chan], axis=2) / 110 # color channels with input range ~[-110, 110], not exact
+            L_chan, a_chan, b_chan = preprocess_lab(lab)
+            a_images = tf.expand_dims(L_chan, axis=2)
+            b_images = tf.stack([a_chan, b_chan], axis=2)
         else:
             # break apart image pair and move to range [-1, 1]
             width = tf.shape(raw_input)[1] # [height, width, channels]
-            a_images = raw_input[:,:width//2,:] * 2 - 1
-            b_images = raw_input[:,width//2:,:] * 2 - 1
+            a_images = preprocess(raw_input[:,:width//2,:])
+            b_images = preprocess(raw_input[:,width//2:,:])
 
     if a.which_direction == "AtoB":
         inputs, targets = [a_images, b_images]
@@ -279,86 +315,87 @@ def load_examples():
     with tf.name_scope("target_images"):
         target_images = transform(targets)
 
-    paths, inputs, targets = tf.train.batch([paths, input_images, target_images], batch_size=a.batch_size)
+    paths_batch, inputs_batch, targets_batch = tf.train.batch([paths, input_images, target_images], batch_size=a.batch_size)
     steps_per_epoch = int(math.ceil(len(input_paths) / a.batch_size))
 
     return Examples(
-        paths=paths,
-        inputs=inputs,
-        targets=targets,
+        paths=paths_batch,
+        inputs=inputs_batch,
+        targets=targets_batch,
         count=len(input_paths),
         steps_per_epoch=steps_per_epoch,
     )
 
 
-def create_model(inputs, targets):
-    def create_generator(generator_inputs, generator_outputs_channels):
-        layers = []
-
-        # encoder_1: [batch, 256, 256, in_channels] => [batch, 128, 128, ngf]
-        with tf.variable_scope("encoder_1"):
-            output = conv(generator_inputs, a.ngf, stride=2)
+def create_generator(generator_inputs, generator_outputs_channels):
+    layers = []
+
+    # encoder_1: [batch, 256, 256, in_channels] => [batch, 128, 128, ngf]
+    with tf.variable_scope("encoder_1"):
+        output = conv(generator_inputs, a.ngf, stride=2)
+        layers.append(output)
+
+    layer_specs = [
+        a.ngf * 2, # encoder_2: [batch, 128, 128, ngf] => [batch, 64, 64, ngf * 2]
+        a.ngf * 4, # encoder_3: [batch, 64, 64, ngf * 2] => [batch, 32, 32, ngf * 4]
+        a.ngf * 8, # encoder_4: [batch, 32, 32, ngf * 4] => [batch, 16, 16, ngf * 8]
+        a.ngf * 8, # encoder_5: [batch, 16, 16, ngf * 8] => [batch, 8, 8, ngf * 8]
+        a.ngf * 8, # encoder_6: [batch, 8, 8, ngf * 8] => [batch, 4, 4, ngf * 8]
+        a.ngf * 8, # encoder_7: [batch, 4, 4, ngf * 8] => [batch, 2, 2, ngf * 8]
+        a.ngf * 8, # encoder_8: [batch, 2, 2, ngf * 8] => [batch, 1, 1, ngf * 8]
+    ]
+
+    for out_channels in layer_specs:
+        with tf.variable_scope("encoder_%d" % (len(layers) + 1)):
+            rectified = lrelu(layers[-1], 0.2)
+            # [batch, in_height, in_width, in_channels] => [batch, in_height/2, in_width/2, out_channels]
+            convolved = conv(rectified, out_channels, stride=2)
+            output = batchnorm(convolved)
             layers.append(output)
 
-        layer_specs = [
-            a.ngf * 2, # encoder_2: [batch, 128, 128, ngf] => [batch, 64, 64, ngf * 2]
-            a.ngf * 4, # encoder_3: [batch, 64, 64, ngf * 2] => [batch, 32, 32, ngf * 4]
-            a.ngf * 8, # encoder_4: [batch, 32, 32, ngf * 4] => [batch, 16, 16, ngf * 8]
-            a.ngf * 8, # encoder_5: [batch, 16, 16, ngf * 8] => [batch, 8, 8, ngf * 8]
-            a.ngf * 8, # encoder_6: [batch, 8, 8, ngf * 8] => [batch, 4, 4, ngf * 8]
-            a.ngf * 8, # encoder_7: [batch, 4, 4, ngf * 8] => [batch, 2, 2, ngf * 8]
-            a.ngf * 8, # encoder_8: [batch, 2, 2, ngf * 8] => [batch, 1, 1, ngf * 8]
-        ]
-
-        for out_channels in layer_specs:
-            with tf.variable_scope("encoder_%d" % (len(layers) + 1)):
-                rectified = lrelu(layers[-1], 0.2)
-                # [batch, in_height, in_width, in_channels] => [batch, in_height/2, in_width/2, out_channels]
-                convolved = conv(rectified, out_channels, stride=2)
-                output = batchnorm(convolved)
-                layers.append(output)
-
-        layer_specs = [
-            (a.ngf * 8, 0.5),   # decoder_8: [batch, 1, 1, ngf * 8] => [batch, 2, 2, ngf * 8 * 2]
-            (a.ngf * 8, 0.5),   # decoder_7: [batch, 2, 2, ngf * 8 * 2] => [batch, 4, 4, ngf * 8 * 2]
-            (a.ngf * 8, 0.5),   # decoder_6: [batch, 4, 4, ngf * 8 * 2] => [batch, 8, 8, ngf * 8 * 2]
-            (a.ngf * 8, 0.0),   # decoder_5: [batch, 8, 8, ngf * 8 * 2] => [batch, 16, 16, ngf * 8 * 2]
-            (a.ngf * 4, 0.0),   # decoder_4: [batch, 16, 16, ngf * 8 * 2] => [batch, 32, 32, ngf * 4 * 2]
-            (a.ngf * 2, 0.0),   # decoder_3: [batch, 32, 32, ngf * 4 * 2] => [batch, 64, 64, ngf * 2 * 2]
-            (a.ngf, 0.0),       # decoder_2: [batch, 64, 64, ngf * 2 * 2] => [batch, 128, 128, ngf * 2]
-        ]
-
-        num_encoder_layers = len(layers)
-        for decoder_layer, (out_channels, dropout) in enumerate(layer_specs):
-            skip_layer = num_encoder_layers - decoder_layer - 1
-            with tf.variable_scope("decoder_%d" % (skip_layer + 1)):
-                if decoder_layer == 0:
-                    # first decoder layer doesn't have skip connections
-                    # since it is directly connected to the skip_layer
-                    input = layers[-1]
-                else:
-                    input = tf.concat_v2([layers[-1], layers[skip_layer]], axis=3)
-
-                rectified = tf.nn.relu(input)
-                # [batch, in_height, in_width, in_channels] => [batch, in_height*2, in_width*2, out_channels]
-                output = deconv(rectified, out_channels)
-                output = batchnorm(output)
-
-                if dropout > 0.0:
-                    output = tf.nn.dropout(output, keep_prob=1 - dropout)
-
-                layers.append(output)
-
-        # decoder_1: [batch, 128, 128, ngf * 2] => [batch, 256, 256, generator_outputs_channels]
-        with tf.variable_scope("decoder_1"):
-            input = tf.concat_v2([layers[-1], layers[0]], axis=3)
+    layer_specs = [
+        (a.ngf * 8, 0.5),   # decoder_8: [batch, 1, 1, ngf * 8] => [batch, 2, 2, ngf * 8 * 2]
+        (a.ngf * 8, 0.5),   # decoder_7: [batch, 2, 2, ngf * 8 * 2] => [batch, 4, 4, ngf * 8 * 2]
+        (a.ngf * 8, 0.5),   # decoder_6: [batch, 4, 4, ngf * 8 * 2] => [batch, 8, 8, ngf * 8 * 2]
+        (a.ngf * 8, 0.0),   # decoder_5: [batch, 8, 8, ngf * 8 * 2] => [batch, 16, 16, ngf * 8 * 2]
+        (a.ngf * 4, 0.0),   # decoder_4: [batch, 16, 16, ngf * 8 * 2] => [batch, 32, 32, ngf * 4 * 2]
+        (a.ngf * 2, 0.0),   # decoder_3: [batch, 32, 32, ngf * 4 * 2] => [batch, 64, 64, ngf * 2 * 2]
+        (a.ngf, 0.0),       # decoder_2: [batch, 64, 64, ngf * 2 * 2] => [batch, 128, 128, ngf * 2]
+    ]
+
+    num_encoder_layers = len(layers)
+    for decoder_layer, (out_channels, dropout) in enumerate(layer_specs):
+        skip_layer = num_encoder_layers - decoder_layer - 1
+        with tf.variable_scope("decoder_%d" % (skip_layer + 1)):
+            if decoder_layer == 0:
+                # first decoder layer doesn't have skip connections
+                # since it is directly connected to the skip_layer
+                input = layers[-1]
+            else:
+                input = tf.concat_v2([layers[-1], layers[skip_layer]], axis=3)
+
             rectified = tf.nn.relu(input)
-            output = deconv(rectified, generator_outputs_channels)
-            output = tf.tanh(output)
+            # [batch, in_height, in_width, in_channels] => [batch, in_height*2, in_width*2, out_channels]
+            output = deconv(rectified, out_channels)
+            output = batchnorm(output)
+
+            if dropout > 0.0:
+                output = tf.nn.dropout(output, keep_prob=1 - dropout)
+
             layers.append(output)
 
-        return layers[-1]
+    # decoder_1: [batch, 128, 128, ngf * 2] => [batch, 256, 256, generator_outputs_channels]
+    with tf.variable_scope("decoder_1"):
+        input = tf.concat_v2([layers[-1], layers[0]], axis=3)
+        rectified = tf.nn.relu(input)
+        output = deconv(rectified, generator_outputs_channels)
+        output = tf.tanh(output)
+        layers.append(output)
 
+    return layers[-1]
+
+
+def create_model(inputs, targets):
     def create_discriminator(discrim_inputs, discrim_targets):
         n_layers = 3
         layers = []
@@ -424,13 +461,15 @@ def create_model(inputs, targets):
     with tf.name_scope("discriminator_train"):
         discrim_tvars = [var for var in tf.trainable_variables() if var.name.startswith("discriminator")]
         discrim_optim = tf.train.AdamOptimizer(a.lr, a.beta1)
-        discrim_train = discrim_optim.minimize(discrim_loss, var_list=discrim_tvars)
+        discrim_grads_and_vars = discrim_optim.compute_gradients(discrim_loss, var_list=discrim_tvars)
+        discrim_train = discrim_optim.apply_gradients(discrim_grads_and_vars)
 
     with tf.name_scope("generator_train"):
         with tf.control_dependencies([discrim_train]):
             gen_tvars = [var for var in tf.trainable_variables() if var.name.startswith("generator")]
             gen_optim = tf.train.AdamOptimizer(a.lr, a.beta1)
-            gen_train = gen_optim.minimize(gen_loss, var_list=gen_tvars)
+            gen_grads_and_vars = gen_optim.compute_gradients(gen_loss, var_list=gen_tvars)
+            gen_train = gen_optim.apply_gradients(gen_grads_and_vars)
 
     ema = tf.train.ExponentialMovingAverage(decay=0.99)
     update_losses = ema.apply([discrim_loss, gen_loss_GAN, gen_loss_L1])
@@ -442,17 +481,23 @@ def create_model(inputs, targets):
         predict_real=predict_real,
         predict_fake=predict_fake,
         discrim_loss=ema.average(discrim_loss),
+        discrim_grads_and_vars=discrim_grads_and_vars,
         gen_loss_GAN=ema.average(gen_loss_GAN),
         gen_loss_L1=ema.average(gen_loss_L1),
+        gen_grads_and_vars=gen_grads_and_vars,
         outputs=outputs,
         train=tf.group(update_losses, incr_global_step, gen_train),
     )
 
 
-def save_images(fetches, image_dir, step=None):
+def save_images(fetches, step=None):
+    image_dir = os.path.join(a.output_dir, "images")
+    if not os.path.exists(image_dir):
+        os.makedirs(image_dir)
+
     filesets = []
     for i, in_path in enumerate(fetches["paths"]):
-        name, _ = os.path.splitext(os.path.basename(in_path))
+        name, _ = os.path.splitext(os.path.basename(in_path.decode("utf8")))
         fileset = {"name": name, "step": step}
         for kind in ["inputs", "outputs", "targets"]:
             filename = name + "-" + kind + ".png"
@@ -461,7 +506,7 @@ def save_images(fetches, image_dir, step=None):
             fileset[kind] = filename
             out_path = os.path.join(image_dir, filename)
             contents = fetches[kind][i]
-            with open(out_path, "w") as f:
+            with open(out_path, "wb") as f:
                 f.write(contents)
         filesets.append(fileset)
     return filesets
@@ -493,6 +538,9 @@ def append_index(filesets, step=False):
 
 
 def main():
+    if tf.__version__ != "0.12.1":
+        raise Exception("Tensorflow version 0.12.1 required")
+
     if a.seed is None:
         a.seed = random.randint(0, 2**31 - 1)
 
@@ -503,14 +551,14 @@ def main():
     if not os.path.exists(a.output_dir):
         os.makedirs(a.output_dir)
 
-    if a.mode == "test":
+    if a.mode == "test" or a.mode == "export":
         if a.checkpoint is None:
             raise Exception("checkpoint required for test mode")
 
         # load some options from the checkpoint
         options = {"which_direction", "ngf", "ndf", "lab_colorization"}
         with open(os.path.join(a.checkpoint, "options.json")) as f:
-            for key, val in json.loads(f.read()).iteritems():
+            for key, val in json.loads(f.read()).items():
                 if key in options:
                     print("loaded", key, "=", val)
                     setattr(a, key, val)
@@ -524,71 +572,107 @@ def main():
     with open(os.path.join(a.output_dir, "options.json"), "w") as f:
         f.write(json.dumps(vars(a), sort_keys=True, indent=4))
 
-    examples = load_examples()
+    if a.mode == "export":
+        # export the generator to a meta graph that can be imported later for standalone generation
+        if a.lab_colorization:
+            raise Exception("export not supported for lab_colorization")
+
+        input = tf.placeholder(tf.float32, shape=[CROP_SIZE, CROP_SIZE, 3], name="input")
+        with tf.variable_scope("generator") as scope:
+            outputs = create_generator(tf.expand_dims(preprocess(input), axis=0), 3)
+
+        output = deprocess(tf.identity(outputs[0,:,:,:], name="output"))
+
+        key = tf.placeholder(tf.string, shape=[None])
+        inputs = {
+            "key": key.name,
+            "input": input.name
+        }
+        tf.add_to_collection("inputs", json.dumps(inputs))
+        outputs = {
+            "key":  tf.identity(key).name,
+            "output": output.name,
+        }
+        tf.add_to_collection("outputs", json.dumps(outputs))
+
+        init_op = tf.global_variables_initializer()
+        restore_saver = tf.train.Saver()
+        export_saver = tf.train.Saver()
 
+        with tf.Session() as sess:
+            sess.run(init_op)
+            print("loading model from checkpoint")
+            checkpoint = tf.train.latest_checkpoint(a.checkpoint)
+            restore_saver.restore(sess, checkpoint)
+            print("exporting model")
+            export_saver.export_meta_graph(filename=os.path.join(a.output_dir, "export.meta"))
+            export_saver.save(sess, os.path.join(a.output_dir, "export"), write_meta_graph=False)
+
+        return
+
+    examples = load_examples()
     print("examples count = %d" % examples.count)
 
+    # inputs and targets are [batch_size, height, width, channels]
     model = create_model(examples.inputs, examples.targets)
 
-    def deprocess(image):
+    # undo colorization splitting on images that we use for display/output
+    if a.lab_colorization:
+        if a.which_direction == "AtoB":
+            # inputs is brightness, this will be handled fine as a grayscale image
+            # need to augment targets and outputs with brightness
+            targets = augment(examples.targets, examples.inputs)
+            outputs = augment(model.outputs, examples.inputs)
+            # inputs can be deprocessed normally and handled as if they are single channel
+            # grayscale images
+            inputs = deprocess(examples.inputs)
+        elif a.which_direction == "BtoA":
+            # inputs will be color channels only, get brightness from targets
+            inputs = augment(examples.inputs, examples.targets)
+            targets = deprocess(examples.targets)
+            outputs = deprocess(model.outputs)
+        else:
+            raise Exception("invalid direction")
+    else:
+        inputs = deprocess(examples.inputs)
+        targets = deprocess(examples.targets)
+        outputs = deprocess(model.outputs)
+
+    def convert(image):
         if a.aspect_ratio != 1.0:
             # upscale to correct aspect ratio
             size = [CROP_SIZE, int(round(CROP_SIZE * a.aspect_ratio))]
             image = tf.image.resize_images(image, size=size, method=tf.image.ResizeMethod.BICUBIC)
 
-        if a.lab_colorization:
-            # colorization mode images can be 1 channel (L) or 2 channels (a,b)
-            num_channels = int(image.get_shape()[-1])
-            if num_channels == 1:
-                return tf.image.convert_image_dtype((image + 1) / 2, dtype=tf.uint8, saturate=True)
-            elif num_channels == 2:
-                # (a, b) color channels, convert to rgb
-                # a_chan and b_chan have range [-1, 1] => [-110, 110]
-                a_chan, b_chan = tf.unstack(image * 110, axis=3)
-                # get L_chan from inputs or targets
-                if a.which_direction == "AtoB":
-                    brightness = examples.inputs
-                elif a.which_direction == "BtoA":
-                    brightness = examples.targets
-                else:
-                    raise Exception("invalid direction")
-                # L_chan has range [-1, 1] => [0, 100]
-                L_chan = tf.squeeze((brightness + 1) / 2 * 100, axis=3)
-                lab = tf.stack([L_chan, a_chan, b_chan], axis=3)
-                rgb = lab_to_rgb(lab)
-                return tf.image.convert_image_dtype(rgb, dtype=tf.uint8, saturate=True)
-            else:
-                raise Exception("unexpected number of channels")
-        else:
-            return tf.image.convert_image_dtype((image + 1) / 2, dtype=tf.uint8, saturate=True)
+        return tf.image.convert_image_dtype(image, dtype=tf.uint8, saturate=True)
 
     # reverse any processing on images so they can be written to disk or displayed to user
-    with tf.name_scope("deprocess_inputs"):
-        deprocessed_inputs = deprocess(examples.inputs)
+    with tf.name_scope("convert_inputs"):
+        converted_inputs = convert(inputs)
 
-    with tf.name_scope("deprocess_targets"):
-        deprocessed_targets = deprocess(examples.targets)
+    with tf.name_scope("convert_targets"):
+        converted_targets = convert(targets)
 
-    with tf.name_scope("deprocess_outputs"):
-        deprocessed_outputs = deprocess(model.outputs)
+    with tf.name_scope("convert_outputs"):
+        converted_outputs = convert(outputs)
 
     with tf.name_scope("encode_images"):
         display_fetches = {
             "paths": examples.paths,
-            "inputs": tf.map_fn(tf.image.encode_png, deprocessed_inputs, dtype=tf.string, name="input_pngs"),
-            "targets": tf.map_fn(tf.image.encode_png, deprocessed_targets, dtype=tf.string, name="target_pngs"),
-            "outputs": tf.map_fn(tf.image.encode_png, deprocessed_outputs, dtype=tf.string, name="output_pngs"),
+            "inputs": tf.map_fn(tf.image.encode_png, converted_inputs, dtype=tf.string, name="input_pngs"),
+            "targets": tf.map_fn(tf.image.encode_png, converted_targets, dtype=tf.string, name="target_pngs"),
+            "outputs": tf.map_fn(tf.image.encode_png, converted_outputs, dtype=tf.string, name="output_pngs"),
         }
 
     # summaries
     with tf.name_scope("inputs_summary"):
-        tf.summary.image("inputs", deprocessed_inputs)
+        tf.summary.image("inputs", converted_inputs)
 
     with tf.name_scope("targets_summary"):
-        tf.summary.image("targets", deprocessed_targets)
+        tf.summary.image("targets", converted_targets)
 
     with tf.name_scope("outputs_summary"):
-        tf.summary.image("outputs", deprocessed_outputs)
+        tf.summary.image("outputs", converted_outputs)
 
     with tf.name_scope("predict_real_summary"):
         tf.summary.image("predict_real", tf.image.convert_image_dtype(model.predict_real, dtype=tf.uint8))
@@ -600,13 +684,15 @@ def main():
     tf.summary.scalar("generator_loss_GAN", model.gen_loss_GAN)
     tf.summary.scalar("generator_loss_L1", model.gen_loss_L1)
 
+    for var in tf.trainable_variables():
+        tf.summary.histogram(var.op.name + "/values", var)
+
+    for grad, var in model.discrim_grads_and_vars + model.gen_grads_and_vars:
+        tf.summary.histogram(var.op.name + "/gradients", grad)
+
     with tf.name_scope("parameter_count"):
         parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()])
 
-    image_dir = os.path.join(a.output_dir, "images")
-    if not os.path.exists(image_dir):
-        os.makedirs(image_dir)
-
     saver = tf.train.Saver(max_to_keep=1)
 
     logdir = a.output_dir if (a.trace_freq > 0 or a.summary_freq > 0) else None
@@ -619,25 +705,26 @@ def main():
             checkpoint = tf.train.latest_checkpoint(a.checkpoint)
             saver.restore(sess, checkpoint)
 
+        max_steps = 2**32
+        if a.max_epochs is not None:
+            max_steps = examples.steps_per_epoch * a.max_epochs
+        if a.max_steps is not None:
+            max_steps = a.max_steps
+
         if a.mode == "test":
             # testing
-            # run a single epoch over all input data
-            for step in range(examples.steps_per_epoch):
+            # at most, process the test data once
+            max_steps = min(examples.steps_per_epoch, max_steps)
+            for step in range(max_steps):
                 results = sess.run(display_fetches)
-                filesets = save_images(results, image_dir)
-                for i, path in enumerate(results["paths"]):
-                    print(step * a.batch_size + i + 1, "evaluated image", os.path.basename(path))
+                filesets = save_images(results)
+                for i, f in enumerate(filesets):
+                    print("evaluated image", f["name"])
                 index_path = append_index(filesets)
 
             print("wrote index at", index_path)
         else:
             # training
-            max_steps = 2**32
-            if a.max_epochs is not None:
-                max_steps = examples.steps_per_epoch * a.max_epochs
-            if a.max_steps is not None:
-                max_steps = a.max_steps
-
             start_time = time.time()
             for step in range(max_steps):
                 def should(freq):
@@ -666,22 +753,25 @@ def main():
                     fetches["display"] = display_fetches
 
                 results = sess.run(fetches, options=options, run_metadata=run_metadata)
+                global_step = results["global_step"]
 
                 if should(a.summary_freq):
-                    sv.summary_writer.add_summary(results["summary"], results["global_step"])
+                    sv.summary_writer.add_summary(results["summary"], global_step)
 
                 if should(a.display_freq):
                     print("saving display images")
-                    filesets = save_images(results["display"], image_dir, step=results["global_step"])
+                    filesets = save_images(results["display"], step=global_step)
                     append_index(filesets, step=True)
 
                 if should(a.trace_freq):
                     print("recording trace")
-                    sv.summary_writer.add_run_metadata(run_metadata, "step_%d" % results["global_step"])
+                    sv.summary_writer.add_run_metadata(run_metadata, "step_%d" % global_step)
 
                 if should(a.progress_freq):
-                    global_step = results["global_step"]
-                    print("progress  epoch %d  step %d  image/sec %0.1f" % (global_step // examples.steps_per_epoch, global_step % examples.steps_per_epoch, global_step * a.batch_size / (time.time() - start_time)))
+                    # global_step will have the correct step count if we resume from a checkpoint
+                    train_epoch = math.ceil(global_step / examples.steps_per_epoch)
+                    train_step = global_step - (train_epoch - 1) * examples.steps_per_epoch
+                    print("progress  epoch %d  step %d  image/sec %0.1f" % (train_epoch, train_step, global_step * a.batch_size / (time.time() - start_time)))
                     print("discrim_loss", results["discrim_loss"])
                     print("gen_loss_GAN", results["gen_loss_GAN"])
                     print("gen_loss_L1", results["gen_loss_L1"])
diff --git a/tools/download-dataset.py b/tools/download-dataset.py
index 6e90c5cf867032e6cdd0bbc81f0b5d384e2156e3..6f2466c20072d23f1522b72442237b28ce7f76c2 100644
--- a/tools/download-dataset.py
+++ b/tools/download-dataset.py
@@ -2,7 +2,10 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import urllib2
+try:
+    from urllib.request import urlopen
+except ImportError:
+    from urllib2 import urlopen # python 3
 import sys
 import tarfile
 import tempfile
@@ -12,7 +15,7 @@ dataset = sys.argv[1]
 url = "https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/%s.tar.gz" % dataset
 with tempfile.TemporaryFile() as tmp:
     print("downloading", url)
-    shutil.copyfileobj(urllib2.urlopen(url), tmp)
+    shutil.copyfileobj(urlopen(url), tmp)
     print("extracting")
     tmp.seek(0)
     tar = tarfile.open(fileobj=tmp)
diff --git a/tools/process.py b/tools/process.py
index a111e6411aa050ac455d81ae01dd68c724eb13da..ac9aab82007f6e31311974a32baa5e28c1bde706 100644
--- a/tools/process.py
+++ b/tools/process.py
@@ -4,7 +4,6 @@ from __future__ import print_function
 
 import argparse
 import os
-import random
 import tensorflow as tf
 import numpy as np
 
@@ -19,18 +18,6 @@ parser.add_argument("--b_dir", type=str, help="path to folder containing B image
 a = parser.parse_args()
 
 
-def grayscale(img):
-    img = img / 255
-    img = 0.299 * img[:,:,0] + 0.587 * img[:,:,1] + 0.114 * img[:,:,2]
-    return (np.expand_dims(img, axis=2) * 255).astype(np.uint8)
-
-
-def normalize(img):
-    img -= img.min()
-    img /= img.max()
-    return img
-
-
 def create_op(func, **placeholders):
     op = func(**placeholders)
 
@@ -166,8 +153,6 @@ def png_path(path):
 
 
 def main():
-    random.seed(0)
-
     if not os.path.exists(a.output_dir):
         os.makedirs(a.output_dir)
 
diff --git a/tools/test.py b/tools/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ee372a27cf7b539f314c8d6f25346d2399dd36d
--- /dev/null
+++ b/tools/test.py
@@ -0,0 +1,67 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import subprocess
+import os
+import sys
+import time
+import argparse
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--long", action="store_true")
+a = parser.parse_args()
+
+
+def run(cmd, image="affinelayer/pix2pix-tensorflow"):
+    docker = "docker"
+    if sys.platform.startswith("linux"):
+        docker = "nvidia-docker"
+
+    datapath = os.path.abspath("../data")
+    prefix = [docker, "run", "--rm", "--volume", os.getcwd() + ":/prj", "--volume", datapath + ":/data", "--workdir", "/prj", "--env", "PYTHONUNBUFFERED=x", "--volume", "/tmp/cuda-cache:/cuda-cache", "--env", "CUDA_CACHE_PATH=/cuda-cache", image]
+    args = prefix + cmd.split(" ")
+    print(" ".join(args))
+    subprocess.check_call(args)
+
+
+def main():
+    start = time.time()
+
+    if a.long:
+        run("python pix2pix.py --mode train --output_dir test/facades_BtoA_train --max_epochs 200 --input_dir /data/official/facades/train --which_direction BtoA --seed 0")
+        run("python pix2pix.py --mode test --output_dir test/facades_BtoA_test --input_dir /data/official/facades/val --seed 0 --checkpoint test/facades_BtoA_train")
+
+        run("python pix2pix.py --mode train --output_dir test/color-lab_AtoB_train --max_epochs 10 --input_dir /data/color-lab/train --which_direction AtoB --seed 0 --lab_colorization")
+        run("python pix2pix.py --mode test --output_dir test/color-lab_AtoB_test --input_dir /data/color-lab/val --seed 0 --checkpoint test/color-lab_AtoB_train")
+    else:
+        # training
+        for direction in ["AtoB", "BtoA"]:
+            for dataset in ["facades", "maps"]:
+                name = dataset + "_" + direction
+                run("python pix2pix.py --mode train --output_dir test/%s_train --max_steps 1 --input_dir /data/official/%s/train --which_direction %s --seed 0" % (name, dataset, direction))
+                run("python pix2pix.py --mode test --output_dir test/%s_test --max_steps 1 --input_dir /data/official/%s/val --seed 0 --checkpoint test/%s_train" % (name, dataset, name))
+
+            # test lab colorization
+            dataset = "color-lab"
+            name = dataset + "_" + direction
+            run("python pix2pix.py --mode train --output_dir test/%s_train --max_steps 1 --input_dir /data/%s/train --which_direction %s --seed 0 --lab_colorization" % (name, dataset, direction))
+            run("python pix2pix.py --mode test --output_dir test/%s_test --max_steps 1 --input_dir /data/%s/val --seed 0 --checkpoint test/%s_train" % (name, dataset, name))
+
+        # using pretrained model
+        for dataset, direction in [("facades", "BtoA"), ("edges2shoes", "AtoB"), ("maps", "AtoB"), ("maps", "BtoA"), ("cityscapes", "AtoB"), ("cityscapes", "BtoA"), ("edges2handbags", "AtoB")]:
+            name = dataset + "_" + direction
+            run("python pix2pix.py --mode test --output_dir test/%s_pretrained_test --input_dir /data/official/%s/val --max_steps 100 --which_direction %s --seed 0 --checkpoint /data/pretrained/%s" % (name, dataset, direction, name))
+            run("python pix2pix.py --mode export --output_dir test/%s_pretrained_export --checkpoint /data/pretrained/%s" % (name, name))
+
+        # test python3
+        run("python pix2pix.py --mode train --output_dir test/py3_facades_AtoB_train --max_steps 1 --input_dir /data/official/facades/train --which_direction AtoB --seed 0", image="tensorflow/tensorflow:0.12.1-gpu-py3")
+        run("python pix2pix.py --mode test --output_dir test/py3_facades_AtoB_test --max_steps 1 --input_dir /data/official/facades/val --seed 0 --checkpoint test/py3_facades_AtoB_train", image="tensorflow/tensorflow:0.12.1-gpu-py3")
+
+    print("elapsed", int(time.time() - start))
+    # short: 2521 (mac)
+    # long: about 9 hours (linux)
+
+
+main()