From bcb85e100db45afea6441fe4ca617bcf4372572a Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 8 Feb 2024 00:36:53 +1100 Subject: [PATCH] tests: fix broken tests --- .../object_serializer_ephemeral_disk.py | 9 ++++++--- .../object_serializer_forward_cache.py | 10 ++++++---- tests/aa_nodes/test_graph_execution_state.py | 5 +++-- tests/aa_nodes/test_invoker.py | 3 ++- 4 files changed, 17 insertions(+), 10 deletions(-) diff --git a/invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py b/invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py index 9545d1714d..880848a142 100644 --- a/invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py +++ b/invokeai/app/services/object_serializer/object_serializer_ephemeral_disk.py @@ -1,15 +1,18 @@ import typing from dataclasses import dataclass from pathlib import Path -from typing import Optional, TypeVar +from typing import TYPE_CHECKING, Optional, TypeVar import torch -from invokeai.app.services.invoker import Invoker from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError from invokeai.app.util.misc import uuid_string +if TYPE_CHECKING: + from invokeai.app.services.invoker import Invoker + + T = TypeVar("T") @@ -31,7 +34,7 @@ class ObjectSerializerEphemeralDisk(ObjectSerializerBase[T]): self._output_dir.mkdir(parents=True, exist_ok=True) self.__obj_class_name: Optional[str] = None - def start(self, invoker: Invoker) -> None: + def start(self, invoker: "Invoker") -> None: self._invoker = invoker delete_all_result = self._delete_all() if delete_all_result.deleted_count > 0: diff --git a/invokeai/app/services/object_serializer/object_serializer_forward_cache.py b/invokeai/app/services/object_serializer/object_serializer_forward_cache.py index 2a4ecdd844..c8ca13982c 100644 --- a/invokeai/app/services/object_serializer/object_serializer_forward_cache.py +++ b/invokeai/app/services/object_serializer/object_serializer_forward_cache.py @@ -1,11 +1,13 @@ from queue import Queue -from typing import Optional, TypeVar +from typing import TYPE_CHECKING, Optional, TypeVar -from invokeai.app.services.invoker import Invoker from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase T = TypeVar("T") +if TYPE_CHECKING: + from invokeai.app.services.invoker import Invoker + class ObjectSerializerForwardCache(ObjectSerializerBase[T]): """Provides a simple forward cache for an underlying storage. The cache is LRU and has a maximum size.""" @@ -17,13 +19,13 @@ class ObjectSerializerForwardCache(ObjectSerializerBase[T]): self._cache_ids = Queue[str]() self._max_cache_size = max_cache_size - def start(self, invoker: Invoker) -> None: + def start(self, invoker: "Invoker") -> None: self._invoker = invoker start_op = getattr(self._underlying_storage, "start", None) if callable(start_op): start_op(invoker) - def stop(self, invoker: Invoker) -> None: + def stop(self, invoker: "Invoker") -> None: self._invoker = invoker stop_op = getattr(self._underlying_storage, "stop", None) if callable(stop_op): diff --git a/tests/aa_nodes/test_graph_execution_state.py b/tests/aa_nodes/test_graph_execution_state.py index aba7c5694f..27d2d2230a 100644 --- a/tests/aa_nodes/test_graph_execution_state.py +++ b/tests/aa_nodes/test_graph_execution_state.py @@ -60,7 +60,6 @@ def mock_services() -> InvocationServices: image_records=None, # type: ignore images=None, # type: ignore invocation_cache=MemoryInvocationCache(max_cache_size=0), - latents=None, # type: ignore logger=logging, # type: ignore model_manager=None, # type: ignore model_records=None, # type: ignore @@ -74,6 +73,8 @@ def mock_services() -> InvocationServices: session_queue=None, # type: ignore urls=None, # type: ignore workflow_records=None, # type: ignore + tensors=None, + conditioning=None, ) @@ -89,7 +90,7 @@ def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[B config=None, context_data=None, images=None, - latents=None, + tensors=None, logger=None, models=None, util=None, diff --git a/tests/aa_nodes/test_invoker.py b/tests/aa_nodes/test_invoker.py index 2ae4eab58a..437ea0f00d 100644 --- a/tests/aa_nodes/test_invoker.py +++ b/tests/aa_nodes/test_invoker.py @@ -63,7 +63,6 @@ def mock_services() -> InvocationServices: image_records=None, # type: ignore images=None, # type: ignore invocation_cache=MemoryInvocationCache(max_cache_size=0), - latents=None, # type: ignore logger=logging, # type: ignore model_manager=None, # type: ignore model_records=None, # type: ignore @@ -77,6 +76,8 @@ def mock_services() -> InvocationServices: session_queue=None, # type: ignore urls=None, # type: ignore workflow_records=None, # type: ignore + tensors=None, + conditioning=None, )