Skip to content
Snippets Groups Projects
Select Git revision
  • poly-speedup
  • master default protected
  • debug-partition-size
  • wta-generator
  • fixes
  • bench-hex
  • ci-artifacts
  • new-monoids
  • stack
  • sumbag
  • tutorial
  • web
  • features/disable-sanity
  • ghc-8.4.4
  • linux-bin-artifacts
  • syntax-doc
  • ci-stack
  • rationals
  • double-round
  • init-time
  • group-weight
21 results

BenchLexer.hs

Blame
  • main.py 14.66 KiB
    import datetime
    
    import numpy as np
    from PIL import Image
    import torch
    import cv2
    import time
    import math
    import matplotlib
    from torch import optim
    
    import CSM
    import models
    from models.common import Detections
    from utils.external import TotalVariation
    from utils.general import scale_coords
    
    matplotlib.use('TkAgg')
    import matplotlib.pyplot as plt
    
    # Model
    from torchvision.transforms import transforms
    
    from patch_transformer import PatchTransformer, PatchApplier
    
    model = torch.hub.load('ultralytics/yolov5', 'yolov5l')  # or yolov5m, yolov5l, yolov5x, cu
    
    # model = torch.hub.load('ultralytics/yolov3', 'yolov3')
    
    MIN_THRESHOLD = 0.00001
    
    coco_class_names = ["person", "bicycle", "car", "motorbike", "aeroplane", "bus",
               "train", "truck", "boat", "traffic light", "fire hydrant",
               "stop sign", "parking meter", "bench", "bird", "cat", "dog",
               "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe",
               "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",
               "skis", "snowboard", "sports ball", "kite", "baseball bat",
               "baseball glove", "skateboard", "surfboard", "tennis racket",
               "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl",
               "banana", "apple", "sandwich", "orange", "broccoli", "carrot",
               "hot dog", "pizza", "donut", "cake", "chair", "sofa", "pottedplant",
               "bed", "diningtable", "toilet", "tvmonitor", "laptop", "mouse", "remote",
                        "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator",
                        "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"]
    
    PATH = "saved_patches/realcat.jpg"
    PATH = "saved_patches/fatcat.jpg"
    PATH = "saved_patches/smallcat.jpg"
    PATH = "saved_patches/person.jpg"
    PATCH_SIZE = 300
    
    total_variation = TotalVariation()
    
    
    def show(imgs):
        f, axarr = plt.subplots(2, len(imgs))
        for i in range(len(imgs)):
            try:
                axarr[0, i].imshow(imgs[i].detach().cpu())
            except:
                pass
        plt.show()
    
    
    def debug_preds():
        detected_classes = [int(detections.pred[0][i][-1]) for i in range(0, len(detections.pred[0]))]
        # print(detected_classes)
        for det in detections.pred[0]:
            if int(det[-1]) == 15:  # cat
                print("Pred BB: ", end="")
                # print("x1:y1 : {}:{}".format(float(det[0]), float(det[1])))
                # print("x2:y2 : {}:{}".format(float(det[2]), float(det[3])))
                print("{} {} {} {} ({}):".format(
                    int(det[0]), int(det[1]), int(det[2]), int(det[3]), float(det[-2])))
    
    
    # from https://github.com/wangzh0ng/adversarial_yolo2
    def read_image(path):
        """
        Read an input image to be used as a patch
    
        :param path: Path to the image to be read.
        :return: Returns the transformed patch as a pytorch Tensor.
        """
        patch_img = Image.open(path).convert('RGB')
        tf = transforms.Resize((PATCH_SIZE, PATCH_SIZE))
        patch_img = tf(patch_img)
        tf = transforms.ToTensor()
    
        return tf(patch_img)
    
    
    def extract_bounding_box(patch):
        mask = torch.where(patch < MIN_THRESHOLD, torch.zeros(patch.shape).cuda(), torch.ones(patch.shape).cuda()).sum(2)
    
        bb_x1 = torch.nonzero(mask.sum(0))[0]
        bb_y1 = torch.nonzero(mask.sum(1))[0]
        bb_x2 = torch.nonzero(mask.sum(0))[-1]
        bb_y2 = torch.nonzero(mask.sum(1))[-1]
    
        return torch.stack([bb_x1, bb_y1, bb_x2, bb_y2]).sum(1)
    
    
    def get_avg_prediction(res, cls_nr):
        avg_prediction = 0
    
        ctr = 0
        if res is None:
            return 0
    
        for pred in res:
            if pred[5:].max() > 0.4 or True:
                ctr += 1
                avg_prediction += pred[cls_nr + 5]
    
        return avg_prediction / (ctr if ctr > 0 else 1)
    
    
    # source https://www.pyimagesearch.com/2016/11/07/intersection-over-union-iou-for-object-detection/
    def bb_intersection_over_union(boxA, boxB):
        # determine the (x, y)-coordinates of the intersection rectangle
        xA = max(boxA[0], boxB[0])
        yA = max(boxA[1], boxB[1])
        xB = min(boxA[2], boxB[2])
        yB = min(boxA[3], boxB[3])
        # compute the area of intersection rectangle
        interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)
        # compute the area of both the prediction and ground-truth
        # rectangles
        boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
        boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
        # compute the intersection over union by taking the intersection
        # area and dividing it by the sum of prediction + ground-truth
        # areas - the interesection area
        iou = interArea / float(boxAArea + boxBArea - interArea)
        # return the intersection over union value
        return iou
    
    
    def save_image(image):
        print("save image called!")
        im = transforms.ToPILImage('RGB')(image)
        plt.imshow(im)
        plt.show()
        im.save(f"saved_patches/{time.time()}.jpg")
    
    
    def get_best_prediction(true_box, res, cls_nr):
        min_distance = float("inf")
        max_iou = float(0)
        best_prediction = None
    
        for pred in res:
            # pred_dist = torch.dist(true_box.cuda(), pred[:4])
            pred_iou = bb_intersection_over_union(true_box, pred[:4].float())
    
            if pred_iou >= max_iou:  # and pred[5:].max() > 0.1:
                max_iou = pred_iou
                best_prediction = pred[cls_nr + 5]
    
        # print(f"max found iou: {max_iou}")
    
        return max_iou, best_prediction
    
    
    def calculate_csms(frame, predictions):
    
        imgs_and_logits = []
    
        for i in range(len(predictions.pred[0])):
            x1, y1, x2, y2, conf = predictions.pred[0][i][:5].float()
            pred_img_section = frame.flip(2)[int(y1):int(y2), int(x1):int(x2), :]
            tup = (pred_img_section, predictions.logits[i], frame, x1, y1, x2, y2)
            imgs_and_logits.append(tup)
    
    
        # TODO insert non_max_suppression
        imgs, csms, cls = CSM.calc_yolo_person_csms(imgs_and_logits, rescale_factor=0, loss_rescale_factor=1000)
    
        return imgs, csms, cls
    
    
    if __name__ == "__main__":
        # init
        patch_transformer = PatchTransformer().cuda()
        patch_applier = PatchApplier().cuda()
    
        # set start time to current time
        start_time = time.time()
    
        # displays the frame rate every 2 second
        display_time = 2
    
        # Set primary FPS to 0
        fps = 0
    
        # we create the video capture object cap
        cap = cv2.VideoCapture(0)
        if not cap.isOpened():
            raise IOError("We cannot open webcam")
    
        patch = read_image(PATH)
        # patch = torch.rand_like(patch)
        patch.requires_grad = True
    
        optimizer = optim.Adam([patch], lr=0.0001, amsgrad=True)
        gradient_sum = 0
    
        # img_size_x = 640
        img_size_x = 480
        img_size_y = 480
    
        # Launch Settings
        # move = True
        # rotate = True
        # taper = True
        # resize = True
        # squeeze = True
        # gauss = True
        # obfuscate = True
        # stretch = True
    
        move = False
        rotate = False
        taper = False
        resize = True
        squeeze = False
        gauss = False
        obfuscate = False
        stretch = False
        transform_interval = 1
        angle_step = 5
        tv_factor = 1
    
        ctr = -1
        pred = -1
        frame_read = False
        fix_frame = False
        patch_transformer.maxangle = 5 / 180 * math.pi
        patch_transformer.minangle = - 5 / 180 * math.pi
        loss = None
        while True:
            if not (fix_frame and frame_read):
                ret, frame = cap.read()
    
                # cut image
                frame = frame[:, :img_size_x, :]
    
            with torch.set_grad_enabled(True):
                # with torch.autograd.detect_anomaly():
    
                if not (fix_frame and frame_read):
                    # resize our captured frame if we need
                    frame = cv2.resize(frame, None, fx=1.0, fy=1.0, interpolation=cv2.INTER_AREA)
                    frame_original = torch.tensor(frame, dtype=torch.float32, requires_grad=True, device="cuda")
                    frame = frame_original.clone()
                    frame_read = True
    
                detections = None
                for _ in range(transform_interval):
                    ctr += 1
    
                    # transform patch (every transform_interval of frames)
                    if ctr % 1 == 0:
                        trans_patch = patch_transformer(patch.cuda(), torch.ones([1, 14, 5]).cuda(), img_size_x, img_size_y,
                                                        do_rotate=rotate, rand_loc=move, rand_size=resize,
                                                        rand_squeeze=squeeze, gauss=gauss, obfuscate=obfuscate,
                                                        stretch=stretch, do_taper=taper)
    
                        # extract bounding box (x1, y1, x2, y2)
                        try:
                            bounding_box = extract_bounding_box(trans_patch)
                        except Exception:
                            print("zero-sized patch ... ")
    
                    # apply patch
                    frame = patch_applier(frame_original, trans_patch)
    
                    # detect object on our frame
                    if ctr % 1 == 0 or detections is None:
                        detections, raw_results = model.forward_pt(frame)
    
                    if ctr % 1 == 0:
                        # debug_preds()
                        pass
    
                    # calculate Cosine Similarity Matrix
                    # imgs, csms, clss = calculate_csms(frame, raw_results)
                    imgs, csms, clss = calculate_csms(frame, detections)
                    for i in range(len(csms)):
                        # show only person predictions
                        if clss[i] == 0:
                            show([torch.min(torch.ones_like(imgs[i]), imgs[i]/255), csms[i].T])
    
                    # iou, pred = get_best_prediction(bounding_box, raw_results, 15)  # get cat
                    iou, pred = get_best_prediction(bounding_box, raw_results, 0)  # get personal
                    # iou, pred = get_best_prediction(bounding_box, raw_results, 12)  # get parking meter
                    # iou, pred = get_best_prediction(bounding_box, raw_results, 11)  # get stop sign
                    # iou, pred = get_best_prediction(bounding_box, raw_results, 8)  # get boat
                    # iou, pred = get_best_prediction(bounding_box, raw_results, 62)  # get tv
                    # pred = get_best_prediction(bounding_box, raw_results, 42)  # get forked
    
                    # pred = get_avg_prediction(raw_results, 15)  # make everything cats
                    # pred = get_avg_prediction(raw_results, 0)  # make everything person
    
                    if pred is not None:
                        # print("P:{}".format(pred))
    
                        # loss
                        loss = -1 * pred  # optimize class
                        # loss = 1 * pred  # adversarial
    
                        # total variation loss component
                        tv_loss = total_variation(patch)
                        loss += tv_factor * tv_loss
    
                        # IoU loss component (low iou = high loss)
                        loss += 0.1 * (1 - iou)
    
                        if not isinstance(loss, torch.Tensor):
                            continue
    
                if loss is None:
                    print("loss is None")
                    continue
    
                loss.backward(retain_graph=True)
                loss = None
                gradient_sum += patch.grad
    
                # sgn_grads = torch.sign(optimizer.param_groups[0]['params'][0].grad)
                # optimizer.param_groups[0]['params'][0].grad = sgn_grads
                # optimizer.step()
                patch.data -= torch.sign(gradient_sum) * 0.001
                patch.data = patch.detach().clone().clamp(MIN_THRESHOLD, 0.99999).data
                gradient_sum = 0
    
                # show us frame with detection
                # cv2.imshow("img", results_np.render()[0])
                try:
                    cv2.imshow("img", detections.render()[0])
                except Exception as e:
                    print(f"catproblem {e}")
    
                key = cv2.waitKey(25) & 0xFF
                if key == ord("q"):
                    cv2.destroyAllWindows()
                    break
                if key == ord("u"):
                    move = not move
                    print("Move: {}".format(move))
                if key == ord("o"):
                    rotate = not rotate
                    print("Rotate: {}".format(rotate))
                if key == ord("t"):
                    resize = not resize
                    print("Resize: {}".format(resize))
                if key == ord("z"):
                    squeeze = not squeeze
                    print("Squeeze: {}".format(squeeze))
                if key == ord("g"):
                    gauss = not gauss
                    print("Gauss: {}".format(gauss))
                if key == ord("p"):
                    taper = not taper
                    print("Taper: {}".format(taper))
                if key == ord("h"):
                    obfuscate = not obfuscate
                    print(f"Obfuscate: {obfuscate}")
                if key == ord("e"):
                    stretch = not stretch
                    print(f"Obfuscate: {obfuscate}")
                if key == ord("+"):
                    # transform_interval += 1
                    patch_transformer.maxsize += 0.01
                    patch_transformer.minsize += 0.01
                    # print("Transform Interval: {}".format(transform_interval))
                    print(f"Size {patch_transformer.minsize}")
                if key == ord("-"):
                    # transform_interval -= 1
                    patch_transformer.maxsize -= 0.01
                    patch_transformer.minsize -= 0.01
                    print(f"Size {patch_transformer.minsize}")
                    # transform_interval = max(transform_interval, 1)
                    # print("Transform Interval: {}".format(transform_interval))
                if key == ord("9"):
                    patch_transformer.maxangle = min(patch_transformer.maxangle + (math.pi * angle_step / 180), math.pi)
                    patch_transformer.minangle = max(patch_transformer.minangle - (math.pi * angle_step / 180), -math.pi)
                    print("Transformer MaxAngle: {}°".format(patch_transformer.maxangle / math.pi * 180))
                if key == ord("3"):
                    patch_transformer.maxangle = max(patch_transformer.maxangle - (math.pi * angle_step / 180), 0)
                    patch_transformer.minangle = min(patch_transformer.minangle + (math.pi * angle_step / 180), 0)
                    print("Transformer MaxAngle: {}°".format(patch_transformer.maxangle / math.pi * 180))
                if key == ord("s"):
                    save_image(patch)
                if key == ord("f"):
                    fix_frame = not fix_frame
                    print("Fix Frame: {}".format(fix_frame))
                if key == ord("a"):
                    tv_factor += 1
                    print("Total Variation Loss Factor: {}".format(tv_factor))
                if key == ord("y"):
                    tv_factor -= 1
                    print("Total Variation Loss Factor: {}".format(tv_factor))
    
            # calculate FPS
            fps += 1
            TIME = time.time() - start_time
            if TIME > display_time:
                # print("FPS:", fps / TIME)
                fps = 0
                start_time = time.time()
            # time.sleep(0.2)
    
        cap.release()
        cv2.destroyAllWindows()