fix issues with module import order breaking pytest node tests

This commit is contained in:
Lincoln Stein
2023-10-09 22:43:00 -04:00
parent 4bab724288
commit 67607f053d
17 changed files with 71 additions and 69 deletions

View File

@ -19,6 +19,7 @@ from invokeai.backend.util.logging import InvokeAILogger
from invokeai.version.invokeai_version import __version__
from ..services.default_graphs import create_system_graphs
from ..services.download_manager import DownloadQueueService
from ..services.graph import GraphExecutionState, LibraryGraph
from ..services.image_file_storage import DiskImageFileStorage
from ..services.invocation_queue import MemoryInvocationQueue
@ -26,7 +27,6 @@ from ..services.invocation_services import InvocationServices
from ..services.invocation_stats import InvocationStatsService
from ..services.invoker import Invoker
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from ..services.download_manager import DownloadQueueService
from ..services.model_install_service import ModelInstallService
from ..services.model_loader_service import ModelLoadService
from ..services.model_record_service import ModelRecordServiceBase
@ -133,11 +133,7 @@ class ApiDependencies:
download_queue = DownloadQueueService(event_bus=events, config=config)
model_record_store = ModelRecordServiceBase.get_impl(config, conn=db_conn, lock=lock)
model_loader = ModelLoadService(config, model_record_store)
model_installer = ModelInstallService(config,
queue=download_queue,
store=model_record_store,
event_bus=events
)
model_installer = ModelInstallService(config, queue=download_queue, store=model_record_store, event_bus=events)
services = InvocationServices(
events=events,

View File

@ -11,19 +11,19 @@ from pydantic import BaseModel, parse_obj_as
from starlette.exceptions import HTTPException
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.download_manager import DownloadJobRemoteSource, DownloadJobStatus, UnknownJobIDException
from invokeai.app.services.model_convert import MergeInterpolationMethod, ModelConvert
from invokeai.app.services.model_install_service import ModelInstallJob
from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_manager import (
OPENAPI_MODEL_CONFIGS,
DuplicateModelException,
InvalidModelException,
ModelConfigBase,
ModelSearch,
SchedulerPredictionType,
UnknownModelException,
ModelSearch
)
from invokeai.app.services.download_manager import DownloadJobStatus, UnknownJobIDException, DownloadJobRemoteSource
from invokeai.app.services.model_convert import MergeInterpolationMethod, ModelConvert
from invokeai.app.services.model_install_service import ModelInstallJob
models_router = APIRouter(prefix="/v1/models", tags=["models"])
@ -58,6 +58,7 @@ class JobControlOperation(str, Enum):
CANCEL = "Cancel"
CHANGE_PRIORITY = "Change Priority"
@models_router.get(
"/",
operation_id="list_models",
@ -292,7 +293,7 @@ async def convert_model(
converter = ModelConvert(
loader=ApiDependencies.invoker.services.model_loader,
installer=ApiDependencies.invoker.services.model_installer,
store=ApiDependencies.invoker.services.model_record_store
store=ApiDependencies.invoker.services.model_record_store,
)
model_config = converter.convert_model(key, dest_directory=dest)
response = parse_obj_as(InvokeAIModelConfig, model_config.dict())
@ -323,6 +324,7 @@ async def search_for_models(
)
return ModelSearch().search(search_path)
@models_router.get(
"/ckpt_confs",
operation_id="list_ckpt_configs",
@ -394,7 +396,7 @@ async def merge_models(
converter = ModelConvert(
loader=ApiDependencies.invoker.services.model_loader,
installer=ApiDependencies.invoker.services.model_installer,
store=ApiDependencies.invoker.services.model_record_store
store=ApiDependencies.invoker.services.model_record_store,
)
result: ModelConfigBase = converter.merge_models(
model_keys=keys,
@ -437,7 +439,8 @@ async def list_install_jobs() -> List[ModelDownloadStatus]:
total_bytes=x.total_bytes,
status=x.status,
)
for x in jobs if isinstance(x, ModelInstallJob)
for x in jobs
if isinstance(x, ModelInstallJob)
]

View File

@ -4,6 +4,7 @@ from typing import List, Optional
from pydantic import BaseModel, Field
from invokeai.backend.model_manager import SubModelType
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,

View File

@ -171,6 +171,7 @@ two configs are kept in separate sections of the config file:
from __future__ import annotations
import os
import sys
from pathlib import Path
from typing import ClassVar, Dict, List, Literal, Optional, Union, get_type_hints

View File

@ -5,10 +5,11 @@ Model download service.
from abc import ABC, abstractmethod
from pathlib import Path
from typing import List, Optional, Union
from typing import TYPE_CHECKING, Any, List, Optional, Union
from pydantic.networks import AnyHttpUrl
from invokeai.backend.model_manager.download import DownloadJobRemoteSource # noqa F401
from invokeai.backend.model_manager.download import ( # noqa F401
DownloadEventHandler,
DownloadJobBase,
@ -19,10 +20,9 @@ from invokeai.backend.model_manager.download import ( # noqa F401
ModelSourceMetadata,
UnknownJobIDException,
)
from invokeai.backend.model_manager.download import DownloadJobRemoteSource # noqa F401
from .events import EventServiceBase
if TYPE_CHECKING:
from .events import EventServiceBase
class DownloadQueueServiceBase(ABC):
@ -146,10 +146,10 @@ class DownloadQueueServiceBase(ABC):
class DownloadQueueService(DownloadQueueServiceBase):
"""Multithreaded queue for downloading models via URL or repo_id."""
_event_bus: EventServiceBase
_event_bus: Optional["EventServiceBase"] = None
_queue: DownloadQueueBase
def __init__(self, event_bus: EventServiceBase, **kwargs):
def __init__(self, event_bus: Optional["EventServiceBase"] = None, **kwargs):
"""
Initialize new DownloadQueueService object.

View File

@ -9,6 +9,7 @@ if TYPE_CHECKING:
from invokeai.app.services.board_images import BoardImagesServiceABC
from invokeai.app.services.boards import BoardServiceABC
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download_manager import DownloadQueueServiceBase
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
from invokeai.app.services.images import ImageServiceABC
@ -18,7 +19,6 @@ if TYPE_CHECKING:
from invokeai.app.services.invoker import InvocationProcessorABC
from invokeai.app.services.item_storage import ItemStorageABC
from invokeai.app.services.latent_storage import LatentsStorageBase
from invokeai.app.services.download_manager import DownloadQueueServiceBase
from invokeai.app.services.model_install_service import ModelInstallServiceBase
from invokeai.app.services.model_loader_service import ModelLoadServiceBase
from invokeai.app.services.model_record_service import ModelRecordServiceBase

View File

@ -7,17 +7,18 @@ Convert and merge models.
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Optional, List
from pydantic import Field
from pathlib import Path
from shutil import move, rmtree
from typing import List, Optional
from pydantic import Field
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
from .config import InvokeAIAppConfig
from .model_install_service import ModelInstallServiceBase
from .model_loader_service import ModelLoadServiceBase, ModelInfo
from .model_record_service import ModelRecordServiceBase, ModelConfigBase, ModelType, SubModelType
from .model_loader_service import ModelInfo, ModelLoadServiceBase
from .model_record_service import ModelConfigBase, ModelRecordServiceBase, ModelType, SubModelType
class ModelConvertBase(ABC):
@ -25,19 +26,19 @@ class ModelConvertBase(ABC):
@abstractmethod
def __init__(
cls,
loader: ModelLoadServiceBase,
installer: ModelInstallServiceBase,
store: ModelRecordServiceBase,
cls,
loader: ModelLoadServiceBase,
installer: ModelInstallServiceBase,
store: ModelRecordServiceBase,
):
"""Initialize ModelConvert with loader, installer and configuration store."""
pass
@abstractmethod
def convert_model(
self,
key: str,
dest_directory: Optional[Path] = None,
self,
key: str,
dest_directory: Optional[Path] = None,
) -> ModelConfigBase:
"""
Convert a checkpoint file into a diffusers folder.
@ -75,14 +76,15 @@ class ModelConvertBase(ABC):
"""
pass
class ModelConvert(ModelConvertBase):
"""Implementation of ModelConvertBase."""
def __init__(
self,
loader: ModelLoadServiceBase,
installer: ModelInstallServiceBase,
store: ModelRecordServiceBase,
self,
loader: ModelLoadServiceBase,
installer: ModelInstallServiceBase,
store: ModelRecordServiceBase,
):
"""Initialize ModelConvert with loader, installer and configuration store."""
self.loader = loader
@ -90,9 +92,9 @@ class ModelConvert(ModelConvertBase):
self.store = store
def convert_model(
self,
key: str,
dest_directory: Optional[Path] = None,
self,
key: str,
dest_directory: Optional[Path] = None,
) -> ModelConfigBase:
"""
Convert a checkpoint file into a diffusers folder.

View File

@ -5,14 +5,13 @@ import tempfile
from abc import ABC, abstractmethod
from pathlib import Path
from shutil import move, rmtree
from typing import Any, Callable, Dict, List, Optional, Set, Union, Literal
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Set, Union
from pydantic import Field, parse_obj_as
from pydantic import Field
from pydantic.networks import AnyHttpUrl
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_record_service import ModelRecordServiceBase
from invokeai.backend.util import Chdir, InvokeAILogger, Logger
from invokeai.backend import get_precision
from invokeai.backend.model_manager.config import (
BaseModelType,
@ -34,13 +33,17 @@ from invokeai.backend.model_manager.models import InvalidModelException
from invokeai.backend.model_manager.probe import ModelProbe, ModelProbeInfo
from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.model_manager.storage import DuplicateModelException, ModelConfigStore
from .events import EventServiceBase
from invokeai.backend.util import Chdir, InvokeAILogger, Logger
if TYPE_CHECKING:
from .events import EventServiceBase
from .download_manager import (
DownloadQueueServiceBase,
DownloadQueueService,
DownloadEventHandler,
DownloadJobBase,
DownloadJobPath,
DownloadEventHandler,
DownloadQueueService,
DownloadQueueServiceBase,
ModelSourceMetadata,
)
@ -81,7 +84,7 @@ class ModelInstallServiceBase(ABC):
config: Optional[InvokeAIAppConfig] = None,
queue: Optional[DownloadQueueServiceBase] = None,
store: Optional[ModelRecordServiceBase] = None,
event_bus: Optional[EventServiceBase] = None,
event_bus: Optional["EventServiceBase"] = None,
):
"""
Create ModelInstallService object.
@ -227,6 +230,7 @@ class ModelInstallServiceBase(ABC):
"""
pass
class ModelInstallService(ModelInstallServiceBase):
"""Model installer class handles installation from a local path."""
@ -239,7 +243,7 @@ class ModelInstallService(ModelInstallServiceBase):
_tmpdir: Optional[tempfile.TemporaryDirectory] # used for downloads
_cached_model_paths: Set[Path] = Field(default=set) # used to speed up directory scanning
_precision: Literal["float16", "float32"] = Field(description="Floating point precision, string form")
_event_bus: Optional[EventServiceBase] = Field(description="an event bus to send install events to", default=None)
_event_bus: Optional["EventServiceBase"] = Field(description="an event bus to send install events to", default=None)
_legacy_configs: Dict[BaseModelType, Dict[ModelVariantType, Union[str, dict]]] = {
BaseModelType.StableDiffusion1: {
@ -269,7 +273,7 @@ class ModelInstallService(ModelInstallServiceBase):
config: Optional[InvokeAIAppConfig] = None,
queue: Optional[DownloadQueueServiceBase] = None,
store: Optional[ModelRecordServiceBase] = None,
event_bus: Optional[EventServiceBase] = None,
event_bus: Optional["EventServiceBase"] = None,
event_handlers: List[DownloadEventHandler] = [],
): # noqa D107 - use base class docstrings
self._app_config = config or InvokeAIAppConfig.get_config()
@ -281,10 +285,7 @@ class ModelInstallService(ModelInstallServiceBase):
if self._event_bus:
self._handlers.append(self._event_bus.emit_model_event)
self._download_queue = queue or DownloadQueueService(
event_bus=event_bus,
config=self._app_config
)
self._download_queue = queue or DownloadQueueService(event_bus=event_bus, config=self._app_config)
self._async_installs: Dict[Union[str, Path, AnyHttpUrl], Union[str, None]] = dict()
self._installed = set()
self._tmpdir = None
@ -318,7 +319,7 @@ class ModelInstallService(ModelInstallServiceBase):
probe_override: Optional[Dict[str, Any]] = None,
metadata: Optional[ModelSourceMetadata] = None,
access_token: Optional[str] = None,
) -> DownloadJobBase: # noqa D102
) -> ModelInstallJob: # noqa D102
queue = self._download_queue
variant = variant or ("fp16" if self._precision == "float16" else None)

View File

@ -3,15 +3,15 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from pydantic import Field
from pathlib import Path
from invokeai.app.models.exceptions import CanceledException
from invokeai.backend.model_manager import ModelConfigStore, SubModelType
from invokeai.backend.model_manager.cache import CacheStats
from invokeai.backend.model_manager.loader import ModelInfo, ModelLoad, ModelConfigBase
from invokeai.backend.model_manager.loader import ModelConfigBase, ModelInfo, ModelLoad
from .config import InvokeAIAppConfig
from .model_record_service import ModelRecordServiceBase
@ -58,8 +58,6 @@ class ModelLoadServiceBase(ABC):
"""Reset model cache statistics for graph with graph_id."""
pass
# implementation
class ModelLoadService(ModelLoadServiceBase):
@ -140,4 +138,3 @@ class ModelLoadService(ModelLoadServiceBase):
model_key=model_key,
submodel=submodel,
)

View File

@ -8,8 +8,7 @@ from abc import abstractmethod
from pathlib import Path
from typing import Optional, Union
from invokeai.backend.model_manager import ModelConfigBase, BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager import BaseModelType, ModelConfigBase, ModelType, SubModelType
from invokeai.backend.model_manager.storage import (
ModelConfigStore,
ModelConfigStoreSQL,

View File

@ -12,9 +12,9 @@ from tqdm import tqdm
import invokeai.configs as configs
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_install_service import ModelInstall, ModelInstallJob, ModelSourceMetadata
from invokeai.backend.model_manager import BaseModelType, ModelType
from invokeai.backend.model_manager.download.queue import DownloadJobRemoteSource
from invokeai.app.services.model_install_service import ModelInstall, ModelInstallJob, ModelSourceMetadata
# name of the starter models file
INITIAL_MODELS = "INITIAL_MODELS.yaml"

View File

@ -8,4 +8,4 @@ from .base import ( # noqa F401
UnknownJobIDException,
)
from .model_queue import ModelDownloadQueue, ModelSourceMetadata # noqa F401
from .queue import DownloadJobPath, DownloadJobURL, DownloadQueue, DownloadJobRemoteSource # noqa F401
from .queue import DownloadJobPath, DownloadJobRemoteSource, DownloadJobURL, DownloadQueue # noqa F401

View File

@ -278,11 +278,11 @@ class ModelDownloadQueue(DownloadQueue):
# including checkpoint files, different EMA versions, etc.
# This filters out just the file types needed for the model
for x in sibs:
if x.rfilename.endswith(('.json', '.txt')):
if x.rfilename.endswith((".json", ".txt")):
paths.append(x.rfilename)
elif x.rfilename.endswith(('learned_embeds.bin', 'ip_adapter.bin')):
elif x.rfilename.endswith(("learned_embeds.bin", "ip_adapter.bin")):
paths.append(x.rfilename)
elif re.search(r'model(\.[^.]+)?\.(safetensors|bin)$', x.rfilename):
elif re.search(r"model(\.[^.]+)?\.(safetensors|bin)$", x.rfilename):
paths.append(x.rfilename)
sizes = {x.rfilename: x.size for x in sibs}

View File

@ -5,8 +5,8 @@ import hashlib
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple, Union
from shutil import move, rmtree
from typing import Optional, Tuple, Union
import torch

View File

@ -6,8 +6,8 @@ import pytest
import torch
from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig
from invokeai.app.services.model_record_service import ModelRecordServiceBase
from invokeai.app.services.model_install_service import ModelInstallService
from invokeai.app.services.model_record_service import ModelRecordServiceBase
from invokeai.backend.model_manager import BaseModelType, ModelType, SubModelType, UnknownModelException
from invokeai.backend.model_manager.loader import ModelInfo, ModelLoad

View File

@ -49,6 +49,7 @@ def mock_services() -> InvocationServices:
conn=db_conn, table_name="graph_executions", lock=lock
)
return InvocationServices(
download_queue=None, # type: ignore
model_loader=None, # type: ignore
model_installer=None, # type: ignore
model_record_store=None, # type: ignore

View File

@ -1,7 +1,8 @@
# conftest.py is a special pytest file. Fixtures defined in this file will be accessible to all tests in this directory
# without needing to explicitly import them. (https://docs.pytest.org/en/6.2.x/fixture.html)
from invokeai.app.services.model_install_service import ModelInstallService # noqa: F401
# We import the model_installer and torch_device fixtures here so that they can be used by all tests. Flake8 does not
# play well with fixtures (F401 and F811), so this is cleaner than importing in all files that use these fixtures.
from invokeai.backend.util.test_utils import torch_device # noqa: F401
from invokeai.app.services.model_install_service import ModelInstallService # noqa: F401