diff --git a/src/gfn/containers/replay_buffer.py b/src/gfn/containers/replay_buffer.py index 2bc870ca..fe1b1110 100644 --- a/src/gfn/containers/replay_buffer.py +++ b/src/gfn/containers/replay_buffer.py @@ -167,7 +167,6 @@ def add(self, training_objects: Transitions | Trajectories | tuple[States]): training_objects, terminating_states = training_objects to_add = len(training_objects) - self._is_full |= len(self) + to_add >= self.capacity # The buffer isn't full yet. @@ -177,18 +176,14 @@ def add(self, training_objects: Transitions | Trajectories | tuple[States]): # Our buffer is full and we will prioritize diverse, high reward additions. else: # Sort the incoming elements by their logrewards. - ix = torch.argsort(training_objects._log_rewards, descending=True) + ix = torch.argsort(training_objects.log_rewards, descending=True) training_objects = training_objects[ix] # Filter all batch logrewards lower than the smallest logreward in buffer. min_reward_in_buffer = self.training_objects.log_rewards.min() - idx_bigger_rewards = training_objects.log_rewards > min_reward_in_buffer + idx_bigger_rewards = training_objects.log_rewards >= min_reward_in_buffer training_objects = training_objects[idx_bigger_rewards] - # Compute all pairwise distances between the batch and the buffer. - curr_dim = training_objects.last_states.batch_shape[0] - buffer_dim = self.training_objects.last_states.batch_shape[0] - # TODO: Concatenate input with final state for conditional GFN. # if self.is_conditional: # batch = torch.cat( @@ -199,37 +194,41 @@ def add(self, training_objects: Transitions | Trajectories | tuple[States]): # [self.storage["input"], self.storage["final_state"]], # dim=-1, # ) - batch = training_objects.last_states.tensor.float() - buffer = self.training_objects.last_states.tensor.float() # Filter the batch for diverse final_states with high reward. + batch = training_objects.last_states.tensor.float() + batch_dim = training_objects.last_states.batch_shape[0] batch_batch_dist = torch.cdist( - batch.view(curr_dim, -1).unsqueeze(0), - batch.view(curr_dim, -1).unsqueeze(0), + batch.view(batch_dim, -1).unsqueeze(0), + batch.view(batch_dim, -1).unsqueeze(0), p=self.p_norm_distance, ).squeeze(0) + # Finds the min distance at each row, and removes rows below the cutoff. r, w = torch.triu_indices(*batch_batch_dist.shape) # Remove upper diag. batch_batch_dist[r, w] = torch.finfo(batch_batch_dist.dtype).max batch_batch_dist = batch_batch_dist.min(-1)[0] + idx_batch_batch = batch_batch_dist > self.cutoff_distance + training_objects = training_objects[idx_batch_batch] - # Filter the batch for diverse final_states w.r.t the buffer. + # Compute all pairwise distances between the remaining batch and the buffer. + batch = training_objects.last_states.tensor.float() + buffer = self.training_objects.last_states.tensor.float() + batch_dim = training_objects.last_states.batch_shape[0] + buffer_dim = self.training_objects.last_states.batch_shape[0] batch_buffer_dist = ( torch.cdist( - batch.view(curr_dim, -1).unsqueeze(0), + batch.view(batch_dim, -1).unsqueeze(0), buffer.view(buffer_dim, -1).unsqueeze(0), p=self.p_norm_distance, ) .squeeze(0) - .min(-1)[0] + .min(-1)[0] # Min calculated over rows, i.e., over the batch elements. ) - # Remove non-diverse examples according to the above distances. - idx_batch_batch = batch_batch_dist > self.cutoff_distance + # Filter the batch for diverse final_states w.r.t the buffer. idx_batch_buffer = batch_buffer_dist > self.cutoff_distance - idx_diverse = idx_batch_batch & idx_batch_buffer - - training_objects = training_objects[idx_diverse] + training_objects = training_objects[idx_batch_buffer] # If any training object remain after filtering, add them. if len(training_objects):