diff --git a/.gitignore b/.gitignore index 1e38629ce5049cbde4d10de60755088bc0143cea..8dd26bad5ca1590f0933ff843a7d607829558aaa 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ +*.pyc tags test diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..4f916f8b94fe67eaa1543ede094cb72290c3c625 --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,122 @@ +# docker build --rm --tag affinelayer/pix2pix-tensorflow . +# docker push affinelayer/pix2pix-tensorflow + +FROM nvidia/cuda:8.0-cudnn5-devel-ubuntu16.04 + +WORKDIR /root + +RUN apt-get update + +# caffe +# from https://github.com/BVLC/caffe/blob/master/docker/cpu/Dockerfile +RUN apt-get install -y --no-install-recommends \ + build-essential \ + cmake \ + git \ + wget \ + libatlas-base-dev \ + libboost-all-dev \ + libgflags-dev \ + libgoogle-glog-dev \ + libhdf5-serial-dev \ + libleveldb-dev \ + liblmdb-dev \ + libopencv-dev \ + libprotobuf-dev \ + libsnappy-dev \ + protobuf-compiler \ + python-dev \ + python-numpy \ + python-pip \ + python-setuptools \ + python-scipy + +ENV CAFFE_ROOT=/opt/caffe + +RUN mkdir -p $CAFFE_ROOT && \ + cd $CAFFE_ROOT && \ + git clone --depth 1 https://github.com/s9xie/hed . && \ + git checkout 9e74dd710773d8d8a469ad905c76f4a7fa08f945 && \ + pip install --upgrade pip && \ + cd python && for req in $(cat requirements.txt) pydot; do pip install $req; done && cd .. && \ + # https://github.com/s9xie/hed/pull/23 + sed -i "s|add_subdirectory(examples)||g" CMakeLists.txt && \ + # https://github.com/s9xie/hed/issues/11 + sed -i "647s|//||" include/caffe/loss_layers.hpp && \ + sed -i "648s|//||" include/caffe/loss_layers.hpp && \ + mkdir build && cd build && \ + cmake -DCPU_ONLY=1 .. && \ + make -j"$(nproc)" + +ENV PYCAFFE_ROOT $CAFFE_ROOT/python +ENV PYTHONPATH $PYCAFFE_ROOT:$PYTHONPATH +ENV PATH $CAFFE_ROOT/build/tools:$PYCAFFE_ROOT:$PATH +RUN echo "$CAFFE_ROOT/build/lib" >> /etc/ld.so.conf.d/caffe.conf && ldconfig + +RUN cd $CAFFE_ROOT && curl -O http://vcl.ucsd.edu/hed/hed_pretrained_bsds.caffemodel + +# octave +RUN apt-get install -y --no-install-recommends octave liboctave-dev && \ + octave --eval "pkg install -forge image" && \ + echo "pkg load image;" >> /root/.octaverc + +RUN apt-get install -y --no-install-recommends unzip && \ + curl -O https://pdollar.github.io/toolbox/archive/piotr_toolbox.zip && \ + unzip piotr_toolbox.zip && \ + octave --eval "addpath(genpath('/root/toolbox')); savepath;" && \ + echo "#include <stdlib.h>" > wrappers.hpp && \ + cat /root/toolbox/channels/private/wrappers.hpp >> wrappers.hpp && \ + mv wrappers.hpp /root/toolbox/channels/private/wrappers.hpp && \ + mkdir /root/mex && \ + cd /root/toolbox/channels/private && \ + mkoctfile --mex -DMATLAB_MEX_FILE -o /root/mex/convConst.mex convConst.cpp && \ + mkoctfile --mex -DMATLAB_MEX_FILE -o /root/mex/gradientMex.mex gradientMex.cpp && \ + mkoctfile --mex -DMATLAB_MEX_FILE -o /root/mex/imPadMex.mex imPadMex.cpp && \ + mkoctfile --mex -DMATLAB_MEX_FILE -o /root/mex/imResampleMex.mex imResampleMex.cpp && \ + mkoctfile --mex -DMATLAB_MEX_FILE -o /root/mex/rgbConvertMex.mex rgbConvertMex.cpp && \ + octave --eval "addpath('/root/mex'); savepath;" + +RUN curl -O https://raw.githubusercontent.com/pdollar/edges/master/private/edgesNmsMex.cpp && \ + octave --eval "mex edgesNmsMex.cpp" && \ + mv edgesNmsMex.mex /root/mex/ + +# from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/docker/Dockerfile.gpu +RUN apt-get install -y --no-install-recommends \ + build-essential \ + curl \ + libfreetype6-dev \ + libpng12-dev \ + libzmq3-dev \ + pkg-config \ + python \ + python-dev \ + rsync \ + software-properties-common \ + unzip + +# gpu tracing in tensorflow +ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH + +RUN pip install \ + appdirs==1.4.0 \ + funcsigs==1.0.2 \ + google-api-python-client==1.6.2 \ + google-auth==0.7.0 \ + google-auth-httplib2==0.0.2 \ + google-cloud-core==0.22.1 \ + google-cloud-storage==0.22.0 \ + googleapis-common-protos==1.5.2 \ + httplib2==0.10.3 \ + mock==2.0.0 \ + numpy==1.12.0 \ + oauth2client==4.0.0 \ + packaging==16.8 \ + pbr==1.10.0 \ + protobuf==3.2.0 \ + pyasn1==0.2.2 \ + pyasn1-modules==0.0.8 \ + pyparsing==2.1.10 \ + rsa==3.4.2 \ + six==1.10.0 \ + uritemplate==3.0.0 \ + https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-0.12.1-cp27-none-linux_x86_64.whl diff --git a/pix2pix.py b/pix2pix.py index eddc5902f3db4d8cc02c02f50108bfe404f0c796..5dd82d894199f65974a0bcd4a734dd3549bc7ce9 100644 --- a/pix2pix.py +++ b/pix2pix.py @@ -577,13 +577,23 @@ def main(): if a.lab_colorization: raise Exception("export not supported for lab_colorization") - input = tf.placeholder(tf.float32, shape=[CROP_SIZE, CROP_SIZE, 3], name="input") + input = tf.placeholder(tf.string, shape=[1]) + input_data = tf.decode_base64(input[0]) + input_image = tf.image.decode_png(input_data) + # remove alpha channel if present + input_image = input_image[:,:,:3] + input_image = tf.image.convert_image_dtype(input_image, dtype=tf.float32) + input_image.set_shape([CROP_SIZE, CROP_SIZE, 3]) + batch_input = tf.expand_dims(input_image, axis=0) + with tf.variable_scope("generator") as scope: - outputs = create_generator(tf.expand_dims(preprocess(input), axis=0), 3) + batch_output = deprocess(create_generator(preprocess(batch_input), 3)) - output = deprocess(tf.identity(outputs[0,:,:,:], name="output")) + output_image = tf.image.convert_image_dtype(batch_output, dtype=tf.uint8)[0] + output_data = tf.image.encode_png(output_image) + output = tf.convert_to_tensor([tf.encode_base64(output_data)]) - key = tf.placeholder(tf.string, shape=[None]) + key = tf.placeholder(tf.string, shape=[1]) inputs = { "key": key.name, "input": input.name @@ -725,7 +735,8 @@ def main(): print("wrote index at", index_path) else: # training - start_time = time.time() + start = time.time() + for step in range(max_steps): def should(freq): return freq > 0 and ((step + 1) % freq == 0 or step == max_steps - 1) @@ -753,25 +764,27 @@ def main(): fetches["display"] = display_fetches results = sess.run(fetches, options=options, run_metadata=run_metadata) - global_step = results["global_step"] if should(a.summary_freq): - sv.summary_writer.add_summary(results["summary"], global_step) + print("recording summary") + sv.summary_writer.add_summary(results["summary"], results["global_step"]) if should(a.display_freq): print("saving display images") - filesets = save_images(results["display"], step=global_step) + filesets = save_images(results["display"], step=results["global_step"]) append_index(filesets, step=True) if should(a.trace_freq): print("recording trace") - sv.summary_writer.add_run_metadata(run_metadata, "step_%d" % global_step) + sv.summary_writer.add_run_metadata(run_metadata, "step_%d" % results["global_step"]) if should(a.progress_freq): # global_step will have the correct step count if we resume from a checkpoint - train_epoch = math.ceil(global_step / examples.steps_per_epoch) - train_step = global_step - (train_epoch - 1) * examples.steps_per_epoch - print("progress epoch %d step %d image/sec %0.1f" % (train_epoch, train_step, global_step * a.batch_size / (time.time() - start_time))) + train_epoch = math.ceil(results["global_step"] / examples.steps_per_epoch) + train_step = (results["global_step"] - 1) % examples.steps_per_epoch + 1 + rate = (step + 1) * a.batch_size / (time.time() - start) + remaining = (max_steps - step) * a.batch_size / rate + print("progress epoch %d step %d image/sec %0.1f remaining %dm" % (train_epoch, train_step, rate, remaining / 60)) print("discrim_loss", results["discrim_loss"]) print("gen_loss_GAN", results["gen_loss_GAN"]) print("gen_loss_L1", results["gen_loss_L1"]) diff --git a/server/README.md b/server/README.md new file mode 100644 index 0000000000000000000000000000000000000000..364159ad25986368d41d171959e24876b1bb4bb1 --- /dev/null +++ b/server/README.md @@ -0,0 +1,54 @@ +# pix2pix-tensorflow server + +Host pix2pix-tensorflow models to be used with something like the [Image-to-Image Demo](https://affinelayer.com/pixsrv/). + +This is a simple python server that serves models exported from `pix2pix.py --mode export`. It can serve local models or use [Cloud ML](https://cloud.google.com/ml/) to run the model. + +## Local + +Using the [pix2pix-tensorflow Docker image](https://hub.docker.com/r/affinelayer/pix2pix-tensorflow/): + +```sh +alias p2p-run="sudo docker run --rm --volume /:/host --workdir /host\$PWD --env PYTHONUNBUFFERED=x --env CUDA_CACHE_PATH=/host/tmp/cuda-cache --env HOME=/host\$HOME --publish 8000:8000 affinelayer/pix2pix-tensorflow" + +# export a model to upload +p2p-run python export-example-model.py --output_dir models/example +# process an image with the model using local tensorflow +p2p-run python process-local.py \ + --model_dir models/example \ + --input_file static/facades-input.png \ + --output_file output.png +# run local server +p2p-run python serve.py --local_models_dir models +# test the local server +curl -X POST http://localhost:8000/example \ + --data-binary @static/facades-input.png >! output.png +``` + +If you open [http://localhost:8000/](http://localhost:8000/) in a browser, you should see an interactive demo, though this expects the server to be hosting the exported models available here: + +- [edges2shoes](https://mega.nz/#!HtYwAZTY!5tBLYt_6HFj9u2Kxgp4-I36O4EV9r3bDP44ztX3qesI) +- [edges2handbags](https://mega.nz/#!Clg3EaLA!YW2jfRHvwpJn5Elww_wM-f3eRzKiGHLw-F4A3eQCceI) +- [facades](https://mega.nz/#!f1ZjmZoa!mCSxFRxt1WLBpNFsv5raoroEigxomDVpdi40aOG1KMc) + +Extract those to the models directory and restart the server to have it host the models. + +## Cloud ML + +For this you'll want to generate a service account JSON file from https://console.cloud.google.com/iam-admin/serviceaccounts/project (select "Furnish a new private key"). If you are already logged in with the gcloud SDK, the script will auto-detect credentials from that if you leave off the `--credentials` option. + +```sh +# upload model to google cloud ml +p2p-run python upload-model.py \ + --bucket your-models-bucket-name-here \ + --model_name example \ + --model_dir models/example \ + --credentials service-account.json +# process an image with the model using google cloud ml +p2p-run python process-cloud.py \ + --model example \ + --input_file static/facades-input.png \ + --output_file output.png \ + --credentials service-account.json +``` + diff --git a/server/export-example-model.py b/server/export-example-model.py new file mode 100644 index 0000000000000000000000000000000000000000..92997fdda698a734d0ce72cca008fdba565bb94e --- /dev/null +++ b/server/export-example-model.py @@ -0,0 +1,50 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +import json +import os +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument("--output_dir", required=True, help="directory to put exported model in") +a = parser.parse_args() + + +def main(): + if not os.path.exists(a.output_dir): + os.makedirs(a.output_dir) + + input = tf.placeholder(tf.string, shape=[1]) + key = tf.placeholder(tf.string, shape=[1]) + + in_data = tf.decode_base64(input[0]) + img = tf.image.decode_png(in_data) + img = tf.image.rgb_to_grayscale(img) + out_data = tf.image.encode_png(img) + output = tf.convert_to_tensor([tf.encode_base64(out_data)]) + + variable_to_allow_model_saving = tf.Variable(1, dtype=tf.float32) + + inputs = { + "key": key.name, + "input": input.name + } + tf.add_to_collection("inputs", json.dumps(inputs)) + outputs = { + "key": tf.identity(key).name, + "output": output.name, + } + tf.add_to_collection("outputs", json.dumps(outputs)) + + init_op = tf.global_variables_initializer() + with tf.Session() as sess: + sess.run(init_op) + saver = tf.train.Saver() + saver.export_meta_graph(filename=os.path.join(a.output_dir, "export.meta")) + saver.save(sess, os.path.join(a.output_dir, "export"), write_meta_graph=False) + + print("exported example model to %s" % a.output_dir) + +main() diff --git a/server/process-cloud.py b/server/process-cloud.py new file mode 100644 index 0000000000000000000000000000000000000000..4f0c9396c098580335475fb3aa06ff8896d3ad9e --- /dev/null +++ b/server/process-cloud.py @@ -0,0 +1,44 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import json +import base64 +import oauth2client.service_account +import googleapiclient.discovery + + +parser = argparse.ArgumentParser() +parser.add_argument("--model_name", required=True, help="name of Cloud Machine Learning model") +parser.add_argument("--input_file", required=True, help="input PNG image file") +parser.add_argument("--output_file", required=True, help="output PNG image file") +parser.add_argument("--credentials", required=True, help="JSON credentials for a Google Cloud Platform service account") +a = parser.parse_args() + +scopes = ["https://www.googleapis.com/auth/cloud-platform"] +credentials = oauth2client.service_account.ServiceAccountCredentials.from_json_keyfile_name(a.credentials, scopes) +ml = googleapiclient.discovery.build("ml", "v1beta1", credentials=credentials) + + +def main(): + with open(a.credentials) as f: + project_id = json.loads(f.read())["project_id"] + + with open(a.input_file) as f: + input_data = f.read() + + input_instance = dict(input=base64.urlsafe_b64encode(input_data), key="0") + input_instance = json.loads(json.dumps(input_instance)) + request = ml.projects().predict(name="projects/" + project_id + "/models/" + a.model_name, body={"instances": [input_instance]}) + response = request.execute() + output_instance = json.loads(json.dumps(response["predictions"][0])) + + b64data = output_instance["output"].encode("ascii") + b64data += "=" * (-len(b64data) % 4) + output_data = base64.urlsafe_b64decode(b64data) + + with open(a.output_file, "w") as f: + f.write(output_data) + +main() \ No newline at end of file diff --git a/server/process-local.py b/server/process-local.py new file mode 100644 index 0000000000000000000000000000000000000000..2034ea670213eb906c64e20bbeabcaaad42a7567 --- /dev/null +++ b/server/process-local.py @@ -0,0 +1,45 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +import numpy as np +import argparse +import json +import base64 + + +parser = argparse.ArgumentParser() +parser.add_argument("--model_dir", required=True, help="directory containing exported model") +parser.add_argument("--input_file", required=True, help="input PNG image file") +parser.add_argument("--output_file", required=True, help="output PNG image file") +a = parser.parse_args() + +def main(): + with open(a.input_file) as f: + input_data = f.read() + + input_instance = dict(input=base64.urlsafe_b64encode(input_data), key="0") + input_instance = json.loads(json.dumps(input_instance)) + + with tf.Session() as sess: + saver = tf.train.import_meta_graph(a.model_dir + "/export.meta") + saver.restore(sess, a.model_dir + "/export") + input_vars = json.loads(tf.get_collection("inputs")[0]) + output_vars = json.loads(tf.get_collection("outputs")[0]) + input = tf.get_default_graph().get_tensor_by_name(input_vars["input"]) + output = tf.get_default_graph().get_tensor_by_name(output_vars["output"]) + + input_value = np.array(input_instance["input"]) + output_value = sess.run(output, feed_dict={input: np.expand_dims(input_value, axis=0)})[0] + + output_instance = dict(output=output_value, key="0") + + b64data = output_instance["output"].encode("ascii") + b64data += "=" * (-len(b64data) % 4) + output_data = base64.urlsafe_b64decode(b64data) + + with open(a.output_file, "w") as f: + f.write(output_data) + +main() \ No newline at end of file diff --git a/server/serve.py b/server/serve.py new file mode 100644 index 0000000000000000000000000000000000000000..7347416676a496403d211d17b5f96e5daa694bf3 --- /dev/null +++ b/server/serve.py @@ -0,0 +1,192 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import socket +import urlparse +import time +import argparse +import base64 +import os +import json +import traceback + +# https://github.com/Nakiami/MultithreadedSimpleHTTPServer/blob/master/MultithreadedSimpleHTTPServer.py +try: + # Python 2 + from SocketServer import ThreadingMixIn + from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler +except ImportError: + # Python 3 + from socketserver import ThreadingMixIn + from http.server import HTTPServer, BaseHTTPRequestHandler + +socket.setdefaulttimeout(30) + +parser = argparse.ArgumentParser() +parser.add_argument("--local_models_dir", help="directory containing local models to serve (either this or --cloud_model_names must be specified)") +parser.add_argument("--cloud_model_names", help="comma separated list of cloud models to serve (either this or --local_models_dir must be specified)") +parser.add_argument("--addr", default="", help="address to listen on") +parser.add_argument("--port", default=8000, type=int, help="port to listen on") +parser.add_argument("--credentials", help="JSON credentials for a Google Cloud Platform service account, generate this at https://console.cloud.google.com/iam-admin/serviceaccounts/project (select \"Furnish a new private key\")") +parser.add_argument("--project", help="Google Cloud Project to use, only necessary if using default application credentials") +a = parser.parse_args() + + +models = {} +local = a.local_models_dir is not None +ml = None +project_id = None + + +class Handler(BaseHTTPRequestHandler): + def do_GET(self): + if not os.path.exists("static"): + self.send_response(404) + return + + if self.path == "/": + self.send_response(200) + self.send_header("Content-Type", "text/html") + self.end_headers() + with open("static/index.html") as f: + self.wfile.write(f.read()) + return + + filenames = [name for name in os.listdir("static") if not name.startswith(".")] + path = self.path[1:] + if path not in filenames: + self.send_response(404) + return + + self.send_response(200) + if path.endswith(".png"): + self.send_header("Content-Type", "image/png") + elif path.endswith(".jpg"): + self.send_header("Content-Type", "image/jpeg") + else: + self.send_header("Content-Type", "application/octet-stream") + self.end_headers() + with open("static/" + path) as f: + self.wfile.write(f.read()) + + + def do_OPTIONS(self): + self.send_response(200) + if "origin" in self.headers: + self.send_header("access-control-allow-origin", "*") + + allow_headers = self.headers.get("access-control-request-headers", "*") + self.send_header("access-control-allow-headers", allow_headers) + self.send_header("access-control-allow-methods", "POST, OPTIONS") + self.send_header("access-control-max-age", "3600") + self.end_headers() + + + def do_POST(self): + start = time.time() + + status = 200 + headers = {} + body = "" + + try: + name = self.path[1:] + if name not in models: + raise Exception("invalid model") + + content_len = int(self.headers.getheader("content-length", 0)) + if content_len > 1 * 1024 * 1024: + raise Exception("post body too large") + input_data = self.rfile.read(content_len) + input_b64data = base64.urlsafe_b64encode(input_data) + + if local: + m = models[name] + output_b64data = m["sess"].run(m["output"], feed_dict={m["input"]: [input_b64data]})[0] + else: + input_instance = dict(input=input_b64data, key="0") + request = ml.projects().predict(name="projects/" + project_id + "/models/" + name, body={"instances": [input_instance]}) + response = request.execute() + output_instance = response["predictions"][0] + output_b64data = output_instance["output"].encode("ascii") + + # add any missing padding + output_b64data += "=" * (-len(output_b64data) % 4) + output_data = base64.urlsafe_b64decode(output_b64data) + headers["content-type"] = "image/png" + body = output_data + except Exception as e: + print("exception", traceback.format_exc()) + status = 500 + body = "server error" + + self.send_response(status) + if "origin" in self.headers: + self.send_header("access-control-allow-origin", "*") + for key, value in headers.iteritems(): + self.send_header(key, value) + self.end_headers() + self.wfile.write(body) + + print("finished in %0.1fs" % (time.time() - start)) + + +class ThreadedHTTPServer(ThreadingMixIn, HTTPServer): + pass + + +def main(): + if a.local_models_dir is not None: + import tensorflow as tf + for name in os.listdir(a.local_models_dir): + if name.startswith("."): + continue + + print("loading model", name) + + with tf.Graph().as_default() as graph: + sess = tf.Session(graph=graph) + saver = tf.train.import_meta_graph(os.path.join(a.local_models_dir, name, "export.meta")) + + saver.restore(sess, os.path.join(a.local_models_dir, name, "export")) + input_vars = json.loads(tf.get_collection("inputs")[0]) + output_vars = json.loads(tf.get_collection("outputs")[0]) + input = graph.get_tensor_by_name(input_vars["input"]) + output = graph.get_tensor_by_name(output_vars["output"]) + + models[name] = dict( + sess=sess, + input=input, + output=output, + ) + elif a.cloud_model_names is not None: + import oauth2client.service_account + import googleapiclient.discovery + for name in a.cloud_model_names.split(","): + models[name] = None + + scopes = ["https://www.googleapis.com/auth/cloud-platform"] + global project_id + if a.credentials is None: + credentials = oauth2client.client.GoogleCredentials.get_application_default() + # use this only to detect the project + import google.cloud.storage + storage = google.cloud.storage.Client() + project_id = storage.project + if a.project is not None: + project_id = a.project + else: + credentials = oauth2client.service_account.ServiceAccountCredentials.from_json_keyfile_name(a.credentials, scopes) + with open(a.credentials) as f: + project_id = json.loads(f.read())["project_id"] + + global ml + ml = googleapiclient.discovery.build("ml", "v1beta1", credentials=credentials) + else: + raise Exception("must specify --local_models_dir or --cloud_model_names") + + print("listening on %s:%s" % (a.addr, a.port)) + ThreadedHTTPServer((a.addr, a.port), Handler).serve_forever() + +main() diff --git a/server/static/edges2cats-input.png b/server/static/edges2cats-input.png new file mode 100644 index 0000000000000000000000000000000000000000..e48918c8523851d6e00d33a785d1cb2709c5447a Binary files /dev/null and b/server/static/edges2cats-input.png differ diff --git a/server/static/edges2cats-output.png b/server/static/edges2cats-output.png new file mode 100644 index 0000000000000000000000000000000000000000..b6049e8b609534706dfc53f92d4b268dc2b35a2a Binary files /dev/null and b/server/static/edges2cats-output.png differ diff --git a/server/static/edges2cats-sheet.jpg b/server/static/edges2cats-sheet.jpg new file mode 100644 index 0000000000000000000000000000000000000000..40b8508d3ccb9349a251af8e1a71d554078e1f89 Binary files /dev/null and b/server/static/edges2cats-sheet.jpg differ diff --git a/server/static/edges2handbags-input.png b/server/static/edges2handbags-input.png new file mode 100644 index 0000000000000000000000000000000000000000..8e342bf48c50a7a1ea940846fc53388acb662499 Binary files /dev/null and b/server/static/edges2handbags-input.png differ diff --git a/server/static/edges2handbags-output.png b/server/static/edges2handbags-output.png new file mode 100644 index 0000000000000000000000000000000000000000..0d4e8330fd1e638d69e4e33c8e66c9ba2882240f Binary files /dev/null and b/server/static/edges2handbags-output.png differ diff --git a/server/static/edges2handbags-sheet.jpg b/server/static/edges2handbags-sheet.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e06ec790abdaa85c1c75fa75fd702aacf2054dad Binary files /dev/null and b/server/static/edges2handbags-sheet.jpg differ diff --git a/server/static/edges2shoes-input.png b/server/static/edges2shoes-input.png new file mode 100644 index 0000000000000000000000000000000000000000..296f5b661d9b83c4aa835074b65cde359aa135c0 Binary files /dev/null and b/server/static/edges2shoes-input.png differ diff --git a/server/static/edges2shoes-output.png b/server/static/edges2shoes-output.png new file mode 100644 index 0000000000000000000000000000000000000000..4faa4e59c6d982aae2c2466002dbbe569ed23712 Binary files /dev/null and b/server/static/edges2shoes-output.png differ diff --git a/server/static/edges2shoes-sheet.jpg b/server/static/edges2shoes-sheet.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e535c5ef97d55ec1f0bb2d7732f05d041277ee8f Binary files /dev/null and b/server/static/edges2shoes-sheet.jpg differ diff --git a/server/static/editor.png b/server/static/editor.png new file mode 100644 index 0000000000000000000000000000000000000000..82090f1088ad8ff830d53862ef8dfccacf140413 Binary files /dev/null and b/server/static/editor.png differ diff --git a/server/static/facades-input.png b/server/static/facades-input.png new file mode 100644 index 0000000000000000000000000000000000000000..f405fc4f8b4ad113199cb067d13064fe1550a649 Binary files /dev/null and b/server/static/facades-input.png differ diff --git a/server/static/facades-output.png b/server/static/facades-output.png new file mode 100644 index 0000000000000000000000000000000000000000..03e092b132cd20ac044b2b8262ee7e18fa484f85 Binary files /dev/null and b/server/static/facades-output.png differ diff --git a/server/static/facades-sheet.jpg b/server/static/facades-sheet.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e1c4961235fd53a9944112c276953b3535cb6755 Binary files /dev/null and b/server/static/facades-sheet.jpg differ diff --git a/server/static/index.html b/server/static/index.html new file mode 100644 index 0000000000000000000000000000000000000000..58a26ba036a935b673a0c5db67a53613205a9cad --- /dev/null +++ b/server/static/index.html @@ -0,0 +1,818 @@ +<html> +<body> + +<div>edges2shoes</div> +<div id="edges2shoes"></div> + +<div>facades</div> +<div id="facades"></div> + +<div>edges2handbags</div> +<div id="edges2handbags"></div> + +<div>edges2cats</div> +<div id="edges2cats"></div> + +<script> + +var editor_background = new Image() +editor_background.src = "editor.png" + +var SIZE = 256 + +var editors = [] +var request_in_progress = false +var last_request_failed = false +var base_url = "" // this will cause it to talk to the server of this file + +function main() { + var create_editor = function(config) { + var editor = new Editor(config) + var elem = document.getElementById(config.name) + elem.appendChild(editor.view.ctx.canvas) + editor.view.ctx.canvas.onselectstart = function(e) { console.log("selectstart2"); + e.preventDefault(); return false; } + editors.push(editor) + } + + create_editor({ + name: "edges2shoes", + generate_url: base_url + "/edges2shoes_AtoB", + mode: "line", + clear: "#FFFFFF", + colors: { + line: "#000000", + eraser: "#ffffff", + }, + draw: "#000000", + initial_input: "/edges2shoes-input.png", + initial_output: "/edges2shoes-output.png", + sheet_url: "/edges2shoes-sheet.jpg", + }) + + create_editor({ + name: "edges2handbags", + generate_url: base_url + "/edges2handbags_AtoB", + mode: "line", + clear: "#FFFFFF", + colors: { + line: "#000000", + eraser: "#ffffff", + }, + draw: "#000000", + initial_input: "/edges2handbags-input.png", + initial_output: "/edges2handbags-output.png", + sheet_url: "/edges2handbags-sheet.jpg", + }) + + create_editor({ + name: "edges2cats", + generate_url: base_url + "/edges2cats_AtoB", + mode: "line", + clear: "#FFFFFF", + colors: { + line: "#000000", + eraser: "#ffffff", + }, + draw: "#000000", + initial_input: "/edges2cats-input.png", + initial_output: "/edges2cats-output.png", + sheet_url: "/edges2cats-sheet.jpg", + }) + + create_editor({ + name: "facades", + generate_url: base_url + "/facades_BtoA", + mode: "rect", + colors: { + background: "#0006d9", + wall: "#0d3dfb", + door: "#a50000", + "window": "#0075ff", + "window sill": "#68f898", + "window head": "#1dffdd", + "shutter": "#eeed28", + balcony: "#b8ff38", + trim: "#ff9204", + cornice: "#ff4401", + column: "#f60001", + entrance: "#00c9ff", + }, + clear: "#0d3dfb", + draw: "#0075ff", + initial_input: "/facades-input.png", + initial_output: "/facades-output.png", + sheet_url: "/facades-sheet.jpg", + }) + + init() +} +window.onload = main + +function render() { + for (var i = 0; i < editors.length; i++) { + editors[i].render() + } +} + +// editor + +function Editor(config) { + this.config = config + this.view = new View(this.config.name, 800, 400) + + this.buffers = [] + + this.buffer = createContext(SIZE, SIZE, SCALE) + this.buffer.fillStyle = this.config.clear + this.buffer.fillRect(0, 0, SIZE, SIZE) + + var image = new Image() + image.src = this.config.initial_input + image.onload = () => { + this.buffer.drawImage(image, 0, 0) + } + + this.output = createContext(SIZE, SIZE, 1) + var output = new Image() + output.src = this.config.initial_output + output.onload = () => { + this.output.drawImage(output, 0, 0) + } + + this.sheet_loaded = false + this.sheet = new Image() + this.sheet.src = this.config.sheet_url + this.sheet.onload = () => { + this.sheet_loaded = true + update() + } + this.sheet_index = 0 +} + +Editor.prototype = { + push_buffer: function() { + this.buffers.push(this.buffer) + var buffer = createContext(SIZE, SIZE, SCALE) + buffer.save() + buffer.scale(1/SCALE, 1/SCALE) + buffer.drawImage(this.buffer.canvas, 0, 0) + buffer.restore() + this.buffer = buffer + }, + pop_buffer: function() { + if (this.buffers.length == 0) { + return + } + this.buffer = this.buffers.pop() + }, +render: function() { + var v = this.view + + v.ctx.clearRect(0, 0, v.f.width, v.f.height) + v.ctx.save() + v.ctx.scale(1/SCALE, 1/SCALE) + v.ctx.drawImage(editor_background, 0, 0) + v.ctx.restore() + + v.frame("tools", 8, 41, 100, 250, () => { + var i = 0 + for (var name in this.config.colors) { + var color = this.config.colors[name] + v.frame("color_selector", 0, i*21, v.f.width, 20, () => { + if (v.contains(mouse_pos)) { + cursor_style = "pointer" + } + + if (mouse_released && v.contains(mouse_pos)) { + this.config.draw = color + update() + } + + if (this.config.draw == color) { + v.ctx.save() + var radius = 5 + v.ctx.beginPath() + v.ctx.moveTo(radius, 0) + var sides = [v.f.width, v.f.height, v.f.width, v.f.height] + for (var i = 0; i < sides.length; i++) { + var side = sides[i] + v.ctx.lineTo(side - radius, 0) + v.ctx.arcTo(side, 0, side, radius, radius) + v.ctx.translate(side, 0) + v.ctx.rotate(90 / 180 * Math.PI) + } + v.ctx.fillStyle = rgba([0.5, 0.5, 0.5, 1.0]) + v.ctx.stroke() + v.ctx.restore() + v.ctx.font = "bold 8pt Arial" + } else { + v.ctx.font = "8pt Arial" + } + + v.ctx.fillText(name, v.f.width - v.ctx.measureText(name).width - 26, 10) + + v.frame("color", v.f.width-25, 0, 20, 20, () => { + v.ctx.beginPath() + v.ctx.fillStyle = "#666666" + v.ctx.arc(10, 10, 9, 0, 2 * Math.PI, false) + v.ctx.fill() + v.ctx.beginPath() + v.ctx.fillStyle = color + v.ctx.arc(10, 10, 8, 0, 2 * Math.PI, false) + v.ctx.fill() + }) + }) + i++ + } + }) + + v.frame("output", 530, 40, 256, 256, () => { + v.ctx.drawImage(this.output.canvas, 0, 0) + }) + + v.frame("input", 140, 40, 256, 256+40, () => { + v.frame("image", 0, 0, 256, 256, () => { + v.ctx.drawImage(this.buffer.canvas, 0, 0, v.f.width, v.f.height) + + if (v.contains(mouse_pos)) { + cursor_style = "crosshair" + if (this.config.mode == "line" && this.config.draw == "#ffffff") { + // eraser tool + cursor_style = "url(/eraser.png) 8 8, auto" + } + } + + if (this.config.mode == "line") { + // this is to make undo work with lines, rather than removing only single frame line segments + var drag_from_outside = mouse_down && v.contains(mouse_pos) && !v.contains(last_mouse_pos) + var start_inside = mouse_pressed && v.contains(mouse_pos) + if (drag_from_outside || start_inside) { + this.push_buffer() + } + + if (mouse_down && v.contains(mouse_pos)) { + var last = v.relative(last_mouse_pos) + var cur = v.relative(mouse_pos) + this.buffer.beginPath() + this.buffer.lineCap = "round" + this.buffer.strokeStyle = this.config.draw + if (this.config.draw == "#ffffff") { + // eraser mode + this.buffer.lineWidth = 15 + } else { + this.buffer.lineWidth = 1 + } + this.buffer.moveTo(last.x, last.y) + this.buffer.lineTo(cur.x, cur.y) + this.buffer.stroke() + this.buffer.closePath() + } + } else { + if (v.contains(drag_start)) { + var start = v.relative(drag_start) + var end = v.relative(mouse_pos) + var width = end.x - start.x + var height = end.y - start.y + if (mouse_down) { + v.ctx.save() + v.ctx.rect(0, 0, v.f.width, v.f.height) + v.ctx.clip(); + v.ctx.fillStyle = this.config.draw + v.ctx.fillRect(start.x, start.y, width, height) + v.ctx.restore() + } else if (mouse_released) { + this.push_buffer() + this.buffer.fillStyle = this.config.draw + this.buffer.fillRect(start.x, start.y, width, height) + v.ctx.drawImage(this.buffer.canvas, 0, 0, v.f.width, v.f.height) + } + } + } + }) + }) + + v.frame("process_button", 461 - 32, 148, 32*2, 40, () => { + if (request_in_progress) { + do_button(v, "...") + } else { + if (do_button(v, "process")) { + if (request_in_progress) { + console.log("request already in progress") + return + } + + last_request_failed = false + var convert = createContext(SIZE, SIZE, 1) + convert.drawImage(this.buffer.canvas, 0, 0, convert.canvas.width, convert.canvas.height) + var input_b64 = convert.canvas.toDataURL("image/png").replace(/^data:image\/png;base64,/, "") + var xhr = new XMLHttpRequest() + xhr.open("POST", this.config.generate_url, true) + xhr.setRequestHeader("Content-Type", "image/png") + xhr.responseType = "arraybuffer" + xhr.timeout = 45000 + + xhr.onreadystatechange = () => { + if (xhr.readyState == 4) { + request_in_progress = false + update() + if (xhr.status == 200) { + console.log("request complete", xhr.status) + var output_bin = new Uint8Array(xhr.response) + var output_b64 = bin_to_b64(output_bin) + var output = new Image() + output.src = "data:image\/png;base64," + output_b64 + output.onload = () => { + // browsers besides chrome need to wait for the image to load + this.output.drawImage(output, 0, 0) + update() + } + } else { + last_request_failed = true + } + } + } + request_in_progress = true + update() + xhr.send(b64_to_bin(input_b64)) + } + } + }) + + v.frame("undo_button", 192-32, 310, 64, 40, () => { + if (do_button(v, "undo")) { + this.pop_buffer() + update() + } + }) + + v.frame("clear_button", 270-32, 310, 64, 40, () => { + if (do_button(v, "clear")) { + this.buffers = [] + this.buffer.fillStyle = this.config.clear + this.buffer.fillRect(0, 0, SIZE, SIZE) + this.output.fillStyle = "#FFFFFF" + this.output.fillRect(0, 0, SIZE, SIZE) + } + }) + + if (this.sheet_loaded) { + v.frame("random_button", 347-32, 310, 64, 40, () => { + if (do_button(v, "random")) { + // pick next sheet entry + this.buffers = [] + var y_offset = this.sheet_index * SIZE + this.buffer.drawImage(this.sheet, 0, y_offset, SIZE, SIZE, 0, 0, SIZE, SIZE) + this.output.drawImage(this.sheet, SIZE, y_offset, SIZE, SIZE, 0, 0, SIZE, SIZE) + this.sheet_index = (this.sheet_index + 1) % (this.sheet.height / SIZE) + update() + } + }) + } + + v.frame("save_button", 655-32, 310, 64, 40, () => { + if (do_button(v, "save")) { + // create a canvas to hold the part of the canvas that we wish to store + var x = 125 * SCALE + var y = 0 + var width = 800 * SCALE - x + var height = 310 * SCALE - y + var convert = createContext(width, height, 1) + convert.drawImage(v.ctx.canvas, x, y, width, height, 0, 0, convert.canvas.width, convert.canvas.height) + var data_b64 = convert.canvas.toDataURL("image/png").replace(/^data:image\/png;base64,/, "") + var data = b64_to_bin(data_b64) + var blob = new Blob([data], {type: "application/octet-stream"}) + var url = window.URL.createObjectURL(blob) + var a = document.createElement("a") + a.href = url + a.download = "pix2pix.png" + // use createEvent instead of .click() to work in firefox + // also can't revoke the object url because firefox breaks + var event = document.createEvent("MouseEvents") + event.initEvent("click", true, true) + a.dispatchEvent(event) + // safari doesn't work at all + } + }) + + if (last_request_failed) { + v.frame("server_error", 50, 350, v.f.width, 50, () => { + v.ctx.font = "20px Arial" + v.ctx.fillStyle = "red" + v.center_text("error connecting to server, try again later") + }) + } + }, +} + +// utility + +function createContext(width, height, scale) { + var canvas = document.createElement("canvas") + canvas.width = width * scale + canvas.height = height * scale + stylize(canvas, { + width: fmt("%dpx", width), + height: fmt("%dpx", height), + margin: "10px auto 10px auto", + }) + var ctx = canvas.getContext("2d") + ctx.scale(scale, scale) + return ctx +} + +function b64_to_bin(str) { + var binstr = atob(str) + var bin = new Uint8Array(binstr.length) + for (var i = 0; i < binstr.length; i++) { + bin[i] = binstr.charCodeAt(i) + } + return bin +} + +function bin_to_b64(bin) { + var parts = [] + for (var i = 0; i < bin.length; i++) { + parts.push(String.fromCharCode(bin[i])) + } + var binstr = parts.join("") + return btoa(binstr) +} + + +// immediate mode + +var SCALE = 2 + +var updated = true +var frame_rate = 0 +var now = new Date() +var last_frame = new Date() +var animations = {} +var values = {} + +var cursor_style = null +var mouse_pos = [0, 0] +var last_mouse_pos = [0, 0] +var drag_start = [0, 0] +var mouse_down = false +var mouse_pressed = false +var mouse_released = false + +function View(name, width, height) { + this.ctx = createContext(width, height, SCALE) + // https://developer.apple.com/library/safari/documentation/AudioVideo/Conceptual/HTML-canvas-guide/AddingText/AddingText.html + this.ctx.textBaseline = "middle" + this.frames = [{name: name, offset_x: 0, offset_y: 0, width: width, height: height}] + this.f = this.frames[0] +} + +View.prototype = { + push_frame: function(name, x, y, width, height) { + this.ctx.save() + this.ctx.translate(x, y) + var current = this.frames[this.frames.length - 1] + var next = {name: name, offset_x: current.offset_x + x, offset_y: current.offset_y + y, width: width, height: height} + this.frames.push(next) + this.f = next + }, + pop_frame: function() { + this.ctx.restore() + this.frames.pop() + this.f = this.frames[this.frames.length - 1] + }, + frame: function(name, x, y, width, height, func) { + this.push_frame(name, x, y, width, height) + func() + this.pop_frame() + }, + frame_path: function() { + var parts = [] + for (var i = 0; i < this.frames.length; i++) { + parts.push(this.frames[i].name) + } + return parts.join(".") + }, + relative: function(pos) { + // adjust x and y relative to the top left corner of the canvas + // then adjust relative to the current frame + var rect = this.ctx.canvas.getBoundingClientRect() + return {x: pos.x - rect.left - this.f.offset_x, y: pos.y - rect.top - this.f.offset_y} + }, + contains: function(pos) { + // first check that position is inside canvas container + var rect = this.ctx.canvas.getBoundingClientRect() + if (pos.x < rect.left || pos.x > rect.left + rect.width || pos.y < rect.top || pos.y > rect.top + rect.height) { + return false + } + // translate coordinates to the current frame + var rel = this.relative(pos) + return 0 < rel.x && rel.x < this.f.width && 0 < rel.y && rel.y < this.f.height + }, + put_image_data: function(d, x, y) { + this.ctx.putImageData(d, (x + this.f.offset_x) * SCALE, (y + this.f.offset_y) * SCALE) + }, + center_text: function(s) { + this.ctx.fillText(s, (this.f.width - this.ctx.measureText(s).width)/2, this.f.height/2) + }, +} + +function do_button(v, text) { + name = v.frame_path() + + if (v.contains(mouse_pos)) { + cursor_style = "pointer" + } + + if (request_in_progress) { + animate(name, parse_color("#aaaaaaFF"), 100) + } else if (mouse_down && v.contains(mouse_pos)) { + animate(name, parse_color("#FF0000FF"), 50) + } else { + if (v.contains(mouse_pos)) { + animate(name, parse_color("#f477a5FF"), 100) + } else { + animate(name, parse_color("#f92672FF"), 100) + } + } + + v.ctx.save() + var radius = 5 + v.ctx.beginPath() + v.ctx.moveTo(radius, 0) + var sides = [v.f.width, v.f.height, v.f.width, v.f.height] + for (var i = 0; i < sides.length; i++) { + var side = sides[i] + v.ctx.lineTo(side - radius, 0) + v.ctx.arcTo(side, 0, side, radius, radius) + v.ctx.translate(side, 0) + v.ctx.rotate(90 / 180 * Math.PI) + } + v.ctx.fillStyle = rgba(calculate(name)) + v.ctx.fill() + v.ctx.restore() + + v.ctx.font = "16px Arial" + v.ctx.fillStyle = "#f8f8f8" + v.center_text(text) + + if (request_in_progress) { + return false + } + + return mouse_released && v.contains(mouse_pos) && v.contains(drag_start) +} + +function stylize(elem, style) { + for (var key in style) { + elem.style[key] = style[key] + } +} + +function update() { + updated = true +} + +function frame() { + var raf = window.requestAnimationFrame(frame) + + if (!updated && Object.keys(animations).length == 0) { + return + } + + now = new Date() + cursor_style = null + updated = false + + try { + render() + } catch (e) { + window.cancelAnimationFrame(raf) + throw e + } + + if (cursor_style == null) { + document.body.style.cursor = "default" + } else { + document.body.style.cursor = cursor_style + } + + last_frame = now + last_mouse_pos = mouse_pos + mouse_pressed = false + mouse_released = false +} + +function array_equal(a, b) { + if (a.length != b.length) { + return false + } + + for (var i = 0; i < a.length; i++) { + if (a[i] != b[i]) { + return false + } + } + return true +} + +function animate(name, end, duration) { + if (values[name] == undefined) { + // no value has been set for this element, set it immediately + values[name] = end + return + } + + var v = calculate(name) + if (array_equal(v, end)) { + return + } + if (duration == 0) { + delete animations[name] + values[name] = end + return + } + var a = animations[name] + if (a != undefined && array_equal(a.end, end)) { + return + } + animations[name] = {time: now, start: v, end: end, duration: duration} +} + +function calculate(name) { + if (values[name] == undefined) { + throw "calculate used before calling animate" + } + + var a = animations[name] + if (a != undefined) { + // update value + var t = Math.min((now - a.time)/a.duration, 1.0) + t = t*t*t*(t*(t*6 - 15) + 10) // smootherstep + var result = [] + for (var i = 0; i < a.start.length; i++) { + result[i] = a.start[i] + (a.end[i] - a.start[i]) * t + } + if (t == 1.0) { + delete animations[name] + } + values[name] = result + } + return values[name] +} + +function rgba(v) { + return fmt("rgba(%d, %d, %d, %f)", v[0] * 255, v[1] * 255, v[2] * 255, v[3]) +} + +var parse_color = function(c) { + return [ + parseInt(c.substr(1,2), 16) / 255, + parseInt(c.substr(3,2), 16) / 255, + parseInt(c.substr(5,2), 16) / 255, + parseInt(c.substr(7,2), 16) / 255, + ] +} + +document.addEventListener("mousemove", function(e) { + mouse_pos = {x: e.clientX, y: e.clientY} + update() +}) + +document.addEventListener("mousedown", function(e) { + drag_start = {x: e.clientX, y: e.clientY} + mouse_down = true + mouse_pressed = true + update() +}) + +document.addEventListener("mouseup", function(e) { + mouse_down = false + mouse_released = true + update() +}) + +function default_format(obj) { + if (typeof(obj) === "string") { + return obj + } else { + return JSON.stringify(obj) + } +} + +function fmt() { + if (arguments.length === 0) { + return "error" + } + + var format = arguments[0] + var output = "" + + var arg_index = 1 + var i = 0 + + while (i < format.length) { + var c = format[i] + i++ + + if (c != "%") { + output += c + continue + } + + if (i === format.length) { + output += "%!(NOVERB)" + break + } + + var flag = format[i] + i++ + + var pad_char = " " + + if (flag == "0") { + pad_char = "0" + } else { + // not a flag + i-- + } + + var width = 0 + while (format[i] >= "0" && format[i] <= "9") { + width *= 10 + width += parseInt(format[i], 10) + i++ + } + + var f = format[i] + i++ + + if (f === "%") { + output += "%" + continue + } + + if (arg_index === arguments.length) { + output += "%!" + f + "(MISSING)" + continue + } + + var arg = arguments[arg_index] + arg_index++ + + var o = null + + if (f === "v") { + o = default_format(arg) + } else if (f === "s" && typeof(arg) === "string") { + o = arg + } else if (f === "T") { + o = typeof(arg) + } else if (f === "d" && typeof(arg) === "number") { + o = arg.toFixed(0) + } else if (f === "f" && typeof(arg) === "number") { + o = arg.toString() + } else if (f === "x" && typeof(arg) === "number") { + o = Math.round(arg).toString(16) + } else if (f === "t" && typeof(arg) === "boolean") { + if (arg) { + o = "true" + } else { + o = "false" + } + } else { + output += "%!" + f + "(" + typeof(arg) + "=" + default_format(arg) + ")" + } + + if (o !== null) { + if (o.length < width) { + output += Array(width - o.length + 1).join(pad_char) + } + output += o + } + } + + if (arg_index < arguments.length) { + output += "%!(EXTRA " + while (arg_index < arguments.length) { + var arg = arguments[arg_index] + output += typeof(arg) + "=" + default_format(arg) + if (arg_index < arguments.length - 1) { + output += ", " + } + arg_index++ + } + output += ")" + } + + return output +} + +function init() { + window.requestAnimationFrame(frame) +} + +</script> + +</body> +</html> diff --git a/server/upload-model.py b/server/upload-model.py new file mode 100644 index 0000000000000000000000000000000000000000..ff10bcac982553e753331c29795fd782ab130be0 --- /dev/null +++ b/server/upload-model.py @@ -0,0 +1,101 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import json +import os +import time +import sys +import base64 +import oauth2client.service_account +import googleapiclient.discovery +import google.cloud.storage + + +parser = argparse.ArgumentParser() +parser.add_argument("--bucket", required=True, help="Google Cloud Storage bucket to upload to") +parser.add_argument("--model_name", required=True, help="name of Google Cloud Machine Learning model to create or update") +parser.add_argument("--model_dir", required=True, help="path to directory containing exported model") +parser.add_argument("--runtime_version", default="0.12", help="tensorflow version to use for the model") +parser.add_argument("--credentials", help="JSON credentials for a Google Cloud Platform service account") +parser.add_argument("--project", help="Google Cloud Project to use to override project detection") +a = parser.parse_args() + +scopes = ["https://www.googleapis.com/auth/cloud-platform"] +if a.credentials is None: + credentials = oauth2client.client.GoogleCredentials.get_application_default() + storage = google.cloud.storage.Client() + project_id = storage.project + if a.project is not None: + project_id = a.project +else: + credentials = oauth2client.service_account.ServiceAccountCredentials.from_json_keyfile_name(a.credentials, scopes) + with open(a.credentials) as f: + project_id = json.loads(f.read())["project_id"] + storage = google.cloud.storage.Client.from_service_account_json(a.credentials, project=project_id) + +ml = googleapiclient.discovery.build("ml", "v1beta1", credentials=credentials) + + +def main(): + try: + bucket = storage.get_bucket(a.bucket) + except google.cloud.exceptions.NotFound as e: + print("creating bucket %s" % a.bucket) + bucket = storage.create_bucket(a.bucket) + + project_path = "projects/%s" % project_id + model_path = "%s/models/%s" % (project_path, a.model_name) + + try: + ml.projects().models().get(name=model_path).execute() + except googleapiclient.errors.HttpError as e: + if e.resp["status"] != "404": + raise + print("creating model %s" % a.model_name) + ml.projects().models().create(parent=project_path, body=dict(name=a.model_name)).execute() + + version_number = 0 + resp = ml.projects().models().versions().list(parent=model_path).execute() + for version in resp.get("versions", []): + name = version["name"] + number = int(name.split("/")[-1][1:]) + if number > version_number: + version_number = number + + version_number += 1 + print("creating version v%d" % version_number) + + for filename in os.listdir(a.model_dir): + if not filename.startswith("export.") and filename != "checkpoint": + continue + + print("uploading", filename) + filepath = os.path.join(a.model_dir, filename) + blob = bucket.blob("%s-v%d/%s" % (a.model_name, version_number, filename)) + blob.upload_from_filename(filepath) + + version_path = "%s/versions/v%d" % (model_path, version_number) + version = dict( + name="v%d" % version_number, + runtimeVersion=a.runtime_version, + deploymentUri="gs://%s/%s-v%d/" % (a.bucket, a.model_name, version_number), + ) + operation = ml.projects().models().versions().create(parent=model_path, body=version).execute() + + sys.stdout.write("waiting for creation to finish") + while True: + operation = ml.projects().operations().get(name=operation["name"]).execute() + if "done" in operation and operation["done"]: + break + sys.stdout.write(".") + sys.stdout.flush() + time.sleep(10) + print() + + print("setting version %d as default" % version_number) + ml.projects().models().versions().setDefault(name=version_path, body=dict()).execute() + + +main() \ No newline at end of file diff --git a/tools/process.py b/tools/process.py index ac9aab82007f6e31311974a32baa5e28c1bde706..06ffcffb262a38a763fd08cfe511a7f1d589eb7d 100644 --- a/tools/process.py +++ b/tools/process.py @@ -2,230 +2,292 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function + import argparse import os +import tempfile +import subprocess import tensorflow as tf import numpy as np +import tfimage as im +import threading +import time +import multiprocessing parser = argparse.ArgumentParser() parser.add_argument("--input_dir", required=True, help="path to folder containing images") parser.add_argument("--output_dir", required=True, help="output path") -parser.add_argument("--operation", required=True, choices=["grayscale", "resize", "blank", "combine"]) +parser.add_argument("--operation", required=True, choices=["grayscale", "resize", "blank", "combine", "edges"]) +parser.add_argument("--workers", type=int, default=1, help="number of workers") +# resize parser.add_argument("--pad", action="store_true", help="pad instead of crop for resize operation") parser.add_argument("--size", type=int, default=256, help="size to use for resize operation") +# combine parser.add_argument("--b_dir", type=str, help="path to folder containing B images for combine operation") a = parser.parse_args() -def create_op(func, **placeholders): - op = func(**placeholders) - - def f(**kwargs): - feed_dict = {} - for argname, argvalue in kwargs.iteritems(): - placeholder = placeholders[argname] - feed_dict[placeholder] = argvalue - return op.eval(feed_dict=feed_dict) - - return f - -downscale = create_op( - func=tf.image.resize_images, - images=tf.placeholder(tf.float32, [None, None, None]), - size=tf.placeholder(tf.int32, [2]), - method=tf.image.ResizeMethod.AREA, -) - -upscale = create_op( - func=tf.image.resize_images, - images=tf.placeholder(tf.float32, [None, None, None]), - size=tf.placeholder(tf.int32, [2]), - method=tf.image.ResizeMethod.BICUBIC, -) - -decode_jpeg = create_op( - func=tf.image.decode_jpeg, - contents=tf.placeholder(tf.string), -) - -decode_png = create_op( - func=tf.image.decode_png, - contents=tf.placeholder(tf.string), -) - -rgb_to_grayscale = create_op( - func=tf.image.rgb_to_grayscale, - images=tf.placeholder(tf.float32), -) - -grayscale_to_rgb = create_op( - func=tf.image.grayscale_to_rgb, - images=tf.placeholder(tf.float32), -) - -encode_jpeg = create_op( - func=tf.image.encode_jpeg, - image=tf.placeholder(tf.uint8), -) - -encode_png = create_op( - func=tf.image.encode_png, - image=tf.placeholder(tf.uint8), -) - -crop = create_op( - func=tf.image.crop_to_bounding_box, - image=tf.placeholder(tf.float32), - offset_height=tf.placeholder(tf.int32, []), - offset_width=tf.placeholder(tf.int32, []), - target_height=tf.placeholder(tf.int32, []), - target_width=tf.placeholder(tf.int32, []), -) - -pad = create_op( - func=tf.image.pad_to_bounding_box, - image=tf.placeholder(tf.float32), - offset_height=tf.placeholder(tf.int32, []), - offset_width=tf.placeholder(tf.int32, []), - target_height=tf.placeholder(tf.int32, []), - target_width=tf.placeholder(tf.int32, []), -) - -to_uint8 = create_op( - func=tf.image.convert_image_dtype, - image=tf.placeholder(tf.float32), - dtype=tf.uint8, - saturate=True, -) - -to_float32 = create_op( - func=tf.image.convert_image_dtype, - image=tf.placeholder(tf.uint8), - dtype=tf.float32, -) - - -def load(path): - contents = open(path).read() - _, ext = os.path.splitext(path.lower()) - - if ext == ".jpg": - image = decode_jpeg(contents=contents) - elif ext == ".png": - image = decode_png(contents=contents) +def resize(src): + height, width, _ = src.shape + dst = src + if height != width: + if a.pad: + size = max(height, width) + # pad to correct ratio + oh = (size - height) // 2 + ow = (size - width) // 2 + dst = im.pad(image=dst, offset_height=oh, offset_width=ow, target_height=size, target_width=size) + else: + # crop to correct ratio + size = min(height, width) + oh = (height - size) // 2 + ow = (width - size) // 2 + dst = im.crop(image=dst, offset_height=oh, offset_width=ow, target_height=size, target_width=size) + + assert(dst.shape[0] == dst.shape[1]) + + size, _, _ = dst.shape + if size > a.size: + dst = im.downscale(images=dst, size=[a.size, a.size]) + elif size < a.size: + dst = im.upscale(images=dst, size=[a.size, a.size]) + return dst + + +def blank(src): + height, width, _ = src.shape + if height != width: + raise Exception("non-square image") + + image_size = width + size = int(image_size * 0.3) + offset = int(image_size / 2 - size / 2) + + dst = src + dst[offset:offset + size,offset:offset + size,:] = np.ones([size, size, 3]) + return dst + + +def combine(src, src_path): + if a.b_dir is None: + raise Exception("missing b_dir") + + # find corresponding file in b_dir, could have a different extension + basename, _ = os.path.splitext(os.path.basename(src_path)) + for ext in [".png", ".jpg"]: + sibling_path = os.path.join(a.b_dir, basename + ext) + if os.path.exists(sibling_path): + sibling = im.load(sibling_path) + break else: - raise Exception("invalid image suffix") - - return to_float32(image=image) - + raise Exception("could not find sibling image for " + src_path) + + # make sure that dimensions are correct + height, width, _ = src.shape + if height != sibling.shape[0] or width != sibling.shape[1]: + raise Exception("differing sizes") + + # convert both images to RGB if necessary + if src.shape[2] == 1: + src = im.grayscale_to_rgb(images=src) + + if sibling.shape[2] == 1: + sibling = im.grayscale_to_rgb(images=sibling) + + # remove alpha channel + if src.shape[2] == 4: + src = src[:,:,:3] + + if sibling.shape[2] == 4: + sibling = sibling[:,:,:3] + + return np.concatenate([src, sibling], axis=1) + + +def grayscale(src): + return im.grayscale_to_rgb(images=im.rgb_to_grayscale(images=src)) + + +net = None +def run_caffe(src): + # lazy load caffe and create net + global net + if net is None: + # don't require caffe unless we are doing edge detection + os.environ["GLOG_minloglevel"] = "2" # disable logging from caffe + import caffe + # using this requires using the docker image or assembling a bunch of dependencies + # and then changing these hardcoded paths + net = caffe.Net("/opt/caffe/examples/hed/deploy.prototxt", "/opt/caffe/hed_pretrained_bsds.caffemodel", caffe.TEST) + + net.blobs["data"].reshape(1, *src.shape) + net.blobs["data"].data[...] = src + net.forward() + return net.blobs["sigmoid-fuse"].data[0][0,:,:] + + +# create the pool before we launch processing threads +# we must create the pool after run_caffe is defined +if a.operation == "edges": + edge_pool = multiprocessing.Pool(a.workers) + +def edges(src): + # based on https://github.com/phillipi/pix2pix/blob/master/scripts/edges/batch_hed.py + # and https://github.com/phillipi/pix2pix/blob/master/scripts/edges/PostprocessHED.m + import scipy.io + src = src * 255 + border = 128 # put a padding around images since edge detection seems to detect edge of image + src = src[:,:,:3] # remove alpha channel if present + src = np.pad(src, ((border, border), (border, border), (0,0)), "reflect") + src = src[:,:,::-1] + src -= np.array((104.00698793,116.66876762,122.67891434)) + src = src.transpose((2, 0, 1)) + + # [height, width, channels] => [batch, channel, height, width] + fuse = edge_pool.apply(run_caffe, [src]) + fuse = fuse[border:-border, border:-border] + + with tempfile.NamedTemporaryFile(suffix=".png") as png_file, tempfile.NamedTemporaryFile(suffix=".mat") as mat_file: + scipy.io.savemat(mat_file.name, {"input": fuse}) + + octave_code = r""" +E = 1-load(input_path).input; +E = imresize(E, [image_width,image_width]); +E = 1 - E; +E = single(E); +[Ox, Oy] = gradient(convTri(E, 4), 1); +[Oxx, ~] = gradient(Ox, 1); +[Oxy, Oyy] = gradient(Oy, 1); +O = mod(atan(Oyy .* sign(-Oxy) ./ (Oxx + 1e-5)), pi); +E = edgesNmsMex(E, O, 1, 5, 1.01, 1); +E = double(E >= max(eps, threshold)); +E = bwmorph(E, 'thin', inf); +E = bwareaopen(E, small_edge); +E = 1 - E; +E = uint8(E * 255); +imwrite(E, output_path); +""" + + config = dict( + input_path="'%s'" % mat_file.name, + output_path="'%s'" % png_file.name, + image_width=256, + threshold=25.0/255.0, + small_edge=5, + ) + + args = ["octave"] + for k, v in config.items(): + args.extend(["--eval", "%s=%s;" % (k, v)]) + + args.extend(["--eval", octave_code]) + try: + subprocess.check_output(args, stderr=subprocess.STDOUT) + except subprocess.CalledProcessError as e: + print("octave failed") + print("returncode:", e.returncode) + print("output:", e.output) + raise + return im.load(png_file.name) + + +def process(src_path, dst_path): + src = im.load(src_path) + + if a.operation == "grayscale": + dst = grayscale(src) + elif a.operation == "resize": + dst = resize(src) + elif a.operation == "blank": + dst = blank(src) + elif a.operation == "combine": + dst = combine(src, src_path) + elif a.operation == "edges": + dst = edges(src) + else: + raise Exception("invalid operation") -def find(d): - result = [] - for filename in os.listdir(d): - _, ext = os.path.splitext(filename.lower()) - if ext == ".jpg" or ext == ".png": - result.append(os.path.join(d, filename)) - result.sort() - return result + im.save(dst, dst_path) -def save(image, path): - _, ext = os.path.splitext(path.lower()) - image = to_uint8(image=image) - if ext == ".jpg": - encoded = encode_jpeg(image=image) - elif ext == ".png": - encoded = encode_png(image=image) - else: - raise Exception("invalid image suffix") +complete_lock = threading.Lock() +start = time.time() +num_complete = 0 +total = 0 - if os.path.exists(path): - raise Exception("file already exists at " + path) +def complete(): + global num_complete, rate, last_complete - with open(path, "w") as f: - f.write(encoded) + with complete_lock: + num_complete += 1 + now = time.time() + elapsed = now - start + rate = num_complete / elapsed + if rate > 0: + remaining = (total - num_complete) / rate + else: + remaining = 0 + print("%d/%d complete %0.2f images/sec %dm%ds elapsed %dm%ds remaining" % (num_complete, total, rate, elapsed // 60, elapsed % 60, remaining // 60, remaining % 60)) -def png_path(path): - basename, _ = os.path.splitext(os.path.basename(path)) - return os.path.join(os.path.dirname(path), basename + ".png") + last_complete = now def main(): if not os.path.exists(a.output_dir): os.makedirs(a.output_dir) - with tf.Session() as sess: - for src_path in find(a.input_dir): - dst_path = png_path(os.path.join(a.output_dir, os.path.basename(src_path))) - print(src_path, "->", dst_path) - src = load(src_path) - - if a.operation == "grayscale": - dst = grayscale_to_rgb(images=rgb_to_grayscale(images=src)) - elif a.operation == "resize": - height, width, _ = src.shape - dst = src - if height != width: - if a.pad: - size = max(height, width) - # pad to correct ratio - oh = (size - height) // 2 - ow = (size - width) // 2 - dst = pad(image=dst, offset_height=oh, offset_width=ow, target_height=size, target_width=size) - else: - # crop to correct ratio - size = min(height, width) - oh = (height - size) // 2 - ow = (width - size) // 2 - dst = crop(image=dst, offset_height=oh, offset_width=ow, target_height=size, target_width=size) - - assert(dst.shape[0] == dst.shape[1]) - - size, _, _ = dst.shape - if size > a.size: - dst = downscale(images=dst, size=[a.size, a.size]) - elif size < a.size: - dst = upscale(images=dst, size=[a.size, a.size]) - elif a.operation == "blank": - height, width, _ = src.shape - if height != width: - raise Exception("non-square image") - - image_size = width - size = int(image_size * 0.3) - offset = int(image_size / 2 - size / 2) - - dst = src - dst[offset:offset + size,offset:offset + size,:] = np.ones([size, size, 3]) - elif a.operation == "combine": - if a.b_dir is None: - raise Exception("missing b_dir") - - # find corresponding file in b_dir, could have a different extension - basename, _ = os.path.splitext(os.path.basename(src_path)) - for ext in [".png", ".jpg"]: - sibling_path = os.path.join(a.b_dir, basename + ext) - if os.path.exists(sibling_path): - sibling = load(sibling_path) + src_paths = [] + dst_paths = [] + + for src_path in im.find(a.input_dir): + name, _ = os.path.splitext(os.path.basename(src_path)) + dst_path = os.path.join(a.output_dir, name + ".png") + if not os.path.exists(dst_path): + src_paths.append(src_path) + dst_paths.append(dst_path) + + global total + total = len(src_paths) + + if a.workers == 1: + with tf.Session() as sess: + for src_path, dst_path in zip(src_paths, dst_paths): + process(src_path, dst_path) + complete() + else: + queue = tf.train.input_producer(zip(src_paths, dst_paths), shuffle=False, num_epochs=1) + dequeue_op = queue.dequeue() + + def worker(coord): + with sess.as_default(): + while not coord.should_stop(): + try: + src_path, dst_path = sess.run(dequeue_op) + except tf.errors.OutOfRangeError: + coord.request_stop() break - else: - raise Exception("could not find sibling image for " + src_path) - - # make sure that dimensions are correct - height, width, _ = src.shape - if height != sibling.shape[0] or width != sibling.shape[1]: - raise Exception("differing sizes") - - # remove alpha channel - src = src[:,:,:3] - sibling = sibling[:,:,:3] - dst = np.concatenate([src, sibling], axis=1) - else: - raise Exception("invalid operation") - - save(dst, dst_path) - -main() + process(src_path, dst_path) + complete() + + # init epoch counter for the queue + local_init_op = tf.local_variables_initializer() + with tf.Session() as sess: + sess.run(local_init_op) + + coord = tf.train.Coordinator() + threads = tf.train.start_queue_runners(coord=coord) + for i in range(a.workers): + t = threading.Thread(target=worker, args=(coord,)) + t.start() + threads.append(t) + + try: + coord.join(threads) + except KeyboardInterrupt: + coord.request_stop() + coord.join(threads) + +main() \ No newline at end of file diff --git a/tools/split.py b/tools/split.py index ef93f2cc68d009354cda2dd8d91d1b6e2d81d890..04867f18d7e4ef701a469af0473fa151f5360586 100644 --- a/tools/split.py +++ b/tools/split.py @@ -12,6 +12,7 @@ parser = argparse.ArgumentParser() parser.add_argument("--dir", type=str, required=True, help="path to folder containing images") parser.add_argument("--train_frac", type=float, default=0.8, help="percentage of images to use for training set") parser.add_argument("--test_frac", type=float, default=0.0, help="percentage of images to use for test set") +parser.add_argument("--sort", action="store_true", help="if set, sort the images instead of shuffling them") a = parser.parse_args() @@ -19,11 +20,15 @@ def main(): random.seed(0) files = glob.glob(os.path.join(a.dir, "*.png")) + files.sort() + assignments = [] assignments.extend(["train"] * int(a.train_frac * len(files))) assignments.extend(["test"] * int(a.test_frac * len(files))) assignments.extend(["val"] * int(len(files) - len(assignments))) - random.shuffle(assignments) + + if not a.sort: + random.shuffle(assignments) for name in ["train", "val", "test"]: if name in assignments: diff --git a/tools/tfimage.py b/tools/tfimage.py new file mode 100644 index 0000000000000000000000000000000000000000..8618f7acd5b13dfbc9c7852465d7b85b1f103db8 --- /dev/null +++ b/tools/tfimage.py @@ -0,0 +1,142 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +import os + + +def create_op(func, **placeholders): + op = func(**placeholders) + + def f(**kwargs): + feed_dict = {} + for argname, argvalue in kwargs.iteritems(): + placeholder = placeholders[argname] + feed_dict[placeholder] = argvalue + return tf.get_default_session().run(op, feed_dict=feed_dict) + + return f + +downscale = create_op( + func=tf.image.resize_images, + images=tf.placeholder(tf.float32, [None, None, None]), + size=tf.placeholder(tf.int32, [2]), + method=tf.image.ResizeMethod.AREA, +) + +upscale = create_op( + func=tf.image.resize_images, + images=tf.placeholder(tf.float32, [None, None, None]), + size=tf.placeholder(tf.int32, [2]), + method=tf.image.ResizeMethod.BICUBIC, +) + +decode_jpeg = create_op( + func=tf.image.decode_jpeg, + contents=tf.placeholder(tf.string), +) + +decode_png = create_op( + func=tf.image.decode_png, + contents=tf.placeholder(tf.string), +) + +rgb_to_grayscale = create_op( + func=tf.image.rgb_to_grayscale, + images=tf.placeholder(tf.float32), +) + +grayscale_to_rgb = create_op( + func=tf.image.grayscale_to_rgb, + images=tf.placeholder(tf.float32), +) + +encode_jpeg = create_op( + func=tf.image.encode_jpeg, + image=tf.placeholder(tf.uint8), +) + +encode_png = create_op( + func=tf.image.encode_png, + image=tf.placeholder(tf.uint8), +) + +crop = create_op( + func=tf.image.crop_to_bounding_box, + image=tf.placeholder(tf.float32), + offset_height=tf.placeholder(tf.int32, []), + offset_width=tf.placeholder(tf.int32, []), + target_height=tf.placeholder(tf.int32, []), + target_width=tf.placeholder(tf.int32, []), +) + +pad = create_op( + func=tf.image.pad_to_bounding_box, + image=tf.placeholder(tf.float32), + offset_height=tf.placeholder(tf.int32, []), + offset_width=tf.placeholder(tf.int32, []), + target_height=tf.placeholder(tf.int32, []), + target_width=tf.placeholder(tf.int32, []), +) + +to_uint8 = create_op( + func=tf.image.convert_image_dtype, + image=tf.placeholder(tf.float32), + dtype=tf.uint8, + saturate=True, +) + +to_float32 = create_op( + func=tf.image.convert_image_dtype, + image=tf.placeholder(tf.uint8), + dtype=tf.float32, +) + + +def load(path): + contents = open(path).read() + _, ext = os.path.splitext(path.lower()) + + if ext == ".jpg": + image = decode_jpeg(contents=contents) + elif ext == ".png": + image = decode_png(contents=contents) + else: + raise Exception("invalid image suffix") + + return to_float32(image=image) + + +def find(d): + result = [] + for filename in os.listdir(d): + _, ext = os.path.splitext(filename.lower()) + if ext == ".jpg" or ext == ".png": + result.append(os.path.join(d, filename)) + result.sort() + return result + + +def save(image, path, replace=False): + _, ext = os.path.splitext(path.lower()) + image = to_uint8(image=image) + if ext == ".jpg": + encoded = encode_jpeg(image=image) + elif ext == ".png": + encoded = encode_png(image=image) + else: + raise Exception("invalid image suffix") + + dirname = os.path.dirname(path) + if dirname != "" and not os.path.exists(dirname): + os.makedirs(dirname) + + if os.path.exists(path): + if replace: + os.remove(path) + else: + raise Exception("file already exists at " + path) + + with open(path, "wb") as f: + f.write(encoded)