Commit df1aa540 authored by Pavlo Beylin's avatar Pavlo Beylin
Browse files

Enable Backpropagation.

parent a9a98806
import numpy as np
from PIL import Image from PIL import Image
import torch import torch
import cv2 import cv2
...@@ -121,7 +122,7 @@ if __name__ == "__main__": ...@@ -121,7 +122,7 @@ if __name__ == "__main__":
raise IOError("We cannot open webcam") raise IOError("We cannot open webcam")
patch = read_image(PATH) patch = read_image(PATH)
# patch_np = torch.transpose(patch.T, 0, 1).numpy() patch.requires_grad = True
img_size_x = 640 img_size_x = 640
img_size_y = 480 img_size_y = 480
...@@ -132,51 +133,48 @@ if __name__ == "__main__": ...@@ -132,51 +133,48 @@ if __name__ == "__main__":
ret, frame = cap.read() ret, frame = cap.read()
# resize our captured frame if we need with torch.set_grad_enabled(True):
frame = cv2.resize(frame, None, fx=1.0, fy=1.0, interpolation=cv2.INTER_AREA) # with torch.autograd.detect_anomaly():
frame = torch.tensor(frame).cuda() # resize our captured frame if we need
# cv2.imshow("Web cam input", frame) frame = cv2.resize(frame, None, fx=1.0, fy=1.0, interpolation=cv2.INTER_AREA)
frame = torch.tensor(frame, dtype=torch.float32, requires_grad=True).cuda()
# transform patch (every couple of frames)
if ctr % 100 == 0: # transform patch (every couple of frames)
trans_patch = patch_transformer(patch.cuda(), torch.ones([1, 14, 5]).cuda(), img_size_x, img_size_y, if ctr % 100 == 0:
do_rotate=True, rand_loc=True) trans_patch = patch_transformer(patch.cuda(), torch.ones([1, 14, 5]).cuda(), img_size_x, img_size_y,
# trans_patch_np = torch.transpose(trans_patch[0][0].T, 0, 1).detach().cpu().numpy() do_rotate=True, rand_loc=True)
trans_patch = torch.transpose(trans_patch[0][0].T, 0, 1) # trans_patch_np = torch.transpose(trans_patch[0][0].T, 0, 1).detach().cpu().numpy()
trans_patch = torch.transpose(trans_patch[0][0].T, 0, 1)
# extract bounding box (x1, y1, x2, y2)
bounding_box = extract_bounding_box(trans_patch) # extract bounding box (x1, y1, x2, y2)
print("True BB: {} {} {} {}".format(int(bounding_box[0]), int(bounding_box[1]), int(bounding_box[2]), bounding_box = extract_bounding_box(trans_patch)
int(bounding_box[3]))) print("True BB: {} {} {} {}".format(int(bounding_box[0]), int(bounding_box[1]), int(bounding_box[2]),
int(bounding_box[3])))
# apply patch
frame = patch_applier(frame, trans_patch) # apply patch
frame = patch_applier(frame, trans_patch)
# detect object on our frame
# results_np = model(frame.detach().cpu().numpy()) # for displaying # detect object on our frame
# frame = torch.transpose(frame, 0, 1).T.unsqueeze(0) results = model.forward_pt(frame)
results = model.forward_pt(frame)
if ctr % 100 == 0:
# Post-process debug_preds()
# y = models.common.non_max_suppression(y, conf_thres=model.conf, iou_thres=model.iou, classes=model.classes) pass
if ctr % 100 == 0: pred_box = get_best_prediction(bounding_box, results, 15) # get cats
debug_preds()
pass if pred_box is not None:
print("P:{}".format(pred_box[-2]))
pred_box = get_best_prediction(bounding_box, results, 15) # get cats loss = -1 * pred_box[-2]
loss.backward(retain_graph=True)
if pred_box is not None: pass
# print("P:{}".format(pred_box[-2]))
# pred_box[-2].backwards() # show us frame with detection
pass # cv2.imshow("img", results_np.render()[0])
cv2.imshow("img", results.render()[0])
# show us frame with detection if cv2.waitKey(25) & 0xFF == ord("q"):
# cv2.imshow("img", results_np.render()[0]) cv2.destroyAllWindows()
cv2.imshow("img", results.render()[0]) break
if cv2.waitKey(25) & 0xFF == ord("q"):
cv2.destroyAllWindows()
break
# calculate FPS # calculate FPS
fps += 1 fps += 1
......
...@@ -287,6 +287,7 @@ class AutoShape(nn.Module): ...@@ -287,6 +287,7 @@ class AutoShape(nn.Module):
LOGGER.info('AutoShape already enabled, skipping... ') # model already converted to model.autoshape() LOGGER.info('AutoShape already enabled, skipping... ') # model already converted to model.autoshape()
return self return self
@torch.enable_grad()
def forward_pt(self, img, size=640, augment=False, profile=False): def forward_pt(self, img, size=640, augment=False, profile=False):
# Inference from various sources. For height=640, width=1280, RGB images example inputs are: # Inference from various sources. For height=640, width=1280, RGB images example inputs are:
# file: imgs = 'data/images/zidane.jpg' # str or PosixPath # file: imgs = 'data/images/zidane.jpg' # str or PosixPath
...@@ -327,13 +328,13 @@ class AutoShape(nn.Module): ...@@ -327,13 +328,13 @@ class AutoShape(nn.Module):
t.append(time_sync()) t.append(time_sync())
# Post-process # Post-process
y = non_max_suppression(y, self.conf, iou_thres=self.iou, classes=self.classes, max_det=self.max_det) # NMS y_sup = non_max_suppression(y, self.conf, iou_thres=self.iou, classes=self.classes, max_det=self.max_det) # NMS
# for i in range(n): # for i in range(n):
# scale_coords(shape1, y[i][:, :4], shape0[i]) # scale_coords(shape1, y[i][:, :4], shape0[i])
t.append(time_sync()) t.append(time_sync())
# return Detections(imgs, y, files, t, self.names, x.shape) # return Detections(imgs, y, files, t, self.names, x.shape)
return Detections([img], y, None, t, self.names, x.shape) return Detections([img], y_sup, None, t, self.names, x.shape)
@torch.no_grad() @torch.no_grad()
def forward(self, imgs, size=640, augment=False, profile=False): def forward(self, imgs, size=640, augment=False, profile=False):
...@@ -411,6 +412,7 @@ class Detections: ...@@ -411,6 +412,7 @@ class Detections:
self.t = tuple((times[i + 1] - times[i]) * 1000 / self.n for i in range(3)) # timestamps (ms) self.t = tuple((times[i + 1] - times[i]) * 1000 / self.n for i in range(3)) # timestamps (ms)
self.s = shape # inference BCHW shape self.s = shape # inference BCHW shape
@torch.no_grad()
def display(self, pprint=False, show=False, save=False, crop=False, render=False, save_dir=Path('')): def display(self, pprint=False, show=False, save=False, crop=False, render=False, save_dir=Path('')):
for i, (im, pred) in enumerate(zip(self.imgs, self.pred)): for i, (im, pred) in enumerate(zip(self.imgs, self.pred)):
str = f'image {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} ' str = f'image {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} '
......
...@@ -60,14 +60,15 @@ class Detect(nn.Module): ...@@ -60,14 +60,15 @@ class Detect(nn.Module):
self.grid[i] = self._make_grid(nx, ny).to(x[i].device) self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
y = x[i].sigmoid() y = x[i].sigmoid()
y_out = y.clone()
if self.inplace: if self.inplace:
y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy y_out[..., 0:2] = (y_out[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh y_out[..., 2:4] = (y_out[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953 else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy xy = (y_out[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i].view(1, self.na, 1, 1, 2) # wh wh = (y_out[..., 2:4] * 2) ** 2 * self.anchor_grid[i].view(1, self.na, 1, 1, 2) # wh
y = torch.cat((xy, wh, y[..., 4:]), -1) y_out = torch.cat((xy, wh, y_out[..., 4:]), -1)
z.append(y.view(bs, -1, self.no)) z.append(y_out.view(bs, -1, self.no))
return x if self.training else (torch.cat(z, 1), x) return x if self.training else (torch.cat(z, 1), x)
......
...@@ -560,7 +560,9 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non ...@@ -560,7 +560,9 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
continue continue
# Compute conf # Compute conf
x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf # x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
x_hat = x.detach().clone()
x[:, 5:] *= x_hat[:, 4:5] # conf = obj_conf * cls_conf
# Box (center x, center y, width, height) to (x1, y1, x2, y2) # Box (center x, center y, width, height) to (x1, y1, x2, y2)
box = xywh2xyxy(x[:, :4]) box = xywh2xyxy(x[:, :4])
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment