Commit 20ac0a35 authored by Pavlo Beylin's avatar Pavlo Beylin
Browse files

CSM calculation for non-max-suppressed predictions for the 10 best classes.

parent d0ced529
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from main import coco_class_names
def calc_yolo_csms(imgs_and_preds: torch.Tensor,
sign: bool = False,
rescale: bool = True) -> torch.Tensor:
def calc_yolo_person_csms(imgs_and_logits: torch.Tensor,
sign: bool = True,
rescale_factor = 0,
loss_rescale_factor = 1) -> torch.Tensor:
'''
computes the cosine similarity map for given input images X
computes the YOLO cosine similarity map for given input images X
Parameters
---------
model: torch model
imgs_and_preds: torch tensor; shape: (Batch_Size, Channels, Width, Height)
imgs_and_logits: torch tensor; shape: (Batch_Size, Channels, Width, Height)
sign: use sign of gradients to calculate cosine similarity maps
rescale: rescale the logits before applying softmax -> solves gradient obfuscation problem of large logits
Returns
---------
......@@ -24,29 +28,42 @@ def calc_yolo_csms(imgs_and_preds: torch.Tensor,
imgs = []
clss = []
for tup in imgs_and_preds:
img, pred, frame, x1, y1, x2, y2 = tup
for tup in imgs_and_logits:
img, logit, frame, x1, y1, x2, y2 = tup
if not img.requires_grad:
img.requires_grad_()
logit = pred[5:]
cls = torch.argmax(pred[5:])
cls = torch.argmax(logit)
# rescale network output to avoid gradient obfuscation
if rescale:
logit = logit / torch.max(torch.abs(logit)) * 10
if rescale_factor > 0:
logit = rescale_factor * logit / torch.max(torch.abs(logit))
# calculate all classes
classes = len(logit)
# only first ten classes
classes = 10
# get top10 classes and person
classes = list(torch.sort(logit, descending=True)[1][:10])
class_names = [coco_class_names[c] for c in classes]
if classes[0] != 0: # do not consider predictions if person is not the top prediction
continue
print(f'Top Ten: {class_names}')
deltas = []
for c in range(classes):
for c in classes:
# calculate loss and compute gradient w.r.t. the input of the current class
y = torch.ones(1, device="cuda", dtype=torch.long) * c
loss = F.cross_entropy(logit.unsqueeze(0), y)
frame_grad = torch.autograd.grad(loss, frame, retain_graph=True)[0][:, 5:]
loss = F.cross_entropy(logit.unsqueeze(0), y) * loss_rescale_factor
frame_grad = torch.autograd.grad(loss, frame, retain_graph=True)[0]
img_grad = frame_grad[int(y1):int(y2), int(x1):int(x2), :]
try:
if torch.min(img_grad) == 0:
print(f"img grad contains zero {c}")
except Exception as e:
print(e)
# take sign of gradient as in the original paper
if sign:
img_grad = torch.sign(img_grad)
......@@ -56,23 +73,22 @@ def calc_yolo_csms(imgs_and_preds: torch.Tensor,
deltas = torch.stack(deltas)
# compute cosine similarity matrices
try:
deltas = torch.max(deltas, dim=-3).values # take only the maximum value of all channels to compute the
deltas = deltas.view(classes, 1, -1)
norm = torch.norm(deltas, p=2, dim=2, keepdim=True)
deltas = torch.max(deltas, dim=-1).values # take only the maximum value of all channels to compute the
deltas = deltas.view(len(classes), 1, -1)
norm = torch.norm(deltas, p=2, dim=2, keepdim=True)
# norm[norm == 0] = 1
if torch.min(norm) != 0:
deltas = deltas / norm
deltas = deltas.transpose(0, 1)
csm = torch.matmul(deltas, deltas.transpose(1, 2))
except Exception as e:
print("error")
# raise e
else:
print("norm min contains 0")
deltas = deltas.transpose(0, 1)
csm = torch.matmul(deltas, deltas.transpose(1, 2))
# division by zero can lead to NaNs
if torch.isnan(csm).any():
# raise Exception("NaNs in CSM!")
print("NaNs in csm")
raise Exception("NaNs in CSM!")
# print("NaNs in csm")
else:
print(f'{deltas.mean()}')
imgs.append(img)
csms.append(csm)
clss.append(cls)
......
......@@ -29,7 +29,7 @@ model = torch.hub.load('ultralytics/yolov5', 'yolov5l') # or yolov5m, yolov5l,
MIN_THRESHOLD = 0.00001
classes = ["person", "bicycle", "car", "motorbike", "aeroplane", "bus",
coco_class_names = ["person", "bicycle", "car", "motorbike", "aeroplane", "bus",
"train", "truck", "boat", "traffic light", "fire hydrant",
"stop sign", "parking meter", "bench", "bird", "cat", "dog",
"horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe",
......@@ -40,8 +40,8 @@ classes = ["person", "bicycle", "car", "motorbike", "aeroplane", "bus",
"banana", "apple", "sandwich", "orange", "broccoli", "carrot",
"hot dog", "pizza", "donut", "cake", "chair", "sofa", "pottedplant",
"bed", "diningtable", "toilet", "tvmonitor", "laptop", "mouse", "remote",
"keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator",
"book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"]
"keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator",
"book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"]
PATH = "saved_patches/realcat.jpg"
PATH = "saved_patches/fatcat.jpg"
......@@ -63,9 +63,9 @@ def show(imgs):
def debug_preds():
detected_classes = [int(results.pred[0][i][-1]) for i in range(0, len(results.pred[0]))]
detected_classes = [int(detections.pred[0][i][-1]) for i in range(0, len(detections.pred[0]))]
# print(detected_classes)
for det in results.pred[0]:
for det in detections.pred[0]:
if int(det[-1]) == 15: # cat
print("Pred BB: ", end="")
# print("x1:y1 : {}:{}".format(float(det[0]), float(det[1])))
......@@ -165,23 +165,17 @@ def get_best_prediction(true_box, res, cls_nr):
def calculate_csms(frame, predictions):
imgs_and_preds = []
for pred in predictions:
x1, y1, x2, y2, conf = pred[:5].float()
imgs_and_logits = []
for i in range(len(predictions.pred[0])):
x1, y1, x2, y2, conf = predictions.pred[0][i][:5].float()
pred_img_section = frame.flip(2)[int(y1):int(y2), int(x1):int(x2), :]
tup = (pred_img_section, pred, frame, x1, y1, x2, y2)
# print(tup)
imgs_and_preds.append(tup)
tup = (pred_img_section, predictions.logits[i], frame, x1, y1, x2, y2)
imgs_and_logits.append(tup)
# if conf > 0.8:
# cls = classes[int(pred[5:].argmax())]
# print(f"{cls}: {conf} - {pred[:5].float()}")
# show(frame.flip(2)[int(y1):int(y2), int(x1):int(x2), :] / 255.)
# print("done")
imgs, csms, cls = CSM.calc_yolo_csms(imgs_and_preds)
# TODO insert non_max_suppression
imgs, csms, cls = CSM.calc_yolo_person_csms(imgs_and_logits, rescale_factor=0, loss_rescale_factor=1000)
return imgs, csms, cls
......@@ -258,11 +252,11 @@ if __name__ == "__main__":
if not (fix_frame and frame_read):
# resize our captured frame if we need
frame = cv2.resize(frame, None, fx=1.0, fy=1.0, interpolation=cv2.INTER_AREA)
frame_original = torch.tensor(frame, dtype=torch.float32, requires_grad=True).cuda()
frame_original = torch.tensor(frame, dtype=torch.float32, requires_grad=True, device="cuda")
frame = frame_original.clone()
frame_read = True
results = None
detections = None
for _ in range(transform_interval):
ctr += 1
......@@ -283,19 +277,20 @@ if __name__ == "__main__":
frame = patch_applier(frame_original, trans_patch)
# detect object on our frame
if ctr % 1 == 0 or results is None:
results, raw_results = model.forward_pt(frame)
if ctr % 1 == 0 or detections is None:
detections, raw_results = model.forward_pt(frame)
if ctr % 1 == 0:
# debug_preds()
pass
# calculate Cosine Similarity Matrix
imgs, csms, clss = calculate_csms(frame, raw_results)
# imgs, csms, clss = calculate_csms(frame, raw_results)
imgs, csms, clss = calculate_csms(frame, detections)
for i in range(len(csms)):
# show only person predictions
if clss[i] == 0:
show([imgs[i]/255, csms[i].T])
show([torch.min(torch.ones_like(imgs[i]), imgs[i]/255), csms[i].T])
# iou, pred = get_best_prediction(bounding_box, raw_results, 15) # get cat
iou, pred = get_best_prediction(bounding_box, raw_results, 0) # get personal
......@@ -336,16 +331,16 @@ if __name__ == "__main__":
# sgn_grads = torch.sign(optimizer.param_groups[0]['params'][0].grad)
# optimizer.param_groups[0]['params'][0].grad = sgn_grads
# optimizer.step()
patch.data -= torch.sign(gradient_sum) * 0.001 # * 0 # TODO reactivate
patch.data -= torch.sign(gradient_sum) * 0.001
patch.data = patch.detach().clone().clamp(MIN_THRESHOLD, 0.99999).data
gradient_sum = 0
# show us frame with detection
# cv2.imshow("img", results_np.render()[0])
try:
cv2.imshow("img", results.render()[0])
except Exception:
print("catproblem")
cv2.imshow("img", detections.render()[0])
except Exception as e:
print(f"catproblem {e}")
key = cv2.waitKey(25) & 0xFF
if key == ord("q"):
......
......@@ -328,13 +328,15 @@ class AutoShape(nn.Module):
t.append(time_sync())
# Post-process
y_sup, y_raw = non_max_suppression(y, self.conf, iou_thres=self.iou, classes=self.classes, max_det=self.max_det) # NMS
y_sup, y_raw, logits = non_max_suppression(y, self.conf, iou_thres=self.iou, classes=self.classes, max_det=self.max_det) # NMS
# for i in range(n):
# scale_coords(shape1, y[i][:, :4], shape0[i])
t.append(time_sync())
# return Detections(imgs, y, files, t, self.names, x.shape)
return Detections([img], y_sup, None, t, self.names, x.shape), y_raw
detections = Detections([img], y_sup, None, t, self.names, x.shape)
detections.logits = logits
return detections, y_raw
@torch.no_grad()
def forward(self, imgs, size=640, augment=False, profile=False):
......
......@@ -542,6 +542,7 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
t = time.time()
output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
mad_output = None
logits = []
for xi, x in enumerate(prediction): # image index, image inference
# Apply constraints
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
......@@ -577,6 +578,7 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
else: # best class only
conf, j = x[:, 5:].max(1, keepdim=True)
logits.append(x[:, 5:][conf.view(-1) > conf_thres])
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
# Filter by class
......@@ -609,11 +611,12 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
i = i[iou.sum(1) > 1] # require redundancy
output[xi] = x[i]
logits[xi] = [logits[0][x] for x in i]
if (time.time() - t) > time_limit:
print(f'WARNING: NMS time limit {time_limit}s exceeded')
break # time limit exceeded
return output, mad_output
return output, mad_output, logits[0]
def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment