Commit 20ac0a35 authored by Pavlo Beylin's avatar Pavlo Beylin
Browse files

CSM calculation for non-max-suppressed predictions for the 10 best classes.

parent d0ced529
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import matplotlib.pyplot as plt
from main import coco_class_names
def calc_yolo_csms(imgs_and_preds: torch.Tensor,
sign: bool = False, def calc_yolo_person_csms(imgs_and_logits: torch.Tensor,
rescale: bool = True) -> torch.Tensor: sign: bool = True,
rescale_factor = 0,
loss_rescale_factor = 1) -> torch.Tensor:
''' '''
computes the cosine similarity map for given input images X computes the YOLO cosine similarity map for given input images X
Parameters Parameters
--------- ---------
model: torch model model: torch model
imgs_and_preds: torch tensor; shape: (Batch_Size, Channels, Width, Height) imgs_and_logits: torch tensor; shape: (Batch_Size, Channels, Width, Height)
sign: use sign of gradients to calculate cosine similarity maps sign: use sign of gradients to calculate cosine similarity maps
rescale: rescale the logits before applying softmax -> solves gradient obfuscation problem of large logits
Returns Returns
--------- ---------
...@@ -24,29 +28,42 @@ def calc_yolo_csms(imgs_and_preds: torch.Tensor, ...@@ -24,29 +28,42 @@ def calc_yolo_csms(imgs_and_preds: torch.Tensor,
imgs = [] imgs = []
clss = [] clss = []
for tup in imgs_and_preds: for tup in imgs_and_logits:
img, pred, frame, x1, y1, x2, y2 = tup img, logit, frame, x1, y1, x2, y2 = tup
if not img.requires_grad: if not img.requires_grad:
img.requires_grad_() img.requires_grad_()
logit = pred[5:] cls = torch.argmax(logit)
cls = torch.argmax(pred[5:])
# rescale network output to avoid gradient obfuscation # rescale network output to avoid gradient obfuscation
if rescale: if rescale_factor > 0:
logit = logit / torch.max(torch.abs(logit)) * 10 logit = rescale_factor * logit / torch.max(torch.abs(logit))
# calculate all classes
classes = len(logit) classes = len(logit)
# only first ten classes
classes = 10 # get top10 classes and person
classes = list(torch.sort(logit, descending=True)[1][:10])
class_names = [coco_class_names[c] for c in classes]
if classes[0] != 0: # do not consider predictions if person is not the top prediction
continue
print(f'Top Ten: {class_names}')
deltas = [] deltas = []
for c in range(classes): for c in classes:
# calculate loss and compute gradient w.r.t. the input of the current class # calculate loss and compute gradient w.r.t. the input of the current class
y = torch.ones(1, device="cuda", dtype=torch.long) * c y = torch.ones(1, device="cuda", dtype=torch.long) * c
loss = F.cross_entropy(logit.unsqueeze(0), y) loss = F.cross_entropy(logit.unsqueeze(0), y) * loss_rescale_factor
frame_grad = torch.autograd.grad(loss, frame, retain_graph=True)[0][:, 5:] frame_grad = torch.autograd.grad(loss, frame, retain_graph=True)[0]
img_grad = frame_grad[int(y1):int(y2), int(x1):int(x2), :] img_grad = frame_grad[int(y1):int(y2), int(x1):int(x2), :]
try:
if torch.min(img_grad) == 0:
print(f"img grad contains zero {c}")
except Exception as e:
print(e)
# take sign of gradient as in the original paper # take sign of gradient as in the original paper
if sign: if sign:
img_grad = torch.sign(img_grad) img_grad = torch.sign(img_grad)
...@@ -56,23 +73,22 @@ def calc_yolo_csms(imgs_and_preds: torch.Tensor, ...@@ -56,23 +73,22 @@ def calc_yolo_csms(imgs_and_preds: torch.Tensor,
deltas = torch.stack(deltas) deltas = torch.stack(deltas)
# compute cosine similarity matrices # compute cosine similarity matrices
try: deltas = torch.max(deltas, dim=-1).values # take only the maximum value of all channels to compute the
deltas = torch.max(deltas, dim=-3).values # take only the maximum value of all channels to compute the deltas = deltas.view(len(classes), 1, -1)
deltas = deltas.view(classes, 1, -1)
norm = torch.norm(deltas, p=2, dim=2, keepdim=True) norm = torch.norm(deltas, p=2, dim=2, keepdim=True)
# norm[norm == 0] = 1
if torch.min(norm) != 0:
deltas = deltas / norm deltas = deltas / norm
else:
print("norm min contains 0")
deltas = deltas.transpose(0, 1) deltas = deltas.transpose(0, 1)
csm = torch.matmul(deltas, deltas.transpose(1, 2)) csm = torch.matmul(deltas, deltas.transpose(1, 2))
except Exception as e:
print("error")
# raise e
# division by zero can lead to NaNs # division by zero can lead to NaNs
if torch.isnan(csm).any(): if torch.isnan(csm).any():
# raise Exception("NaNs in CSM!") raise Exception("NaNs in CSM!")
print("NaNs in csm") # print("NaNs in csm")
else: else:
print(f'{deltas.mean()}')
imgs.append(img) imgs.append(img)
csms.append(csm) csms.append(csm)
clss.append(cls) clss.append(cls)
......
...@@ -29,7 +29,7 @@ model = torch.hub.load('ultralytics/yolov5', 'yolov5l') # or yolov5m, yolov5l, ...@@ -29,7 +29,7 @@ model = torch.hub.load('ultralytics/yolov5', 'yolov5l') # or yolov5m, yolov5l,
MIN_THRESHOLD = 0.00001 MIN_THRESHOLD = 0.00001
classes = ["person", "bicycle", "car", "motorbike", "aeroplane", "bus", coco_class_names = ["person", "bicycle", "car", "motorbike", "aeroplane", "bus",
"train", "truck", "boat", "traffic light", "fire hydrant", "train", "truck", "boat", "traffic light", "fire hydrant",
"stop sign", "parking meter", "bench", "bird", "cat", "dog", "stop sign", "parking meter", "bench", "bird", "cat", "dog",
"horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe",
...@@ -63,9 +63,9 @@ def show(imgs): ...@@ -63,9 +63,9 @@ def show(imgs):
def debug_preds(): def debug_preds():
detected_classes = [int(results.pred[0][i][-1]) for i in range(0, len(results.pred[0]))] detected_classes = [int(detections.pred[0][i][-1]) for i in range(0, len(detections.pred[0]))]
# print(detected_classes) # print(detected_classes)
for det in results.pred[0]: for det in detections.pred[0]:
if int(det[-1]) == 15: # cat if int(det[-1]) == 15: # cat
print("Pred BB: ", end="") print("Pred BB: ", end="")
# print("x1:y1 : {}:{}".format(float(det[0]), float(det[1]))) # print("x1:y1 : {}:{}".format(float(det[0]), float(det[1])))
...@@ -165,23 +165,17 @@ def get_best_prediction(true_box, res, cls_nr): ...@@ -165,23 +165,17 @@ def get_best_prediction(true_box, res, cls_nr):
def calculate_csms(frame, predictions): def calculate_csms(frame, predictions):
imgs_and_preds = [] imgs_and_logits = []
for pred in predictions:
x1, y1, x2, y2, conf = pred[:5].float()
for i in range(len(predictions.pred[0])):
x1, y1, x2, y2, conf = predictions.pred[0][i][:5].float()
pred_img_section = frame.flip(2)[int(y1):int(y2), int(x1):int(x2), :] pred_img_section = frame.flip(2)[int(y1):int(y2), int(x1):int(x2), :]
tup = (pred_img_section, pred, frame, x1, y1, x2, y2) tup = (pred_img_section, predictions.logits[i], frame, x1, y1, x2, y2)
# print(tup) imgs_and_logits.append(tup)
imgs_and_preds.append(tup)
# if conf > 0.8:
# cls = classes[int(pred[5:].argmax())]
# print(f"{cls}: {conf} - {pred[:5].float()}")
# show(frame.flip(2)[int(y1):int(y2), int(x1):int(x2), :] / 255.)
# print("done")
imgs, csms, cls = CSM.calc_yolo_csms(imgs_and_preds) # TODO insert non_max_suppression
imgs, csms, cls = CSM.calc_yolo_person_csms(imgs_and_logits, rescale_factor=0, loss_rescale_factor=1000)
return imgs, csms, cls return imgs, csms, cls
...@@ -258,11 +252,11 @@ if __name__ == "__main__": ...@@ -258,11 +252,11 @@ if __name__ == "__main__":
if not (fix_frame and frame_read): if not (fix_frame and frame_read):
# resize our captured frame if we need # resize our captured frame if we need
frame = cv2.resize(frame, None, fx=1.0, fy=1.0, interpolation=cv2.INTER_AREA) frame = cv2.resize(frame, None, fx=1.0, fy=1.0, interpolation=cv2.INTER_AREA)
frame_original = torch.tensor(frame, dtype=torch.float32, requires_grad=True).cuda() frame_original = torch.tensor(frame, dtype=torch.float32, requires_grad=True, device="cuda")
frame = frame_original.clone() frame = frame_original.clone()
frame_read = True frame_read = True
results = None detections = None
for _ in range(transform_interval): for _ in range(transform_interval):
ctr += 1 ctr += 1
...@@ -283,19 +277,20 @@ if __name__ == "__main__": ...@@ -283,19 +277,20 @@ if __name__ == "__main__":
frame = patch_applier(frame_original, trans_patch) frame = patch_applier(frame_original, trans_patch)
# detect object on our frame # detect object on our frame
if ctr % 1 == 0 or results is None: if ctr % 1 == 0 or detections is None:
results, raw_results = model.forward_pt(frame) detections, raw_results = model.forward_pt(frame)
if ctr % 1 == 0: if ctr % 1 == 0:
# debug_preds() # debug_preds()
pass pass
# calculate Cosine Similarity Matrix # calculate Cosine Similarity Matrix
imgs, csms, clss = calculate_csms(frame, raw_results) # imgs, csms, clss = calculate_csms(frame, raw_results)
imgs, csms, clss = calculate_csms(frame, detections)
for i in range(len(csms)): for i in range(len(csms)):
# show only person predictions # show only person predictions
if clss[i] == 0: if clss[i] == 0:
show([imgs[i]/255, csms[i].T]) show([torch.min(torch.ones_like(imgs[i]), imgs[i]/255), csms[i].T])
# iou, pred = get_best_prediction(bounding_box, raw_results, 15) # get cat # iou, pred = get_best_prediction(bounding_box, raw_results, 15) # get cat
iou, pred = get_best_prediction(bounding_box, raw_results, 0) # get personal iou, pred = get_best_prediction(bounding_box, raw_results, 0) # get personal
...@@ -336,16 +331,16 @@ if __name__ == "__main__": ...@@ -336,16 +331,16 @@ if __name__ == "__main__":
# sgn_grads = torch.sign(optimizer.param_groups[0]['params'][0].grad) # sgn_grads = torch.sign(optimizer.param_groups[0]['params'][0].grad)
# optimizer.param_groups[0]['params'][0].grad = sgn_grads # optimizer.param_groups[0]['params'][0].grad = sgn_grads
# optimizer.step() # optimizer.step()
patch.data -= torch.sign(gradient_sum) * 0.001 # * 0 # TODO reactivate patch.data -= torch.sign(gradient_sum) * 0.001
patch.data = patch.detach().clone().clamp(MIN_THRESHOLD, 0.99999).data patch.data = patch.detach().clone().clamp(MIN_THRESHOLD, 0.99999).data
gradient_sum = 0 gradient_sum = 0
# show us frame with detection # show us frame with detection
# cv2.imshow("img", results_np.render()[0]) # cv2.imshow("img", results_np.render()[0])
try: try:
cv2.imshow("img", results.render()[0]) cv2.imshow("img", detections.render()[0])
except Exception: except Exception as e:
print("catproblem") print(f"catproblem {e}")
key = cv2.waitKey(25) & 0xFF key = cv2.waitKey(25) & 0xFF
if key == ord("q"): if key == ord("q"):
......
...@@ -328,13 +328,15 @@ class AutoShape(nn.Module): ...@@ -328,13 +328,15 @@ class AutoShape(nn.Module):
t.append(time_sync()) t.append(time_sync())
# Post-process # Post-process
y_sup, y_raw = non_max_suppression(y, self.conf, iou_thres=self.iou, classes=self.classes, max_det=self.max_det) # NMS y_sup, y_raw, logits = non_max_suppression(y, self.conf, iou_thres=self.iou, classes=self.classes, max_det=self.max_det) # NMS
# for i in range(n): # for i in range(n):
# scale_coords(shape1, y[i][:, :4], shape0[i]) # scale_coords(shape1, y[i][:, :4], shape0[i])
t.append(time_sync()) t.append(time_sync())
# return Detections(imgs, y, files, t, self.names, x.shape) # return Detections(imgs, y, files, t, self.names, x.shape)
return Detections([img], y_sup, None, t, self.names, x.shape), y_raw detections = Detections([img], y_sup, None, t, self.names, x.shape)
detections.logits = logits
return detections, y_raw
@torch.no_grad() @torch.no_grad()
def forward(self, imgs, size=640, augment=False, profile=False): def forward(self, imgs, size=640, augment=False, profile=False):
......
...@@ -542,6 +542,7 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non ...@@ -542,6 +542,7 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
t = time.time() t = time.time()
output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0] output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
mad_output = None mad_output = None
logits = []
for xi, x in enumerate(prediction): # image index, image inference for xi, x in enumerate(prediction): # image index, image inference
# Apply constraints # Apply constraints
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
...@@ -577,6 +578,7 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non ...@@ -577,6 +578,7 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1) x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
else: # best class only else: # best class only
conf, j = x[:, 5:].max(1, keepdim=True) conf, j = x[:, 5:].max(1, keepdim=True)
logits.append(x[:, 5:][conf.view(-1) > conf_thres])
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres] x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
# Filter by class # Filter by class
...@@ -609,11 +611,12 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non ...@@ -609,11 +611,12 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
i = i[iou.sum(1) > 1] # require redundancy i = i[iou.sum(1) > 1] # require redundancy
output[xi] = x[i] output[xi] = x[i]
logits[xi] = [logits[0][x] for x in i]
if (time.time() - t) > time_limit: if (time.time() - t) > time_limit:
print(f'WARNING: NMS time limit {time_limit}s exceeded') print(f'WARNING: NMS time limit {time_limit}s exceeded')
break # time limit exceeded break # time limit exceeded
return output, mad_output return output, mad_output, logits[0]
def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer() def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment