Skip to content

Commit

Permalink
feat: check types in splines to allow amp (#52)
Browse files Browse the repository at this point in the history
* Cast outputs to allow amp

* Only cast if necessary
  • Loading branch information
VincentStimper authored Oct 26, 2023
1 parent 9607072 commit 374c6e4
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions normflows/utils/splines.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def unconstrained_rational_quadratic_spline(
top = tail_bound

(
outputs[inside_interval_mask],
logabsdet[inside_interval_mask],
outputs_masked,
logabsdet_masked
) = rational_quadratic_spline(
inputs=inputs[inside_interval_mask],
unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
Expand All @@ -87,6 +87,12 @@ def unconstrained_rational_quadratic_spline(
min_bin_height=min_bin_height,
min_derivative=min_derivative,
)
if outputs.dtype == outputs_masked.dtype and logabsdet.dtype == logabsdet_masked.dtype:
outputs[inside_interval_mask] = outputs_masked
logabsdet[inside_interval_mask] = logabsdet_masked
else:
outputs[inside_interval_mask] = outputs_masked.to(outputs.dtype)
logabsdet[inside_interval_mask] = logabsdet_masked.to(logabsdet.dtype)

return outputs, logabsdet

Expand Down

0 comments on commit 374c6e4

Please sign in to comment.