refactor installer class hierarchy

This commit is contained in:
Lincoln Stein
2023-10-09 13:56:28 -04:00
parent 33d4756c48
commit 4149d357bf
18 changed files with 901 additions and 1180 deletions

View File

@ -26,6 +26,7 @@ from ..services.invocation_services import InvocationServices
from ..services.invocation_stats import InvocationStatsService
from ..services.invoker import Invoker
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from ..services.download_manager import DownloadQueueService
from ..services.model_install_service import ModelInstallService
from ..services.model_loader_service import ModelLoadService
from ..services.model_record_service import ModelRecordServiceBase
@ -129,9 +130,14 @@ class ApiDependencies:
)
)
download_queue = DownloadQueueService(event_bus=events, config=config)
model_record_store = ModelRecordServiceBase.get_impl(config, conn=db_conn, lock=lock)
model_loader = ModelLoadService(config, model_record_store)
model_installer = ModelInstallService(config, model_record_store, events)
model_installer = ModelInstallService(config,
queue=download_queue,
store=model_record_store,
event_bus=events
)
services = InvocationServices(
events=events,
@ -146,6 +152,7 @@ class ApiDependencies:
configuration=config,
performance_statistics=InvocationStatsService(graph_execution_manager),
logger=logger,
download_queue=download_queue,
model_record_store=model_record_store,
model_loader=model_loader,
model_installer=model_installer,

View File

@ -19,9 +19,11 @@ from invokeai.backend.model_manager import (
ModelConfigBase,
SchedulerPredictionType,
UnknownModelException,
ModelSearch
)
from invokeai.backend.model_manager.download import DownloadJobStatus, UnknownJobIDException
from invokeai.backend.model_manager.merge import MergeInterpolationMethod
from invokeai.app.services.download_manager import DownloadJobStatus, UnknownJobIDException, DownloadJobRemoteSource
from invokeai.app.services.model_convert import MergeInterpolationMethod, ModelConvert
from invokeai.app.services.model_install_service import ModelInstallJob
models_router = APIRouter(prefix="/v1/models", tags=["models"])
@ -39,7 +41,7 @@ class ModelsList(BaseModel):
models: List[InvokeAIModelConfig]
class ModelImportStatus(BaseModel):
class ModelDownloadStatus(BaseModel):
"""Return information about a background installation job."""
job_id: int
@ -56,7 +58,6 @@ class JobControlOperation(str, Enum):
CANCEL = "Cancel"
CHANGE_PRIORITY = "Change Priority"
@models_router.get(
"/",
operation_id="list_models",
@ -135,7 +136,7 @@ async def update_model(
409: {"description": "There is already a model corresponding to this path or repo_id"},
},
status_code=201,
response_model=ModelImportStatus,
response_model=ModelDownloadStatus,
)
async def import_model(
location: str = Body(description="A model path, repo_id or URL to import"),
@ -147,7 +148,7 @@ async def import_model(
description="Which import jobs run first. Lower values run before higher ones.",
default=10,
),
) -> ModelImportStatus:
) -> ModelDownloadStatus:
"""
Add a model using its local path, repo_id, or remote URL.
@ -172,10 +173,10 @@ async def import_model(
installer = ApiDependencies.invoker.services.model_installer
result = installer.install_model(
location,
model_attributes={"prediction_type": SchedulerPredictionType(prediction_type)},
probe_override={"prediction_type": SchedulerPredictionType(prediction_type) if prediction_type else None},
priority=priority,
)
return ModelImportStatus(
return ModelDownloadStatus(
job_id=result.id,
source=result.source,
priority=result.priority,
@ -288,8 +289,12 @@ async def convert_model(
"""Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none."""
try:
dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None
installer = ApiDependencies.invoker.services.model_installer
model_config = installer.convert_model(key, dest_directory=dest)
converter = ModelConvert(
loader=ApiDependencies.invoker.services.model_loader,
installer=ApiDependencies.invoker.services.model_installer,
store=ApiDependencies.invoker.services.model_record_store
)
model_config = converter.convert_model(key, dest_directory=dest)
response = parse_obj_as(InvokeAIModelConfig, model_config.dict())
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=f"Model '{key}' not found: {str(e)}")
@ -316,8 +321,7 @@ async def search_for_models(
raise HTTPException(
status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory"
)
return ApiDependencies.invoker.services.model_installer.search_for_models(search_path)
return ModelSearch().search(search_path)
@models_router.get(
"/ckpt_confs",
@ -330,7 +334,10 @@ async def search_for_models(
)
async def list_ckpt_configs() -> List[pathlib.Path]:
"""Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT."""
return ApiDependencies.invoker.services.model_installer.list_checkpoint_configs()
config = ApiDependencies.invoker.services.configuration
conf_path = config.legacy_conf_path
root_path = config.root_path
return [(conf_path / x).relative_to(root_path) for x in conf_path.glob("**/*.yaml")]
@models_router.post(
@ -349,7 +356,8 @@ async def sync_to_config() -> bool:
Call after making changes to models.yaml, autoimport directories
or models directory.
"""
ApiDependencies.invoker.services.model_installer.sync_to_config()
installer = ApiDependencies.invoker.services.model_installer
installer.sync_to_config()
return True
@ -383,7 +391,12 @@ async def merge_models(
try:
logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
result: ModelConfigBase = ApiDependencies.invoker.services.model_installer.merge_models(
converter = ModelConvert(
loader=ApiDependencies.invoker.services.model_loader,
installer=ApiDependencies.invoker.services.model_installer,
store=ApiDependencies.invoker.services.model_record_store
)
result: ModelConfigBase = converter.merge_models(
model_keys=keys,
merged_model_name=merged_model_name,
alpha=alpha,
@ -409,14 +422,14 @@ async def merge_models(
400: {"description": "Bad request"},
},
status_code=200,
response_model=List[ModelImportStatus],
response_model=List[ModelDownloadStatus],
)
async def list_install_jobs() -> List[ModelImportStatus]:
async def list_install_jobs() -> List[ModelDownloadStatus]:
"""List active and pending model installation jobs."""
job_mgr = ApiDependencies.invoker.services.model_installer
jobs = job_mgr.list_install_jobs()
job_mgr = ApiDependencies.invoker.services.download_queue
jobs = job_mgr.list_jobs()
return [
ModelImportStatus(
ModelDownloadStatus(
job_id=x.id,
source=x.source,
priority=x.priority,
@ -424,32 +437,32 @@ async def list_install_jobs() -> List[ModelImportStatus]:
total_bytes=x.total_bytes,
status=x.status,
)
for x in jobs
for x in jobs if isinstance(x, ModelInstallJob)
]
@models_router.patch(
"/jobs/control/{operation}/{job_id}",
operation_id="control_install_jobs",
operation_id="control_download_jobs",
responses={
200: {"description": "The control job was updated successfully"},
400: {"description": "Bad request"},
404: {"description": "The job could not be found"},
},
status_code=200,
response_model=ModelImportStatus,
response_model=ModelDownloadStatus,
)
async def control_install_jobs(
job_id: int = Path(description="Install job_id for start, pause and cancel operations"),
async def control_download_jobs(
job_id: int = Path(description="Download/install job_id for start, pause and cancel operations"),
operation: JobControlOperation = Path(description="The operation to perform on the job."),
priority_delta: Optional[int] = Body(
description="Change in job priority for priority operations only. Negative numbers increase priority.",
default=None,
),
) -> ModelImportStatus:
) -> ModelDownloadStatus:
"""Start, pause, cancel, or change the run priority of a running model install job."""
logger = ApiDependencies.invoker.services.logger
job_mgr = ApiDependencies.invoker.services.model_installer
job_mgr = ApiDependencies.invoker.services.download_queue
try:
job = job_mgr.id_to_job(job_id)
@ -467,14 +480,19 @@ async def control_install_jobs(
else:
raise ValueError(f"Unknown operation {operation.value}")
bytes = 0
total_bytes = 0
if isinstance(job, DownloadJobRemoteSource):
bytes = job.bytes
total_bytes = job.total_bytes
return ModelImportStatus(
return ModelDownloadStatus(
job_id=job_id,
source=job.source,
priority=job.priority,
status=job.status,
bytes=job.bytes,
total_bytes=job.total_bytes,
bytes=bytes,
total_bytes=total_bytes,
)
except UnknownJobIDException as e:
raise HTTPException(status_code=404, detail=str(e))
@ -485,17 +503,17 @@ async def control_install_jobs(
@models_router.patch(
"/jobs/cancel_all",
operation_id="cancel_all_jobs",
operation_id="cancel_all_download_jobs",
responses={
204: {"description": "All jobs cancelled successfully"},
400: {"description": "Bad request"},
},
)
async def cancel_install_jobs():
async def cancel_all_download_jobs():
"""Cancel all model installation jobs."""
logger = ApiDependencies.invoker.services.logger
job_mgr = ApiDependencies.invoker.services.model_installer
logger.info("Cancelling all model installation jobs.")
job_mgr = ApiDependencies.invoker.services.download_queue
logger.info("Cancelling all download jobs.")
job_mgr.cancel_all_jobs()
return Response(status_code=204)
@ -510,7 +528,6 @@ async def cancel_install_jobs():
)
async def prune_jobs():
"""Prune all completed and errored jobs."""
logger = ApiDependencies.invoker.services.logger
mgr = ApiDependencies.invoker.services.model_installer
mgr = ApiDependencies.invoker.services.download_queue
mgr.prune_jobs()
return Response(status_code=204)

View File

@ -145,7 +145,7 @@ class BaseCommand(ABC, BaseModel):
"""A CLI command"""
# All commands must include a type name like this:
# type: Literal['your_command_name'] = 'your_command_name'
# Literal['your_command_name'] = 'your_command_name'
@classmethod
def get_all_subclasses(cls):

View File

@ -9,7 +9,18 @@ from typing import List, Optional, Union
from pydantic.networks import AnyHttpUrl
from invokeai.backend.model_manager.download import DownloadEventHandler, DownloadJobBase, DownloadQueue
from invokeai.backend.model_manager.download import ( # noqa F401
DownloadEventHandler,
DownloadJobBase,
DownloadJobPath,
DownloadJobStatus,
DownloadQueueBase,
ModelDownloadQueue,
ModelSourceMetadata,
UnknownJobIDException,
)
from invokeai.backend.model_manager.download import DownloadJobRemoteSource # noqa F401
from .events import EventServiceBase
@ -40,6 +51,22 @@ class DownloadQueueServiceBase(ABC):
"""
pass
@abstractmethod
def submit_download_job(
self,
job: DownloadJobBase,
start: bool = True,
):
"""
Submit a download job.
:param job: A DownloadJobBase
:param start: Immediately start job [True]
After execution, `job.id` will be set to a non-negative value.
"""
pass
@abstractmethod
def list_jobs(self) -> List[DownloadJobBase]:
"""
@ -76,6 +103,11 @@ class DownloadQueueServiceBase(ABC):
"""Cancel all active and enquedjobs."""
pass
@abstractmethod
def prune_jobs(self):
"""Prune completed and errored queue items from the job list."""
pass
@abstractmethod
def start_job(self, job: DownloadJobBase):
"""Start the job putting it into ENQUEUED state."""
@ -115,7 +147,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
"""Multithreaded queue for downloading models via URL or repo_id."""
_event_bus: EventServiceBase
_queue: DownloadQueue
_queue: DownloadQueueBase
def __init__(self, event_bus: EventServiceBase, **kwargs):
"""
@ -126,7 +158,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
e.g. `max_parallel_dl`.
"""
self._event_bus = event_bus
self._queue = DownloadQueue(**kwargs)
self._queue = ModelDownloadQueue(**kwargs)
def create_download_job(
self,
@ -149,6 +181,13 @@ class DownloadQueueService(DownloadQueueServiceBase):
event_handlers=event_handlers,
)
def submit_download_job(
self,
job: DownloadJobBase,
start: bool = True,
):
return self._queue.submit_download_job(job, start)
def list_jobs(self) -> List[DownloadJobBase]: # noqa D102
return self._queue.list_jobs()
@ -164,6 +203,9 @@ class DownloadQueueService(DownloadQueueServiceBase):
def cancel_all_jobs(self): # noqa D102
return self._queue.cancel_all_jobs()
def prune_jobs(self, job: DownloadJobBase): # noqa D102
return self._queue.prune_jobs()
def start_job(self, job: DownloadJobBase): # noqa D102
return self._queue.start_job(job)

View File

@ -18,6 +18,7 @@ if TYPE_CHECKING:
from invokeai.app.services.invoker import InvocationProcessorABC
from invokeai.app.services.item_storage import ItemStorageABC
from invokeai.app.services.latent_storage import LatentsStorageBase
from invokeai.app.services.download_manager import DownloadQueueServiceBase
from invokeai.app.services.model_install_service import ModelInstallServiceBase
from invokeai.app.services.model_loader_service import ModelLoadServiceBase
from invokeai.app.services.model_record_service import ModelRecordServiceBase
@ -37,6 +38,7 @@ class InvocationServices:
graph_library: "ItemStorageABC[LibraryGraph]"
images: "ImageServiceABC"
latents: "LatentsStorageBase"
download_queue: "DownloadQueueServiceBase"
model_record_store: "ModelRecordServiceBase"
model_loader: "ModelLoadServiceBase"
model_installer: "ModelInstallServiceBase"
@ -59,6 +61,7 @@ class InvocationServices:
images: "ImageServiceABC",
latents: "LatentsStorageBase",
logger: "Logger",
download_queue: "DownloadQueueServiceBase",
model_record_store: "ModelRecordServiceBase",
model_loader: "ModelLoadServiceBase",
model_installer: "ModelInstallServiceBase",
@ -78,6 +81,7 @@ class InvocationServices:
self.images = images
self.latents = latents
self.logger = logger
self.download_queue = download_queue
self.model_record_store = model_record_store
self.model_loader = model_loader
self.model_installer = model_installer

View File

@ -0,0 +1,190 @@
# Copyright 2023 Lincoln Stein and the InvokeAI Team
"""
Convert and merge models.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Optional, List
from pydantic import Field
from pathlib import Path
from shutil import move, rmtree
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
from .config import InvokeAIAppConfig
from .model_install_service import ModelInstallServiceBase
from .model_loader_service import ModelLoadServiceBase, ModelInfo
from .model_record_service import ModelRecordServiceBase, ModelConfigBase, ModelType, SubModelType
class ModelConvertBase(ABC):
"""Convert and merge models."""
@abstractmethod
def __init__(
cls,
loader: ModelLoadServiceBase,
installer: ModelInstallServiceBase,
store: ModelRecordServiceBase,
):
"""Initialize ModelConvert with loader, installer and configuration store."""
pass
@abstractmethod
def convert_model(
self,
key: str,
dest_directory: Optional[Path] = None,
) -> ModelConfigBase:
"""
Convert a checkpoint file into a diffusers folder.
It will delete the cached version ans well as the
original checkpoint file if it is in the models directory.
:param key: Unique key of model.
:dest_directory: Optional place to put converted file. If not specified,
will be stored in the `models_dir`.
This will raise a ValueError unless the model is a checkpoint.
This will raise an UnknownModelException if key is unknown.
"""
pass
def merge_models(
self,
model_keys: List[str] = Field(
default=None, min_items=2, max_items=3, description="List of model keys to merge"
),
merged_model_name: Optional[str] = Field(default=None, description="Name of destination model after merging"),
alpha: Optional[float] = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
force: Optional[bool] = False,
merge_dest_directory: Optional[Path] = None,
) -> ModelConfigBase:
"""
Merge two to three diffusrs pipeline models and save as a new model.
:param model_keys: List of 2-3 model unique keys to merge
:param merged_model_name: Name of destination merged model
:param alpha: Alpha strength to apply to 2d and 3d model
:param interp: Interpolation method. None (default)
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
"""
pass
class ModelConvert(ModelConvertBase):
"""Implementation of ModelConvertBase."""
def __init__(
self,
loader: ModelLoadServiceBase,
installer: ModelInstallServiceBase,
store: ModelRecordServiceBase,
):
"""Initialize ModelConvert with loader, installer and configuration store."""
self.loader = loader
self.installer = installer
self.store = store
def convert_model(
self,
key: str,
dest_directory: Optional[Path] = None,
) -> ModelConfigBase:
"""
Convert a checkpoint file into a diffusers folder.
It will delete the cached version ans well as the
original checkpoint file if it is in the models directory.
:param key: Unique key of model.
:dest_directory: Optional place to put converted file. If not specified,
will be stored in the `models_dir`.
This will raise a ValueError unless the model is a checkpoint.
This will raise an UnknownModelException if key is unknown.
"""
new_diffusers_path = None
config = InvokeAIAppConfig.get_config()
try:
info: ModelConfigBase = self.store.get_model(key)
if info.model_format != "checkpoint":
raise ValueError(f"not a checkpoint format model: {info.name}")
# We are taking advantage of a side effect of get_model() that converts check points
# into cached diffusers directories stored at `path`. It doesn't matter
# what submodel type we request here, so we get the smallest.
submodel = {"submodel_type": SubModelType.Scheduler} if info.model_type == ModelType.Main else {}
converted_model: ModelInfo = self.loader.get_model(key, **submodel)
checkpoint_path = config.models_path / info.path
old_diffusers_path = config.models_path / converted_model.location
# new values to write in
update = info.dict()
update.pop("config")
update["model_format"] = "diffusers"
update["path"] = str(converted_model.location)
if dest_directory:
new_diffusers_path = Path(dest_directory) / info.name
if new_diffusers_path.exists():
raise ValueError(f"A diffusers model already exists at {new_diffusers_path}")
move(old_diffusers_path, new_diffusers_path)
update["path"] = new_diffusers_path.as_posix()
self.store.update_model(key, update)
result = self.installer.sync_model_path(key, ignore_hash_change=True)
except Exception as excp:
# something went wrong, so don't leave dangling diffusers model in directory or it will cause a duplicate model error!
if new_diffusers_path:
rmtree(new_diffusers_path)
raise excp
if checkpoint_path.exists() and checkpoint_path.is_relative_to(config.models_path):
checkpoint_path.unlink()
return result
def merge_models(
self,
model_keys: List[str] = Field(
default=None, min_items=2, max_items=3, description="List of model keys to merge"
),
merged_model_name: Optional[str] = Field(default=None, description="Name of destination model after merging"),
alpha: Optional[float] = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
force: Optional[bool] = False,
merge_dest_directory: Optional[Path] = None,
) -> ModelConfigBase:
"""
Merge two to three diffusrs pipeline models and save as a new model.
:param model_keys: List of 2-3 model unique keys to merge
:param merged_model_name: Name of destination merged model
:param alpha: Alpha strength to apply to 2d and 3d model
:param interp: Interpolation method. None (default)
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
"""
pass
merger = ModelMerger(self.store)
try:
if not merged_model_name:
merged_model_name = "+".join([self.store.get_model(x).name for x in model_keys])
raise Exception("not implemented")
result = merger.merge_diffusion_models_and_save(
model_keys=model_keys,
merged_model_name=merged_model_name,
alpha=alpha,
interp=interp,
force=force,
merge_dest_directory=merge_dest_directory,
)
except AssertionError as e:
raise ValueError(e)
return result

View File

@ -1,41 +1,143 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
from __future__ import annotations
import shutil
from abc import abstractmethod
import re
import tempfile
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Set, Union
from shutil import move, rmtree
from typing import Any, Callable, Dict, List, Optional, Set, Union, Literal
from pydantic import Field, parse_obj_as
from pydantic.networks import AnyHttpUrl
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_record_service import ModelRecordServiceBase
from invokeai.backend.util import Chdir, InvokeAILogger, Logger
from invokeai.backend import get_precision
from invokeai.backend.model_manager import ModelConfigBase, ModelSearch
from invokeai.backend.model_manager.download import DownloadJobBase
from invokeai.backend.model_manager.install import ModelInstall, ModelInstallBase, ModelInstallJob
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
from invokeai.backend.util.logging import InvokeAILogger
from .config import InvokeAIAppConfig
from invokeai.backend.model_manager.config import (
BaseModelType,
ModelConfigBase,
ModelFormat,
ModelType,
ModelVariantType,
SchedulerPredictionType,
SubModelType,
)
from invokeai.backend.model_manager.download.model_queue import (
HTTP_RE,
REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE,
DownloadJobRepoID,
DownloadJobWithMetadata,
)
from invokeai.backend.model_manager.hash import FastModelHash
from invokeai.backend.model_manager.models import InvalidModelException
from invokeai.backend.model_manager.probe import ModelProbe, ModelProbeInfo
from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.model_manager.storage import DuplicateModelException, ModelConfigStore
from .events import EventServiceBase
from .model_record_service import ModelRecordServiceBase
from .download_manager import (
DownloadQueueServiceBase,
DownloadQueueService,
DownloadJobBase,
DownloadJobPath,
DownloadEventHandler,
ModelSourceMetadata,
)
class ModelInstallServiceBase(ModelInstallBase): # This is an ABC
"""Responsible for downloading, installing and deleting models."""
class ModelInstallJob(DownloadJobBase):
"""This is a version of DownloadJobBase that has an additional slot for the model key and probe info."""
model_key: Optional[str] = Field(
description="After model installation, this field will hold its primary key", default=None
)
probe_override: Optional[Dict[str, Any]] = Field(
description="Keys in this dict will override like-named attributes in the automatic probe info",
default=None,
)
class ModelInstallURLJob(DownloadJobWithMetadata, ModelInstallJob):
"""Job for installing URLs."""
class ModelInstallRepoIDJob(DownloadJobRepoID, ModelInstallJob):
"""Job for installing repo ids."""
class ModelInstallPathJob(DownloadJobPath, ModelInstallJob):
"""Job for installing local paths."""
ModelInstallEventHandler = Callable[["ModelInstallJob"], None]
class ModelInstallServiceBase(ABC):
"""Abstract base class for InvokeAI model installation."""
@abstractmethod
def __init__(
self, config: InvokeAIAppConfig, store: ModelRecordServiceBase, event_bus: Optional[EventServiceBase] = None
self,
config: Optional[InvokeAIAppConfig] = None,
queue: Optional[DownloadQueueServiceBase] = None,
store: Optional[ModelRecordServiceBase] = None,
event_bus: Optional[EventServiceBase] = None,
):
"""
Initialize a ModelInstallService instance.
Create ModelInstallService object.
:param config: InvokeAIAppConfig object
:param store: A ModelRecordServiceBase object install to
:param event_bus: Optional EventServiceBase object. If provided,
installation and download events will be sent to the event bus as "model_event".
:param store: Optional ModelConfigStore. If None passed,
defaults to `configs/models.yaml`.
:param config: Optional InvokeAIAppConfig. If None passed,
uses the system-wide default app config.
:param logger: Optional InvokeAILogger. If None passed,
uses the system-wide default logger.
:param download: Optional DownloadQueueServiceBase object. If None passed,
a default queue object will be created.
:param event_handlers: List of event handlers to pass to the queue object.
"""
pass
@property
@abstractmethod
def queue(self) -> DownloadQueueServiceBase:
"""Return the download queue used by the installer."""
pass
@property
@abstractmethod
def store(self) -> ModelRecordServiceBase:
"""Return the storage backend used by the installer."""
pass
@property
@abstractmethod
def config(self) -> InvokeAIAppConfig:
"""Return the app_config used by the installer."""
pass
@abstractmethod
def register_path(self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]]) -> str:
"""
Probe and register the model at model_path.
:param model_path: Filesystem Path to the model.
:param overrides: Dict of attributes that will override probed values.
:returns id: The string ID of the registered model.
"""
pass
@abstractmethod
def install_path(self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]] = None) -> str:
"""
Probe, register and install the model in the models directory.
This involves moving the model from its current location into
the models directory handled by InvokeAI.
:param model_path: Filesystem Path to the model.
:param overrides: Dictionary of model probe info fields that, if present, override probed values.
:returns id: The string ID of the installed model.
"""
pass
@ -43,327 +145,500 @@ class ModelInstallServiceBase(ModelInstallBase): # This is an ABC
def install_model(
self,
source: Union[str, Path, AnyHttpUrl],
inplace: bool = True,
priority: int = 10,
model_attributes: Optional[Dict[str, Any]] = None,
variant: Optional[str] = None,
subfolder: Optional[str] = None,
probe_override: Optional[Dict[str, Any]] = None,
metadata: Optional[ModelSourceMetadata] = None,
access_token: Optional[str] = None,
) -> ModelInstallJob:
"""Import a path, repo_id or URL. Returns an ModelInstallJob.
"""
Download and install the indicated model.
:param model_attributes: Additional attributes to supplement/override
the model information gained from automated probing.
:param priority: Queue priority. Lower values have higher priority.
This will download the model located at `source`,
probe it, and install it into the models directory.
This call is executed asynchronously in a separate
thread, and the returned object is a
invokeai.backend.model_manager.download.DownloadJobBase
object which can be interrogated to get the status of
the download and install process. Call our `wait_for_installs()`
method to wait for all downloads and installations to complete.
Typical usage:
job = model_manager.install(
'stabilityai/stable-diffusion-2-1',
model_attributes={'prediction_type": 'v-prediction'}
)
:param source: Either a URL or a HuggingFace repo_id.
:param inplace: If True, local paths will not be moved into
the models directory, but registered in place (the default).
:param variant: For HuggingFace models, this optional parameter
specifies which variant to download (e.g. 'fp16')
:param subfolder: When downloading HF repo_ids this can be used to
specify a subfolder of the HF repository to download from.
:param probe_override: Optional dict. Any fields in this dict
will override corresponding probe fields. Use it to override
`base_type`, `model_type`, `format`, `prediction_type` and `image_size`.
:param metadata: Use this to override the fields 'description`,
`author`, `tags`, `source` and `license`.
The result is an ModelInstallJob object, which provides
information on the asynchronous model download and install
process. A series of "install_model_event" events will be emitted
until the install is completed, cancelled or errors out.
:returns ModelInstallJob object.
The `inplace` flag does not affect the behavior of downloaded
models, which are always moved into the `models` directory.
Variants recognized by HuggingFace currently are:
1. onnx
2. openvino
3. fp16
4. None (usually returns fp32 model)
"""
pass
@abstractmethod
def list_install_jobs(self) -> List[ModelInstallJob]:
"""Return a series of active or enqueued ModelInstallJobs."""
pass
@abstractmethod
def id_to_job(self, id: int) -> ModelInstallJob:
"""Return the ModelInstallJob instance corresponding to the given job ID."""
pass
@abstractmethod
def start_job(self, job_id: int):
"""Start the given install job if it is paused or idle."""
pass
@abstractmethod
def pause_job(self, job_id: int):
"""Pause the given install job if it is paused or idle."""
pass
@abstractmethod
def cancel_job(self, job_id: int):
"""Cancel the given install job."""
pass
@abstractmethod
def cancel_all_jobs(self):
"""Cancel all installation jobs."""
pass
@abstractmethod
def prune_jobs(self):
"""Remove completed or errored install jobs."""
pass
@abstractmethod
def change_job_priority(self, job_id: int, delta: int):
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]:
"""
Change an install job's priority.
Wait for all pending installs to complete.
:param job_id: Job to change
:param delta: Value to increment or decrement priority.
This will block until all pending downloads have
completed, been cancelled, or errored out. It will
block indefinitely if one or more jobs are in the
paused state.
Lower values are higher priority. The default starting value is 10.
Thus to make this a really high priority job:
manager.change_job_priority(-10).
It will return a dict that maps the source model
path, URL or repo_id to the ID of the installed model.
"""
pass
@abstractmethod
def merge_models(
self,
model_keys: List[str] = Field(
default=None, min_items=2, max_items=3, description="List of model keys to merge"
),
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
alpha: Optional[float] = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
force: Optional[bool] = False,
merge_dest_directory: Optional[Path] = None,
) -> ModelConfigBase:
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]:
"""
Merge two to three diffusrs pipeline models and save as a new model.
Recursively scan directory for new models and register or install them.
:param model_keys: List of 2-3 model unique keys to merge
:param merged_model_name: Name of destination merged model
:param alpha: Alpha strength to apply to 2d and 3d model
:param interp: Interpolation method. None (default)
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
:param scan_dir: Path to the directory to scan.
:param install: Install if True, otherwise register in place.
:returns list of IDs: Returns list of IDs of models registered/installed
"""
pass
@abstractmethod
def list_checkpoint_configs(self) -> List[Path]:
"""List the checkpoint config paths from ROOT/configs/stable-diffusion."""
def hash(self, model_path: Union[Path, str]) -> str:
"""
Compute and return the fast hash of the model.
:param model_path: Path to the model on disk.
:return str: FastHash of the model for use as an ID.
"""
pass
@abstractmethod
def search_for_models(self, directory: Path) -> Set[Path]:
"""Return list of all models found in the designated directory."""
pass
# implementation
class ModelInstallService(ModelInstall, ModelInstallServiceBase):
"""Responsible for managing models on disk and in memory."""
class ModelInstallService(ModelInstallServiceBase):
"""Model installer class handles installation from a local path."""
_app_config: InvokeAIAppConfig
_logger: Logger
_store: ModelConfigStore
_download_queue: DownloadQueueServiceBase
_async_installs: Dict[Union[str, Path, AnyHttpUrl], Optional[str]]
_installed: Set[str] = Field(default=set)
_tmpdir: Optional[tempfile.TemporaryDirectory] # used for downloads
_cached_model_paths: Set[Path] = Field(default=set) # used to speed up directory scanning
_precision: Literal["float16", "float32"] = Field(description="Floating point precision, string form")
_event_bus: Optional[EventServiceBase] = Field(description="an event bus to send install events to", default=None)
_legacy_configs: Dict[BaseModelType, Dict[ModelVariantType, Union[str, dict]]] = {
BaseModelType.StableDiffusion1: {
ModelVariantType.Normal: "v1-inference.yaml",
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
},
BaseModelType.StableDiffusion2: {
ModelVariantType.Normal: {
SchedulerPredictionType.Epsilon: "v2-inference.yaml",
SchedulerPredictionType.VPrediction: "v2-inference-v.yaml",
},
ModelVariantType.Inpaint: {
SchedulerPredictionType.Epsilon: "v2-inpainting-inference.yaml",
SchedulerPredictionType.VPrediction: "v2-inpainting-inference-v.yaml",
},
},
BaseModelType.StableDiffusionXL: {
ModelVariantType.Normal: "sd_xl_base.yaml",
},
BaseModelType.StableDiffusionXLRefiner: {
ModelVariantType.Normal: "sd_xl_refiner.yaml",
},
}
def __init__(
self, config: InvokeAIAppConfig, store: ModelRecordServiceBase, event_bus: Optional[EventServiceBase] = None
):
"""
Initialize a ModelInstallService instance.
:param config: InvokeAIAppConfig object
:param store: Either a ModelRecordService object or a ModelConfigStore
:param event_bus: Optional EventServiceBase object. If provided,
Installation and download events will be sent to the event bus as "model_event".
"""
self,
config: Optional[InvokeAIAppConfig] = None,
queue: Optional[DownloadQueueServiceBase] = None,
store: Optional[ModelRecordServiceBase] = None,
event_bus: Optional[EventServiceBase] = None,
event_handlers: List[DownloadEventHandler] = [],
): # noqa D107 - use base class docstrings
self._app_config = config or InvokeAIAppConfig.get_config()
self._store = store or ModelRecordServiceBase.get_impl(self._app_config)
self._logger = InvokeAILogger.get_logger(config=self._app_config)
self._event_bus = event_bus
kwargs: Dict[str, Any] = {}
if self._event_bus:
kwargs.update(event_handlers=[self._event_bus.emit_model_event])
self._precision = get_precision()
logger = InvokeAILogger.get_logger()
super().__init__(store=store, config=config, logger=logger, **kwargs)
self._handlers = event_handlers
if self._event_bus:
self._handlers.append(self._event_bus.emit_model_event)
self._download_queue = queue or DownloadQueueService(
event_bus=event_bus,
config=self._app_config
)
self._async_installs: Dict[Union[str, Path, AnyHttpUrl], Union[str, None]] = dict()
self._installed = set()
self._tmpdir = None
def start(self, invoker: Any): # Because .processor is giving circular import errors, declaring invoker an 'Any'
"""Call automatically at process start."""
self.scan_models_directory() # synchronize new/deleted models found in models directory
self.sync_to_config()
@property
def queue(self) -> DownloadQueueServiceBase:
"""Return the queue."""
return self._download_queue
@property
def store(self) -> ModelConfigStore:
"""Return the storage backend used by the installer."""
return self._store
@property
def config(self) -> InvokeAIAppConfig:
"""Return the app_config used by the installer."""
return self._app_config
def install_model(
self,
source: Union[str, Path, AnyHttpUrl],
inplace: bool = True,
priority: int = 10,
variant: Optional[str] = None,
subfolder: Optional[str] = None,
probe_override: Optional[Dict[str, Any]] = None,
metadata: Optional[ModelSourceMetadata] = None,
access_token: Optional[str] = None,
) -> DownloadJobBase: # noqa D102
queue = self._download_queue
variant = variant or ("fp16" if self._precision == "float16" else None)
job = self._make_download_job(
source, variant=variant, access_token=access_token, subfolder=subfolder, priority=priority
)
handler = (
self._complete_registration_handler
if inplace and Path(source).exists()
else self._complete_installation_handler
)
if isinstance(job, ModelInstallJob):
job.probe_override = probe_override
if metadata and isinstance(job, DownloadJobWithMetadata):
job.metadata = metadata
job.add_event_handler(handler)
self._async_installs[source] = None
queue.submit_download_job(job, True)
return job
def register_path(
self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]] = None
) -> str: # noqa D102
model_path = Path(model_path)
info: ModelProbeInfo = self._probe_model(model_path, overrides)
return self._register(model_path, info)
def install_path(
self,
model_path: Union[Path, str],
overrides: Optional[Dict[str, Any]] = None,
) -> str: # noqa D102
model_path = Path(model_path)
info: ModelProbeInfo = self._probe_model(model_path, overrides)
dest_path = self._app_config.models_path / info.base_type.value / info.model_type.value / model_path.name
new_path = self._move_model(model_path, dest_path)
new_hash = self.hash(new_path)
assert new_hash == info.hash, f"{model_path}: Model hash changed during installation, possibly corrupted."
return self._register(
new_path,
info,
)
def unregister(self, key: str): # noqa D102
self._store.del_model(key)
def delete(self, key: str): # noqa D102
model = self._store.get_model(key)
path = self._app_config.models_path / model.path
if path.is_dir():
rmtree(path)
else:
path.unlink()
self.unregister(key)
def conditionally_delete(self, key: str): # noqa D102
"""Unregister the model. Delete its files only if they are within our models directory."""
model = self._store.get_model(key)
models_dir = self._app_config.models_path
model_path = models_dir / model.path
if model_path.is_relative_to(models_dir):
self.delete(key)
else:
self.unregister(key)
def _register(self, model_path: Path, info: ModelProbeInfo) -> str:
key: str = FastModelHash.hash(model_path)
model_path = model_path.absolute()
if model_path.is_relative_to(self._app_config.models_path):
model_path = model_path.relative_to(self._app_config.models_path)
registration_data = dict(
path=model_path.as_posix(),
name=model_path.name if model_path.is_dir() else model_path.stem,
base_model=info.base_type,
model_type=info.model_type,
model_format=info.format,
hash=key,
)
# add 'main' specific fields
if info.model_type == ModelType.Main:
if info.variant_type:
registration_data.update(variant=info.variant_type)
if info.format == ModelFormat.Checkpoint:
try:
config_file = self._legacy_configs[info.base_type][info.variant_type]
if isinstance(config_file, dict): # need another tier for sd-2.x models
if prediction_type := info.prediction_type:
config_file = config_file[prediction_type]
else:
self._logger.warning(
f"Could not infer prediction type for {model_path.stem}. Guessing 'v_prediction' for a SD-2 768 pixel model"
)
config_file = config_file[SchedulerPredictionType.VPrediction]
registration_data.update(
config=Path(self._app_config.legacy_conf_dir, str(config_file)).as_posix(),
)
except KeyError as exc:
raise InvalidModelException(
"Configuration file for this checkpoint could not be determined"
) from exc
self._store.add_model(key, registration_data)
return key
def _move_model(self, old_path: Path, new_path: Path) -> Path:
if old_path == new_path:
return old_path
new_path.parent.mkdir(parents=True, exist_ok=True)
# if path already exists then we jigger the name to make it unique
counter: int = 1
while new_path.exists():
path = new_path.with_stem(new_path.stem + f"_{counter:02d}")
if not path.exists():
new_path = path
counter += 1
return move(old_path, new_path)
def _probe_model(self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]] = None) -> ModelProbeInfo:
info: ModelProbeInfo = ModelProbe.probe(Path(model_path))
if overrides: # used to override probe fields
for key, value in overrides.items():
try:
setattr(info, key, value) # skip validation errors
except Exception:
pass
return info
def _complete_installation_handler(self, job: DownloadJobBase):
assert isinstance(job, ModelInstallJob)
if job.status == "completed":
self._logger.info(f"{job.source}: Download finished with status {job.status}. Installing.")
model_id = self.install_path(job.destination, job.probe_override)
info = self._store.get_model(model_id)
info.source = str(job.source)
if isinstance(job, DownloadJobWithMetadata):
metadata: ModelSourceMetadata = job.metadata
info.description = metadata.description or f"Imported model {info.name}"
info.name = metadata.name or info.name
info.author = metadata.author
info.tags = metadata.tags
info.license = metadata.license
info.thumbnail_url = metadata.thumbnail_url
self._store.update_model(model_id, info)
self._async_installs[job.source] = model_id
job.model_key = model_id
elif job.status == "error":
self._logger.warning(f"{job.source}: Model installation error: {job.error}")
elif job.status == "cancelled":
self._logger.warning(f"{job.source}: Model installation cancelled at caller's request.")
jobs = self._download_queue.list_jobs()
if self._tmpdir and len(jobs) <= 1 and job.status in ["completed", "error", "cancelled"]:
self._tmpdir.cleanup()
self._tmpdir = None
def _complete_registration_handler(self, job: DownloadJobBase):
assert isinstance(job, ModelInstallJob)
if job.status == "completed":
self._logger.info(f"{job.source}: Installing in place.")
model_id = self.register_path(job.destination, job.probe_override)
info = self._store.get_model(model_id)
info.source = str(job.source)
info.description = f"Imported model {info.name}"
self._store.update_model(model_id, info)
self._async_installs[job.source] = model_id
job.model_key = model_id
elif job.status == "error":
self._logger.warning(f"{job.source}: Model installation error: {job.error}")
elif job.status == "cancelled":
self._logger.warning(f"{job.source}: Model installation cancelled at caller's request.")
def sync_model_path(self, key: str, ignore_hash_change: bool = False) -> ModelConfigBase:
"""
Move model into the location indicated by its basetype, type and name.
Call this after updating a model's attributes in order to move
the model's path into the location indicated by its basetype, type and
name. Applies only to models whose paths are within the root `models_dir`
directory.
May raise an UnknownModelException.
"""
model = self._store.get_model(key)
old_path = Path(model.path)
models_dir = self._app_config.models_path
if not old_path.is_relative_to(models_dir):
return model
new_path = models_dir / model.base_model.value / model.model_type.value / model.name
self._logger.info(f"Moving {model.name} to {new_path}.")
new_path = self._move_model(old_path, new_path)
model.hash = self.hash(new_path)
model.path = new_path.relative_to(models_dir).as_posix()
if model.hash != key:
assert (
ignore_hash_change
), f"{model.name}: Model hash changed during installation, model is possibly corrupted"
self._logger.info(f"Model has new hash {model.hash}, but will continue to be identified by {key}")
self._store.update_model(key, model)
return model
def _make_download_job(
self,
source: Union[str, Path, AnyHttpUrl],
variant: Optional[str] = None,
subfolder: Optional[str] = None,
access_token: Optional[str] = None,
priority: Optional[int] = 10,
) -> ModelInstallJob:
# Clean up a common source of error. Doesn't work with Paths.
if isinstance(source, str):
source = source.strip()
# In the event that we are being asked to install a path that is already on disk,
# we simply probe and register/install it. The job does not actually do anything, but we
# create one anyway in order to have similar behavior for local files, URLs and repo_ids.
if Path(source).exists(): # a path that is already on disk
destdir = source
return ModelInstallPathJob(source=source, destination=Path(destdir), event_handlers=self._handlers)
# choose a temporary directory inside the models directory
models_dir = self._app_config.models_path
self._tmpdir = self._tmpdir or tempfile.TemporaryDirectory(dir=models_dir)
cls = ModelInstallJob
if match := re.match(REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE, str(source)):
cls = ModelInstallRepoIDJob
source = match.group(1)
subfolder = match.group(2) or subfolder
kwargs = dict(variant=variant, subfolder=subfolder)
elif re.match(HTTP_RE, str(source)):
cls = ModelInstallURLJob
kwargs = {}
else:
raise ValueError(f"'{source}' is not recognized as a local file, directory, repo_id or URL")
return cls(
source=str(source),
destination=Path(self._tmpdir.name),
access_token=access_token,
priority=priority,
event_handlers=self._handlers,
**kwargs,
)
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]:
"""Pause until all installation jobs have completed."""
self._download_queue.join()
id_map = self._async_installs
self._async_installs = dict()
return id_map
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
self._cached_model_paths = set([Path(x.path) for x in self._store.all_models()])
callback = self._scan_install if install else self._scan_register
search = ModelSearch(on_model_found=callback)
self._installed = set()
search.search(scan_dir)
return list(self._installed)
def scan_models_directory(self):
"""
Scan the models directory for new and missing models.
New models will be added to the storage backend. Missing models
will be deleted.
"""
defunct_models = set()
installed = set()
with Chdir(self._app_config.models_path):
self._logger.info("Checking for models that have been moved or deleted from disk")
for model_config in self._store.all_models():
path = Path(model_config.path)
if not path.exists():
self._logger.info(f"{model_config.name}: path {path.as_posix()} no longer exists. Unregistering")
defunct_models.add(model_config.key)
for key in defunct_models:
self.unregister(key)
self._logger.info(f"Scanning {self._app_config.models_path} for new models")
for cur_base_model in BaseModelType:
for cur_model_type in ModelType:
models_dir = Path(cur_base_model.value, cur_model_type.value)
installed.update(self.scan_directory(models_dir))
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")
def sync_to_config(self):
"""Synchronize models on disk to those in memory."""
self.scan_models_directory()
if autoimport := self._app_config.autoimport_dir:
self._logger.info("Scanning autoimport directory for new models")
self.scan_directory(self._app_config.root_path / autoimport)
def install_model(
self,
source: Union[str, Path, AnyHttpUrl],
priority: int = 10,
model_attributes: Optional[Dict[str, Any]] = None,
) -> ModelInstallJob:
"""
Add a model using a path, repo_id or URL.
def hash(self, model_path: Union[Path, str]) -> str: # noqa D102
return FastModelHash.hash(model_path)
:param model_attributes: Dictionary of ModelConfigBase fields to
attach to the model. When installing a URL or repo_id, some metadata
values, such as `tags` will be automagically added.
:param priority: Queue priority for this install job. Lower value jobs
will run before higher value ones.
"""
self.logger.debug(f"add model {source}")
variant = "fp16" if self._precision == "float16" else None
job = self.install(
source,
probe_override=model_attributes,
variant=variant,
priority=priority,
)
assert isinstance(job, ModelInstallJob)
return job
def list_install_jobs(self) -> List[ModelInstallJob]:
"""Return a series of active or enqueued ModelInstallJobs."""
queue = self.queue
jobs: List[DownloadJobBase] = queue.list_jobs()
return [parse_obj_as(ModelInstallJob, x) for x in jobs] # downcast to proper type
def id_to_job(self, id: int) -> ModelInstallJob:
"""Return the ModelInstallJob instance corresponding to the given job ID."""
job = self.queue.id_to_job(id)
assert isinstance(job, ModelInstallJob)
return job
def start_job(self, job_id: int):
"""Start the given install job if it is paused or idle."""
queue = self.queue
queue.start_job(queue.id_to_job(job_id))
def pause_job(self, job_id: int):
"""Pause the given install job if it is paused or idle."""
queue = self.queue
queue.pause_job(queue.id_to_job(job_id))
def cancel_job(self, job_id: int):
"""Cancel the given install job."""
queue = self.queue
queue.cancel_job(queue.id_to_job(job_id))
def cancel_all_jobs(self):
"""Cancel all active install job."""
queue = self.queue
queue.cancel_all_jobs()
def prune_jobs(self):
"""Cancel all active install job."""
queue = self.queue
queue.prune_jobs()
def change_job_priority(self, job_id: int, delta: int):
"""
Change an install job's priority.
:param job_id: Job to change
:param delta: Value to increment or decrement priority.
Lower values are higher priority. The default starting value is 10.
Thus to make this a really high priority job:
manager.change_job_priority(-10).
"""
queue = self.queue
queue.change_priority(queue.id_to_job(job_id), delta)
def del_model(
self,
key: str,
delete_files: bool = False,
):
"""
Delete the named model from configuration.
If delete_files is true,
then the underlying weight file or diffusers directory will be deleted
as well.
"""
model_info = self.store.get_model(key)
self.logger.debug(f"delete model {model_info.name}")
self.store.del_model(key)
if delete_files and Path(self._app_config.models_path / model_info.path).exists():
path = Path(model_info.path)
if path.is_dir():
shutil.rmtree(path)
else:
path.unlink()
def convert_model(
self,
key: str,
dest_directory: Optional[Path] = None,
) -> ModelConfigBase:
"""
Convert a checkpoint file into a diffusers folder.
Delete the cached
version and delete the original checkpoint file if it is in the models
directory.
:param key: Key of the model to convert
:param convert_dest_directory: Save the converted model to the designated directory (`models/etc/etc` by default)
This will raise a ValueError unless the model is a checkpoint. It will
also raise a ValueError in the event that there is a similarly-named diffusers
directory already in place.
"""
model_info = self.store.get_model(key)
self.logger.info(f"Converting model {model_info.name} into a diffusers")
return super().convert_model(key, dest_directory)
@property
def logger(self):
"""Get the logger associated with this instance."""
return self._logger
@property
def store(self):
"""Get the store associated with this instance."""
return self._store
def merge_models(
self,
model_keys: List[str] = Field(
default=None, min_items=2, max_items=3, description="List of model keys to merge"
),
merged_model_name: Optional[str] = Field(default=None, description="Name of destination model after merging"),
alpha: Optional[float] = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
force: Optional[bool] = False,
merge_dest_directory: Optional[Path] = None,
) -> ModelConfigBase:
"""
Merge two to three diffusrs pipeline models and save as a new model.
:param model_keys: List of 2-3 model unique keys to merge
:param merged_model_name: Name of destination merged model
:param alpha: Alpha strength to apply to 2d and 3d model
:param interp: Interpolation method. None (default)
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
"""
merger = ModelMerger(self.store)
def _scan_register(self, model: Path) -> bool:
if model in self._cached_model_paths:
return True
try:
if not merged_model_name:
merged_model_name = "+".join([self.store.get_model(x).name for x in model_keys])
raise Exception("not implemented")
id = self.register_path(model)
self.sync_model_path(id) # possibly move it to right place in `models`
self._logger.info(f"Registered {model.name} with id {id}")
self._installed.add(id)
except DuplicateModelException:
pass
return True
result = merger.merge_diffusion_models_and_save(
model_keys=model_keys,
merged_model_name=merged_model_name,
alpha=alpha,
interp=interp,
force=force,
merge_dest_directory=merge_dest_directory,
)
except AssertionError as e:
raise ValueError(e)
return result
def search_for_models(self, directory: Path) -> Set[Path]:
"""
Return list of all models found in the designated directory.
:param directory: Path to the directory to recursively search.
returns a list of model paths
"""
return ModelSearch().search(directory)
def list_checkpoint_configs(self) -> List[Path]:
"""List the checkpoint config paths from ROOT/configs/stable-diffusion."""
config = self._app_config
conf_path = config.legacy_conf_path
root_path = config.root_path
return [(conf_path / x).relative_to(root_path) for x in conf_path.glob("**/*.yaml")]
def _scan_install(self, model: Path) -> bool:
if model in self._cached_model_paths:
return True
try:
id = self.install_path(model)
self._logger.info(f"Installed {model} with id {id}")
self._installed.add(id)
except DuplicateModelException:
pass
return True

