Commit 3e896e33 authored by Pavlo Beylin's avatar Pavlo Beylin
Browse files

Add Total Variation loss. Implement smooth key controls.

parent bc08f03d
......@@ -8,6 +8,7 @@ from torch import optim
import models
from models.common import Detections
from utils.external import TotalVariation
from utils.general import scale_coords
matplotlib.use('TkAgg')
......@@ -23,11 +24,6 @@ model = torch.hub.load('ultralytics/yolov5', 'yolov5l') # or yolov5m, yolov5l,
# model = torch.hub.load('ultralytics/yolov3', 'yolov3')
def show(img):
plt.imshow(img.detach().cpu())
plt.show()
MIN_THRESHOLD = 0.00001
classes = ["person", "bicycle", "car", "motorbike", "aeroplane", "bus",
......@@ -47,6 +43,11 @@ classes = ["person", "bicycle", "car", "motorbike", "aeroplane", "bus",
PATH = "cat_patch0.jpg"
PATCH_SIZE = 300
total_variation = TotalVariation()
def show(img):
plt.imshow(img.detach().cpu())
plt.show()
def debug_preds():
detected_classes = [int(results.pred[0][i][-1]) for i in range(0, len(results.pred[0]))]
......@@ -87,6 +88,21 @@ def extract_bounding_box(patch):
return torch.stack([bb_x1, bb_y1, bb_x2, bb_y2]).sum(1)
def get_avg_prediction(res, cls_nr):
avg_prediction = 0
ctr = 0
if res is None:
return 0
for pred in res:
if pred[5:].max() > 0.4 or True:
ctr += 1
avg_prediction += pred[cls_nr + 5]
return avg_prediction / (ctr if ctr > 0 else 1)
def get_best_prediction(true_box, res, cls_nr):
min_distance = float("inf")
best_prediction = None
......@@ -96,7 +112,7 @@ def get_best_prediction(true_box, res, cls_nr):
if pred_dist < min_distance and pred[5:].max() > 0.1:
min_distance = pred_dist
best_prediction = pred[cls_nr+5]
best_prediction = pred[cls_nr + 5]
return best_prediction
......@@ -133,6 +149,7 @@ if __name__ == "__main__":
pred = -1
move = False
rotate = False
transform_interval = 10
while True:
ctr += 1
......@@ -146,11 +163,10 @@ if __name__ == "__main__":
frame = torch.tensor(frame, dtype=torch.float32, requires_grad=True).cuda()
# transform patch (every couple of frames)
if ctr % 1 == 0:
move = pred > 0.5
if ctr % transform_interval == 0:
# print("{} {}".format(float(patch.min()), float(patch.max())))
trans_patch = patch_transformer(patch.cuda(), torch.ones([1, 14, 5]).cuda(), img_size_x, img_size_y,
do_rotate=move, rand_loc=move)
do_rotate=rotate, rand_loc=move)
trans_patch = torch.transpose(trans_patch[0][0].T, 0, 1)
# extract bounding box (x1, y1, x2, y2)
......@@ -168,15 +184,25 @@ if __name__ == "__main__":
# debug_preds()
pass
pred = get_best_prediction(bounding_box, raw_results, 15) # get cats
# pred = get_best_prediction(bounding_box, raw_results, 15) # get cats
# pred = get_best_prediction(bounding_box, raw_results, 42) # get forked
pred = get_avg_prediction(raw_results, 15) # make everything cats
pred = get_avg_prediction(raw_results, 0) # make everything person
if pred is not None:
print("P:{}".format(pred))
# loss
loss = -1 * pred # optimize class
# loss = 1 * pred # adversarial
# loss = -1 * pred # optimize class
loss = 1 * pred # adversarial
# Total Variation Loss
tv_loss = total_variation(patch)
loss += tv_loss
if not isinstance(loss, torch.Tensor):
continue
loss.backward(retain_graph=True)
# sgn_grads = torch.sign(optimizer.param_groups[0]['params'][0].grad)
......@@ -185,20 +211,30 @@ if __name__ == "__main__":
patch.data -= torch.sign(patch.grad) * 0.001
patch.data = patch.detach().clone().clamp(MIN_THRESHOLD, 0.99999).data
# show us frame with detection
# cv2.imshow("img", results_np.render()[0])
try:
cv2.imshow("img", results.render()[0])
except Exception:
print("catproblem")
if cv2.waitKey(25) & 0xFF == ord("q"):
key = cv2.waitKey(25) & 0xFF
if key == ord("q"):
cv2.destroyAllWindows()
break
if cv2.waitKey(25) & 0xFF == ord("u"):
if key == ord("u"):
move = not move
if cv2.waitKey(25) & 0xFF == ord("o"):
print("Move: {}".format(move))
if key == ord("o"):
rotate = not rotate
print("Rotate: {}".format(rotate))
if key == ord("+"):
transform_interval += 1
print("Transform Interval: {}".format(transform_interval))
if key == ord("-"):
transform_interval -= 1
transform_interval = max(transform_interval, 1)
print("Transform Interval: {}".format(transform_interval))
# calculate FPS
fps += 1
......
import torch
import torch.nn as nn
# source: https://github.com/wangzh0ng/adversarial_yolo2
class TotalVariation(nn.Module):
"""TotalVariation: calculates the total variation of a patch.
Module providing the functionality necessary to calculate the total vatiation (TV) of an adversarial patch.
"""
def __init__(self):
super(TotalVariation, self).__init__()
def forward(self, adv_patch):
# bereken de total variation van de adv_patch
tvcomp1 = torch.sum(torch.abs(adv_patch[:, :, 1:] - adv_patch[:, :, :-1] + 0.000001), 0)
tvcomp1 = torch.sum(torch.sum(tvcomp1, 0), 0)
tvcomp2 = torch.sum(torch.abs(adv_patch[:, 1:, :] - adv_patch[:, :-1, :] + 0.000001), 0)
tvcomp2 = torch.sum(torch.sum(tvcomp2, 0), 0)
tv = tvcomp1 + tvcomp2
return tv / torch.numel(adv_patch)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment