diff --git a/tests/test_node_graph.py b/tests/test_node_graph.py index b0358c08ba..0a4ce77538 100644 --- a/tests/test_node_graph.py +++ b/tests/test_node_graph.py @@ -9,7 +9,9 @@ from invokeai.app.invocations.baseinvocation import ( invocation, invocation_output, ) +from invokeai.app.invocations.math import AddInvocation from invokeai.app.invocations.primitives import ( + ColorInvocation, FloatCollectionInvocation, FloatInvocation, IntegerInvocation, @@ -689,9 +691,6 @@ def test_any_accepts_any(): def test_iterate_accepts_collection(): - """We need to update the validation for Collect -> Iterate to traverse to the Iterate - node's output and compare that against the item type of the Collect node's collection. Until - then, Collect nodes may not output into Iterate nodes.""" g = Graph() n1 = IntegerInvocation(id="1", value=1) n2 = IntegerInvocation(id="2", value=2) @@ -706,9 +705,36 @@ def test_iterate_accepts_collection(): e3 = create_edge(n3.id, "collection", n4.id, "collection") g.add_edge(e1) g.add_edge(e2) - # Once we fix the validation logic as described, this should should not raise an error - with pytest.raises(InvalidEdgeError, match="Cannot connect collector to iterator"): - g.add_edge(e3) + g.add_edge(e3) + + +def test_iterate_validates_collection_inputs_against_iterator_outputs(): + g = Graph() + n1 = IntegerInvocation(id="1", value=1) + n2 = IntegerInvocation(id="2", value=2) + n3 = CollectInvocation(id="3") + n4 = IterateInvocation(id="4") + n5 = AddInvocation(id="5") + g.add_node(n1) + g.add_node(n2) + g.add_node(n3) + g.add_node(n4) + g.add_node(n5) + e1 = create_edge(n1.id, "value", n3.id, "item") + e2 = create_edge(n2.id, "value", n3.id, "item") + e3 = create_edge(n3.id, "collection", n4.id, "collection") + e4 = create_edge(n4.id, "item", n5.id, "a") + g.add_edge(e1) + g.add_edge(e2) + g.add_edge(e3) + # Not throwing on this line indicates the collector's input types validated successfully against the iterator's output types + g.add_edge(e4) + with pytest.raises(InvalidEdgeError, match="Iterator collection type must match all iterator output types"): + # Connect iterator to a node with a different type than the collector inputs which is not allowed + n6 = ColorInvocation(id="6") + g.add_node(n6) + e5 = create_edge(n4.id, "item", n6.id, "color") + g.add_edge(e5) def test_graph_can_generate_schema():