Fix preview image to work well with FLUX trajectory guidance.

This commit is contained in:
Ryan Dick 2024-09-20 21:08:41 +00:00
parent cd3a7bdb5e
commit a43a045b04
2 changed files with 8 additions and 9 deletions

View File

@ -38,12 +38,12 @@ def denoise(
)
if traj_guidance_extension is not None:
img = traj_guidance_extension.step(t_curr_latents=img, pred_noise=pred, t_curr=t_curr, t_prev=t_prev)
# TODO(ryand): Generate a better preview image.
preview_img = img
else:
preview_img = img - t_curr * pred
img = img + (t_prev - t_curr) * pred
pred = traj_guidance_extension.update_noise(
t_curr_latents=img, pred_noise=pred, t_curr=t_curr, t_prev=t_prev
)
preview_img = img - t_curr * pred
img = img + (t_prev - t_curr) * pred
step_callback(
PipelineIntermediateState(

View File

@ -96,7 +96,7 @@ class TrajectoryGuidanceExtension:
assert 0.0 - eps <= change_ratio <= 1.0 + eps
return change_ratio
def step(
def update_noise(
self, t_curr_latents: torch.Tensor, pred_noise: torch.Tensor, t_curr: float, t_prev: float
) -> torch.Tensor:
# Handle gradient cutoff.
@ -131,5 +131,4 @@ class TrajectoryGuidanceExtension:
# Blend the init_traj_noise with the pred_noise according to the inpaint mask and the trajectory guidance.
noise = pred_noise * mask + init_traj_noise * (1.0 - mask)
# Take a denoising step.
return t_curr_latents + (t_prev - t_curr) * noise
return noise