Skip to content

Commit

Permalink
small efficiency improvements to prioritized replay buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Apr 2, 2024
1 parent 00cab17 commit a670356
Showing 1 changed file with 18 additions and 19 deletions.
37 changes: 18 additions & 19 deletions src/gfn/containers/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,6 @@ def add(self, training_objects: Transitions | Trajectories | tuple[States]):
training_objects, terminating_states = training_objects

to_add = len(training_objects)

self._is_full |= len(self) + to_add >= self.capacity

# The buffer isn't full yet.
Expand All @@ -177,18 +176,14 @@ def add(self, training_objects: Transitions | Trajectories | tuple[States]):
# Our buffer is full and we will prioritize diverse, high reward additions.
else:
# Sort the incoming elements by their logrewards.
ix = torch.argsort(training_objects._log_rewards, descending=True)
ix = torch.argsort(training_objects.log_rewards, descending=True)
training_objects = training_objects[ix]

# Filter all batch logrewards lower than the smallest logreward in buffer.
min_reward_in_buffer = self.training_objects.log_rewards.min()
idx_bigger_rewards = training_objects.log_rewards > min_reward_in_buffer
idx_bigger_rewards = training_objects.log_rewards >= min_reward_in_buffer
training_objects = training_objects[idx_bigger_rewards]

# Compute all pairwise distances between the batch and the buffer.
curr_dim = training_objects.last_states.batch_shape[0]
buffer_dim = self.training_objects.last_states.batch_shape[0]

# TODO: Concatenate input with final state for conditional GFN.
# if self.is_conditional:
# batch = torch.cat(
Expand All @@ -199,37 +194,41 @@ def add(self, training_objects: Transitions | Trajectories | tuple[States]):
# [self.storage["input"], self.storage["final_state"]],
# dim=-1,
# )
batch = training_objects.last_states.tensor.float()
buffer = self.training_objects.last_states.tensor.float()

# Filter the batch for diverse final_states with high reward.
batch = training_objects.last_states.tensor.float()
batch_dim = training_objects.last_states.batch_shape[0]
batch_batch_dist = torch.cdist(
batch.view(curr_dim, -1).unsqueeze(0),
batch.view(curr_dim, -1).unsqueeze(0),
batch.view(batch_dim, -1).unsqueeze(0),
batch.view(batch_dim, -1).unsqueeze(0),
p=self.p_norm_distance,
).squeeze(0)

# Finds the min distance at each row, and removes rows below the cutoff.
r, w = torch.triu_indices(*batch_batch_dist.shape) # Remove upper diag.
batch_batch_dist[r, w] = torch.finfo(batch_batch_dist.dtype).max
batch_batch_dist = batch_batch_dist.min(-1)[0]
idx_batch_batch = batch_batch_dist > self.cutoff_distance
training_objects = training_objects[idx_batch_batch]

# Filter the batch for diverse final_states w.r.t the buffer.
# Compute all pairwise distances between the remaining batch and the buffer.
batch = training_objects.last_states.tensor.float()
buffer = self.training_objects.last_states.tensor.float()
batch_dim = training_objects.last_states.batch_shape[0]
buffer_dim = self.training_objects.last_states.batch_shape[0]
batch_buffer_dist = (
torch.cdist(
batch.view(curr_dim, -1).unsqueeze(0),
batch.view(batch_dim, -1).unsqueeze(0),
buffer.view(buffer_dim, -1).unsqueeze(0),
p=self.p_norm_distance,
)
.squeeze(0)
.min(-1)[0]
.min(-1)[0] # Min calculated over rows, i.e., over the batch elements.
)

# Remove non-diverse examples according to the above distances.
idx_batch_batch = batch_batch_dist > self.cutoff_distance
# Filter the batch for diverse final_states w.r.t the buffer.
idx_batch_buffer = batch_buffer_dist > self.cutoff_distance
idx_diverse = idx_batch_batch & idx_batch_buffer

training_objects = training_objects[idx_diverse]
training_objects = training_objects[idx_batch_buffer]

# If any training object remain after filtering, add them.
if len(training_objects):
Expand Down

0 comments on commit a670356

Please sign in to comment.