feat(nodes): support collect -> iterate node connections w/ validation

This commit is contained in:
psychedelicious 2025-02-05 10:42:14 +11:00
parent d66cd4e81b
commit 50cb27cd0b

View File

@ -529,10 +529,6 @@ class Graph(BaseModel):
if err is not None:
raise InvalidEdgeError(f"Collector output type does not match collector input type ({edge}): {err}")
# Validate that we are not connecting collector to iterator (currently unsupported)
if isinstance(from_node, CollectInvocation) and isinstance(to_node, IterateInvocation):
raise InvalidEdgeError(f"Cannot connect collector to iterator ({edge})")
# Validate if collector output type matches input type (if this edge results in both being set) - skip if the destination field is not Any or list[Any]
if (
isinstance(from_node, CollectInvocation)
@ -638,8 +634,10 @@ class Graph(BaseModel):
if len(inputs) > 1:
return "Iterator may only have one input edge"
input_node = self.get_node(inputs[0].node_id)
# Get input and output fields (the fields linked to the iterator's input/output)
input_field_type = get_output_field_type(self.get_node(inputs[0].node_id), inputs[0].field)
input_field_type = get_output_field_type(input_node, inputs[0].field)
output_field_types = [get_input_field_type(self.get_node(e.node_id), e.field) for e in outputs]
# Input type must be a list
@ -651,6 +649,22 @@ class Graph(BaseModel):
if not all((are_connection_types_compatible(input_field_item_type, t) for t in output_field_types)):
return "Iterator outputs must connect to an input with a matching type"
# Collector input type must match all iterator output types
if isinstance(input_node, CollectInvocation):
# Traverse the graph to find the first collector input edge. Collectors validate that their collection
# inputs are all of the same type, so we can use the first input edge to determine the collector's type
first_collector_input_edge = self._get_input_edges(input_node.id, "item")[0]
first_collector_input_type = get_output_field_type(
self.get_node(first_collector_input_edge.source.node_id), first_collector_input_edge.source.field
)
resolved_collector_type = (
first_collector_input_type
if get_origin(first_collector_input_type) is None
else get_args(first_collector_input_type)
)
if not all((are_connection_types_compatible(resolved_collector_type, t) for t in output_field_types)):
return "Iterator collection type must match all iterator output types"
return None
def _is_collector_connection_valid(