Skip to content

Commit 0748109

Browse files
Add function to update W&B run name with sequence number (#7)
1 parent 21742c7 commit 0748109

File tree

2 files changed

+43
-3
lines changed

2 files changed

+43
-3
lines changed

amp_rsl_rl/runners/amp_on_policy_runner.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,12 +234,46 @@ def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = Fals
234234
)
235235
elif self.logger_type == "wandb":
236236
from rsl_rl.utils.wandb_utils import WandbSummaryWriter
237+
import wandb
238+
239+
# Update the run name with a sequence number. This function is useful to
240+
# replicate the same behaviour of rsl-rl-lib before v2.3.0
241+
def update_run_name_with_sequence(prefix: str) -> None:
242+
# Retrieve the current wandb run details (project and entity)
243+
project = wandb.run.project
244+
entity = wandb.run.entity
245+
246+
# Use wandb's API to list all runs in your project
247+
api = wandb.Api()
248+
runs = api.runs(f"{entity}/{project}")
249+
250+
max_num = 0
251+
# Iterate through runs to extract the numeric suffix after the prefix.
252+
for run in runs:
253+
if run.name.startswith(prefix):
254+
# Extract the numeric part from the run name.
255+
numeric_suffix = run.name[
256+
len(prefix) :
257+
] # e.g., from "prefix564", get "564"
258+
try:
259+
run_num = int(numeric_suffix)
260+
if run_num > max_num:
261+
max_num = run_num
262+
except ValueError:
263+
continue
264+
265+
# Increment to get the new run number
266+
new_num = max_num + 1
267+
new_run_name = f"{prefix}{new_num}"
268+
269+
# Update the wandb run's name
270+
wandb.run.name = new_run_name
271+
print("Updated run name to:", wandb.run.name)
237272

238273
self.writer = WandbSummaryWriter(
239274
log_dir=self.log_dir, flush_secs=10, cfg=self.cfg
240275
)
241-
242-
import wandb
276+
update_run_name_with_sequence(prefix=self.cfg["wandb_project"])
243277

244278
wandb.gym.monitor()
245279
self.writer.log_config(

amp_rsl_rl/utils/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,10 @@
1010
from .motion_loader import AMPLoader, download_amp_dataset_from_hf
1111
from .exporter import export_policy_as_onnx
1212

13-
__all__ = ["Normalizer", "RunningMeanStd", "AMPLoader", "download_amp_dataset_from_hf", "export_policy_as_onnx"]
13+
__all__ = [
14+
"Normalizer",
15+
"RunningMeanStd",
16+
"AMPLoader",
17+
"download_amp_dataset_from_hf",
18+
"export_policy_as_onnx",
19+
]

0 commit comments

Comments
 (0)