diff --git a/tests/nodes/test_graph_execution_state.py b/tests/nodes/test_graph_execution_state.py index 41ca93551a..a8a6590f68 100644 --- a/tests/nodes/test_graph_execution_state.py +++ b/tests/nodes/test_graph_execution_state.py @@ -3,8 +3,6 @@ import threading import pytest -from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache - # This import must happen before other invoke imports or test in other files(!!) break from .test_nodes import ( # isort: split PromptCollectionTestInvocation, @@ -17,7 +15,9 @@ import sqlite3 from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext from invokeai.app.invocations.collections import RangeInvocation from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation +from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig from invokeai.app.services.graph import CollectInvocation, Graph, GraphExecutionState, IterateInvocation, LibraryGraph +from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache from invokeai.app.services.invocation_queue import MemoryInvocationQueue from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.invocation_stats import InvocationStatsService @@ -61,7 +61,7 @@ def mock_services() -> InvocationServices: graph_execution_manager=graph_execution_manager, performance_statistics=InvocationStatsService(graph_execution_manager), processor=DefaultInvocationProcessor(), - configuration=None, # type: ignore + configuration=InvokeAIAppConfig(node_cache_size=0), # type: ignore session_queue=None, # type: ignore session_processor=None, # type: ignore invocation_cache=MemoryInvocationCache(), # type: ignore diff --git a/tests/nodes/test_invoker.py b/tests/nodes/test_invoker.py index 7dc5cf57b3..c3b508f675 100644 --- a/tests/nodes/test_invoker.py +++ b/tests/nodes/test_invoker.py @@ -4,6 +4,8 @@ import threading import pytest +from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig + # This import must happen before other invoke imports or test in other files(!!) break from .test_nodes import ( # isort: split ErrorInvocation, @@ -14,7 +16,6 @@ from .test_nodes import ( # isort: split wait_until, ) -from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.graph import Graph, GraphExecutionState, GraphInvocation, LibraryGraph from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache from invokeai.app.services.invocation_queue import MemoryInvocationQueue @@ -70,10 +71,10 @@ def mock_services() -> InvocationServices: graph_execution_manager=graph_execution_manager, processor=DefaultInvocationProcessor(), performance_statistics=InvocationStatsService(graph_execution_manager), - configuration=InvokeAIAppConfig(), + configuration=InvokeAIAppConfig(node_cache_size=0), # type: ignore session_queue=None, # type: ignore session_processor=None, # type: ignore - invocation_cache=MemoryInvocationCache(), + invocation_cache=MemoryInvocationCache(max_cache_size=0), ) @@ -102,7 +103,7 @@ def test_can_create_graph_state_from_graph(mock_invoker: Invoker, simple_graph): # @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor") def test_can_invoke(mock_invoker: Invoker, simple_graph): g = mock_invoker.create_execution_state(graph=simple_graph) - invocation_id = mock_invoker.invoke(queue_item_id="1", queue_id=DEFAULT_QUEUE_ID, graph_execution_state=g) + invocation_id = mock_invoker.invoke(queue_item_id=1, queue_id=DEFAULT_QUEUE_ID, graph_execution_state=g) assert invocation_id is not None def has_executed_any(g: GraphExecutionState): @@ -120,7 +121,7 @@ def test_can_invoke(mock_invoker: Invoker, simple_graph): def test_can_invoke_all(mock_invoker: Invoker, simple_graph): g = mock_invoker.create_execution_state(graph=simple_graph) invocation_id = mock_invoker.invoke( - queue_item_id="1", queue_id=DEFAULT_QUEUE_ID, graph_execution_state=g, invoke_all=True + queue_item_id=1, queue_id=DEFAULT_QUEUE_ID, graph_execution_state=g, invoke_all=True ) assert invocation_id is not None @@ -140,7 +141,7 @@ def test_handles_errors(mock_invoker: Invoker): g = mock_invoker.create_execution_state() g.graph.add_node(ErrorInvocation(id="1")) - mock_invoker.invoke(queue_item_id="1", queue_id=DEFAULT_QUEUE_ID, graph_execution_state=g, invoke_all=True) + mock_invoker.invoke(queue_item_id=1, queue_id=DEFAULT_QUEUE_ID, graph_execution_state=g, invoke_all=True) def has_executed_all(g: GraphExecutionState): g = mock_invoker.services.graph_execution_manager.get(g.id)