Unverified Commit b74929c9 authored by Kalen Michael's avatar Kalen Michael Committed by GitHub
Browse files

Add `train.py` and `val.py` callbacks (#4220)



* added callbacks

* Update callbacks.py

* Update train.py

* Update val.py

* Fix CamlCase add staticmethod

* Refactor logger into callbacks

* Cleanup

* New callback on_val_image_end()

* Add curves and results images to TensorBoard

Co-authored-by: default avatarGlenn Jocher <glenn.jocher@ultralytics.com>
parent d8f18834
......@@ -34,7 +34,7 @@ from utils.autoanchor import check_anchors
from utils.datasets import create_dataloader
from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
check_requirements, print_mutation, set_logging, one_cycle, colorstr
check_requirements, print_mutation, set_logging, one_cycle, colorstr, methods
from utils.downloads import attempt_download
from utils.loss import ComputeLoss
from utils.plots import plot_labels, plot_evolution
......@@ -42,6 +42,7 @@ from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_di
from utils.loggers.wandb.wandb_utils import check_wandb_resume
from utils.metrics import fitness
from utils.loggers import Loggers
from utils.callbacks import Callbacks
LOGGER = logging.getLogger(__name__)
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
......@@ -52,6 +53,7 @@ WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
def train(hyp, # path/to/hyp.yaml or hyp dictionary
opt,
device,
callbacks=Callbacks()
):
save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze, = \
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
......@@ -77,12 +79,16 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# Loggers
if RANK in [-1, 0]:
loggers = Loggers(save_dir, weights, opt, hyp, LOGGER).start() # loggers dict
loggers = Loggers(save_dir, weights, opt, hyp, LOGGER) # loggers instance
if loggers.wandb:
data_dict = loggers.wandb.data_dict
if resume:
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp
# Register actions
for k in methods(loggers):
callbacks.register_action(k, callback=getattr(loggers, k))
# Config
plots = not evolve # create plots
cuda = device.type != 'cpu'
......@@ -215,13 +221,15 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
# model._initialize_biases(cf.to(device))
if plots:
plot_labels(labels, names, save_dir, loggers)
plot_labels(labels, names, save_dir)
# Anchors
if not opt.noautoanchor:
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
model.half().float() # pre-reduce anchor precision
callbacks.on_pretrain_routine_end()
# DDP mode
if cuda and RANK != -1:
model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
......@@ -329,8 +337,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % (
f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
loggers.on_train_batch_end(ni, model, imgs, targets, paths, plots)
callbacks.on_train_batch_end(ni, model, imgs, targets, paths, plots)
# end batch ------------------------------------------------------------------------------------------------
# Scheduler
......@@ -339,7 +346,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
if RANK in [-1, 0]:
# mAP
loggers.on_train_epoch_end(epoch)
callbacks.on_train_epoch_end(epoch=epoch)
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
final_epoch = epoch + 1 == epochs
if not noval or final_epoch: # Calculate mAP
......@@ -353,14 +360,14 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
save_json=is_coco and final_epoch,
verbose=nc < 50 and final_epoch,
plots=plots and final_epoch,
loggers=loggers,
callbacks=callbacks,
compute_loss=compute_loss)
# Update best mAP
fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
if fi > best_fitness:
best_fitness = fi
loggers.on_train_val_end(mloss, results, lr, epoch, best_fitness, fi)
callbacks.on_fit_epoch_end(mloss, results, lr, epoch, best_fitness, fi)
# Save model
if (not nosave) or (final_epoch and not evolve): # if save
......@@ -377,7 +384,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
if best_fitness == fi:
torch.save(ckpt, best)
del ckpt
loggers.on_model_save(last, epoch, final_epoch, best_fitness, fi)
callbacks.on_model_save(last, epoch, final_epoch, best_fitness, fi)
# end epoch ----------------------------------------------------------------------------------------------------
# end training -----------------------------------------------------------------------------------------------------
......@@ -400,7 +407,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
for f in last, best:
if f.exists():
strip_optimizer(f) # strip optimizers
loggers.on_train_end(last, best, plots)
callbacks.on_train_end(last, best, plots, epoch)
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
torch.cuda.empty_cache()
return results
......@@ -448,6 +456,7 @@ def parse_opt(known=False):
def main(opt):
# Checks
set_logging(RANK)
if RANK in [-1, 0]:
print(colorstr('train: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
......
#!/usr/bin/env python
class Callbacks:
""""
Handles all registered callbacks for YOLOv5 Hooks
"""
_callbacks = {
'on_pretrain_routine_start': [],
'on_pretrain_routine_end': [],
'on_train_start': [],
'on_train_epoch_start': [],
'on_train_batch_start': [],
'optimizer_step': [],
'on_before_zero_grad': [],
'on_train_batch_end': [],
'on_train_epoch_end': [],
'on_val_start': [],
'on_val_batch_start': [],
'on_val_image_end': [],
'on_val_batch_end': [],
'on_val_end': [],
'on_fit_epoch_end': [], # fit = train + val
'on_model_save': [],
'on_train_end': [],
'teardown': [],
}
def __init__(self):
return
def register_action(self, hook, name='', callback=None):
"""
Register a new action to a callback hook
Args:
hook The callback hook name to register the action to
name The name of the action
callback The callback to fire
"""
assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
assert callable(callback), f"callback '{callback}' is not callable"
self._callbacks[hook].append({'name': name, 'callback': callback})
def get_registered_actions(self, hook=None):
""""
Returns all the registered actions by callback hook
Args:
hook The name of the hook to check, defaults to all
"""
if hook:
return self._callbacks[hook]
else:
return self._callbacks
@staticmethod
def run_callbacks(register, *args, **kwargs):
"""
Loop through the registered actions and fire all callbacks
"""
for logger in register:
# print(f"Running callbacks.{logger['callback'].__name__}()")
logger['callback'](*args, **kwargs)
def on_pretrain_routine_start(self, *args, **kwargs):
"""
Fires all registered callbacks at the start of each pretraining routine
"""
self.run_callbacks(self._callbacks['on_pretrain_routine_start'], *args, **kwargs)
def on_pretrain_routine_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of each pretraining routine
"""
self.run_callbacks(self._callbacks['on_pretrain_routine_end'], *args, **kwargs)
def on_train_start(self, *args, **kwargs):
"""
Fires all registered callbacks at the start of each training
"""
self.run_callbacks(self._callbacks['on_train_start'], *args, **kwargs)
def on_train_epoch_start(self, *args, **kwargs):
"""
Fires all registered callbacks at the start of each training epoch
"""
self.run_callbacks(self._callbacks['on_train_epoch_start'], *args, **kwargs)
def on_train_batch_start(self, *args, **kwargs):
"""
Fires all registered callbacks at the start of each training batch
"""
self.run_callbacks(self._callbacks['on_train_batch_start'], *args, **kwargs)
def optimizer_step(self, *args, **kwargs):
"""
Fires all registered callbacks on each optimizer step
"""
self.run_callbacks(self._callbacks['optimizer_step'], *args, **kwargs)
def on_before_zero_grad(self, *args, **kwargs):
"""
Fires all registered callbacks before zero grad
"""
self.run_callbacks(self._callbacks['on_before_zero_grad'], *args, **kwargs)
def on_train_batch_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of each training batch
"""
self.run_callbacks(self._callbacks['on_train_batch_end'], *args, **kwargs)
def on_train_epoch_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of each training epoch
"""
self.run_callbacks(self._callbacks['on_train_epoch_end'], *args, **kwargs)
def on_val_start(self, *args, **kwargs):
"""
Fires all registered callbacks at the start of the validation
"""
self.run_callbacks(self._callbacks['on_val_start'], *args, **kwargs)
def on_val_batch_start(self, *args, **kwargs):
"""
Fires all registered callbacks at the start of each validation batch
"""
self.run_callbacks(self._callbacks['on_val_batch_start'], *args, **kwargs)
def on_val_image_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of each val image
"""
self.run_callbacks(self._callbacks['on_val_image_end'], *args, **kwargs)
def on_val_batch_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of each validation batch
"""
self.run_callbacks(self._callbacks['on_val_batch_end'], *args, **kwargs)
def on_val_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of the validation
"""
self.run_callbacks(self._callbacks['on_val_end'], *args, **kwargs)
def on_fit_epoch_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of each fit (train+val) epoch
"""
self.run_callbacks(self._callbacks['on_fit_epoch_end'], *args, **kwargs)
def on_model_save(self, *args, **kwargs):
"""
Fires all registered callbacks after each model save
"""
self.run_callbacks(self._callbacks['on_model_save'], *args, **kwargs)
def on_train_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of training
"""
self.run_callbacks(self._callbacks['on_train_end'], *args, **kwargs)
def teardown(self, *args, **kwargs):
"""
Fires all registered callbacks before teardown
"""
self.run_callbacks(self._callbacks['teardown'], *args, **kwargs)
......@@ -67,6 +67,11 @@ def try_except(func):
return handler
def methods(instance):
# Get class/instance methods
return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
def set_logging(rank=-1, verbose=True):
logging.basicConfig(
format="%(message)s",
......
......@@ -29,10 +29,12 @@ class Loggers():
self.hyp = hyp
self.logger = logger # for printing results to console
self.include = include
self.keys = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss
'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', # metrics
'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss
'x/lr0', 'x/lr1', 'x/lr2'] # params
for k in LOGGERS:
setattr(self, k, None) # init empty logger dictionary
def start(self):
self.csv = True # always log to csv
# Message
......@@ -57,7 +59,11 @@ class Loggers():
else:
self.wandb = None
return self
def on_pretrain_routine_end(self):
# Callback runs on pre-train routine end
paths = self.save_dir.glob('*labels*.jpg') # training labels
if self.wandb:
self.wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in paths]})
def on_train_batch_end(self, ni, model, imgs, targets, paths, plots):
# Callback runs on train batch end
......@@ -78,8 +84,8 @@ class Loggers():
if self.wandb:
self.wandb.current_epoch = epoch + 1
def on_val_batch_end(self, pred, predn, path, names, im):
# Callback runs on train batch end
def on_val_image_end(self, pred, predn, path, names, im):
# Callback runs on val image end
if self.wandb:
self.wandb.val_one_image(pred, predn, path, names, im)
......@@ -89,25 +95,20 @@ class Loggers():
files = sorted(self.save_dir.glob('val*.jpg'))
self.wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in files]})
def on_train_val_end(self, mloss, results, lr, epoch, best_fitness, fi):
# Callback runs on val end during training
def on_fit_epoch_end(self, mloss, results, lr, epoch, best_fitness, fi):
# Callback runs at the end of each fit (train+val) epoch
vals = list(mloss) + list(results) + lr
keys = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss
'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', # metrics
'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss
'x/lr0', 'x/lr1', 'x/lr2'] # params
x = {k: v for k, v in zip(keys, vals)} # dict
x = {k: v for k, v in zip(self.keys, vals)} # dict
if self.csv:
file = self.save_dir / 'results.csv'
n = len(x) + 1 # number of cols
s = '' if file.exists() else (('%20s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n') # add header
s = '' if file.exists() else (('%20s,' * n % tuple(['epoch'] + self.keys)).rstrip(',') + '\n') # add header
with open(file, 'a') as f:
f.write(s + ('%20.5g,' * n % tuple([epoch] + vals)).rstrip(',') + '\n')
if self.tb:
for k, v in x.items():
self.tb.add_scalar(k, v, epoch) # TensorBoard
self.tb.add_scalar(k, v, epoch)
if self.wandb:
self.wandb.log(x)
......@@ -119,20 +120,22 @@ class Loggers():
if ((epoch + 1) % self.opt.save_period == 0 and not final_epoch) and self.opt.save_period != -1:
self.wandb.log_model(last.parent, self.opt, epoch, fi, best_model=best_fitness == fi)
def on_train_end(self, last, best, plots):
def on_train_end(self, last, best, plots, epoch):
# Callback runs on training end
if plots:
plot_results(dir=self.save_dir) # save results.png
files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
files = [(self.save_dir / f) for f in files if (self.save_dir / f).exists()] # filter
if self.tb:
from PIL import Image
import numpy as np
for f in files:
self.tb.add_image(f.stem, np.asarray(Image.open(f)), epoch, dataformats='HWC')
if self.wandb:
wandb.log({"Results": [wandb.Image(str(f), caption=f.name) for f in files]})
wandb.log_artifact(str(best if best.exists() else last), type='model',
name='run_' + self.wandb.wandb_run.id + '_model',
aliases=['latest', 'best', 'stripped'])
self.wandb.finish_run()
def log_images(self, paths):
# Log images
if self.wandb:
self.wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in paths]})
......@@ -281,7 +281,7 @@ def plot_study_txt(path='', x=None): # from utils.plots import *; plot_study_tx
plt.savefig(str(Path(path).name) + '.png', dpi=300)
def plot_labels(labels, names=(), save_dir=Path(''), loggers=None):
def plot_labels(labels, names=(), save_dir=Path('')):
# plot dataset labels
print('Plotting labels... ')
c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
......@@ -324,10 +324,6 @@ def plot_labels(labels, names=(), save_dir=Path(''), loggers=None):
matplotlib.use('Agg')
plt.close()
# loggers
if loggers:
loggers.log_images(save_dir.glob('*labels*.jpg'))
def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution()
# Plot hyperparameter evolution results in evolve.txt
......
......@@ -25,7 +25,7 @@ from utils.general import coco80_to_coco91_class, check_dataset, check_file, che
from utils.metrics import ap_per_class, ConfusionMatrix
from utils.plots import plot_images, output_to_target, plot_study_txt
from utils.torch_utils import select_device, time_sync
from utils.loggers import Loggers
from utils.callbacks import Callbacks
def save_one_txt(predn, save_conf, shape, file):
......@@ -97,7 +97,7 @@ def run(data,
dataloader=None,
save_dir=Path(''),
plots=True,
loggers=Loggers(),
callbacks=Callbacks(),
compute_loss=None,
):
# Initialize/load model and set device
......@@ -213,7 +213,7 @@ def run(data,
save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / (path.stem + '.txt'))
if save_json:
save_one_json(predn, jdict, path, class_map) # append to COCO-JSON dictionary
loggers.on_val_batch_end(pred, predn, path, names, img[si])
callbacks.on_val_image_end(pred, predn, path, names, img[si])
# Plot images
if plots and batch_i < 3:
......@@ -250,7 +250,7 @@ def run(data,
# Plots
if plots:
confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
loggers.on_val_end()
callbacks.on_val_end()
# Save JSON
if save_json and len(jdict):
......@@ -282,7 +282,7 @@ def run(data,
model.float() # for training
if not training:
s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
print(f"Results saved to {save_dir}{s}")
print(f"Results saved to {colorstr('bold', save_dir)}{s}")
maps = np.zeros(nc) + map
for i, c in enumerate(ap_class):
maps[c] = ap[i]
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment