diff --git a/.gitignore b/.gitignore index 181a5519b9ab844dc28fd43e9cefdbaf12c1e505..b899f89b51f1efee5dd6faae0b675a525813d55a 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ models terraform.tfstate terraform.tfstate.backup terraform.tfvars +.idea diff --git a/pix2pix.py b/pix2pix.py index fbcbace82c38d616a273d25d3259158c890743db..20ba819340c414b14ba753047919744e14a777f4 100644 --- a/pix2pix.py +++ b/pix2pix.py @@ -581,13 +581,17 @@ def main(): input = tf.placeholder(tf.string, shape=[1]) input_data = tf.decode_base64(input[0]) input_image = tf.image.decode_png(input_data) + # remove alpha channel if present - input_image = input_image[:,:,:3] + input_image = tf.cond(tf.equal(tf.shape(input_image)[2], 4), lambda: input_image[:,:,:3], lambda: input_image) + # convert grayscale to RGB + input_image = tf.cond(tf.equal(tf.shape(input_image)[2], 1), lambda: tf.image.grayscale_to_rgb(input_image), lambda: input_image) + input_image = tf.image.convert_image_dtype(input_image, dtype=tf.float32) input_image.set_shape([CROP_SIZE, CROP_SIZE, 3]) batch_input = tf.expand_dims(input_image, axis=0) - with tf.variable_scope("generator") as scope: + with tf.variable_scope("generator"): batch_output = deprocess(create_generator(preprocess(batch_input), 3)) output_image = tf.image.convert_image_dtype(batch_output, dtype=tf.uint8)[0] diff --git a/server/README.md b/server/README.md index 0804c90a026dfdfbc34a866af9590b936d95203d..f5fac9be602501ece825ff5be1a19c45dc635c15 100644 --- a/server/README.md +++ b/server/README.md @@ -88,3 +88,107 @@ cp terraform.tfvars.example terraform.tfvars python ../tools/dockrun.py terraform plan python ../tools/dockrun.py terraform apply ``` + +## Full training + exporting + hosting commands + +Tested with Python 3.6, Tensorflow 1.0.0, Docker, gcloud, and Terraform (https://www.terraform.io/downloads.html) + +```sh +git clone https://github.com/affinelayer/pix2pix-tensorflow.git +cd pix2pix-tensorflow + +# get some images (only 2 for testing) +mkdir source +curl -o source/cat1.jpg https://farm5.staticflickr.com/4032/4394955222_eea73818d9_o.jpg +curl -o source/cat2.jpg http://wallpapercave.com/wp/ePMeSmp.jpg + +# resize source images +python tools/process.py \ + --input_dir source \ + --operation resize \ + --output_dir resized + +# create edges from resized images (uses docker container since compiling the dependencies is annoying) +python tools/dockrun.py python tools/process.py \ + --input_dir resized \ + --operation edges \ + --output_dir edges + +# combine resized with edges +python tools/process.py \ + --input_dir edges \ + --b_dir resized \ + --operation combine \ + --output_dir combined + +# train on images (only 1 epoch for testing) +python pix2pix.py \ + --mode train \ + --output_dir train \ + --max_epochs 1 \ + --input_dir combined \ + --which_direction AtoB + +# export model (creates a version of the model that works with the server in server/serve.py as well as google hosted tensorflow) +python pix2pix.py \ + --mode export \ + --output_dir server/models/edges2cats_AtoB \ + --checkpoint train + +# process image locally using exported model +python server/tools/process-local.py \ + --model_dir server/models/edges2cats_AtoB \ + --input_file edges/cat1.png \ + --output_file output.png + +# serve model locally +cd server +python serve.py --port 8000 --local_models_dir models + +# open http://localhost:8000 in a browser, and scroll to the bottom, you should be able to process an edges2cat image and get a bunch of noise as output + +# serve model remotely + +export GOOGLE_PROJECT=<project name> + +# build image +# make sure models are in a directory called "models" in the current directory +docker build --rm --tag us.gcr.io/$GOOGLE_PROJECT/pix2pix-server . + +# test image locally +docker run --publish 8000:8000 --rm --name server us.gcr.io/$GOOGLE_PROJECT/pix2pix-server python -u serve.py \ + --port 8000 \ + --local_models_dir models + +# run this while the above server is running +python tools/process-remote.py \ + --input_file static/edges2cats-input.png \ + --url http://localhost:8000/edges2cats_AtoB \ + --output_file output.png + +# publish image to private google container repository +python tools/upload-image.py --project $GOOGLE_PROJECT --version v1 + +# create a google cloud server +cp terraform.tfvars.example terraform.tfvars +# edit terraform.tfvars to put your cloud info in there +# get the service-account.json from the google cloud console +# make sure GCE is enabled on your account as well +python terraform plan +python terraform apply + +# get name of server +gcloud compute instance-groups list-instances pix2pix-manager +# ssh to server +gcloud compute ssh <name of instance here> +# look at the logs (can take awhile to load docker image) +sudo journalctl -f -u pix2pix +# if you have never made an http-server before, apparently you may need this rule +gcloud compute firewall-rules create http-server --allow=tcp:80 --target-tags http-server +# get ip address of load balancer +gcloud compute forwarding-rules list +# open that in the browser, should see the same page you saw locally + +# to destroy the GCP resources, use this +terraform destroy +``` \ No newline at end of file diff --git a/server/serve.py b/server/serve.py index 5ece7a0dfb1a2df8c18d768c28709f61b154b85a..f0241fe93d35868f3d92061a7916ea1163007a32 100644 --- a/server/serve.py +++ b/server/serve.py @@ -3,7 +3,6 @@ from __future__ import division from __future__ import print_function import socket -import urlparse import time import argparse import base64 @@ -99,7 +98,7 @@ class Handler(BaseHTTPRequestHandler): self.send_response(200) self.send_header("Content-Type", "text/html") self.end_headers() - with open("static/index.html") as f: + with open("static/index.html", "rb") as f: self.wfile.write(f.read()) return @@ -117,7 +116,7 @@ class Handler(BaseHTTPRequestHandler): else: self.send_header("Content-Type", "application/octet-stream") self.end_headers() - with open("static/" + path) as f: + with open("static/" + path, "rb") as f: self.wfile.write(f.read()) @@ -154,7 +153,7 @@ class Handler(BaseHTTPRequestHandler): variants = models[name] # "cloud" and "local" are the two possible variants - content_len = int(self.headers.getheader("content-length", 0)) + content_len = int(self.headers.get("content-length", "0")) if content_len > 1 * 1024 * 1024: raise Exception("post body too large") input_data = self.rfile.read(content_len) @@ -192,9 +191,9 @@ class Handler(BaseHTTPRequestHandler): raise Exception("too many requests") # add any missing padding - output_b64data += "=" * (-len(output_b64data) % 4) + output_b64data += b"=" * (-len(output_b64data) % 4) output_data = base64.urlsafe_b64decode(output_b64data) - if output_data.startswith("\x89PNG"): + if output_data.startswith(b"\x89PNG"): headers["content-type"] = "image/png" else: headers["content-type"] = "image/jpeg" @@ -207,7 +206,7 @@ class Handler(BaseHTTPRequestHandler): body = "server error" self.send_response(status) - for key, value in headers.iteritems(): + for key, value in headers.items(): self.send_header(key, value) self.end_headers() self.wfile.write(body) @@ -273,7 +272,7 @@ def main(): project_id = a.project else: credentials = oauth2client.service_account.ServiceAccountCredentials.from_json_keyfile_name(a.credentials, scopes) - with open(a.credentials) as f: + with open(a.credentials, "r") as f: project_id = json.loads(f.read())["project_id"] # due to what appears to be a bug, we cannot get the discovery document when specifying an http client diff --git a/tools/dockrun.py b/tools/dockrun.py index 611573783bcef0f62973135034657d49bdb64eb4..b3e8d473cc36a0909ea6b769528af8de221ef329 100644 --- a/tools/dockrun.py +++ b/tools/dockrun.py @@ -101,8 +101,6 @@ def main(): "PYTHONUNBUFFERED=x", "--env", "CUDA_CACHE_PATH=/host/tmp/cuda-cache", - "--env", - "HOME=/host" + os.environ["HOME"], ] if a.port is not None: