tests: fix broken tests

This commit is contained in:
psychedelicious 2024-02-08 00:36:53 +11:00
parent aff44c0e58
commit 6d25789705
4 changed files with 17 additions and 10 deletions

View File

@ -1,15 +1,18 @@
import typing
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, TypeVar
from typing import TYPE_CHECKING, Optional, TypeVar
import torch
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError
from invokeai.app.util.misc import uuid_string
if TYPE_CHECKING:
from invokeai.app.services.invoker import Invoker
T = TypeVar("T")
@ -31,7 +34,7 @@ class ObjectSerializerEphemeralDisk(ObjectSerializerBase[T]):
self._output_dir.mkdir(parents=True, exist_ok=True)
self.__obj_class_name: Optional[str] = None
def start(self, invoker: Invoker) -> None:
def start(self, invoker: "Invoker") -> None:
self._invoker = invoker
delete_all_result = self._delete_all()
if delete_all_result.deleted_count > 0:

View File

@ -1,11 +1,13 @@
from queue import Queue
from typing import Optional, TypeVar
from typing import TYPE_CHECKING, Optional, TypeVar
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
T = TypeVar("T")
if TYPE_CHECKING:
from invokeai.app.services.invoker import Invoker
class ObjectSerializerForwardCache(ObjectSerializerBase[T]):
"""Provides a simple forward cache for an underlying storage. The cache is LRU and has a maximum size."""
@ -17,13 +19,13 @@ class ObjectSerializerForwardCache(ObjectSerializerBase[T]):
self._cache_ids = Queue[str]()
self._max_cache_size = max_cache_size
def start(self, invoker: Invoker) -> None:
def start(self, invoker: "Invoker") -> None:
self._invoker = invoker
start_op = getattr(self._underlying_storage, "start", None)
if callable(start_op):
start_op(invoker)
def stop(self, invoker: Invoker) -> None:
def stop(self, invoker: "Invoker") -> None:
self._invoker = invoker
stop_op = getattr(self._underlying_storage, "stop", None)
if callable(stop_op):

View File

@ -60,7 +60,6 @@ def mock_services() -> InvocationServices:
image_records=None, # type: ignore
images=None, # type: ignore
invocation_cache=MemoryInvocationCache(max_cache_size=0),
latents=None, # type: ignore
logger=logging, # type: ignore
model_manager=None, # type: ignore
model_records=None, # type: ignore
@ -74,6 +73,8 @@ def mock_services() -> InvocationServices:
session_queue=None, # type: ignore
urls=None, # type: ignore
workflow_records=None, # type: ignore
tensors=None,
conditioning=None,
)
@ -89,7 +90,7 @@ def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[B
config=None,
context_data=None,
images=None,
latents=None,
tensors=None,
logger=None,
models=None,
util=None,

View File

@ -63,7 +63,6 @@ def mock_services() -> InvocationServices:
image_records=None, # type: ignore
images=None, # type: ignore
invocation_cache=MemoryInvocationCache(max_cache_size=0),
latents=None, # type: ignore
logger=logging, # type: ignore
model_manager=None, # type: ignore
model_records=None, # type: ignore
@ -77,6 +76,8 @@ def mock_services() -> InvocationServices:
session_queue=None, # type: ignore
urls=None, # type: ignore
workflow_records=None, # type: ignore
tensors=None,
conditioning=None,
)