diff --git a/torchcfm/optimal_transport.py b/torchcfm/optimal_transport.py index 2155d18..b901e8f 100644 --- a/torchcfm/optimal_transport.py +++ b/torchcfm/optimal_transport.py @@ -86,7 +86,9 @@ def sample_map(self, pi, batch_size, replace=True): """ p = pi.flatten() p = p / p.sum() - choices = np.random.choice(pi.shape[0] * pi.shape[1], p=p, size=batch_size, replace=replace) + choices = np.random.choice( + pi.shape[0] * pi.shape[1], p=p, size=batch_size, replace=replace + ) return np.divmod(choices, pi.shape[1]) def sample_plan(self, x0, x1, replace=True):