Skip to content

Commit ec38740

Browse files
committed
first commit
1 parent ea6ed7a commit ec38740

File tree

688 files changed

+130805
-89
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

688 files changed

+130805
-89
lines changed

.gitignore

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,4 +136,9 @@ dmypy.json
136136

137137
# Project files
138138
sketch*
139-
iclr2025_mbr_uncertainty/sketch*
139+
iclr2025_mbr_uncertainty/sketch*
140+
141+
alias/
142+
output/
143+
work/
144+
*wandb*

NOTICE.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
-------------------------------------------------------------------------------
2-
Copyright 2024
2+
Copyright 2025
33
Ubiquitous Knowledge Processing (UKP) Lab
44
Technische Universität Darmstadt
55

README.md

Lines changed: 80 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,112 +1,129 @@
1-
<p align="center">
2-
<img src='logo.png' width='200'>
3-
</p>
1+
# Uncertainty-Aware Decoding with Minimum Bayes' Risk - ICLR 2025
42

5-
# iclr2025_mbr_uncertainty
6-
[![Arxiv](https://img.shields.io/badge/Arxiv-YYMM.NNNNN-red?style=flat-square&logo=arxiv&logoColor=white)](https://put-here-your-paper.com)
3+
[![Arxiv](https://img.shields.io/badge/Arxiv-YYMM.NNNNN-red?style=flat-square&logo=arxiv&logoColor=white)](https://arxiv.org/search/cs?searchtype=author&query=Daheim,+N)
74
[![License](https://img.shields.io/github/license/UKPLab/iclr2025-mbr-uncertainty)](https://opensource.org/licenses/Apache-2.0)
8-
[![Python Versions](https://img.shields.io/badge/Python-3.9-blue.svg?style=flat&logo=python&logoColor=white)](https://www.python.org/)
9-
[![CI](https://github.com/UKPLab/iclr2025-mbr-uncertainty/actions/workflows/main.yml/badge.svg)](https://github.com/UKPLab/iclr2025-mbr-uncertainty/actions/workflows/main.yml)
105

11-
This is the official template for new Python projects at UKP Lab. It was adapted for the needs of UKP Lab from the excellent [python-project-template](https://github.com/rochacbruno/python-project-template/) by [rochacbruno](https://github.com/rochacbruno).
6+
This is the repositoy for ``Uncertainty-Aware Decoding with Minimum Bayes' Risk'' (ICLR 2025). The repo template is adapted from [python-project-template](https://github.com/rochacbruno/python-project-template/) by [rochacbruno](https://github.com/rochacbruno).
127

13-
It should help you start your project and give you continuous status updates on the development through [GitHub Actions](https://docs.github.com/en/actions).
148

15-
> **Abstract:** The study of natural language processing (NLP) has gained increasing importance in recent years, with applications ranging from machine translation to sentiment analysis. Properly managing Python projects in this domain is of paramount importance to ensure reproducibility and facilitate collaboration. The template provides a structured starting point for projects and offers continuous status updates on development through GitHub Actions. Key features include a basic setup.py file for installation, packaging, and distribution, documentation structure using mkdocs, testing structure using pytest, code linting with pylint, and entry points for executing the program with basic CLI argument parsing. Additionally, the template incorporates continuous integration using GitHub Actions with jobs to check, lint, and test the project, ensuring robustness and reliability throughout the development process.
169

17-
Contact person: [Federico Tiblias](mailto:[email protected])
10+
> **Abstract:** Despite their outstanding performance in the majority of scenarios, contemporary language models still occasionally generate undesirable outputs, for example, hallucinated text. While such behaviors have previously been linked to uncertainty, there is a notable lack of methods that actively consider uncertainty during text generation. In this work, we show how Minimum Bayes’ Risk (MBR) decoding, which selects model generations according to an expected risk can be generalized into a principled uncertainty-aware decoding method. In short, we account for model uncertainty during decoding by incorporating a posterior over model parameters into MBR’s computation of expected risk. We show that this modified expected risk is useful for both choosing outputs and deciding when to abstain from generation and can provide improvements without incurring overhead. We benchmark different methods for learning posteriors and show that performance improves with prediction diversity.
11+
12+
Contact person: [Nico Daheim](mailto:[email protected])
1813

1914
[UKP Lab](https://www.ukp.tu-darmstadt.de/) | [TU Darmstadt](https://www.tu-darmstadt.de/
2015
)
2116

2217
Don't hesitate to send us an e-mail or report an issue, if something is broken (and it shouldn't be) or if you have further questions.
2318

24-
2519
## Getting Started
2620

27-
> **DO NOT CLONE OR FORK**
21+
The repository contains code to run uncertainty-aware MBR, as well as to train models using [huggingface transformers](https://github.com/huggingface/transformers) and [fairseq](https://github.com/facebookresearch/fairseq).
22+
Both have been adapted in this repository to accustom training with variational learning using the [IVON optimizer](https://openreview.net/forum?id=cXBv07GKvk) for which we use [the official implementation](https://github.com/team-approx-bayes/ivon). Our MBR implementation is based on the [implementation](https://github.com/deep-spin/qaware-decode) of [``Quality-Aware Decoding for Neural Machine Translation''](https://aclanthology.org/2022.naacl-main.100.pdf).
2823

29-
If you want to set up this template:
24+
If you are only interested in experiments based on huggingface, then running
25+
```
26+
pip install -r requirements.txt
27+
```
28+
will install all necessary packages.
3029

31-
1. Request a repository on UKP Lab's GitHub by following the standard procedure on the wiki. It will install the template directly. Alternatively, set it up in your personal GitHub account by clicking **[Use this template](https://github.com/rochacbruno/python-project-template/generate)**.
32-
2. Wait until the first run of CI finishes. Github Actions will commit to your new repo with a "✅ Ready to clone and code" message.
33-
3. Delete optional files:
34-
- If you don't need automatic documentation generation, you can delete folder `docs`, file `.github\workflows\docs.yml` and `mkdocs.yml`
35-
- If you don't want automatic testing, you can delete folder `tests` and file `.github\workflows\tests.yml`
36-
4. Prepare a virtual environment:
37-
```bash
38-
python -m venv .venv
39-
source .venv/bin/activate
40-
pip install .
41-
pip install -r requirements-dev.txt # Only needed for development
30+
When using fairseq, the following has to be run in addition:
31+
```
32+
cd fairseq/
33+
pip install --editable ./
4234
```
43-
5. Adapt anything else (for example this file) to your project.
4435

45-
6. Read the file [ABOUT_THIS_TEMPLATE.md](ABOUT_THIS_TEMPLATE.md) for more information about development.
36+
The experiments of the paper were organized using the workflow manager [Sisyphus](https://github.com/rwth-i6/sisyphus). If you would like to make use of it, too, then please run:
37+
```
38+
git clone [email protected]:rwth-i6/sisyphus.git
39+
cd sisyphus/
40+
pip install -r requirements.txt
41+
cd ..
42+
mkdir alias
43+
mkdir output
44+
mkdir work
45+
```
46+
Sisyphus will use the directories as follows:
47+
1. `alias`: It's possible to identify aliases for each job to identify it quickly (as a default, a hash is appended to the jobclass name as an identifier), and sisyphus adds a symlink to the job under the alias.
48+
2. `output`: `tk.register_output("name", job_class.file)` registers an output under the filename `name` in the output folder that symlinks to `job_class.file`
49+
3. `work`: All jobs will be placed here under their hash.
4650

4751
## Usage
4852

49-
### Using the classes
53+
### Running experiments using Sisyphus
5054

51-
To import classes/methods of `iclr2025_mbr_uncertainty` from inside the package itself you can use relative imports:
55+
Examples for training with Sisyphus are found in the `config/` folder.
56+
Running either training using huggingface or fairseq on a Slurm cluster only requires
57+
```
58+
cd iclr2025-mbr-uncertainty
59+
sisyphus/sis --config config config/huggingface.py
60+
```
61+
or
62+
```
63+
cd iclr2025-mbr-uncertainty
64+
sisyphus/sis --config config config/fairseq.py
65+
```
5266

53-
```py
54-
from .base import BaseClass # Notice how I omit the package name
67+
The examples will run finetuning with LoRA of GEMMA-2B on IWSLT17 and a from-scratch training of a Transformer-base model on IWSLT14, respectively.
68+
The examples also show how to use our sequence-level MBR methods and single model MBR baselines.
69+
Token-level posteriors can be used easily in fairseq according to the documentation in their [repo](https://github.com/facebookresearch/fairseq).
5570

56-
BaseClass().something()
57-
```
71+
For each part of the pipeline, Sisyphus Jobs are defined that wrap python scripts for training, decoding, mbr, and evaluation.
5872

59-
To import classes/methods from outside the package (e.g. when you want to use the package in some other project) you can instead refer to the package name:
73+
### Using scripts
6074

61-
```py
62-
from iclr2025_mbr_uncertainty import BaseClass # Notice how I omit the file name
63-
from iclr2025_mbr_uncertainty.subpackage import SubPackageClass # Here it's necessary because it's a subpackage
75+
For training, there is an example configuration file in `scripts`. The file can be invoked via:
6476

65-
BaseClass().something()
66-
SubPackageClass().something()
6777
```
78+
cd huggingface/code/
79+
python3 train.py ../../scripts/train_config.json
80+
```
81+
The example will run a similar training to the config in `config/huggingface.py` and train GEMMA-2B on IWSLT17 using LoRA.
6882

69-
### Using scripts
83+
Similarly, decoding can be run by
84+
```
85+
cd huggingface/code/
86+
python3 predict.py ../../scripts/search_config.json
87+
```
88+
Here, the config files describe all relevant parameters, such as the model to be used, the dataset, a prompt, random seed, whether to sample during decoding, etc.
7089

71-
This is how you can use `iclr2025_mbr_uncertainty` from command line:
90+
For MBR, the script in `mbr/mbr.py` can be used.
91+
The arguments follow the implementation from [`qaware-decode`](https://github.com/deep-spin/qaware-decode) but there are two important changes:
7292

73-
```bash
74-
$ python -m iclr2025_mbr_uncertainty
75-
```
93+
The first argument is can be the path to one prediction file but also a semi-colon-separated concatenation of multiple paths to prediction files to perform uncertainty-aware decoding via model combination.
7694

77-
### Expected results
95+
Then, using `--flatten` concatenates all these hypothesis sets for each sample, i.e. performs Eq. 9, while not passing the argument will calculate utilities individually for each sample and then sum them, i.e. perform Eq. 10.
7896

79-
After running the experiments, you should expect the following results:
97+
For evaluation, the script in `huggingface/code/evaluation.py` can be used. Besides predictions, hypothesis set size, etc. the argument `eval_task` has to be passed which selects the metrics for the given task, for example, rouge for summarization.
8098

81-
(Feel free to describe your expected results here...)
99+
### Expected results
82100

83-
### Parameter description
101+
After running the jobs in `config/huggingface.py`, the results should closely match our MBR results on IWSLT17 using GEMMA-2B in Table 1, where we average over 4 seeds.
84102

85-
* `x, --xxxx`: This parameter does something nice
86103

87-
* ...
104+
### Code Structure
88105

89-
* `z, --zzzz`: This parameter does something even nicer
106+
The code is mainly based on the concept of ''methods'' that are found in the `/code/mbr/methods/` folder which wrap all of the functionality needed to reproduce a certain method:
107+
1. Defining and loading Trainer and Data Collator classes
108+
2. Loading all datasets
109+
3. Defining and applying the preprocessing methods, defined in `/code/mbr/methods/preprocessing`
90110

91-
## Development
111+
To understand how the method classes are structured it's best to check `code/mbr/methods/base.py` which defines a base class from which all methods inherit.
92112

93-
Read the FAQs in [ABOUT_THIS_TEMPLATE.md](ABOUT_THIS_TEMPLATE.md) to learn more about how this template works and where you should put your classes & methods. Make sure you've correctly installed `requirements-dev.txt` dependencies
113+
The main entry point for the code is `/code/mbr/main.py` that handles loading method classes, models, and running the Trainers.
94114

95115
## Cite
96116

97117
Please use the following citation:
98118

99119
```
100-
@InProceedings{smith:20xx:CONFERENCE_TITLE,
101-
author = {Smith, John},
102-
title = {My Paper Title},
103-
booktitle = {Proceedings of the 20XX Conference on XXXX},
104-
month = mmm,
105-
year = {20xx},
106-
address = {Gotham City, USA},
107-
publisher = {Association for XXX},
108-
pages = {XXXX--XXXX},
109-
url = {http://xxxx.xxx}
120+
@inproceedings{
121+
daheim2025uncertaintyaware,
122+
title={Uncertainty-Aware Decoding with Minimum Bayes' Risk},
123+
author={Nico Daheim and Clara Meister and Thomas M{\"o}llenhoff and Iryna Gurevych},
124+
booktitle={The Thirteenth International Conference on Learning Representations},
125+
year={2025},
126+
url={https://openreview.net/forum?id=hPpyUv1XyQ}
110127
}
111128
```
112129

config/fairseq.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import os
2+
import sys
3+
4+
import numpy as np
5+
6+
sys.setrecursionlimit(2500)
7+
8+
# ------------------------------ Sisyphus -------------------------------------
9+
10+
import sisyphus.toolkit as tk
11+
from ukp.fairseq.evaluation import FairseqEvalJob
12+
from ukp.fairseq.search import FairseqSearchJob, MBRJob
13+
from ukp.fairseq.training import FairseqTrainingJob
14+
15+
Path = tk.Path
16+
17+
async def ivon_iwslt14_base():
18+
for mc_samples in [2]:
19+
for optimizer in ["ivon"]:
20+
config = [
21+
"/path/to/data-bin/iwslt14.tokenized.de-en", # binarized data for iwslt14
22+
"--arch", "transformer",
23+
"--share-decoder-input-output-embed",
24+
"--optimizer", optimizer,
25+
"--clip-norm", "1.0",
26+
"--lr", "0.15",
27+
"--lr-scheduler", "inverse_sqrt",
28+
"--warmup-updates", "4000",
29+
"--clip-radius", "0.001",
30+
"--dropout", "0.2",
31+
"--weight-decay", "0.0001",
32+
"--criterion", "cross_entropy",
33+
"--max-tokens", "4096",
34+
"--eval-bleu",
35+
"--eval-bleu-args", '{\"beam\": 5, \"max_len_a\": 1.2, \"max_len_b\": 10}',
36+
"--eval-bleu-detok", "moses",
37+
"--eval-bleu-remove-bpe",
38+
"--best-checkpoint-metric", "bleu",
39+
"--maximize-best-checkpoint-metric",
40+
"--patience", "3",
41+
"--batch-size", "1024",
42+
"--ess", "1e8",
43+
"--hess-init", "0.1",
44+
"--seed", "1",
45+
"--mc-samples", str(mc_samples),
46+
]
47+
48+
train_job = FairseqTrainingJob(
49+
config
50+
)
51+
train_job.add_alias(f"{optimizer}_comparison_mc_{mc_samples}")
52+
tk.register_output(f"fairseq_example/{optimizer}_ivon_iwslt14_trafo_base_{mc_samples}", train_job.out_checkpoints_dir)
53+
54+
model_path = os.path.join(train_job.out_checkpoints_dir.get_path(), "checkpoint_best.pt")
55+
56+
config = [
57+
"/path/to/data-bin/iwslt14.tokenized.de-en", # binarized data for iwslt14
58+
"--path", model_path,
59+
"--batch-size", "128",
60+
"--beam", "4",
61+
"--nbest", "4",
62+
"--remove-bpe",
63+
"--sampling",
64+
"--lenpen", "0.6",
65+
"--sample-params",
66+
"--num-mc-samples", "1"
67+
]
68+
69+
search_job = FairseqSearchJob(config, train_job.out_checkpoints_dir)
70+
tk.register_output(f"example/{optimizer}_iwslt14_trafo_base_out.txt", search_job.out_hyp_file)
71+
72+
config = [
73+
"/path/to/data-bin/iwslt14.tokenized.de-en", # binarized data for iwslt14
74+
"--path", model_path,
75+
"--batch-size", "128",
76+
"--beam", "4",
77+
"--nbest", "4",
78+
"--remove-bpe",
79+
"--sampling",
80+
"--lenpen", "0.6",
81+
"--sample-params",
82+
"--num-mc-samples", "1"
83+
]
84+
85+
search_job = FairseqSearchJob(config, train_job.out_checkpoints_dir)
86+
tk.register_output(f"example/{optimizer}_iwslt14_trafo_base_out.txt", search_job.out_hyp_file)
87+
88+
eval_job = FairseqEvalJob(
89+
search_job.out_hyp_file,
90+
search_job.out_tgt_file,
91+
nbest=4
92+
)
93+
tk.register_output(f"example/{optimizer}ivon_iwslt14_trafo_base_out.metrics.json", eval_job.out_metrics_file)
94+
95+
mbr_job = MBRJob(
96+
search_job.out_hyp_file,
97+
search_job.out_src_file,
98+
4
99+
)
100+
tk.register_output(f"example/{optimizer}ivon_iwslt14_trafo_base_out.mbr.txt", mbr_job.out_hyp_file)
101+
102+
eval_job = FairseqEvalJob(
103+
mbr_job.out_hyp_file,
104+
search_job.out_tgt_file,
105+
nbest=1
106+
)
107+
tk.register_output(f"example/{optimizer}ivon_iwslt14_trafo_base_out.mbr.metrics.json", eval_job.out_metrics_file)
108+
109+
async def async_main():
110+
await ivon_iwslt14_base()
111+
112+
async def py():
113+
await async_main()

0 commit comments

Comments
 (0)