Skip to content

Commit 1c66792

Browse files
sarthakgargfacebook-github-bot
authored andcommitted
Implementation of the paper "Jointly Learning to Align and Translate with Transformer Models" (facebookresearch#877)
Summary: Pull Request resolved: fairinternal/fairseq-py#877 This PR implements guided alignment training described in "Jointly Learning to Align and Translate with Transformer Models (https://arxiv.org/abs/1909.02074)". In summary, it allows for training selected heads of the Transformer Model with external alignments computed by Statistical Alignment Toolkits. During inference, attention probabilities from the trained heads can be used to extract reliable alignments. In our work, we did not see any regressions in the translation performance because of guided alignment training. Pull Request resolved: facebookresearch#1095 Differential Revision: D17170337 Pulled By: myleott fbshipit-source-id: daa418bef70324d7088dbb30aa2adf9f95774859
1 parent acb6fba commit 1c66792

20 files changed

+899
-61
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ Fairseq provides reference implementations of various sequence-to-sequence model
3333
- [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md)
3434
- [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
3535
- [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
36+
- [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md )
3637
- **Non-autoregressive Transformers**
3738
- Non-Autoregressive Neural Machine Translation (Gu et al., 2017)
3839
- Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018)
@@ -100,6 +101,7 @@ as well as example training and evaluation commands.
100101
- [Language Modeling](examples/language_model/README.md): convolutional and transformer models are available
101102

102103
We also have more detailed READMEs to reproduce results from specific papers:
104+
- [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md )
103105
- [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md)
104106
- [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
105107
- [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)
2+
3+
This page includes instructions for training models described in [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](https://arxiv.org/abs/1909.02074).
4+
5+
## Training a joint alignment-translation model on WMT'18 En-De
6+
7+
##### 1. Extract and preprocess the WMT'18 En-De data
8+
```bash
9+
./prepare-wmt18en2de_no_norm_no_escape_no_agressive.sh
10+
```
11+
12+
##### 2. Generate alignments from statistical alignment toolkits e.g. Giza++/FastAlign.
13+
In this example, we use FastAlign.
14+
```bash
15+
git clone [email protected]:clab/fast_align.git
16+
pushd fast_align
17+
mkdir build
18+
cd build
19+
cmake ..
20+
make
21+
popd
22+
ALIGN=fast_align/build/fast_align
23+
paste bpe.32k/train.en bpe.32k/train.de | awk -F '\t' '{print $1 " ||| " $2}' > bpe.32k/train.en-de
24+
$ALIGN -i bpe.32k/train.en-de -d -o -v > bpe.32k/train.align
25+
```
26+
27+
##### 3. Preprocess the dataset with the above generated alignments.
28+
```bash
29+
fairseq-preprocess \
30+
--source-lang en --target-lang de \
31+
--trainpref bpe.32k/train \
32+
--validpref bpe.32k/valid \
33+
--testpref bpe.32k/test \
34+
--align-suffix align \
35+
--destdir binarized/ \
36+
--joined-dictionary \
37+
--workers 32
38+
```
39+
40+
##### 4. Train a model
41+
```bash
42+
fairseq-train \
43+
binarized \
44+
--arch transformer_wmt_en_de_big_align --share-all-embeddings \
45+
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 --activation-fn relu\
46+
--lr 0.0002 --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07 \
47+
--dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \
48+
--max-tokens 3500 --label-smoothing 0.1 \
49+
--save-dir ./checkpoints --log-interval 1000 --max-update 60000 \
50+
--keep-interval-updates -1 --save-interval-updates 0 \
51+
--load-alignments --criterion label_smoothed_cross_entropy_with_alignment \
52+
--fp16
53+
```
54+
55+
Note that the `--fp16` flag requires you have CUDA 9.1 or greater and a Volta GPU or newer.
56+
57+
If you want to train the above model with big batches (assuming your machine has 8 GPUs):
58+
- add `--update-freq 8` to simulate training on 8x8=64 GPUs
59+
- increase the learning rate; 0.0007 works well for big batches
60+
61+
##### 5. Evaluate and generate the alignments (BPE level)
62+
```bash
63+
fairseq-generate \
64+
binarized --gen-subset test --print-alignment \
65+
--source-lang en --target-lang de \
66+
--path checkpoints/checkpoint_best.pt --beam 5 --nbest 1
67+
```
68+
69+
##### 6. Other resources.
70+
The code for:
71+
1. preparing alignment test sets
72+
2. converting BPE level alignments to token level alignments
73+
3. symmetrizing bidirectional alignments
74+
4. evaluating alignments using AER metric
75+
can be found [here](https://github.com/lilt/alignment-scripts)
76+
77+
## Citation
78+
79+
```bibtex
80+
@inproceedings{garg2019jointly,
81+
title = {Jointly Learning to Align and Translate with Transformer Models},
82+
author = {Garg, Sarthak and Peitz, Stephan and Nallasamy, Udhyakumar and Paulik, Matthias},
83+
booktitle = {Conference on Empirical Methods in Natural Language Processing (EMNLP)},
84+
address = {Hong Kong},
85+
month = {November},
86+
url = {https://arxiv.org/abs/1909.02074},
87+
year = {2019},
88+
}
89+
```
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
#!/bin/bash
2+
3+
# Copyright (c) Facebook, Inc. and its affiliates.
4+
#
5+
# This source code is licensed under the MIT license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
echo 'Cloning Moses github repository (for tokenization scripts)...'
9+
git clone https://github.com/moses-smt/mosesdecoder.git
10+
11+
SCRIPTS=mosesdecoder/scripts
12+
TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl
13+
CLEAN=$SCRIPTS/training/clean-corpus-n.perl
14+
REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl
15+
16+
URLS=(
17+
"http://statmt.org/wmt13/training-parallel-europarl-v7.tgz"
18+
"http://statmt.org/wmt13/training-parallel-commoncrawl.tgz"
19+
"http://data.statmt.org/wmt18/translation-task/training-parallel-nc-v13.tgz"
20+
"http://data.statmt.org/wmt18/translation-task/rapid2016.tgz"
21+
"http://data.statmt.org/wmt17/translation-task/dev.tgz"
22+
"http://statmt.org/wmt14/test-full.tgz"
23+
)
24+
CORPORA=(
25+
"training/europarl-v7.de-en"
26+
"commoncrawl.de-en"
27+
"training-parallel-nc-v13/news-commentary-v13.de-en"
28+
"rapid2016.de-en"
29+
)
30+
31+
if [ ! -d "$SCRIPTS" ]; then
32+
echo "Please set SCRIPTS variable correctly to point to Moses scripts."
33+
exit
34+
fi
35+
36+
src=en
37+
tgt=de
38+
lang=en-de
39+
prep=wmt18_en_de
40+
tmp=$prep/tmp
41+
orig=orig
42+
dev=dev/newstest2012
43+
codes=32000
44+
bpe=bpe.32k
45+
46+
mkdir -p $orig $tmp $prep $bpe
47+
48+
cd $orig
49+
50+
for ((i=0;i<${#URLS[@]};++i)); do
51+
url=${URLS[i]}
52+
file=$(basename $url)
53+
if [ -f $file ]; then
54+
echo "$file already exists, skipping download"
55+
else
56+
wget "$url"
57+
if [ -f $file ]; then
58+
echo "$url successfully downloaded."
59+
else
60+
echo "$url not successfully downloaded."
61+
exit 1
62+
fi
63+
if [ ${file: -4} == ".tgz" ]; then
64+
tar zxvf $file
65+
elif [ ${file: -4} == ".tar" ]; then
66+
tar xvf $file
67+
fi
68+
fi
69+
done
70+
cd ..
71+
72+
echo "pre-processing train data..."
73+
for l in $src $tgt; do
74+
rm -rf $tmp/train.tags.$lang.tok.$l
75+
for f in "${CORPORA[@]}"; do
76+
cat $orig/$f.$l | \
77+
perl $REM_NON_PRINT_CHAR | \
78+
perl $TOKENIZER -threads 8 -l $l -no-escape >> $tmp/train.tags.$lang.tok.$l
79+
done
80+
done
81+
82+
echo "pre-processing test data..."
83+
for l in $src $tgt; do
84+
if [ "$l" == "$src" ]; then
85+
t="src"
86+
else
87+
t="ref"
88+
fi
89+
grep '<seg id' $orig/test-full/newstest2014-deen-$t.$l.sgm | \
90+
sed -e 's/<seg id="[0-9]*">\s*//g' | \
91+
sed -e 's/\s*<\/seg>\s*//g' | \
92+
sed -e "s/\’/\'/g" | \
93+
perl $TOKENIZER -threads 8 -l $l -no-escape > $tmp/test.$l
94+
echo ""
95+
done
96+
97+
# apply length filtering before BPE
98+
perl $CLEAN -ratio 1.5 $tmp/train.tags.$lang.tok $src $tgt $tmp/train 1 100
99+
100+
# use newstest2012 for valid
101+
echo "pre-processing valid data..."
102+
for l in $src $tgt; do
103+
rm -rf $tmp/valid.$l
104+
cat $orig/$dev.$l | \
105+
perl $REM_NON_PRINT_CHAR | \
106+
perl $TOKENIZER -threads 8 -l $l -no-escape >> $tmp/valid.$l
107+
done
108+
109+
mkdir output
110+
mv $tmp/{train,valid,test}.{$src,$tgt} output
111+
112+
#BPE
113+
git clone [email protected]:glample/fastBPE.git
114+
pushd fastBPE
115+
g++ -std=c++11 -pthread -O3 fastBPE/main.cc -IfastBPE -o fast
116+
popd
117+
fastBPE/fast learnbpe $codes output/train.$src output/train.$tgt > $bpe/codes
118+
for split in {train,valid,test}; do for lang in {en,de}; do fastBPE/fast applybpe $bpe/$split.$lang output/$split.$lang $bpe/codes; done; done

fairseq/binarizer.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,22 @@ def replaced_consumer(word, idx):
5252
line = f.readline()
5353
return {'nseq': nseq, 'nunk': sum(replaced.values()), 'ntok': ntok, 'replaced': replaced}
5454

55+
@staticmethod
56+
def binarize_alignments(filename, alignment_parser, consumer, offset=0, end=-1):
57+
nseq = 0
58+
59+
with open(filename, 'r') as f:
60+
f.seek(offset)
61+
line = safe_readline(f)
62+
while line:
63+
if end > 0 and f.tell() > end:
64+
break
65+
ids = alignment_parser(line)
66+
nseq += 1
67+
consumer(ids)
68+
line = f.readline()
69+
return {'nseq': nseq}
70+
5571
@staticmethod
5672
def find_offsets(filename, num_chunks):
5773
with open(filename, 'r', encoding='utf-8') as f:
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import math
7+
8+
from fairseq import utils
9+
10+
from .label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion
11+
from . import register_criterion
12+
13+
14+
@register_criterion('label_smoothed_cross_entropy_with_alignment')
15+
class LabelSmoothedCrossEntropyCriterionWithAlignment(LabelSmoothedCrossEntropyCriterion):
16+
17+
def __init__(self, args, task):
18+
super().__init__(args, task)
19+
self.alignment_lambda = args.alignment_lambda
20+
21+
@staticmethod
22+
def add_args(parser):
23+
"""Add criterion-specific arguments to the parser."""
24+
super(LabelSmoothedCrossEntropyCriterionWithAlignment,
25+
LabelSmoothedCrossEntropyCriterionWithAlignment).add_args(parser)
26+
parser.add_argument('--alignment-lambda', default=0.05, type=float, metavar='D',
27+
help='weight for the alignment loss')
28+
29+
def forward(self, model, sample, reduce=True):
30+
"""Compute the loss for the given sample.
31+
32+
Returns a tuple with three elements:
33+
1) the loss
34+
2) the sample size, which is used as the denominator for the gradient
35+
3) logging outputs to display while training
36+
"""
37+
net_output = model(**sample['net_input'])
38+
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
39+
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
40+
logging_output = {
41+
'loss': utils.item(loss.data) if reduce else loss.data,
42+
'nll_loss': utils.item(nll_loss.data) if reduce else nll_loss.data,
43+
'ntokens': sample['ntokens'],
44+
'nsentences': sample['target'].size(0),
45+
'sample_size': sample_size,
46+
}
47+
48+
alignment_loss = None
49+
50+
# Compute alignment loss only for training set and non dummy batches.
51+
if 'alignments' in sample and sample['alignments'] is not None:
52+
alignment_loss = self.compute_alignment_loss(sample, net_output)
53+
54+
if alignment_loss is not None:
55+
logging_output['alignment_loss'] = utils.item(alignment_loss.data)
56+
loss += self.alignment_lambda * alignment_loss
57+
58+
return loss, sample_size, logging_output
59+
60+
def compute_alignment_loss(self, sample, net_output):
61+
attn_prob = net_output[1]['attn']
62+
bsz, tgt_sz, src_sz = attn_prob.shape
63+
attn = attn_prob.view(bsz * tgt_sz, src_sz)
64+
65+
align = sample['alignments']
66+
align_weights = sample['align_weights'].float()
67+
68+
if len(align) > 0:
69+
# Alignment loss computation. align (shape [:, 2]) contains the src-tgt index pairs corresponding to
70+
# the alignments. align_weights (shape [:]) contains the 1 / frequency of a tgt index for normalizing.
71+
loss = -((attn[align[:, 1][:, None], align[:, 0][:, None]]).log() * align_weights[:, None]).sum()
72+
else:
73+
return None
74+
75+
return loss
76+
77+
@staticmethod
78+
def aggregate_logging_outputs(logging_outputs):
79+
"""Aggregate logging outputs from data parallel training."""
80+
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
81+
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
82+
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
83+
return {
84+
'loss': sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2) if sample_size > 0 else 0.,
85+
'nll_loss': sum(log.get('nll_loss', 0) for log in logging_outputs) / ntokens / math.log(2) if ntokens > 0 else 0.,
86+
'alignment_loss': sum(log.get('alignment_loss', 0) for log in logging_outputs) / sample_size / math.log(2) if sample_size > 0 else 0.,
87+
'ntokens': ntokens,
88+
'nsentences': nsentences,
89+
'sample_size': sample_size,
90+
}

0 commit comments

Comments
 (0)