Skip to content

Commit 32e5fdb

Browse files
authored
Added ElasticSampler and PyTorch Elastic ImageNet example (horovod#2297)
Signed-off-by: Travis Addair <[email protected]>
1 parent 41b8152 commit 32e5fdb

File tree

12 files changed

+881
-101
lines changed

12 files changed

+881
-101
lines changed

docs/api.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,26 @@ API
44
horovod.tensorflow
55
------------------
66
.. automodule:: horovod.tensorflow
7+
.. automodule:: horovod.tensorflow.elastic
78

89
horovod.tensorflow.keras
910
------------------------
1011
.. automodule:: horovod.tensorflow.keras
1112
.. automodule:: horovod.tensorflow.keras.callbacks
1213
:special-members: __init__
14+
.. automodule:: horovod.tensorflow.keras.elastic
1315

1416
horovod.keras
1517
-------------
1618
.. automodule:: horovod.keras
1719
.. automodule:: horovod.keras.callbacks
1820
:special-members: __init__
21+
.. automodule:: horovod.keras.elastic
1922

2023
horovod.torch
2124
-------------
2225
.. automodule:: horovod.torch
26+
.. automodule:: horovod.torch.elastic
2327

2428
horovod.mxnet
2529
-------------

docs/mocks.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def _dummy():
7272
'torch.nn.modules.batchnorm',
7373
'torch.utils',
7474
'torch.utils.data',
75+
'torch.utils.data.distributed',
7576
'torch.utils.tensorboard',
7677

7778
'mxnet',
@@ -112,6 +113,11 @@ def _dummy():
112113
}
113114
},
114115
},
116+
'utils': {
117+
'data': {
118+
'Sampler': MagicMock,
119+
},
120+
},
115121
},
116122
'pyspark': {
117123
'ml': {
Lines changed: 340 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,340 @@
1+
import torch
2+
import argparse
3+
import torch.backends.cudnn as cudnn
4+
import torch.multiprocessing as mp
5+
import torch.nn.functional as F
6+
import torch.optim as optim
7+
import torch.utils.data.distributed
8+
from torch.utils.tensorboard import SummaryWriter
9+
from torchvision import datasets, transforms, models
10+
import horovod.torch as hvd
11+
import os
12+
import math
13+
from tqdm import tqdm
14+
15+
# Training settings
16+
parser = argparse.ArgumentParser(description='Elastic PyTorch ImageNet Example',
17+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
18+
parser.add_argument('--train-dir', default=os.path.expanduser('~/imagenet/train'),
19+
help='path to training data')
20+
parser.add_argument('--val-dir', default=os.path.expanduser('~/imagenet/validation'),
21+
help='path to validation data')
22+
parser.add_argument('--log-dir', default='./logs',
23+
help='tensorboard log directory')
24+
parser.add_argument('--checkpoint-format', default='./checkpoint-{epoch}.pth.tar',
25+
help='checkpoint file format')
26+
27+
# Horoovd settings
28+
parser.add_argument('--fp16-allreduce', action='store_true', default=False,
29+
help='use fp16 compression during allreduce')
30+
parser.add_argument('--batches-per-allreduce', type=int, default=1,
31+
help='number of batches processed locally before '
32+
'executing allreduce across workers; it multiplies '
33+
'total batch size.')
34+
parser.add_argument('--use-adasum', action='store_true', default=False,
35+
help='use adasum algorithm to do reduction')
36+
parser.add_argument('--gradient-predivide-factor', type=float, default=1.0,
37+
help='apply gradient predivide factor in optimizer (default: 1.0)')
38+
39+
# Elastic Horovod settings
40+
parser.add_argument('--batches-per-commit', type=int, default=100,
41+
help='number of batches processed before calling `state.commit()`; '
42+
'commits prevent losing progress if an error occurs, but slow '
43+
'down training.')
44+
parser.add_argument('--batches-per-host-check', type=int, default=10,
45+
help='number of batches processed before calling `state.check_host_updates()`; '
46+
'this check is very fast compared to state.commit() (which calls this '
47+
'as part of the commit process), but because still incurs some cost due '
48+
'to broadcast, so we may not want to perform it every batch.')
49+
50+
# Default settings from https://arxiv.org/abs/1706.02677.
51+
parser.add_argument('--batch-size', type=int, default=32,
52+
help='input batch size for training')
53+
parser.add_argument('--val-batch-size', type=int, default=32,
54+
help='input batch size for validation')
55+
parser.add_argument('--epochs', type=int, default=90,
56+
help='number of epochs to train')
57+
parser.add_argument('--base-lr', type=float, default=0.0125,
58+
help='learning rate for a single GPU')
59+
parser.add_argument('--warmup-epochs', type=float, default=5,
60+
help='number of warmup epochs')
61+
parser.add_argument('--momentum', type=float, default=0.9,
62+
help='SGD momentum')
63+
parser.add_argument('--wd', type=float, default=0.00005,
64+
help='weight decay')
65+
66+
parser.add_argument('--no-cuda', action='store_true', default=False,
67+
help='disables CUDA training')
68+
parser.add_argument('--seed', type=int, default=42,
69+
help='random seed')
70+
71+
72+
def train(state):
73+
model.train()
74+
epoch = state.epoch
75+
train_loss = Metric('train_loss')
76+
train_accuracy = Metric('train_accuracy')
77+
78+
batch_offset = state.batch
79+
with tqdm(total=len(train_loader),
80+
desc='Train Epoch #{}'.format(epoch + 1),
81+
disable=not verbose) as t:
82+
for idx, (data, target) in enumerate(train_loader):
83+
# Elastic Horovod: update the current batch index this epoch
84+
# and commit / check for host updates. Do not check hosts when
85+
# we commit as it would be redundant.
86+
state.batch = batch_idx = batch_offset + idx
87+
if args.batches_per_commit > 0 and \
88+
state.batch % args.batches_per_commit == 0:
89+
state.commit()
90+
elif args.batches_per_host_check > 0 and \
91+
state.batch % args.batches_per_host_check == 0:
92+
state.check_host_updates()
93+
94+
adjust_learning_rate(epoch, batch_idx)
95+
96+
if args.cuda:
97+
data, target = data.cuda(), target.cuda()
98+
optimizer.zero_grad()
99+
# Split data into sub-batches of size batch_size
100+
for i in range(0, len(data), args.batch_size):
101+
data_batch = data[i:i + args.batch_size]
102+
target_batch = target[i:i + args.batch_size]
103+
output = model(data_batch)
104+
train_accuracy.update(accuracy(output, target_batch))
105+
loss = F.cross_entropy(output, target_batch)
106+
train_loss.update(loss)
107+
# Average gradients among sub-batches
108+
loss.div_(math.ceil(float(len(data)) / args.batch_size))
109+
loss.backward()
110+
111+
# Elastic Horovod: record which samples were processed this batch
112+
# so we do not reprocess them if a reset event occurs
113+
state.train_sampler.record_batch(idx, allreduce_batch_size)
114+
115+
# Gradient is applied across all ranks
116+
optimizer.step()
117+
t.set_postfix({'loss': train_loss.avg.item(),
118+
'accuracy': 100. * train_accuracy.avg.item()})
119+
120+
t.update(1)
121+
122+
if log_writer:
123+
log_writer.add_scalar('train/loss', train_loss.avg, epoch)
124+
log_writer.add_scalar('train/accuracy', train_accuracy.avg, epoch)
125+
126+
state.commit()
127+
128+
129+
def validate(epoch):
130+
model.eval()
131+
val_loss = Metric('val_loss')
132+
val_accuracy = Metric('val_accuracy')
133+
134+
with tqdm(total=len(val_loader),
135+
desc='Validate Epoch #{}'.format(epoch + 1),
136+
disable=not verbose) as t:
137+
with torch.no_grad():
138+
for data, target in val_loader:
139+
if args.cuda:
140+
data, target = data.cuda(), target.cuda()
141+
output = model(data)
142+
143+
val_loss.update(F.cross_entropy(output, target))
144+
val_accuracy.update(accuracy(output, target))
145+
t.set_postfix({'loss': val_loss.avg.item(),
146+
'accuracy': 100. * val_accuracy.avg.item()})
147+
t.update(1)
148+
149+
if log_writer:
150+
log_writer.add_scalar('val/loss', val_loss.avg, epoch)
151+
log_writer.add_scalar('val/accuracy', val_accuracy.avg, epoch)
152+
153+
154+
# Horovod: using `lr = base_lr * hvd.size()` from the very beginning leads to worse final
155+
# accuracy. Scale the learning rate `lr = base_lr` ---> `lr = base_lr * hvd.size()` during
156+
# the first five epochs. See https://arxiv.org/abs/1706.02677 for details.
157+
# After the warmup reduce learning rate by 10 on the 30th, 60th and 80th epochs.
158+
def adjust_learning_rate(epoch, batch_idx):
159+
if epoch < args.warmup_epochs:
160+
epoch += float(batch_idx + 1) / len(train_loader)
161+
lr_adj = 1. / hvd.size() * (epoch * (hvd.size() - 1) / args.warmup_epochs + 1)
162+
elif epoch < 30:
163+
lr_adj = 1.
164+
elif epoch < 60:
165+
lr_adj = 1e-1
166+
elif epoch < 80:
167+
lr_adj = 1e-2
168+
else:
169+
lr_adj = 1e-3
170+
for param_group in optimizer.param_groups:
171+
param_group['lr'] = args.base_lr * hvd.size() * args.batches_per_allreduce * lr_adj
172+
173+
174+
def accuracy(output, target):
175+
# get the index of the max log-probability
176+
pred = output.max(1, keepdim=True)[1]
177+
return pred.eq(target.view_as(pred)).cpu().float().mean()
178+
179+
180+
def save_checkpoint(epoch):
181+
if hvd.rank() == 0:
182+
filepath = args.checkpoint_format.format(epoch=epoch + 1)
183+
state = {
184+
'model': model.state_dict(),
185+
'optimizer': optimizer.state_dict(),
186+
}
187+
torch.save(state, filepath)
188+
189+
190+
def end_epoch(state):
191+
state.epoch += 1
192+
state.batch = 0
193+
state.train_sampler.set_epoch(state.epoch)
194+
state.commit()
195+
196+
197+
# Horovod: average metrics from distributed training.
198+
class Metric(object):
199+
def __init__(self, name):
200+
self.name = name
201+
self.sum = torch.tensor(0.)
202+
self.n = torch.tensor(0.)
203+
204+
def update(self, val):
205+
self.sum += hvd.allreduce(val.detach().cpu(), name=self.name)
206+
self.n += 1
207+
208+
@property
209+
def avg(self):
210+
return self.sum / self.n
211+
212+
213+
@hvd.elastic.run
214+
def full_train(state):
215+
while state.epoch < args.epochs:
216+
train(state)
217+
validate(state.epoch)
218+
save_checkpoint(state.epoch)
219+
end_epoch(state)
220+
221+
222+
if __name__ == '__main__':
223+
args = parser.parse_args()
224+
args.cuda = not args.no_cuda and torch.cuda.is_available()
225+
226+
allreduce_batch_size = args.batch_size * args.batches_per_allreduce
227+
228+
hvd.init()
229+
torch.manual_seed(args.seed)
230+
231+
if args.cuda:
232+
# Horovod: pin GPU to local rank.
233+
torch.cuda.set_device(hvd.local_rank())
234+
torch.cuda.manual_seed(args.seed)
235+
236+
cudnn.benchmark = True
237+
238+
# Horovod: print logs on the first worker.
239+
verbose = 1 if hvd.rank() == 0 else 0
240+
241+
# Horovod: write TensorBoard logs on first worker.
242+
log_writer = SummaryWriter(args.log_dir) if hvd.rank() == 0 else None
243+
244+
# Horovod: limit # of CPU threads to be used per worker.
245+
torch.set_num_threads(4)
246+
247+
kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {}
248+
# When supported, use 'forkserver' to spawn dataloader workers instead of 'fork' to prevent
249+
# issues with Infiniband implementations that are not fork-safe
250+
if (kwargs.get('num_workers', 0) > 0 and hasattr(mp, '_supports_context') and
251+
mp._supports_context and 'forkserver' in mp.get_all_start_methods()):
252+
kwargs['multiprocessing_context'] = 'forkserver'
253+
254+
train_dataset = \
255+
datasets.ImageFolder(args.train_dir,
256+
transform=transforms.Compose([
257+
transforms.RandomResizedCrop(224),
258+
transforms.RandomHorizontalFlip(),
259+
transforms.ToTensor(),
260+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
261+
std=[0.229, 0.224, 0.225])
262+
]))
263+
# Elastic Horovod: use ElasticSampler to partition data among workers.
264+
train_sampler = hvd.elastic.ElasticSampler(train_dataset)
265+
train_loader = torch.utils.data.DataLoader(
266+
train_dataset,
267+
batch_size=allreduce_batch_size,
268+
sampler=train_sampler,
269+
**kwargs)
270+
271+
val_dataset = \
272+
datasets.ImageFolder(args.val_dir,
273+
transform=transforms.Compose([
274+
transforms.Resize(256),
275+
transforms.CenterCrop(224),
276+
transforms.ToTensor(),
277+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
278+
std=[0.229, 0.224, 0.225])
279+
]))
280+
val_sampler = hvd.elastic.ElasticSampler(val_dataset)
281+
val_loader = torch.utils.data.DataLoader(
282+
val_dataset,
283+
batch_size=args.val_batch_size,
284+
sampler=val_sampler,
285+
**kwargs)
286+
287+
# Set up standard ResNet-50 model.
288+
model = models.resnet50()
289+
290+
# By default, Adasum doesn't need scaling up learning rate.
291+
# For sum/average with gradient Accumulation: scale learning rate by batches_per_allreduce
292+
lr_scaler = args.batches_per_allreduce * hvd.size() if not args.use_adasum else 1
293+
294+
if args.cuda:
295+
# Move model to GPU.
296+
model.cuda()
297+
# If using GPU Adasum allreduce, scale learning rate by local_size.
298+
if args.use_adasum and hvd.nccl_built():
299+
lr_scaler = args.batches_per_allreduce * hvd.local_size()
300+
301+
# Horovod: scale learning rate by the number of GPUs.
302+
optimizer = optim.SGD(model.parameters(),
303+
lr=(args.base_lr *
304+
lr_scaler),
305+
momentum=args.momentum, weight_decay=args.wd)
306+
307+
# Horovod: (optional) compression algorithm.
308+
compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none
309+
310+
# Horovod: wrap optimizer with DistributedOptimizer.
311+
optimizer = hvd.DistributedOptimizer(
312+
optimizer, named_parameters=model.named_parameters(),
313+
compression=compression,
314+
backward_passes_per_step=args.batches_per_allreduce,
315+
op=hvd.Adasum if args.use_adasum else hvd.Average,
316+
gradient_predivide_factor=args.gradient_predivide_factor)
317+
318+
# Restore from a previous checkpoint, if initial_epoch is specified.
319+
# Horovod: restore on the first worker which will broadcast weights to other workers.
320+
resume_from_epoch = 0
321+
if hvd.rank() == 0:
322+
for try_epoch in range(args.epochs, 0, -1):
323+
if os.path.exists(args.checkpoint_format.format(epoch=try_epoch)):
324+
resume_from_epoch = try_epoch
325+
break
326+
327+
if resume_from_epoch > 0:
328+
filepath = args.checkpoint_format.format(epoch=resume_from_epoch)
329+
checkpoint = torch.load(filepath)
330+
model.load_state_dict(checkpoint['model'])
331+
optimizer.load_state_dict(checkpoint['optimizer'])
332+
333+
state = hvd.elastic.TorchState(model=model,
334+
optimizer=optimizer,
335+
train_sampler=train_sampler,
336+
val_sampler=val_sampler,
337+
epoch=resume_from_epoch,
338+
batch=0)
339+
340+
full_train(state)

0 commit comments

Comments
 (0)