diff --git a/README.md b/README.md
index 04757b6ea33b3e68413900223a072aa7375303c8..ef2ef4edc45fe0c5cb9cd0257b97125ce6c4810b 100644
--- a/README.md
+++ b/README.md
@@ -1,2 +1,182 @@
 # pix2pix-tensorflow
-Tensorflow Port of Image-to-image translation using conditional adversarial nets https://phillipi.github.io/pix2pix/
+
+Based on [pix2pix](https://phillipi.github.io/pix2pix/) by Isola et al.
+
+[Article about this implemention](https://affinelayer.com/pix2pix/)
+
+Tensorflow implementation of pix2pix.  Learns a mapping from input images to output images, like these examples from the original paper:
+
+<img src="docs/examples.jpg" width="900px"/>
+
+This port is based directly on the torch implementation, and not on an existing Tensorflow implementation.  It is meant to be a faithful implementation of the original work and so does not add anything.  The processing speed on a GPU with cuDNN was equivalent to the Torch implementation in testing.
+
+## Setup
+
+### Prerequisites
+- Tensorflow 0.12.1
+
+### Recommended
+- Linux with Tensorflow GPU edition + cuDNN
+
+### Getting Started
+
+```sh
+# Clone this repo
+git clone https://github.com/affinelayer/pix2pix-tensorflow.git
+cd pix2pix-tensorflow
+# Download the CMP Facades dataset http://cmp.felk.cvut.cz/~tylecr1/facade/
+python tools/download-dataset.py facades
+# Train the model (this may take 1-8 hours depending on GPU, on CPU you will be waiting for a bit)
+python pix2pix.py --mode train --output_dir facades_train --max_epochs 200 --input_dir facades/train --which_direction BtoA
+# Test the model
+python pix2pix.py --mode test --output_dir facades_test --input_dir facades/val --checkpoint facades_train
+```
+
+The test run will output an HTML file at `facades_test/index.html` that shows input/output/target image sets.
+
+## Datasets
+
+The data format used by this program is the same as the original pix2pix format, which consists of images of input and desired output side by side like:
+
+<img src="docs/ab.png" width="256px"/>
+
+For example:
+
+<img src="docs/418.png" width="256px"/>
+
+Some datasets have been made available by the authors of the pix2pix paper.  To download those datasets, use the included script `tools/download-dataset.py`.
+
+| dataset | image |
+| --- | --- |
+| `python tools/download-dataset.py facades` <br> 400 images from [CMP Facades dataset](http://cmp.felk.cvut.cz/~tylecr1/facade/). (31MB)  | <img src="docs/facades.jpg" width="256px"/> |
+| `python tools/download-dataset.py cityscapes` <br> 2975 images from the [Cityscapes training set](https://www.cityscapes-dataset.com/). (113M) | <img src="docs/cityscapes.jpg" width="256px"/> |
+| `python tools/download-dataset.py maps` <br> 1096 training images scraped from Google Maps (246M) | <img src="docs/maps.jpg" width="256px"/> |
+| `python tools/download-dataset.py edges2shoes` <br> 50k training images from [UT Zappos50K dataset](http://vision.cs.utexas.edu/projects/finegrained/utzap50k/). Edges are computed by [HED](https://github.com/s9xie/hed) edge detector + post-processing. (2.2GB) | <img src="docs/edges2shoes.jpg" width="256px"/>  |
+| `python tools/download-dataset.py edges2handbags` <br> 137K Amazon Handbag images from [iGAN project](https://github.com/junyanz/iGAN). Edges are computed by [HED](https://github.com/s9xie/hed) edge detector + post-processing. (8.6GB) | <img src="docs/edges2handbags.jpg" width="256px"/> |
+
+The `facades` dataset is the smallest and easiest to get started with.
+
+### Creating your own dataset
+
+#### Example: creating images with blank centers for [inpainting](https://people.eecs.berkeley.edu/~pathak/context_encoder/)
+
+<img src="docs/combine.png" width="900px"/>
+
+```sh
+# Resize source images
+python tools/process.py --input_dir photos/original --operation resize --output_dir photos/resized
+# Create images with blank centers
+python tools/process.py --input_dir photos/resized --operation blank --output_dir photos/blank
+# Combine resized images with blanked images
+python tools/process.py --input_dir photos/resized --b_dir photos/blank --operation combine --output_dir photos/combined
+# Split into train/val set
+python tools/split.py --dir photos/combined
+```
+
+The folder `photos/combined` will now have `train` and `val` subfolders that you can use for training and testing.
+
+#### Creating image pairs from existing images
+
+If you have two directories `a` and `b`, with corresponding images (same name, same dimensions, different data) you can combine them with `process.py`:
+
+```sh
+python tools/process.py --input_dir a --b_dir b --operation combine --output_dir c
+```
+
+This puts the images in a side-by-side combined image that `pix2pix.py` expects.
+
+#### Colorization
+
+For colorization, your images should ideally all be the same aspect ratio.  You can resize and crop them with the resize command:
+```sh
+python tools/process.py --input_dir photos/original --operation resize --output_dir photos/resized
+```
+
+No other processing is required, the colorzation mode (see Training section below) uses single images instead of image pairs.
+
+## Training
+
+### Image Pairs
+
+For normal training with image pairs, you need to specify which directory contains the training images, and which direction to train on.  The direction options are `AtoB` or `BtoA`
+```sh
+python pix2pix.py --mode train --output_dir facades_train --max_epochs 200 --input_dir facades/train --which_direction BtoA
+```
+
+### Colorization
+
+`pix2pix.py` includes special code to handle colorization with single images instead of pairs, using that looks like this:
+
+```sh
+python pix2pix.py --mode train --output_dir photos_train --max_epochs 200 --input_dir photos/train --lab_colorization
+```
+
+In this mode, image A is the black and white image (lightness only), and image B contains the color channels of that image (no lightness information).
+
+### Tips
+
+You can look at the loss and computation graph using tensorboard:
+```sh
+tensorboard --logdir=facades_train
+```
+
+<img src="docs/tensorboard-scalar.png" width="250px"/> <img src="docs/tensorboard-image.png" width="250px"/> <img src="docs/tensorboard-graph.png" width="250px"/>
+
+If you wish to write in-progress pictures as the network is training, use `--display_freq 50`.  This will update `facades_train/index.html` every 50 steps with the current training inputs and outputs.
+
+## Testing
+
+Testing is done with `--mode test`.  You should specify the checkpoint to use with `--checkpoint`, this should point to the `output_dir` that you created previously with `--mode train`:
+
+```sh
+python pix2pix.py --mode test --output_dir facades_test --input_dir facades/val --checkpoint facades_train
+```
+
+The testing mode will load some of the configuration options from the checkpoint provided so you do not need to specify `which_direction` for instance.
+
+The test run will output an HTML file at `facades_test/index.html` that shows input/output/target image sets:
+
+<img src="docs/test-html.png" width="300px"/>
+
+## Implementation Validation
+
+Validation of the code was performed on a Linux machine with a ~1.3 TFLOPS Nvidia GTX 750 Ti GPU.  Due to a lack of compute power, validation is not extensive and only the `facades` dataset at 200 epochs was tested.
+
+```sh
+git clone https://github.com/affinelayer/pix2pix-tensorflow.git
+cd pix2pix-tensorflow
+python tools/download-dataset.py facades
+time nvidia-docker run --volume $PWD:/prj --workdir /prj --env PYTHONUNBUFFERED=x affinelayer/tensorflow:pix2pix python pix2pix.py --mode train --output_dir facades_train --max_epochs 200 --input_dir facades/train --which_direction BtoA
+nvidia-docker run --volume $PWD:/prj --workdir /prj --env PYTHONUNBUFFERED=x affinelayer/tensorflow:pix2pix python pix2pix.py --mode test --output_dir facades_test --input_dir facades/val --checkpoint facades_train
+```
+
+Comparison on facades dataset:
+
+| Input | Tensorflow | Torch | Target |
+| --- | --- | --- | --- |
+| <img src="docs/1-inputs.png" width="256px"> | <img src="docs/1-tensorflow.png" width="256px"> | <img src="docs/1-torch.jpg" width="256px"> | <img src="docs/1-targets.png" width="256px"> |
+| <img src="docs/5-inputs.png" width="256px"> | <img src="docs/5-tensorflow.png" width="256px"> | <img src="docs/5-torch.jpg" width="256px"> | <img src="docs/5-targets.png" width="256px"> |
+| <img src="docs/51-inputs.png" width="256px"> | <img src="docs/51-tensorflow.png" width="256px"> | <img src="docs/51-torch.jpg" width="256px"> | <img src="docs/51-targets.png" width="256px"> |
+| <img src="docs/95-inputs.png" width="256px"> | <img src="docs/95-tensorflow.png" width="256px"> | <img src="docs/95-torch.jpg" width="256px"> | <img src="docs/95-targets.png" width="256px"> |
+
+## Unimplemented Features
+
+The following models have not been implemented:
+- defineG_encoder_decoder
+- defineG_unet_128
+- defineD_pixelGAN
+
+## Citation
+If you use this code for your research, please cite the paper this code is based on: <a href="https://arxiv.org/pdf/1611.07004v1.pdf">Image-to-Image Translation Using Conditional Adversarial Networks</a>:
+
+```
+@article{pix2pix2016,
+  title={Image-to-Image Translation with Conditional Adversarial Networks},
+  author={Isola, Phillip and Zhu, Jun-Yan and Zhou, Tinghui and Efros, Alexei A},
+  journal={arxiv},
+  year={2016}
+}
+```
+
+## Acknowledgments
+This is a port of [pix2pix](https://github.com/phillipi/pix2pix) from Torch to Tensorflow.  It also contains colorspace conversion code ported from Torch.
diff --git a/docs/1-inputs.png b/docs/1-inputs.png
new file mode 100644
index 0000000000000000000000000000000000000000..a12be3c732f3c38e87a95c1a36c05dc871020ebc
Binary files /dev/null and b/docs/1-inputs.png differ
diff --git a/docs/1-targets.png b/docs/1-targets.png
new file mode 100644
index 0000000000000000000000000000000000000000..f45487797f4f4d05f3b06b3962dd8910aefce52c
Binary files /dev/null and b/docs/1-targets.png differ
diff --git a/docs/1-tensorflow.png b/docs/1-tensorflow.png
new file mode 100644
index 0000000000000000000000000000000000000000..262392a02119efe87b6cc35370c94fb75715bac5
Binary files /dev/null and b/docs/1-tensorflow.png differ
diff --git a/docs/1-torch.jpg b/docs/1-torch.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..2b70d27828610b09f0cc125b532f3089e6240b54
Binary files /dev/null and b/docs/1-torch.jpg differ
diff --git a/docs/418.png b/docs/418.png
new file mode 100644
index 0000000000000000000000000000000000000000..34bfb3dc1d669a0eb75715b413db16a2efeb4c0e
Binary files /dev/null and b/docs/418.png differ
diff --git a/docs/5-inputs.png b/docs/5-inputs.png
new file mode 100644
index 0000000000000000000000000000000000000000..d58a5196a88cb5b8f96bfa6ef3ab71458db73759
Binary files /dev/null and b/docs/5-inputs.png differ
diff --git a/docs/5-targets.png b/docs/5-targets.png
new file mode 100644
index 0000000000000000000000000000000000000000..066d88db8b21d35e496c5f68456e8d03df5dd30d
Binary files /dev/null and b/docs/5-targets.png differ
diff --git a/docs/5-tensorflow.png b/docs/5-tensorflow.png
new file mode 100644
index 0000000000000000000000000000000000000000..591e126635e69a925c75fde9047ee06892961e6a
Binary files /dev/null and b/docs/5-tensorflow.png differ
diff --git a/docs/5-torch.jpg b/docs/5-torch.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..c989387ad57ef71715e287967f74b538a03780ac
Binary files /dev/null and b/docs/5-torch.jpg differ
diff --git a/docs/51-inputs.png b/docs/51-inputs.png
new file mode 100644
index 0000000000000000000000000000000000000000..1d8a571978a885c1c1fd7081bb1e1cd5ae08d0fd
Binary files /dev/null and b/docs/51-inputs.png differ
diff --git a/docs/51-targets.png b/docs/51-targets.png
new file mode 100644
index 0000000000000000000000000000000000000000..42012dddcacef8b841fa73f14508184919e982bc
Binary files /dev/null and b/docs/51-targets.png differ
diff --git a/docs/51-tensorflow.png b/docs/51-tensorflow.png
new file mode 100644
index 0000000000000000000000000000000000000000..19075ce04d3c1819953d07e7faa7f47b29b3fc28
Binary files /dev/null and b/docs/51-tensorflow.png differ
diff --git a/docs/51-torch.jpg b/docs/51-torch.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..a4013e00b87864c524ad02c359cc1b1c9969dfb2
Binary files /dev/null and b/docs/51-torch.jpg differ
diff --git a/docs/95-inputs.png b/docs/95-inputs.png
new file mode 100644
index 0000000000000000000000000000000000000000..6fc2ec2636a4f7b0877bb43ee2e2fb2c2225d0f3
Binary files /dev/null and b/docs/95-inputs.png differ
diff --git a/docs/95-targets.png b/docs/95-targets.png
new file mode 100644
index 0000000000000000000000000000000000000000..f594d737b36056c005584aca1d5a8f613b52d0b8
Binary files /dev/null and b/docs/95-targets.png differ
diff --git a/docs/95-tensorflow.png b/docs/95-tensorflow.png
new file mode 100644
index 0000000000000000000000000000000000000000..e4c34d1ca9fe0c0ce8a2c9fd7889a53f148d4c0d
Binary files /dev/null and b/docs/95-tensorflow.png differ
diff --git a/docs/95-torch.jpg b/docs/95-torch.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..84bed739bcbcac40f402323e88f70c95b1fb3425
Binary files /dev/null and b/docs/95-torch.jpg differ
diff --git a/docs/ab.png b/docs/ab.png
new file mode 100644
index 0000000000000000000000000000000000000000..1dadedbd7b5fcf4ba4e0d783d4a15dd0362da967
Binary files /dev/null and b/docs/ab.png differ
diff --git a/docs/cityscapes.jpg b/docs/cityscapes.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..dfebed7359c8ece36d4f84346e8dae3718029c6a
Binary files /dev/null and b/docs/cityscapes.jpg differ
diff --git a/docs/combine.png b/docs/combine.png
new file mode 100644
index 0000000000000000000000000000000000000000..72b35952cfe4b09400f325cfb3f29b674ef5ee52
Binary files /dev/null and b/docs/combine.png differ
diff --git a/docs/edges2handbags.jpg b/docs/edges2handbags.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..4dbcac47043505afddcdd69025d6ad246999b69f
Binary files /dev/null and b/docs/edges2handbags.jpg differ
diff --git a/docs/edges2shoes.jpg b/docs/edges2shoes.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..55278d45a297652b73d30dfec2e050705226b4d6
Binary files /dev/null and b/docs/edges2shoes.jpg differ
diff --git a/docs/examples.jpg b/docs/examples.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..b1f24d5ef635e16773ea163a28cec68d09a59978
Binary files /dev/null and b/docs/examples.jpg differ
diff --git a/docs/facades.jpg b/docs/facades.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..b88704f656a9e80f2fb9eb88ac3bc768a0e37db5
Binary files /dev/null and b/docs/facades.jpg differ
diff --git a/docs/maps.jpg b/docs/maps.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..4ecdfec85f5eef1324ac78cce3d002248a8b8e4d
Binary files /dev/null and b/docs/maps.jpg differ
diff --git a/docs/tensorboard-graph.png b/docs/tensorboard-graph.png
new file mode 100644
index 0000000000000000000000000000000000000000..fce1f62bfba2d5a5e60e56ccde4acd25eec4fd04
Binary files /dev/null and b/docs/tensorboard-graph.png differ
diff --git a/docs/tensorboard-image.png b/docs/tensorboard-image.png
new file mode 100644
index 0000000000000000000000000000000000000000..8a9581bf4f9778e8322b39647939613c59eb6cdd
Binary files /dev/null and b/docs/tensorboard-image.png differ
diff --git a/docs/tensorboard-scalar.png b/docs/tensorboard-scalar.png
new file mode 100644
index 0000000000000000000000000000000000000000..358028c985b74eca8fe3949bff0283dfa5e7aacc
Binary files /dev/null and b/docs/tensorboard-scalar.png differ
diff --git a/docs/test-html.png b/docs/test-html.png
new file mode 100644
index 0000000000000000000000000000000000000000..aed11d0ab1dfd91ec0e77759ae69febc07018514
Binary files /dev/null and b/docs/test-html.png differ
diff --git a/pix2pix.py b/pix2pix.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d7faa928913a31b4014e1810b1e3707a1183096
--- /dev/null
+++ b/pix2pix.py
@@ -0,0 +1,691 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+import numpy as np
+import argparse
+import os
+import json
+import glob
+import random
+import collections
+import math
+import time
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--input_dir", required=True, help="path to folder containing images")
+parser.add_argument("--mode", required=True, choices=["train", "test"])
+parser.add_argument("--output_dir", required=True, help="where to put output files")
+parser.add_argument("--seed", type=int)
+parser.add_argument("--checkpoint", default=None, help="directory with checkpoint to resume training from or use for testing")
+
+parser.add_argument("--max_steps", type=int, help="number of training steps (0 to disable)")
+parser.add_argument("--max_epochs", type=int, help="number of training epochs")
+parser.add_argument("--summary_freq", type=int, default=10, help="update summaries every summary_freq steps")
+parser.add_argument("--progress_freq", type=int, default=50, help="display progress every progress_freq steps")
+# to get tracing working on GPU, LD_LIBRARY_PATH may need to be modified:
+# LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64:/usr/local/cuda/extras/CUPTI/lib64
+parser.add_argument("--trace_freq", type=int, default=0, help="trace execution every trace_freq steps")
+parser.add_argument("--display_freq", type=int, default=0, help="write current training images every display_freq steps")
+parser.add_argument("--save_freq", type=int, default=5000, help="save model every save_freq steps, 0 to disable")
+
+parser.add_argument("--aspect_ratio", type=float, default=1.0, help="aspect ratio of output images (width/height)")
+parser.add_argument("--lab_colorization", action="store_true", help="split A image into brightness (A) and color (B), ignore B image")
+parser.add_argument("--batch_size", type=int, default=1, help="number of images in batch")
+parser.add_argument("--which_direction", type=str, default="AtoB", choices=["AtoB", "BtoA"])
+parser.add_argument("--ngf", type=int, default=64, help="number of generator filters in first conv layer")
+parser.add_argument("--ndf", type=int, default=64, help="number of discriminator filters in first conv layer")
+parser.add_argument("--scale_size", type=int, default=286, help="scale images to this size before cropping to 256x256")
+parser.add_argument("--flip", dest="flip", action="store_true", help="flip images horizontally")
+parser.add_argument("--no_flip", dest="flip", action="store_false", help="don't flip images horizontally")
+parser.set_defaults(flip=True)
+parser.add_argument("--lr", type=float, default=0.0002, help="initial learning rate for adam")
+parser.add_argument("--beta1", type=float, default=0.5, help="momentum term of adam")
+parser.add_argument("--l1_weight", type=float, default=100.0, help="weight on L1 term for generator gradient")
+parser.add_argument("--gan_weight", type=float, default=1.0, help="weight on GAN term for generator gradient")
+a = parser.parse_args()
+
+EPS = 1e-12
+CROP_SIZE = 256
+
+Examples = collections.namedtuple("Examples", "paths, inputs, targets, count, steps_per_epoch")
+Model = collections.namedtuple("Model", "outputs, predict_real, predict_fake, discrim_loss, gen_loss_GAN, gen_loss_L1, train")
+
+
+def conv(batch_input, out_channels, stride):
+    with tf.variable_scope("conv"):
+        in_channels = batch_input.get_shape()[3]
+        filter = tf.get_variable("filter", [4, 4, in_channels, out_channels], dtype=tf.float32, initializer=tf.random_normal_initializer(0, 0.02))
+        # [batch, in_height, in_width, in_channels], [filter_width, filter_height, in_channels, out_channels]
+        #     => [batch, out_height, out_width, out_channels]
+        padded_input = tf.pad(batch_input, [[0, 0], [1, 1], [1, 1], [0, 0]], mode="CONSTANT")
+        conv = tf.nn.conv2d(padded_input, filter, [1, stride, stride, 1], padding="VALID")
+        return conv
+
+
+def lrelu(x, a):
+    with tf.name_scope("lrelu"):
+        # adding these together creates the leak part and linear part
+        # then cancels them out by subtracting/adding an absolute value term
+        # leak: a*x/2 - a*abs(x)/2
+        # linear: x/2 + abs(x)/2
+
+        # this block looks like it has 2 inputs on the graph unless we do this
+        x = tf.identity(x)
+        return (0.5 * (1 + a)) * x + (0.5 * (1 - a)) * tf.abs(x)
+
+
+def batchnorm(input):
+    with tf.variable_scope("batchnorm"):
+        # this block looks like it has 3 inputs on the graph unless we do this
+        input = tf.identity(input)
+
+        channels = input.get_shape()[3]
+        offset = tf.get_variable("offset", [channels], dtype=tf.float32, initializer=tf.zeros_initializer)
+        scale = tf.get_variable("scale", [channels], dtype=tf.float32, initializer=tf.random_normal_initializer(1.0, 0.02))
+        mean, variance = tf.nn.moments(input, axes=[0, 1, 2], keep_dims=False)
+        variance_epsilon = 1e-5
+        normalized = tf.nn.batch_normalization(input, mean, variance, offset, scale, variance_epsilon=variance_epsilon)
+        return normalized
+
+
+def deconv(batch_input, out_channels):
+    with tf.variable_scope("deconv"):
+        batch, in_height, in_width, in_channels = [int(d) for d in batch_input.get_shape()]
+        filter = tf.get_variable("filter", [4, 4, out_channels, in_channels], dtype=tf.float32, initializer=tf.random_normal_initializer(0, 0.02))
+        # [batch, in_height, in_width, in_channels], [filter_width, filter_height, out_channels, in_channels]
+        #     => [batch, out_height, out_width, out_channels]
+        conv = tf.nn.conv2d_transpose(batch_input, filter, [batch, in_height * 2, in_width * 2, out_channels], [1, 2, 2, 1], padding="SAME")
+        return conv
+
+
+def check_image(image):
+    assertion = tf.assert_equal(tf.shape(image)[-1], 3, message="image must have 3 color channels")
+    with tf.control_dependencies([assertion]):
+        image = tf.identity(image)
+
+    if image.get_shape().ndims not in (3, 4):
+        raise ValueError("image must be either 3 or 4 dimensions")
+
+    # make the last dimension 3 so that you can unstack the colors
+    shape = list(image.get_shape())
+    shape[-1] = 3
+    image.set_shape(shape)
+    return image
+
+# based on https://github.com/torch/image/blob/9f65c30167b2048ecbe8b7befdc6b2d6d12baee9/generic/image.c
+def rgb_to_lab(srgb):
+    with tf.name_scope("rgb_to_lab"):
+        srgb = check_image(srgb)
+        srgb_pixels = tf.reshape(srgb, [-1, 3])
+
+        with tf.name_scope("srgb_to_xyz"):
+            linear_mask = tf.cast(srgb_pixels <= 0.04045, dtype=tf.float32)
+            exponential_mask = tf.cast(srgb_pixels > 0.04045, dtype=tf.float32)
+            rgb_pixels = (srgb_pixels / 12.92 * linear_mask) + (((srgb_pixels + 0.055) / 1.055) ** 2.4) * exponential_mask
+            rgb_to_xyz = tf.constant([
+                #    X        Y          Z
+                [0.412453, 0.212671, 0.019334], # R
+                [0.357580, 0.715160, 0.119193], # G
+                [0.180423, 0.072169, 0.950227], # B
+            ])
+            xyz_pixels = tf.matmul(rgb_pixels, rgb_to_xyz)
+
+        # https://en.wikipedia.org/wiki/Lab_color_space#CIELAB-CIEXYZ_conversions
+        with tf.name_scope("xyz_to_cielab"):
+            # convert to fx = f(X/Xn), fy = f(Y/Yn), fz = f(Z/Zn)
+
+            # normalize for D65 white point
+            xyz_normalized_pixels = tf.multiply(xyz_pixels, [1/0.950456, 1.0, 1/1.088754])
+
+            epsilon = 6/29
+            linear_mask = tf.cast(xyz_normalized_pixels <= (epsilon**3), dtype=tf.float32)
+            exponential_mask = tf.cast(xyz_normalized_pixels > (epsilon**3), dtype=tf.float32)
+            fxfyfz_pixels = (xyz_normalized_pixels / (3 * epsilon**2) + 4/29) * linear_mask + (xyz_normalized_pixels ** (1/3)) * exponential_mask
+
+            # convert to lab
+            fxfyfz_to_lab = tf.constant([
+                #  l       a       b
+                [  0.0,  500.0,    0.0], # fx
+                [116.0, -500.0,  200.0], # fy
+                [  0.0,    0.0, -200.0], # fz
+            ])
+            lab_pixels = tf.matmul(fxfyfz_pixels, fxfyfz_to_lab) + tf.constant([-16.0, 0.0, 0.0])
+
+        return tf.reshape(lab_pixels, tf.shape(srgb))
+
+
+def lab_to_rgb(lab):
+    with tf.name_scope("lab_to_rgb"):
+        lab = check_image(lab)
+        lab_pixels = tf.reshape(lab, [-1, 3])
+
+        # https://en.wikipedia.org/wiki/Lab_color_space#CIELAB-CIEXYZ_conversions
+        with tf.name_scope("cielab_to_xyz"):
+            # convert to fxfyfz
+            lab_to_fxfyfz = tf.constant([
+                #   fx      fy        fz
+                [1/116.0, 1/116.0,  1/116.0], # l
+                [1/500.0,     0.0,      0.0], # a
+                [    0.0,     0.0, -1/200.0], # b
+            ])
+            fxfyfz_pixels = tf.matmul(lab_pixels + tf.constant([16.0, 0.0, 0.0]), lab_to_fxfyfz)
+
+            # convert to xyz
+            epsilon = 6/29
+            linear_mask = tf.cast(fxfyfz_pixels <= epsilon, dtype=tf.float32)
+            exponential_mask = tf.cast(fxfyfz_pixels > epsilon, dtype=tf.float32)
+            xyz_pixels = (3 * epsilon**2 * (fxfyfz_pixels - 4/29)) * linear_mask + (fxfyfz_pixels ** 3) * exponential_mask
+
+            # denormalize for D65 white point
+            xyz_pixels = tf.multiply(xyz_pixels, [0.950456, 1.0, 1.088754])
+
+        with tf.name_scope("xyz_to_srgb"):
+            xyz_to_rgb = tf.constant([
+                #     r           g          b
+                [ 3.2404542, -0.9692660,  0.0556434], # x
+                [-1.5371385,  1.8760108, -0.2040259], # y
+                [-0.4985314,  0.0415560,  1.0572252], # z
+            ])
+            rgb_pixels = tf.matmul(xyz_pixels, xyz_to_rgb)
+            # avoid a slightly negative number messing up the conversion
+            rgb_pixels = tf.clip_by_value(rgb_pixels, 0.0, 1.0)
+            linear_mask = tf.cast(rgb_pixels <= 0.0031308, dtype=tf.float32)
+            exponential_mask = tf.cast(rgb_pixels > 0.0031308, dtype=tf.float32)
+            srgb_pixels = (rgb_pixels * 12.92 * linear_mask) + ((rgb_pixels ** (1/2.4) * 1.055) - 0.055) * exponential_mask
+
+        return tf.reshape(srgb_pixels, tf.shape(lab))
+
+
+def load_examples():
+    input_paths = glob.glob(os.path.join(a.input_dir, "*.jpg"))
+    decode = tf.image.decode_jpeg
+    if len(input_paths) == 0:
+        input_paths = glob.glob(os.path.join(a.input_dir, "*.png"))
+        decode = tf.image.decode_png
+
+    def get_name(path):
+        name, _ = os.path.splitext(os.path.basename(path))
+        return name
+
+    # if the image names are numbers, sort by the value rather than asciibetically
+    # having sorted inputs means that the outputs are sorted in test mode
+    if all(get_name(path).isdigit() for path in input_paths):
+        input_paths = sorted(input_paths, key=lambda path: int(get_name(path)))
+    else:
+        input_paths = sorted(input_paths)
+
+    with tf.name_scope("load_images"):
+        path_queue = tf.train.string_input_producer(input_paths, shuffle=a.mode == "train")
+        reader = tf.WholeFileReader()
+        paths, contents = reader.read(path_queue)
+        raw_input = decode(contents)
+        raw_input = tf.image.convert_image_dtype(raw_input, dtype=tf.float32)
+
+        assertion = tf.assert_equal(tf.shape(raw_input)[2], 3, message="image does not have 3 channels")
+        with tf.control_dependencies([assertion]):
+            raw_input = tf.identity(raw_input)
+
+        raw_input.set_shape([None, None, 3])
+
+        if a.lab_colorization:
+            # load color and brightness from image, no B image exists here
+            lab = rgb_to_lab(raw_input)
+            L_chan, a_chan, b_chan = tf.unstack(lab, axis=2)
+            a_images = tf.expand_dims(L_chan, axis=2) / 50 - 1 # black and white with input range [0, 100]
+            b_images = tf.stack([a_chan, b_chan], axis=2) / 110 # color channels with input range ~[-110, 110], not exact
+        else:
+            # break apart image pair and move to range [-1, 1]
+            width = tf.shape(raw_input)[1] # [height, width, channels]
+            a_images = raw_input[:,:width//2,:] * 2 - 1
+            b_images = raw_input[:,width//2:,:] * 2 - 1
+
+    if a.which_direction == "AtoB":
+        inputs, targets = [a_images, b_images]
+    elif a.which_direction == "BtoA":
+        inputs, targets = [b_images, a_images]
+    else:
+        raise Exception("invalid direction")
+
+    # synchronize seed for image operations so that we do the same operations to both
+    # input and output images
+    seed = random.randint(0, 2**31 - 1)
+    def transform(image):
+        r = image
+        if a.flip:
+            r = tf.image.random_flip_left_right(r, seed=seed)
+
+        # area produces a nice downscaling, but does nearest neighbor for upscaling
+        # assume we're going to be doing downscaling here
+        r = tf.image.resize_images(r, [a.scale_size, a.scale_size], method=tf.image.ResizeMethod.AREA)
+
+        offset = tf.cast(tf.floor(tf.random_uniform([2], 0, a.scale_size - CROP_SIZE + 1, seed=seed)), dtype=tf.int32)
+        if a.scale_size > CROP_SIZE:
+            r = tf.image.crop_to_bounding_box(r, offset[0], offset[1], CROP_SIZE, CROP_SIZE)
+        elif a.scale_size < CROP_SIZE:
+            raise Exception("scale size cannot be less than crop size")
+        return r
+
+    with tf.name_scope("input_images"):
+        input_images = transform(inputs)
+
+    with tf.name_scope("target_images"):
+        target_images = transform(targets)
+
+    paths, inputs, targets = tf.train.batch([paths, input_images, target_images], batch_size=a.batch_size)
+    steps_per_epoch = int(math.ceil(len(input_paths) / a.batch_size))
+
+    return Examples(
+        paths=paths,
+        inputs=inputs,
+        targets=targets,
+        count=len(input_paths),
+        steps_per_epoch=steps_per_epoch,
+    )
+
+
+def create_model(inputs, targets):
+    def create_generator(generator_inputs, generator_outputs_channels):
+        layers = []
+
+        # encoder_1: [batch, 256, 256, in_channels] => [batch, 128, 128, ngf]
+        with tf.variable_scope("encoder_1"):
+            output = conv(generator_inputs, a.ngf, stride=2)
+            layers.append(output)
+
+        layer_specs = [
+            a.ngf * 2, # encoder_2: [batch, 128, 128, ngf] => [batch, 64, 64, ngf * 2]
+            a.ngf * 4, # encoder_3: [batch, 64, 64, ngf * 2] => [batch, 32, 32, ngf * 4]
+            a.ngf * 8, # encoder_4: [batch, 32, 32, ngf * 4] => [batch, 16, 16, ngf * 8]
+            a.ngf * 8, # encoder_5: [batch, 16, 16, ngf * 8] => [batch, 8, 8, ngf * 8]
+            a.ngf * 8, # encoder_6: [batch, 8, 8, ngf * 8] => [batch, 4, 4, ngf * 8]
+            a.ngf * 8, # encoder_7: [batch, 4, 4, ngf * 8] => [batch, 2, 2, ngf * 8]
+            a.ngf * 8, # encoder_8: [batch, 2, 2, ngf * 8] => [batch, 1, 1, ngf * 8]
+        ]
+
+        for out_channels in layer_specs:
+            with tf.variable_scope("encoder_%d" % (len(layers) + 1)):
+                rectified = lrelu(layers[-1], 0.2)
+                # [batch, in_height, in_width, in_channels] => [batch, in_height/2, in_width/2, out_channels]
+                convolved = conv(rectified, out_channels, stride=2)
+                output = batchnorm(convolved)
+                layers.append(output)
+
+        layer_specs = [
+            (a.ngf * 8, 0.5),   # decoder_8: [batch, 1, 1, ngf * 8] => [batch, 2, 2, ngf * 8 * 2]
+            (a.ngf * 8, 0.5),   # decoder_7: [batch, 2, 2, ngf * 8 * 2] => [batch, 4, 4, ngf * 8 * 2]
+            (a.ngf * 8, 0.5),   # decoder_6: [batch, 4, 4, ngf * 8 * 2] => [batch, 8, 8, ngf * 8 * 2]
+            (a.ngf * 8, 0.0),   # decoder_5: [batch, 8, 8, ngf * 8 * 2] => [batch, 16, 16, ngf * 8 * 2]
+            (a.ngf * 4, 0.0),   # decoder_4: [batch, 16, 16, ngf * 8 * 2] => [batch, 32, 32, ngf * 4 * 2]
+            (a.ngf * 2, 0.0),   # decoder_3: [batch, 32, 32, ngf * 4 * 2] => [batch, 64, 64, ngf * 2 * 2]
+            (a.ngf, 0.0),       # decoder_2: [batch, 64, 64, ngf * 2 * 2] => [batch, 128, 128, ngf * 2]
+        ]
+
+        num_encoder_layers = len(layers)
+        for decoder_layer, (out_channels, dropout) in enumerate(layer_specs):
+            skip_layer = num_encoder_layers - decoder_layer - 1
+            with tf.variable_scope("decoder_%d" % (skip_layer + 1)):
+                if decoder_layer == 0:
+                    # first decoder layer doesn't have skip connections
+                    # since it is directly connected to the skip_layer
+                    input = layers[-1]
+                else:
+                    input = tf.concat_v2([layers[-1], layers[skip_layer]], axis=3)
+
+                rectified = tf.nn.relu(input)
+                # [batch, in_height, in_width, in_channels] => [batch, in_height*2, in_width*2, out_channels]
+                output = deconv(rectified, out_channels)
+                output = batchnorm(output)
+
+                if dropout > 0.0:
+                    output = tf.nn.dropout(output, keep_prob=1 - dropout)
+
+                layers.append(output)
+
+        # decoder_1: [batch, 128, 128, ngf * 2] => [batch, 256, 256, generator_outputs_channels]
+        with tf.variable_scope("decoder_1"):
+            input = tf.concat_v2([layers[-1], layers[0]], axis=3)
+            rectified = tf.nn.relu(input)
+            output = deconv(rectified, generator_outputs_channels)
+            output = tf.tanh(output)
+            layers.append(output)
+
+        return layers[-1]
+
+    def create_discriminator(discrim_inputs, discrim_targets):
+        n_layers = 3
+        layers = []
+
+        # 2x [batch, height, width, in_channels] => [batch, height, width, in_channels * 2]
+        input = tf.concat_v2([discrim_inputs, discrim_targets], axis=3)
+
+        # layer_1: [batch, 256, 256, in_channels * 2] => [batch * 2, 128, 128, ndf]
+        with tf.variable_scope("layer_1"):
+            convolved = conv(input, a.ndf, stride=2)
+            rectified = lrelu(convolved, 0.2)
+            layers.append(rectified)
+
+        # layer_2: [batch * 2, 128, 128, ndf] => [batch * 2, 64, 64, ndf * 2]
+        # layer_3: [batch * 2, 64, 64, ndf * 2] => [batch * 2, 32, 32, ndf * 4]
+        # layer_4: [batch * 2, 32, 32, ndf * 4] => [batch * 2, 31, 31, ndf * 8]
+        for i in range(n_layers):
+            with tf.variable_scope("layer_%d" % (len(layers) + 1)):
+                out_channels = a.ndf * min(2**(i+1), 8)
+                stride = 1 if i == n_layers - 1 else 2  # last layer here has stride 1
+                convolved = conv(layers[-1], out_channels, stride=stride)
+                normalized = batchnorm(convolved)
+                rectified = lrelu(normalized, 0.2)
+                layers.append(rectified)
+
+        # layer_5: [batch * 2, 31, 31, ndf * 8] => [batch * 2, 30, 30, 1]
+        with tf.variable_scope("layer_%d" % (len(layers) + 1)):
+            convolved = conv(rectified, out_channels=1, stride=1)
+            output = tf.sigmoid(convolved)
+            layers.append(output)
+
+        return layers[-1]
+
+    with tf.variable_scope("generator") as scope:
+        out_channels = int(targets.get_shape()[-1])
+        outputs = create_generator(inputs, out_channels)
+
+    # create two copies of discriminator, one for real pairs and one for fake pairs
+    # they share the same underlying variables
+    with tf.name_scope("real_discriminator"):
+        with tf.variable_scope("discriminator"):
+            # 2x [batch, height, width, channels] => [batch, 30, 30, 1]
+            predict_real = create_discriminator(inputs, targets)
+
+    with tf.name_scope("fake_discriminator"):
+        with tf.variable_scope("discriminator", reuse=True):
+            # 2x [batch, height, width, channels] => [batch, 30, 30, 1]
+            predict_fake = create_discriminator(inputs, outputs)
+
+    with tf.name_scope("discriminator_loss"):
+        # minimizing -tf.log will try to get inputs to 1
+        # predict_real => 1
+        # predict_fake => 0
+        discrim_loss = tf.reduce_mean(-(tf.log(predict_real + EPS) + tf.log(1 - predict_fake + EPS)))
+
+    with tf.name_scope("generator_loss"):
+        # predict_fake => 1
+        # abs(targets - outputs) => 0
+        gen_loss_GAN = tf.reduce_mean(-tf.log(predict_fake + EPS))
+        gen_loss_L1 = tf.reduce_mean(tf.abs(targets - outputs))
+        gen_loss = gen_loss_GAN * a.gan_weight + gen_loss_L1 * a.l1_weight
+
+    with tf.name_scope("discriminator_train"):
+        discrim_tvars = [var for var in tf.trainable_variables() if var.name.startswith("discriminator")]
+        discrim_optim = tf.train.AdamOptimizer(a.lr, a.beta1)
+        discrim_train = discrim_optim.minimize(discrim_loss, var_list=discrim_tvars)
+
+    with tf.name_scope("generator_train"):
+        with tf.control_dependencies([discrim_train]):
+            gen_tvars = [var for var in tf.trainable_variables() if var.name.startswith("generator")]
+            gen_optim = tf.train.AdamOptimizer(a.lr, a.beta1)
+            gen_train = gen_optim.minimize(gen_loss, var_list=gen_tvars)
+
+    ema = tf.train.ExponentialMovingAverage(decay=0.99)
+    update_losses = ema.apply([discrim_loss, gen_loss_GAN, gen_loss_L1])
+
+    global_step = tf.contrib.framework.get_or_create_global_step()
+    incr_global_step = tf.assign(global_step, global_step+1)
+
+    return Model(
+        predict_real=predict_real,
+        predict_fake=predict_fake,
+        discrim_loss=ema.average(discrim_loss),
+        gen_loss_GAN=ema.average(gen_loss_GAN),
+        gen_loss_L1=ema.average(gen_loss_L1),
+        outputs=outputs,
+        train=tf.group(update_losses, incr_global_step, gen_train),
+    )
+
+
+def save_images(fetches, image_dir, step=None):
+    filesets = []
+    for i, in_path in enumerate(fetches["paths"]):
+        name, _ = os.path.splitext(os.path.basename(in_path))
+        fileset = {"name": name, "step": step}
+        for kind in ["inputs", "outputs", "targets"]:
+            filename = name + "-" + kind + ".png"
+            if step is not None:
+                filename = "%08d-%s" % (step, filename)
+            fileset[kind] = filename
+            out_path = os.path.join(image_dir, filename)
+            contents = fetches[kind][i]
+            with open(out_path, "w") as f:
+                f.write(contents)
+        filesets.append(fileset)
+    return filesets
+
+
+def append_index(filesets, step=False):
+    index_path = os.path.join(a.output_dir, "index.html")
+    if os.path.exists(index_path):
+        index = open(index_path, "a")
+    else:
+        index = open(index_path, "w")
+        index.write("<html><body><table><tr>")
+        if step:
+            index.write("<th>step</th>")
+        index.write("<th>name</th><th>input</th><th>output</th><th>target</th></tr>")
+
+    for fileset in filesets:
+        index.write("<tr>")
+
+        if step:
+            index.write("<td>%d</td>" % fileset["step"])
+        index.write("<td>%s</td>" % fileset["name"])
+
+        for kind in ["inputs", "outputs", "targets"]:
+            index.write("<td><img src='images/%s'></td>" % fileset[kind])
+
+        index.write("</tr>")
+    return index_path
+
+
+def main():
+    if a.seed is None:
+        a.seed = random.randint(0, 2**31 - 1)
+
+    tf.set_random_seed(a.seed)
+    np.random.seed(a.seed)
+    random.seed(a.seed)
+
+    if not os.path.exists(a.output_dir):
+        os.makedirs(a.output_dir)
+
+    if a.mode == "test":
+        if a.checkpoint is None:
+            raise Exception("checkpoint required for test mode")
+
+        # load some options from the checkpoint
+        options = {"which_direction", "ngf", "ndf", "lab_colorization"}
+        with open(os.path.join(a.checkpoint, "options.json")) as f:
+            for key, val in json.loads(f.read()).iteritems():
+                if key in options:
+                    print("loaded", key, "=", val)
+                    setattr(a, key, val)
+        # disable these features in test mode
+        a.scale_size = CROP_SIZE
+        a.flip = False
+
+    for k, v in a._get_kwargs():
+        print(k, "=", v)
+
+    with open(os.path.join(a.output_dir, "options.json"), "w") as f:
+        f.write(json.dumps(vars(a), sort_keys=True, indent=4))
+
+    examples = load_examples()
+
+    print("examples count = %d" % examples.count)
+
+    model = create_model(examples.inputs, examples.targets)
+
+    def deprocess(image):
+        if a.aspect_ratio != 1.0:
+            # upscale to correct aspect ratio
+            size = [CROP_SIZE, int(round(CROP_SIZE * a.aspect_ratio))]
+            image = tf.image.resize_images(image, size=size, method=tf.image.ResizeMethod.BICUBIC)
+
+        if a.lab_colorization:
+            # colorization mode images can be 1 channel (L) or 2 channels (a,b)
+            num_channels = int(image.get_shape()[-1])
+            if num_channels == 1:
+                return tf.image.convert_image_dtype((image + 1) / 2, dtype=tf.uint8, saturate=True)
+            elif num_channels == 2:
+                # (a, b) color channels, convert to rgb
+                # a_chan and b_chan have range [-1, 1] => [-110, 110]
+                a_chan, b_chan = tf.unstack(image * 110, axis=3)
+                # get L_chan from inputs or targets
+                if a.which_direction == "AtoB":
+                    brightness = examples.inputs
+                elif a.which_direction == "BtoA":
+                    brightness = examples.targets
+                else:
+                    raise Exception("invalid direction")
+                # L_chan has range [-1, 1] => [0, 100]
+                L_chan = tf.squeeze((brightness + 1) / 2 * 100, axis=3)
+                lab = tf.stack([L_chan, a_chan, b_chan], axis=3)
+                rgb = lab_to_rgb(lab)
+                return tf.image.convert_image_dtype(rgb, dtype=tf.uint8, saturate=True)
+            else:
+                raise Exception("unexpected number of channels")
+        else:
+            return tf.image.convert_image_dtype((image + 1) / 2, dtype=tf.uint8, saturate=True)
+
+    # reverse any processing on images so they can be written to disk or displayed to user
+    with tf.name_scope("deprocess_inputs"):
+        deprocessed_inputs = deprocess(examples.inputs)
+
+    with tf.name_scope("deprocess_targets"):
+        deprocessed_targets = deprocess(examples.targets)
+
+    with tf.name_scope("deprocess_outputs"):
+        deprocessed_outputs = deprocess(model.outputs)
+
+    with tf.name_scope("encode_images"):
+        display_fetches = {
+            "paths": examples.paths,
+            "inputs": tf.map_fn(tf.image.encode_png, deprocessed_inputs, dtype=tf.string, name="input_pngs"),
+            "targets": tf.map_fn(tf.image.encode_png, deprocessed_targets, dtype=tf.string, name="target_pngs"),
+            "outputs": tf.map_fn(tf.image.encode_png, deprocessed_outputs, dtype=tf.string, name="output_pngs"),
+        }
+
+    # summaries
+    with tf.name_scope("inputs_summary"):
+        tf.summary.image("inputs", deprocessed_inputs)
+
+    with tf.name_scope("targets_summary"):
+        tf.summary.image("targets", deprocessed_targets)
+
+    with tf.name_scope("outputs_summary"):
+        tf.summary.image("outputs", deprocessed_outputs)
+
+    with tf.name_scope("predict_real_summary"):
+        tf.summary.image("predict_real", tf.image.convert_image_dtype(model.predict_real, dtype=tf.uint8))
+
+    with tf.name_scope("predict_fake_summary"):
+        tf.summary.image("predict_fake", tf.image.convert_image_dtype(model.predict_fake, dtype=tf.uint8))
+
+    tf.summary.scalar("discriminator_loss", model.discrim_loss)
+    tf.summary.scalar("generator_loss_GAN", model.gen_loss_GAN)
+    tf.summary.scalar("generator_loss_L1", model.gen_loss_L1)
+
+    with tf.name_scope("parameter_count"):
+        parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()])
+
+    image_dir = os.path.join(a.output_dir, "images")
+    if not os.path.exists(image_dir):
+        os.makedirs(image_dir)
+
+    saver = tf.train.Saver(max_to_keep=1)
+
+    logdir = a.output_dir if (a.trace_freq > 0 or a.summary_freq > 0) else None
+    sv = tf.train.Supervisor(logdir=logdir, save_summaries_secs=0, saver=None)
+    with sv.managed_session() as sess:
+        print("parameter_count =", sess.run(parameter_count))
+
+        if a.checkpoint is not None:
+            print("loading model from checkpoint")
+            checkpoint = tf.train.latest_checkpoint(a.checkpoint)
+            saver.restore(sess, checkpoint)
+
+        if a.mode == "test":
+            # testing
+            # run a single epoch over all input data
+            for step in range(examples.steps_per_epoch):
+                results = sess.run(display_fetches)
+                filesets = save_images(results, image_dir)
+                for i, path in enumerate(results["paths"]):
+                    print(step * a.batch_size + i + 1, "evaluated image", os.path.basename(path))
+                index_path = append_index(filesets)
+
+            print("wrote index at", index_path)
+        else:
+            # training
+            max_steps = 2**32
+            if a.max_epochs is not None:
+                max_steps = examples.steps_per_epoch * a.max_epochs
+            if a.max_steps is not None:
+                max_steps = a.max_steps
+
+            start_time = time.time()
+            for step in range(max_steps):
+                def should(freq):
+                    return freq > 0 and ((step + 1) % freq == 0 or step == max_steps - 1)
+
+                options = None
+                run_metadata = None
+                if should(a.trace_freq):
+                    options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
+                    run_metadata = tf.RunMetadata()
+
+                fetches = {
+                    "train": model.train,
+                    "global_step": sv.global_step,
+                }
+
+                if should(a.progress_freq):
+                    fetches["discrim_loss"] = model.discrim_loss
+                    fetches["gen_loss_GAN"] = model.gen_loss_GAN
+                    fetches["gen_loss_L1"] = model.gen_loss_L1
+
+                if should(a.summary_freq):
+                    fetches["summary"] = sv.summary_op
+
+                if should(a.display_freq):
+                    fetches["display"] = display_fetches
+
+                results = sess.run(fetches, options=options, run_metadata=run_metadata)
+
+                if should(a.summary_freq):
+                    sv.summary_writer.add_summary(results["summary"], results["global_step"])
+
+                if should(a.display_freq):
+                    print("saving display images")
+                    filesets = save_images(results["display"], image_dir, step=results["global_step"])
+                    append_index(filesets, step=True)
+
+                if should(a.trace_freq):
+                    print("recording trace")
+                    sv.summary_writer.add_run_metadata(run_metadata, "step_%d" % results["global_step"])
+
+                if should(a.progress_freq):
+                    global_step = results["global_step"]
+                    print("progress  epoch %d  step %d  image/sec %0.1f" % (global_step // examples.steps_per_epoch, global_step % examples.steps_per_epoch, global_step * a.batch_size / (time.time() - start_time)))
+                    print("discrim_loss", results["discrim_loss"])
+                    print("gen_loss_GAN", results["gen_loss_GAN"])
+                    print("gen_loss_L1", results["gen_loss_L1"])
+
+                if should(a.save_freq):
+                    print("saving model")
+                    saver.save(sess, os.path.join(a.output_dir, "model"), global_step=sv.global_step)
+
+                if sv.should_stop():
+                    break
+
+
+main()
diff --git a/tools/download-dataset.py b/tools/download-dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e90c5cf867032e6cdd0bbc81f0b5d384e2156e3
--- /dev/null
+++ b/tools/download-dataset.py
@@ -0,0 +1,21 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import urllib2
+import sys
+import tarfile
+import tempfile
+import shutil
+
+dataset = sys.argv[1]
+url = "https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/%s.tar.gz" % dataset
+with tempfile.TemporaryFile() as tmp:
+    print("downloading", url)
+    shutil.copyfileobj(urllib2.urlopen(url), tmp)
+    print("extracting")
+    tmp.seek(0)
+    tar = tarfile.open(fileobj=tmp)
+    tar.extractall()
+    tar.close()
+    print("done")
diff --git a/tools/process.py b/tools/process.py
new file mode 100644
index 0000000000000000000000000000000000000000..a111e6411aa050ac455d81ae01dd68c724eb13da
--- /dev/null
+++ b/tools/process.py
@@ -0,0 +1,246 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import os
+import random
+import tensorflow as tf
+import numpy as np
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--input_dir", required=True, help="path to folder containing images")
+parser.add_argument("--output_dir", required=True, help="output path")
+parser.add_argument("--operation", required=True, choices=["grayscale", "resize", "blank", "combine"])
+parser.add_argument("--pad", action="store_true", help="pad instead of crop for resize operation")
+parser.add_argument("--size", type=int, default=256, help="size to use for resize operation")
+parser.add_argument("--b_dir", type=str, help="path to folder containing B images for combine operation")
+a = parser.parse_args()
+
+
+def grayscale(img):
+    img = img / 255
+    img = 0.299 * img[:,:,0] + 0.587 * img[:,:,1] + 0.114 * img[:,:,2]
+    return (np.expand_dims(img, axis=2) * 255).astype(np.uint8)
+
+
+def normalize(img):
+    img -= img.min()
+    img /= img.max()
+    return img
+
+
+def create_op(func, **placeholders):
+    op = func(**placeholders)
+
+    def f(**kwargs):
+        feed_dict = {}
+        for argname, argvalue in kwargs.iteritems():
+            placeholder = placeholders[argname]
+            feed_dict[placeholder] = argvalue
+        return op.eval(feed_dict=feed_dict)
+
+    return f
+
+downscale = create_op(
+    func=tf.image.resize_images,
+    images=tf.placeholder(tf.float32, [None, None, None]),
+    size=tf.placeholder(tf.int32, [2]),
+    method=tf.image.ResizeMethod.AREA,
+)
+
+upscale = create_op(
+    func=tf.image.resize_images,
+    images=tf.placeholder(tf.float32, [None, None, None]),
+    size=tf.placeholder(tf.int32, [2]),
+    method=tf.image.ResizeMethod.BICUBIC,
+)
+
+decode_jpeg = create_op(
+    func=tf.image.decode_jpeg,
+    contents=tf.placeholder(tf.string),
+)
+
+decode_png = create_op(
+    func=tf.image.decode_png,
+    contents=tf.placeholder(tf.string),
+)
+
+rgb_to_grayscale = create_op(
+    func=tf.image.rgb_to_grayscale,
+    images=tf.placeholder(tf.float32),
+)
+
+grayscale_to_rgb = create_op(
+    func=tf.image.grayscale_to_rgb,
+    images=tf.placeholder(tf.float32),
+)
+
+encode_jpeg = create_op(
+    func=tf.image.encode_jpeg,
+    image=tf.placeholder(tf.uint8),
+)
+
+encode_png = create_op(
+    func=tf.image.encode_png,
+    image=tf.placeholder(tf.uint8),
+)
+
+crop = create_op(
+    func=tf.image.crop_to_bounding_box,
+    image=tf.placeholder(tf.float32),
+    offset_height=tf.placeholder(tf.int32, []),
+    offset_width=tf.placeholder(tf.int32, []),
+    target_height=tf.placeholder(tf.int32, []),
+    target_width=tf.placeholder(tf.int32, []),
+)
+
+pad = create_op(
+    func=tf.image.pad_to_bounding_box,
+    image=tf.placeholder(tf.float32),
+    offset_height=tf.placeholder(tf.int32, []),
+    offset_width=tf.placeholder(tf.int32, []),
+    target_height=tf.placeholder(tf.int32, []),
+    target_width=tf.placeholder(tf.int32, []),
+)
+
+to_uint8 = create_op(
+    func=tf.image.convert_image_dtype,
+    image=tf.placeholder(tf.float32),
+    dtype=tf.uint8,
+    saturate=True,
+)
+
+to_float32 = create_op(
+    func=tf.image.convert_image_dtype,
+    image=tf.placeholder(tf.uint8),
+    dtype=tf.float32,
+)
+
+
+def load(path):
+    contents = open(path).read()
+    _, ext = os.path.splitext(path.lower())
+
+    if ext == ".jpg":
+        image = decode_jpeg(contents=contents)
+    elif ext == ".png":
+        image = decode_png(contents=contents)
+    else:
+        raise Exception("invalid image suffix")
+
+    return to_float32(image=image)
+
+
+def find(d):
+    result = []
+    for filename in os.listdir(d):
+        _, ext = os.path.splitext(filename.lower())
+        if ext == ".jpg" or ext == ".png":
+            result.append(os.path.join(d, filename))
+    result.sort()
+    return result
+
+
+def save(image, path):
+    _, ext = os.path.splitext(path.lower())
+    image = to_uint8(image=image)
+    if ext == ".jpg":
+        encoded = encode_jpeg(image=image)
+    elif ext == ".png":
+        encoded = encode_png(image=image)
+    else:
+        raise Exception("invalid image suffix")
+
+    if os.path.exists(path):
+        raise Exception("file already exists at " + path)
+
+    with open(path, "w") as f:
+        f.write(encoded)
+
+
+def png_path(path):
+    basename, _ = os.path.splitext(os.path.basename(path))
+    return os.path.join(os.path.dirname(path), basename + ".png")
+
+
+def main():
+    random.seed(0)
+
+    if not os.path.exists(a.output_dir):
+        os.makedirs(a.output_dir)
+
+    with tf.Session() as sess:
+        for src_path in find(a.input_dir):
+            dst_path = png_path(os.path.join(a.output_dir, os.path.basename(src_path)))
+            print(src_path, "->", dst_path)
+            src = load(src_path)
+
+            if a.operation == "grayscale":
+                dst = grayscale_to_rgb(images=rgb_to_grayscale(images=src))
+            elif a.operation == "resize":
+                height, width, _ = src.shape
+                dst = src
+                if height != width:
+                    if a.pad:
+                        size = max(height, width)
+                        # pad to correct ratio
+                        oh = (size - height) // 2
+                        ow = (size - width) // 2
+                        dst = pad(image=dst, offset_height=oh, offset_width=ow, target_height=size, target_width=size)
+                    else:
+                        # crop to correct ratio
+                        size = min(height, width)
+                        oh = (height - size) // 2
+                        ow = (width - size) // 2
+                        dst = crop(image=dst, offset_height=oh, offset_width=ow, target_height=size, target_width=size)
+
+                assert(dst.shape[0] == dst.shape[1])
+
+                size, _, _ = dst.shape
+                if size > a.size:
+                    dst = downscale(images=dst, size=[a.size, a.size])
+                elif size < a.size:
+                    dst = upscale(images=dst, size=[a.size, a.size])
+            elif a.operation == "blank":
+                height, width, _ = src.shape
+                if height != width:
+                    raise Exception("non-square image")
+
+                image_size = width
+                size = int(image_size * 0.3)
+                offset = int(image_size / 2 - size / 2)
+
+                dst = src
+                dst[offset:offset + size,offset:offset + size,:] = np.ones([size, size, 3])
+            elif a.operation == "combine":
+                if a.b_dir is None:
+                    raise Exception("missing b_dir")
+
+                # find corresponding file in b_dir, could have a different extension
+                basename, _ = os.path.splitext(os.path.basename(src_path))
+                for ext in [".png", ".jpg"]:
+                    sibling_path = os.path.join(a.b_dir, basename + ext)
+                    if os.path.exists(sibling_path):
+                        sibling = load(sibling_path)
+                        break
+                else:
+                    raise Exception("could not find sibling image for " + src_path)
+
+                # make sure that dimensions are correct
+                height, width, _ = src.shape
+                if height != sibling.shape[0] or width != sibling.shape[1]:
+                    raise Exception("differing sizes")
+
+                # remove alpha channel
+                src = src[:,:,:3]
+                sibling = sibling[:,:,:3]
+                dst = np.concatenate([src, sibling], axis=1)
+            else:
+                raise Exception("invalid operation")
+
+            save(dst, dst_path)
+
+
+main()
diff --git a/tools/split.py b/tools/split.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef93f2cc68d009354cda2dd8d91d1b6e2d81d890
--- /dev/null
+++ b/tools/split.py
@@ -0,0 +1,40 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import random
+import argparse
+import glob
+import os
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--dir", type=str, required=True, help="path to folder containing images")
+parser.add_argument("--train_frac", type=float, default=0.8, help="percentage of images to use for training set")
+parser.add_argument("--test_frac", type=float, default=0.0, help="percentage of images to use for test set")
+a = parser.parse_args()
+
+
+def main():
+    random.seed(0)
+
+    files = glob.glob(os.path.join(a.dir, "*.png"))
+    assignments = []
+    assignments.extend(["train"] * int(a.train_frac * len(files)))
+    assignments.extend(["test"] * int(a.test_frac * len(files)))
+    assignments.extend(["val"] * int(len(files) - len(assignments)))
+    random.shuffle(assignments)
+
+    for name in ["train", "val", "test"]:
+        if name in assignments:
+            d = os.path.join(a.dir, name)
+            if not os.path.exists(d):
+                os.makedirs(d)
+
+    print(len(files), len(assignments))
+    for inpath, assignment in zip(files, assignments):
+        outpath = os.path.join(a.dir, assignment, os.path.basename(inpath))
+        print(inpath, "->", outpath)
+        os.rename(inpath, outpath)
+
+main()