tests(nodes): add test for collect -> iterate type validation

This commit is contained in:
psychedelicious 2025-02-05 10:41:59 +11:00
parent 8556a2558e
commit d66cd4e81b

View File

@ -9,7 +9,9 @@ from invokeai.app.invocations.baseinvocation import (
invocation,
invocation_output,
)
from invokeai.app.invocations.math import AddInvocation
from invokeai.app.invocations.primitives import (
ColorInvocation,
FloatCollectionInvocation,
FloatInvocation,
IntegerInvocation,
@ -689,9 +691,6 @@ def test_any_accepts_any():
def test_iterate_accepts_collection():
"""We need to update the validation for Collect -> Iterate to traverse to the Iterate
node's output and compare that against the item type of the Collect node's collection. Until
then, Collect nodes may not output into Iterate nodes."""
g = Graph()
n1 = IntegerInvocation(id="1", value=1)
n2 = IntegerInvocation(id="2", value=2)
@ -706,9 +705,36 @@ def test_iterate_accepts_collection():
e3 = create_edge(n3.id, "collection", n4.id, "collection")
g.add_edge(e1)
g.add_edge(e2)
# Once we fix the validation logic as described, this should should not raise an error
with pytest.raises(InvalidEdgeError, match="Cannot connect collector to iterator"):
g.add_edge(e3)
g.add_edge(e3)
def test_iterate_validates_collection_inputs_against_iterator_outputs():
g = Graph()
n1 = IntegerInvocation(id="1", value=1)
n2 = IntegerInvocation(id="2", value=2)
n3 = CollectInvocation(id="3")
n4 = IterateInvocation(id="4")
n5 = AddInvocation(id="5")
g.add_node(n1)
g.add_node(n2)
g.add_node(n3)
g.add_node(n4)
g.add_node(n5)
e1 = create_edge(n1.id, "value", n3.id, "item")
e2 = create_edge(n2.id, "value", n3.id, "item")
e3 = create_edge(n3.id, "collection", n4.id, "collection")
e4 = create_edge(n4.id, "item", n5.id, "a")
g.add_edge(e1)
g.add_edge(e2)
g.add_edge(e3)
# Not throwing on this line indicates the collector's input types validated successfully against the iterator's output types
g.add_edge(e4)
with pytest.raises(InvalidEdgeError, match="Iterator collection type must match all iterator output types"):
# Connect iterator to a node with a different type than the collector inputs which is not allowed
n6 = ColorInvocation(id="6")
g.add_node(n6)
e5 = create_edge(n4.id, "item", n6.id, "color")
g.add_edge(e5)
def test_graph_can_generate_schema():