mirror of
https://github.com/invoke-ai/InvokeAI
synced 2026-03-02 04:59:06 +01:00
tests: fix test for breaking pydantic v2.12 change
Fixes a test failure introduced by https://github.com/pydantic/pydantic/pull/11957 TL;DR: "after" model validators should be instance methods, not class methods. Batch model updated to use an instance method, which fixes the failing test.
This commit is contained in:
parent
c0469ef633
commit
25f8ab24aa
@ -1,7 +1,7 @@
|
||||
import datetime
|
||||
import json
|
||||
from itertools import chain, product
|
||||
from typing import Generator, Literal, Optional, TypeAlias, Union, cast
|
||||
from typing import Generator, Literal, Optional, TypeAlias, Union
|
||||
|
||||
from pydantic import (
|
||||
AliasChoices,
|
||||
@ -15,7 +15,6 @@ from pydantic import (
|
||||
)
|
||||
from pydantic_core import to_jsonable_python
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.invocations.fields import ImageField
|
||||
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, NodeNotFoundError
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import (
|
||||
@ -137,20 +136,18 @@ class Batch(BaseModel):
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_batch_nodes_and_edges(cls, values):
|
||||
batch_data_collection = cast(Optional[BatchDataCollection], values.data)
|
||||
if batch_data_collection is None:
|
||||
return values
|
||||
graph = cast(Graph, values.graph)
|
||||
for batch_data_list in batch_data_collection:
|
||||
def validate_batch_nodes_and_edges(self):
|
||||
if self.data is None:
|
||||
return self
|
||||
for batch_data_list in self.data:
|
||||
for batch_data in batch_data_list:
|
||||
try:
|
||||
node = cast(BaseInvocation, graph.get_node(batch_data.node_path))
|
||||
node = self.graph.get_node(batch_data.node_path)
|
||||
except NodeNotFoundError:
|
||||
raise NodeNotFoundError(f"Node {batch_data.node_path} not found in graph")
|
||||
if batch_data.field_name not in type(node).model_fields:
|
||||
raise NodeNotFoundError(f"Field {batch_data.field_name} not found in node {batch_data.node_path}")
|
||||
return values
|
||||
return self
|
||||
|
||||
@field_validator("graph")
|
||||
def validate_graph(cls, v: Graph):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user