|
| 1 | +import torch |
| 2 | +import argparse |
| 3 | +import torch.backends.cudnn as cudnn |
| 4 | +import torch.multiprocessing as mp |
| 5 | +import torch.nn.functional as F |
| 6 | +import torch.optim as optim |
| 7 | +import torch.utils.data.distributed |
| 8 | +from torch.utils.tensorboard import SummaryWriter |
| 9 | +from torchvision import datasets, transforms, models |
| 10 | +import horovod.torch as hvd |
| 11 | +import os |
| 12 | +import math |
| 13 | +from tqdm import tqdm |
| 14 | + |
| 15 | +# Training settings |
| 16 | +parser = argparse.ArgumentParser(description='Elastic PyTorch ImageNet Example', |
| 17 | + formatter_class=argparse.ArgumentDefaultsHelpFormatter) |
| 18 | +parser.add_argument('--train-dir', default=os.path.expanduser('~/imagenet/train'), |
| 19 | + help='path to training data') |
| 20 | +parser.add_argument('--val-dir', default=os.path.expanduser('~/imagenet/validation'), |
| 21 | + help='path to validation data') |
| 22 | +parser.add_argument('--log-dir', default='./logs', |
| 23 | + help='tensorboard log directory') |
| 24 | +parser.add_argument('--checkpoint-format', default='./checkpoint-{epoch}.pth.tar', |
| 25 | + help='checkpoint file format') |
| 26 | + |
| 27 | +# Horoovd settings |
| 28 | +parser.add_argument('--fp16-allreduce', action='store_true', default=False, |
| 29 | + help='use fp16 compression during allreduce') |
| 30 | +parser.add_argument('--batches-per-allreduce', type=int, default=1, |
| 31 | + help='number of batches processed locally before ' |
| 32 | + 'executing allreduce across workers; it multiplies ' |
| 33 | + 'total batch size.') |
| 34 | +parser.add_argument('--use-adasum', action='store_true', default=False, |
| 35 | + help='use adasum algorithm to do reduction') |
| 36 | +parser.add_argument('--gradient-predivide-factor', type=float, default=1.0, |
| 37 | + help='apply gradient predivide factor in optimizer (default: 1.0)') |
| 38 | + |
| 39 | +# Elastic Horovod settings |
| 40 | +parser.add_argument('--batches-per-commit', type=int, default=100, |
| 41 | + help='number of batches processed before calling `state.commit()`; ' |
| 42 | + 'commits prevent losing progress if an error occurs, but slow ' |
| 43 | + 'down training.') |
| 44 | +parser.add_argument('--batches-per-host-check', type=int, default=10, |
| 45 | + help='number of batches processed before calling `state.check_host_updates()`; ' |
| 46 | + 'this check is very fast compared to state.commit() (which calls this ' |
| 47 | + 'as part of the commit process), but because still incurs some cost due ' |
| 48 | + 'to broadcast, so we may not want to perform it every batch.') |
| 49 | + |
| 50 | +# Default settings from https://arxiv.org/abs/1706.02677. |
| 51 | +parser.add_argument('--batch-size', type=int, default=32, |
| 52 | + help='input batch size for training') |
| 53 | +parser.add_argument('--val-batch-size', type=int, default=32, |
| 54 | + help='input batch size for validation') |
| 55 | +parser.add_argument('--epochs', type=int, default=90, |
| 56 | + help='number of epochs to train') |
| 57 | +parser.add_argument('--base-lr', type=float, default=0.0125, |
| 58 | + help='learning rate for a single GPU') |
| 59 | +parser.add_argument('--warmup-epochs', type=float, default=5, |
| 60 | + help='number of warmup epochs') |
| 61 | +parser.add_argument('--momentum', type=float, default=0.9, |
| 62 | + help='SGD momentum') |
| 63 | +parser.add_argument('--wd', type=float, default=0.00005, |
| 64 | + help='weight decay') |
| 65 | + |
| 66 | +parser.add_argument('--no-cuda', action='store_true', default=False, |
| 67 | + help='disables CUDA training') |
| 68 | +parser.add_argument('--seed', type=int, default=42, |
| 69 | + help='random seed') |
| 70 | + |
| 71 | + |
| 72 | +def train(state): |
| 73 | + model.train() |
| 74 | + epoch = state.epoch |
| 75 | + train_loss = Metric('train_loss') |
| 76 | + train_accuracy = Metric('train_accuracy') |
| 77 | + |
| 78 | + batch_offset = state.batch |
| 79 | + with tqdm(total=len(train_loader), |
| 80 | + desc='Train Epoch #{}'.format(epoch + 1), |
| 81 | + disable=not verbose) as t: |
| 82 | + for idx, (data, target) in enumerate(train_loader): |
| 83 | + # Elastic Horovod: update the current batch index this epoch |
| 84 | + # and commit / check for host updates. Do not check hosts when |
| 85 | + # we commit as it would be redundant. |
| 86 | + state.batch = batch_idx = batch_offset + idx |
| 87 | + if args.batches_per_commit > 0 and \ |
| 88 | + state.batch % args.batches_per_commit == 0: |
| 89 | + state.commit() |
| 90 | + elif args.batches_per_host_check > 0 and \ |
| 91 | + state.batch % args.batches_per_host_check == 0: |
| 92 | + state.check_host_updates() |
| 93 | + |
| 94 | + adjust_learning_rate(epoch, batch_idx) |
| 95 | + |
| 96 | + if args.cuda: |
| 97 | + data, target = data.cuda(), target.cuda() |
| 98 | + optimizer.zero_grad() |
| 99 | + # Split data into sub-batches of size batch_size |
| 100 | + for i in range(0, len(data), args.batch_size): |
| 101 | + data_batch = data[i:i + args.batch_size] |
| 102 | + target_batch = target[i:i + args.batch_size] |
| 103 | + output = model(data_batch) |
| 104 | + train_accuracy.update(accuracy(output, target_batch)) |
| 105 | + loss = F.cross_entropy(output, target_batch) |
| 106 | + train_loss.update(loss) |
| 107 | + # Average gradients among sub-batches |
| 108 | + loss.div_(math.ceil(float(len(data)) / args.batch_size)) |
| 109 | + loss.backward() |
| 110 | + |
| 111 | + # Elastic Horovod: record which samples were processed this batch |
| 112 | + # so we do not reprocess them if a reset event occurs |
| 113 | + state.train_sampler.record_batch(idx, allreduce_batch_size) |
| 114 | + |
| 115 | + # Gradient is applied across all ranks |
| 116 | + optimizer.step() |
| 117 | + t.set_postfix({'loss': train_loss.avg.item(), |
| 118 | + 'accuracy': 100. * train_accuracy.avg.item()}) |
| 119 | + |
| 120 | + t.update(1) |
| 121 | + |
| 122 | + if log_writer: |
| 123 | + log_writer.add_scalar('train/loss', train_loss.avg, epoch) |
| 124 | + log_writer.add_scalar('train/accuracy', train_accuracy.avg, epoch) |
| 125 | + |
| 126 | + state.commit() |
| 127 | + |
| 128 | + |
| 129 | +def validate(epoch): |
| 130 | + model.eval() |
| 131 | + val_loss = Metric('val_loss') |
| 132 | + val_accuracy = Metric('val_accuracy') |
| 133 | + |
| 134 | + with tqdm(total=len(val_loader), |
| 135 | + desc='Validate Epoch #{}'.format(epoch + 1), |
| 136 | + disable=not verbose) as t: |
| 137 | + with torch.no_grad(): |
| 138 | + for data, target in val_loader: |
| 139 | + if args.cuda: |
| 140 | + data, target = data.cuda(), target.cuda() |
| 141 | + output = model(data) |
| 142 | + |
| 143 | + val_loss.update(F.cross_entropy(output, target)) |
| 144 | + val_accuracy.update(accuracy(output, target)) |
| 145 | + t.set_postfix({'loss': val_loss.avg.item(), |
| 146 | + 'accuracy': 100. * val_accuracy.avg.item()}) |
| 147 | + t.update(1) |
| 148 | + |
| 149 | + if log_writer: |
| 150 | + log_writer.add_scalar('val/loss', val_loss.avg, epoch) |
| 151 | + log_writer.add_scalar('val/accuracy', val_accuracy.avg, epoch) |
| 152 | + |
| 153 | + |
| 154 | +# Horovod: using `lr = base_lr * hvd.size()` from the very beginning leads to worse final |
| 155 | +# accuracy. Scale the learning rate `lr = base_lr` ---> `lr = base_lr * hvd.size()` during |
| 156 | +# the first five epochs. See https://arxiv.org/abs/1706.02677 for details. |
| 157 | +# After the warmup reduce learning rate by 10 on the 30th, 60th and 80th epochs. |
| 158 | +def adjust_learning_rate(epoch, batch_idx): |
| 159 | + if epoch < args.warmup_epochs: |
| 160 | + epoch += float(batch_idx + 1) / len(train_loader) |
| 161 | + lr_adj = 1. / hvd.size() * (epoch * (hvd.size() - 1) / args.warmup_epochs + 1) |
| 162 | + elif epoch < 30: |
| 163 | + lr_adj = 1. |
| 164 | + elif epoch < 60: |
| 165 | + lr_adj = 1e-1 |
| 166 | + elif epoch < 80: |
| 167 | + lr_adj = 1e-2 |
| 168 | + else: |
| 169 | + lr_adj = 1e-3 |
| 170 | + for param_group in optimizer.param_groups: |
| 171 | + param_group['lr'] = args.base_lr * hvd.size() * args.batches_per_allreduce * lr_adj |
| 172 | + |
| 173 | + |
| 174 | +def accuracy(output, target): |
| 175 | + # get the index of the max log-probability |
| 176 | + pred = output.max(1, keepdim=True)[1] |
| 177 | + return pred.eq(target.view_as(pred)).cpu().float().mean() |
| 178 | + |
| 179 | + |
| 180 | +def save_checkpoint(epoch): |
| 181 | + if hvd.rank() == 0: |
| 182 | + filepath = args.checkpoint_format.format(epoch=epoch + 1) |
| 183 | + state = { |
| 184 | + 'model': model.state_dict(), |
| 185 | + 'optimizer': optimizer.state_dict(), |
| 186 | + } |
| 187 | + torch.save(state, filepath) |
| 188 | + |
| 189 | + |
| 190 | +def end_epoch(state): |
| 191 | + state.epoch += 1 |
| 192 | + state.batch = 0 |
| 193 | + state.train_sampler.set_epoch(state.epoch) |
| 194 | + state.commit() |
| 195 | + |
| 196 | + |
| 197 | +# Horovod: average metrics from distributed training. |
| 198 | +class Metric(object): |
| 199 | + def __init__(self, name): |
| 200 | + self.name = name |
| 201 | + self.sum = torch.tensor(0.) |
| 202 | + self.n = torch.tensor(0.) |
| 203 | + |
| 204 | + def update(self, val): |
| 205 | + self.sum += hvd.allreduce(val.detach().cpu(), name=self.name) |
| 206 | + self.n += 1 |
| 207 | + |
| 208 | + @property |
| 209 | + def avg(self): |
| 210 | + return self.sum / self.n |
| 211 | + |
| 212 | + |
| 213 | +@hvd.elastic.run |
| 214 | +def full_train(state): |
| 215 | + while state.epoch < args.epochs: |
| 216 | + train(state) |
| 217 | + validate(state.epoch) |
| 218 | + save_checkpoint(state.epoch) |
| 219 | + end_epoch(state) |
| 220 | + |
| 221 | + |
| 222 | +if __name__ == '__main__': |
| 223 | + args = parser.parse_args() |
| 224 | + args.cuda = not args.no_cuda and torch.cuda.is_available() |
| 225 | + |
| 226 | + allreduce_batch_size = args.batch_size * args.batches_per_allreduce |
| 227 | + |
| 228 | + hvd.init() |
| 229 | + torch.manual_seed(args.seed) |
| 230 | + |
| 231 | + if args.cuda: |
| 232 | + # Horovod: pin GPU to local rank. |
| 233 | + torch.cuda.set_device(hvd.local_rank()) |
| 234 | + torch.cuda.manual_seed(args.seed) |
| 235 | + |
| 236 | + cudnn.benchmark = True |
| 237 | + |
| 238 | + # Horovod: print logs on the first worker. |
| 239 | + verbose = 1 if hvd.rank() == 0 else 0 |
| 240 | + |
| 241 | + # Horovod: write TensorBoard logs on first worker. |
| 242 | + log_writer = SummaryWriter(args.log_dir) if hvd.rank() == 0 else None |
| 243 | + |
| 244 | + # Horovod: limit # of CPU threads to be used per worker. |
| 245 | + torch.set_num_threads(4) |
| 246 | + |
| 247 | + kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {} |
| 248 | + # When supported, use 'forkserver' to spawn dataloader workers instead of 'fork' to prevent |
| 249 | + # issues with Infiniband implementations that are not fork-safe |
| 250 | + if (kwargs.get('num_workers', 0) > 0 and hasattr(mp, '_supports_context') and |
| 251 | + mp._supports_context and 'forkserver' in mp.get_all_start_methods()): |
| 252 | + kwargs['multiprocessing_context'] = 'forkserver' |
| 253 | + |
| 254 | + train_dataset = \ |
| 255 | + datasets.ImageFolder(args.train_dir, |
| 256 | + transform=transforms.Compose([ |
| 257 | + transforms.RandomResizedCrop(224), |
| 258 | + transforms.RandomHorizontalFlip(), |
| 259 | + transforms.ToTensor(), |
| 260 | + transforms.Normalize(mean=[0.485, 0.456, 0.406], |
| 261 | + std=[0.229, 0.224, 0.225]) |
| 262 | + ])) |
| 263 | + # Elastic Horovod: use ElasticSampler to partition data among workers. |
| 264 | + train_sampler = hvd.elastic.ElasticSampler(train_dataset) |
| 265 | + train_loader = torch.utils.data.DataLoader( |
| 266 | + train_dataset, |
| 267 | + batch_size=allreduce_batch_size, |
| 268 | + sampler=train_sampler, |
| 269 | + **kwargs) |
| 270 | + |
| 271 | + val_dataset = \ |
| 272 | + datasets.ImageFolder(args.val_dir, |
| 273 | + transform=transforms.Compose([ |
| 274 | + transforms.Resize(256), |
| 275 | + transforms.CenterCrop(224), |
| 276 | + transforms.ToTensor(), |
| 277 | + transforms.Normalize(mean=[0.485, 0.456, 0.406], |
| 278 | + std=[0.229, 0.224, 0.225]) |
| 279 | + ])) |
| 280 | + val_sampler = hvd.elastic.ElasticSampler(val_dataset) |
| 281 | + val_loader = torch.utils.data.DataLoader( |
| 282 | + val_dataset, |
| 283 | + batch_size=args.val_batch_size, |
| 284 | + sampler=val_sampler, |
| 285 | + **kwargs) |
| 286 | + |
| 287 | + # Set up standard ResNet-50 model. |
| 288 | + model = models.resnet50() |
| 289 | + |
| 290 | + # By default, Adasum doesn't need scaling up learning rate. |
| 291 | + # For sum/average with gradient Accumulation: scale learning rate by batches_per_allreduce |
| 292 | + lr_scaler = args.batches_per_allreduce * hvd.size() if not args.use_adasum else 1 |
| 293 | + |
| 294 | + if args.cuda: |
| 295 | + # Move model to GPU. |
| 296 | + model.cuda() |
| 297 | + # If using GPU Adasum allreduce, scale learning rate by local_size. |
| 298 | + if args.use_adasum and hvd.nccl_built(): |
| 299 | + lr_scaler = args.batches_per_allreduce * hvd.local_size() |
| 300 | + |
| 301 | + # Horovod: scale learning rate by the number of GPUs. |
| 302 | + optimizer = optim.SGD(model.parameters(), |
| 303 | + lr=(args.base_lr * |
| 304 | + lr_scaler), |
| 305 | + momentum=args.momentum, weight_decay=args.wd) |
| 306 | + |
| 307 | + # Horovod: (optional) compression algorithm. |
| 308 | + compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none |
| 309 | + |
| 310 | + # Horovod: wrap optimizer with DistributedOptimizer. |
| 311 | + optimizer = hvd.DistributedOptimizer( |
| 312 | + optimizer, named_parameters=model.named_parameters(), |
| 313 | + compression=compression, |
| 314 | + backward_passes_per_step=args.batches_per_allreduce, |
| 315 | + op=hvd.Adasum if args.use_adasum else hvd.Average, |
| 316 | + gradient_predivide_factor=args.gradient_predivide_factor) |
| 317 | + |
| 318 | + # Restore from a previous checkpoint, if initial_epoch is specified. |
| 319 | + # Horovod: restore on the first worker which will broadcast weights to other workers. |
| 320 | + resume_from_epoch = 0 |
| 321 | + if hvd.rank() == 0: |
| 322 | + for try_epoch in range(args.epochs, 0, -1): |
| 323 | + if os.path.exists(args.checkpoint_format.format(epoch=try_epoch)): |
| 324 | + resume_from_epoch = try_epoch |
| 325 | + break |
| 326 | + |
| 327 | + if resume_from_epoch > 0: |
| 328 | + filepath = args.checkpoint_format.format(epoch=resume_from_epoch) |
| 329 | + checkpoint = torch.load(filepath) |
| 330 | + model.load_state_dict(checkpoint['model']) |
| 331 | + optimizer.load_state_dict(checkpoint['optimizer']) |
| 332 | + |
| 333 | + state = hvd.elastic.TorchState(model=model, |
| 334 | + optimizer=optimizer, |
| 335 | + train_sampler=train_sampler, |
| 336 | + val_sampler=val_sampler, |
| 337 | + epoch=resume_from_epoch, |
| 338 | + batch=0) |
| 339 | + |
| 340 | + full_train(state) |
0 commit comments