mirror of
https://github.com/invoke-ai/InvokeAI
synced 2026-03-04 05:59:05 +01:00
tests(nodes): add test for collect -> iterate type validation
This commit is contained in:
parent
8556a2558e
commit
d66cd4e81b
@ -9,7 +9,9 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.math import AddInvocation
|
||||
from invokeai.app.invocations.primitives import (
|
||||
ColorInvocation,
|
||||
FloatCollectionInvocation,
|
||||
FloatInvocation,
|
||||
IntegerInvocation,
|
||||
@ -689,9 +691,6 @@ def test_any_accepts_any():
|
||||
|
||||
|
||||
def test_iterate_accepts_collection():
|
||||
"""We need to update the validation for Collect -> Iterate to traverse to the Iterate
|
||||
node's output and compare that against the item type of the Collect node's collection. Until
|
||||
then, Collect nodes may not output into Iterate nodes."""
|
||||
g = Graph()
|
||||
n1 = IntegerInvocation(id="1", value=1)
|
||||
n2 = IntegerInvocation(id="2", value=2)
|
||||
@ -706,9 +705,36 @@ def test_iterate_accepts_collection():
|
||||
e3 = create_edge(n3.id, "collection", n4.id, "collection")
|
||||
g.add_edge(e1)
|
||||
g.add_edge(e2)
|
||||
# Once we fix the validation logic as described, this should should not raise an error
|
||||
with pytest.raises(InvalidEdgeError, match="Cannot connect collector to iterator"):
|
||||
g.add_edge(e3)
|
||||
g.add_edge(e3)
|
||||
|
||||
|
||||
def test_iterate_validates_collection_inputs_against_iterator_outputs():
|
||||
g = Graph()
|
||||
n1 = IntegerInvocation(id="1", value=1)
|
||||
n2 = IntegerInvocation(id="2", value=2)
|
||||
n3 = CollectInvocation(id="3")
|
||||
n4 = IterateInvocation(id="4")
|
||||
n5 = AddInvocation(id="5")
|
||||
g.add_node(n1)
|
||||
g.add_node(n2)
|
||||
g.add_node(n3)
|
||||
g.add_node(n4)
|
||||
g.add_node(n5)
|
||||
e1 = create_edge(n1.id, "value", n3.id, "item")
|
||||
e2 = create_edge(n2.id, "value", n3.id, "item")
|
||||
e3 = create_edge(n3.id, "collection", n4.id, "collection")
|
||||
e4 = create_edge(n4.id, "item", n5.id, "a")
|
||||
g.add_edge(e1)
|
||||
g.add_edge(e2)
|
||||
g.add_edge(e3)
|
||||
# Not throwing on this line indicates the collector's input types validated successfully against the iterator's output types
|
||||
g.add_edge(e4)
|
||||
with pytest.raises(InvalidEdgeError, match="Iterator collection type must match all iterator output types"):
|
||||
# Connect iterator to a node with a different type than the collector inputs which is not allowed
|
||||
n6 = ColorInvocation(id="6")
|
||||
g.add_node(n6)
|
||||
e5 = create_edge(n4.id, "item", n6.id, "color")
|
||||
g.add_edge(e5)
|
||||
|
||||
|
||||
def test_graph_can_generate_schema():
|
||||
|
||||
Loading…
Reference in New Issue
Block a user