Unverified Commit 808bcad3 authored by Jiacong Fang's avatar Jiacong Fang Committed by GitHub
Browse files

Add TensorFlow and TFLite export (#1127)



* Add models/tf.py for TensorFlow and TFLite export

* Set auto=False for int8 calibration

* Update requirements.txt for TensorFlow and TFLite export

* Read anchors directly from PyTorch weights

* Add --tf-nms to append NMS in TensorFlow SavedModel and GraphDef export

* Remove check_anchor_order, check_file, set_logging from import

* Reformat code and optimize imports

* Autodownload model and check cfg

* update --source path, img-size to 320, single output

* Adjust representative_dataset

* Put representative dataset in tfl_int8 block

* detect.py TF inference

* weights to string

* weights to string

* cleanup tf.py

* Add --dynamic-batch-size

* Add xywh normalization to reduce calibration error

* Update requirements.txt

TensorFlow 2.3.1 -> 2.4.0 to avoid int8 quantization error

* Fix imports

Move C3 from models.experimental to models.common

* Add models/tf.py for TensorFlow and TFLite export

* Set auto=False for int8 calibration

* Update requirements.txt for TensorFlow and TFLite export

* Read anchors directly from PyTorch weights

* Add --tf-nms to append NMS in TensorFlow SavedModel and GraphDef export

* Remove check_anchor_order, check_file, set_logging from import

* Reformat code and optimize imports

* Autodownload model and check cfg

* update --source path, img-size to 320, single output

* Adjust representative_dataset

* detect.py TF inference

* Put representative dataset in tfl_int8 block

* weights to string

* weights to string

* cleanup tf.py

* Add --dynamic-batch-size

* Add xywh normalization to reduce calibration error

* Update requirements.txt

TensorFlow 2.3.1 -> 2.4.0 to avoid int8 quantization error

* Fix imports

Move C3 from models.experimental to models.common

* implement C3() and SiLU()

* Fix reshape dim to support dynamic batching

* Add epsilon argument in tf_BN, which is different between TF and PT

* Set stride to None if not using PyTorch, and do not warmup without PyTorch

* Add list support in check_img_size()

* Add list input support in detect.py

* sys.path.append('./') to run from yolov5/

* Add int8 quantization support for TensorFlow 2.5

* Add get_coco128.sh

* Remove --no-tfl-detect in models/tf.py (Use tf-android-tfl-detect branch for EdgeTPU)

* Update requirements.txt

* Replace torch.load() with attempt_load()

* Update requirements.txt

* Add --tf-raw-resize to set half_pixel_centers=False

* Add --agnostic-nms for TF class-agnostic NMS

* Cleanup after merge

* Cleanup2 after merge

* Cleanup3 after merge

* Add tf.py docstring with credit and usage

* pb saved_model and tflite use only one model in detect.py

* Add use cases in docstring of tf.py

* Remove redundant `stride` definition

* Remove keras direct import

* Fix `check_requirements(('tensorflow>=2.4.1',))`

Co-authored-by: default avatarGlenn Jocher <glenn.jocher@ultralytics.com>
parent f3e3f760
...@@ -12,6 +12,7 @@ import time ...@@ -12,6 +12,7 @@ import time
from pathlib import Path from pathlib import Path
import cv2 import cv2
import numpy as np
import torch import torch
import torch.backends.cudnn as cudnn import torch.backends.cudnn as cudnn
...@@ -51,6 +52,7 @@ def run(weights='yolov5s.pt', # model.pt path(s) ...@@ -51,6 +52,7 @@ def run(weights='yolov5s.pt', # model.pt path(s)
hide_labels=False, # hide labels hide_labels=False, # hide labels
hide_conf=False, # hide confidences hide_conf=False, # hide confidences
half=False, # use FP16 half-precision inference half=False, # use FP16 half-precision inference
tfl_int8=False, # INT8 quantized TFLite model
): ):
save_img = not nosave and not source.endswith('.txt') # save inference images save_img = not nosave and not source.endswith('.txt') # save inference images
webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith( webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith(
...@@ -68,7 +70,7 @@ def run(weights='yolov5s.pt', # model.pt path(s) ...@@ -68,7 +70,7 @@ def run(weights='yolov5s.pt', # model.pt path(s)
# Load model # Load model
w = weights[0] if isinstance(weights, list) else weights w = weights[0] if isinstance(weights, list) else weights
classify, suffix = False, Path(w).suffix.lower() classify, suffix = False, Path(w).suffix.lower()
pt, onnx, tflite, pb, graph_def = (suffix == x for x in ['.pt', '.onnx', '.tflite', '.pb', '']) # backend pt, onnx, tflite, pb, saved_model = (suffix == x for x in ['.pt', '.onnx', '.tflite', '.pb', '']) # backend
stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults
if pt: if pt:
model = attempt_load(weights, map_location=device) # load FP32 model model = attempt_load(weights, map_location=device) # load FP32 model
...@@ -83,30 +85,49 @@ def run(weights='yolov5s.pt', # model.pt path(s) ...@@ -83,30 +85,49 @@ def run(weights='yolov5s.pt', # model.pt path(s)
check_requirements(('onnx', 'onnxruntime')) check_requirements(('onnx', 'onnxruntime'))
import onnxruntime import onnxruntime
session = onnxruntime.InferenceSession(w, None) session = onnxruntime.InferenceSession(w, None)
else: # TensorFlow models
check_requirements(('tensorflow>=2.4.1',))
import tensorflow as tf
if pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
def wrap_frozen_graph(gd, inputs, outputs):
x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped import
return x.prune(tf.nest.map_structure(x.graph.as_graph_element, inputs),
tf.nest.map_structure(x.graph.as_graph_element, outputs))
graph_def = tf.Graph().as_graph_def()
graph_def.ParseFromString(open(w, 'rb').read())
frozen_func = wrap_frozen_graph(gd=graph_def, inputs="x:0", outputs="Identity:0")
elif saved_model:
model = tf.keras.models.load_model(w)
elif tflite:
interpreter = tf.lite.Interpreter(model_path=w) # load TFLite model
interpreter.allocate_tensors() # allocate
input_details = interpreter.get_input_details() # inputs
output_details = interpreter.get_output_details() # outputs
imgsz = check_img_size(imgsz, s=stride) # check image size imgsz = check_img_size(imgsz, s=stride) # check image size
# Dataloader # Dataloader
if webcam: if webcam:
view_img = check_imshow() view_img = check_imshow()
cudnn.benchmark = True # set True to speed up constant image size inference cudnn.benchmark = True # set True to speed up constant image size inference
dataset = LoadStreams(source, img_size=imgsz, stride=stride) dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt)
bs = len(dataset) # batch_size bs = len(dataset) # batch_size
else: else:
dataset = LoadImages(source, img_size=imgsz, stride=stride) dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt)
bs = 1 # batch_size bs = 1 # batch_size
vid_path, vid_writer = [None] * bs, [None] * bs vid_path, vid_writer = [None] * bs, [None] * bs
# Run inference # Run inference
if pt and device.type != 'cpu': if pt and device.type != 'cpu':
model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once model(torch.zeros(1, 3, *imgsz).to(device).type_as(next(model.parameters()))) # run once
t0 = time.time() t0 = time.time()
for path, img, im0s, vid_cap in dataset: for path, img, im0s, vid_cap in dataset:
if pt: if onnx:
img = img.astype('float32')
else:
img = torch.from_numpy(img).to(device) img = torch.from_numpy(img).to(device)
img = img.half() if half else img.float() # uint8 to fp16/32 img = img.half() if half else img.float() # uint8 to fp16/32
elif onnx: img = img / 255.0 # 0 - 255 to 0.0 - 1.0
img = img.astype('float32')
img /= 255.0 # 0 - 255 to 0.0 - 1.0
if len(img.shape) == 3: if len(img.shape) == 3:
img = img[None] # expand for batch dim img = img[None] # expand for batch dim
...@@ -117,6 +138,27 @@ def run(weights='yolov5s.pt', # model.pt path(s) ...@@ -117,6 +138,27 @@ def run(weights='yolov5s.pt', # model.pt path(s)
pred = model(img, augment=augment, visualize=visualize)[0] pred = model(img, augment=augment, visualize=visualize)[0]
elif onnx: elif onnx:
pred = torch.tensor(session.run([session.get_outputs()[0].name], {session.get_inputs()[0].name: img})) pred = torch.tensor(session.run([session.get_outputs()[0].name], {session.get_inputs()[0].name: img}))
else: # tensorflow model (tflite, pb, saved_model)
imn = img.permute(0, 2, 3, 1).cpu().numpy() # image in numpy
if pb:
pred = frozen_func(x=tf.constant(imn)).numpy()
elif saved_model:
pred = model(imn, training=False).numpy()
elif tflite:
if tfl_int8:
scale, zero_point = input_details[0]['quantization']
imn = (imn / scale + zero_point).astype(np.uint8)
interpreter.set_tensor(input_details[0]['index'], imn)
interpreter.invoke()
pred = interpreter.get_tensor(output_details[0]['index'])
if tfl_int8:
scale, zero_point = output_details[0]['quantization']
pred = (pred.astype(np.float32) - zero_point) * scale
pred[..., 0] *= imgsz[1] # x
pred[..., 1] *= imgsz[0] # y
pred[..., 2] *= imgsz[1] # w
pred[..., 3] *= imgsz[0] # h
pred = torch.tensor(pred)
# NMS # NMS
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det) pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
...@@ -202,9 +244,9 @@ def run(weights='yolov5s.pt', # model.pt path(s) ...@@ -202,9 +244,9 @@ def run(weights='yolov5s.pt', # model.pt path(s)
def parse_opt(): def parse_opt():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)') parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pb', help='model.pt path(s)')
parser.add_argument('--source', type=str, default='data/images', help='file/dir/URL/glob, 0 for webcam') parser.add_argument('--source', type=str, default='data/images', help='file/dir/URL/glob, 0 for webcam')
parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='inference size (pixels)') parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold') parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold')
parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS IoU threshold') parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS IoU threshold')
parser.add_argument('--max-det', type=int, default=1000, help='maximum detections per image') parser.add_argument('--max-det', type=int, default=1000, help='maximum detections per image')
...@@ -226,7 +268,9 @@ def parse_opt(): ...@@ -226,7 +268,9 @@ def parse_opt():
parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels') parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels')
parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences') parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences')
parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference') parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
parser.add_argument('--tfl-int8', action='store_true', help='INT8 quantized TFLite model')
opt = parser.parse_args() opt = parser.parse_args()
opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
return opt return opt
......
...@@ -85,14 +85,18 @@ class Ensemble(nn.ModuleList): ...@@ -85,14 +85,18 @@ class Ensemble(nn.ModuleList):
return y, None # inference, train output return y, None # inference, train output
def attempt_load(weights, map_location=None, inplace=True): def attempt_load(weights, map_location=None, inplace=True, fuse=True):
from models.yolo import Detect, Model from models.yolo import Detect, Model
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
model = Ensemble() model = Ensemble()
for w in weights if isinstance(weights, list) else [weights]: for w in weights if isinstance(weights, list) else [weights]:
ckpt = torch.load(attempt_download(w), map_location=map_location) # load ckpt = torch.load(attempt_download(w), map_location=map_location) # load
if fuse:
model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model
else:
model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().eval()) # without layer fuse
# Compatibility updates # Compatibility updates
for m in model.modules(): for m in model.modules():
......
This diff is collapsed.
...@@ -23,6 +23,7 @@ pandas ...@@ -23,6 +23,7 @@ pandas
# coremltools>=4.1 # coremltools>=4.1
# onnx>=1.9.0 # onnx>=1.9.0
# scikit-learn==0.19.2 # for coreml quantization # scikit-learn==0.19.2 # for coreml quantization
# tensorflow==2.4.1 # for TFLite export
# extras -------------------------------------- # extras --------------------------------------
# Cython # for pycocotools https://github.com/cocodataset/cocoapi/issues/172 # Cython # for pycocotools https://github.com/cocodataset/cocoapi/issues/172
......
...@@ -155,7 +155,7 @@ class _RepeatSampler(object): ...@@ -155,7 +155,7 @@ class _RepeatSampler(object):
class LoadImages: # for inference class LoadImages: # for inference
def __init__(self, path, img_size=640, stride=32): def __init__(self, path, img_size=640, stride=32, auto=True):
p = str(Path(path).absolute()) # os-agnostic absolute path p = str(Path(path).absolute()) # os-agnostic absolute path
if '*' in p: if '*' in p:
files = sorted(glob.glob(p, recursive=True)) # glob files = sorted(glob.glob(p, recursive=True)) # glob
...@@ -176,6 +176,7 @@ class LoadImages: # for inference ...@@ -176,6 +176,7 @@ class LoadImages: # for inference
self.nf = ni + nv # number of files self.nf = ni + nv # number of files
self.video_flag = [False] * ni + [True] * nv self.video_flag = [False] * ni + [True] * nv
self.mode = 'image' self.mode = 'image'
self.auto = auto
if any(videos): if any(videos):
self.new_video(videos[0]) # new video self.new_video(videos[0]) # new video
else: else:
...@@ -217,7 +218,7 @@ class LoadImages: # for inference ...@@ -217,7 +218,7 @@ class LoadImages: # for inference
print(f'image {self.count}/{self.nf} {path}: ', end='') print(f'image {self.count}/{self.nf} {path}: ', end='')
# Padded resize # Padded resize
img = letterbox(img0, self.img_size, stride=self.stride)[0] img = letterbox(img0, self.img_size, stride=self.stride, auto=self.auto)[0]
# Convert # Convert
img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
...@@ -276,7 +277,7 @@ class LoadWebcam: # for inference ...@@ -276,7 +277,7 @@ class LoadWebcam: # for inference
class LoadStreams: # multiple IP or RTSP cameras class LoadStreams: # multiple IP or RTSP cameras
def __init__(self, sources='streams.txt', img_size=640, stride=32): def __init__(self, sources='streams.txt', img_size=640, stride=32, auto=True):
self.mode = 'stream' self.mode = 'stream'
self.img_size = img_size self.img_size = img_size
self.stride = stride self.stride = stride
...@@ -290,6 +291,7 @@ class LoadStreams: # multiple IP or RTSP cameras ...@@ -290,6 +291,7 @@ class LoadStreams: # multiple IP or RTSP cameras
n = len(sources) n = len(sources)
self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
self.sources = [clean_str(x) for x in sources] # clean source names for later self.sources = [clean_str(x) for x in sources] # clean source names for later
self.auto = auto
for i, s in enumerate(sources): # index, source for i, s in enumerate(sources): # index, source
# Start thread to read frames from video stream # Start thread to read frames from video stream
print(f'{i + 1}/{n}: {s}... ', end='') print(f'{i + 1}/{n}: {s}... ', end='')
...@@ -312,7 +314,7 @@ class LoadStreams: # multiple IP or RTSP cameras ...@@ -312,7 +314,7 @@ class LoadStreams: # multiple IP or RTSP cameras
print('') # newline print('') # newline
# check for common shapes # check for common shapes
s = np.stack([letterbox(x, self.img_size, stride=self.stride)[0].shape for x in self.imgs], 0) # shapes s = np.stack([letterbox(x, self.img_size, stride=self.stride, auto=self.auto)[0].shape for x in self.imgs], 0) # shapes
self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
if not self.rect: if not self.rect:
print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.') print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.')
...@@ -341,7 +343,7 @@ class LoadStreams: # multiple IP or RTSP cameras ...@@ -341,7 +343,7 @@ class LoadStreams: # multiple IP or RTSP cameras
# Letterbox # Letterbox
img0 = self.imgs.copy() img0 = self.imgs.copy()
img = [letterbox(x, self.img_size, auto=self.rect, stride=self.stride)[0] for x in img0] img = [letterbox(x, self.img_size, stride=self.stride, auto=self.rect and self.auto)[0] for x in img0]
# Stack # Stack
img = np.stack(img, 0) img = np.stack(img, 0)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment