plots.py 17.5 KB
Newer Older
Glenn Jocher's avatar
Glenn Jocher committed
1
2
3
4
5
6
# Plotting utils

from copy import copy
from pathlib import Path

import cv2
7
import math
Glenn Jocher's avatar
Glenn Jocher committed
8
9
10
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
11
import pandas as pd
12
import seaborn as sn
Glenn Jocher's avatar
Glenn Jocher committed
13
14
import torch
import yaml
Glenn Jocher's avatar
Glenn Jocher committed
15
from PIL import Image, ImageDraw, ImageFont
Glenn Jocher's avatar
Glenn Jocher committed
16

17
from utils.general import xywh2xyxy, xyxy2xywh
Glenn Jocher's avatar
Glenn Jocher committed
18
19
from utils.metrics import fitness

Glenn Jocher's avatar
Glenn Jocher committed
20
# Settings
21
matplotlib.rc('font', **{'size': 11})
22
matplotlib.use('Agg')  # for writing to files only
Glenn Jocher's avatar
Glenn Jocher committed
23

Glenn Jocher's avatar
Glenn Jocher committed
24

Glenn Jocher's avatar
Glenn Jocher committed
25
26
27
class Colors:
    # Ultralytics color palette https://ultralytics.com/
    def __init__(self):
Glenn Jocher's avatar
Glenn Jocher committed
28
29
30
31
        # hex = matplotlib.colors.TABLEAU_COLORS.values()
        hex = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
               '2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
        self.palette = [self.hex2rgb('#' + c) for c in hex]
Glenn Jocher's avatar
Glenn Jocher committed
32
33
34
35
36
37
38
39
        self.n = len(self.palette)

    def __call__(self, i, bgr=False):
        c = self.palette[int(i) % self.n]
        return (c[2], c[1], c[0]) if bgr else c

    @staticmethod
    def hex2rgb(h):  # rgb order (PIL)
Glenn Jocher's avatar
Glenn Jocher committed
40
41
        return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))

Glenn Jocher's avatar
Glenn Jocher committed
42
43

colors = Colors()  # create instance for 'from utils.plots import colors'
Glenn Jocher's avatar
Glenn Jocher committed
44
45
46
47
48
49
50
51
52
53
54
55


def hist2d(x, y, n=100):
    # 2d histogram used in labels.png and evolve.png
    xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n)
    hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges))
    xidx = np.clip(np.digitize(x, xedges) - 1, 0, hist.shape[0] - 1)
    yidx = np.clip(np.digitize(y, yedges) - 1, 0, hist.shape[1] - 1)
    return np.log(hist[xidx, yidx])


def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5):
56
57
    from scipy.signal import butter, filtfilt

Glenn Jocher's avatar
Glenn Jocher committed
58
59
60
61
62
63
64
65
66
67
    # https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy
    def butter_lowpass(cutoff, fs, order):
        nyq = 0.5 * fs
        normal_cutoff = cutoff / nyq
        return butter(order, normal_cutoff, btype='low', analog=False)

    b, a = butter_lowpass(cutoff, fs, order=order)
    return filtfilt(b, a, data)  # forward-backward filter


68
def plot_one_box(x, im, color=(128, 128, 128), label=None, line_thickness=3):
69
70
71
    # Plots one bounding box on image 'im' using OpenCV
    assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to plot_on_box() input image.'
    tl = line_thickness or round(0.002 * (im.shape[0] + im.shape[1]) / 2) + 1  # line/font thickness
Glenn Jocher's avatar
Glenn Jocher committed
72
    c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
73
    cv2.rectangle(im, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
Glenn Jocher's avatar
Glenn Jocher committed
74
75
76
77
    if label:
        tf = max(tl - 1, 1)  # font thickness
        t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
        c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
78
79
        cv2.rectangle(im, c1, c2, color, -1, cv2.LINE_AA)  # filled
        cv2.putText(im, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
Glenn Jocher's avatar
Glenn Jocher committed
80
81


82
def plot_one_box_PIL(box, im, color=(128, 128, 128), label=None, line_thickness=None):
83
84
85
86
    # Plots one bounding box on image 'im' using PIL
    im = Image.fromarray(im)
    draw = ImageDraw.Draw(im)
    line_thickness = line_thickness or max(int(min(im.size) / 200), 2)
87
    draw.rectangle(box, width=line_thickness, outline=color)  # plot
Glenn Jocher's avatar
Glenn Jocher committed
88
    if label:
89
        font = ImageFont.truetype("Arial.ttf", size=max(round(max(im.size) / 40), 12))
Glenn Jocher's avatar
Glenn Jocher committed
90
        txt_width, txt_height = font.getsize(label)
91
        draw.rectangle([box[0], box[1] - txt_height + 4, box[0] + txt_width, box[1]], fill=color)
Glenn Jocher's avatar
Glenn Jocher committed
92
        draw.text((box[0], box[1] - txt_height + 1), label, fill=(255, 255, 255), font=font)
93
    return np.asarray(im)
Glenn Jocher's avatar
Glenn Jocher committed
94
95


Glenn Jocher's avatar
Glenn Jocher committed
96
def plot_wh_methods():  # from utils.plots import *; plot_wh_methods()
Glenn Jocher's avatar
Glenn Jocher committed
97
98
99
100
101
102
    # Compares the two methods for width-height anchor multiplication
    # https://github.com/ultralytics/yolov3/issues/168
    x = np.arange(-4.0, 4.0, .1)
    ya = np.exp(x)
    yb = torch.sigmoid(torch.from_numpy(x)).numpy() * 2

103
    fig = plt.figure(figsize=(6, 3), tight_layout=True)
Glenn Jocher's avatar
Glenn Jocher committed
104
105
106
107
108
109
110
111
112
113
114
115
    plt.plot(x, ya, '.-', label='YOLOv3')
    plt.plot(x, yb ** 2, '.-', label='YOLOv5 ^2')
    plt.plot(x, yb ** 1.6, '.-', label='YOLOv5 ^1.6')
    plt.xlim(left=-4, right=4)
    plt.ylim(bottom=0, top=6)
    plt.xlabel('input')
    plt.ylabel('output')
    plt.grid()
    plt.legend()
    fig.savefig('comparison.png', dpi=200)


Glenn Jocher's avatar
Glenn Jocher committed
116
def output_to_target(output):
Glenn Jocher's avatar
Glenn Jocher committed
117
118
119
    # Convert model output to target format [batch_id, class_id, x, y, w, h, conf]
    targets = []
    for i, o in enumerate(output):
Glenn Jocher's avatar
Glenn Jocher committed
120
121
        for *box, conf, cls in o.cpu().numpy():
            targets.append([i, cls, *list(*xyxy2xywh(np.array(box)[None])), conf])
Glenn Jocher's avatar
Glenn Jocher committed
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
    return np.array(targets)


def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=640, max_subplots=16):
    # Plot image grid with labels

    if isinstance(images, torch.Tensor):
        images = images.cpu().float().numpy()
    if isinstance(targets, torch.Tensor):
        targets = targets.cpu().numpy()

    # un-normalise
    if np.max(images[0]) <= 1:
        images *= 255

    tl = 3  # line thickness
    tf = max(tl - 1, 1)  # font thickness
    bs, _, h, w = images.shape  # batch size, _, height, width
    bs = min(bs, max_subplots)  # limit plot images
    ns = np.ceil(bs ** 0.5)  # number of subplots (square)

    # Check if we should resize
    scale_factor = max_size / max(h, w)
    if scale_factor < 1:
        h = math.ceil(scale_factor * h)
        w = math.ceil(scale_factor * w)

    mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8)  # init
    for i, img in enumerate(images):
        if i == max_subplots:  # if last batch has fewer images than we expect
            break

        block_x = int(w * (i // ns))
        block_y = int(h * (i % ns))

        img = img.transpose(1, 2, 0)
        if scale_factor < 1:
            img = cv2.resize(img, (w, h))

        mosaic[block_y:block_y + h, block_x:block_x + w, :] = img
        if len(targets) > 0:
            image_targets = targets[targets[:, 0] == i]
            boxes = xywh2xyxy(image_targets[:, 2:6]).T
            classes = image_targets[:, 1].astype('int')
            labels = image_targets.shape[1] == 6  # labels if no conf column
            conf = None if labels else image_targets[:, 6]  # check for confidence presence (label vs pred)

Hu Ye's avatar
Hu Ye committed
169
            if boxes.shape[1]:
170
                if boxes.max() <= 1.01:  # if normalized with tolerance 0.01
Hu Ye's avatar
Hu Ye committed
171
172
                    boxes[[0, 2]] *= w  # scale to pixels
                    boxes[[1, 3]] *= h
173
                elif scale_factor < 1:  # absolute coords need scale if image scales
Hu Ye's avatar
Hu Ye committed
174
                    boxes *= scale_factor
Glenn Jocher's avatar
Glenn Jocher committed
175
176
177
178
            boxes[[0, 2]] += block_x
            boxes[[1, 3]] += block_y
            for j, box in enumerate(boxes.T):
                cls = int(classes[j])
Glenn Jocher's avatar
Glenn Jocher committed
179
                color = colors(cls)
Glenn Jocher's avatar
Glenn Jocher committed
180
                cls = names[cls] if names else cls
181
                if labels or conf[j] > 0.25:  # 0.25 conf thresh
Glenn Jocher's avatar
Glenn Jocher committed
182
183
184
185
                    label = '%s' % cls if labels else '%s %.1f' % (cls, conf[j])
                    plot_one_box(box, mosaic, label=label, color=color, line_thickness=tl)

        # Draw image filename labels
186
187
        if paths:
            label = Path(paths[i]).name[:40]  # trim to 40 char
Glenn Jocher's avatar
Glenn Jocher committed
188
189
190
191
192
193
194
            t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
            cv2.putText(mosaic, label, (block_x + 5, block_y + t_size[1] + 5), 0, tl / 3, [220, 220, 220], thickness=tf,
                        lineType=cv2.LINE_AA)

        # Image border
        cv2.rectangle(mosaic, (block_x, block_y), (block_x + w, block_y + h), (255, 255, 255), thickness=3)

195
    if fname:
Glenn Jocher's avatar
Glenn Jocher committed
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
        r = min(1280. / max(h, w) / ns, 1.0)  # ratio to limit image size
        mosaic = cv2.resize(mosaic, (int(ns * w * r), int(ns * h * r)), interpolation=cv2.INTER_AREA)
        # cv2.imwrite(fname, cv2.cvtColor(mosaic, cv2.COLOR_BGR2RGB))  # cv2 save
        Image.fromarray(mosaic).save(fname)  # PIL save
    return mosaic


def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir=''):
    # Plot LR simulating training for full epochs
    optimizer, scheduler = copy(optimizer), copy(scheduler)  # do not modify originals
    y = []
    for _ in range(epochs):
        scheduler.step()
        y.append(optimizer.param_groups[0]['lr'])
    plt.plot(y, '.-', label='LR')
    plt.xlabel('epoch')
    plt.ylabel('LR')
    plt.grid()
    plt.xlim(0, epochs)
    plt.ylim(0)
    plt.savefig(Path(save_dir) / 'LR.png', dpi=200)
217
    plt.close()
Glenn Jocher's avatar
Glenn Jocher committed
218
219


220
221
222
def plot_val_txt():  # from utils.plots import *; plot_val()
    # Plot val.txt histograms
    x = np.loadtxt('val.txt', dtype=np.float32)
Glenn Jocher's avatar
Glenn Jocher committed
223
224
225
226
227
228
229
230
231
232
233
234
235
236
    box = xyxy2xywh(x[:, :4])
    cx, cy = box[:, 0], box[:, 1]

    fig, ax = plt.subplots(1, 1, figsize=(6, 6), tight_layout=True)
    ax.hist2d(cx, cy, bins=600, cmax=10, cmin=0)
    ax.set_aspect('equal')
    plt.savefig('hist2d.png', dpi=300)

    fig, ax = plt.subplots(1, 2, figsize=(12, 6), tight_layout=True)
    ax[0].hist(cx, bins=600)
    ax[1].hist(cy, bins=600)
    plt.savefig('hist1d.png', dpi=200)


Glenn Jocher's avatar
Glenn Jocher committed
237
def plot_targets_txt():  # from utils.plots import *; plot_targets_txt()
Glenn Jocher's avatar
Glenn Jocher committed
238
239
240
241
242
243
244
245
246
247
248
249
    # Plot targets.txt histograms
    x = np.loadtxt('targets.txt', dtype=np.float32).T
    s = ['x targets', 'y targets', 'width targets', 'height targets']
    fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)
    ax = ax.ravel()
    for i in range(4):
        ax[i].hist(x[i], bins=100, label='%.3g +/- %.3g' % (x[i].mean(), x[i].std()))
        ax[i].legend()
        ax[i].set_title(s[i])
    plt.savefig('targets.jpg', dpi=200)


250
def plot_study_txt(path='', x=None):  # from utils.plots import *; plot_study_txt()
251
    # Plot study.txt generated by val.py
252
253
254
    plot2 = False  # plot additional results
    if plot2:
        ax = plt.subplots(2, 4, figsize=(10, 6), tight_layout=True)[1].ravel()
Glenn Jocher's avatar
Glenn Jocher committed
255
256

    fig2, ax2 = plt.subplots(1, 1, figsize=(8, 4), tight_layout=True)
Glenn Jocher's avatar
Glenn Jocher committed
257
    # for f in [Path(path) / f'study_coco_{x}.txt' for x in ['yolov5s6', 'yolov5m6', 'yolov5l6', 'yolov5x6']]:
Glenn Jocher's avatar
Glenn Jocher committed
258
    for f in sorted(Path(path).glob('study*.txt')):
Glenn Jocher's avatar
Glenn Jocher committed
259
260
        y = np.loadtxt(f, dtype=np.float32, usecols=[0, 1, 2, 3, 7, 8, 9], ndmin=2).T
        x = np.arange(y.shape[1]) if x is None else np.array(x)
261
262
263
264
265
        if plot2:
            s = ['P', 'R', 'mAP@.5', 'mAP@.5:.95', 't_preprocess (ms/img)', 't_inference (ms/img)', 't_NMS (ms/img)']
            for i in range(7):
                ax[i].plot(x, y[i], '.-', linewidth=2, markersize=8)
                ax[i].set_title(s[i])
Glenn Jocher's avatar
Glenn Jocher committed
266
267

        j = y[3].argmax() + 1
268
        ax2.plot(y[5, 1:j], y[3, 1:j] * 1E2, '.-', linewidth=2, markersize=8,
Glenn Jocher's avatar
Glenn Jocher committed
269
                 label=f.stem.replace('study_coco_', '').replace('yolo', 'YOLO'))
Glenn Jocher's avatar
Glenn Jocher committed
270
271
272
273

    ax2.plot(1E3 / np.array([209, 140, 97, 58, 35, 18]), [34.6, 40.5, 43.0, 47.5, 49.7, 51.5],
             'k.-', linewidth=2, markersize=8, alpha=.25, label='EfficientDet')

Glenn Jocher's avatar
Glenn Jocher committed
274
275
    ax2.grid(alpha=0.2)
    ax2.set_yticks(np.arange(20, 60, 5))
Glenn Jocher's avatar
Glenn Jocher committed
276
    ax2.set_xlim(0, 57)
Glenn Jocher's avatar
Glenn Jocher committed
277
    ax2.set_ylim(30, 55)
Glenn Jocher's avatar
Glenn Jocher committed
278
279
280
    ax2.set_xlabel('GPU Speed (ms/img)')
    ax2.set_ylabel('COCO AP val')
    ax2.legend(loc='lower right')
Glenn Jocher's avatar
Glenn Jocher committed
281
    plt.savefig(str(Path(path).name) + '.png', dpi=300)
