Commit b66751f0 authored by Pavlo Beylin's avatar Pavlo Beylin
Browse files

Implement clamped cat optimization.

parent df1aa540
......@@ -4,6 +4,7 @@ import torch
import cv2
import time
import matplotlib
from torch import optim
import models
from models.common import Detections
......@@ -42,7 +43,7 @@ classes = ["person", "bicycle", "car", "motorbike", "aeroplane", "bus",
"keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator",
"book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"]
PATH = "cat_patch1.jpg"
PATH = "cat_patch0.jpg"
PATCH_SIZE = 300
......@@ -75,7 +76,7 @@ def read_image(path):
def extract_bounding_box(patch):
mask = torch.where(patch < 0.1, torch.zeros(patch.shape).cuda(), torch.ones(patch.shape).cuda()).sum(2)
mask = torch.where(patch < 0, torch.zeros(patch.shape).cuda(), torch.ones(patch.shape).cuda()).sum(2)
bb_x1 = torch.nonzero(mask.sum(0))[0]
bb_y1 = torch.nonzero(mask.sum(1))[0]
......@@ -124,6 +125,8 @@ if __name__ == "__main__":
patch = read_image(PATH)
patch.requires_grad = True
optimizer = optim.Adam([patch], lr=0.0001, amsgrad=True)
img_size_x = 640
img_size_y = 480
......@@ -135,21 +138,22 @@ if __name__ == "__main__":
with torch.set_grad_enabled(True):
# with torch.autograd.detect_anomaly():
# resize our captured frame if we need
frame = cv2.resize(frame, None, fx=1.0, fy=1.0, interpolation=cv2.INTER_AREA)
frame = torch.tensor(frame, dtype=torch.float32, requires_grad=True).cuda()
# transform patch (every couple of frames)
if ctr % 100 == 0:
if ctr % 1 == 0:
# print("{} {}".format(float(patch.min()), float(patch.max())))
trans_patch = patch_transformer(patch.cuda(), torch.ones([1, 14, 5]).cuda(), img_size_x, img_size_y,
do_rotate=True, rand_loc=True)
# trans_patch_np = torch.transpose(trans_patch[0][0].T, 0, 1).detach().cpu().numpy()
do_rotate=True, rand_loc=False)
trans_patch = torch.transpose(trans_patch[0][0].T, 0, 1)
# extract bounding box (x1, y1, x2, y2)
bounding_box = extract_bounding_box(trans_patch)
print("True BB: {} {} {} {}".format(int(bounding_box[0]), int(bounding_box[1]), int(bounding_box[2]),
int(bounding_box[3])))
# print("True BB: {} {} {} {}".format(int(bounding_box[0]), int(bounding_box[1]), int(bounding_box[2]),
# int(bounding_box[3])))
# apply patch
frame = patch_applier(frame, trans_patch)
......@@ -157,17 +161,26 @@ if __name__ == "__main__":
# detect object on our frame
results = model.forward_pt(frame)
if ctr % 100 == 0:
debug_preds()
if ctr % 1 == 0:
# debug_preds()
pass
pred_box = get_best_prediction(bounding_box, results, 15) # get cats
if pred_box is not None:
print("P:{}".format(pred_box[-2]))
loss = -1 * pred_box[-2]
# print("P:{}".format(pred_box[-2]))
# loss
loss = -1 * pred_box[-2] # optimize class
# loss = 1 * pred_box[-2] # adversarial
loss.backward(retain_graph=True)
pass
optimizer.step()
# TODO: general clamping (regardless of min / max values) kills all patch updates. Why? Cloning maybe?
if patch.max() > 1 or patch.min() < 0:
print("clamped {}".format(time.time()))
patch = torch.clamp(patch.detach().clone(), 0., 1.)
# show us frame with detection
# cv2.imshow("img", results_np.render()[0])
......@@ -183,6 +196,7 @@ if __name__ == "__main__":
print("FPS:", fps / TIME)
fps = 0
start_time = time.time()
# time.sleep(0.3)
cap.release()
cv2.destroyAllWindows()
......@@ -138,7 +138,7 @@ class PatchTransformer(nn.Module):
msk_batch = torch.cuda.FloatTensor(cls_mask.size()).fill_(1) - cls_mask
# Pad patch and mask to image dimensions
mypad = nn.ConstantPad2d((int(pad_x + 0.5), int(pad_x), int(pad_y + 0.5), int(pad_y)), 0)
mypad = nn.ConstantPad2d((int(pad_x + 0.5), int(pad_x), int(pad_y + 0.5), int(pad_y)), -1)
adv_batch = mypad(adv_batch)
msk_batch = mypad(msk_batch)
......@@ -213,7 +213,7 @@ class PatchTransformer(nn.Module):
adv_batch_t = adv_batch_t.view(s[0], s[1], s[2], s[3], s[4])
msk_batch_t = msk_batch_t.view(s[0], s[1], s[2], s[3], s[4])
adv_batch_t = torch.clamp(adv_batch_t, 0.000001, 0.999999)
adv_batch_t = torch.clamp(adv_batch_t, -0.000001, 0.999999)
# img = msk_batch_t[0, 0, :, :, :].detach().cpu()
# img = transforms.ToPILImage()(img)
# img.show()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment