diff --git a/CSM.py b/CSM.py
index 54dcd0e9f3b47f5c4247db95e367516c60b075ae..6346913df223905966a87ed1cf6619f19be15fb7 100644
--- a/CSM.py
+++ b/CSM.py
@@ -1,19 +1,23 @@
 import torch
 import torch.nn.functional as F
+import matplotlib.pyplot as plt
 
+from main import coco_class_names
 
-def calc_yolo_csms(imgs_and_preds: torch.Tensor,
-                   sign: bool = False,
-                   rescale: bool = True) -> torch.Tensor:
+
+def calc_yolo_person_csms(imgs_and_logits: torch.Tensor,
+                          sign: bool = True,
+                          rescale_factor = 0,
+                          loss_rescale_factor = 1) -> torch.Tensor:
     '''
-    computes the cosine similarity map for given input images X
+    computes the YOLO cosine similarity map for given input images X
 
     Parameters
     ---------
     model: torch model
-    imgs_and_preds: torch tensor; shape: (Batch_Size, Channels, Width, Height)
+    imgs_and_logits: torch tensor; shape: (Batch_Size, Channels, Width, Height)
     sign: use sign of gradients to calculate cosine similarity maps
-    rescale: rescale the logits before applying softmax -> solves gradient obfuscation problem of large logits
+
 
     Returns
     ---------
@@ -24,29 +28,42 @@ def calc_yolo_csms(imgs_and_preds: torch.Tensor,
     imgs = []
     clss = []
 
-    for tup in imgs_and_preds:
-        img, pred, frame, x1, y1, x2, y2 = tup
+    for tup in imgs_and_logits:
+        img, logit, frame, x1, y1, x2, y2 = tup
         if not img.requires_grad:
             img.requires_grad_()
-        logit = pred[5:]
-        cls = torch.argmax(pred[5:])
+        cls = torch.argmax(logit)
 
         # rescale network output to avoid gradient obfuscation
-        if rescale:
-            logit = logit / torch.max(torch.abs(logit)) * 10
+        if rescale_factor > 0:
+            logit = rescale_factor * logit / torch.max(torch.abs(logit))
 
+        # calculate all classes
         classes = len(logit)
-        # only first ten classes
-        classes = 10
+
+        # get top10 classes and person
+        classes = list(torch.sort(logit, descending=True)[1][:10])
+        class_names = [coco_class_names[c] for c in classes]
+
+        if classes[0] != 0:  # do not consider predictions if person is not the top prediction
+            continue
+
+        print(f'Top Ten: {class_names}')
 
         deltas = []
-        for c in range(classes):
+        for c in classes:
             #  calculate loss and compute gradient w.r.t. the input of the current class
             y = torch.ones(1, device="cuda", dtype=torch.long) * c
-            loss = F.cross_entropy(logit.unsqueeze(0), y)
-            frame_grad = torch.autograd.grad(loss, frame, retain_graph=True)[0][:, 5:]
+            loss = F.cross_entropy(logit.unsqueeze(0), y) * loss_rescale_factor
+            frame_grad = torch.autograd.grad(loss, frame, retain_graph=True)[0]
             img_grad = frame_grad[int(y1):int(y2), int(x1):int(x2), :]
 
+            try:
+                if torch.min(img_grad) == 0:
+                    print(f"img grad contains zero {c}")
+            except Exception as e:
+                print(e)
+
             #  take sign of gradient as in the original paper
             if sign:
                 img_grad = torch.sign(img_grad)
@@ -56,23 +73,22 @@ def calc_yolo_csms(imgs_and_preds: torch.Tensor,
         deltas = torch.stack(deltas)
         #  compute cosine similarity matrices
 
-        try:
-            deltas = torch.max(deltas, dim=-3).values  # take only the maximum value of all channels to compute the
-            deltas = deltas.view(classes, 1, -1)
-            norm = torch.norm(deltas, p=2, dim=2, keepdim=True)
+        deltas = torch.max(deltas, dim=-1).values  # take only the maximum value of all channels to compute the
+        deltas = deltas.view(len(classes), 1, -1)
+        norm = torch.norm(deltas, p=2, dim=2, keepdim=True)
+        # norm[norm == 0] = 1
+        if torch.min(norm) != 0:
             deltas = deltas / norm
-            deltas = deltas.transpose(0, 1)
-            csm = torch.matmul(deltas, deltas.transpose(1, 2))
-        except Exception as e:
-            print("error")
-            # raise e
+        else:
+            print("norm min contains 0")
+        deltas = deltas.transpose(0, 1)
+        csm = torch.matmul(deltas, deltas.transpose(1, 2))
 
         #  division by zero can lead to NaNs
         if torch.isnan(csm).any():
-            # raise Exception("NaNs in CSM!")
-            print("NaNs in csm")
+            raise Exception("NaNs in CSM!")
+            # print("NaNs in csm")
         else:
-            print(f'{deltas.mean()}')
             imgs.append(img)
             csms.append(csm)
             clss.append(cls)
diff --git a/main.py b/main.py
index 79cf10d0d0b2317262b9578442be47960552bcf0..5b43e98c875225373754896d35bcea565257e677 100644
--- a/main.py
+++ b/main.py
@@ -29,7 +29,7 @@ model = torch.hub.load('ultralytics/yolov5', 'yolov5l')  # or yolov5m, yolov5l,
 
 MIN_THRESHOLD = 0.00001
 
-classes = ["person", "bicycle", "car", "motorbike", "aeroplane", "bus",
+coco_class_names = ["person", "bicycle", "car", "motorbike", "aeroplane", "bus",
            "train", "truck", "boat", "traffic light", "fire hydrant",
            "stop sign", "parking meter", "bench", "bird", "cat", "dog",
            "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe",
@@ -40,8 +40,8 @@ classes = ["person", "bicycle", "car", "motorbike", "aeroplane", "bus",
            "banana", "apple", "sandwich", "orange", "broccoli", "carrot",
            "hot dog", "pizza", "donut", "cake", "chair", "sofa", "pottedplant",
            "bed", "diningtable", "toilet", "tvmonitor", "laptop", "mouse", "remote",
-           "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator",
-           "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"]
+                    "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator",
+                    "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"]
 
 PATH = "saved_patches/realcat.jpg"
 PATH = "saved_patches/fatcat.jpg"
@@ -63,9 +63,9 @@ def show(imgs):
 
 
 def debug_preds():
-    detected_classes = [int(results.pred[0][i][-1]) for i in range(0, len(results.pred[0]))]
+    detected_classes = [int(detections.pred[0][i][-1]) for i in range(0, len(detections.pred[0]))]
     # print(detected_classes)
-    for det in results.pred[0]:
+    for det in detections.pred[0]:
         if int(det[-1]) == 15:  # cat
             print("Pred BB: ", end="")
             # print("x1:y1 : {}:{}".format(float(det[0]), float(det[1])))
@@ -165,23 +165,17 @@ def get_best_prediction(true_box, res, cls_nr):
 
 def calculate_csms(frame, predictions):
 
-    imgs_and_preds = []
-
-    for pred in predictions:
-        x1, y1, x2, y2, conf = pred[:5].float()
+    imgs_and_logits = []
 
+    for i in range(len(predictions.pred[0])):
+        x1, y1, x2, y2, conf = predictions.pred[0][i][:5].float()
         pred_img_section = frame.flip(2)[int(y1):int(y2), int(x1):int(x2), :]
-        tup = (pred_img_section, pred, frame, x1, y1, x2, y2)
-        # print(tup)
-        imgs_and_preds.append(tup)
+        tup = (pred_img_section, predictions.logits[i], frame, x1, y1, x2, y2)
+        imgs_and_logits.append(tup)
 
-        # if conf > 0.8:
-        #     cls = classes[int(pred[5:].argmax())]
-            # print(f"{cls}: {conf} - {pred[:5].float()}")
-            # show(frame.flip(2)[int(y1):int(y2), int(x1):int(x2), :] / 255.)
-            # print("done")
 
-    imgs, csms, cls = CSM.calc_yolo_csms(imgs_and_preds)
+    # TODO insert non_max_suppression
+    imgs, csms, cls = CSM.calc_yolo_person_csms(imgs_and_logits, rescale_factor=0, loss_rescale_factor=1000)
 
     return imgs, csms, cls
 
@@ -258,11 +252,11 @@ if __name__ == "__main__":
             if not (fix_frame and frame_read):
                 # resize our captured frame if we need
                 frame = cv2.resize(frame, None, fx=1.0, fy=1.0, interpolation=cv2.INTER_AREA)
-                frame_original = torch.tensor(frame, dtype=torch.float32, requires_grad=True).cuda()
+                frame_original = torch.tensor(frame, dtype=torch.float32, requires_grad=True, device="cuda")
                 frame = frame_original.clone()
                 frame_read = True
 
-            results = None
+            detections = None
             for _ in range(transform_interval):
                 ctr += 1
 
@@ -283,19 +277,20 @@ if __name__ == "__main__":
                 frame = patch_applier(frame_original, trans_patch)
 
                 # detect object on our frame
-                if ctr % 1 == 0 or results is None:
-                    results, raw_results = model.forward_pt(frame)
+                if ctr % 1 == 0 or detections is None:
+                    detections, raw_results = model.forward_pt(frame)
 
                 if ctr % 1 == 0:
                     # debug_preds()
                     pass
 
                 # calculate Cosine Similarity Matrix
-                imgs, csms, clss = calculate_csms(frame, raw_results)
+                # imgs, csms, clss = calculate_csms(frame, raw_results)
+                imgs, csms, clss = calculate_csms(frame, detections)
                 for i in range(len(csms)):
                     # show only person predictions
                     if clss[i] == 0:
-                        show([imgs[i]/255, csms[i].T])
+                        show([torch.min(torch.ones_like(imgs[i]), imgs[i]/255), csms[i].T])
 
                 # iou, pred = get_best_prediction(bounding_box, raw_results, 15)  # get cat
                 iou, pred = get_best_prediction(bounding_box, raw_results, 0)  # get personal
@@ -336,16 +331,16 @@ if __name__ == "__main__":
             # sgn_grads = torch.sign(optimizer.param_groups[0]['params'][0].grad)
             # optimizer.param_groups[0]['params'][0].grad = sgn_grads
             # optimizer.step()
-            patch.data -= torch.sign(gradient_sum) * 0.001  # * 0 # TODO reactivate
+            patch.data -= torch.sign(gradient_sum) * 0.001
             patch.data = patch.detach().clone().clamp(MIN_THRESHOLD, 0.99999).data
             gradient_sum = 0
 
             # show us frame with detection
             # cv2.imshow("img", results_np.render()[0])
             try:
-                cv2.imshow("img", results.render()[0])
-            except Exception:
-                print("catproblem")
+                cv2.imshow("img", detections.render()[0])
+            except Exception as e:
+                print(f"catproblem {e}")
 
             key = cv2.waitKey(25) & 0xFF
             if key == ord("q"):
diff --git a/models/common.py b/models/common.py
index da1c25b39225afd4efcda96d500d74632ae924c5..7fa5a3d60f4ae2c2e5712394d5474b1d8a751189 100644
--- a/models/common.py
+++ b/models/common.py
@@ -328,13 +328,15 @@ class AutoShape(nn.Module):
             t.append(time_sync())
 
             # Post-process
-            y_sup, y_raw = non_max_suppression(y, self.conf, iou_thres=self.iou, classes=self.classes, max_det=self.max_det)  # NMS
+            y_sup, y_raw, logits = non_max_suppression(y, self.conf, iou_thres=self.iou, classes=self.classes, max_det=self.max_det)  # NMS
             # for i in range(n):
             #     scale_coords(shape1, y[i][:, :4], shape0[i])
 
             t.append(time_sync())
             # return Detections(imgs, y, files, t, self.names, x.shape)
-            return Detections([img], y_sup, None, t, self.names, x.shape), y_raw
+            detections = Detections([img], y_sup, None, t, self.names, x.shape)
+            detections.logits = logits
+            return detections, y_raw
 
     @torch.no_grad()
     def forward(self, imgs, size=640, augment=False, profile=False):
diff --git a/utils/general.py b/utils/general.py
index 5cdf9af7a3bb83d15a706db31c39424d294a81a8..3ee167d92c1440abb842078e6e5e1118101d79b6 100755
--- a/utils/general.py
+++ b/utils/general.py
@@ -542,6 +542,7 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
     t = time.time()
     output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
     mad_output = None
+    logits = []
     for xi, x in enumerate(prediction):  # image index, image inference
         # Apply constraints
         # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0  # width-height
@@ -577,6 +578,7 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
             x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
         else:  # best class only
             conf, j = x[:, 5:].max(1, keepdim=True)
+            logits.append(x[:, 5:][conf.view(-1) > conf_thres])
             x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
 
         # Filter by class
@@ -609,11 +611,12 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
                 i = i[iou.sum(1) > 1]  # require redundancy
 
         output[xi] = x[i]
+        logits[xi] = [logits[0][x] for x in i]
         if (time.time() - t) > time_limit:
             print(f'WARNING: NMS time limit {time_limit}s exceeded')
             break  # time limit exceeded
 
-    return output, mad_output
+    return output, mad_output, logits[0]
 
 
 def strip_optimizer(f='best.pt', s=''):  # from utils.general import *; strip_optimizer()