Commit 3e896e33 authored by Pavlo Beylin's avatar Pavlo Beylin
Browse files

Add Total Variation loss. Implement smooth key controls.

parent bc08f03d
...@@ -8,6 +8,7 @@ from torch import optim ...@@ -8,6 +8,7 @@ from torch import optim
import models import models
from models.common import Detections from models.common import Detections
from utils.external import TotalVariation
from utils.general import scale_coords from utils.general import scale_coords
matplotlib.use('TkAgg') matplotlib.use('TkAgg')
...@@ -23,11 +24,6 @@ model = torch.hub.load('ultralytics/yolov5', 'yolov5l') # or yolov5m, yolov5l, ...@@ -23,11 +24,6 @@ model = torch.hub.load('ultralytics/yolov5', 'yolov5l') # or yolov5m, yolov5l,
# model = torch.hub.load('ultralytics/yolov3', 'yolov3') # model = torch.hub.load('ultralytics/yolov3', 'yolov3')
def show(img):
plt.imshow(img.detach().cpu())
plt.show()
MIN_THRESHOLD = 0.00001 MIN_THRESHOLD = 0.00001
classes = ["person", "bicycle", "car", "motorbike", "aeroplane", "bus", classes = ["person", "bicycle", "car", "motorbike", "aeroplane", "bus",
...@@ -47,6 +43,11 @@ classes = ["person", "bicycle", "car", "motorbike", "aeroplane", "bus", ...@@ -47,6 +43,11 @@ classes = ["person", "bicycle", "car", "motorbike", "aeroplane", "bus",
PATH = "cat_patch0.jpg" PATH = "cat_patch0.jpg"
PATCH_SIZE = 300 PATCH_SIZE = 300
total_variation = TotalVariation()
def show(img):
plt.imshow(img.detach().cpu())
plt.show()
def debug_preds(): def debug_preds():
detected_classes = [int(results.pred[0][i][-1]) for i in range(0, len(results.pred[0]))] detected_classes = [int(results.pred[0][i][-1]) for i in range(0, len(results.pred[0]))]
...@@ -87,6 +88,21 @@ def extract_bounding_box(patch): ...@@ -87,6 +88,21 @@ def extract_bounding_box(patch):
return torch.stack([bb_x1, bb_y1, bb_x2, bb_y2]).sum(1) return torch.stack([bb_x1, bb_y1, bb_x2, bb_y2]).sum(1)
def get_avg_prediction(res, cls_nr):
avg_prediction = 0
ctr = 0
if res is None:
return 0
for pred in res:
if pred[5:].max() > 0.4 or True:
ctr += 1
avg_prediction += pred[cls_nr + 5]
return avg_prediction / (ctr if ctr > 0 else 1)
def get_best_prediction(true_box, res, cls_nr): def get_best_prediction(true_box, res, cls_nr):
min_distance = float("inf") min_distance = float("inf")
best_prediction = None best_prediction = None
...@@ -96,7 +112,7 @@ def get_best_prediction(true_box, res, cls_nr): ...@@ -96,7 +112,7 @@ def get_best_prediction(true_box, res, cls_nr):
if pred_dist < min_distance and pred[5:].max() > 0.1: if pred_dist < min_distance and pred[5:].max() > 0.1:
min_distance = pred_dist min_distance = pred_dist
best_prediction = pred[cls_nr+5] best_prediction = pred[cls_nr + 5]
return best_prediction return best_prediction
...@@ -133,24 +149,24 @@ if __name__ == "__main__": ...@@ -133,24 +149,24 @@ if __name__ == "__main__":
pred = -1 pred = -1
move = False move = False
rotate = False rotate = False
transform_interval = 10
while True: while True:
ctr += 1 ctr += 1
ret, frame = cap.read() ret, frame = cap.read()
with torch.set_grad_enabled(True): with torch.set_grad_enabled(True):
# with torch.autograd.detect_anomaly(): # with torch.autograd.detect_anomaly():
# resize our captured frame if we need # resize our captured frame if we need
frame = cv2.resize(frame, None, fx=1.0, fy=1.0, interpolation=cv2.INTER_AREA) frame = cv2.resize(frame, None, fx=1.0, fy=1.0, interpolation=cv2.INTER_AREA)
frame = torch.tensor(frame, dtype=torch.float32, requires_grad=True).cuda() frame = torch.tensor(frame, dtype=torch.float32, requires_grad=True).cuda()
# transform patch (every couple of frames) # transform patch (every couple of frames)
if ctr % 1 == 0: if ctr % transform_interval == 0:
move = pred > 0.5
# print("{} {}".format(float(patch.min()), float(patch.max()))) # print("{} {}".format(float(patch.min()), float(patch.max())))
trans_patch = patch_transformer(patch.cuda(), torch.ones([1, 14, 5]).cuda(), img_size_x, img_size_y, trans_patch = patch_transformer(patch.cuda(), torch.ones([1, 14, 5]).cuda(), img_size_x, img_size_y,
do_rotate=move, rand_loc=move) do_rotate=rotate, rand_loc=move)
trans_patch = torch.transpose(trans_patch[0][0].T, 0, 1) trans_patch = torch.transpose(trans_patch[0][0].T, 0, 1)
# extract bounding box (x1, y1, x2, y2) # extract bounding box (x1, y1, x2, y2)
...@@ -168,15 +184,25 @@ if __name__ == "__main__": ...@@ -168,15 +184,25 @@ if __name__ == "__main__":
# debug_preds() # debug_preds()
pass pass
pred = get_best_prediction(bounding_box, raw_results, 15) # get cats # pred = get_best_prediction(bounding_box, raw_results, 15) # get cats
# pred = get_best_prediction(bounding_box, raw_results, 42) # get forked # pred = get_best_prediction(bounding_box, raw_results, 42) # get forked
pred = get_avg_prediction(raw_results, 15) # make everything cats
pred = get_avg_prediction(raw_results, 0) # make everything person
if pred is not None: if pred is not None:
print("P:{}".format(pred)) print("P:{}".format(pred))
# loss # loss
loss = -1 * pred # optimize class # loss = -1 * pred # optimize class
# loss = 1 * pred # adversarial loss = 1 * pred # adversarial
# Total Variation Loss
tv_loss = total_variation(patch)
loss += tv_loss
if not isinstance(loss, torch.Tensor):
continue
loss.backward(retain_graph=True) loss.backward(retain_graph=True)
# sgn_grads = torch.sign(optimizer.param_groups[0]['params'][0].grad) # sgn_grads = torch.sign(optimizer.param_groups[0]['params'][0].grad)
...@@ -185,20 +211,30 @@ if __name__ == "__main__": ...@@ -185,20 +211,30 @@ if __name__ == "__main__":
patch.data -= torch.sign(patch.grad) * 0.001 patch.data -= torch.sign(patch.grad) * 0.001
patch.data = patch.detach().clone().clamp(MIN_THRESHOLD, 0.99999).data patch.data = patch.detach().clone().clamp(MIN_THRESHOLD, 0.99999).data
# show us frame with detection # show us frame with detection
# cv2.imshow("img", results_np.render()[0]) # cv2.imshow("img", results_np.render()[0])
try: try:
cv2.imshow("img", results.render()[0]) cv2.imshow("img", results.render()[0])
except Exception: except Exception:
print("catproblem") print("catproblem")
if cv2.waitKey(25) & 0xFF == ord("q"):
key = cv2.waitKey(25) & 0xFF
if key == ord("q"):
cv2.destroyAllWindows() cv2.destroyAllWindows()
break break
if cv2.waitKey(25) & 0xFF == ord("u"): if key == ord("u"):
move = not move move = not move
if cv2.waitKey(25) & 0xFF == ord("o"): print("Move: {}".format(move))
rotate = not rotate if key == ord("o"):
rotate = not rotate
print("Rotate: {}".format(rotate))
if key == ord("+"):
transform_interval += 1
print("Transform Interval: {}".format(transform_interval))
if key == ord("-"):
transform_interval -= 1
transform_interval = max(transform_interval, 1)
print("Transform Interval: {}".format(transform_interval))
# calculate FPS # calculate FPS
fps += 1 fps += 1
......
import torch
import torch.nn as nn
# source: https://github.com/wangzh0ng/adversarial_yolo2
class TotalVariation(nn.Module):
"""TotalVariation: calculates the total variation of a patch.
Module providing the functionality necessary to calculate the total vatiation (TV) of an adversarial patch.
"""
def __init__(self):
super(TotalVariation, self).__init__()
def forward(self, adv_patch):
# bereken de total variation van de adv_patch
tvcomp1 = torch.sum(torch.abs(adv_patch[:, :, 1:] - adv_patch[:, :, :-1] + 0.000001), 0)
tvcomp1 = torch.sum(torch.sum(tvcomp1, 0), 0)
tvcomp2 = torch.sum(torch.abs(adv_patch[:, 1:, :] - adv_patch[:, :-1, :] + 0.000001), 0)
tvcomp2 = torch.sum(torch.sum(tvcomp2, 0), 0)
tv = tvcomp1 + tvcomp2
return tv / torch.numel(adv_patch)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment