export.py 7.76 KB
Newer Older
1
"""Export a YOLOv5 *.pt model to TorchScript, ONNX, CoreML formats
2
3

Usage:
4
    $ python path/to/export.py --weights yolov5s.pt --img 640 --batch 1
5
6
7
"""

import argparse
8
import sys
9
import time
10
from pathlib import Path
11

Jirka Borovec's avatar
Jirka Borovec committed
12
import torch
13
import torch.nn as nn
14
from torch.utils.mobile_optimizer import optimize_for_mobile
Jirka Borovec's avatar
Jirka Borovec committed
15

16
FILE = Path(__file__).absolute()
17
sys.path.append(FILE.parents[0].as_posix())  # add yolov5/ to path
18
19
20

from models.common import Conv
from models.yolo import Detect
21
from models.experimental import attempt_load
22
from utils.activations import Hardswish, SiLU
23
from utils.general import colorstr, check_img_size, check_requirements, file_size, set_logging
Jan Hajek's avatar
Jan Hajek committed
24
from utils.torch_utils import select_device
25

26

Glenn Jocher's avatar
Glenn Jocher committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def export_torchscript(model, img, file, optimize):
    # TorchScript model export
    prefix = colorstr('TorchScript:')
    try:
        print(f'\n{prefix} starting export with torch {torch.__version__}...')
        f = file.with_suffix('.torchscript.pt')
        ts = torch.jit.trace(model, img, strict=False)
        (optimize_for_mobile(ts) if optimize else ts).save(f)
        print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
        return ts
    except Exception as e:
        print(f'{prefix} export failure: {e}')


41
def export_onnx(model, img, file, opset, train, dynamic, simplify):
Glenn Jocher's avatar
Glenn Jocher committed
42
43
44
45
46
47
    # ONNX model export
    prefix = colorstr('ONNX:')
    try:
        check_requirements(('onnx', 'onnx-simplifier'))
        import onnx

48
        print(f'\n{prefix} starting export with onnx {onnx.__version__}...')
Glenn Jocher's avatar
Glenn Jocher committed
49
        f = file.with_suffix('.onnx')
