-
Notifications
You must be signed in to change notification settings - Fork 51
Bare bones GenerationModule
#324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
max_sequence_length: Optional[int] = None, | ||
rank_microbatch_size: Optional[int] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These actually are optional and if provided are used to warm up the RoPE cache (max_seq_len) and an MoE component (rank_microbatch_size).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a great start. I just have a few comments so far
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: maybe rename this generation_module.py
to be more specific. Also consider moving the transformer implementation to its own submodule.
src/olmo_core/generate/generation.py
Outdated
from olmo_core.generate.config import GenerationConfig | ||
from olmo_core.generate.selection import temperature_sampling |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: relative imports are nice to have at least within the same submodule.
from olmo_core.generate.config import GenerationConfig | |
from olmo_core.generate.selection import temperature_sampling | |
from .config import GenerationConfig | |
from .selection import temperature_sampling |
Args: | ||
checkpoint_dir: Path to checkpoint directory | ||
work_dir: Working directory for caching remote checkpoints | ||
process_group: Process group for distributed loading | ||
pre_download: Whether to pre-download remote checkpoints | ||
load_thread_count: Number of threads to use for loading the checkpoint | ||
|
||
Raises: | ||
FileNotFoundError: If checkpoint directory doesn't exist | ||
RuntimeError: If checkpoint loading fails |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this docstring syntax is inconsistent with our other docstrings
work_dir = Path( | ||
work_dir or (tempfile.mkdtemp() if get_rank(process_group) == 0 else "/tmp") | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This assumes all ranks share the filesystem? Which is usually only true for single-node jobs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternatively just force the user to provide a work dir.
Add a simple GenerationModule (analogous to
TrainModule
) that can be used to configure and run autoregressive next token prediction.As of now this supports:
1. loading distributed checkpoints exactly as they were saved during training.
2.
temperature
parameter for generation.3. using FSDP to shard larger models across multiple devices (mostly as a demonstration of how other types of parallelism can be worked in).
4. Attention masks passed through to SDPA so that batched generation w/ left-padding is supported.
Note that this implementation is very inefficient compared to
transformers
orvllm
, in part due to the lack of kv caching.