View File

@ -6,11 +6,12 @@ from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from pydantic import Field
from pathlib import Path
from invokeai.app.models.exceptions import CanceledException
from invokeai.backend.model_manager import ModelConfigStore, SubModelType
from invokeai.backend.model_manager.cache import CacheStats
from invokeai.backend.model_manager.loader import ModelInfo, ModelLoad
from invokeai.backend.model_manager.loader import ModelInfo, ModelLoad, ModelConfigBase
from .config import InvokeAIAppConfig
from .model_record_service import ModelRecordServiceBase
@ -57,6 +58,8 @@ class ModelLoadServiceBase(ABC):
"""Reset model cache statistics for graph with graph_id."""
pass
# implementation
class ModelLoadService(ModelLoadServiceBase):
@ -137,3 +140,4 @@ class ModelLoadService(ModelLoadServiceBase):
model_key=model_key,
submodel=submodel,
)

View File

@ -8,6 +8,8 @@ from abc import abstractmethod
from pathlib import Path
from typing import Optional, Union
from invokeai.backend.model_manager import ModelConfigBase, ModelType, SubModelType
from invokeai.backend.model_manager.storage import (
ModelConfigStore,
ModelConfigStoreSQL,

View File

@ -14,7 +14,7 @@ import invokeai.configs as configs
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager import BaseModelType, ModelType
from invokeai.backend.model_manager.download.queue import DownloadJobRemoteSource
from invokeai.backend.model_manager.install import ModelInstall, ModelInstallJob, ModelSourceMetadata
from invokeai.app.services.model_install_service import ModelInstall, ModelInstallJob, ModelSourceMetadata
# name of the starter models file
INITIAL_MODELS = "INITIAL_MODELS.yaml"
@ -168,10 +168,9 @@ class InstallHelper(object):
self._add_required_models(selections.install_models)
for model in selections.install_models:
metadata = ModelSourceMetadata(description=model.description, name=model.name)
installer.install(
installer.install_model(
model.source,
subfolder=model.subfolder,
variant="fp16" if self._config.precision == "float16" else None,
access_token=HfFolder.get_token(),
metadata=metadata,
)

View File

@ -8,4 +8,4 @@ from .base import ( # noqa F401
UnknownJobIDException,
)
from .model_queue import ModelDownloadQueue, ModelSourceMetadata # noqa F401
from .queue import DownloadJobPath, DownloadJobURL, DownloadQueue # noqa F401
from .queue import DownloadJobPath, DownloadJobURL, DownloadQueue, DownloadJobRemoteSource # noqa F401

View File

@ -1,822 +0,0 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Install/delete models.
Typical usage:
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager import ModelInstall
from invokeai.backend.model_manager.storage import ModelConfigStoreSQL
from invokeai.backend.model_manager.download import DownloadQueue
config = InvokeAIAppConfig.get_config()
store = ModelConfigStoreSQL(config.db_path)
download = DownloadQueue()
installer = ModelInstall(store=store, config=config, download=download)
# register config, don't move path
id: str = installer.register_path('/path/to/model')
# register config, and install model in `models`
id: str = installer.install_path('/path/to/model')
# download some remote models and install them in the background
installer.install('stabilityai/stable-diffusion-2-1')
installer.install('https://civitai.com/api/download/models/154208')
installer.install('runwayml/stable-diffusion-v1-5')
installer.install('/home/user/models/stable-diffusion-v1-5', inplace=True)
installed_ids = installer.wait_for_installs()
id1 = installed_ids['stabilityai/stable-diffusion-2-1']
id2 = installed_ids['https://civitai.com/api/download/models/154208']
# unregister, don't delete
installer.unregister(id)
# unregister and delete model from disk
installer.delete_model(id)
# scan directory recursively and install all new models found
ids: List[str] = installer.scan_directory('/path/to/directory')
# Synchronize with the models directory, adding missing models and
# removing orphans
installer.scan_models_directory()
hash: str = installer.hash('/path/to/model') # should be same as id above
The following exceptions may be raised:
DuplicateModelException
UnknownModelTypeException
"""
import re
import tempfile
from abc import ABC, abstractmethod
from pathlib import Path
from shutil import move, rmtree
from typing import Any, Callable, Dict, List, Optional, Set, Union
from pydantic import Field
from pydantic.networks import AnyHttpUrl
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_record_service import ModelRecordServiceBase
from invokeai.backend.util import Chdir, InvokeAILogger, Logger
from .config import (
BaseModelType,
ModelConfigBase,
ModelFormat,
ModelType,
ModelVariantType,
SchedulerPredictionType,
SubModelType,
)
from .download import (
DownloadEventHandler,
DownloadJobBase,
DownloadJobPath,
DownloadJobURL,
DownloadQueueBase,
ModelDownloadQueue,
ModelSourceMetadata,
)
from .download.model_queue import (
HTTP_RE,
REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE,
DownloadJobRepoID,
DownloadJobWithMetadata,
)
from .hash import FastModelHash
from .models import InvalidModelException
from .probe import ModelProbe, ModelProbeInfo
from .search import ModelSearch
from .storage import DuplicateModelException, ModelConfigStore
class ModelInstallJob(DownloadJobBase):
"""This is a version of DownloadJobBase that has an additional slot for the model key and probe info."""
model_key: Optional[str] = Field(
description="After model installation, this field will hold its primary key", default=None
)
probe_override: Optional[Dict[str, Any]] = Field(
description="Keys in this dict will override like-named attributes in the automatic probe info",
default=None,
)
class ModelInstallURLJob(DownloadJobWithMetadata, ModelInstallJob):
"""Job for installing URLs."""
class ModelInstallRepoIDJob(DownloadJobRepoID, ModelInstallJob):
"""Job for installing repo ids."""
class ModelInstallPathJob(DownloadJobPath, ModelInstallJob):
"""Job for installing local paths."""
ModelInstallEventHandler = Callable[["ModelInstallJob"], None]
class ModelInstallBase(ABC):
"""Abstract base class for InvokeAI model installation."""
@abstractmethod
def __init__(
self,
config: Optional[InvokeAIAppConfig] = None,
store: Optional[ModelConfigStore] = None,
logger: Optional[InvokeAILogger] = None,
download: Optional[DownloadQueueBase] = None,
event_handlers: Optional[List[DownloadEventHandler]] = None,
):
"""
Create ModelInstall object.
:param store: Optional ModelConfigStore. If None passed,
defaults to `configs/models.yaml`.
:param config: Optional InvokeAIAppConfig. If None passed,
uses the system-wide default app config.
:param logger: Optional InvokeAILogger. If None passed,
uses the system-wide default logger.
:param download: Optional DownloadQueueBase object. If None passed,
a default queue object will be created.
:param event_handlers: List of event handlers to pass to the queue object.
"""
pass
@property
@abstractmethod
def queue(self) -> DownloadQueueBase:
"""Return the download queue used by the installer."""
pass
@property
@abstractmethod
def store(self) -> ModelConfigStore:
"""Return the storage backend used by the installer."""
pass
@property
@abstractmethod
def config(self) -> InvokeAIAppConfig:
"""Return the app_config used by the installer."""
pass
@abstractmethod
def register_path(self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]]) -> str:
"""
Probe and register the model at model_path.
:param model_path: Filesystem Path to the model.
:param overrides: Dict of attributes that will override probed values.
:returns id: The string ID of the registered model.
"""
pass
@abstractmethod
def install_path(self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]] = None) -> str:
"""
Probe, register and install the model in the models directory.
This involves moving the model from its current location into
the models directory handled by InvokeAI.
:param model_path: Filesystem Path to the model.
:param overrides: Dictionary of model probe info fields that, if present, override probed values.
:returns id: The string ID of the installed model.
"""
pass
@abstractmethod
def install(
self,
source: Union[str, Path, AnyHttpUrl],
inplace: bool = True,
priority: int = 10,
variant: Optional[str] = None,
subfolder: Optional[str] = None,
probe_override: Optional[Dict[str, Any]] = None,
metadata: Optional[ModelSourceMetadata] = None,
access_token: Optional[str] = None,
) -> DownloadJobBase:
"""
Download and install the indicated model.
This will download the model located at `source`,
probe it, and install it into the models directory.
This call is executed asynchronously in a separate
thread, and the returned object is a
invokeai.backend.model_manager.download.DownloadJobBase
object which can be interrogated to get the status of
the download and install process. Call our `wait_for_installs()`
method to wait for all downloads and installations to complete.
:param source: Either a URL or a HuggingFace repo_id.
:param inplace: If True, local paths will not be moved into
the models directory, but registered in place (the default).
:param variant: For HuggingFace models, this optional parameter
specifies which variant to download (e.g. 'fp16')
:param subfolder: When downloading HF repo_ids this can be used to
specify a subfolder of the HF repository to download from.
:param probe_override: Optional dict. Any fields in this dict
will override corresponding probe fields. Use it to override
`base_type`, `model_type`, `format`, `prediction_type` and `image_size`.
:param metadata: Use this to override the fields 'description`,
`author`, `tags`, `source` and `license`.
:returns DownloadQueueBase object.
The `inplace` flag does not affect the behavior of downloaded
models, which are always moved into the `models` directory.
Variants recognized by HuggingFace currently are:
1. onnx
2. openvino
3. fp16
4. None (usually returns fp32 model)
"""
pass
@abstractmethod
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]:
"""
Wait for all pending installs to complete.
This will block until all pending downloads have
completed, been cancelled, or errored out. It will
block indefinitely if one or more jobs are in the
paused state.
It will return a dict that maps the source model
path, URL or repo_id to the ID of the installed model.
"""
pass
@abstractmethod
def unregister(self, id: str):
"""
Unregister the model identified by id.
This removes the model from the registry without
deleting the underlying model from disk.
:param id: The string ID of the model to forget.
:raises UnknownModelException: In the event the ID is unknown.
"""
pass
@abstractmethod
def delete(self, id: str):
"""
Unregister and delete the model identified by id.
This removes the model from the registry and
deletes the underlying model from disk.
:param id: The string ID of the model to forget.
:raises UnknownModelException: In the event the ID is unknown.
:raises OSError: In the event the model cannot be deleted from disk.
"""
pass
@abstractmethod
def conditionally_delete(self, key: str): # noqa D102
"""Unregister the model. Delete its files only if they are within our models directory."""
pass
@abstractmethod
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]:
"""
Recursively scan directory for new models and register or install them.
:param scan_dir: Path to the directory to scan.
:param install: Install if True, otherwise register in place.
:returns list of IDs: Returns list of IDs of models registered/installed
"""
pass
@abstractmethod
def hash(self, model_path: Union[Path, str]) -> str:
"""
Compute and return the fast hash of the model.
:param model_path: Path to the model on disk.
:return str: FastHash of the model for use as an ID.
"""
pass
@abstractmethod
def convert_model(
self,
key: str,
dest_directory: Optional[Path] = None,
) -> ModelConfigBase:
"""
Convert a checkpoint file into a diffusers folder.
It will delete the cached version ans well as the
original checkpoint file if it is in the models directory.
:param key: Unique key of model.
:dest_directory: Optional place to put converted file. If not specified,
will be stored in the `models_dir`.
This will raise a ValueError unless the model is a checkpoint.
This will raise an UnknownModelException if key is unknown.
"""
pass
@abstractmethod
def sync_model_path(self, key) -> ModelConfigBase:
"""
Move model into the location indicated by its basetype, type and name.
Call this after updating a model's attributes in order to move
the model's path into the location indicated by its basetype, type and
name. Applies only to models whose paths are within the root `models_dir`
directory.
May raise an UnknownModelException.
"""
pass
@abstractmethod
def sync_to_config(self):
"""Synchronize models on disk to those in memory."""
pass
@abstractmethod
def scan_models_directory(self):
"""
Scan the models directory for new and missing models.
New models will be added to the storage backend. Missing models
will be deleted.
"""
pass
class ModelInstall(ModelInstallBase):
"""Model installer class handles installation from a local path."""
_app_config: InvokeAIAppConfig
_logger: Logger
_store: ModelConfigStore
_download_queue: DownloadQueueBase
_async_installs: Dict[Union[str, Path, AnyHttpUrl], Optional[str]]
_installed: Set[str] = Field(default=set)
_tmpdir: Optional[tempfile.TemporaryDirectory] # used for downloads
_cached_model_paths: Set[Path] = Field(default=set) # used to speed up directory scanning
_legacy_configs: Dict[BaseModelType, Dict[ModelVariantType, Union[str, dict]]] = {
BaseModelType.StableDiffusion1: {
ModelVariantType.Normal: "v1-inference.yaml",
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
},
BaseModelType.StableDiffusion2: {
ModelVariantType.Normal: {
SchedulerPredictionType.Epsilon: "v2-inference.yaml",
SchedulerPredictionType.VPrediction: "v2-inference-v.yaml",
},
ModelVariantType.Inpaint: {
SchedulerPredictionType.Epsilon: "v2-inpainting-inference.yaml",
SchedulerPredictionType.VPrediction: "v2-inpainting-inference-v.yaml",
},
},
BaseModelType.StableDiffusionXL: {
ModelVariantType.Normal: "sd_xl_base.yaml",
},
BaseModelType.StableDiffusionXLRefiner: {
ModelVariantType.Normal: "sd_xl_refiner.yaml",
},
}
def __init__(
self,
store: Optional[ModelConfigStore] = None,
config: Optional[InvokeAIAppConfig] = None,
logger: Optional[Logger] = None,
download: Optional[DownloadQueueBase] = None,
event_handlers: List[DownloadEventHandler] = [],
): # noqa D107 - use base class docstrings
self._app_config = config or InvokeAIAppConfig.get_config()
self._logger = logger or InvokeAILogger.get_logger(config=self._app_config)
self._store = store or ModelRecordServiceBase.get_impl(self._app_config)
self._download_queue = download or ModelDownloadQueue(config=self._app_config, event_handlers=event_handlers)
self._async_installs: Dict[Union[str, Path, AnyHttpUrl], Union[str, None]] = dict()
self._installed = set()
self._tmpdir = None
@property
def queue(self) -> DownloadQueueBase:
"""Return the queue."""
return self._download_queue
@property
def store(self) -> ModelConfigStore:
"""Return the storage backend used by the installer."""
return self._store
@property
def config(self) -> InvokeAIAppConfig:
"""Return the app_config used by the installer."""
return self._app_config
def register_path(
self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]] = None
) -> str: # noqa D102
model_path = Path(model_path)
info: ModelProbeInfo = self._probe_model(model_path, overrides)
return self._register(model_path, info)
def _register(self, model_path: Path, info: ModelProbeInfo) -> str:
key: str = FastModelHash.hash(model_path)
model_path = model_path.absolute()
if model_path.is_relative_to(self._app_config.models_path):
model_path = model_path.relative_to(self._app_config.models_path)
registration_data = dict(
path=model_path.as_posix(),
name=model_path.name if model_path.is_dir() else model_path.stem,
base_model=info.base_type,
model_type=info.model_type,
model_format=info.format,
hash=key,
)
# add 'main' specific fields
if info.model_type == ModelType.Main:
if info.variant_type:
registration_data.update(variant=info.variant_type)
if info.format == ModelFormat.Checkpoint:
try:
config_file = self._legacy_configs[info.base_type][info.variant_type]
if isinstance(config_file, dict): # need another tier for sd-2.x models
if prediction_type := info.prediction_type:
config_file = config_file[prediction_type]
else:
self._logger.warning(
f"Could not infer prediction type for {model_path.stem}. Guessing 'v_prediction' for a SD-2 768 pixel model"
)
config_file = config_file[SchedulerPredictionType.VPrediction]
registration_data.update(
config=Path(self._app_config.legacy_conf_dir, str(config_file)).as_posix(),
)
except KeyError as exc:
raise InvalidModelException(
"Configuration file for this checkpoint could not be determined"
) from exc
self._store.add_model(key, registration_data)
return key
def install_path(
self,
model_path: Union[Path, str],
overrides: Optional[Dict[str, Any]] = None,
) -> str: # noqa D102
model_path = Path(model_path)
info: ModelProbeInfo = self._probe_model(model_path, overrides)
dest_path = self._app_config.models_path / info.base_type.value / info.model_type.value / model_path.name
new_path = self._move_model(model_path, dest_path)
new_hash = self.hash(new_path)
assert new_hash == info.hash, f"{model_path}: Model hash changed during installation, possibly corrupted."
return self._register(
new_path,
info,
)
def _move_model(self, old_path: Path, new_path: Path) -> Path:
if old_path == new_path:
return old_path
new_path.parent.mkdir(parents=True, exist_ok=True)
# if path already exists then we jigger the name to make it unique
counter: int = 1
while new_path.exists():
path = new_path.with_stem(new_path.stem + f"_{counter:02d}")
if not path.exists():
new_path = path
counter += 1
return move(old_path, new_path)
def _probe_model(self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]] = None) -> ModelProbeInfo:
info: ModelProbeInfo = ModelProbe.probe(Path(model_path))
if overrides: # used to override probe fields
for key, value in overrides.items():
try:
setattr(info, key, value) # skip validation errors
except Exception:
pass
return info
def unregister(self, key: str): # noqa D102
self._store.del_model(key)
def delete(self, key: str): # noqa D102
model = self._store.get_model(key)
path = self._app_config.models_path / model.path
if path.is_dir():
rmtree(path)
else:
path.unlink()
self.unregister(key)
def conditionally_delete(self, key: str): # noqa D102
"""Unregister the model. Delete its files only if they are within our models directory."""
model = self._store.get_model(key)
models_dir = self._app_config.models_path
model_path = models_dir / model.path
if model_path.is_relative_to(models_dir):
self.delete(key)
else:
self.unregister(key)
def install(
self,
source: Union[str, Path, AnyHttpUrl],
inplace: bool = True,
priority: int = 10,
variant: Optional[str] = None,
subfolder: Optional[str] = None,
probe_override: Optional[Dict[str, Any]] = None,
metadata: Optional[ModelSourceMetadata] = None,
access_token: Optional[str] = None,
) -> DownloadJobBase: # noqa D102
queue = self._download_queue
job = self._make_download_job(
source, variant=variant, access_token=access_token, subfolder=subfolder, priority=priority
)
handler = (
self._complete_registration_handler
if inplace and Path(source).exists()
else self._complete_installation_handler
)
if isinstance(job, ModelInstallJob):
job.probe_override = probe_override
if metadata and isinstance(job, DownloadJobWithMetadata):
job.metadata = metadata
job.add_event_handler(handler)
self._async_installs[source] = None
queue.submit_download_job(job, True)
return job
def _complete_installation_handler(self, job: DownloadJobBase):
assert isinstance(job, ModelInstallJob)
if job.status == "completed":
self._logger.info(f"{job.source}: Download finished with status {job.status}. Installing.")
model_id = self.install_path(job.destination, job.probe_override)
info = self._store.get_model(model_id)
info.source = str(job.source)
if isinstance(job, DownloadJobWithMetadata):
metadata: ModelSourceMetadata = job.metadata
info.description = metadata.description or f"Imported model {info.name}"
info.name = metadata.name or info.name
info.author = metadata.author
info.tags = metadata.tags
info.license = metadata.license
info.thumbnail_url = metadata.thumbnail_url
self._store.update_model(model_id, info)
self._async_installs[job.source] = model_id
job.model_key = model_id
elif job.status == "error":
self._logger.warning(f"{job.source}: Model installation error: {job.error}")
elif job.status == "cancelled":
self._logger.warning(f"{job.source}: Model installation cancelled at caller's request.")
jobs = self._download_queue.list_jobs()
if self._tmpdir and len(jobs) <= 1 and job.status in ["completed", "error", "cancelled"]:
self._tmpdir.cleanup()
self._tmpdir = None
def _complete_registration_handler(self, job: DownloadJobBase):
assert isinstance(job, ModelInstallJob)
if job.status == "completed":
self._logger.info(f"{job.source}: Installing in place.")
model_id = self.register_path(job.destination, job.probe_override)
info = self._store.get_model(model_id)
info.source = str(job.source)
info.description = f"Imported model {info.name}"
self._store.update_model(model_id, info)
self._async_installs[job.source] = model_id
job.model_key = model_id
elif job.status == "error":
self._logger.warning(f"{job.source}: Model installation error: {job.error}")
elif job.status == "cancelled":
self._logger.warning(f"{job.source}: Model installation cancelled at caller's request.")
def sync_model_path(self, key: str, ignore_hash_change: bool = False) -> ModelConfigBase:
"""
Move model into the location indicated by its basetype, type and name.
Call this after updating a model's attributes in order to move
the model's path into the location indicated by its basetype, type and
name. Applies only to models whose paths are within the root `models_dir`
directory.
May raise an UnknownModelException.
"""
model = self._store.get_model(key)
old_path = Path(model.path)
models_dir = self._app_config.models_path
if not old_path.is_relative_to(models_dir):
return model
new_path = models_dir / model.base_model.value / model.model_type.value / model.name
self._logger.info(f"Moving {model.name} to {new_path}.")
new_path = self._move_model(old_path, new_path)
model.hash = self.hash(new_path)
model.path = new_path.relative_to(models_dir).as_posix()
if model.hash != key:
assert (
ignore_hash_change
), f"{model.name}: Model hash changed during installation, model is possibly corrupted"
self._logger.info(f"Model has new hash {model.hash}, but will continue to be identified by {key}")
self._store.update_model(key, model)
return model
def _make_download_job(
self,
source: Union[str, Path, AnyHttpUrl],
variant: Optional[str] = None,
subfolder: Optional[str] = None,
access_token: Optional[str] = None,
priority: Optional[int] = 10,
) -> ModelInstallJob:
# Clean up a common source of error. Doesn't work with Paths.
if isinstance(source, str):
source = source.strip()
# In the event that we are being asked to install a path that is already on disk,
# we simply probe and register/install it. The job does not actually do anything, but we
# create one anyway in order to have similar behavior for local files, URLs and repo_ids.
if Path(source).exists(): # a path that is already on disk
destdir = source
return ModelInstallPathJob(source=source, destination=Path(destdir))
# choose a temporary directory inside the models directory
models_dir = self._app_config.models_path
self._tmpdir = self._tmpdir or tempfile.TemporaryDirectory(dir=models_dir)
cls = ModelInstallJob
if match := re.match(REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE, str(source)):
cls = ModelInstallRepoIDJob
source = match.group(1)
subfolder = match.group(2) or subfolder
kwargs = dict(variant=variant, subfolder=subfolder)
elif re.match(HTTP_RE, str(source)):
cls = ModelInstallURLJob
kwargs = {}
else:
raise ValueError(f"'{source}' is not recognized as a local file, directory, repo_id or URL")
return cls(
source=str(source),
destination=Path(self._tmpdir.name),
access_token=access_token,
priority=priority,
**kwargs,
)
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]:
"""Pause until all installation jobs have completed."""
self._download_queue.join()
id_map = self._async_installs
self._async_installs = dict()
return id_map
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
self._cached_model_paths = set([Path(x.path) for x in self._store.all_models()])
callback = self._scan_install if install else self._scan_register
search = ModelSearch(on_model_found=callback)
self._installed = set()
search.search(scan_dir)
return list(self._installed)
def hash(self, model_path: Union[Path, str]) -> str: # noqa D102
return FastModelHash.hash(model_path)
def convert_model(
self,
key: str,
dest_directory: Optional[Path] = None,
) -> ModelConfigBase:
"""
Convert a checkpoint file into a diffusers folder.
It will delete the cached version ans well as the
original checkpoint file if it is in the models directory.
:param key: Unique key of model.
:dest_directory: Optional place to put converted file. If not specified,
will be stored in the `models_dir`.
This will raise a ValueError unless the model is a checkpoint.
This will raise an UnknownModelException if key is unknown.
"""
from .loader import ModelInfo, ModelLoad # to avoid circular imports
new_diffusers_path = None
try:
info: ModelConfigBase = self._store.get_model(key)
if info.model_format != "checkpoint":
raise ValueError(f"not a checkpoint format model: {info.name}")
# We are taking advantage of a side effect of get_model() that converts check points
# into cached diffusers directories stored at `path`. It doesn't matter
# what submodel type we request here, so we get the smallest.
loader = ModelLoad(self._app_config, self.store)
submodel = {"submodel_type": SubModelType.Scheduler} if info.model_type == ModelType.Main else {}
converted_model: ModelInfo = loader.get_model(key, **submodel)
checkpoint_path = loader.resolve_model_path(info.path)
old_diffusers_path = loader.resolve_model_path(converted_model.location)
# new values to write in
update = info.dict()
update.pop("config")
update["model_format"] = "diffusers"
update["path"] = str(converted_model.location)
if dest_directory:
new_diffusers_path = Path(dest_directory) / info.name
if new_diffusers_path.exists():
raise ValueError(f"A diffusers model already exists at {new_diffusers_path}")
move(old_diffusers_path, new_diffusers_path)
update["path"] = new_diffusers_path.as_posix()
self._store.update_model(key, update)
result = self.sync_model_path(key, ignore_hash_change=True)
except Exception as excp:
# something went wrong, so don't leave dangling diffusers model in directory or it will cause a duplicate model error!
if new_diffusers_path:
rmtree(new_diffusers_path)
raise excp
if checkpoint_path.exists() and checkpoint_path.is_relative_to(self._app_config.models_path):
checkpoint_path.unlink()
return result
# the following two methods are callbacks to the ModelSearch object
def _scan_register(self, model: Path) -> bool:
if model in self._cached_model_paths:
return True
try:
id = self.register_path(model)
self.sync_model_path(id) # possibly move it to right place in `models`
self._logger.info(f"Registered {model.name} with id {id}")
self._installed.add(id)
except DuplicateModelException:
pass
return True
def _scan_install(self, model: Path) -> bool:
if model in self._cached_model_paths:
return True
try:
id = self.install_path(model)
self._logger.info(f"Installed {model} with id {id}")
self._installed.add(id)
except DuplicateModelException:
pass
return True
def sync_to_config(self):
"""Synchronize models on disk to those in memory."""
self.scan_models_directory()
if autoimport := self._app_config.autoimport_dir:
self._logger.info("Scanning autoimport directory for new models")
self.scan_directory(self._app_config.root_path / autoimport)
def scan_models_directory(self):
"""
Scan the models directory for new and missing models.
New models will be added to the storage backend. Missing models
will be deleted.
"""
defunct_models = set()
installed = set()
with Chdir(self._app_config.models_path):
self._logger.info("Checking for models that have been moved or deleted from disk")
for model_config in self._store.all_models():
path = Path(model_config.path)
if not path.exists():
self._logger.info(f"{model_config.name}: path {path.as_posix()} no longer exists. Unregistering")
defunct_models.add(model_config.key)
for key in defunct_models:
self.unregister(key)
self._logger.info(f"Scanning {self._app_config.models_path} for new models")
for cur_base_model in BaseModelType:
for cur_model_type in ModelType:
models_dir = Path(cur_base_model.value, cur_model_type.value)
installed.update(self.scan_directory(models_dir))
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")

View File

@ -6,6 +6,7 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple, Union
from shutil import move, rmtree
import torch

View File

@ -16,10 +16,10 @@ from diffusers import logging as dlogging
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_install_service import ModelInstallService
from . import BaseModelType, ModelConfigBase, ModelConfigStore, ModelType
from .config import MainConfig
from .loader import ModelLoad
class MergeInterpolationMethod(str, Enum):
@ -151,7 +151,7 @@ class ModelMerger(object):
merged_pipe.save_pretrained(dump_path, safe_serialization=True)
# register model and get its unique key
installer = ModelInstall(store=self._store, config=self._config)
installer = ModelInstallService(store=self._store, config=self._config)
key = installer.register_path(dump_path)
# update model's config

View File

@ -7,8 +7,8 @@ import torch
from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig
from invokeai.app.services.model_record_service import ModelRecordServiceBase
from invokeai.app.services.model_install_service import ModelInstallService
from invokeai.backend.model_manager import BaseModelType, ModelType, SubModelType, UnknownModelException
from invokeai.backend.model_manager.install import ModelInstall
from invokeai.backend.model_manager.loader import ModelInfo, ModelLoad
@ -30,11 +30,11 @@ def model_installer():
#
config = InvokeAIAppConfig(log_level="info")
model_store = ModelRecordServiceBase.get_impl(config)
return ModelInstall(model_store, config)
return ModelInstallService(store=model_store, config=config)
def install_and_load_model(
model_installer: ModelInstall,
model_installer: ModelInstallService,
model_path_id_or_url: Union[str, Path],
model_name: str,
base_model: BaseModelType,

View File

@ -26,9 +26,9 @@ from pydantic import BaseModel
import invokeai.configs as configs
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_install_service import ModelInstall, ModelInstallJob
from invokeai.backend.install.install_helper import InstallHelper, UnifiedModelInfo
from invokeai.backend.model_manager import BaseModelType, ModelType
from invokeai.backend.model_manager.install import ModelInstall, ModelInstallJob
from invokeai.backend.util import choose_precision, choose_torch_device
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.frontend.install.widgets import (

View File

@ -59,6 +59,7 @@ def mock_services() -> InvocationServices:
conn=db_conn, table_name="graph_executions", lock=lock
)
return InvocationServices(
download_queue=None, # type: ignore
model_loader=None, # type: ignore
model_installer=None, # type: ignore
model_record_store=None, # type: ignore

View File

@ -3,4 +3,5 @@
# We import the model_installer and torch_device fixtures here so that they can be used by all tests. Flake8 does not
# play well with fixtures (F401 and F811), so this is cleaner than importing in all files that use these fixtures.
from invokeai.backend.util.test_utils import model_installer, torch_device # noqa: F401
from invokeai.backend.util.test_utils import torch_device # noqa: F401
from invokeai.app.services.model_install_service import ModelInstallService # noqa: F401