Commit df1aa540 authored by Pavlo Beylin's avatar Pavlo Beylin
Browse files

Enable Backpropagation.

parent a9a98806
import numpy as np
from PIL import Image
import torch
import cv2
......@@ -121,7 +122,7 @@ if __name__ == "__main__":
raise IOError("We cannot open webcam")
patch = read_image(PATH)
# patch_np = torch.transpose(patch.T, 0, 1).numpy()
patch.requires_grad = True
img_size_x = 640
img_size_y = 480
......@@ -132,10 +133,11 @@ if __name__ == "__main__":
ret, frame = cap.read()
with torch.set_grad_enabled(True):
# with torch.autograd.detect_anomaly():
# resize our captured frame if we need
frame = cv2.resize(frame, None, fx=1.0, fy=1.0, interpolation=cv2.INTER_AREA)
frame = torch.tensor(frame).cuda()
# cv2.imshow("Web cam input", frame)
frame = torch.tensor(frame, dtype=torch.float32, requires_grad=True).cuda()
# transform patch (every couple of frames)
if ctr % 100 == 0:
......@@ -153,13 +155,8 @@ if __name__ == "__main__":
frame = patch_applier(frame, trans_patch)
# detect object on our frame
# results_np = model(frame.detach().cpu().numpy()) # for displaying
# frame = torch.transpose(frame, 0, 1).T.unsqueeze(0)
results = model.forward_pt(frame)
# Post-process
# y = models.common.non_max_suppression(y, conf_thres=model.conf, iou_thres=model.iou, classes=model.classes)
if ctr % 100 == 0:
debug_preds()
pass
......@@ -167,8 +164,9 @@ if __name__ == "__main__":
pred_box = get_best_prediction(bounding_box, results, 15) # get cats
if pred_box is not None:
# print("P:{}".format(pred_box[-2]))
# pred_box[-2].backwards()
print("P:{}".format(pred_box[-2]))
loss = -1 * pred_box[-2]
loss.backward(retain_graph=True)
pass
# show us frame with detection
......
......@@ -287,6 +287,7 @@ class AutoShape(nn.Module):
LOGGER.info('AutoShape already enabled, skipping... ') # model already converted to model.autoshape()
return self
@torch.enable_grad()
def forward_pt(self, img, size=640, augment=False, profile=False):
# Inference from various sources. For height=640, width=1280, RGB images example inputs are:
# file: imgs = 'data/images/zidane.jpg' # str or PosixPath
......@@ -327,13 +328,13 @@ class AutoShape(nn.Module):
t.append(time_sync())
# Post-process
y = non_max_suppression(y, self.conf, iou_thres=self.iou, classes=self.classes, max_det=self.max_det) # NMS
y_sup = non_max_suppression(y, self.conf, iou_thres=self.iou, classes=self.classes, max_det=self.max_det) # NMS
# for i in range(n):
# scale_coords(shape1, y[i][:, :4], shape0[i])
t.append(time_sync())
# return Detections(imgs, y, files, t, self.names, x.shape)
return Detections([img], y, None, t, self.names, x.shape)
return Detections([img], y_sup, None, t, self.names, x.shape)
@torch.no_grad()
def forward(self, imgs, size=640, augment=False, profile=False):
......@@ -411,6 +412,7 @@ class Detections:
self.t = tuple((times[i + 1] - times[i]) * 1000 / self.n for i in range(3)) # timestamps (ms)
self.s = shape # inference BCHW shape
@torch.no_grad()
def display(self, pprint=False, show=False, save=False, crop=False, render=False, save_dir=Path('')):
for i, (im, pred) in enumerate(zip(self.imgs, self.pred)):
str = f'image {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} '
......
......@@ -60,14 +60,15 @@ class Detect(nn.Module):
self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
y = x[i].sigmoid()
y_out = y.clone()
if self.inplace:
y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
y_out[..., 0:2] = (y_out[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
y_out[..., 2:4] = (y_out[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i].view(1, self.na, 1, 1, 2) # wh
y = torch.cat((xy, wh, y[..., 4:]), -1)
z.append(y.view(bs, -1, self.no))
xy = (y_out[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
wh = (y_out[..., 2:4] * 2) ** 2 * self.anchor_grid[i].view(1, self.na, 1, 1, 2) # wh
y_out = torch.cat((xy, wh, y_out[..., 4:]), -1)
z.append(y_out.view(bs, -1, self.no))
return x if self.training else (torch.cat(z, 1), x)
......
......@@ -560,7 +560,9 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
continue
# Compute conf
x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
# x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
x_hat = x.detach().clone()
x[:, 5:] *= x_hat[:, 4:5] # conf = obj_conf * cls_conf
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
box = xywh2xyxy(x[:, :4])
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment