diff --git a/detect.py b/detect.py
index cdac4f2137901cbe5c817ec265beb3a2a7cc1703..a2331e23b43e391fcd2287dee031a96924ab601c 100644
--- a/detect.py
+++ b/detect.py
@@ -12,6 +12,7 @@ import time
 from pathlib import Path
 
 import cv2
+import numpy as np
 import torch
 import torch.backends.cudnn as cudnn
 
@@ -51,6 +52,7 @@ def run(weights='yolov5s.pt',  # model.pt path(s)
         hide_labels=False,  # hide labels
         hide_conf=False,  # hide confidences
         half=False,  # use FP16 half-precision inference
+        tfl_int8=False,  # INT8 quantized TFLite model
         ):
     save_img = not nosave and not source.endswith('.txt')  # save inference images
     webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith(
@@ -68,7 +70,7 @@ def run(weights='yolov5s.pt',  # model.pt path(s)
     # Load model
     w = weights[0] if isinstance(weights, list) else weights
     classify, suffix = False, Path(w).suffix.lower()
-    pt, onnx, tflite, pb, graph_def = (suffix == x for x in ['.pt', '.onnx', '.tflite', '.pb', ''])  # backend
+    pt, onnx, tflite, pb, saved_model = (suffix == x for x in ['.pt', '.onnx', '.tflite', '.pb', ''])  # backend
     stride, names = 64, [f'class{i}' for i in range(1000)]  # assign defaults
     if pt:
         model = attempt_load(weights, map_location=device)  # load FP32 model
@@ -83,30 +85,49 @@ def run(weights='yolov5s.pt',  # model.pt path(s)
         check_requirements(('onnx', 'onnxruntime'))
         import onnxruntime
         session = onnxruntime.InferenceSession(w, None)
+    else:  # TensorFlow models
+        check_requirements(('tensorflow>=2.4.1',))
+        import tensorflow as tf
+        if pb:  # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
+            def wrap_frozen_graph(gd, inputs, outputs):
+                x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), [])  # wrapped import
+                return x.prune(tf.nest.map_structure(x.graph.as_graph_element, inputs),
+                               tf.nest.map_structure(x.graph.as_graph_element, outputs))
+
+            graph_def = tf.Graph().as_graph_def()
+            graph_def.ParseFromString(open(w, 'rb').read())
+            frozen_func = wrap_frozen_graph(gd=graph_def, inputs="x:0", outputs="Identity:0")
+        elif saved_model:
+            model = tf.keras.models.load_model(w)
+        elif tflite:
+            interpreter = tf.lite.Interpreter(model_path=w)  # load TFLite model
+            interpreter.allocate_tensors()  # allocate
+            input_details = interpreter.get_input_details()  # inputs
+            output_details = interpreter.get_output_details()  # outputs
     imgsz = check_img_size(imgsz, s=stride)  # check image size
 
     # Dataloader
     if webcam:
         view_img = check_imshow()
         cudnn.benchmark = True  # set True to speed up constant image size inference
-        dataset = LoadStreams(source, img_size=imgsz, stride=stride)
+        dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt)
         bs = len(dataset)  # batch_size
     else:
-        dataset = LoadImages(source, img_size=imgsz, stride=stride)
+        dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt)
         bs = 1  # batch_size
     vid_path, vid_writer = [None] * bs, [None] * bs
 
     # Run inference
     if pt and device.type != 'cpu':
-        model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters())))  # run once
+        model(torch.zeros(1, 3, *imgsz).to(device).type_as(next(model.parameters())))  # run once
     t0 = time.time()
     for path, img, im0s, vid_cap in dataset:
-        if pt:
+        if onnx:
+            img = img.astype('float32')
+        else:
             img = torch.from_numpy(img).to(device)
             img = img.half() if half else img.float()  # uint8 to fp16/32
-        elif onnx:
-            img = img.astype('float32')
-        img /= 255.0  # 0 - 255 to 0.0 - 1.0
+        img = img / 255.0  # 0 - 255 to 0.0 - 1.0
         if len(img.shape) == 3:
             img = img[None]  # expand for batch dim
 
@@ -117,6 +138,27 @@ def run(weights='yolov5s.pt',  # model.pt path(s)
             pred = model(img, augment=augment, visualize=visualize)[0]
         elif onnx:
             pred = torch.tensor(session.run([session.get_outputs()[0].name], {session.get_inputs()[0].name: img}))
+        else:  # tensorflow model (tflite, pb, saved_model)
+            imn = img.permute(0, 2, 3, 1).cpu().numpy()  # image in numpy
+            if pb:
+                pred = frozen_func(x=tf.constant(imn)).numpy()
+            elif saved_model:
+                pred = model(imn, training=False).numpy()
+            elif tflite:
+                if tfl_int8:
+                    scale, zero_point = input_details[0]['quantization']
+                    imn = (imn / scale + zero_point).astype(np.uint8)
+                interpreter.set_tensor(input_details[0]['index'], imn)
+                interpreter.invoke()
+                pred = interpreter.get_tensor(output_details[0]['index'])
+                if tfl_int8:
+                    scale, zero_point = output_details[0]['quantization']
+                    pred = (pred.astype(np.float32) - zero_point) * scale
+            pred[..., 0] *= imgsz[1]  # x
+            pred[..., 1] *= imgsz[0]  # y
+            pred[..., 2] *= imgsz[1]  # w
+            pred[..., 3] *= imgsz[0]  # h
+            pred = torch.tensor(pred)
 
         # NMS
         pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
@@ -202,9 +244,9 @@ def run(weights='yolov5s.pt',  # model.pt path(s)
 
 def parse_opt():
     parser = argparse.ArgumentParser()
-    parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pt', help='model.pt path(s)')
+    parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.pb', help='model.pt path(s)')
     parser.add_argument('--source', type=str, default='data/images', help='file/dir/URL/glob, 0 for webcam')
-    parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='inference size (pixels)')
+    parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
     parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold')
     parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS IoU threshold')
     parser.add_argument('--max-det', type=int, default=1000, help='maximum detections per image')
@@ -226,7 +268,9 @@ def parse_opt():
     parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels')
     parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences')
     parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
+    parser.add_argument('--tfl-int8', action='store_true', help='INT8 quantized TFLite model')
     opt = parser.parse_args()
+    opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1  # expand
     return opt
 
 
diff --git a/models/experimental.py b/models/experimental.py
index 7dfaf9611bec8349d4e38e168ef9af8c00237cd4..e25a4e1779fa7847f79e8570e7cf0a23e845a9f0 100644
--- a/models/experimental.py
+++ b/models/experimental.py
@@ -85,14 +85,18 @@ class Ensemble(nn.ModuleList):
         return y, None  # inference, train output
 
 
-def attempt_load(weights, map_location=None, inplace=True):
+def attempt_load(weights, map_location=None, inplace=True, fuse=True):
     from models.yolo import Detect, Model
 
     # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
     model = Ensemble()
     for w in weights if isinstance(weights, list) else [weights]:
         ckpt = torch.load(attempt_download(w), map_location=map_location)  # load
-        model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval())  # FP32 model
+        if fuse:
+            model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval())  # FP32 model
+        else:
+            model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().eval())  # without layer fuse
+
 
     # Compatibility updates
     for m in model.modules():
diff --git a/models/tf.py b/models/tf.py
new file mode 100644
index 0000000000000000000000000000000000000000..40e7d20a9d846da4f311e33e8f099ea25a7e0ac4
--- /dev/null
+++ b/models/tf.py
@@ -0,0 +1,558 @@
+# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
+"""
+TensorFlow/Keras and TFLite versions of YOLOv5
+Authored by https://github.com/zldrobit in PR https://github.com/ultralytics/yolov5/pull/1127
+
+Usage:
+    $ python models/tf.py --weights yolov5s.pt --cfg yolov5s.yaml
+
+Export int8 TFLite models:
+    $ python models/tf.py --weights yolov5s.pt --cfg models/yolov5s.yaml --tfl-int8 \
+        --source path/to/images/ --ncalib 100
+
+Detection:
+    $ python detect.py --weights yolov5s.pb          --img 320
+    $ python detect.py --weights yolov5s_saved_model --img 320
+    $ python detect.py --weights yolov5s-fp16.tflite --img 320
+    $ python detect.py --weights yolov5s-int8.tflite --img 320 --tfl-int8
+
+For TensorFlow.js:
+    $ python models/tf.py --weights yolov5s.pt --cfg models/yolov5s.yaml --img 320 --tf-nms --agnostic-nms
+    $ pip install tensorflowjs
+    $ tensorflowjs_converter \
+          --input_format=tf_frozen_model \
+          --output_node_names='Identity,Identity_1,Identity_2,Identity_3' \
+          yolov5s.pb \
+          web_model
+    $ # Edit web_model/model.json to sort Identity* in ascending order
+    $ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
+    $ npm install
+    $ ln -s ../../yolov5/web_model public/web_model
+    $ npm start
+"""
+
+import argparse
+import logging
+import os
+import sys
+import traceback
+from copy import deepcopy
+from pathlib import Path
+
+sys.path.append('./')  # to run '$ python *.py' files in subdirectories
+
+import numpy as np
+import tensorflow as tf
+import torch
+import torch.nn as nn
+import yaml
+from tensorflow import keras
+from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
+
+from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat, autopad, C3
+from models.experimental import MixConv2d, CrossConv, attempt_load
+from models.yolo import Detect
+from utils.datasets import LoadImages
+from utils.general import make_divisible, check_file, check_dataset
+
+logger = logging.getLogger(__name__)
+
+
+class tf_BN(keras.layers.Layer):
+    # TensorFlow BatchNormalization wrapper
+    def __init__(self, w=None):
+        super(tf_BN, self).__init__()
+        self.bn = keras.layers.BatchNormalization(
+            beta_initializer=keras.initializers.Constant(w.bias.numpy()),
+            gamma_initializer=keras.initializers.Constant(w.weight.numpy()),
+            moving_mean_initializer=keras.initializers.Constant(w.running_mean.numpy()),
+            moving_variance_initializer=keras.initializers.Constant(w.running_var.numpy()),
+            epsilon=w.eps)
+
+    def call(self, inputs):
+        return self.bn(inputs)
+
+
+class tf_Pad(keras.layers.Layer):
+    def __init__(self, pad):
+        super(tf_Pad, self).__init__()
+        self.pad = tf.constant([[0, 0], [pad, pad], [pad, pad], [0, 0]])
+
+    def call(self, inputs):
+        return tf.pad(inputs, self.pad, mode='constant', constant_values=0)
+
+
+class tf_Conv(keras.layers.Layer):
+    # Standard convolution
+    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
+        # ch_in, ch_out, weights, kernel, stride, padding, groups
+        super(tf_Conv, self).__init__()
+        assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"
+        assert isinstance(k, int), "Convolution with multiple kernels are not allowed."
+        # TensorFlow convolution padding is inconsistent with PyTorch (e.g. k=3 s=2 'SAME' padding)
+        # see https://stackoverflow.com/questions/52975843/comparing-conv2d-with-padding-between-tensorflow-and-pytorch
+
+        conv = keras.layers.Conv2D(
+            c2, k, s, 'SAME' if s == 1 else 'VALID', use_bias=False,
+            kernel_initializer=keras.initializers.Constant(w.conv.weight.permute(2, 3, 1, 0).numpy()))
+        self.conv = conv if s == 1 else keras.Sequential([tf_Pad(autopad(k, p)), conv])
+        self.bn = tf_BN(w.bn) if hasattr(w, 'bn') else tf.identity
+
+        # YOLOv5 activations
+        if isinstance(w.act, nn.LeakyReLU):
+            self.act = (lambda x: keras.activations.relu(x, alpha=0.1)) if act else tf.identity
+        elif isinstance(w.act, nn.Hardswish):
+            self.act = (lambda x: x * tf.nn.relu6(x + 3) * 0.166666667) if act else tf.identity
+        elif isinstance(w.act, nn.SiLU):
+            self.act = (lambda x: keras.activations.swish(x)) if act else tf.identity
+
+    def call(self, inputs):
+        return self.act(self.bn(self.conv(inputs)))
+
+
+class tf_Focus(keras.layers.Layer):
+    # Focus wh information into c-space
+    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
+        # ch_in, ch_out, kernel, stride, padding, groups
+        super(tf_Focus, self).__init__()
+        self.conv = tf_Conv(c1 * 4, c2, k, s, p, g, act, w.conv)
+
+    def call(self, inputs):  # x(b,w,h,c) -> y(b,w/2,h/2,4c)
+        # inputs = inputs / 255.  # normalize 0-255 to 0-1
+        return self.conv(tf.concat([inputs[:, ::2, ::2, :],
+                                    inputs[:, 1::2, ::2, :],
+                                    inputs[:, ::2, 1::2, :],
+                                    inputs[:, 1::2, 1::2, :]], 3))
+
+
+class tf_Bottleneck(keras.layers.Layer):
+    # Standard bottleneck
+    def __init__(self, c1, c2, shortcut=True, g=1, e=0.5, w=None):  # ch_in, ch_out, shortcut, groups, expansion
+        super(tf_Bottleneck, self).__init__()
+        c_ = int(c2 * e)  # hidden channels
+        self.cv1 = tf_Conv(c1, c_, 1, 1, w=w.cv1)
+        self.cv2 = tf_Conv(c_, c2, 3, 1, g=g, w=w.cv2)
+        self.add = shortcut and c1 == c2
+
+    def call(self, inputs):
+        return inputs + self.cv2(self.cv1(inputs)) if self.add else self.cv2(self.cv1(inputs))
+
+
+class tf_Conv2d(keras.layers.Layer):
+    # Substitution for PyTorch nn.Conv2D
+    def __init__(self, c1, c2, k, s=1, g=1, bias=True, w=None):
+        super(tf_Conv2d, self).__init__()
+        assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"
+        self.conv = keras.layers.Conv2D(
+            c2, k, s, 'VALID', use_bias=bias,
+            kernel_initializer=keras.initializers.Constant(w.weight.permute(2, 3, 1, 0).numpy()),
+            bias_initializer=keras.initializers.Constant(w.bias.numpy()) if bias else None, )
+
+    def call(self, inputs):
+        return self.conv(inputs)
+
+
+class tf_BottleneckCSP(keras.layers.Layer):
+    # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
+    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
+        # ch_in, ch_out, number, shortcut, groups, expansion
+        super(tf_BottleneckCSP, self).__init__()
+        c_ = int(c2 * e)  # hidden channels
+        self.cv1 = tf_Conv(c1, c_, 1, 1, w=w.cv1)
+        self.cv2 = tf_Conv2d(c1, c_, 1, 1, bias=False, w=w.cv2)
+        self.cv3 = tf_Conv2d(c_, c_, 1, 1, bias=False, w=w.cv3)
+        self.cv4 = tf_Conv(2 * c_, c2, 1, 1, w=w.cv4)
+        self.bn = tf_BN(w.bn)
+        self.act = lambda x: keras.activations.relu(x, alpha=0.1)
+        self.m = keras.Sequential([tf_Bottleneck(c_, c_, shortcut, g, e=1.0, w=w.m[j]) for j in range(n)])
+
+    def call(self, inputs):
+        y1 = self.cv3(self.m(self.cv1(inputs)))
+        y2 = self.cv2(inputs)
+        return self.cv4(self.act(self.bn(tf.concat((y1, y2), axis=3))))
+
+
+class tf_C3(keras.layers.Layer):
+    # CSP Bottleneck with 3 convolutions
+    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
+        # ch_in, ch_out, number, shortcut, groups, expansion
+        super(tf_C3, self).__init__()
+        c_ = int(c2 * e)  # hidden channels
+        self.cv1 = tf_Conv(c1, c_, 1, 1, w=w.cv1)
+        self.cv2 = tf_Conv(c1, c_, 1, 1, w=w.cv2)
+        self.cv3 = tf_Conv(2 * c_, c2, 1, 1, w=w.cv3)
+        self.m = keras.Sequential([tf_Bottleneck(c_, c_, shortcut, g, e=1.0, w=w.m[j]) for j in range(n)])
+
+    def call(self, inputs):
+        return self.cv3(tf.concat((self.m(self.cv1(inputs)), self.cv2(inputs)), axis=3))
+
+
+class tf_SPP(keras.layers.Layer):
+    # Spatial pyramid pooling layer used in YOLOv3-SPP
+    def __init__(self, c1, c2, k=(5, 9, 13), w=None):
+        super(tf_SPP, self).__init__()
+        c_ = c1 // 2  # hidden channels
+        self.cv1 = tf_Conv(c1, c_, 1, 1, w=w.cv1)
+        self.cv2 = tf_Conv(c_ * (len(k) + 1), c2, 1, 1, w=w.cv2)
+        self.m = [keras.layers.MaxPool2D(pool_size=x, strides=1, padding='SAME') for x in k]
+
+    def call(self, inputs):
+        x = self.cv1(inputs)
+        return self.cv2(tf.concat([x] + [m(x) for m in self.m], 3))
+
+
+class tf_Detect(keras.layers.Layer):
+    def __init__(self, nc=80, anchors=(), ch=(), w=None):  # detection layer
+        super(tf_Detect, self).__init__()
+        self.stride = tf.convert_to_tensor(w.stride.numpy(), dtype=tf.float32)
+        self.nc = nc  # number of classes
+        self.no = nc + 5  # number of outputs per anchor
+        self.nl = len(anchors)  # number of detection layers
+        self.na = len(anchors[0]) // 2  # number of anchors
+        self.grid = [tf.zeros(1)] * self.nl  # init grid
+        self.anchors = tf.convert_to_tensor(w.anchors.numpy(), dtype=tf.float32)
+        self.anchor_grid = tf.reshape(tf.convert_to_tensor(w.anchor_grid.numpy(), dtype=tf.float32),
+                                      [self.nl, 1, -1, 1, 2])
+        self.m = [tf_Conv2d(x, self.no * self.na, 1, w=w.m[i]) for i, x in enumerate(ch)]
+        self.export = False  # onnx export
+        self.training = True  # set to False after building model
+        for i in range(self.nl):
+            ny, nx = opt.img_size[0] // self.stride[i], opt.img_size[1] // self.stride[i]
+            self.grid[i] = self._make_grid(nx, ny)
+
+    def call(self, inputs):
+        # x = x.copy()  # for profiling
+        z = []  # inference output
+        self.training |= self.export
+        x = []
+        for i in range(self.nl):
+            x.append(self.m[i](inputs[i]))
+            # x(bs,20,20,255) to x(bs,3,20,20,85)
+            ny, nx = opt.img_size[0] // self.stride[i], opt.img_size[1] // self.stride[i]
+            x[i] = tf.transpose(tf.reshape(x[i], [-1, ny * nx, self.na, self.no]), [0, 2, 1, 3])
+
+            if not self.training:  # inference
+                y = tf.sigmoid(x[i])
+                xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i]  # xy
+                wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]
+                # Normalize xywh to 0-1 to reduce calibration error
+                xy /= tf.constant([[opt.img_size[1], opt.img_size[0]]], dtype=tf.float32)
+                wh /= tf.constant([[opt.img_size[1], opt.img_size[0]]], dtype=tf.float32)
+                y = tf.concat([xy, wh, y[..., 4:]], -1)
+                z.append(tf.reshape(y, [-1, 3 * ny * nx, self.no]))
+
+        return x if self.training else (tf.concat(z, 1), x)
+
+    @staticmethod
+    def _make_grid(nx=20, ny=20):
+        # yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
+        # return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
+        xv, yv = tf.meshgrid(tf.range(nx), tf.range(ny))
+        return tf.cast(tf.reshape(tf.stack([xv, yv], 2), [1, 1, ny * nx, 2]), dtype=tf.float32)
+
+
+class tf_Upsample(keras.layers.Layer):
+    def __init__(self, size, scale_factor, mode, w=None):
+        super(tf_Upsample, self).__init__()
+        assert scale_factor == 2, "scale_factor must be 2"
+        # self.upsample = keras.layers.UpSampling2D(size=scale_factor, interpolation=mode)
+        if opt.tf_raw_resize:
+            # with default arguments: align_corners=False, half_pixel_centers=False
+            self.upsample = lambda x: tf.raw_ops.ResizeNearestNeighbor(images=x,
+                                                                       size=(x.shape[1] * 2, x.shape[2] * 2))
+        else:
+            self.upsample = lambda x: tf.image.resize(x, (x.shape[1] * 2, x.shape[2] * 2), method=mode)
+
+    def call(self, inputs):
+        return self.upsample(inputs)
+
+
+class tf_Concat(keras.layers.Layer):
+    def __init__(self, dimension=1, w=None):
+        super(tf_Concat, self).__init__()
+        assert dimension == 1, "convert only NCHW to NHWC concat"
+        self.d = 3
+
+    def call(self, inputs):
+        return tf.concat(inputs, self.d)
+
+
+def parse_model(d, ch, model):  # model_dict, input_channels(3)
+    logger.info('\n%3s%18s%3s%10s  %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
+    anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
+    na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors  # number of anchors
+    no = na * (nc + 5)  # number of outputs = anchors * (classes + 5)
+
+    layers, save, c2 = [], [], ch[-1]  # layers, savelist, ch out
+    for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']):  # from, number, module, args
+        m_str = m
+        m = eval(m) if isinstance(m, str) else m  # eval strings
+        for j, a in enumerate(args):
+            try:
+                args[j] = eval(a) if isinstance(a, str) else a  # eval strings
+            except:
+                pass
+
+        n = max(round(n * gd), 1) if n > 1 else n  # depth gain
+        if m in [nn.Conv2d, Conv, Bottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]:
+            c1, c2 = ch[f], args[0]
+            c2 = make_divisible(c2 * gw, 8) if c2 != no else c2
+
+            args = [c1, c2, *args[1:]]
+            if m in [BottleneckCSP, C3]:
+                args.insert(2, n)
+                n = 1
+        elif m is nn.BatchNorm2d:
+            args = [ch[f]]
+        elif m is Concat:
+            c2 = sum([ch[-1 if x == -1 else x + 1] for x in f])
+        elif m is Detect:
+            args.append([ch[x + 1] for x in f])
+            if isinstance(args[1], int):  # number of anchors
+                args[1] = [list(range(args[1] * 2))] * len(f)
+        else:
+            c2 = ch[f]
+
+        tf_m = eval('tf_' + m_str.replace('nn.', ''))
+        m_ = keras.Sequential([tf_m(*args, w=model.model[i][j]) for j in range(n)]) if n > 1 \
+            else tf_m(*args, w=model.model[i])  # module
+
+        torch_m_ = nn.Sequential(*[m(*args) for _ in range(n)]) if n > 1 else m(*args)  # module
+        t = str(m)[8:-2].replace('__main__.', '')  # module type
+        np = sum([x.numel() for x in torch_m_.parameters()])  # number params
+        m_.i, m_.f, m_.type, m_.np = i, f, t, np  # attach index, 'from' index, type, number params
+        logger.info('%3s%18s%3s%10.0f  %-40s%-30s' % (i, f, n, np, t, args))  # print
+        save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1)  # append to savelist
+        layers.append(m_)
+        ch.append(c2)
+    return keras.Sequential(layers), sorted(save)
+
+
+class tf_Model():
+    def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, model=None):  # model, input channels, number of classes
+        super(tf_Model, self).__init__()
+        if isinstance(cfg, dict):
+            self.yaml = cfg  # model dict
+        else:  # is *.yaml
+            import yaml  # for torch hub
+            self.yaml_file = Path(cfg).name
+            with open(cfg) as f:
+                self.yaml = yaml.load(f, Loader=yaml.FullLoader)  # model dict
+
+        # Define model
+        if nc and nc != self.yaml['nc']:
+            print('Overriding %s nc=%g with nc=%g' % (cfg, self.yaml['nc'], nc))
+            self.yaml['nc'] = nc  # override yaml value
+        self.model, self.savelist = parse_model(deepcopy(self.yaml), ch=[ch], model=model)  # model, savelist, ch_out
+
+    def predict(self, inputs, profile=False):
+        y = []  # outputs
+        x = inputs
+        for i, m in enumerate(self.model.layers):
+            if m.f != -1:  # if not from previous layer
+                x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]  # from earlier layers
+
+            x = m(x)  # run
+            y.append(x if m.i in self.savelist else None)  # save output
+
+        # Add TensorFlow NMS
+        if opt.tf_nms:
+            boxes = xywh2xyxy(x[0][..., :4])
+            probs = x[0][:, :, 4:5]
+            classes = x[0][:, :, 5:]
+            scores = probs * classes
+            if opt.agnostic_nms:
+                nms = agnostic_nms_layer()((boxes, classes, scores))
+                return nms, x[1]
+            else:
+                boxes = tf.expand_dims(boxes, 2)
+                nms = tf.image.combined_non_max_suppression(
+                    boxes, scores, opt.topk_per_class, opt.topk_all, opt.iou_thres, opt.score_thres, clip_boxes=False)
+                return nms, x[1]
+
+        return x[0]  # output only first tensor [1,6300,85] = [xywh, conf, class0, class1, ...]
+        # x = x[0][0]  # [x(1,6300,85), ...] to x(6300,85)
+        # xywh = x[..., :4]  # x(6300,4) boxes
+        # conf = x[..., 4:5]  # x(6300,1) confidences
+        # cls = tf.reshape(tf.cast(tf.argmax(x[..., 5:], axis=1), tf.float32), (-1, 1))  # x(6300,1)  classes
+        # return tf.concat([conf, cls, xywh], 1)
+
+
+class agnostic_nms_layer(keras.layers.Layer):
+    # wrap map_fn to avoid TypeSpec related error https://stackoverflow.com/a/65809989/3036450
+    def call(self, input):
+        return tf.map_fn(agnostic_nms, input,
+                         fn_output_signature=(tf.float32, tf.float32, tf.float32, tf.int32),
+                         name='agnostic_nms')
+
+
+def agnostic_nms(x):
+    boxes, classes, scores = x
+    class_inds = tf.cast(tf.argmax(classes, axis=-1), tf.float32)
+    scores_inp = tf.reduce_max(scores, -1)
+    selected_inds = tf.image.non_max_suppression(
+        boxes, scores_inp, max_output_size=opt.topk_all, iou_threshold=opt.iou_thres, score_threshold=opt.score_thres)
+    selected_boxes = tf.gather(boxes, selected_inds)
+    padded_boxes = tf.pad(selected_boxes,
+                          paddings=[[0, opt.topk_all - tf.shape(selected_boxes)[0]], [0, 0]],
+                          mode="CONSTANT", constant_values=0.0)
+    selected_scores = tf.gather(scores_inp, selected_inds)
+    padded_scores = tf.pad(selected_scores,
+                           paddings=[[0, opt.topk_all - tf.shape(selected_boxes)[0]]],
+                           mode="CONSTANT", constant_values=-1.0)
+    selected_classes = tf.gather(class_inds, selected_inds)
+    padded_classes = tf.pad(selected_classes,
+                            paddings=[[0, opt.topk_all - tf.shape(selected_boxes)[0]]],
+                            mode="CONSTANT", constant_values=-1.0)
+    valid_detections = tf.shape(selected_inds)[0]
+    return padded_boxes, padded_scores, padded_classes, valid_detections
+
+
+def xywh2xyxy(xywh):
+    # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
+    x, y, w, h = tf.split(xywh, num_or_size_splits=4, axis=-1)
+    return tf.concat([x - w / 2, y - h / 2, x + w / 2, y + h / 2], axis=-1)
+
+
+def representative_dataset_gen():
+    # Representative dataset for use with converter.representative_dataset
+    n = 0
+    for path, img, im0s, vid_cap in dataset:
+        # Get sample input data as a numpy array in a method of your choosing.
+        n += 1
+        input = np.transpose(img, [1, 2, 0])
+        input = np.expand_dims(input, axis=0).astype(np.float32)
+        input /= 255.0
+        yield [input]
+        if n >= opt.ncalib:
+            break
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--cfg', type=str, default='yolov5s.yaml', help='cfg path')
+    parser.add_argument('--weights', type=str, default='yolov5s.pt', help='weights path')
+    parser.add_argument('--img-size', nargs='+', type=int, default=[320, 320], help='image size')  # height, width
+    parser.add_argument('--batch-size', type=int, default=1, help='batch size')
+    parser.add_argument('--dynamic-batch-size', action='store_true', help='dynamic batch size')
+    parser.add_argument('--source', type=str, default='../data/coco128.yaml', help='dir of images or data.yaml file')
+    parser.add_argument('--ncalib', type=int, default=100, help='number of calibration images')
+    parser.add_argument('--tfl-int8', action='store_true', dest='tfl_int8', help='export TFLite int8 model')
+    parser.add_argument('--tf-nms', action='store_true', dest='tf_nms', help='TF NMS (without TFLite export)')
+    parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
+    parser.add_argument('--tf-raw-resize', action='store_true', dest='tf_raw_resize',
+                        help='use tf.raw_ops.ResizeNearestNeighbor for resize')
+    parser.add_argument('--topk-per-class', type=int, default=100, help='topk per class to keep in NMS')
+    parser.add_argument('--topk-all', type=int, default=100, help='topk for all classes to keep in NMS')
+    parser.add_argument('--iou-thres', type=float, default=0.5, help='IOU threshold for NMS')
+    parser.add_argument('--score-thres', type=float, default=0.4, help='score threshold for NMS')
+    opt = parser.parse_args()
+    opt.cfg = check_file(opt.cfg)  # check file
+    opt.img_size *= 2 if len(opt.img_size) == 1 else 1  # expand
+    print(opt)
+
+    # Input
+    img = torch.zeros((opt.batch_size, 3, *opt.img_size))  # image size(1,3,320,192) iDetection
+
+    # Load PyTorch model
+    model = attempt_load(opt.weights, map_location=torch.device('cpu'), inplace=True, fuse=False)
+    model.model[-1].export = False  # set Detect() layer export=True
+    y = model(img)  # dry run
+    nc = y[0].shape[-1] - 5
+
+    # TensorFlow saved_model export
+    try:
+        print('\nStarting TensorFlow saved_model export with TensorFlow %s...' % tf.__version__)
+        tf_model = tf_Model(opt.cfg, model=model, nc=nc)
+        img = tf.zeros((opt.batch_size, *opt.img_size, 3))  # NHWC Input for TensorFlow
+
+        m = tf_model.model.layers[-1]
+        assert isinstance(m, tf_Detect), "the last layer must be Detect"
+        m.training = False
+        y = tf_model.predict(img)
+
+        inputs = keras.Input(shape=(*opt.img_size, 3), batch_size=None if opt.dynamic_batch_size else opt.batch_size)
+        keras_model = keras.Model(inputs=inputs, outputs=tf_model.predict(inputs))
+        keras_model.summary()
+        path = opt.weights.replace('.pt', '_saved_model')  # filename
+        keras_model.save(path, save_format='tf')
+        print('TensorFlow saved_model export success, saved as %s' % path)
+    except Exception as e:
+        print('TensorFlow saved_model export failure: %s' % e)
+        traceback.print_exc(file=sys.stdout)
+
+    # TensorFlow GraphDef export
+    try:
+        print('\nStarting TensorFlow GraphDef export with TensorFlow %s...' % tf.__version__)
+
+        # https://github.com/leimao/Frozen_Graph_TensorFlow
+        full_model = tf.function(lambda x: keras_model(x))
+        full_model = full_model.get_concrete_function(
+            tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
+
+        frozen_func = convert_variables_to_constants_v2(full_model)
+        frozen_func.graph.as_graph_def()
+        f = opt.weights.replace('.pt', '.pb')  # filename
+        tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
+                          logdir=os.path.dirname(f),
+                          name=os.path.basename(f),
+                          as_text=False)
+
+        print('TensorFlow GraphDef export success, saved as %s' % f)
+    except Exception as e:
+        print('TensorFlow GraphDef export failure: %s' % e)
+        traceback.print_exc(file=sys.stdout)
+
+    # TFLite model export
+    if not opt.tf_nms:
+        try:
+            print('\nStarting TFLite export with TensorFlow %s...' % tf.__version__)
+
+            # fp32 TFLite model export ---------------------------------------------------------------------------------
+            # converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
+            # converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
+            # converter.allow_custom_ops = False
+            # converter.experimental_new_converter = True
+            # tflite_model = converter.convert()
+            # f = opt.weights.replace('.pt', '.tflite')  # filename
+            # open(f, "wb").write(tflite_model)
+
+            # fp16 TFLite model export ---------------------------------------------------------------------------------
+            converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
+            converter.optimizations = [tf.lite.Optimize.DEFAULT]
+            # converter.representative_dataset = representative_dataset_gen
+            # converter.target_spec.supported_types = [tf.float16]
+            converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
+            converter.allow_custom_ops = False
+            converter.experimental_new_converter = True
+            tflite_model = converter.convert()
+            f = opt.weights.replace('.pt', '-fp16.tflite')  # filename
+            open(f, "wb").write(tflite_model)
+            print('\nTFLite export success, saved as %s' % f)
+
+            # int8 TFLite model export ---------------------------------------------------------------------------------
+            if opt.tfl_int8:
+                # Representative Dataset
+                if opt.source.endswith('.yaml'):
+                    with open(check_file(opt.source)) as f:
+                        data = yaml.load(f, Loader=yaml.FullLoader)  # data dict
+                        check_dataset(data)  # check
+                    opt.source = data['train']
+                dataset = LoadImages(opt.source, img_size=opt.img_size, auto=False)
+                converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
+                converter.optimizations = [tf.lite.Optimize.DEFAULT]
+                converter.representative_dataset = representative_dataset_gen
+                converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+                converter.inference_input_type = tf.uint8  # or tf.int8
+                converter.inference_output_type = tf.uint8  # or tf.int8
+                converter.allow_custom_ops = False
+                converter.experimental_new_converter = True
+                converter.experimental_new_quantizer = False
+                tflite_model = converter.convert()
+                f = opt.weights.replace('.pt', '-int8.tflite')  # filename
+                open(f, "wb").write(tflite_model)
+                print('\nTFLite (int8) export success, saved as %s' % f)
+
+        except Exception as e:
+            print('\nTFLite export failure: %s' % e)
+            traceback.print_exc(file=sys.stdout)
diff --git a/requirements.txt b/requirements.txt
index f1629eafc65a2432f1c383a9c9bae5c914ae2a69..f6361d591f1be62c94e8fad2292974472742db8b 100755
--- a/requirements.txt
+++ b/requirements.txt
@@ -23,6 +23,7 @@ pandas
 # coremltools>=4.1
 # onnx>=1.9.0
 # scikit-learn==0.19.2  # for coreml quantization
+# tensorflow==2.4.1  # for TFLite export
 
 # extras --------------------------------------
 # Cython  # for pycocotools https://github.com/cocodataset/cocoapi/issues/172
diff --git a/utils/datasets.py b/utils/datasets.py
index 7d831cd632307f1a7f16deb447b5afe3f9475e0f..52b02899432589e63c514f98ad6f6a753fea29d6 100755
--- a/utils/datasets.py
+++ b/utils/datasets.py
@@ -155,7 +155,7 @@ class _RepeatSampler(object):
 
 
 class LoadImages:  # for inference
-    def __init__(self, path, img_size=640, stride=32):
+    def __init__(self, path, img_size=640, stride=32, auto=True):
         p = str(Path(path).absolute())  # os-agnostic absolute path
         if '*' in p:
             files = sorted(glob.glob(p, recursive=True))  # glob
@@ -176,6 +176,7 @@ class LoadImages:  # for inference
         self.nf = ni + nv  # number of files
         self.video_flag = [False] * ni + [True] * nv
         self.mode = 'image'
+        self.auto = auto
         if any(videos):
             self.new_video(videos[0])  # new video
         else:
@@ -217,7 +218,7 @@ class LoadImages:  # for inference
             print(f'image {self.count}/{self.nf} {path}: ', end='')
 
         # Padded resize
-        img = letterbox(img0, self.img_size, stride=self.stride)[0]
+        img = letterbox(img0, self.img_size, stride=self.stride, auto=self.auto)[0]
 
         # Convert
         img = img.transpose((2, 0, 1))[::-1]  # HWC to CHW, BGR to RGB
@@ -276,7 +277,7 @@ class LoadWebcam:  # for inference
 
 
 class LoadStreams:  # multiple IP or RTSP cameras
-    def __init__(self, sources='streams.txt', img_size=640, stride=32):
+    def __init__(self, sources='streams.txt', img_size=640, stride=32, auto=True):
         self.mode = 'stream'
         self.img_size = img_size
         self.stride = stride
@@ -290,6 +291,7 @@ class LoadStreams:  # multiple IP or RTSP cameras
         n = len(sources)
         self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
         self.sources = [clean_str(x) for x in sources]  # clean source names for later
+        self.auto = auto
         for i, s in enumerate(sources):  # index, source
             # Start thread to read frames from video stream
             print(f'{i + 1}/{n}: {s}... ', end='')
@@ -312,7 +314,7 @@ class LoadStreams:  # multiple IP or RTSP cameras
         print('')  # newline
 
         # check for common shapes
-        s = np.stack([letterbox(x, self.img_size, stride=self.stride)[0].shape for x in self.imgs], 0)  # shapes
+        s = np.stack([letterbox(x, self.img_size, stride=self.stride, auto=self.auto)[0].shape for x in self.imgs], 0)  # shapes
         self.rect = np.unique(s, axis=0).shape[0] == 1  # rect inference if all shapes equal
         if not self.rect:
             print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.')
@@ -341,7 +343,7 @@ class LoadStreams:  # multiple IP or RTSP cameras
 
         # Letterbox
         img0 = self.imgs.copy()
-        img = [letterbox(x, self.img_size, auto=self.rect, stride=self.stride)[0] for x in img0]
+        img = [letterbox(x, self.img_size, stride=self.stride, auto=self.rect and self.auto)[0] for x in img0]
 
         # Stack
         img = np.stack(img, 0)