Skip to content

Commit fb59fb3

Browse files
Enhance ReplayBuffer's feed_forward_generator to support sampling with replacement (#14)
1 parent 4fdc58f commit fb59fb3

File tree

2 files changed

+37
-15
lines changed

2 files changed

+37
-15
lines changed

amp_rsl_rl/algorithms/amp_ppo.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -376,10 +376,11 @@ def update(self) -> Tuple[float, float, float, float, float, float, float, float
376376

377377
# Generator for policy-generated AMP transitions.
378378
amp_policy_generator = self.amp_storage.feed_forward_generator(
379-
self.num_learning_epochs * self.num_mini_batches,
380-
self.storage.num_envs
379+
num_mini_batch=self.num_learning_epochs * self.num_mini_batches,
380+
mini_batch_size=self.storage.num_envs
381381
* self.storage.num_transitions_per_env
382382
// self.num_mini_batches,
383+
allow_replacement=True,
383384
)
384385

385386
# Generator for expert AMP data.

amp_rsl_rl/storage/replay_buffer.py

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -87,25 +87,46 @@ def feed_forward_generator(
8787
self,
8888
num_mini_batch: int,
8989
mini_batch_size: int,
90+
allow_replacement: bool = True,
9091
) -> Generator[Tuple[torch.Tensor, torch.Tensor], None, None]:
9192
"""
92-
Yield mini-batches of (state, next_state) tuples from the buffer.
93-
94-
Args:
95-
num_mini_batch (int): Number of mini-batches to generate.
96-
mini_batch_size (int): Number of samples per mini-batch.
97-
98-
Yields:
99-
Tuple[Tensor, Tensor]: A mini-batch of states and next_states.
93+
Yield `num_mini_batch` mini‑batches of (state, next_state) tuples from the buffer,
94+
each of length `mini_batch_size`.
95+
96+
If the total number of requested samples is larger than the number of
97+
items currently stored (`len(self)`), the method will
98+
99+
* raise an error when `allow_replacement=False`;
100+
* silently sample **with replacement** when `allow_replacement=True`
101+
(the default).
102+
103+
Args
104+
----
105+
num_mini_batch : int
106+
mini_batch_size : int
107+
allow_replacement : bool, optional
108+
Whether to allow sampling with replacement when the request
109+
exceeds the number of stored transitions.
100110
"""
101111
total = num_mini_batch * mini_batch_size
102-
assert (
103-
total <= self.num_samples
104-
), f"Not enough samples in buffer: requested {total}, but have {self.num_samples}"
105112

106-
# Generate a random permutation of valid indices on-device
107-
indices = torch.randperm(self.num_samples, device=self.device)[:total]
113+
# Sampling with replacement might yield duplicate samples, which can affect training dynamics
114+
if total > self.num_samples:
115+
if not allow_replacement:
116+
raise ValueError(
117+
f"Not enough samples in buffer: requested {total}, "
118+
f"but have {self.num_samples}"
119+
)
120+
# Permute‑then‑modulo
121+
cycles = (total + self.num_samples - 1) // self.num_samples
122+
big_size = self.num_samples * cycles
123+
big_perm = torch.randperm(big_size, device=self.device)
124+
indices = big_perm[:total] % self.num_samples
125+
else:
126+
# Sample WITHOUT replacement
127+
indices = torch.randperm(self.num_samples, device=self.device)[:total]
108128

129+
# Yield the mini‑batches
109130
for i in range(num_mini_batch):
110131
batch_idx = indices[i * mini_batch_size : (i + 1) * mini_batch_size]
111132
yield self.states[batch_idx], self.next_states[batch_idx]

0 commit comments

Comments
 (0)