Skip to content
Permalink

Comparing changes

This is a direct comparison between two commits made in this repository or its related repositories. View the default comparison for this range or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: kaldi-asr/kaldi
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: 5ccb4565782175db0e6d6650a47104e86282e8db
Choose a base ref
..
head repository: kaldi-asr/kaldi
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: cd351bb31c98f9d540c409478cbf2c5fef1853ca
Choose a head ref
Showing with 11 additions and 15 deletions.
  1. +11 −15 egs/aishell/s10/chain/train.py
26 changes: 11 additions & 15 deletions egs/aishell/s10/chain/train.py
Original file line number Diff line number Diff line change
@@ -34,7 +34,7 @@
from libs.nnet3.train.dropout_schedule import _get_dropout_proportions
from model import get_chain_model
from options import get_args
from sgd_max_change import SgdMaxChange
#from sgd_max_change import SgdMaxChange

def get_objf(batch, model, device, criterion, opts, den_graph, training, optimizer=None, dropout=0.):
feature, supervision = batch
@@ -68,20 +68,20 @@ def get_objf(batch, model, device, criterion, opts, den_graph, training, optimiz
supervision, nnet_output,
xent_output)
objf = objf_l2_term_weight[0]
change = 0
if training:
optimizer.zero_grad()
objf.backward()
# clip_grad_value_(model.parameters(), 5.0)
_, change = optimizer.step()
clip_grad_value_(model.parameters(), 5.0)
optimizer.step()

objf_l2_term_weight = objf_l2_term_weight.detach().cpu()

total_objf = objf_l2_term_weight[0].item()
total_weight = objf_l2_term_weight[2].item()
total_frames = nnet_output.shape[0]

return total_objf, total_weight, total_frames, change
return total_objf, total_weight, total_frames


def get_validation_objf(dataloader, model, device, criterion, opts, den_graph):
total_objf = 0.
@@ -91,7 +91,7 @@ def get_validation_objf(dataloader, model, device, criterion, opts, den_graph):
model.eval()

for batch_idx, (pseudo_epoch, batch) in enumerate(dataloader):
objf, weight, frames, _ = get_objf(
objf, weight, frames = get_objf(
batch, model, device, criterion, opts, den_graph, False)
total_objf += objf
total_weight += weight
@@ -117,7 +117,7 @@ def train_one_epoch(dataloader, valid_dataloader, model, device, optimizer, crit
len(dataloader)) / (len(dataloader) * num_epochs)
_, dropout = _get_dropout_proportions(
dropout_schedule, data_fraction)[0]
curr_batch_objf, curr_batch_weight, curr_batch_frames, curr_batch_change = get_objf(
curr_batch_objf, curr_batch_weight, curr_batch_frames = get_objf(
batch, model, device, criterion, opts, den_graph, True, optimizer, dropout=dropout)

total_objf += curr_batch_objf
@@ -128,13 +128,13 @@ def train_one_epoch(dataloader, valid_dataloader, model, device, optimizer, crit
logging.info(
'Device ({}) processing batch {}, current pseudo-epoch is {}/{}({:.6f}%), '
'global average objf: {:.6f} over {} '
'frames, current batch average objf: {:.6f} over {} frames, minibatch change: {:.6f}, epoch {}'
'frames, current batch average objf: {:.6f} over {} frames, epoch {}'
.format(
device.index, batch_idx, pseudo_epoch, len(dataloader),
float(pseudo_epoch) / len(dataloader) * 100,
total_objf / total_weight, total_frames,
curr_batch_objf / curr_batch_weight,
curr_batch_frames, curr_batch_change, current_epoch))
curr_batch_frames, current_epoch))

if valid_dataloader and batch_idx % 1000 == 0:
total_valid_objf, total_valid_weight, total_valid_frames = get_validation_objf(
@@ -168,11 +168,6 @@ def train_one_epoch(dataloader, valid_dataloader, model, device, optimizer, crit
dropout,
pseudo_epoch + current_epoch * len(dataloader))

tf_writer.add_scalar(
'train/current_batch_change',
curr_batch_change,
pseudo_epoch + current_epoch * len(dataloader))

state_dict = model.state_dict()
for key, value in state_dict.items():
# skip batchnorm parameters
@@ -307,7 +302,8 @@ def process_job(learning_rate, device_id=None, local_rank=None):
else:
valid_dataloader = None

optimizer = SgdMaxChange(model.parameters(),
#optimizer = SgdMaxChange(model.parameters(),
optimizer = optim.Adam(model.parameters(),
lr=learning_rate,
weight_decay=5e-4)