Pre-compute trajectory guidance schedule params rather than calculating on each step.

This commit is contained in:
Ryan Dick 2024-09-20 20:18:06 +00:00
parent a4a0cc6d10
commit 16ca540ece

View File

@ -49,7 +49,6 @@ class TrajectoryGuidanceExtension:
"""
assert 0.0 <= trajectory_guidance_strength <= 1.0
self._init_latents = init_latents
self._trajectory_guidance_strength = trajectory_guidance_strength
if inpaint_mask is None:
# The inpaing mask is None, so we initialize a mask with a single value of 1.0.
# This value will be broadcasted and treated as a mask of all 1s.
@ -57,6 +56,13 @@ class TrajectoryGuidanceExtension:
else:
self._inpaint_mask = inpaint_mask
# Calculate the params that define the trajectory guidance schedule.
# These mappings from trajectory_guidance_strength have no theoretical basis - they were tuned manually.
self._trajectory_guidance_strength = trajectory_guidance_strength
self._change_ratio_at_t_1 = build_line(x1=0.0, y1=1.0, x2=1.0, y2=0.0)(self._trajectory_guidance_strength)
self._change_ratio_at_cutoff = 1.0
self._t_cutoff = build_line(x1=0.0, y1=1.0, x2=1.0, y2=0.5)(self._trajectory_guidance_strength)
def _apply_mask_gradient_adjustment(self, t_prev: float) -> torch.Tensor:
"""Applies inpaint mask gradient adjustment and returns the inpaint mask to be used at the current timestep."""
# As we progress through the denoising process, we promote gradient regions of the mask to have a full weight of
@ -75,24 +81,25 @@ class TrajectoryGuidanceExtension:
return mask
def _get_change_ratio(self, t_prev: float) -> float:
"""Get the change_ratio for t_prev based on the change schedule."""
change_ratio = 1.0
if t_prev > self._t_cutoff:
# If we are before the cutoff, linearly interpolate between the change_ratio at t=1.0 and the change_ratio
# at the cutoff.
change_ratio = build_line(
x1=1.0, y1=self._change_ratio_at_t_1, x2=self._t_cutoff, y2=self._change_ratio_at_cutoff
)(t_prev)
return change_ratio
def step(
self, t_curr_latents: torch.Tensor, pred_noise: torch.Tensor, t_curr: float, t_prev: float
) -> torch.Tensor:
# Handle gradient cutoff.
mask = self._apply_mask_gradient_adjustment(t_prev)
# Calculate the change_ratio based on the trajectory_guidance_strength.
# These mappings from trajectory_guidance_strength have no theoretical basis - they were tuned manually.
change_ratio_at_t_1 = build_line(x1=0.0, y1=1.0, x2=1.0, y2=0.0)(self._trajectory_guidance_strength)
change_ratio_at_cutoff = 1.0
t_cutoff = build_line(x1=0.0, y1=1.0, x2=1.0, y2=0.5)(self._trajectory_guidance_strength)
change_ratio = 1.0
if t_prev > t_cutoff:
# If we are before the cutoff, linearly interpolate between the change_ratio at t=1.0 and the change_ratio
# at the cutoff.
change_ratio = build_line(x1=1.0, y1=change_ratio_at_t_1, x2=t_cutoff, y2=change_ratio_at_cutoff)(t_prev)
mask = mask * change_ratio
mask = mask * self._get_change_ratio(t_prev)
# NOTE(ryand): During inpainting, it is common to guide the denoising process by noising the initial latents for
# the current timestep and then blending the predicted intermediate latents with the noised initial latents.