mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix issues with module import order breaking pytest node tests
This commit is contained in:
@ -19,6 +19,7 @@ from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
|
||||
from ..services.default_graphs import create_system_graphs
|
||||
from ..services.download_manager import DownloadQueueService
|
||||
from ..services.graph import GraphExecutionState, LibraryGraph
|
||||
from ..services.image_file_storage import DiskImageFileStorage
|
||||
from ..services.invocation_queue import MemoryInvocationQueue
|
||||
@ -26,7 +27,6 @@ from ..services.invocation_services import InvocationServices
|
||||
from ..services.invocation_stats import InvocationStatsService
|
||||
from ..services.invoker import Invoker
|
||||
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
from ..services.download_manager import DownloadQueueService
|
||||
from ..services.model_install_service import ModelInstallService
|
||||
from ..services.model_loader_service import ModelLoadService
|
||||
from ..services.model_record_service import ModelRecordServiceBase
|
||||
@ -133,11 +133,7 @@ class ApiDependencies:
|
||||
download_queue = DownloadQueueService(event_bus=events, config=config)
|
||||
model_record_store = ModelRecordServiceBase.get_impl(config, conn=db_conn, lock=lock)
|
||||
model_loader = ModelLoadService(config, model_record_store)
|
||||
model_installer = ModelInstallService(config,
|
||||
queue=download_queue,
|
||||
store=model_record_store,
|
||||
event_bus=events
|
||||
)
|
||||
model_installer = ModelInstallService(config, queue=download_queue, store=model_record_store, event_bus=events)
|
||||
|
||||
services = InvocationServices(
|
||||
events=events,
|
||||
|
@ -11,19 +11,19 @@ from pydantic import BaseModel, parse_obj_as
|
||||
from starlette.exceptions import HTTPException
|
||||
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.services.download_manager import DownloadJobRemoteSource, DownloadJobStatus, UnknownJobIDException
|
||||
from invokeai.app.services.model_convert import MergeInterpolationMethod, ModelConvert
|
||||
from invokeai.app.services.model_install_service import ModelInstallJob
|
||||
from invokeai.backend import BaseModelType, ModelType
|
||||
from invokeai.backend.model_manager import (
|
||||
OPENAPI_MODEL_CONFIGS,
|
||||
DuplicateModelException,
|
||||
InvalidModelException,
|
||||
ModelConfigBase,
|
||||
ModelSearch,
|
||||
SchedulerPredictionType,
|
||||
UnknownModelException,
|
||||
ModelSearch
|
||||
)
|
||||
from invokeai.app.services.download_manager import DownloadJobStatus, UnknownJobIDException, DownloadJobRemoteSource
|
||||
from invokeai.app.services.model_convert import MergeInterpolationMethod, ModelConvert
|
||||
from invokeai.app.services.model_install_service import ModelInstallJob
|
||||
|
||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||
|
||||
@ -58,6 +58,7 @@ class JobControlOperation(str, Enum):
|
||||
CANCEL = "Cancel"
|
||||
CHANGE_PRIORITY = "Change Priority"
|
||||
|
||||
|
||||
@models_router.get(
|
||||
"/",
|
||||
operation_id="list_models",
|
||||
@ -292,7 +293,7 @@ async def convert_model(
|
||||
converter = ModelConvert(
|
||||
loader=ApiDependencies.invoker.services.model_loader,
|
||||
installer=ApiDependencies.invoker.services.model_installer,
|
||||
store=ApiDependencies.invoker.services.model_record_store
|
||||
store=ApiDependencies.invoker.services.model_record_store,
|
||||
)
|
||||
model_config = converter.convert_model(key, dest_directory=dest)
|
||||
response = parse_obj_as(InvokeAIModelConfig, model_config.dict())
|
||||
@ -323,6 +324,7 @@ async def search_for_models(
|
||||
)
|
||||
return ModelSearch().search(search_path)
|
||||
|
||||
|
||||
@models_router.get(
|
||||
"/ckpt_confs",
|
||||
operation_id="list_ckpt_configs",
|
||||
@ -394,7 +396,7 @@ async def merge_models(
|
||||
converter = ModelConvert(
|
||||
loader=ApiDependencies.invoker.services.model_loader,
|
||||
installer=ApiDependencies.invoker.services.model_installer,
|
||||
store=ApiDependencies.invoker.services.model_record_store
|
||||
store=ApiDependencies.invoker.services.model_record_store,
|
||||
)
|
||||
result: ModelConfigBase = converter.merge_models(
|
||||
model_keys=keys,
|
||||
@ -437,7 +439,8 @@ async def list_install_jobs() -> List[ModelDownloadStatus]:
|
||||
total_bytes=x.total_bytes,
|
||||
status=x.status,
|
||||
)
|
||||
for x in jobs if isinstance(x, ModelInstallJob)
|
||||
for x in jobs
|
||||
if isinstance(x, ModelInstallJob)
|
||||
]
|
||||
|
||||
|
||||
|
@ -4,6 +4,7 @@ from typing import List, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.backend.model_manager import SubModelType
|
||||
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
|
@ -171,6 +171,7 @@ two configs are kept in separate sections of the config file:
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, Dict, List, Literal, Optional, Union, get_type_hints
|
||||
|
||||
|
@ -5,10 +5,11 @@ Model download service.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
|
||||
from invokeai.backend.model_manager.download import DownloadJobRemoteSource # noqa F401
|
||||
from invokeai.backend.model_manager.download import ( # noqa F401
|
||||
DownloadEventHandler,
|
||||
DownloadJobBase,
|
||||
@ -19,10 +20,9 @@ from invokeai.backend.model_manager.download import ( # noqa F401
|
||||
ModelSourceMetadata,
|
||||
UnknownJobIDException,
|
||||
)
|
||||
from invokeai.backend.model_manager.download import DownloadJobRemoteSource # noqa F401
|
||||
|
||||
|
||||
from .events import EventServiceBase
|
||||
if TYPE_CHECKING:
|
||||
from .events import EventServiceBase
|
||||
|
||||
|
||||
class DownloadQueueServiceBase(ABC):
|
||||
@ -146,10 +146,10 @@ class DownloadQueueServiceBase(ABC):
|
||||
class DownloadQueueService(DownloadQueueServiceBase):
|
||||
"""Multithreaded queue for downloading models via URL or repo_id."""
|
||||
|
||||
_event_bus: EventServiceBase
|
||||
_event_bus: Optional["EventServiceBase"] = None
|
||||
_queue: DownloadQueueBase
|
||||
|
||||
def __init__(self, event_bus: EventServiceBase, **kwargs):
|
||||
def __init__(self, event_bus: Optional["EventServiceBase"] = None, **kwargs):
|
||||
"""
|
||||
Initialize new DownloadQueueService object.
|
||||
|
||||
|
@ -9,6 +9,7 @@ if TYPE_CHECKING:
|
||||
from invokeai.app.services.board_images import BoardImagesServiceABC
|
||||
from invokeai.app.services.boards import BoardServiceABC
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download_manager import DownloadQueueServiceBase
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
|
||||
from invokeai.app.services.images import ImageServiceABC
|
||||
@ -18,7 +19,6 @@ if TYPE_CHECKING:
|
||||
from invokeai.app.services.invoker import InvocationProcessorABC
|
||||
from invokeai.app.services.item_storage import ItemStorageABC
|
||||
from invokeai.app.services.latent_storage import LatentsStorageBase
|
||||
from invokeai.app.services.download_manager import DownloadQueueServiceBase
|
||||
from invokeai.app.services.model_install_service import ModelInstallServiceBase
|
||||
from invokeai.app.services.model_loader_service import ModelLoadServiceBase
|
||||
from invokeai.app.services.model_record_service import ModelRecordServiceBase
|
||||
|
@ -7,17 +7,18 @@ Convert and merge models.
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, List
|
||||
|
||||
from pydantic import Field
|
||||
from pathlib import Path
|
||||
from shutil import move, rmtree
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
|
||||
|
||||
from .config import InvokeAIAppConfig
|
||||
from .model_install_service import ModelInstallServiceBase
|
||||
from .model_loader_service import ModelLoadServiceBase, ModelInfo
|
||||
from .model_record_service import ModelRecordServiceBase, ModelConfigBase, ModelType, SubModelType
|
||||
from .model_loader_service import ModelInfo, ModelLoadServiceBase
|
||||
from .model_record_service import ModelConfigBase, ModelRecordServiceBase, ModelType, SubModelType
|
||||
|
||||
|
||||
class ModelConvertBase(ABC):
|
||||
@ -25,19 +26,19 @@ class ModelConvertBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
cls,
|
||||
loader: ModelLoadServiceBase,
|
||||
installer: ModelInstallServiceBase,
|
||||
store: ModelRecordServiceBase,
|
||||
cls,
|
||||
loader: ModelLoadServiceBase,
|
||||
installer: ModelInstallServiceBase,
|
||||
store: ModelRecordServiceBase,
|
||||
):
|
||||
"""Initialize ModelConvert with loader, installer and configuration store."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def convert_model(
|
||||
self,
|
||||
key: str,
|
||||
dest_directory: Optional[Path] = None,
|
||||
self,
|
||||
key: str,
|
||||
dest_directory: Optional[Path] = None,
|
||||
) -> ModelConfigBase:
|
||||
"""
|
||||
Convert a checkpoint file into a diffusers folder.
|
||||
@ -75,14 +76,15 @@ class ModelConvertBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ModelConvert(ModelConvertBase):
|
||||
"""Implementation of ModelConvertBase."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
loader: ModelLoadServiceBase,
|
||||
installer: ModelInstallServiceBase,
|
||||
store: ModelRecordServiceBase,
|
||||
self,
|
||||
loader: ModelLoadServiceBase,
|
||||
installer: ModelInstallServiceBase,
|
||||
store: ModelRecordServiceBase,
|
||||
):
|
||||
"""Initialize ModelConvert with loader, installer and configuration store."""
|
||||
self.loader = loader
|
||||
@ -90,9 +92,9 @@ class ModelConvert(ModelConvertBase):
|
||||
self.store = store
|
||||
|
||||
def convert_model(
|
||||
self,
|
||||
key: str,
|
||||
dest_directory: Optional[Path] = None,
|
||||
self,
|
||||
key: str,
|
||||
dest_directory: Optional[Path] = None,
|
||||
) -> ModelConfigBase:
|
||||
"""
|
||||
Convert a checkpoint file into a diffusers folder.
|
||||
|
@ -5,14 +5,13 @@ import tempfile
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from shutil import move, rmtree
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Union, Literal
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Set, Union
|
||||
|
||||
from pydantic import Field, parse_obj_as
|
||||
from pydantic import Field
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.model_record_service import ModelRecordServiceBase
|
||||
from invokeai.backend.util import Chdir, InvokeAILogger, Logger
|
||||
from invokeai.backend import get_precision
|
||||
from invokeai.backend.model_manager.config import (
|
||||
BaseModelType,
|
||||
@ -34,13 +33,17 @@ from invokeai.backend.model_manager.models import InvalidModelException
|
||||
from invokeai.backend.model_manager.probe import ModelProbe, ModelProbeInfo
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
from invokeai.backend.model_manager.storage import DuplicateModelException, ModelConfigStore
|
||||
from .events import EventServiceBase
|
||||
from invokeai.backend.util import Chdir, InvokeAILogger, Logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .events import EventServiceBase
|
||||
|
||||
from .download_manager import (
|
||||
DownloadQueueServiceBase,
|
||||
DownloadQueueService,
|
||||
DownloadEventHandler,
|
||||
DownloadJobBase,
|
||||
DownloadJobPath,
|
||||
DownloadEventHandler,
|
||||
DownloadQueueService,
|
||||
DownloadQueueServiceBase,
|
||||
ModelSourceMetadata,
|
||||
)
|
||||
|
||||
@ -81,7 +84,7 @@ class ModelInstallServiceBase(ABC):
|
||||
config: Optional[InvokeAIAppConfig] = None,
|
||||
queue: Optional[DownloadQueueServiceBase] = None,
|
||||
store: Optional[ModelRecordServiceBase] = None,
|
||||
event_bus: Optional[EventServiceBase] = None,
|
||||
event_bus: Optional["EventServiceBase"] = None,
|
||||
):
|
||||
"""
|
||||
Create ModelInstallService object.
|
||||
@ -227,6 +230,7 @@ class ModelInstallServiceBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ModelInstallService(ModelInstallServiceBase):
|
||||
"""Model installer class handles installation from a local path."""
|
||||
|
||||
@ -239,7 +243,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
_tmpdir: Optional[tempfile.TemporaryDirectory] # used for downloads
|
||||
_cached_model_paths: Set[Path] = Field(default=set) # used to speed up directory scanning
|
||||
_precision: Literal["float16", "float32"] = Field(description="Floating point precision, string form")
|
||||
_event_bus: Optional[EventServiceBase] = Field(description="an event bus to send install events to", default=None)
|
||||
_event_bus: Optional["EventServiceBase"] = Field(description="an event bus to send install events to", default=None)
|
||||
|
||||
_legacy_configs: Dict[BaseModelType, Dict[ModelVariantType, Union[str, dict]]] = {
|
||||
BaseModelType.StableDiffusion1: {
|
||||
@ -269,7 +273,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
config: Optional[InvokeAIAppConfig] = None,
|
||||
queue: Optional[DownloadQueueServiceBase] = None,
|
||||
store: Optional[ModelRecordServiceBase] = None,
|
||||
event_bus: Optional[EventServiceBase] = None,
|
||||
event_bus: Optional["EventServiceBase"] = None,
|
||||
event_handlers: List[DownloadEventHandler] = [],
|
||||
): # noqa D107 - use base class docstrings
|
||||
self._app_config = config or InvokeAIAppConfig.get_config()
|
||||
@ -281,10 +285,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
if self._event_bus:
|
||||
self._handlers.append(self._event_bus.emit_model_event)
|
||||
|
||||
self._download_queue = queue or DownloadQueueService(
|
||||
event_bus=event_bus,
|
||||
config=self._app_config
|
||||
)
|
||||
self._download_queue = queue or DownloadQueueService(event_bus=event_bus, config=self._app_config)
|
||||
self._async_installs: Dict[Union[str, Path, AnyHttpUrl], Union[str, None]] = dict()
|
||||
self._installed = set()
|
||||
self._tmpdir = None
|
||||
@ -318,7 +319,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
probe_override: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[ModelSourceMetadata] = None,
|
||||
access_token: Optional[str] = None,
|
||||
) -> DownloadJobBase: # noqa D102
|
||||
) -> ModelInstallJob: # noqa D102
|
||||
queue = self._download_queue
|
||||
variant = variant or ("fp16" if self._precision == "float16" else None)
|
||||
|
||||
|
@ -3,15 +3,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
||||
|
||||
from pydantic import Field
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.app.models.exceptions import CanceledException
|
||||
from invokeai.backend.model_manager import ModelConfigStore, SubModelType
|
||||
from invokeai.backend.model_manager.cache import CacheStats
|
||||
from invokeai.backend.model_manager.loader import ModelInfo, ModelLoad, ModelConfigBase
|
||||
from invokeai.backend.model_manager.loader import ModelConfigBase, ModelInfo, ModelLoad
|
||||
|
||||
from .config import InvokeAIAppConfig
|
||||
from .model_record_service import ModelRecordServiceBase
|
||||
@ -58,8 +58,6 @@ class ModelLoadServiceBase(ABC):
|
||||
"""Reset model cache statistics for graph with graph_id."""
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
# implementation
|
||||
class ModelLoadService(ModelLoadServiceBase):
|
||||
@ -140,4 +138,3 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
model_key=model_key,
|
||||
submodel=submodel,
|
||||
)
|
||||
|
||||
|
@ -8,8 +8,7 @@ from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
from invokeai.backend.model_manager import ModelConfigBase, BaseModelType, ModelType, SubModelType
|
||||
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelConfigBase, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.storage import (
|
||||
ModelConfigStore,
|
||||
ModelConfigStoreSQL,
|
||||
|
@ -12,9 +12,9 @@ from tqdm import tqdm
|
||||
|
||||
import invokeai.configs as configs
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.model_install_service import ModelInstall, ModelInstallJob, ModelSourceMetadata
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelType
|
||||
from invokeai.backend.model_manager.download.queue import DownloadJobRemoteSource
|
||||
from invokeai.app.services.model_install_service import ModelInstall, ModelInstallJob, ModelSourceMetadata
|
||||
|
||||
# name of the starter models file
|
||||
INITIAL_MODELS = "INITIAL_MODELS.yaml"
|
||||
|
@ -8,4 +8,4 @@ from .base import ( # noqa F401
|
||||
UnknownJobIDException,
|
||||
)
|
||||
from .model_queue import ModelDownloadQueue, ModelSourceMetadata # noqa F401
|
||||
from .queue import DownloadJobPath, DownloadJobURL, DownloadQueue, DownloadJobRemoteSource # noqa F401
|
||||
from .queue import DownloadJobPath, DownloadJobRemoteSource, DownloadJobURL, DownloadQueue # noqa F401
|
||||
|
@ -278,11 +278,11 @@ class ModelDownloadQueue(DownloadQueue):
|
||||
# including checkpoint files, different EMA versions, etc.
|
||||
# This filters out just the file types needed for the model
|
||||
for x in sibs:
|
||||
if x.rfilename.endswith(('.json', '.txt')):
|
||||
if x.rfilename.endswith((".json", ".txt")):
|
||||
paths.append(x.rfilename)
|
||||
elif x.rfilename.endswith(('learned_embeds.bin', 'ip_adapter.bin')):
|
||||
elif x.rfilename.endswith(("learned_embeds.bin", "ip_adapter.bin")):
|
||||
paths.append(x.rfilename)
|
||||
elif re.search(r'model(\.[^.]+)?\.(safetensors|bin)$', x.rfilename):
|
||||
elif re.search(r"model(\.[^.]+)?\.(safetensors|bin)$", x.rfilename):
|
||||
paths.append(x.rfilename)
|
||||
|
||||
sizes = {x.rfilename: x.size for x in sibs}
|
||||
|
@ -5,8 +5,8 @@ import hashlib
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, Union
|
||||
from shutil import move, rmtree
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -6,8 +6,8 @@ import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig
|
||||
from invokeai.app.services.model_record_service import ModelRecordServiceBase
|
||||
from invokeai.app.services.model_install_service import ModelInstallService
|
||||
from invokeai.app.services.model_record_service import ModelRecordServiceBase
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelType, SubModelType, UnknownModelException
|
||||
from invokeai.backend.model_manager.loader import ModelInfo, ModelLoad
|
||||
|
||||
|
@ -49,6 +49,7 @@ def mock_services() -> InvocationServices:
|
||||
conn=db_conn, table_name="graph_executions", lock=lock
|
||||
)
|
||||
return InvocationServices(
|
||||
download_queue=None, # type: ignore
|
||||
model_loader=None, # type: ignore
|
||||
model_installer=None, # type: ignore
|
||||
model_record_store=None, # type: ignore
|
||||
|
@ -1,7 +1,8 @@
|
||||
# conftest.py is a special pytest file. Fixtures defined in this file will be accessible to all tests in this directory
|
||||
# without needing to explicitly import them. (https://docs.pytest.org/en/6.2.x/fixture.html)
|
||||
|
||||
from invokeai.app.services.model_install_service import ModelInstallService # noqa: F401
|
||||
|
||||
# We import the model_installer and torch_device fixtures here so that they can be used by all tests. Flake8 does not
|
||||
# play well with fixtures (F401 and F811), so this is cleaner than importing in all files that use these fixtures.
|
||||
from invokeai.backend.util.test_utils import torch_device # noqa: F401
|
||||
from invokeai.app.services.model_install_service import ModelInstallService # noqa: F401
|
||||
|
Reference in New Issue
Block a user