mirror of
https://github.com/invoke-ai/InvokeAI
synced 2026-04-07 15:35:07 +02:00
[feat] Make model prober recognize yet another LoRA format (#5296)
## What type of PR is this? (check all applicable)
- [ ] Refactor
- [X] Feature
- [ ] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission
## Have you discussed this change with the InvokeAI team?
- [X] Yes
- [ ] No, because:
## Have you updated all relevant documentation?
- [X] Yes
- [ ] No
## Description
This adds a probe for the SDXL LoRA format found in the wild at
https://civitai.com/models/224641.
## Related Tickets & Documents
<!--
For pull requests that relate or close an issue, please include them
below.
For example having the text: "closes #1234" would connect the current
pull
request to issue 1234. And when we merge the pull request, Github will
automatically close the issue.
-->
See discord message at:
https://discord.com/channels/1020123559063990373/1149510134058471514/1184982133912113182
## QA Instructions, Screenshots, Recordings
Try installing the SDXL LoRA at the URL given above.
## Merge Plan
This can be merged when approved.
## Added/updated tests?
- [ ] Yes
- [X] No : we do not yet have a comprehensive suite of models to test
probing on.
## [optional] Are there any post deployment tasks we need to perform?
This commit is contained in:
commit
fc150acde5
@ -9,7 +9,7 @@ def lora_token_vector_length(checkpoint: dict) -> int:
|
||||
:param checkpoint: The checkpoint
|
||||
"""
|
||||
|
||||
def _get_shape_1(key, tensor, checkpoint):
|
||||
def _get_shape_1(key: str, tensor, checkpoint) -> int:
|
||||
lora_token_vector_length = None
|
||||
|
||||
if "." not in key:
|
||||
@ -57,6 +57,10 @@ def lora_token_vector_length(checkpoint: dict) -> int:
|
||||
for key, tensor in checkpoint.items():
|
||||
if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key):
|
||||
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
|
||||
elif key.startswith("lora_unet_") and (
|
||||
"time_emb_proj.lora_down" in key
|
||||
): # recognizes format at https://civitai.com/models/224641
|
||||
lora_token_vector_length = _get_shape_1(key, tensor, checkpoint)
|
||||
elif key.startswith("lora_te") and "_self_attn_" in key:
|
||||
tmp_length = _get_shape_1(key, tensor, checkpoint)
|
||||
if key.startswith("lora_te_"):
|
||||
|
||||
@ -400,6 +400,8 @@ class LoRACheckpointProbe(CheckpointProbeBase):
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif token_vector_length == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif token_vector_length == 1280:
|
||||
return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641
|
||||
elif token_vector_length == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user