Skip to content

Commit cdd25e5

Browse files
Make amp_ppo compatible with rsl-rl v2.3.0 (#4)
1 parent 5a59524 commit cdd25e5

File tree

5 files changed

+181
-23
lines changed

5 files changed

+181
-23
lines changed

amp_rsl_rl/algorithms/amp_ppo.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def __init__(
106106
# The discriminator expects concatenated observations, so the replay buffer uses half the dimension.
107107
obs_dim: int = self.discriminator.input_dim // 2
108108
self.amp_storage: ReplayBuffer = ReplayBuffer(
109-
obs_dim, amp_replay_buffer_size, device
109+
obs_dim=obs_dim, buffer_size=amp_replay_buffer_size, device=device
110110
)
111111
self.amp_data: AMPLoader = amp_data
112112
self.amp_normalizer: Optional[Any] = amp_normalizer
@@ -172,11 +172,13 @@ def init_storage(
172172
Shape of the actions taken by the policy.
173173
"""
174174
self.storage = RolloutStorage(
175-
num_envs,
176-
num_transitions_per_env,
177-
actor_obs_shape,
178-
critic_obs_shape,
179-
action_shape,
175+
training_type="rl",
176+
num_envs=num_envs,
177+
num_transitions_per_env=num_transitions_per_env,
178+
obs_shape=actor_obs_shape,
179+
privileged_obs_shape=critic_obs_shape,
180+
actions_shape=action_shape,
181+
rnd_state_shape=None,
180182
device=self.device,
181183
)
182184

amp_rsl_rl/runners/amp_on_policy_runner.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,27 +9,22 @@
99
import os
1010
import statistics
1111
import time
12-
import rsl_rl.utils
13-
import torch
1412
from collections import deque
13+
14+
import torch
1515
from torch.utils.tensorboard import SummaryWriter as TensorboardSummaryWriter
1616

1717
import rsl_rl
18+
import rsl_rl.utils
19+
from rsl_rl.env import VecEnv
20+
from rsl_rl.modules import ActorCritic, ActorCriticRecurrent, EmpiricalNormalization
21+
from rsl_rl.utils import store_code_state
22+
1823
from amp_rsl_rl.utils import Normalizer
1924
from amp_rsl_rl.utils import AMPLoader
2025
from amp_rsl_rl.algorithms import AMP_PPO
2126
from amp_rsl_rl.networks import Discriminator
22-
23-
from isaaclab_rl.rsl_rl import (
24-
export_policy_as_onnx,
25-
)
26-
27-
from rsl_rl.env import VecEnv
28-
29-
from rsl_rl.modules import ActorCritic, ActorCriticRecurrent, EmpiricalNormalization
30-
31-
32-
from rsl_rl.utils import store_code_state
27+
from amp_rsl_rl.utils import export_policy_as_onnx
3328

3429

3530
class AMPOnPolicyRunner:
@@ -508,13 +503,24 @@ def log(self, locs: dict, width: int = 80, pad: int = 35):
508503
# f"""{'Mean episode length/episode:':>{pad}} {locs['mean_trajectory_length']:.2f}\n""")
509504

510505
log_string += ep_string
506+
507+
# make the eta in H:M:S
508+
eta_seconds = (
509+
self.tot_time
510+
/ (locs["it"] + 1)
511+
* (locs["num_learning_iterations"] - locs["it"])
512+
)
513+
514+
# Convert seconds to H:M:S
515+
eta_h, rem = divmod(eta_seconds, 3600)
516+
eta_m, eta_s = divmod(rem, 60)
517+
511518
log_string += (
512519
f"""{'-' * width}\n"""
513520
f"""{'Total timesteps:':>{pad}} {self.tot_timesteps}\n"""
514521
f"""{'Iteration time:':>{pad}} {iteration_time:.2f}s\n"""
515522
f"""{'Total time:':>{pad}} {self.tot_time:.2f}s\n"""
516-
f"""{'ETA:':>{pad}} {self.tot_time / (locs['it'] + 1) * (
517-
locs['num_learning_iterations'] - locs['it']):.1f}s\n"""
523+
f"""{'ETA:':>{pad}} {int(eta_h)}h {int(eta_m)}m {int(eta_s)}s\n"""
518524
)
519525
print(log_string)
520526

amp_rsl_rl/utils/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,6 @@
88

99
from .utils import Normalizer, RunningMeanStd
1010
from .motion_loader import AMPLoader, download_amp_dataset_from_hf
11+
from .exporter import export_policy_as_onnx
1112

12-
__all__ = ["Normalizer", "RunningMeanStd", "AMPLoader", "download_amp_dataset_from_hf"]
13+
__all__ = ["Normalizer", "RunningMeanStd", "AMPLoader", "download_amp_dataset_from_hf", "export_policy_as_onnx"]

amp_rsl_rl/utils/exporter.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# Copyright (c) 2022-2025, The Isaac Lab Project Developers.
2+
# All rights reserved.
3+
#
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# Code taken from https://github.com/isaac-sim/IsaacLab/blob/5716d5600a1a0e45345bc01342a70bd81fac7889/source/isaaclab_rl/isaaclab_rl/rsl_rl/exporter.py
7+
8+
import copy
9+
import os
10+
import torch
11+
12+
13+
def export_policy_as_onnx(
14+
actor_critic: object,
15+
path: str,
16+
normalizer: object | None = None,
17+
filename="policy.onnx",
18+
verbose=False,
19+
):
20+
"""Export policy into a Torch ONNX file.
21+
22+
Args:
23+
actor_critic: The actor-critic torch module.
24+
normalizer: The empirical normalizer module. If None, Identity is used.
25+
path: The path to the saving directory.
26+
filename: The name of exported ONNX file. Defaults to "policy.onnx".
27+
verbose: Whether to print the model summary. Defaults to False.
28+
"""
29+
if not os.path.exists(path):
30+
os.makedirs(path, exist_ok=True)
31+
policy_exporter = _OnnxPolicyExporter(actor_critic, normalizer, verbose)
32+
policy_exporter.export(path, filename)
33+
34+
35+
"""
36+
Helper Classes - Private.
37+
"""
38+
39+
40+
class _TorchPolicyExporter(torch.nn.Module):
41+
"""Exporter of actor-critic into JIT file."""
42+
43+
def __init__(self, actor_critic, normalizer=None):
44+
super().__init__()
45+
self.actor = copy.deepcopy(actor_critic.actor)
46+
self.is_recurrent = actor_critic.is_recurrent
47+
if self.is_recurrent:
48+
self.rnn = copy.deepcopy(actor_critic.memory_a.rnn)
49+
self.rnn.cpu()
50+
self.register_buffer(
51+
"hidden_state",
52+
torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size),
53+
)
54+
self.register_buffer(
55+
"cell_state", torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size)
56+
)
57+
self.forward = self.forward_lstm
58+
self.reset = self.reset_memory
59+
# copy normalizer if exists
60+
if normalizer:
61+
self.normalizer = copy.deepcopy(normalizer)
62+
else:
63+
self.normalizer = torch.nn.Identity()
64+
65+
def forward_lstm(self, x):
66+
x = self.normalizer(x)
67+
x, (h, c) = self.rnn(x.unsqueeze(0), (self.hidden_state, self.cell_state))
68+
self.hidden_state[:] = h
69+
self.cell_state[:] = c
70+
x = x.squeeze(0)
71+
return self.actor(x)
72+
73+
def forward(self, x):
74+
return self.actor(self.normalizer(x))
75+
76+
@torch.jit.export
77+
def reset(self):
78+
pass
79+
80+
def reset_memory(self):
81+
self.hidden_state[:] = 0.0
82+
self.cell_state[:] = 0.0
83+
84+
def export(self, path, filename):
85+
os.makedirs(path, exist_ok=True)
86+
path = os.path.join(path, filename)
87+
self.to("cpu")
88+
traced_script_module = torch.jit.script(self)
89+
traced_script_module.save(path)
90+
91+
92+
class _OnnxPolicyExporter(torch.nn.Module):
93+
"""Exporter of actor-critic into ONNX file."""
94+
95+
def __init__(self, actor_critic, normalizer=None, verbose=False):
96+
super().__init__()
97+
self.verbose = verbose
98+
self.actor = copy.deepcopy(actor_critic.actor)
99+
self.is_recurrent = actor_critic.is_recurrent
100+
if self.is_recurrent:
101+
self.rnn = copy.deepcopy(actor_critic.memory_a.rnn)
102+
self.rnn.cpu()
103+
self.forward = self.forward_lstm
104+
# copy normalizer if exists
105+
if normalizer:
106+
self.normalizer = copy.deepcopy(normalizer)
107+
else:
108+
self.normalizer = torch.nn.Identity()
109+
110+
def forward_lstm(self, x_in, h_in, c_in):
111+
x_in = self.normalizer(x_in)
112+
x, (h, c) = self.rnn(x_in.unsqueeze(0), (h_in, c_in))
113+
x = x.squeeze(0)
114+
return self.actor(x), h, c
115+
116+
def forward(self, x):
117+
return self.actor(self.normalizer(x))
118+
119+
def export(self, path, filename):
120+
self.to("cpu")
121+
if self.is_recurrent:
122+
obs = torch.zeros(1, self.rnn.input_size)
123+
h_in = torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size)
124+
c_in = torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size)
125+
actions, h_out, c_out = self(obs, h_in, c_in)
126+
torch.onnx.export(
127+
self,
128+
(obs, h_in, c_in),
129+
os.path.join(path, filename),
130+
export_params=True,
131+
opset_version=11,
132+
verbose=self.verbose,
133+
input_names=["obs", "h_in", "c_in"],
134+
output_names=["actions", "h_out", "c_out"],
135+
dynamic_axes={},
136+
)
137+
else:
138+
obs = torch.zeros(1, self.actor[0].in_features)
139+
torch.onnx.export(
140+
self,
141+
obs,
142+
os.path.join(path, filename),
143+
export_params=True,
144+
opset_version=11,
145+
verbose=self.verbose,
146+
input_names=["obs"],
147+
output_names=["actions"],
148+
dynamic_axes={},
149+
)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ dependencies = [
3535
"numpy>=1.21.0",
3636
"scipy>=1.7.0",
3737
"torch>=1.10.0",
38-
"rsl-rl-lib>=1.0.0",
38+
"rsl-rl-lib>=2.3.0",
3939
]
4040
dynamic = ["version"]
4141

0 commit comments

Comments
 (0)