Glenn Jocher's avatar
Glenn Jocher committed
282
283


284
def plot_labels(labels, names=(), save_dir=Path('')):
Glenn Jocher's avatar
Glenn Jocher committed
285
    # plot dataset labels
286
    print('Plotting labels... ')
Glenn Jocher's avatar
Glenn Jocher committed
287
288
    c, b = labels[:, 0], labels[:, 1:].transpose()  # classes, boxes
    nc = int(c.max() + 1)  # number of classes
289
    x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])
Glenn Jocher's avatar
Glenn Jocher committed
290

291
    # seaborn correlogram
292
    sn.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
293
294
    plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
    plt.close()
295
296

    # matplotlib labels
297
    matplotlib.use('svg')  # faster
298
    ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
299
    y = ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
300
    # [y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)]  # update colors bug #3195
Glenn Jocher's avatar
Glenn Jocher committed
301
302
303
304
305
306
    ax[0].set_ylabel('instances')
    if 0 < len(names) < 30:
        ax[0].set_xticks(range(len(names)))
        ax[0].set_xticklabels(names, rotation=90, fontsize=10)
    else:
        ax[0].set_xlabel('classes')
307
308
    sn.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
    sn.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)
309
310
311
312
313
314

    # rectangles
    labels[:, 1:3] = 0.5  # center
    labels[:, 1:] = xywh2xyxy(labels[:, 1:]) * 2000
    img = Image.fromarray(np.ones((2000, 2000, 3), dtype=np.uint8) * 255)
    for cls, *box in labels[:1000]:
Glenn Jocher's avatar
Glenn Jocher committed
315
        ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls))  # plot
316
317
318
319
320
321
    ax[1].imshow(img)
    ax[1].axis('off')

    for a in [0, 1, 2, 3]:
        for s in ['top', 'right', 'left', 'bottom']:
            ax[a].spines[s].set_visible(False)
322
323
324

    plt.savefig(save_dir / 'labels.jpg', dpi=200)
    matplotlib.use('Agg')
Glenn Jocher's avatar
Glenn Jocher committed
325
326
327
    plt.close()


328
def profile_idetection(start=0, stop=0, labels=(), save_dir=''):
329
330
    # Plot iDetection '*.txt' per-image logs. from utils.plots import *; profile_idetection()
    ax = plt.subplots(2, 4, figsize=(12, 6), tight_layout=True)[1].ravel()
331
332
333
334
335
336
337
    s = ['Images', 'Free Storage (GB)', 'RAM Usage (GB)', 'Battery', 'dt_raw (ms)', 'dt_smooth (ms)', 'real-world FPS']
    files = list(Path(save_dir).glob('frames*.txt'))
    for fi, f in enumerate(files):
        try:
            results = np.loadtxt(f, ndmin=2).T[:, 90:-30]  # clip first and last rows
            n = results.shape[1]  # number of rows
            x = np.arange(start, min(stop, n) if stop else n)
338
            results = results[:, x]
339
340
341
342
343
            t = (results[0] - results[0].min())  # set t0=0s
            results[0] = x
            for i, a in enumerate(ax):
                if i < len(results):
                    label = labels[fi] if len(labels) else f.stem.replace('frames_', '')
344
                    a.plot(t, results[i], marker='.', label=label, linewidth=1, markersize=5)
345
346
347
                    a.set_title(s[i])
                    a.set_xlabel('time (s)')
                    # if fi == len(files) - 1:
348
                    #     a.set_ylim(bottom=0)
349
350
351
352
353
354
355
356
                    for side in ['top', 'right']:
                        a.spines[side].set_visible(False)
                else:
                    a.remove()
        except Exception as e:
            print('Warning: Plotting error for %s; %s' % (f, e))

    ax[1].legend()
357
    plt.savefig(Path(save_dir) / 'idetection_profile.png', dpi=200)
358
359


