Commit 20ac0a35 authored by Pavlo Beylin's avatar Pavlo Beylin
Browse files

CSM calculation for non-max-suppressed predictions for the 10 best classes.

parent d0ced529
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from main import coco_class_names
def calc_yolo_csms(imgs_and_preds: torch.Tensor,
sign: bool = False,
rescale: bool = True) -> torch.Tensor:
def calc_yolo_person_csms(imgs_and_logits: torch.Tensor,
sign: bool = True,
rescale_factor = 0,
loss_rescale_factor = 1) -> torch.Tensor:
'''
computes the cosine similarity map for given input images X
computes the YOLO cosine similarity map for given input images X
Parameters
---------
model: torch model
imgs_and_preds: torch tensor; shape: (Batch_Size, Channels, Width, Height)
imgs_and_logits: torch tensor; shape: (Batch_Size, Channels, Width, Height)
sign: use sign of gradients to calculate cosine similarity maps
rescale: rescale the logits before applying softmax -> solves gradient obfuscation problem of large logits
Returns
---------
......@@ -24,29 +28,42 @@ def calc_yolo_csms(imgs_and_preds: torch.Tensor,
imgs = []
clss = []
for tup in imgs_and_preds:
img, pred, frame, x1, y1, x2, y2 = tup
for tup in imgs_and_logits:
img, logit, frame, x1, y1, x2, y2 = tup
if not img.requires_grad:
img.requires_grad_()
logit = pred[5:]
cls = torch.argmax(pred[5:])
cls = torch.argmax(logit)
# rescale network output to avoid gradient obfuscation
if rescale:
logit = logit / torch.max(torch.abs(logit)) * 10
if rescale_factor > 0:
logit = rescale_factor * logit / torch.max(torch.abs(logit))
# calculate all classes
classes = len(logit)
# only first ten classes
classes = 10
# get top10 classes and person
classes = list(torch.sort(logit, descending=True)[1][:10])
class_names = [coco_class_names[c] for c in classes]
if classes[0] != 0: # do not consider predictions if person is not the top prediction
continue
print(f'Top Ten: {class_names}')
deltas = []
for c in range(classes):
for c in classes:
# calculate loss and compute gradient w.r.t. the input of the current class
y = torch.ones(1, device="cuda", dtype=torch.long) * c
loss = F.cross_entropy(logit.unsqueeze(0), y)
frame_grad = torch.autograd.grad(loss, frame, retain_graph=True)[0][:, 5:]
loss = F.cross_entropy(logit.unsqueeze(0), y) * loss_rescale_factor
frame_grad = torch.autograd.grad(loss, frame, retain_graph=True)[0]
img_grad = frame_grad[int(y1):int(y2), int(x1):int(x2), :]
try:
if torch.min(img_grad) == 0:
print(f"img grad contains zero {c}")
except Exception as e:
print(e)
# take sign of gradient as in the original paper
if sign:
img_grad = torch.sign(img_grad)
......@@ -56,23 +73,22 @@ def calc_yolo_csms(imgs_and_preds: torch.Tensor,
deltas = torch.stack(deltas)
# compute cosine similarity matrices
try:
deltas = torch.max(deltas, dim=-3).values # take only the maximum value of all channels to compute the
deltas = deltas.view(classes, 1, -1)
deltas = torch.max(deltas, dim=-1).values # take only the maximum value of all channels to compute the
deltas = deltas.view(len(classes), 1, -1)
norm = torch.norm(deltas, p=2, dim=2, keepdim=True)
# norm[norm == 0] = 1
if torch.min(norm) != 0:
deltas = deltas / norm
else:
print("norm min contains 0")
deltas = deltas.transpose(0, 1)
csm = torch.matmul(deltas, deltas.transpose(1, 2))
except Exception as e:
print("error")
# raise e
# division by zero can lead to NaNs
if torch.isnan(csm).any():
# raise Exception("NaNs in CSM!")
print("NaNs in csm")
raise Exception("NaNs in CSM!")
# print("NaNs in csm")
else:
print(f'{deltas.mean()}')
imgs.append(img)
csms.append(csm)
clss.append(cls)
......
......@@ -29,7 +29,7 @@ model = torch.hub.load('ultralytics/yolov5', 'yolov5l') # or yolov5m, yolov5l,
MIN_THRESHOLD = 0.00001
classes = ["person", "bicycle", "car", "motorbike", "aeroplane", "bus",
coco_class_names = ["person", "bicycle", "car", "motorbike", "aeroplane", "bus",
"train", "truck", "boat", "traffic light", "fire hydrant",
"stop sign", "parking meter", "bench", "bird", "cat", "dog",
"horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe",
......@@ -63,9 +63,9 @@ def show(imgs):
def debug_preds():
detected_classes = [int(results.pred[0][i][-1]) for i in range(0, len(results.pred[0]))]
detected_classes = [int(detections.pred[0][i][-1]) for i in range(0, len(detections.pred[0]))]
# print(detected_classes)
for det in results.pred[0]:
for det in detections.pred[0]:
if int(det[-1]) == 15: # cat
print("Pred BB: ", end="")
# print("x1:y1 : {}:{}".format(float(det[0]), float(det[1])))
......@@ -165,23 +165,17 @@ def get_best_prediction(true_box, res, cls_nr):
def calculate_csms(frame, predictions):
imgs_and_preds = []
for pred in predictions:
x1, y1, x2, y2, conf = pred[:5].float()
imgs_and_logits = []
for i in range(len(predictions.pred[0])):
x1, y1, x2, y2, conf = predictions.pred[0][i][:5].float()
pred_img_section = frame.flip(2)[int(y1):int(y2), int(x1):int(x2), :]
tup = (pred_img_section, pred, frame, x1, y1, x2, y2)
# print(tup)
imgs_and_preds.append(tup)
tup = (pred_img_section, predictions.logits[i], frame, x1, y1, x2, y2)
imgs_and_logits.append(tup)
# if conf > 0.8:
# cls = classes[int(pred[5:].argmax())]
# print(f"{cls}: {conf} - {pred[:5].float()}")
# show(frame.flip(2)[int(y1):int(y2), int(x1):int(x2), :] / 255.)
# print("done")
imgs, csms, cls = CSM.calc_yolo_csms(imgs_and_preds)
# TODO insert non_max_suppression
imgs, csms, cls = CSM.calc_yolo_person_csms(imgs_and_logits, rescale_factor=0, loss_rescale_factor=1000)
return imgs, csms, cls
......@@ -258,11 +252,11 @@ if __name__ == "__main__":
if not (fix_frame and frame_read):
# resize our captured frame if we need
frame = cv2.resize(frame, None, fx=1.0, fy=1.0, interpolation=cv2.INTER_AREA)
frame_original = torch.tensor(frame, dtype=torch.float32, requires_grad=True).cuda()
frame_original = torch.tensor(frame, dtype=torch.float32, requires_grad=True, device="cuda")
frame = frame_original.clone()
frame_read = True
results = None
detections = None
for _ in range(transform_interval):
ctr += 1
......@@ -283,19 +277,20 @@ if __name__ == "__main__":
frame = patch_applier(frame_original, trans_patch)
# detect object on our frame
if ctr % 1 == 0 or results is None:
results, raw_results = model.forward_pt(frame)
if ctr % 1 == 0 or detections is None:
detections, raw_results = model.forward_pt(frame)
if ctr % 1 == 0:
# debug_preds()
pass
# calculate Cosine Similarity Matrix
imgs, csms, clss = calculate_csms(frame, raw_results)
# imgs, csms, clss = calculate_csms(frame, raw_results)
imgs, csms, clss = calculate_csms(frame, detections)
for i in range(len(csms)):
# show only person predictions
if clss[i] == 0:
show([imgs[i]/255, csms[i].T])
show([torch.min(torch.ones_like(imgs[i]), imgs[i]/255), csms[i].T])
# iou, pred = get_best_prediction(bounding_box, raw_results, 15) # get cat
iou, pred = get_best_prediction(bounding_box, raw_results, 0) # get personal
......@@ -336,16 +331,16 @@ if __name__ == "__main__":
# sgn_grads = torch.sign(optimizer.param_groups[0]['params'][0].grad)
# optimizer.param_groups[0]['params'][0].grad = sgn_grads
# optimizer.step()
patch.data -= torch.sign(gradient_sum) * 0.001 # * 0 # TODO reactivate
patch.data -= torch.sign(gradient_sum) * 0.001
patch.data = patch.detach().clone().clamp(MIN_THRESHOLD, 0.99999).data
gradient_sum = 0
# show us frame with detection
# cv2.imshow("img", results_np.render()[0])
try:
cv2.imshow("img", results.render()[0])
except Exception:
print("catproblem")
cv2.imshow("img", detections.render()[0])
except Exception as e:
print(f"catproblem {e}")
key = cv2.waitKey(25) & 0xFF
if key == ord("q"):
......
......@@ -328,13 +328,15 @@ class AutoShape(nn.Module):
t.append(time_sync())
# Post-process
y_sup, y_raw = non_max_suppression(y, self.conf, iou_thres=self.iou, classes=self.classes, max_det=self.max_det) # NMS
y_sup, y_raw, logits = non_max_suppression(y, self.conf, iou_thres=self.iou, classes=self.classes, max_det=self.max_det) # NMS
# for i in range(n):
# scale_coords(shape1, y[i][:, :4], shape0[i])
t.append(time_sync())
# return Detections(imgs, y, files, t, self.names, x.shape)
return Detections([img], y_sup, None, t, self.names, x.shape), y_raw
detections = Detections([img], y_sup, None, t, self.names, x.shape)
detections.logits = logits
return detections, y_raw
@torch.no_grad()
def forward(self, imgs, size=640, augment=False, profile=False):
......
......@@ -542,6 +542,7 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
t = time.time()
output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
mad_output = None
logits = []
for xi, x in enumerate(prediction): # image index, image inference
# Apply constraints
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
......@@ -577,6 +578,7 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
else: # best class only
conf, j = x[:, 5:].max(1, keepdim=True)
logits.append(x[:, 5:][conf.view(-1) > conf_thres])
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
# Filter by class
......@@ -609,11 +611,12 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
i = i[iou.sum(1) > 1] # require redundancy
output[xi] = x[i]
logits[xi] = [logits[0][x] for x in i]
if (time.time() - t) > time_limit:
print(f'WARNING: NMS time limit {time_limit}s exceeded')
break # time limit exceeded
return output, mad_output
return output, mad_output, logits[0]
def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment