Skip to content

Commit

Permalink
changed how calculations are done so we don't pass a States instance …
Browse files Browse the repository at this point in the history
…to multiprocessing, changed assert to warning.
  • Loading branch information
josephdviviano committed Jul 18, 2024
1 parent 421a167 commit 5dbd15d
Showing 1 changed file with 52 additions and 38 deletions.
90 changes: 52 additions & 38 deletions src/gfn/gym/hypergrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
Copied and Adapted from https://github.com/Tikquuss/GflowNets_Tutorial
"""

from typing import Literal, Tuple
from math import gcd, log
from functools import reduce
from decimal import Decimal
import itertools
import torch
import multiprocessing
import warnings
from decimal import Decimal
from functools import reduce
from math import gcd, log
from time import time
from typing import Literal, Tuple

import torch
from einops import rearrange
Expand All @@ -21,7 +21,6 @@
from gfn.preprocessors import EnumPreprocessor, IdentityPreprocessor
from gfn.states import DiscreteStates


multiprocessing.set_start_method("fork") # multiprocessing-torch compatibility.


Expand Down Expand Up @@ -85,28 +84,32 @@ def __init__(
all_states. Might have intractable space complexity for very large
problems.
"""
assert height > 4, "height <= 4 can lead to unsolvable environments."
if height <= 4:
warnings.warn("+ Warning: height <= 4 can lead to unsolvable environments.")

self.ndim = ndim
self.height = height
self.R0 = R0
self.R1 = R1
self.R2 = R2
self.reward_cos = reward_cos
self._all_states = None # Populated at first request.
self._log_partition = None # Populated at first request.
self._true_dist = None # Populated at first request.
self._all_states = None # Populated optionally in init.
self._log_partition = None # Populated optionally in init.
self._true_dist_pmf = None # Populated at first request.
self.calculate_partition = calculate_partition
self.calculate_all_states = calculate_all_states

# Pre-computes these values when printing.
if self.calculate_all_states:
print("+ Environment has {} states".format(len(self.all_states)))
self._calculate_all_states_tensor()
print("+ Environment has {} states".format(len(self._all_states)))
if self.calculate_partition:
print("+ Environment log partition is {}".format(self.log_partition))
self._calculate_log_partition()
print("+ Environment log partition is {}".format(self._log_partition))

# This scale is used to stabilize some calculations.
self.scale_factor = smallest_multiplier_to_integers([R0, R1, R2])
# self.scale_factor = 1
# self.scale_factor = smallest_multiplier_to_integers([R0, R1, R2])
self.scale_factor = 1

s0 = torch.zeros(ndim, dtype=torch.long, device=torch.device(device_str))
sf = torch.full(
Expand Down Expand Up @@ -225,22 +228,9 @@ def n_states(self) -> int:
def n_terminating_states(self) -> int:
return self.n_states

@property
def true_dist_pmf(self) -> torch.Tensor:
"""Returns the pmf over all states in the hypergrid."""
if not self._true_dist and self.calculate_all_states:
assert torch.all(
self.get_states_indices(self.all_states)
== torch.arange(self.n_states, device=self.device)
)
self._true_dist = self.reward(self.all_states)
self._true_dist /= self._true_dist.sum()

return self._true_dist

@property
def log_partition(self, batch_size: int = 20_000) -> float:
"""Returns the log partition of the complete hypergrid.
# Functions for calculating the true log partition function / state enumeration.
def _calculate_log_partition(self, batch_size: int = 20_000):
"""Calculates the log partition of the complete hypergrid.
Args:
batch_size: Compute this number of hypergrid indices in parallel.
Expand All @@ -262,7 +252,9 @@ def log_partition(self, batch_size: int = 20_000) -> float:
batch_size,
):
batch = torch.LongTensor(list(batch))
rewards = self.reward(batch) # Operates on raw tensors due to multiprocessing.
rewards = self.reward(
batch
) # Operates on raw tensors due to multiprocessing.
total_reward += rewards.sum().item() # Accumulate.
n_found += batch.shape[0]

Expand All @@ -279,12 +271,12 @@ def log_partition(self, batch_size: int = 20_000) -> float:

self._log_partition = total_log_reward

return self._log_partition

@property
def all_states(self, batch_size: int = 20_000) -> DiscreteStates:
"""Returns a tensor of all hypergrid states."""
def _calculate_all_states_tensor(self, batch_size: int = 20_000):
"""Enumerates all states of the complete hypergrid.
Args:
batch_size: Compute this number of hypergrid indices in parallel.
"""
if self._all_states is None and self.calculate_all_states:
start_time = time()
all_states = []
Expand All @@ -305,14 +297,36 @@ def all_states(self, batch_size: int = 20_000) -> DiscreteStates:
)
)

self._all_states = self.States(all_states)
self._all_states = all_states

# These properties are optionally available according to the flags set in init.
@property
def true_dist_pmf(self) -> torch.Tensor:
"""Returns the pmf over all states in the hypergrid."""
if not self._true_dist_pmf and self.calculate_all_states:
assert torch.all(
self.get_states_indices(self.all_states)
== torch.arange(self.n_states, device=self.device)
)
self._true_dist_pmf = self.reward(self.all_states)
self._true_dist_pmf /= self._true_dist_pmf.sum()

return self._true_dist_pmf

return self._all_states
@property
def log_partition(self) -> float:
return self._log_partition

@property
def all_states(self) -> DiscreteStates:
"""Returns a tensor of all hypergrid states as a States instance."""
return self.States(self._all_states)

@property
def terminating_states(self) -> DiscreteStates:
return self.all_states

# Helper methods for enumerating all possible states.
def _generate_combinations_chunk(self, numbers, n, start, end):
"""Generate combinations with replacement for the specified range."""
# islice accesses a subset of the full iterator - each job does unique work.
Expand Down

0 comments on commit 5dbd15d

Please sign in to comment.