Glenn Jocher's avatar
Glenn Jocher committed
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
def plot_evolve(evolve_csv=Path('path/to/evolve.csv')):  # from utils.plots import *; plot_evolve()
    # Plot evolve.csv hyp evolution results
    data = pd.read_csv(evolve_csv)
    keys = [x.strip() for x in data.columns]
    x = data.values
    f = fitness(x)
    j = np.argmax(f)  # max fitness index
    plt.figure(figsize=(10, 12), tight_layout=True)
    matplotlib.rc('font', **{'size': 8})
    for i, k in enumerate(keys[7:]):
        v = x[:, 7 + i]
        mu = v[j]  # best single result
        plt.subplot(6, 5, i + 1)
        plt.scatter(v, f, c=hist2d(v, f, 20), cmap='viridis', alpha=.8, edgecolors='none')
        plt.plot(mu, f.max(), 'k+', markersize=15)
        plt.title('%s = %.3g' % (k, mu), fontdict={'size': 9})  # limit to 40 characters
        if i % 5 != 0:
            plt.yticks([])
        print('%15s: %.3g' % (k, mu))
    f = evolve_csv.with_suffix('.png')  # filename
    plt.savefig(f, dpi=200)
    print(f'Saved {f}')


def plot_results(file='path/to/results.csv', dir=''):
Glenn Jocher's avatar
Glenn Jocher committed
385
386
    # Plot training results.csv. Usage: from utils.plots import *; plot_results('path/to/results.csv')
    save_dir = Path(file).parent if file else Path(dir)
387
    fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
Glenn Jocher's avatar
Glenn Jocher committed
388
    ax = ax.ravel()
Glenn Jocher's avatar
Glenn Jocher committed
389
390
    files = list(save_dir.glob('results*.csv'))
    assert len(files), f'No results.csv files found in {save_dir.resolve()}, nothing to plot.'
Glenn Jocher's avatar
Glenn Jocher committed
391
392
    for fi, f in enumerate(files):
        try:
Glenn Jocher's avatar
Glenn Jocher committed
393
394
395
396
397
398
399
400
401
            data = pd.read_csv(f)
            s = [x.strip() for x in data.columns]
            x = data.values[:, 0]
            for i, j in enumerate([1, 2, 3, 4, 5, 8, 9, 10, 6, 7]):
                y = data.values[:, j]
                # y[y == 0] = np.nan  # don't show zero values
                ax[i].plot(x, y, marker='.', label=f.stem, linewidth=2, markersize=8)
                ax[i].set_title(s[j], fontsize=12)
                # if j in [8, 9, 10]:  # share train and val loss y axes
Glenn Jocher's avatar
Glenn Jocher committed
402
403
                #     ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
        except Exception as e:
Glenn Jocher's avatar
Glenn Jocher committed
404
            print(f'Warning: Plotting error for {f}: {e}')
Glenn Jocher's avatar
Glenn Jocher committed
405
    ax[1].legend()
Glenn Jocher's avatar
Glenn Jocher committed
406
    fig.savefig(save_dir / 'results.png', dpi=200)
407
408


409
def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detect/exp')):
410
    """
411
    x:              Features to be visualized
412
    module_type:    Module type
413
    stage:          Module stage within model
414
    n:              Maximum number of feature maps to plot
415
    save_dir:       Directory to save results
416
    """
417
418
419
420
421
    if 'Detect' not in module_type:
        batch, channels, height, width = x.shape  # batch, channels, height, width
        if height > 1 and width > 1:
            f = f"stage{stage}_{module_type.split('.')[-1]}_features.png"  # filename

422
            blocks = torch.chunk(x[0].cpu(), channels, dim=0)  # select batch index 0, block by channels
423
            n = min(n, channels)  # number of plots
424
425
426
            fig, ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True)  # 8 rows x n/8 cols
            ax = ax.ravel()
            plt.subplots_adjust(wspace=0.05, hspace=0.05)
427
428
429
430
431
            for i in range(n):
                ax[i].imshow(blocks[i].squeeze())  # cmap='gray'
                ax[i].axis('off')

            print(f'Saving {save_dir / f}... ({n}/{channels})')
432
            plt.savefig(save_dir / f, dpi=300, bbox_inches='tight')