Commit b66751f0 authored by Pavlo Beylin's avatar Pavlo Beylin
Browse files

Implement clamped cat optimization.

parent df1aa540
...@@ -4,6 +4,7 @@ import torch ...@@ -4,6 +4,7 @@ import torch
import cv2 import cv2
import time import time
import matplotlib import matplotlib
from torch import optim
import models import models
from models.common import Detections from models.common import Detections
...@@ -42,7 +43,7 @@ classes = ["person", "bicycle", "car", "motorbike", "aeroplane", "bus", ...@@ -42,7 +43,7 @@ classes = ["person", "bicycle", "car", "motorbike", "aeroplane", "bus",
"keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator",
"book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"] "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"]
PATH = "cat_patch1.jpg" PATH = "cat_patch0.jpg"
PATCH_SIZE = 300 PATCH_SIZE = 300
...@@ -75,7 +76,7 @@ def read_image(path): ...@@ -75,7 +76,7 @@ def read_image(path):
def extract_bounding_box(patch): def extract_bounding_box(patch):
mask = torch.where(patch < 0.1, torch.zeros(patch.shape).cuda(), torch.ones(patch.shape).cuda()).sum(2) mask = torch.where(patch < 0, torch.zeros(patch.shape).cuda(), torch.ones(patch.shape).cuda()).sum(2)
bb_x1 = torch.nonzero(mask.sum(0))[0] bb_x1 = torch.nonzero(mask.sum(0))[0]
bb_y1 = torch.nonzero(mask.sum(1))[0] bb_y1 = torch.nonzero(mask.sum(1))[0]
...@@ -124,6 +125,8 @@ if __name__ == "__main__": ...@@ -124,6 +125,8 @@ if __name__ == "__main__":
patch = read_image(PATH) patch = read_image(PATH)
patch.requires_grad = True patch.requires_grad = True
optimizer = optim.Adam([patch], lr=0.0001, amsgrad=True)
img_size_x = 640 img_size_x = 640
img_size_y = 480 img_size_y = 480
...@@ -135,21 +138,22 @@ if __name__ == "__main__": ...@@ -135,21 +138,22 @@ if __name__ == "__main__":
with torch.set_grad_enabled(True): with torch.set_grad_enabled(True):
# with torch.autograd.detect_anomaly(): # with torch.autograd.detect_anomaly():
# resize our captured frame if we need # resize our captured frame if we need
frame = cv2.resize(frame, None, fx=1.0, fy=1.0, interpolation=cv2.INTER_AREA) frame = cv2.resize(frame, None, fx=1.0, fy=1.0, interpolation=cv2.INTER_AREA)
frame = torch.tensor(frame, dtype=torch.float32, requires_grad=True).cuda() frame = torch.tensor(frame, dtype=torch.float32, requires_grad=True).cuda()
# transform patch (every couple of frames) # transform patch (every couple of frames)
if ctr % 100 == 0: if ctr % 1 == 0:
# print("{} {}".format(float(patch.min()), float(patch.max())))
trans_patch = patch_transformer(patch.cuda(), torch.ones([1, 14, 5]).cuda(), img_size_x, img_size_y, trans_patch = patch_transformer(patch.cuda(), torch.ones([1, 14, 5]).cuda(), img_size_x, img_size_y,
do_rotate=True, rand_loc=True) do_rotate=True, rand_loc=False)
# trans_patch_np = torch.transpose(trans_patch[0][0].T, 0, 1).detach().cpu().numpy()
trans_patch = torch.transpose(trans_patch[0][0].T, 0, 1) trans_patch = torch.transpose(trans_patch[0][0].T, 0, 1)
# extract bounding box (x1, y1, x2, y2) # extract bounding box (x1, y1, x2, y2)
bounding_box = extract_bounding_box(trans_patch) bounding_box = extract_bounding_box(trans_patch)
print("True BB: {} {} {} {}".format(int(bounding_box[0]), int(bounding_box[1]), int(bounding_box[2]), # print("True BB: {} {} {} {}".format(int(bounding_box[0]), int(bounding_box[1]), int(bounding_box[2]),
int(bounding_box[3]))) # int(bounding_box[3])))
# apply patch # apply patch
frame = patch_applier(frame, trans_patch) frame = patch_applier(frame, trans_patch)
...@@ -157,17 +161,26 @@ if __name__ == "__main__": ...@@ -157,17 +161,26 @@ if __name__ == "__main__":
# detect object on our frame # detect object on our frame
results = model.forward_pt(frame) results = model.forward_pt(frame)
if ctr % 100 == 0: if ctr % 1 == 0:
debug_preds() # debug_preds()
pass pass
pred_box = get_best_prediction(bounding_box, results, 15) # get cats pred_box = get_best_prediction(bounding_box, results, 15) # get cats
if pred_box is not None: if pred_box is not None:
print("P:{}".format(pred_box[-2])) # print("P:{}".format(pred_box[-2]))
loss = -1 * pred_box[-2]
# loss
loss = -1 * pred_box[-2] # optimize class
# loss = 1 * pred_box[-2] # adversarial
loss.backward(retain_graph=True) loss.backward(retain_graph=True)
pass optimizer.step()
# TODO: general clamping (regardless of min / max values) kills all patch updates. Why? Cloning maybe?
if patch.max() > 1 or patch.min() < 0:
print("clamped {}".format(time.time()))
patch = torch.clamp(patch.detach().clone(), 0., 1.)
# show us frame with detection # show us frame with detection
# cv2.imshow("img", results_np.render()[0]) # cv2.imshow("img", results_np.render()[0])
...@@ -183,6 +196,7 @@ if __name__ == "__main__": ...@@ -183,6 +196,7 @@ if __name__ == "__main__":
print("FPS:", fps / TIME) print("FPS:", fps / TIME)
fps = 0 fps = 0
start_time = time.time() start_time = time.time()
# time.sleep(0.3)
cap.release() cap.release()
cv2.destroyAllWindows() cv2.destroyAllWindows()
...@@ -138,7 +138,7 @@ class PatchTransformer(nn.Module): ...@@ -138,7 +138,7 @@ class PatchTransformer(nn.Module):
msk_batch = torch.cuda.FloatTensor(cls_mask.size()).fill_(1) - cls_mask msk_batch = torch.cuda.FloatTensor(cls_mask.size()).fill_(1) - cls_mask
# Pad patch and mask to image dimensions # Pad patch and mask to image dimensions
mypad = nn.ConstantPad2d((int(pad_x + 0.5), int(pad_x), int(pad_y + 0.5), int(pad_y)), 0) mypad = nn.ConstantPad2d((int(pad_x + 0.5), int(pad_x), int(pad_y + 0.5), int(pad_y)), -1)
adv_batch = mypad(adv_batch) adv_batch = mypad(adv_batch)
msk_batch = mypad(msk_batch) msk_batch = mypad(msk_batch)
...@@ -213,7 +213,7 @@ class PatchTransformer(nn.Module): ...@@ -213,7 +213,7 @@ class PatchTransformer(nn.Module):
adv_batch_t = adv_batch_t.view(s[0], s[1], s[2], s[3], s[4]) adv_batch_t = adv_batch_t.view(s[0], s[1], s[2], s[3], s[4])
msk_batch_t = msk_batch_t.view(s[0], s[1], s[2], s[3], s[4]) msk_batch_t = msk_batch_t.view(s[0], s[1], s[2], s[3], s[4])
adv_batch_t = torch.clamp(adv_batch_t, 0.000001, 0.999999) adv_batch_t = torch.clamp(adv_batch_t, -0.000001, 0.999999)
# img = msk_batch_t[0, 0, :, :, :].detach().cpu() # img = msk_batch_t[0, 0, :, :, :].detach().cpu()
# img = transforms.ToPILImage()(img) # img = transforms.ToPILImage()(img)
# img.show() # img.show()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment