InvokeAI/invokeai/backend/anima/conditioning_data.py
4pointoh f0d09c34a8
feat: add Anima model support (#8961)
* feat: add Anima model support

* schema

* image to image

* regional guidance

* loras

* last fixes

* tests

* fix attributions

* fix attributions

* refactor to use diffusers reference

* fix an additional lora type

* some adjustments to follow flux 2 paper implementation

* use t5 from model manager instead of downloading

* make lora identification more reliable

* fix: resolve lint errors in anima module

Remove unused variable, fix import ordering, inline dict() call,
and address minor lint issues across anima-related files.

* Chore Ruff format again

* fix regional guidance error

* fix(anima): validate unexpected keys after strict=False checkpoint loading

Capture the load_state_dict result and raise RuntimeError on unexpected
keys (indicating a corrupted or incompatible checkpoint), while logging
a warning for missing keys (expected for inv_freq buffers regenerated
at runtime).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix(anima): make model loader submodel fields required instead of Optional

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix(anima): add Classification.Prototype to LoRA loaders, fix exception types

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix(anima): fix replace-all in key conversion, warn on DoRA+LoKR, unify grouping functions

- Use key.replace(old, new, 1) in _convert_kohya_unet_key and _convert_kohya_te_key to avoid replacing multiple occurrences
- Upgrade DoRA+LoKR dora_scale strip from logger.debug to logger.warning since it represents data loss
- Replace _group_kohya_keys and _group_by_layer with a single _group_keys_by_layer function parameterized by extra_suffixes, with _KOHYA_KNOWN_SUFFIXES and _PEFT_EXTRA_SUFFIXES constants
- Add test_empty_state_dict_returns_empty_model to verify empty input produces a model with no layers

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix(anima): add safety cap for Qwen3 sequence length to prevent OOM

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix(anima): add denoising range validation, fix closure capture, add edge case tests

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix(anima): add T5 to metadata, fix dead code, decouple scheduler type guard

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix(anima): update VAE field description for required field

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* chore: regenerate frontend types after upstream merge

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* chore: ruff format anima_denoise.py

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix(anima): add T5 encoder metadata recall handler

The T5 encoder was added to generation metadata but had no recall
handler, so it wasn't restored when recalling from metadata.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* chore(frontend): add regression test for buildAnimaGraph

Add tests for CFG gating (negative conditioning omitted when cfgScale <= 1)
and basic graph structure (model loader, text encoder, denoise nodes).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* only show 0.6b for anima

* dont show 0.6b for other models

* schema

* Anima preview 3

* fix ci

---------

Co-authored-by: Your Name <you@example.com>
Co-authored-by: kappacommit <samwolfe40@gmail.com>
Co-authored-by: Alexander Eichhorn <alex@eichhorn.dev>
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2026-04-09 12:04:11 -04:00

65 lines
2.6 KiB
Python

"""Anima text conditioning data structures.
Anima uses a dual-conditioning scheme:
- Qwen3 0.6B hidden states (continuous embeddings)
- T5-XXL token IDs (discrete IDs, embedded by the LLM Adapter inside the transformer)
Both are produced by the text encoder invocation and stored together.
For regional prompting, multiple conditionings (each with an optional spatial mask)
are concatenated and processed together. The LLM Adapter runs on each region's
conditioning separately, producing per-region context vectors that are concatenated
for the DiT's cross-attention layers. An attention mask restricts which image tokens
attend to which regional context tokens.
"""
from dataclasses import dataclass
import torch
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Range
@dataclass
class AnimaTextConditioning:
"""Anima text conditioning with Qwen3 hidden states, T5-XXL token IDs, and optional mask.
Attributes:
qwen3_embeds: Text embeddings from Qwen3 0.6B encoder.
Shape: (seq_len, hidden_size) where hidden_size=1024.
t5xxl_ids: T5-XXL token IDs for the same prompt.
Shape: (seq_len,).
t5xxl_weights: Per-token weights for prompt weighting.
Shape: (seq_len,). Defaults to all ones if not provided.
mask: Optional binary mask for regional prompting. If None, the prompt is global.
Shape: (1, 1, img_seq_len) where img_seq_len = (H // patch_size) * (W // patch_size).
"""
qwen3_embeds: torch.Tensor
t5xxl_ids: torch.Tensor
t5xxl_weights: torch.Tensor | None = None
mask: torch.Tensor | None = None
@dataclass
class AnimaRegionalTextConditioning:
"""Container for multiple regional text conditionings processed by the LLM Adapter.
After the LLM Adapter processes each region's conditioning, the outputs are concatenated.
The DiT cross-attention then uses an attention mask to restrict which image tokens
attend to which region's context tokens.
Attributes:
context_embeds: Concatenated LLM Adapter outputs from all regional prompts.
Shape: (total_context_len, 1024).
image_masks: List of binary masks for each regional prompt.
If None, the prompt is global (applies to entire image).
Shape: (1, 1, img_seq_len).
context_ranges: List of ranges indicating which portion of context_embeds
corresponds to each regional prompt.
"""
context_embeds: torch.Tensor
image_masks: list[torch.Tensor | None]
context_ranges: list[Range]