diff --git a/invokeai/backend/flux/trajectory_guidance_extension.py b/invokeai/backend/flux/trajectory_guidance_extension.py index 9b779e3338..b6329a1b6b 100644 --- a/invokeai/backend/flux/trajectory_guidance_extension.py +++ b/invokeai/backend/flux/trajectory_guidance_extension.py @@ -49,7 +49,6 @@ class TrajectoryGuidanceExtension: """ assert 0.0 <= trajectory_guidance_strength <= 1.0 self._init_latents = init_latents - self._trajectory_guidance_strength = trajectory_guidance_strength if inpaint_mask is None: # The inpaing mask is None, so we initialize a mask with a single value of 1.0. # This value will be broadcasted and treated as a mask of all 1s. @@ -57,6 +56,13 @@ class TrajectoryGuidanceExtension: else: self._inpaint_mask = inpaint_mask + # Calculate the params that define the trajectory guidance schedule. + # These mappings from trajectory_guidance_strength have no theoretical basis - they were tuned manually. + self._trajectory_guidance_strength = trajectory_guidance_strength + self._change_ratio_at_t_1 = build_line(x1=0.0, y1=1.0, x2=1.0, y2=0.0)(self._trajectory_guidance_strength) + self._change_ratio_at_cutoff = 1.0 + self._t_cutoff = build_line(x1=0.0, y1=1.0, x2=1.0, y2=0.5)(self._trajectory_guidance_strength) + def _apply_mask_gradient_adjustment(self, t_prev: float) -> torch.Tensor: """Applies inpaint mask gradient adjustment and returns the inpaint mask to be used at the current timestep.""" # As we progress through the denoising process, we promote gradient regions of the mask to have a full weight of @@ -75,24 +81,25 @@ class TrajectoryGuidanceExtension: return mask + def _get_change_ratio(self, t_prev: float) -> float: + """Get the change_ratio for t_prev based on the change schedule.""" + change_ratio = 1.0 + if t_prev > self._t_cutoff: + # If we are before the cutoff, linearly interpolate between the change_ratio at t=1.0 and the change_ratio + # at the cutoff. + change_ratio = build_line( + x1=1.0, y1=self._change_ratio_at_t_1, x2=self._t_cutoff, y2=self._change_ratio_at_cutoff + )(t_prev) + + return change_ratio + def step( self, t_curr_latents: torch.Tensor, pred_noise: torch.Tensor, t_curr: float, t_prev: float ) -> torch.Tensor: # Handle gradient cutoff. mask = self._apply_mask_gradient_adjustment(t_prev) - # Calculate the change_ratio based on the trajectory_guidance_strength. - # These mappings from trajectory_guidance_strength have no theoretical basis - they were tuned manually. - change_ratio_at_t_1 = build_line(x1=0.0, y1=1.0, x2=1.0, y2=0.0)(self._trajectory_guidance_strength) - change_ratio_at_cutoff = 1.0 - t_cutoff = build_line(x1=0.0, y1=1.0, x2=1.0, y2=0.5)(self._trajectory_guidance_strength) - change_ratio = 1.0 - if t_prev > t_cutoff: - # If we are before the cutoff, linearly interpolate between the change_ratio at t=1.0 and the change_ratio - # at the cutoff. - change_ratio = build_line(x1=1.0, y1=change_ratio_at_t_1, x2=t_cutoff, y2=change_ratio_at_cutoff)(t_prev) - - mask = mask * change_ratio + mask = mask * self._get_change_ratio(t_prev) # NOTE(ryand): During inpainting, it is common to guide the denoising process by noising the initial latents for # the current timestep and then blending the predicted intermediate latents with the noised initial latents.