Skip to content

Commit

Permalink
Adjust AdversarialBalancing to numpy>=2
Browse files Browse the repository at this point in the history
replace `row_stack` with `vstack` and `np.NaN` with `np.nan`.

Signed-off-by: Ehud-Karavani <[email protected]>
  • Loading branch information
ehudkr committed Jul 25, 2024
1 parent 3551319 commit 48f0ae1
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def _run(self, X, A, w_init=None, is_train=True, use_stabilized=None, **select_k
if not np.all(unique_treatments == np.arange(n_treatments)):
raise AssertionError("Treatment values in `a` must be indexed 0, 1, 2, ...")
self.iterative_models_ = np.empty((n_treatments, self.iterations), dtype=object)
self.iterative_normalizing_consts_ = np.full((n_treatments, self.iterations), np.NaN)
self.iterative_normalizing_consts_ = np.full((n_treatments, self.iterations), np.nan)

self.discriminator_loss_ = np.zeros((n_treatments, self.iterations))
self.treatments_frequency_ = _compute_treatments_frequency(A)
Expand All @@ -147,7 +147,7 @@ def _run(self, X, A, w_init=None, is_train=True, use_stabilized=None, **select_k
# population ("source population"),
# and the samples with label -1 are the population under treatment a ("target population").
# Labels 1 and -1 (rather than 0) are used because of the later exponential loss function
X_augm = np.row_stack((X, X[A == a])) # create the augmented dataset
X_augm = np.vstack((X, X[A == a])) # create the augmented dataset
y = np.ones((X_augm.shape[0]))
y[X.shape[0]:] *= -1 # subpopulation of current treatment (a) has y== -1
target_pop_mask = y == -1
Expand Down
2 changes: 1 addition & 1 deletion causallib/contrib/tests/test_adversarial_balancing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class TestAdversarialBalancing(unittest.TestCase):
def create_identical_treatment_groups_data(n=100):
np.random.seed(42)
X = np.random.rand(n, 3)
X = np.row_stack((X, X)) # Duplicate identical samples
X = np.vstack((X, X)) # Duplicate identical samples
a = np.array([1] * n + [0] * n) # Give duplicated samples different treatment assignment
X, a = pd.DataFrame(X), pd.Series(a)
return X, a
Expand Down

0 comments on commit 48f0ae1

Please sign in to comment.