From 50cb27cd0b6b232d707dbfe72992bac7de655ec4 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 5 Feb 2025 10:42:14 +1100 Subject: [PATCH] feat(nodes): support collect -> iterate node connections w/ validation --- invokeai/app/services/shared/graph.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index 0b86d6b13f..2d425a7515 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -529,10 +529,6 @@ class Graph(BaseModel): if err is not None: raise InvalidEdgeError(f"Collector output type does not match collector input type ({edge}): {err}") - # Validate that we are not connecting collector to iterator (currently unsupported) - if isinstance(from_node, CollectInvocation) and isinstance(to_node, IterateInvocation): - raise InvalidEdgeError(f"Cannot connect collector to iterator ({edge})") - # Validate if collector output type matches input type (if this edge results in both being set) - skip if the destination field is not Any or list[Any] if ( isinstance(from_node, CollectInvocation) @@ -638,8 +634,10 @@ class Graph(BaseModel): if len(inputs) > 1: return "Iterator may only have one input edge" + input_node = self.get_node(inputs[0].node_id) + # Get input and output fields (the fields linked to the iterator's input/output) - input_field_type = get_output_field_type(self.get_node(inputs[0].node_id), inputs[0].field) + input_field_type = get_output_field_type(input_node, inputs[0].field) output_field_types = [get_input_field_type(self.get_node(e.node_id), e.field) for e in outputs] # Input type must be a list @@ -651,6 +649,22 @@ class Graph(BaseModel): if not all((are_connection_types_compatible(input_field_item_type, t) for t in output_field_types)): return "Iterator outputs must connect to an input with a matching type" + # Collector input type must match all iterator output types + if isinstance(input_node, CollectInvocation): + # Traverse the graph to find the first collector input edge. Collectors validate that their collection + # inputs are all of the same type, so we can use the first input edge to determine the collector's type + first_collector_input_edge = self._get_input_edges(input_node.id, "item")[0] + first_collector_input_type = get_output_field_type( + self.get_node(first_collector_input_edge.source.node_id), first_collector_input_edge.source.field + ) + resolved_collector_type = ( + first_collector_input_type + if get_origin(first_collector_input_type) is None + else get_args(first_collector_input_type) + ) + if not all((are_connection_types_compatible(resolved_collector_type, t) for t in output_field_types)): + return "Iterator collection type must match all iterator output types" + return None def _is_collector_connection_valid(