50
        torch.onnx.export(model, img, f, verbose=False, opset_version=opset,
Glenn Jocher's avatar
Glenn Jocher committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
                          training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,
                          do_constant_folding=not train,
                          input_names=['images'],
                          output_names=['output'],
                          dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'},  # shape(1,3,640,640)
                                        'output': {0: 'batch', 1: 'anchors'}  # shape(1,25200,85)
                                        } if dynamic else None)

        # Checks
        model_onnx = onnx.load(f)  # load onnx model
        onnx.checker.check_model(model_onnx)  # check onnx model
        # print(onnx.helper.printable_graph(model_onnx.graph))  # print

        # Simplify
        if simplify:
            try:
                import onnxsim

                print(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
                model_onnx, check = onnxsim.simplify(
                    model_onnx,
                    dynamic_input_shape=dynamic,
                    input_shapes={'images': list(img.shape)} if dynamic else None)
                assert check, 'assert check failed'
                onnx.save(model_onnx, f)
            except Exception as e:
                print(f'{prefix} simplifier failure: {e}')
        print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
79
        print(f"{prefix} run --dynamic ONNX model inference with: 'python detect.py --weights {f}'")
Glenn Jocher's avatar
Glenn Jocher committed
80
81
82
83
    except Exception as e:
        print(f'{prefix} export failure: {e}')


84
def export_coreml(model, img, file):
Glenn Jocher's avatar
Glenn Jocher committed
85
86
87
88
89
    # CoreML model export
    prefix = colorstr('CoreML:')
    try:
        import coremltools as ct

90
        print(f'\n{prefix} starting export with coremltools {ct.__version__}...')
Glenn Jocher's avatar
Glenn Jocher committed
91
        f = file.with_suffix('.mlmodel')
92
93
94
        model.train()  # CoreML exports should be placed in model.train() mode
        ts = torch.jit.trace(model, img, strict=False)  # TorchScript model
        model = ct.convert(ts, inputs=[ct.ImageType('image', shape=img.shape, scale=1 / 255.0, bias=[0, 0, 0])])
Glenn Jocher's avatar
Glenn Jocher committed
95
96
97
        model.save(f)
        print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
    except Exception as e:
98
        print(f'\n{prefix} export failure: {e}')
Glenn Jocher's avatar
Glenn Jocher committed
99
100


101
102
103
104
105
106
107
108
109
110
111
def run(weights='./yolov5s.pt',  # weights path
        img_size=(640, 640),  # image (height, width)
        batch_size=1,  # batch size
        device='cpu',  # cuda device, i.e. 0 or 0,1,2,3 or cpu
        include=('torchscript', 'onnx', 'coreml'),  # include formats
        half=False,  # FP16 half-precision export
        inplace=False,  # set YOLOv5 Detect() inplace=True
        train=False,  # model.train() mode
        optimize=False,  # TorchScript: optimize for mobile
        dynamic=False,  # ONNX: dynamic axes
        simplify=False,  # ONNX: simplify model
112
        opset=12,  # ONNX: opset version
113
        ):
114
    t = time.time()
115
116
    include = [x.lower() for x in include]
    img_size *= 2 if len(img_size) == 1 else 1  # expand
Glenn Jocher's avatar
Glenn Jocher committed
117
    file = Path(weights)
118
119

    # Load PyTorch model
120
    device = select_device(device)
fcakyon's avatar
fcakyon committed
121
    assert not (device.type == 'cpu' and half), '--half only compatible with GPU export, i.e. use --device 0'
122
    model = attempt_load(weights, map_location=device)  # load FP32 model
Glenn Jocher's avatar
Glenn Jocher committed
123
    names = model.names
124

125
    # Input
126
    gs = int(max(model.stride))  # grid size (max stride)
127
128
    img_size = [check_img_size(x, gs) for x in img_size]  # verify img_size are gs-multiples
    img = torch.zeros(batch_size, 3, *img_size).to(device)  # image size(1,3,320,192) iDetection
129

130
    # Update model
131
    if half:
132
        img, model = img.half(), model.half()  # to FP16
133
    model.train() if train else model.eval()  # training mode = no Detect() layer grid construction
134
    for k, m in model.named_modules():
135
        if isinstance(m, Conv):  # assign export-friendly activations
136
137
138
139
            if isinstance(m.act, nn.Hardswish):
                m.act = Hardswish()
            elif isinstance(m.act, nn.SiLU):
                m.act = SiLU()
140
        elif isinstance(m, Detect):
141
142
            m.inplace = inplace
            m.onnx_dynamic = dynamic
143
144
            # m.forward = m.forward_export  # assign forward (optional)

145
146
    for _ in range(2):
        y = model(img)  # dry runs
147
    print(f"\n{colorstr('PyTorch:')} starting from {weights} ({file_size(weights):.1f} MB)")
148

Glenn Jocher's avatar
Glenn Jocher committed
149
    # Exports
150
151
    if 'torchscript' in include:
        export_torchscript(model, img, file, optimize)
152
    if 'onnx' in include:
153
        export_onnx(model, img, file, opset, train, dynamic, simplify)
154
155
    if 'coreml' in include:
        export_coreml(model, img, file)
Glenn Jocher's avatar
Glenn Jocher committed
156

Glenn Jocher's avatar
Glenn Jocher committed
157
    # Finish
158
    print(f'\nExport complete ({time.time() - t:.2f}s)'
159
160
          f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
          f'\nVisualize with https://netron.app')
161
162


163
def parse_opt():
164
165
166
167
168
169
170
171
172
173
174
175
    parser = argparse.ArgumentParser()
    parser.add_argument('--weights', type=str, default='./yolov5s.pt', help='weights path')
    parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image (height, width)')
    parser.add_argument('--batch-size', type=int, default=1, help='batch size')
    parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
    parser.add_argument('--include', nargs='+', default=['torchscript', 'onnx', 'coreml'], help='include formats')
    parser.add_argument('--half', action='store_true', help='FP16 half-precision export')
    parser.add_argument('--inplace', action='store_true', help='set YOLOv5 Detect() inplace=True')
    parser.add_argument('--train', action='store_true', help='model.train() mode')
    parser.add_argument('--optimize', action='store_true', help='TorchScript: optimize for mobile')
    parser.add_argument('--dynamic', action='store_true', help='ONNX: dynamic axes')
    parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')
176
    parser.add_argument('--opset', type=int, default=12, help='ONNX: opset version')
177
    opt = parser.parse_args()
178
179
180
181
    return opt


def main(opt):
182
    set_logging()
183
    print(colorstr('export: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
184
    run(**vars(opt))
185
186
187
188
189


if __name__ == "__main__":
    opt = parse_opt()
    main(opt)