From 04e50e945809a914563c6c391b86ae939b838a9a Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Tue, 2 Apr 2024 17:38:55 -0400 Subject: [PATCH] reworking of prioritized replay buffer logic --- src/gfn/containers/replay_buffer.py | 65 +++++++++++++++-------------- 1 file changed, 34 insertions(+), 31 deletions(-) diff --git a/src/gfn/containers/replay_buffer.py b/src/gfn/containers/replay_buffer.py index fe1b1110..2cf9fcc6 100644 --- a/src/gfn/containers/replay_buffer.py +++ b/src/gfn/containers/replay_buffer.py @@ -132,7 +132,9 @@ def __init__( capacity: the size of the buffer. objects_type: the type of buffer (transitions, trajectories, or states). cutoff_distance: threshold used to determine if new last_states are - different enough from those already contained in the buffer. + different enough from those already contained in the buffer. If the + cutoff is negative, all diversity caclulations are skipped (since all + norms are >= 0). p_norm_distance: p-norm distance value to pass to torch.cdist, for the determination of novel states. """ @@ -195,40 +197,41 @@ def add(self, training_objects: Transitions | Trajectories | tuple[States]): # dim=-1, # ) - # Filter the batch for diverse final_states with high reward. - batch = training_objects.last_states.tensor.float() - batch_dim = training_objects.last_states.batch_shape[0] - batch_batch_dist = torch.cdist( - batch.view(batch_dim, -1).unsqueeze(0), - batch.view(batch_dim, -1).unsqueeze(0), - p=self.p_norm_distance, - ).squeeze(0) - - # Finds the min distance at each row, and removes rows below the cutoff. - r, w = torch.triu_indices(*batch_batch_dist.shape) # Remove upper diag. - batch_batch_dist[r, w] = torch.finfo(batch_batch_dist.dtype).max - batch_batch_dist = batch_batch_dist.min(-1)[0] - idx_batch_batch = batch_batch_dist > self.cutoff_distance - training_objects = training_objects[idx_batch_batch] - - # Compute all pairwise distances between the remaining batch and the buffer. - batch = training_objects.last_states.tensor.float() - buffer = self.training_objects.last_states.tensor.float() - batch_dim = training_objects.last_states.batch_shape[0] - buffer_dim = self.training_objects.last_states.batch_shape[0] - batch_buffer_dist = ( - torch.cdist( + if self.cutoff_distance >= 0: + # Filter the batch for diverse final_states with high reward. + batch = training_objects.last_states.tensor.float() + batch_dim = training_objects.last_states.batch_shape[0] + batch_batch_dist = torch.cdist( + batch.view(batch_dim, -1).unsqueeze(0), batch.view(batch_dim, -1).unsqueeze(0), - buffer.view(buffer_dim, -1).unsqueeze(0), p=self.p_norm_distance, + ).squeeze(0) + + # Finds the min distance at each row, and removes rows below the cutoff. + r, w = torch.triu_indices(*batch_batch_dist.shape) # Remove upper diag. + batch_batch_dist[r, w] = torch.finfo(batch_batch_dist.dtype).max + batch_batch_dist = batch_batch_dist.min(-1)[0] + idx_batch_batch = batch_batch_dist > self.cutoff_distance + training_objects = training_objects[idx_batch_batch] + + # Compute all pairwise distances between the remaining batch & buffer. + batch = training_objects.last_states.tensor.float() + buffer = self.training_objects.last_states.tensor.float() + batch_dim = training_objects.last_states.batch_shape[0] + buffer_dim = self.training_objects.last_states.batch_shape[0] + batch_buffer_dist = ( + torch.cdist( + batch.view(batch_dim, -1).unsqueeze(0), + buffer.view(buffer_dim, -1).unsqueeze(0), + p=self.p_norm_distance, + ) + .squeeze(0) + .min(-1)[0] # Min calculated over rows - the batch elements. ) - .squeeze(0) - .min(-1)[0] # Min calculated over rows, i.e., over the batch elements. - ) - # Filter the batch for diverse final_states w.r.t the buffer. - idx_batch_buffer = batch_buffer_dist > self.cutoff_distance - training_objects = training_objects[idx_batch_buffer] + # Filter the batch for diverse final_states w.r.t the buffer. + idx_batch_buffer = batch_buffer_dist > self.cutoff_distance + training_objects = training_objects[idx_batch_buffer] # If any training object remain after filtering, add them. if len(training_objects):