Skip to content

Commit

Permalink
bugfix test
Browse files Browse the repository at this point in the history
  • Loading branch information
brycedrennan committed Jan 19, 2024
1 parent da29ddc commit cf2a084
Showing 1 changed file with 3 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(
)
self.estimated_data = deque([tensor([])] * 2, maxlen=2)
self.last_step_first_order = last_step_first_order
self._first_step_has_been_run = False

def _generate_timesteps(self) -> Tensor:
# We need to use numpy here because:
Expand Down Expand Up @@ -80,7 +81,6 @@ def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int) -> Tens
previous_scale_factor = self.cumulative_scale_factors[previous_timestep]
previous_noise_std = self.noise_std[previous_timestep]
current_noise_std = self.noise_std[current_timestep]

estimation_delta = (current_data_estimation - next_data_estimation) / (
(current_ratio - next_ratio) / (previous_ratio - current_ratio)
)
Expand All @@ -105,7 +105,8 @@ def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | N
estimated_denoised_data = (x - noise_ratio * noise) / scale_factor
self.estimated_data.append(estimated_denoised_data)

if step == 0 or (self.last_step_first_order and step == self.num_inference_steps - 1):
if step == 0 or (self.last_step_first_order and step == self.num_inference_steps - 1) or not self._first_step_has_been_run:
self._first_step_has_been_run = True
return self.dpm_solver_first_order_update(x=x, noise=estimated_denoised_data, step=step)

return self.multistep_dpm_solver_second_order_update(x=x, step=step)

0 comments on commit cf2a084

Please sign in to comment.