Skip to content

[Feature] add tensorbard support to the training framework #811

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1200,6 +1200,11 @@ class TrainConfig(BaseConfig):
Weights & Biases configuration.
"""

tensorboard_path: Optional[str] = None
"""
Path to tensorbard log output directory.
"""

speed_monitor: SpeedMonitorConfig = field(default_factory=SpeedMonitorConfig)
"""
Speed monitor configuration.
Expand Down
43 changes: 43 additions & 0 deletions olmo/tensorbard_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import logging
import os.path as osp

logger = logging.getLogger(__name__)

try:
from torch.utils.tensorboard import SummaryWriter
HAS_TENSORBOARD = True
except ImportError:
HAS_TENSORBOARD = False


# create a new class inheriting from SummaryWriter
class TBNewSummaryWriter(SummaryWriter):

def __init__(self, log_dir=None, comment="", **kwargs):
super().__init__(log_dir, comment, **kwargs)


# create a new function that will take dictionary as input
# and uses built-in add_scalar() function
# that function combines all plots into one subgroup by a tag
def add_scalar_dict(self, dictionary, global_step, tag=None):
for name, val in dictionary.items():
if tag is not None:
name = osp.join(tag, name)
self.add_scalar(name, val, global_step)


def log(self, dictionary, global_step, tag=None):
self.add_scalar_dict(dictionary, global_step, tag)


def write_args_to_tensorboard(self, args, iteration, prefix=""):
"""Write arguments to tensorboard."""
if prefix:
prefix = f"{prefix}."
for arg in args.keys():
arg_text = f"{prefix}{arg}"
if isinstance(args[arg], dict):
self.write_args_to_tensorboard(args[arg], iteration, prefix=arg_text)
else:
self.add_text(arg_text, str(args[arg]), global_step=iteration)
20 changes: 20 additions & 0 deletions olmo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
synchronize_value,
)
from .util import upload
from olmo.tensorbard_logger import TBNewSummaryWriter, HAS_TENSORBOARD

__all__ = ["SpeedMonitor", "LRMonitor", "Trainer"]

Expand Down Expand Up @@ -240,6 +241,15 @@ def __post_init__(self):
self.loss_fn = fused_loss_fn
else:
raise NameError("`fused_loss_fn` is not defined. Please ensure that `flash_attn` is installed.")
self.logger = None
if self.cfg.tensorboard_path:
if HAS_TENSORBOARD:
log_dir = Path(self.cfg.tensorboard_path)
log_dir.mkdir(parents=True, exist_ok=True)
self.logger = TBNewSummaryWriter(log_dir=log_dir)
self.logger.write_args_to_tensorboard(self.cfg.asdict())
else:
logger.warn("Failed to import tensorbard writer, will not write tensorbard logs.")

@property
def dataset(self) -> IterableDataset:
Expand Down Expand Up @@ -1126,6 +1136,8 @@ def fit(self):
eval_metrics = self.eval()
if wandb.run is not None:
wandb.log(eval_metrics, step=self.global_step)
if self.logger is not None:
self.logger.log(eval_metrics, global_step=self.global_step)

# Set model to 'train' mode.
self.dist_model.train()
Expand All @@ -1141,6 +1153,8 @@ def fit(self):
self.log_metrics_to_console("Pre-train system metrics", sys_metrics)
if wandb.run is not None:
wandb.log(sys_metrics, step=0)
if self.logger is not None:
self.logger.log(sys_metrics, global_step=0)

# Python Profiler stuff
if self.cfg.python_profiling:
Expand Down Expand Up @@ -1251,6 +1265,8 @@ def on_trace_ready(p):
and self.global_step % self.cfg.wandb.log_interval == 0
):
wandb.log(metrics, step=self.global_step)
if self.logger is not None:
self.logger.log(metrics, global_step=self.global_step)

# Check if/when run should be canceled.
if not cancel_initiated and self.global_step % self.cfg.canceled_check_interval == 0:
Expand Down Expand Up @@ -1317,6 +1333,8 @@ def on_trace_ready(p):
# Log metrics to W&B.
if wandb.run is not None:
wandb.log(eval_metrics, step=self.global_step)
if self.logger is not None:
self.logger.log(eval_metrics, global_step=self.global_step)

# Reset speed monitor so that we don't count the time taken to run evaluations.
speed_monitor.reset()
Expand Down Expand Up @@ -1387,6 +1405,8 @@ def close(self, exit_code: int = 0) -> None:
gc.disable()
if wandb.run is not None:
wandb.finish(exit_code=exit_code, quiet=True)
if self.logger is not None:
self.logger.close()

def __enter__(self) -> Trainer:
return self
Expand Down