mirror of
https://github.com/invoke-ai/InvokeAI
synced 2026-05-03 12:11:58 +02:00
feat(nodes): support collect -> iterate node connections w/ validation
This commit is contained in:
parent
d66cd4e81b
commit
50cb27cd0b
@ -529,10 +529,6 @@ class Graph(BaseModel):
|
||||
if err is not None:
|
||||
raise InvalidEdgeError(f"Collector output type does not match collector input type ({edge}): {err}")
|
||||
|
||||
# Validate that we are not connecting collector to iterator (currently unsupported)
|
||||
if isinstance(from_node, CollectInvocation) and isinstance(to_node, IterateInvocation):
|
||||
raise InvalidEdgeError(f"Cannot connect collector to iterator ({edge})")
|
||||
|
||||
# Validate if collector output type matches input type (if this edge results in both being set) - skip if the destination field is not Any or list[Any]
|
||||
if (
|
||||
isinstance(from_node, CollectInvocation)
|
||||
@ -638,8 +634,10 @@ class Graph(BaseModel):
|
||||
if len(inputs) > 1:
|
||||
return "Iterator may only have one input edge"
|
||||
|
||||
input_node = self.get_node(inputs[0].node_id)
|
||||
|
||||
# Get input and output fields (the fields linked to the iterator's input/output)
|
||||
input_field_type = get_output_field_type(self.get_node(inputs[0].node_id), inputs[0].field)
|
||||
input_field_type = get_output_field_type(input_node, inputs[0].field)
|
||||
output_field_types = [get_input_field_type(self.get_node(e.node_id), e.field) for e in outputs]
|
||||
|
||||
# Input type must be a list
|
||||
@ -651,6 +649,22 @@ class Graph(BaseModel):
|
||||
if not all((are_connection_types_compatible(input_field_item_type, t) for t in output_field_types)):
|
||||
return "Iterator outputs must connect to an input with a matching type"
|
||||
|
||||
# Collector input type must match all iterator output types
|
||||
if isinstance(input_node, CollectInvocation):
|
||||
# Traverse the graph to find the first collector input edge. Collectors validate that their collection
|
||||
# inputs are all of the same type, so we can use the first input edge to determine the collector's type
|
||||
first_collector_input_edge = self._get_input_edges(input_node.id, "item")[0]
|
||||
first_collector_input_type = get_output_field_type(
|
||||
self.get_node(first_collector_input_edge.source.node_id), first_collector_input_edge.source.field
|
||||
)
|
||||
resolved_collector_type = (
|
||||
first_collector_input_type
|
||||
if get_origin(first_collector_input_type) is None
|
||||
else get_args(first_collector_input_type)
|
||||
)
|
||||
if not all((are_connection_types_compatible(resolved_collector_type, t) for t in output_field_types)):
|
||||
return "Iterator collection type must match all iterator output types"
|
||||
|
||||
return None
|
||||
|
||||
def _is_collector_connection_valid(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user