Commit a9a98806 authored by Pavlo Beylin's avatar Pavlo Beylin
Browse files

Switch to torch tensors.

parent 3d9e90a3
......@@ -4,6 +4,10 @@ import cv2
import time
import matplotlib
import models
from models.common import Detections
from utils.general import scale_coords
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
......@@ -70,14 +74,14 @@ def read_image(path):
def extract_bounding_box(patch):
mask = torch.where(torch.tensor(patch) < 0.1, torch.zeros(patch.shape), torch.ones(patch.shape)).sum(2)
mask = torch.where(patch < 0.1, torch.zeros(patch.shape).cuda(), torch.ones(patch.shape).cuda()).sum(2)
bb_x1 = mask.sum(0).nonzero()[0]
bb_y1 = mask.sum(1).nonzero()[0]
bb_x2 = mask.sum(0).nonzero()[-1]
bb_y2 = mask.sum(1).nonzero()[-1]
bb_x1 = torch.nonzero(mask.sum(0))[0]
bb_y1 = torch.nonzero(mask.sum(1))[0]
bb_x2 = torch.nonzero(mask.sum(0))[-1]
bb_y2 = torch.nonzero(mask.sum(1))[-1]
return torch.stack([bb_x1, bb_y1, bb_x2, bb_y2], axis=0).sum(1)
return torch.stack([bb_x1, bb_y1, bb_x2, bb_y2]).sum(1)
def get_best_prediction(true_box, res, cls_nr):
......@@ -85,7 +89,8 @@ def get_best_prediction(true_box, res, cls_nr):
best_prediction = None
for pred in res.pred[0]:
if int(pred[-1]) != cls_nr:
pred_cls_nr = int(pred[-1])
if pred_cls_nr != cls_nr:
continue
pred_dist = torch.dist(true_box.cuda(), pred[:4])
......@@ -116,7 +121,7 @@ if __name__ == "__main__":
raise IOError("We cannot open webcam")
patch = read_image(PATH)
patch_np = torch.transpose(patch.T, 0, 1).numpy()
# patch_np = torch.transpose(patch.T, 0, 1).numpy()
img_size_x = 640
img_size_y = 480
......@@ -126,36 +131,48 @@ if __name__ == "__main__":
ctr += 1
ret, frame = cap.read()
# resize our captured frame if we need
frame = cv2.resize(frame, None, fx=1.0, fy=1.0, interpolation=cv2.INTER_AREA)
frame = torch.tensor(frame).cuda()
# cv2.imshow("Web cam input", frame)
# transform patch (every couple of frames)
if ctr % 100 == 0:
trans_patch = patch_transformer(patch.cuda(), torch.ones([1, 14, 5]).cuda(), img_size_x, img_size_y,
do_rotate=True, rand_loc=True)
trans_patch_np = torch.transpose(trans_patch[0][0].T, 0, 1).detach().cpu().numpy()
# trans_patch_np = torch.transpose(trans_patch[0][0].T, 0, 1).detach().cpu().numpy()
trans_patch = torch.transpose(trans_patch[0][0].T, 0, 1)
# extract bounding box (x1, y1, x2, y2)
bounding_box = extract_bounding_box(trans_patch_np)
bounding_box = extract_bounding_box(trans_patch)
print("True BB: {} {} {} {}".format(int(bounding_box[0]), int(bounding_box[1]), int(bounding_box[2]),
int(bounding_box[3])))
# apply patch
frame = patch_applier(frame, trans_patch_np)
frame = patch_applier(frame, trans_patch)
# detect object on our frame
results = model(frame.copy())
# results_np = model(frame.detach().cpu().numpy()) # for displaying
# frame = torch.transpose(frame, 0, 1).T.unsqueeze(0)
results = model.forward_pt(frame)
# Post-process
# y = models.common.non_max_suppression(y, conf_thres=model.conf, iou_thres=model.iou, classes=model.classes)
if ctr % 100 == 0:
# debug_preds()
debug_preds()
pass
pred_box = get_best_prediction(bounding_box, results, 15) # get cats
if pred_box is not None:
print("P:{}".format(pred_box[-2]))
# print("P:{}".format(pred_box[-2]))
# pred_box[-2].backwards()
pass
# show us frame with detection
# cv2.imshow("img", results_np.render()[0])
cv2.imshow("img", results.render()[0])
if cv2.waitKey(25) & 0xFF == ord("q"):
cv2.destroyAllWindows()
......
......@@ -287,6 +287,54 @@ class AutoShape(nn.Module):
LOGGER.info('AutoShape already enabled, skipping... ') # model already converted to model.autoshape()
return self
def forward_pt(self, img, size=640, augment=False, profile=False):
# Inference from various sources. For height=640, width=1280, RGB images example inputs are:
# file: imgs = 'data/images/zidane.jpg' # str or PosixPath
# URI: = 'https://ultralytics.com/images/zidane.jpg'
# OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
# PIL: = Image.open('image.jpg') or ImageGrab.grab() # HWC x(640,1280,3)
# numpy: = np.zeros((640,1280,3)) # HWC
# torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
# multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
t = [time_sync()]
p = next(self.model.parameters()) # for device and type
# Pre-process
# n, imgs = (len(imgs), imgs) if isinstance(imgs, list) else (1, [imgs]) # number of images, list of images
# shape0, shape1, files = [], [], [] # image and inference shapes, filenames
# for i, im in enumerate(imgs):
# f = f'image{i}' # filename
# if im.shape[0] < 5: # image in CHW
# im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
# im = im[..., :3] if im.ndim == 3 else np.tile(im[..., None], 3) # enforce 3ch input
# s = im.shape[:2] # HWC
# shape0.append(s) # image shape
# g = (size / max(s)) # gain
# shape1.append([y * g for y in s])
# imgs[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
# shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape
# x = [letterbox(im, new_shape=shape1, auto=False)[0] for im in imgs] # pad
# x = np.stack(x, 0) if n > 1 else x[0][None] # stack
# x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW
x = img.type_as(p) / 255. # uint8 to fp16/32
x = torch.transpose(x, 0, 1).unsqueeze(-1).T
t.append(time_sync())
with amp.autocast(enabled=p.device.type != 'cpu'):
# Inference
y = self.model(x, augment, profile)[0] # forward
t.append(time_sync())
# Post-process
y = non_max_suppression(y, self.conf, iou_thres=self.iou, classes=self.classes, max_det=self.max_det) # NMS
# for i in range(n):
# scale_coords(shape1, y[i][:, :4], shape0[i])
t.append(time_sync())
# return Detections(imgs, y, files, t, self.names, x.shape)
return Detections([img], y, None, t, self.names, x.shape)
@torch.no_grad()
def forward(self, imgs, size=640, augment=False, profile=False):
# Inference from various sources. For height=640, width=1280, RGB images example inputs are:
......@@ -359,7 +407,8 @@ class Detections:
self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
self.n = len(self.pred) # number of images (batch size)
self.t = tuple((times[i + 1] - times[i]) * 1000 / self.n for i in range(3)) # timestamps (ms)
if times is not None:
self.t = tuple((times[i + 1] - times[i]) * 1000 / self.n for i in range(3)) # timestamps (ms)
self.s = shape # inference BCHW shape
def display(self, pprint=False, show=False, save=False, crop=False, render=False, save_dir=Path('')):
......
......@@ -16,8 +16,8 @@ class PatchApplier(nn.Module):
super(PatchApplier, self).__init__()
def forward(self, img, patch):
img = torch.where(torch.tensor(patch < 1e-05), torch.tensor(img)/256, torch.tensor(patch))*256
return img.detach().numpy()
img = torch.where(patch < 1e-05, img/256, patch) * 256
return img
class MedianPool2d(nn.Module):
......
......@@ -68,6 +68,9 @@ def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5):
def plot_one_box(box, im, color=(128, 128, 128), txt_color=(255, 255, 255), label=None, line_width=3, use_pil=False):
if isinstance(im, torch.Tensor):
im = im.detach().cpu().numpy()
# Plots one xyxy box on image im with label
assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to plot_on_box() input image.'
lw = line_width or max(int(min(im.size) / 200), 2) # line width
......@@ -85,6 +88,7 @@ def plot_one_box(box, im, color=(128, 128, 128), txt_color=(255, 255, 255), labe
else: # use OpenCV
c1, c2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
cv2.rectangle(im, c1, c2, color, thickness=lw, lineType=cv2.LINE_AA)
if label:
tf = max(lw - 1, 1) # font thickness
txt_width, txt_height = cv2.getTextSize(label, 0, fontScale=lw / 3, thickness=tf)[0]
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment