Commit bc08f03d authored by Pavlo Beylin's avatar Pavlo Beylin
Browse files

Implement targeted patch optimization with random noise.

parent b66751f0
......@@ -28,6 +28,7 @@ def show(img):
plt.imshow(img.detach().cpu())
plt.show()
MIN_THRESHOLD = 0.00001
classes = ["person", "bicycle", "car", "motorbike", "aeroplane", "bus",
"train", "truck", "boat", "traffic light", "fire hydrant",
......@@ -76,7 +77,7 @@ def read_image(path):
def extract_bounding_box(patch):
mask = torch.where(patch < 0, torch.zeros(patch.shape).cuda(), torch.ones(patch.shape).cuda()).sum(2)
mask = torch.where(patch < MIN_THRESHOLD, torch.zeros(patch.shape).cuda(), torch.ones(patch.shape).cuda()).sum(2)
bb_x1 = torch.nonzero(mask.sum(0))[0]
bb_y1 = torch.nonzero(mask.sum(1))[0]
......@@ -90,15 +91,12 @@ def get_best_prediction(true_box, res, cls_nr):
min_distance = float("inf")
best_prediction = None
for pred in res.pred[0]:
pred_cls_nr = int(pred[-1])
if pred_cls_nr != cls_nr:
continue
for pred in res:
pred_dist = torch.dist(true_box.cuda(), pred[:4])
if pred_dist < min_distance:
if pred_dist < min_distance and pred[5:].max() > 0.1:
min_distance = pred_dist
best_prediction = pred
best_prediction = pred[cls_nr+5]
return best_prediction
......@@ -123,6 +121,7 @@ if __name__ == "__main__":
raise IOError("We cannot open webcam")
patch = read_image(PATH)
patch = torch.rand_like(patch)
patch.requires_grad = True
optimizer = optim.Adam([patch], lr=0.0001, amsgrad=True)
......@@ -131,6 +130,9 @@ if __name__ == "__main__":
img_size_y = 480
ctr = -1
pred = -1
move = False
rotate = False
while True:
ctr += 1
......@@ -145,9 +147,10 @@ if __name__ == "__main__":
# transform patch (every couple of frames)
if ctr % 1 == 0:
move = pred > 0.5
# print("{} {}".format(float(patch.min()), float(patch.max())))
trans_patch = patch_transformer(patch.cuda(), torch.ones([1, 14, 5]).cuda(), img_size_x, img_size_y,
do_rotate=True, rand_loc=False)
do_rotate=move, rand_loc=move)
trans_patch = torch.transpose(trans_patch[0][0].T, 0, 1)
# extract bounding box (x1, y1, x2, y2)
......@@ -159,44 +162,52 @@ if __name__ == "__main__":
frame = patch_applier(frame, trans_patch)
# detect object on our frame
results = model.forward_pt(frame)
results, raw_results = model.forward_pt(frame)
if ctr % 1 == 0:
# debug_preds()
pass
pred_box = get_best_prediction(bounding_box, results, 15) # get cats
pred = get_best_prediction(bounding_box, raw_results, 15) # get cats
# pred = get_best_prediction(bounding_box, raw_results, 42) # get forked
if pred_box is not None:
# print("P:{}".format(pred_box[-2]))
if pred is not None:
print("P:{}".format(pred))
# loss
loss = -1 * pred_box[-2] # optimize class
# loss = 1 * pred_box[-2] # adversarial
loss = -1 * pred # optimize class
# loss = 1 * pred # adversarial
loss.backward(retain_graph=True)
optimizer.step()
# sgn_grads = torch.sign(optimizer.param_groups[0]['params'][0].grad)
# optimizer.param_groups[0]['params'][0].grad = sgn_grads
# optimizer.step()
patch.data -= torch.sign(patch.grad) * 0.001
patch.data = patch.detach().clone().clamp(MIN_THRESHOLD, 0.99999).data
# TODO: general clamping (regardless of min / max values) kills all patch updates. Why? Cloning maybe?
if patch.max() > 1 or patch.min() < 0:
print("clamped {}".format(time.time()))
patch = torch.clamp(patch.detach().clone(), 0., 1.)
# show us frame with detection
# cv2.imshow("img", results_np.render()[0])
cv2.imshow("img", results.render()[0])
try:
cv2.imshow("img", results.render()[0])
except Exception:
print("catproblem")
if cv2.waitKey(25) & 0xFF == ord("q"):
cv2.destroyAllWindows()
break
if cv2.waitKey(25) & 0xFF == ord("u"):
move = not move
if cv2.waitKey(25) & 0xFF == ord("o"):
rotate = not rotate
# calculate FPS
fps += 1
TIME = time.time() - start_time
if TIME > display_time:
print("FPS:", fps / TIME)
# print("FPS:", fps / TIME)
fps = 0
start_time = time.time()
# time.sleep(0.3)
# time.sleep(0.2)
cap.release()
cv2.destroyAllWindows()
......@@ -328,13 +328,13 @@ class AutoShape(nn.Module):
t.append(time_sync())
# Post-process
y_sup = non_max_suppression(y, self.conf, iou_thres=self.iou, classes=self.classes, max_det=self.max_det) # NMS
y_sup, y_raw = non_max_suppression(y, self.conf, iou_thres=self.iou, classes=self.classes, max_det=self.max_det) # NMS
# for i in range(n):
# scale_coords(shape1, y[i][:, :4], shape0[i])
t.append(time_sync())
# return Detections(imgs, y, files, t, self.names, x.shape)
return Detections([img], y_sup, None, t, self.names, x.shape)
return Detections([img], y_sup, None, t, self.names, x.shape), y_raw
@torch.no_grad()
def forward(self, imgs, size=640, augment=False, profile=False):
......
......@@ -123,7 +123,7 @@ class PatchTransformer(nn.Module):
noise = torch.cuda.FloatTensor(adv_batch.size()).uniform_(-1, 1) * self.noise_factor
# Apply contrast/brightness/noise, clamp
adv_batch = adv_batch * contrast + brightness + noise
adv_batch = adv_batch # * contrast + brightness + noise
adv_batch = torch.clamp(adv_batch, 0.000001, 0.99999)
......
......@@ -541,6 +541,7 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
t = time.time()
output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
mad_output = None
for xi, x in enumerate(prediction): # image index, image inference
# Apply constraints
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
......@@ -567,6 +568,9 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
box = xywh2xyxy(x[:, :4])
mad_output = x.clone()
mad_output[:, :4] = box
# Detections matrix nx6 (xyxy, conf, cls)
if multi_label:
i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
......@@ -609,7 +613,7 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
print(f'WARNING: NMS time limit {time_limit}s exceeded')
break # time limit exceeded
return output
return output, mad_output
def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment