Commit 3d9e90a3 authored by Pavlo Beylin's avatar Pavlo Beylin
Browse files

Add probabilistic catselector.

parent 5f366d50
......@@ -3,6 +3,7 @@ import torch
import cv2
import time
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
......@@ -36,18 +37,20 @@ classes = ["person", "bicycle", "car", "motorbike", "aeroplane", "bus",
"keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator",
"book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"]
PATH = "cat_patch0.jpg"
PATCH_SIZE = 100
PATH = "cat_patch1.jpg"
PATCH_SIZE = 300
def debug_preds():
detected_classes = [int(results.pred[0][i][-1]) for i in range(0, len(results.pred[0]))]
print(detected_classes)
# print(detected_classes)
for det in results.pred[0]:
if int(det[-1]) == 0: # person
print("Person ({}):".format(float(det[-2])))
print("x1:y1 : {}:{}".format(float(det[0]), float(det[1])))
print("x2:y2 : {}:{}".format(float(det[2]), float(det[3])))
if int(det[-1]) == 15: # cat
print("Pred BB: ", end="")
# print("x1:y1 : {}:{}".format(float(det[0]), float(det[1])))
# print("x2:y2 : {}:{}".format(float(det[2]), float(det[3])))
print("{} {} {} {} ({}):".format(
int(det[0]), int(det[1]), int(det[2]), int(det[3]), float(det[-2])))
# from https://github.com/wangzh0ng/adversarial_yolo2
......@@ -66,6 +69,31 @@ def read_image(path):
return tf(patch_img)
def extract_bounding_box(patch):
mask = torch.where(torch.tensor(patch) < 0.1, torch.zeros(patch.shape), torch.ones(patch.shape)).sum(2)
bb_x1 = mask.sum(0).nonzero()[0]
bb_y1 = mask.sum(1).nonzero()[0]
bb_x2 = mask.sum(0).nonzero()[-1]
bb_y2 = mask.sum(1).nonzero()[-1]
return torch.stack([bb_x1, bb_y1, bb_x2, bb_y2], axis=0).sum(1)
def get_best_prediction(true_box, res, cls_nr):
min_distance = float("inf")
best_prediction = None
for pred in res.pred[0]:
if int(pred[-1]) != cls_nr:
continue
pred_dist = torch.dist(true_box.cuda(), pred[:4])
if pred_dist < min_distance:
min_distance = pred_dist
best_prediction = pred
return best_prediction
if __name__ == "__main__":
......@@ -73,7 +101,6 @@ if __name__ == "__main__":
patch_transformer = PatchTransformer().cuda()
patch_applier = PatchApplier().cuda()
# set start time to current time
start_time = time.time()
......@@ -94,24 +121,39 @@ if __name__ == "__main__":
img_size_x = 640
img_size_y = 480
ctr = -1
while True:
ctr += 1
ret, frame = cap.read()
# resize our captured frame if we need
frame = cv2.resize(frame, None, fx=1.0, fy=1.0, interpolation=cv2.INTER_AREA)
# cv2.imshow("Web cam input", frame)
# transform patch
# transform patch (every couple of frames)
if ctr % 100 == 0:
trans_patch = patch_transformer(patch.cuda(), torch.ones([1, 14, 5]).cuda(), img_size_x, img_size_y,
do_rotate=True, rand_loc=True)
trans_patch_np = torch.transpose(trans_patch[0][0].T, 0, 1).detach().cpu().numpy()
# cv2.imshow("patch", trans_patch_np)
# extract bounding box (x1, y1, x2, y2)
bounding_box = extract_bounding_box(trans_patch_np)
print("True BB: {} {} {} {}".format(int(bounding_box[0]), int(bounding_box[1]), int(bounding_box[2]),
int(bounding_box[3])))
# apply patch
frame = patch_applier(frame, trans_patch_np)
# detect object on our frame
results = model(frame.copy())
if ctr % 100 == 0:
# debug_preds()
pass
pred_box = get_best_prediction(bounding_box, results, 15) # get cats
if pred_box is not None:
print("P:{}".format(pred_box[-2]))
# show us frame with detection
cv2.imshow("img", results.render()[0])
......
......@@ -16,7 +16,7 @@ class PatchApplier(nn.Module):
super(PatchApplier, self).__init__()
def forward(self, img, patch):
img = torch.where(torch.tensor(patch < 0.1), torch.tensor(img)/256, torch.tensor(patch))*256
img = torch.where(torch.tensor(patch < 1e-05), torch.tensor(img)/256, torch.tensor(patch))*256
return img.detach().numpy()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment