Commit 3d9e90a3 authored by Pavlo Beylin's avatar Pavlo Beylin
Browse files

Add probabilistic catselector.

parent 5f366d50
...@@ -3,6 +3,7 @@ import torch ...@@ -3,6 +3,7 @@ import torch
import cv2 import cv2
import time import time
import matplotlib import matplotlib
matplotlib.use('TkAgg') matplotlib.use('TkAgg')
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
...@@ -36,18 +37,20 @@ classes = ["person", "bicycle", "car", "motorbike", "aeroplane", "bus", ...@@ -36,18 +37,20 @@ classes = ["person", "bicycle", "car", "motorbike", "aeroplane", "bus",
"keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator",
"book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"] "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"]
PATH = "cat_patch0.jpg" PATH = "cat_patch1.jpg"
PATCH_SIZE = 100 PATCH_SIZE = 300
def debug_preds(): def debug_preds():
detected_classes = [int(results.pred[0][i][-1]) for i in range(0, len(results.pred[0]))] detected_classes = [int(results.pred[0][i][-1]) for i in range(0, len(results.pred[0]))]
print(detected_classes) # print(detected_classes)
for det in results.pred[0]: for det in results.pred[0]:
if int(det[-1]) == 0: # person if int(det[-1]) == 15: # cat
print("Person ({}):".format(float(det[-2]))) print("Pred BB: ", end="")
print("x1:y1 : {}:{}".format(float(det[0]), float(det[1]))) # print("x1:y1 : {}:{}".format(float(det[0]), float(det[1])))
print("x2:y2 : {}:{}".format(float(det[2]), float(det[3]))) # print("x2:y2 : {}:{}".format(float(det[2]), float(det[3])))
print("{} {} {} {} ({}):".format(
int(det[0]), int(det[1]), int(det[2]), int(det[3]), float(det[-2])))
# from https://github.com/wangzh0ng/adversarial_yolo2 # from https://github.com/wangzh0ng/adversarial_yolo2
...@@ -66,6 +69,31 @@ def read_image(path): ...@@ -66,6 +69,31 @@ def read_image(path):
return tf(patch_img) return tf(patch_img)
def extract_bounding_box(patch):
mask = torch.where(torch.tensor(patch) < 0.1, torch.zeros(patch.shape), torch.ones(patch.shape)).sum(2)
bb_x1 = mask.sum(0).nonzero()[0]
bb_y1 = mask.sum(1).nonzero()[0]
bb_x2 = mask.sum(0).nonzero()[-1]
bb_y2 = mask.sum(1).nonzero()[-1]
return torch.stack([bb_x1, bb_y1, bb_x2, bb_y2], axis=0).sum(1)
def get_best_prediction(true_box, res, cls_nr):
min_distance = float("inf")
best_prediction = None
for pred in res.pred[0]:
if int(pred[-1]) != cls_nr:
continue
pred_dist = torch.dist(true_box.cuda(), pred[:4])
if pred_dist < min_distance:
min_distance = pred_dist
best_prediction = pred
return best_prediction
if __name__ == "__main__": if __name__ == "__main__":
...@@ -73,7 +101,6 @@ if __name__ == "__main__": ...@@ -73,7 +101,6 @@ if __name__ == "__main__":
patch_transformer = PatchTransformer().cuda() patch_transformer = PatchTransformer().cuda()
patch_applier = PatchApplier().cuda() patch_applier = PatchApplier().cuda()
# set start time to current time # set start time to current time
start_time = time.time() start_time = time.time()
...@@ -94,24 +121,39 @@ if __name__ == "__main__": ...@@ -94,24 +121,39 @@ if __name__ == "__main__":
img_size_x = 640 img_size_x = 640
img_size_y = 480 img_size_y = 480
ctr = -1
while True: while True:
ctr += 1
ret, frame = cap.read() ret, frame = cap.read()
# resize our captured frame if we need # resize our captured frame if we need
frame = cv2.resize(frame, None, fx=1.0, fy=1.0, interpolation=cv2.INTER_AREA) frame = cv2.resize(frame, None, fx=1.0, fy=1.0, interpolation=cv2.INTER_AREA)
# cv2.imshow("Web cam input", frame) # cv2.imshow("Web cam input", frame)
# transform patch # transform patch (every couple of frames)
trans_patch = patch_transformer(patch.cuda(), torch.ones([1, 14, 5]).cuda(), img_size_x, img_size_y, if ctr % 100 == 0:
do_rotate=True, rand_loc=True) trans_patch = patch_transformer(patch.cuda(), torch.ones([1, 14, 5]).cuda(), img_size_x, img_size_y,
trans_patch_np = torch.transpose(trans_patch[0][0].T, 0, 1).detach().cpu().numpy() do_rotate=True, rand_loc=True)
# cv2.imshow("patch", trans_patch_np) trans_patch_np = torch.transpose(trans_patch[0][0].T, 0, 1).detach().cpu().numpy()
# extract bounding box (x1, y1, x2, y2)
bounding_box = extract_bounding_box(trans_patch_np)
print("True BB: {} {} {} {}".format(int(bounding_box[0]), int(bounding_box[1]), int(bounding_box[2]),
int(bounding_box[3])))
# apply patch # apply patch
frame = patch_applier(frame, trans_patch_np) frame = patch_applier(frame, trans_patch_np)
# detect object on our frame # detect object on our frame
results = model(frame.copy()) results = model(frame.copy())
# debug_preds() if ctr % 100 == 0:
# debug_preds()
pass
pred_box = get_best_prediction(bounding_box, results, 15) # get cats
if pred_box is not None:
print("P:{}".format(pred_box[-2]))
# show us frame with detection # show us frame with detection
cv2.imshow("img", results.render()[0]) cv2.imshow("img", results.render()[0])
......
...@@ -16,7 +16,7 @@ class PatchApplier(nn.Module): ...@@ -16,7 +16,7 @@ class PatchApplier(nn.Module):
super(PatchApplier, self).__init__() super(PatchApplier, self).__init__()
def forward(self, img, patch): def forward(self, img, patch):
img = torch.where(torch.tensor(patch < 0.1), torch.tensor(img)/256, torch.tensor(patch))*256 img = torch.where(torch.tensor(patch < 1e-05), torch.tensor(img)/256, torch.tensor(patch))*256
return img.detach().numpy() return img.detach().numpy()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment