@@ -87,25 +87,46 @@ def feed_forward_generator(
87
87
self ,
88
88
num_mini_batch : int ,
89
89
mini_batch_size : int ,
90
+ allow_replacement : bool = True ,
90
91
) -> Generator [Tuple [torch .Tensor , torch .Tensor ], None , None ]:
91
92
"""
92
- Yield mini-batches of (state, next_state) tuples from the buffer.
93
-
94
- Args:
95
- num_mini_batch (int): Number of mini-batches to generate.
96
- mini_batch_size (int): Number of samples per mini-batch.
97
-
98
- Yields:
99
- Tuple[Tensor, Tensor]: A mini-batch of states and next_states.
93
+ Yield `num_mini_batch` mini‑batches of (state, next_state) tuples from the buffer,
94
+ each of length `mini_batch_size`.
95
+
96
+ If the total number of requested samples is larger than the number of
97
+ items currently stored (`len(self)`), the method will
98
+
99
+ * raise an error when `allow_replacement=False`;
100
+ * silently sample **with replacement** when `allow_replacement=True`
101
+ (the default).
102
+
103
+ Args
104
+ ----
105
+ num_mini_batch : int
106
+ mini_batch_size : int
107
+ allow_replacement : bool, optional
108
+ Whether to allow sampling with replacement when the request
109
+ exceeds the number of stored transitions.
100
110
"""
101
111
total = num_mini_batch * mini_batch_size
102
- assert (
103
- total <= self .num_samples
104
- ), f"Not enough samples in buffer: requested { total } , but have { self .num_samples } "
105
112
106
- # Generate a random permutation of valid indices on-device
107
- indices = torch .randperm (self .num_samples , device = self .device )[:total ]
113
+ # Sampling with replacement might yield duplicate samples, which can affect training dynamics
114
+ if total > self .num_samples :
115
+ if not allow_replacement :
116
+ raise ValueError (
117
+ f"Not enough samples in buffer: requested { total } , "
118
+ f"but have { self .num_samples } "
119
+ )
120
+ # Permute‑then‑modulo
121
+ cycles = (total + self .num_samples - 1 ) // self .num_samples
122
+ big_size = self .num_samples * cycles
123
+ big_perm = torch .randperm (big_size , device = self .device )
124
+ indices = big_perm [:total ] % self .num_samples
125
+ else :
126
+ # Sample WITHOUT replacement
127
+ indices = torch .randperm (self .num_samples , device = self .device )[:total ]
108
128
129
+ # Yield the mini‑batches
109
130
for i in range (num_mini_batch ):
110
131
batch_idx = indices [i * mini_batch_size : (i + 1 ) * mini_batch_size ]
111
132
yield self .states [batch_idx ], self .next_states [batch_idx ]
0 commit comments