mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor installer class hierarchy
This commit is contained in:
@ -26,6 +26,7 @@ from ..services.invocation_services import InvocationServices
|
||||
from ..services.invocation_stats import InvocationStatsService
|
||||
from ..services.invoker import Invoker
|
||||
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
from ..services.download_manager import DownloadQueueService
|
||||
from ..services.model_install_service import ModelInstallService
|
||||
from ..services.model_loader_service import ModelLoadService
|
||||
from ..services.model_record_service import ModelRecordServiceBase
|
||||
@ -129,9 +130,14 @@ class ApiDependencies:
|
||||
)
|
||||
)
|
||||
|
||||
download_queue = DownloadQueueService(event_bus=events, config=config)
|
||||
model_record_store = ModelRecordServiceBase.get_impl(config, conn=db_conn, lock=lock)
|
||||
model_loader = ModelLoadService(config, model_record_store)
|
||||
model_installer = ModelInstallService(config, model_record_store, events)
|
||||
model_installer = ModelInstallService(config,
|
||||
queue=download_queue,
|
||||
store=model_record_store,
|
||||
event_bus=events
|
||||
)
|
||||
|
||||
services = InvocationServices(
|
||||
events=events,
|
||||
@ -146,6 +152,7 @@ class ApiDependencies:
|
||||
configuration=config,
|
||||
performance_statistics=InvocationStatsService(graph_execution_manager),
|
||||
logger=logger,
|
||||
download_queue=download_queue,
|
||||
model_record_store=model_record_store,
|
||||
model_loader=model_loader,
|
||||
model_installer=model_installer,
|
||||
|
@ -19,9 +19,11 @@ from invokeai.backend.model_manager import (
|
||||
ModelConfigBase,
|
||||
SchedulerPredictionType,
|
||||
UnknownModelException,
|
||||
ModelSearch
|
||||
)
|
||||
from invokeai.backend.model_manager.download import DownloadJobStatus, UnknownJobIDException
|
||||
from invokeai.backend.model_manager.merge import MergeInterpolationMethod
|
||||
from invokeai.app.services.download_manager import DownloadJobStatus, UnknownJobIDException, DownloadJobRemoteSource
|
||||
from invokeai.app.services.model_convert import MergeInterpolationMethod, ModelConvert
|
||||
from invokeai.app.services.model_install_service import ModelInstallJob
|
||||
|
||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||
|
||||
@ -39,7 +41,7 @@ class ModelsList(BaseModel):
|
||||
models: List[InvokeAIModelConfig]
|
||||
|
||||
|
||||
class ModelImportStatus(BaseModel):
|
||||
class ModelDownloadStatus(BaseModel):
|
||||
"""Return information about a background installation job."""
|
||||
|
||||
job_id: int
|
||||
@ -56,7 +58,6 @@ class JobControlOperation(str, Enum):
|
||||
CANCEL = "Cancel"
|
||||
CHANGE_PRIORITY = "Change Priority"
|
||||
|
||||
|
||||
@models_router.get(
|
||||
"/",
|
||||
operation_id="list_models",
|
||||
@ -135,7 +136,7 @@ async def update_model(
|
||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=ModelImportStatus,
|
||||
response_model=ModelDownloadStatus,
|
||||
)
|
||||
async def import_model(
|
||||
location: str = Body(description="A model path, repo_id or URL to import"),
|
||||
@ -147,7 +148,7 @@ async def import_model(
|
||||
description="Which import jobs run first. Lower values run before higher ones.",
|
||||
default=10,
|
||||
),
|
||||
) -> ModelImportStatus:
|
||||
) -> ModelDownloadStatus:
|
||||
"""
|
||||
Add a model using its local path, repo_id, or remote URL.
|
||||
|
||||
@ -172,10 +173,10 @@ async def import_model(
|
||||
installer = ApiDependencies.invoker.services.model_installer
|
||||
result = installer.install_model(
|
||||
location,
|
||||
model_attributes={"prediction_type": SchedulerPredictionType(prediction_type)},
|
||||
probe_override={"prediction_type": SchedulerPredictionType(prediction_type) if prediction_type else None},
|
||||
priority=priority,
|
||||
)
|
||||
return ModelImportStatus(
|
||||
return ModelDownloadStatus(
|
||||
job_id=result.id,
|
||||
source=result.source,
|
||||
priority=result.priority,
|
||||
@ -288,8 +289,12 @@ async def convert_model(
|
||||
"""Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none."""
|
||||
try:
|
||||
dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None
|
||||
installer = ApiDependencies.invoker.services.model_installer
|
||||
model_config = installer.convert_model(key, dest_directory=dest)
|
||||
converter = ModelConvert(
|
||||
loader=ApiDependencies.invoker.services.model_loader,
|
||||
installer=ApiDependencies.invoker.services.model_installer,
|
||||
store=ApiDependencies.invoker.services.model_record_store
|
||||
)
|
||||
model_config = converter.convert_model(key, dest_directory=dest)
|
||||
response = parse_obj_as(InvokeAIModelConfig, model_config.dict())
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=f"Model '{key}' not found: {str(e)}")
|
||||
@ -316,8 +321,7 @@ async def search_for_models(
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory"
|
||||
)
|
||||
return ApiDependencies.invoker.services.model_installer.search_for_models(search_path)
|
||||
|
||||
return ModelSearch().search(search_path)
|
||||
|
||||
@models_router.get(
|
||||
"/ckpt_confs",
|
||||
@ -330,7 +334,10 @@ async def search_for_models(
|
||||
)
|
||||
async def list_ckpt_configs() -> List[pathlib.Path]:
|
||||
"""Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT."""
|
||||
return ApiDependencies.invoker.services.model_installer.list_checkpoint_configs()
|
||||
config = ApiDependencies.invoker.services.configuration
|
||||
conf_path = config.legacy_conf_path
|
||||
root_path = config.root_path
|
||||
return [(conf_path / x).relative_to(root_path) for x in conf_path.glob("**/*.yaml")]
|
||||
|
||||
|
||||
@models_router.post(
|
||||
@ -349,7 +356,8 @@ async def sync_to_config() -> bool:
|
||||
Call after making changes to models.yaml, autoimport directories
|
||||
or models directory.
|
||||
"""
|
||||
ApiDependencies.invoker.services.model_installer.sync_to_config()
|
||||
installer = ApiDependencies.invoker.services.model_installer
|
||||
installer.sync_to_config()
|
||||
return True
|
||||
|
||||
|
||||
@ -383,7 +391,12 @@ async def merge_models(
|
||||
try:
|
||||
logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
||||
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
||||
result: ModelConfigBase = ApiDependencies.invoker.services.model_installer.merge_models(
|
||||
converter = ModelConvert(
|
||||
loader=ApiDependencies.invoker.services.model_loader,
|
||||
installer=ApiDependencies.invoker.services.model_installer,
|
||||
store=ApiDependencies.invoker.services.model_record_store
|
||||
)
|
||||
result: ModelConfigBase = converter.merge_models(
|
||||
model_keys=keys,
|
||||
merged_model_name=merged_model_name,
|
||||
alpha=alpha,
|
||||
@ -409,14 +422,14 @@ async def merge_models(
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
status_code=200,
|
||||
response_model=List[ModelImportStatus],
|
||||
response_model=List[ModelDownloadStatus],
|
||||
)
|
||||
async def list_install_jobs() -> List[ModelImportStatus]:
|
||||
async def list_install_jobs() -> List[ModelDownloadStatus]:
|
||||
"""List active and pending model installation jobs."""
|
||||
job_mgr = ApiDependencies.invoker.services.model_installer
|
||||
jobs = job_mgr.list_install_jobs()
|
||||
job_mgr = ApiDependencies.invoker.services.download_queue
|
||||
jobs = job_mgr.list_jobs()
|
||||
return [
|
||||
ModelImportStatus(
|
||||
ModelDownloadStatus(
|
||||
job_id=x.id,
|
||||
source=x.source,
|
||||
priority=x.priority,
|
||||
@ -424,32 +437,32 @@ async def list_install_jobs() -> List[ModelImportStatus]:
|
||||
total_bytes=x.total_bytes,
|
||||
status=x.status,
|
||||
)
|
||||
for x in jobs
|
||||
for x in jobs if isinstance(x, ModelInstallJob)
|
||||
]
|
||||
|
||||
|
||||
@models_router.patch(
|
||||
"/jobs/control/{operation}/{job_id}",
|
||||
operation_id="control_install_jobs",
|
||||
operation_id="control_download_jobs",
|
||||
responses={
|
||||
200: {"description": "The control job was updated successfully"},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "The job could not be found"},
|
||||
},
|
||||
status_code=200,
|
||||
response_model=ModelImportStatus,
|
||||
response_model=ModelDownloadStatus,
|
||||
)
|
||||
async def control_install_jobs(
|
||||
job_id: int = Path(description="Install job_id for start, pause and cancel operations"),
|
||||
async def control_download_jobs(
|
||||
job_id: int = Path(description="Download/install job_id for start, pause and cancel operations"),
|
||||
operation: JobControlOperation = Path(description="The operation to perform on the job."),
|
||||
priority_delta: Optional[int] = Body(
|
||||
description="Change in job priority for priority operations only. Negative numbers increase priority.",
|
||||
default=None,
|
||||
),
|
||||
) -> ModelImportStatus:
|
||||
) -> ModelDownloadStatus:
|
||||
"""Start, pause, cancel, or change the run priority of a running model install job."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
job_mgr = ApiDependencies.invoker.services.model_installer
|
||||
job_mgr = ApiDependencies.invoker.services.download_queue
|
||||
try:
|
||||
job = job_mgr.id_to_job(job_id)
|
||||
|
||||
@ -467,14 +480,19 @@ async def control_install_jobs(
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown operation {operation.value}")
|
||||
bytes = 0
|
||||
total_bytes = 0
|
||||
if isinstance(job, DownloadJobRemoteSource):
|
||||
bytes = job.bytes
|
||||
total_bytes = job.total_bytes
|
||||
|
||||
return ModelImportStatus(
|
||||
return ModelDownloadStatus(
|
||||
job_id=job_id,
|
||||
source=job.source,
|
||||
priority=job.priority,
|
||||
status=job.status,
|
||||
bytes=job.bytes,
|
||||
total_bytes=job.total_bytes,
|
||||
bytes=bytes,
|
||||
total_bytes=total_bytes,
|
||||
)
|
||||
except UnknownJobIDException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
@ -485,17 +503,17 @@ async def control_install_jobs(
|
||||
|
||||
@models_router.patch(
|
||||
"/jobs/cancel_all",
|
||||
operation_id="cancel_all_jobs",
|
||||
operation_id="cancel_all_download_jobs",
|
||||
responses={
|
||||
204: {"description": "All jobs cancelled successfully"},
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
)
|
||||
async def cancel_install_jobs():
|
||||
async def cancel_all_download_jobs():
|
||||
"""Cancel all model installation jobs."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
job_mgr = ApiDependencies.invoker.services.model_installer
|
||||
logger.info("Cancelling all model installation jobs.")
|
||||
job_mgr = ApiDependencies.invoker.services.download_queue
|
||||
logger.info("Cancelling all download jobs.")
|
||||
job_mgr.cancel_all_jobs()
|
||||
return Response(status_code=204)
|
||||
|
||||
@ -510,7 +528,6 @@ async def cancel_install_jobs():
|
||||
)
|
||||
async def prune_jobs():
|
||||
"""Prune all completed and errored jobs."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
mgr = ApiDependencies.invoker.services.model_installer
|
||||
mgr = ApiDependencies.invoker.services.download_queue
|
||||
mgr.prune_jobs()
|
||||
return Response(status_code=204)
|
||||
|
@ -145,7 +145,7 @@ class BaseCommand(ABC, BaseModel):
|
||||
"""A CLI command"""
|
||||
|
||||
# All commands must include a type name like this:
|
||||
# type: Literal['your_command_name'] = 'your_command_name'
|
||||
# Literal['your_command_name'] = 'your_command_name'
|
||||
|
||||
@classmethod
|
||||
def get_all_subclasses(cls):
|
||||
|
@ -9,7 +9,18 @@ from typing import List, Optional, Union
|
||||
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
|
||||
from invokeai.backend.model_manager.download import DownloadEventHandler, DownloadJobBase, DownloadQueue
|
||||
from invokeai.backend.model_manager.download import ( # noqa F401
|
||||
DownloadEventHandler,
|
||||
DownloadJobBase,
|
||||
DownloadJobPath,
|
||||
DownloadJobStatus,
|
||||
DownloadQueueBase,
|
||||
ModelDownloadQueue,
|
||||
ModelSourceMetadata,
|
||||
UnknownJobIDException,
|
||||
)
|
||||
from invokeai.backend.model_manager.download import DownloadJobRemoteSource # noqa F401
|
||||
|
||||
|
||||
from .events import EventServiceBase
|
||||
|
||||
@ -40,6 +51,22 @@ class DownloadQueueServiceBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def submit_download_job(
|
||||
self,
|
||||
job: DownloadJobBase,
|
||||
start: bool = True,
|
||||
):
|
||||
"""
|
||||
Submit a download job.
|
||||
|
||||
:param job: A DownloadJobBase
|
||||
:param start: Immediately start job [True]
|
||||
|
||||
After execution, `job.id` will be set to a non-negative value.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_jobs(self) -> List[DownloadJobBase]:
|
||||
"""
|
||||
@ -76,6 +103,11 @@ class DownloadQueueServiceBase(ABC):
|
||||
"""Cancel all active and enquedjobs."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def prune_jobs(self):
|
||||
"""Prune completed and errored queue items from the job list."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def start_job(self, job: DownloadJobBase):
|
||||
"""Start the job putting it into ENQUEUED state."""
|
||||
@ -115,7 +147,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
"""Multithreaded queue for downloading models via URL or repo_id."""
|
||||
|
||||
_event_bus: EventServiceBase
|
||||
_queue: DownloadQueue
|
||||
_queue: DownloadQueueBase
|
||||
|
||||
def __init__(self, event_bus: EventServiceBase, **kwargs):
|
||||
"""
|
||||
@ -126,7 +158,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
e.g. `max_parallel_dl`.
|
||||
"""
|
||||
self._event_bus = event_bus
|
||||
self._queue = DownloadQueue(**kwargs)
|
||||
self._queue = ModelDownloadQueue(**kwargs)
|
||||
|
||||
def create_download_job(
|
||||
self,
|
||||
@ -149,6 +181,13 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
event_handlers=event_handlers,
|
||||
)
|
||||
|
||||
def submit_download_job(
|
||||
self,
|
||||
job: DownloadJobBase,
|
||||
start: bool = True,
|
||||
):
|
||||
return self._queue.submit_download_job(job, start)
|
||||
|
||||
def list_jobs(self) -> List[DownloadJobBase]: # noqa D102
|
||||
return self._queue.list_jobs()
|
||||
|
||||
@ -164,6 +203,9 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
def cancel_all_jobs(self): # noqa D102
|
||||
return self._queue.cancel_all_jobs()
|
||||
|
||||
def prune_jobs(self, job: DownloadJobBase): # noqa D102
|
||||
return self._queue.prune_jobs()
|
||||
|
||||
def start_job(self, job: DownloadJobBase): # noqa D102
|
||||
return self._queue.start_job(job)
|
||||
|
||||
|
@ -18,6 +18,7 @@ if TYPE_CHECKING:
|
||||
from invokeai.app.services.invoker import InvocationProcessorABC
|
||||
from invokeai.app.services.item_storage import ItemStorageABC
|
||||
from invokeai.app.services.latent_storage import LatentsStorageBase
|
||||
from invokeai.app.services.download_manager import DownloadQueueServiceBase
|
||||
from invokeai.app.services.model_install_service import ModelInstallServiceBase
|
||||
from invokeai.app.services.model_loader_service import ModelLoadServiceBase
|
||||
from invokeai.app.services.model_record_service import ModelRecordServiceBase
|
||||
@ -37,6 +38,7 @@ class InvocationServices:
|
||||
graph_library: "ItemStorageABC[LibraryGraph]"
|
||||
images: "ImageServiceABC"
|
||||
latents: "LatentsStorageBase"
|
||||
download_queue: "DownloadQueueServiceBase"
|
||||
model_record_store: "ModelRecordServiceBase"
|
||||
model_loader: "ModelLoadServiceBase"
|
||||
model_installer: "ModelInstallServiceBase"
|
||||
@ -59,6 +61,7 @@ class InvocationServices:
|
||||
images: "ImageServiceABC",
|
||||
latents: "LatentsStorageBase",
|
||||
logger: "Logger",
|
||||
download_queue: "DownloadQueueServiceBase",
|
||||
model_record_store: "ModelRecordServiceBase",
|
||||
model_loader: "ModelLoadServiceBase",
|
||||
model_installer: "ModelInstallServiceBase",
|
||||
@ -78,6 +81,7 @@ class InvocationServices:
|
||||
self.images = images
|
||||
self.latents = latents
|
||||
self.logger = logger
|
||||
self.download_queue = download_queue
|
||||
self.model_record_store = model_record_store
|
||||
self.model_loader = model_loader
|
||||
self.model_installer = model_installer
|
||||
|
190
invokeai/app/services/model_convert.py
Normal file
190
invokeai/app/services/model_convert.py
Normal file
@ -0,0 +1,190 @@
|
||||
# Copyright 2023 Lincoln Stein and the InvokeAI Team
|
||||
|
||||
"""
|
||||
Convert and merge models.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, List
|
||||
|
||||
from pydantic import Field
|
||||
from pathlib import Path
|
||||
from shutil import move, rmtree
|
||||
|
||||
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
|
||||
from .config import InvokeAIAppConfig
|
||||
from .model_install_service import ModelInstallServiceBase
|
||||
from .model_loader_service import ModelLoadServiceBase, ModelInfo
|
||||
from .model_record_service import ModelRecordServiceBase, ModelConfigBase, ModelType, SubModelType
|
||||
|
||||
|
||||
class ModelConvertBase(ABC):
|
||||
"""Convert and merge models."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
cls,
|
||||
loader: ModelLoadServiceBase,
|
||||
installer: ModelInstallServiceBase,
|
||||
store: ModelRecordServiceBase,
|
||||
):
|
||||
"""Initialize ModelConvert with loader, installer and configuration store."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def convert_model(
|
||||
self,
|
||||
key: str,
|
||||
dest_directory: Optional[Path] = None,
|
||||
) -> ModelConfigBase:
|
||||
"""
|
||||
Convert a checkpoint file into a diffusers folder.
|
||||
|
||||
It will delete the cached version ans well as the
|
||||
original checkpoint file if it is in the models directory.
|
||||
:param key: Unique key of model.
|
||||
:dest_directory: Optional place to put converted file. If not specified,
|
||||
will be stored in the `models_dir`.
|
||||
|
||||
This will raise a ValueError unless the model is a checkpoint.
|
||||
This will raise an UnknownModelException if key is unknown.
|
||||
"""
|
||||
pass
|
||||
|
||||
def merge_models(
|
||||
self,
|
||||
model_keys: List[str] = Field(
|
||||
default=None, min_items=2, max_items=3, description="List of model keys to merge"
|
||||
),
|
||||
merged_model_name: Optional[str] = Field(default=None, description="Name of destination model after merging"),
|
||||
alpha: Optional[float] = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: Optional[bool] = False,
|
||||
merge_dest_directory: Optional[Path] = None,
|
||||
) -> ModelConfigBase:
|
||||
"""
|
||||
Merge two to three diffusrs pipeline models and save as a new model.
|
||||
|
||||
:param model_keys: List of 2-3 model unique keys to merge
|
||||
:param merged_model_name: Name of destination merged model
|
||||
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||
:param interp: Interpolation method. None (default)
|
||||
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||
"""
|
||||
pass
|
||||
|
||||
class ModelConvert(ModelConvertBase):
|
||||
"""Implementation of ModelConvertBase."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
loader: ModelLoadServiceBase,
|
||||
installer: ModelInstallServiceBase,
|
||||
store: ModelRecordServiceBase,
|
||||
):
|
||||
"""Initialize ModelConvert with loader, installer and configuration store."""
|
||||
self.loader = loader
|
||||
self.installer = installer
|
||||
self.store = store
|
||||
|
||||
def convert_model(
|
||||
self,
|
||||
key: str,
|
||||
dest_directory: Optional[Path] = None,
|
||||
) -> ModelConfigBase:
|
||||
"""
|
||||
Convert a checkpoint file into a diffusers folder.
|
||||
|
||||
It will delete the cached version ans well as the
|
||||
original checkpoint file if it is in the models directory.
|
||||
:param key: Unique key of model.
|
||||
:dest_directory: Optional place to put converted file. If not specified,
|
||||
will be stored in the `models_dir`.
|
||||
|
||||
This will raise a ValueError unless the model is a checkpoint.
|
||||
This will raise an UnknownModelException if key is unknown.
|
||||
"""
|
||||
new_diffusers_path = None
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
|
||||
try:
|
||||
info: ModelConfigBase = self.store.get_model(key)
|
||||
|
||||
if info.model_format != "checkpoint":
|
||||
raise ValueError(f"not a checkpoint format model: {info.name}")
|
||||
|
||||
# We are taking advantage of a side effect of get_model() that converts check points
|
||||
# into cached diffusers directories stored at `path`. It doesn't matter
|
||||
# what submodel type we request here, so we get the smallest.
|
||||
submodel = {"submodel_type": SubModelType.Scheduler} if info.model_type == ModelType.Main else {}
|
||||
converted_model: ModelInfo = self.loader.get_model(key, **submodel)
|
||||
|
||||
checkpoint_path = config.models_path / info.path
|
||||
old_diffusers_path = config.models_path / converted_model.location
|
||||
|
||||
# new values to write in
|
||||
update = info.dict()
|
||||
update.pop("config")
|
||||
update["model_format"] = "diffusers"
|
||||
update["path"] = str(converted_model.location)
|
||||
|
||||
if dest_directory:
|
||||
new_diffusers_path = Path(dest_directory) / info.name
|
||||
if new_diffusers_path.exists():
|
||||
raise ValueError(f"A diffusers model already exists at {new_diffusers_path}")
|
||||
move(old_diffusers_path, new_diffusers_path)
|
||||
update["path"] = new_diffusers_path.as_posix()
|
||||
|
||||
self.store.update_model(key, update)
|
||||
result = self.installer.sync_model_path(key, ignore_hash_change=True)
|
||||
except Exception as excp:
|
||||
# something went wrong, so don't leave dangling diffusers model in directory or it will cause a duplicate model error!
|
||||
if new_diffusers_path:
|
||||
rmtree(new_diffusers_path)
|
||||
raise excp
|
||||
|
||||
if checkpoint_path.exists() and checkpoint_path.is_relative_to(config.models_path):
|
||||
checkpoint_path.unlink()
|
||||
|
||||
return result
|
||||
|
||||
def merge_models(
|
||||
self,
|
||||
model_keys: List[str] = Field(
|
||||
default=None, min_items=2, max_items=3, description="List of model keys to merge"
|
||||
),
|
||||
merged_model_name: Optional[str] = Field(default=None, description="Name of destination model after merging"),
|
||||
alpha: Optional[float] = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: Optional[bool] = False,
|
||||
merge_dest_directory: Optional[Path] = None,
|
||||
) -> ModelConfigBase:
|
||||
"""
|
||||
Merge two to three diffusrs pipeline models and save as a new model.
|
||||
|
||||
:param model_keys: List of 2-3 model unique keys to merge
|
||||
:param merged_model_name: Name of destination merged model
|
||||
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||
:param interp: Interpolation method. None (default)
|
||||
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||
"""
|
||||
pass
|
||||
merger = ModelMerger(self.store)
|
||||
try:
|
||||
if not merged_model_name:
|
||||
merged_model_name = "+".join([self.store.get_model(x).name for x in model_keys])
|
||||
raise Exception("not implemented")
|
||||
|
||||
result = merger.merge_diffusion_models_and_save(
|
||||
model_keys=model_keys,
|
||||
merged_model_name=merged_model_name,
|
||||
alpha=alpha,
|
||||
interp=interp,
|
||||
force=force,
|
||||
merge_dest_directory=merge_dest_directory,
|
||||
)
|
||||
except AssertionError as e:
|
||||
raise ValueError(e)
|
||||
return result
|
@ -1,41 +1,143 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import shutil
|
||||
from abc import abstractmethod
|
||||
import re
|
||||
import tempfile
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, Set, Union
|
||||
from shutil import move, rmtree
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Union, Literal
|
||||
|
||||
from pydantic import Field, parse_obj_as
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.model_record_service import ModelRecordServiceBase
|
||||
from invokeai.backend.util import Chdir, InvokeAILogger, Logger
|
||||
from invokeai.backend import get_precision
|
||||
from invokeai.backend.model_manager import ModelConfigBase, ModelSearch
|
||||
from invokeai.backend.model_manager.download import DownloadJobBase
|
||||
from invokeai.backend.model_manager.install import ModelInstall, ModelInstallBase, ModelInstallJob
|
||||
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from .config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_manager.config import (
|
||||
BaseModelType,
|
||||
ModelConfigBase,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.download.model_queue import (
|
||||
HTTP_RE,
|
||||
REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE,
|
||||
DownloadJobRepoID,
|
||||
DownloadJobWithMetadata,
|
||||
)
|
||||
from invokeai.backend.model_manager.hash import FastModelHash
|
||||
from invokeai.backend.model_manager.models import InvalidModelException
|
||||
from invokeai.backend.model_manager.probe import ModelProbe, ModelProbeInfo
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
from invokeai.backend.model_manager.storage import DuplicateModelException, ModelConfigStore
|
||||
from .events import EventServiceBase
|
||||
from .model_record_service import ModelRecordServiceBase
|
||||
from .download_manager import (
|
||||
DownloadQueueServiceBase,
|
||||
DownloadQueueService,
|
||||
DownloadJobBase,
|
||||
DownloadJobPath,
|
||||
DownloadEventHandler,
|
||||
ModelSourceMetadata,
|
||||
)
|
||||
|
||||
|
||||
class ModelInstallServiceBase(ModelInstallBase): # This is an ABC
|
||||
"""Responsible for downloading, installing and deleting models."""
|
||||
class ModelInstallJob(DownloadJobBase):
|
||||
"""This is a version of DownloadJobBase that has an additional slot for the model key and probe info."""
|
||||
|
||||
model_key: Optional[str] = Field(
|
||||
description="After model installation, this field will hold its primary key", default=None
|
||||
)
|
||||
probe_override: Optional[Dict[str, Any]] = Field(
|
||||
description="Keys in this dict will override like-named attributes in the automatic probe info",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
class ModelInstallURLJob(DownloadJobWithMetadata, ModelInstallJob):
|
||||
"""Job for installing URLs."""
|
||||
|
||||
|
||||
class ModelInstallRepoIDJob(DownloadJobRepoID, ModelInstallJob):
|
||||
"""Job for installing repo ids."""
|
||||
|
||||
|
||||
class ModelInstallPathJob(DownloadJobPath, ModelInstallJob):
|
||||
"""Job for installing local paths."""
|
||||
|
||||
|
||||
ModelInstallEventHandler = Callable[["ModelInstallJob"], None]
|
||||
|
||||
|
||||
class ModelInstallServiceBase(ABC):
|
||||
"""Abstract base class for InvokeAI model installation."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self, config: InvokeAIAppConfig, store: ModelRecordServiceBase, event_bus: Optional[EventServiceBase] = None
|
||||
self,
|
||||
config: Optional[InvokeAIAppConfig] = None,
|
||||
queue: Optional[DownloadQueueServiceBase] = None,
|
||||
store: Optional[ModelRecordServiceBase] = None,
|
||||
event_bus: Optional[EventServiceBase] = None,
|
||||
):
|
||||
"""
|
||||
Initialize a ModelInstallService instance.
|
||||
Create ModelInstallService object.
|
||||
|
||||
:param config: InvokeAIAppConfig object
|
||||
:param store: A ModelRecordServiceBase object install to
|
||||
:param event_bus: Optional EventServiceBase object. If provided,
|
||||
installation and download events will be sent to the event bus as "model_event".
|
||||
:param store: Optional ModelConfigStore. If None passed,
|
||||
defaults to `configs/models.yaml`.
|
||||
:param config: Optional InvokeAIAppConfig. If None passed,
|
||||
uses the system-wide default app config.
|
||||
:param logger: Optional InvokeAILogger. If None passed,
|
||||
uses the system-wide default logger.
|
||||
:param download: Optional DownloadQueueServiceBase object. If None passed,
|
||||
a default queue object will be created.
|
||||
:param event_handlers: List of event handlers to pass to the queue object.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def queue(self) -> DownloadQueueServiceBase:
|
||||
"""Return the download queue used by the installer."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def store(self) -> ModelRecordServiceBase:
|
||||
"""Return the storage backend used by the installer."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def config(self) -> InvokeAIAppConfig:
|
||||
"""Return the app_config used by the installer."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def register_path(self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
Probe and register the model at model_path.
|
||||
|
||||
:param model_path: Filesystem Path to the model.
|
||||
:param overrides: Dict of attributes that will override probed values.
|
||||
:returns id: The string ID of the registered model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def install_path(self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]] = None) -> str:
|
||||
"""
|
||||
Probe, register and install the model in the models directory.
|
||||
|
||||
This involves moving the model from its current location into
|
||||
the models directory handled by InvokeAI.
|
||||
|
||||
:param model_path: Filesystem Path to the model.
|
||||
:param overrides: Dictionary of model probe info fields that, if present, override probed values.
|
||||
:returns id: The string ID of the installed model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -43,327 +145,500 @@ class ModelInstallServiceBase(ModelInstallBase): # This is an ABC
|
||||
def install_model(
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
inplace: bool = True,
|
||||
priority: int = 10,
|
||||
model_attributes: Optional[Dict[str, Any]] = None,
|
||||
variant: Optional[str] = None,
|
||||
subfolder: Optional[str] = None,
|
||||
probe_override: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[ModelSourceMetadata] = None,
|
||||
access_token: Optional[str] = None,
|
||||
) -> ModelInstallJob:
|
||||
"""Import a path, repo_id or URL. Returns an ModelInstallJob.
|
||||
"""
|
||||
Download and install the indicated model.
|
||||
|
||||
:param model_attributes: Additional attributes to supplement/override
|
||||
the model information gained from automated probing.
|
||||
:param priority: Queue priority. Lower values have higher priority.
|
||||
This will download the model located at `source`,
|
||||
probe it, and install it into the models directory.
|
||||
This call is executed asynchronously in a separate
|
||||
thread, and the returned object is a
|
||||
invokeai.backend.model_manager.download.DownloadJobBase
|
||||
object which can be interrogated to get the status of
|
||||
the download and install process. Call our `wait_for_installs()`
|
||||
method to wait for all downloads and installations to complete.
|
||||
|
||||
Typical usage:
|
||||
job = model_manager.install(
|
||||
'stabilityai/stable-diffusion-2-1',
|
||||
model_attributes={'prediction_type": 'v-prediction'}
|
||||
)
|
||||
:param source: Either a URL or a HuggingFace repo_id.
|
||||
:param inplace: If True, local paths will not be moved into
|
||||
the models directory, but registered in place (the default).
|
||||
:param variant: For HuggingFace models, this optional parameter
|
||||
specifies which variant to download (e.g. 'fp16')
|
||||
:param subfolder: When downloading HF repo_ids this can be used to
|
||||
specify a subfolder of the HF repository to download from.
|
||||
:param probe_override: Optional dict. Any fields in this dict
|
||||
will override corresponding probe fields. Use it to override
|
||||
`base_type`, `model_type`, `format`, `prediction_type` and `image_size`.
|
||||
:param metadata: Use this to override the fields 'description`,
|
||||
`author`, `tags`, `source` and `license`.
|
||||
|
||||
The result is an ModelInstallJob object, which provides
|
||||
information on the asynchronous model download and install
|
||||
process. A series of "install_model_event" events will be emitted
|
||||
until the install is completed, cancelled or errors out.
|
||||
:returns ModelInstallJob object.
|
||||
|
||||
The `inplace` flag does not affect the behavior of downloaded
|
||||
models, which are always moved into the `models` directory.
|
||||
|
||||
Variants recognized by HuggingFace currently are:
|
||||
1. onnx
|
||||
2. openvino
|
||||
3. fp16
|
||||
4. None (usually returns fp32 model)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_install_jobs(self) -> List[ModelInstallJob]:
|
||||
"""Return a series of active or enqueued ModelInstallJobs."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def id_to_job(self, id: int) -> ModelInstallJob:
|
||||
"""Return the ModelInstallJob instance corresponding to the given job ID."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def start_job(self, job_id: int):
|
||||
"""Start the given install job if it is paused or idle."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pause_job(self, job_id: int):
|
||||
"""Pause the given install job if it is paused or idle."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_job(self, job_id: int):
|
||||
"""Cancel the given install job."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_all_jobs(self):
|
||||
"""Cancel all installation jobs."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def prune_jobs(self):
|
||||
"""Remove completed or errored install jobs."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def change_job_priority(self, job_id: int, delta: int):
|
||||
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]:
|
||||
"""
|
||||
Change an install job's priority.
|
||||
Wait for all pending installs to complete.
|
||||
|
||||
:param job_id: Job to change
|
||||
:param delta: Value to increment or decrement priority.
|
||||
This will block until all pending downloads have
|
||||
completed, been cancelled, or errored out. It will
|
||||
block indefinitely if one or more jobs are in the
|
||||
paused state.
|
||||
|
||||
Lower values are higher priority. The default starting value is 10.
|
||||
Thus to make this a really high priority job:
|
||||
manager.change_job_priority(-10).
|
||||
It will return a dict that maps the source model
|
||||
path, URL or repo_id to the ID of the installed model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def merge_models(
|
||||
self,
|
||||
model_keys: List[str] = Field(
|
||||
default=None, min_items=2, max_items=3, description="List of model keys to merge"
|
||||
),
|
||||
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
||||
alpha: Optional[float] = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: Optional[bool] = False,
|
||||
merge_dest_directory: Optional[Path] = None,
|
||||
) -> ModelConfigBase:
|
||||
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]:
|
||||
"""
|
||||
Merge two to three diffusrs pipeline models and save as a new model.
|
||||
Recursively scan directory for new models and register or install them.
|
||||
|
||||
:param model_keys: List of 2-3 model unique keys to merge
|
||||
:param merged_model_name: Name of destination merged model
|
||||
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||
:param interp: Interpolation method. None (default)
|
||||
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||
:param scan_dir: Path to the directory to scan.
|
||||
:param install: Install if True, otherwise register in place.
|
||||
:returns list of IDs: Returns list of IDs of models registered/installed
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_checkpoint_configs(self) -> List[Path]:
|
||||
"""List the checkpoint config paths from ROOT/configs/stable-diffusion."""
|
||||
def hash(self, model_path: Union[Path, str]) -> str:
|
||||
"""
|
||||
Compute and return the fast hash of the model.
|
||||
|
||||
:param model_path: Path to the model on disk.
|
||||
:return str: FastHash of the model for use as an ID.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search_for_models(self, directory: Path) -> Set[Path]:
|
||||
"""Return list of all models found in the designated directory."""
|
||||
pass
|
||||
|
||||
|
||||
# implementation
|
||||
class ModelInstallService(ModelInstall, ModelInstallServiceBase):
|
||||
"""Responsible for managing models on disk and in memory."""
|
||||
class ModelInstallService(ModelInstallServiceBase):
|
||||
"""Model installer class handles installation from a local path."""
|
||||
|
||||
_app_config: InvokeAIAppConfig
|
||||
_logger: Logger
|
||||
_store: ModelConfigStore
|
||||
_download_queue: DownloadQueueServiceBase
|
||||
_async_installs: Dict[Union[str, Path, AnyHttpUrl], Optional[str]]
|
||||
_installed: Set[str] = Field(default=set)
|
||||
_tmpdir: Optional[tempfile.TemporaryDirectory] # used for downloads
|
||||
_cached_model_paths: Set[Path] = Field(default=set) # used to speed up directory scanning
|
||||
_precision: Literal["float16", "float32"] = Field(description="Floating point precision, string form")
|
||||
_event_bus: Optional[EventServiceBase] = Field(description="an event bus to send install events to", default=None)
|
||||
|
||||
_legacy_configs: Dict[BaseModelType, Dict[ModelVariantType, Union[str, dict]]] = {
|
||||
BaseModelType.StableDiffusion1: {
|
||||
ModelVariantType.Normal: "v1-inference.yaml",
|
||||
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
|
||||
},
|
||||
BaseModelType.StableDiffusion2: {
|
||||
ModelVariantType.Normal: {
|
||||
SchedulerPredictionType.Epsilon: "v2-inference.yaml",
|
||||
SchedulerPredictionType.VPrediction: "v2-inference-v.yaml",
|
||||
},
|
||||
ModelVariantType.Inpaint: {
|
||||
SchedulerPredictionType.Epsilon: "v2-inpainting-inference.yaml",
|
||||
SchedulerPredictionType.VPrediction: "v2-inpainting-inference-v.yaml",
|
||||
},
|
||||
},
|
||||
BaseModelType.StableDiffusionXL: {
|
||||
ModelVariantType.Normal: "sd_xl_base.yaml",
|
||||
},
|
||||
BaseModelType.StableDiffusionXLRefiner: {
|
||||
ModelVariantType.Normal: "sd_xl_refiner.yaml",
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self, config: InvokeAIAppConfig, store: ModelRecordServiceBase, event_bus: Optional[EventServiceBase] = None
|
||||
):
|
||||
"""
|
||||
Initialize a ModelInstallService instance.
|
||||
|
||||
:param config: InvokeAIAppConfig object
|
||||
:param store: Either a ModelRecordService object or a ModelConfigStore
|
||||
:param event_bus: Optional EventServiceBase object. If provided,
|
||||
|
||||
Installation and download events will be sent to the event bus as "model_event".
|
||||
"""
|
||||
self,
|
||||
config: Optional[InvokeAIAppConfig] = None,
|
||||
queue: Optional[DownloadQueueServiceBase] = None,
|
||||
store: Optional[ModelRecordServiceBase] = None,
|
||||
event_bus: Optional[EventServiceBase] = None,
|
||||
event_handlers: List[DownloadEventHandler] = [],
|
||||
): # noqa D107 - use base class docstrings
|
||||
self._app_config = config or InvokeAIAppConfig.get_config()
|
||||
self._store = store or ModelRecordServiceBase.get_impl(self._app_config)
|
||||
self._logger = InvokeAILogger.get_logger(config=self._app_config)
|
||||
self._event_bus = event_bus
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if self._event_bus:
|
||||
kwargs.update(event_handlers=[self._event_bus.emit_model_event])
|
||||
self._precision = get_precision()
|
||||
logger = InvokeAILogger.get_logger()
|
||||
super().__init__(store=store, config=config, logger=logger, **kwargs)
|
||||
self._handlers = event_handlers
|
||||
if self._event_bus:
|
||||
self._handlers.append(self._event_bus.emit_model_event)
|
||||
|
||||
self._download_queue = queue or DownloadQueueService(
|
||||
event_bus=event_bus,
|
||||
config=self._app_config
|
||||
)
|
||||
self._async_installs: Dict[Union[str, Path, AnyHttpUrl], Union[str, None]] = dict()
|
||||
self._installed = set()
|
||||
self._tmpdir = None
|
||||
|
||||
def start(self, invoker: Any): # Because .processor is giving circular import errors, declaring invoker an 'Any'
|
||||
"""Call automatically at process start."""
|
||||
self.scan_models_directory() # synchronize new/deleted models found in models directory
|
||||
self.sync_to_config()
|
||||
|
||||
@property
|
||||
def queue(self) -> DownloadQueueServiceBase:
|
||||
"""Return the queue."""
|
||||
return self._download_queue
|
||||
|
||||
@property
|
||||
def store(self) -> ModelConfigStore:
|
||||
"""Return the storage backend used by the installer."""
|
||||
return self._store
|
||||
|
||||
@property
|
||||
def config(self) -> InvokeAIAppConfig:
|
||||
"""Return the app_config used by the installer."""
|
||||
return self._app_config
|
||||
|
||||
def install_model(
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
inplace: bool = True,
|
||||
priority: int = 10,
|
||||
variant: Optional[str] = None,
|
||||
subfolder: Optional[str] = None,
|
||||
probe_override: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[ModelSourceMetadata] = None,
|
||||
access_token: Optional[str] = None,
|
||||
) -> DownloadJobBase: # noqa D102
|
||||
queue = self._download_queue
|
||||
variant = variant or ("fp16" if self._precision == "float16" else None)
|
||||
|
||||
job = self._make_download_job(
|
||||
source, variant=variant, access_token=access_token, subfolder=subfolder, priority=priority
|
||||
)
|
||||
handler = (
|
||||
self._complete_registration_handler
|
||||
if inplace and Path(source).exists()
|
||||
else self._complete_installation_handler
|
||||
)
|
||||
if isinstance(job, ModelInstallJob):
|
||||
job.probe_override = probe_override
|
||||
if metadata and isinstance(job, DownloadJobWithMetadata):
|
||||
job.metadata = metadata
|
||||
job.add_event_handler(handler)
|
||||
|
||||
self._async_installs[source] = None
|
||||
queue.submit_download_job(job, True)
|
||||
return job
|
||||
|
||||
def register_path(
|
||||
self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]] = None
|
||||
) -> str: # noqa D102
|
||||
model_path = Path(model_path)
|
||||
info: ModelProbeInfo = self._probe_model(model_path, overrides)
|
||||
return self._register(model_path, info)
|
||||
|
||||
def install_path(
|
||||
self,
|
||||
model_path: Union[Path, str],
|
||||
overrides: Optional[Dict[str, Any]] = None,
|
||||
) -> str: # noqa D102
|
||||
model_path = Path(model_path)
|
||||
info: ModelProbeInfo = self._probe_model(model_path, overrides)
|
||||
|
||||
dest_path = self._app_config.models_path / info.base_type.value / info.model_type.value / model_path.name
|
||||
new_path = self._move_model(model_path, dest_path)
|
||||
new_hash = self.hash(new_path)
|
||||
assert new_hash == info.hash, f"{model_path}: Model hash changed during installation, possibly corrupted."
|
||||
return self._register(
|
||||
new_path,
|
||||
info,
|
||||
)
|
||||
|
||||
def unregister(self, key: str): # noqa D102
|
||||
self._store.del_model(key)
|
||||
|
||||
def delete(self, key: str): # noqa D102
|
||||
model = self._store.get_model(key)
|
||||
path = self._app_config.models_path / model.path
|
||||
if path.is_dir():
|
||||
rmtree(path)
|
||||
else:
|
||||
path.unlink()
|
||||
self.unregister(key)
|
||||
|
||||
def conditionally_delete(self, key: str): # noqa D102
|
||||
"""Unregister the model. Delete its files only if they are within our models directory."""
|
||||
model = self._store.get_model(key)
|
||||
models_dir = self._app_config.models_path
|
||||
model_path = models_dir / model.path
|
||||
if model_path.is_relative_to(models_dir):
|
||||
self.delete(key)
|
||||
else:
|
||||
self.unregister(key)
|
||||
|
||||
def _register(self, model_path: Path, info: ModelProbeInfo) -> str:
|
||||
key: str = FastModelHash.hash(model_path)
|
||||
|
||||
model_path = model_path.absolute()
|
||||
if model_path.is_relative_to(self._app_config.models_path):
|
||||
model_path = model_path.relative_to(self._app_config.models_path)
|
||||
|
||||
registration_data = dict(
|
||||
path=model_path.as_posix(),
|
||||
name=model_path.name if model_path.is_dir() else model_path.stem,
|
||||
base_model=info.base_type,
|
||||
model_type=info.model_type,
|
||||
model_format=info.format,
|
||||
hash=key,
|
||||
)
|
||||
# add 'main' specific fields
|
||||
if info.model_type == ModelType.Main:
|
||||
if info.variant_type:
|
||||
registration_data.update(variant=info.variant_type)
|
||||
if info.format == ModelFormat.Checkpoint:
|
||||
try:
|
||||
config_file = self._legacy_configs[info.base_type][info.variant_type]
|
||||
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
||||
if prediction_type := info.prediction_type:
|
||||
config_file = config_file[prediction_type]
|
||||
else:
|
||||
self._logger.warning(
|
||||
f"Could not infer prediction type for {model_path.stem}. Guessing 'v_prediction' for a SD-2 768 pixel model"
|
||||
)
|
||||
config_file = config_file[SchedulerPredictionType.VPrediction]
|
||||
registration_data.update(
|
||||
config=Path(self._app_config.legacy_conf_dir, str(config_file)).as_posix(),
|
||||
)
|
||||
except KeyError as exc:
|
||||
raise InvalidModelException(
|
||||
"Configuration file for this checkpoint could not be determined"
|
||||
) from exc
|
||||
self._store.add_model(key, registration_data)
|
||||
return key
|
||||
|
||||
def _move_model(self, old_path: Path, new_path: Path) -> Path:
|
||||
if old_path == new_path:
|
||||
return old_path
|
||||
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# if path already exists then we jigger the name to make it unique
|
||||
counter: int = 1
|
||||
while new_path.exists():
|
||||
path = new_path.with_stem(new_path.stem + f"_{counter:02d}")
|
||||
if not path.exists():
|
||||
new_path = path
|
||||
counter += 1
|
||||
return move(old_path, new_path)
|
||||
|
||||
def _probe_model(self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]] = None) -> ModelProbeInfo:
|
||||
info: ModelProbeInfo = ModelProbe.probe(Path(model_path))
|
||||
if overrides: # used to override probe fields
|
||||
for key, value in overrides.items():
|
||||
try:
|
||||
setattr(info, key, value) # skip validation errors
|
||||
except Exception:
|
||||
pass
|
||||
return info
|
||||
|
||||
def _complete_installation_handler(self, job: DownloadJobBase):
|
||||
assert isinstance(job, ModelInstallJob)
|
||||
if job.status == "completed":
|
||||
self._logger.info(f"{job.source}: Download finished with status {job.status}. Installing.")
|
||||
model_id = self.install_path(job.destination, job.probe_override)
|
||||
info = self._store.get_model(model_id)
|
||||
info.source = str(job.source)
|
||||
if isinstance(job, DownloadJobWithMetadata):
|
||||
metadata: ModelSourceMetadata = job.metadata
|
||||
info.description = metadata.description or f"Imported model {info.name}"
|
||||
info.name = metadata.name or info.name
|
||||
info.author = metadata.author
|
||||
info.tags = metadata.tags
|
||||
info.license = metadata.license
|
||||
info.thumbnail_url = metadata.thumbnail_url
|
||||
self._store.update_model(model_id, info)
|
||||
self._async_installs[job.source] = model_id
|
||||
job.model_key = model_id
|
||||
elif job.status == "error":
|
||||
self._logger.warning(f"{job.source}: Model installation error: {job.error}")
|
||||
elif job.status == "cancelled":
|
||||
self._logger.warning(f"{job.source}: Model installation cancelled at caller's request.")
|
||||
jobs = self._download_queue.list_jobs()
|
||||
if self._tmpdir and len(jobs) <= 1 and job.status in ["completed", "error", "cancelled"]:
|
||||
self._tmpdir.cleanup()
|
||||
self._tmpdir = None
|
||||
|
||||
def _complete_registration_handler(self, job: DownloadJobBase):
|
||||
assert isinstance(job, ModelInstallJob)
|
||||
if job.status == "completed":
|
||||
self._logger.info(f"{job.source}: Installing in place.")
|
||||
model_id = self.register_path(job.destination, job.probe_override)
|
||||
info = self._store.get_model(model_id)
|
||||
info.source = str(job.source)
|
||||
info.description = f"Imported model {info.name}"
|
||||
self._store.update_model(model_id, info)
|
||||
self._async_installs[job.source] = model_id
|
||||
job.model_key = model_id
|
||||
elif job.status == "error":
|
||||
self._logger.warning(f"{job.source}: Model installation error: {job.error}")
|
||||
elif job.status == "cancelled":
|
||||
self._logger.warning(f"{job.source}: Model installation cancelled at caller's request.")
|
||||
|
||||
def sync_model_path(self, key: str, ignore_hash_change: bool = False) -> ModelConfigBase:
|
||||
"""
|
||||
Move model into the location indicated by its basetype, type and name.
|
||||
|
||||
Call this after updating a model's attributes in order to move
|
||||
the model's path into the location indicated by its basetype, type and
|
||||
name. Applies only to models whose paths are within the root `models_dir`
|
||||
directory.
|
||||
|
||||
May raise an UnknownModelException.
|
||||
"""
|
||||
model = self._store.get_model(key)
|
||||
old_path = Path(model.path)
|
||||
models_dir = self._app_config.models_path
|
||||
|
||||
if not old_path.is_relative_to(models_dir):
|
||||
return model
|
||||
|
||||
new_path = models_dir / model.base_model.value / model.model_type.value / model.name
|
||||
self._logger.info(f"Moving {model.name} to {new_path}.")
|
||||
new_path = self._move_model(old_path, new_path)
|
||||
model.hash = self.hash(new_path)
|
||||
model.path = new_path.relative_to(models_dir).as_posix()
|
||||
if model.hash != key:
|
||||
assert (
|
||||
ignore_hash_change
|
||||
), f"{model.name}: Model hash changed during installation, model is possibly corrupted"
|
||||
self._logger.info(f"Model has new hash {model.hash}, but will continue to be identified by {key}")
|
||||
self._store.update_model(key, model)
|
||||
return model
|
||||
|
||||
def _make_download_job(
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
variant: Optional[str] = None,
|
||||
subfolder: Optional[str] = None,
|
||||
access_token: Optional[str] = None,
|
||||
priority: Optional[int] = 10,
|
||||
) -> ModelInstallJob:
|
||||
# Clean up a common source of error. Doesn't work with Paths.
|
||||
if isinstance(source, str):
|
||||
source = source.strip()
|
||||
|
||||
# In the event that we are being asked to install a path that is already on disk,
|
||||
# we simply probe and register/install it. The job does not actually do anything, but we
|
||||
# create one anyway in order to have similar behavior for local files, URLs and repo_ids.
|
||||
if Path(source).exists(): # a path that is already on disk
|
||||
destdir = source
|
||||
return ModelInstallPathJob(source=source, destination=Path(destdir), event_handlers=self._handlers)
|
||||
|
||||
# choose a temporary directory inside the models directory
|
||||
models_dir = self._app_config.models_path
|
||||
self._tmpdir = self._tmpdir or tempfile.TemporaryDirectory(dir=models_dir)
|
||||
|
||||
cls = ModelInstallJob
|
||||
if match := re.match(REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE, str(source)):
|
||||
cls = ModelInstallRepoIDJob
|
||||
source = match.group(1)
|
||||
subfolder = match.group(2) or subfolder
|
||||
kwargs = dict(variant=variant, subfolder=subfolder)
|
||||
elif re.match(HTTP_RE, str(source)):
|
||||
cls = ModelInstallURLJob
|
||||
kwargs = {}
|
||||
else:
|
||||
raise ValueError(f"'{source}' is not recognized as a local file, directory, repo_id or URL")
|
||||
return cls(
|
||||
source=str(source),
|
||||
destination=Path(self._tmpdir.name),
|
||||
access_token=access_token,
|
||||
priority=priority,
|
||||
event_handlers=self._handlers,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]:
|
||||
"""Pause until all installation jobs have completed."""
|
||||
self._download_queue.join()
|
||||
id_map = self._async_installs
|
||||
self._async_installs = dict()
|
||||
return id_map
|
||||
|
||||
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
|
||||
self._cached_model_paths = set([Path(x.path) for x in self._store.all_models()])
|
||||
callback = self._scan_install if install else self._scan_register
|
||||
search = ModelSearch(on_model_found=callback)
|
||||
self._installed = set()
|
||||
search.search(scan_dir)
|
||||
return list(self._installed)
|
||||
|
||||
def scan_models_directory(self):
|
||||
"""
|
||||
Scan the models directory for new and missing models.
|
||||
|
||||
New models will be added to the storage backend. Missing models
|
||||
will be deleted.
|
||||
"""
|
||||
defunct_models = set()
|
||||
installed = set()
|
||||
|
||||
with Chdir(self._app_config.models_path):
|
||||
self._logger.info("Checking for models that have been moved or deleted from disk")
|
||||
for model_config in self._store.all_models():
|
||||
path = Path(model_config.path)
|
||||
if not path.exists():
|
||||
self._logger.info(f"{model_config.name}: path {path.as_posix()} no longer exists. Unregistering")
|
||||
defunct_models.add(model_config.key)
|
||||
for key in defunct_models:
|
||||
self.unregister(key)
|
||||
|
||||
self._logger.info(f"Scanning {self._app_config.models_path} for new models")
|
||||
for cur_base_model in BaseModelType:
|
||||
for cur_model_type in ModelType:
|
||||
models_dir = Path(cur_base_model.value, cur_model_type.value)
|
||||
installed.update(self.scan_directory(models_dir))
|
||||
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")
|
||||
|
||||
def sync_to_config(self):
|
||||
"""Synchronize models on disk to those in memory."""
|
||||
self.scan_models_directory()
|
||||
if autoimport := self._app_config.autoimport_dir:
|
||||
self._logger.info("Scanning autoimport directory for new models")
|
||||
self.scan_directory(self._app_config.root_path / autoimport)
|
||||
|
||||
def install_model(
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
priority: int = 10,
|
||||
model_attributes: Optional[Dict[str, Any]] = None,
|
||||
) -> ModelInstallJob:
|
||||
"""
|
||||
Add a model using a path, repo_id or URL.
|
||||
def hash(self, model_path: Union[Path, str]) -> str: # noqa D102
|
||||
return FastModelHash.hash(model_path)
|
||||
|
||||
:param model_attributes: Dictionary of ModelConfigBase fields to
|
||||
attach to the model. When installing a URL or repo_id, some metadata
|
||||
values, such as `tags` will be automagically added.
|
||||
:param priority: Queue priority for this install job. Lower value jobs
|
||||
will run before higher value ones.
|
||||
"""
|
||||
self.logger.debug(f"add model {source}")
|
||||
variant = "fp16" if self._precision == "float16" else None
|
||||
job = self.install(
|
||||
source,
|
||||
probe_override=model_attributes,
|
||||
variant=variant,
|
||||
priority=priority,
|
||||
)
|
||||
assert isinstance(job, ModelInstallJob)
|
||||
return job
|
||||
|
||||
def list_install_jobs(self) -> List[ModelInstallJob]:
|
||||
"""Return a series of active or enqueued ModelInstallJobs."""
|
||||
queue = self.queue
|
||||
jobs: List[DownloadJobBase] = queue.list_jobs()
|
||||
return [parse_obj_as(ModelInstallJob, x) for x in jobs] # downcast to proper type
|
||||
|
||||
def id_to_job(self, id: int) -> ModelInstallJob:
|
||||
"""Return the ModelInstallJob instance corresponding to the given job ID."""
|
||||
job = self.queue.id_to_job(id)
|
||||
assert isinstance(job, ModelInstallJob)
|
||||
return job
|
||||
|
||||
def start_job(self, job_id: int):
|
||||
"""Start the given install job if it is paused or idle."""
|
||||
queue = self.queue
|
||||
queue.start_job(queue.id_to_job(job_id))
|
||||
|
||||
def pause_job(self, job_id: int):
|
||||
"""Pause the given install job if it is paused or idle."""
|
||||
queue = self.queue
|
||||
queue.pause_job(queue.id_to_job(job_id))
|
||||
|
||||
def cancel_job(self, job_id: int):
|
||||
"""Cancel the given install job."""
|
||||
queue = self.queue
|
||||
queue.cancel_job(queue.id_to_job(job_id))
|
||||
|
||||
def cancel_all_jobs(self):
|
||||
"""Cancel all active install job."""
|
||||
queue = self.queue
|
||||
queue.cancel_all_jobs()
|
||||
|
||||
def prune_jobs(self):
|
||||
"""Cancel all active install job."""
|
||||
queue = self.queue
|
||||
queue.prune_jobs()
|
||||
|
||||
def change_job_priority(self, job_id: int, delta: int):
|
||||
"""
|
||||
Change an install job's priority.
|
||||
|
||||
:param job_id: Job to change
|
||||
:param delta: Value to increment or decrement priority.
|
||||
|
||||
Lower values are higher priority. The default starting value is 10.
|
||||
Thus to make this a really high priority job:
|
||||
manager.change_job_priority(-10).
|
||||
"""
|
||||
queue = self.queue
|
||||
queue.change_priority(queue.id_to_job(job_id), delta)
|
||||
|
||||
def del_model(
|
||||
self,
|
||||
key: str,
|
||||
delete_files: bool = False,
|
||||
):
|
||||
"""
|
||||
Delete the named model from configuration.
|
||||
|
||||
If delete_files is true,
|
||||
then the underlying weight file or diffusers directory will be deleted
|
||||
as well.
|
||||
"""
|
||||
model_info = self.store.get_model(key)
|
||||
self.logger.debug(f"delete model {model_info.name}")
|
||||
self.store.del_model(key)
|
||||
if delete_files and Path(self._app_config.models_path / model_info.path).exists():
|
||||
path = Path(model_info.path)
|
||||
if path.is_dir():
|
||||
shutil.rmtree(path)
|
||||
else:
|
||||
path.unlink()
|
||||
|
||||
def convert_model(
|
||||
self,
|
||||
key: str,
|
||||
dest_directory: Optional[Path] = None,
|
||||
) -> ModelConfigBase:
|
||||
"""
|
||||
Convert a checkpoint file into a diffusers folder.
|
||||
|
||||
Delete the cached
|
||||
version and delete the original checkpoint file if it is in the models
|
||||
directory.
|
||||
|
||||
:param key: Key of the model to convert
|
||||
:param convert_dest_directory: Save the converted model to the designated directory (`models/etc/etc` by default)
|
||||
|
||||
This will raise a ValueError unless the model is a checkpoint. It will
|
||||
also raise a ValueError in the event that there is a similarly-named diffusers
|
||||
directory already in place.
|
||||
"""
|
||||
model_info = self.store.get_model(key)
|
||||
self.logger.info(f"Converting model {model_info.name} into a diffusers")
|
||||
return super().convert_model(key, dest_directory)
|
||||
|
||||
@property
|
||||
def logger(self):
|
||||
"""Get the logger associated with this instance."""
|
||||
return self._logger
|
||||
|
||||
@property
|
||||
def store(self):
|
||||
"""Get the store associated with this instance."""
|
||||
return self._store
|
||||
|
||||
def merge_models(
|
||||
self,
|
||||
model_keys: List[str] = Field(
|
||||
default=None, min_items=2, max_items=3, description="List of model keys to merge"
|
||||
),
|
||||
merged_model_name: Optional[str] = Field(default=None, description="Name of destination model after merging"),
|
||||
alpha: Optional[float] = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: Optional[bool] = False,
|
||||
merge_dest_directory: Optional[Path] = None,
|
||||
) -> ModelConfigBase:
|
||||
"""
|
||||
Merge two to three diffusrs pipeline models and save as a new model.
|
||||
|
||||
:param model_keys: List of 2-3 model unique keys to merge
|
||||
:param merged_model_name: Name of destination merged model
|
||||
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||
:param interp: Interpolation method. None (default)
|
||||
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||
"""
|
||||
merger = ModelMerger(self.store)
|
||||
def _scan_register(self, model: Path) -> bool:
|
||||
if model in self._cached_model_paths:
|
||||
return True
|
||||
try:
|
||||
if not merged_model_name:
|
||||
merged_model_name = "+".join([self.store.get_model(x).name for x in model_keys])
|
||||
raise Exception("not implemented")
|
||||
id = self.register_path(model)
|
||||
self.sync_model_path(id) # possibly move it to right place in `models`
|
||||
self._logger.info(f"Registered {model.name} with id {id}")
|
||||
self._installed.add(id)
|
||||
except DuplicateModelException:
|
||||
pass
|
||||
return True
|
||||
|
||||
result = merger.merge_diffusion_models_and_save(
|
||||
model_keys=model_keys,
|
||||
merged_model_name=merged_model_name,
|
||||
alpha=alpha,
|
||||
interp=interp,
|
||||
force=force,
|
||||
merge_dest_directory=merge_dest_directory,
|
||||
)
|
||||
except AssertionError as e:
|
||||
raise ValueError(e)
|
||||
return result
|
||||
|
||||
def search_for_models(self, directory: Path) -> Set[Path]:
|
||||
"""
|
||||
Return list of all models found in the designated directory.
|
||||
|
||||
:param directory: Path to the directory to recursively search.
|
||||
returns a list of model paths
|
||||
"""
|
||||
return ModelSearch().search(directory)
|
||||
|
||||
def list_checkpoint_configs(self) -> List[Path]:
|
||||
"""List the checkpoint config paths from ROOT/configs/stable-diffusion."""
|
||||
config = self._app_config
|
||||
conf_path = config.legacy_conf_path
|
||||
root_path = config.root_path
|
||||
return [(conf_path / x).relative_to(root_path) for x in conf_path.glob("**/*.yaml")]
|
||||
def _scan_install(self, model: Path) -> bool:
|
||||
if model in self._cached_model_paths:
|
||||
return True
|
||||
try:
|
||||
id = self.install_path(model)
|
||||
self._logger.info(f"Installed {model} with id {id}")
|
||||
self._installed.add(id)
|
||||
except DuplicateModelException:
|
||||
pass
|
||||
return True
|
||||
|
@ -6,11 +6,12 @@ from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
||||
|
||||
from pydantic import Field
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.app.models.exceptions import CanceledException
|
||||
from invokeai.backend.model_manager import ModelConfigStore, SubModelType
|
||||
from invokeai.backend.model_manager.cache import CacheStats
|
||||
from invokeai.backend.model_manager.loader import ModelInfo, ModelLoad
|
||||
from invokeai.backend.model_manager.loader import ModelInfo, ModelLoad, ModelConfigBase
|
||||
|
||||
from .config import InvokeAIAppConfig
|
||||
from .model_record_service import ModelRecordServiceBase
|
||||
@ -57,6 +58,8 @@ class ModelLoadServiceBase(ABC):
|
||||
"""Reset model cache statistics for graph with graph_id."""
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
# implementation
|
||||
class ModelLoadService(ModelLoadServiceBase):
|
||||
@ -137,3 +140,4 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
model_key=model_key,
|
||||
submodel=submodel,
|
||||
)
|
||||
|
||||
|
@ -8,6 +8,8 @@ from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
from invokeai.backend.model_manager import ModelConfigBase, ModelType, SubModelType
|
||||
|
||||
from invokeai.backend.model_manager.storage import (
|
||||
ModelConfigStore,
|
||||
ModelConfigStoreSQL,
|
||||
|
@ -14,7 +14,7 @@ import invokeai.configs as configs
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelType
|
||||
from invokeai.backend.model_manager.download.queue import DownloadJobRemoteSource
|
||||
from invokeai.backend.model_manager.install import ModelInstall, ModelInstallJob, ModelSourceMetadata
|
||||
from invokeai.app.services.model_install_service import ModelInstall, ModelInstallJob, ModelSourceMetadata
|
||||
|
||||
# name of the starter models file
|
||||
INITIAL_MODELS = "INITIAL_MODELS.yaml"
|
||||
@ -168,10 +168,9 @@ class InstallHelper(object):
|
||||
self._add_required_models(selections.install_models)
|
||||
for model in selections.install_models:
|
||||
metadata = ModelSourceMetadata(description=model.description, name=model.name)
|
||||
installer.install(
|
||||
installer.install_model(
|
||||
model.source,
|
||||
subfolder=model.subfolder,
|
||||
variant="fp16" if self._config.precision == "float16" else None,
|
||||
access_token=HfFolder.get_token(),
|
||||
metadata=metadata,
|
||||
)
|
||||
|
@ -8,4 +8,4 @@ from .base import ( # noqa F401
|
||||
UnknownJobIDException,
|
||||
)
|
||||
from .model_queue import ModelDownloadQueue, ModelSourceMetadata # noqa F401
|
||||
from .queue import DownloadJobPath, DownloadJobURL, DownloadQueue # noqa F401
|
||||
from .queue import DownloadJobPath, DownloadJobURL, DownloadQueue, DownloadJobRemoteSource # noqa F401
|
||||
|
@ -1,822 +0,0 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
Install/delete models.
|
||||
|
||||
Typical usage:
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_manager import ModelInstall
|
||||
from invokeai.backend.model_manager.storage import ModelConfigStoreSQL
|
||||
from invokeai.backend.model_manager.download import DownloadQueue
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
store = ModelConfigStoreSQL(config.db_path)
|
||||
download = DownloadQueue()
|
||||
installer = ModelInstall(store=store, config=config, download=download)
|
||||
|
||||
# register config, don't move path
|
||||
id: str = installer.register_path('/path/to/model')
|
||||
|
||||
# register config, and install model in `models`
|
||||
id: str = installer.install_path('/path/to/model')
|
||||
|
||||
# download some remote models and install them in the background
|
||||
installer.install('stabilityai/stable-diffusion-2-1')
|
||||
installer.install('https://civitai.com/api/download/models/154208')
|
||||
installer.install('runwayml/stable-diffusion-v1-5')
|
||||
installer.install('/home/user/models/stable-diffusion-v1-5', inplace=True)
|
||||
|
||||
installed_ids = installer.wait_for_installs()
|
||||
id1 = installed_ids['stabilityai/stable-diffusion-2-1']
|
||||
id2 = installed_ids['https://civitai.com/api/download/models/154208']
|
||||
|
||||
# unregister, don't delete
|
||||
installer.unregister(id)
|
||||
|
||||
# unregister and delete model from disk
|
||||
installer.delete_model(id)
|
||||
|
||||
# scan directory recursively and install all new models found
|
||||
ids: List[str] = installer.scan_directory('/path/to/directory')
|
||||
|
||||
# Synchronize with the models directory, adding missing models and
|
||||
# removing orphans
|
||||
installer.scan_models_directory()
|
||||
|
||||
hash: str = installer.hash('/path/to/model') # should be same as id above
|
||||
|
||||
The following exceptions may be raised:
|
||||
DuplicateModelException
|
||||
UnknownModelTypeException
|
||||
"""
|
||||
import re
|
||||
import tempfile
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from shutil import move, rmtree
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Union
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.model_record_service import ModelRecordServiceBase
|
||||
from invokeai.backend.util import Chdir, InvokeAILogger, Logger
|
||||
|
||||
from .config import (
|
||||
BaseModelType,
|
||||
ModelConfigBase,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
SubModelType,
|
||||
)
|
||||
from .download import (
|
||||
DownloadEventHandler,
|
||||
DownloadJobBase,
|
||||
DownloadJobPath,
|
||||
DownloadJobURL,
|
||||
DownloadQueueBase,
|
||||
ModelDownloadQueue,
|
||||
ModelSourceMetadata,
|
||||
)
|
||||
from .download.model_queue import (
|
||||
HTTP_RE,
|
||||
REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE,
|
||||
DownloadJobRepoID,
|
||||
DownloadJobWithMetadata,
|
||||
)
|
||||
from .hash import FastModelHash
|
||||
from .models import InvalidModelException
|
||||
from .probe import ModelProbe, ModelProbeInfo
|
||||
from .search import ModelSearch
|
||||
from .storage import DuplicateModelException, ModelConfigStore
|
||||
|
||||
|
||||
class ModelInstallJob(DownloadJobBase):
|
||||
"""This is a version of DownloadJobBase that has an additional slot for the model key and probe info."""
|
||||
|
||||
model_key: Optional[str] = Field(
|
||||
description="After model installation, this field will hold its primary key", default=None
|
||||
)
|
||||
probe_override: Optional[Dict[str, Any]] = Field(
|
||||
description="Keys in this dict will override like-named attributes in the automatic probe info",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
class ModelInstallURLJob(DownloadJobWithMetadata, ModelInstallJob):
|
||||
"""Job for installing URLs."""
|
||||
|
||||
|
||||
class ModelInstallRepoIDJob(DownloadJobRepoID, ModelInstallJob):
|
||||
"""Job for installing repo ids."""
|
||||
|
||||
|
||||
class ModelInstallPathJob(DownloadJobPath, ModelInstallJob):
|
||||
"""Job for installing local paths."""
|
||||
|
||||
|
||||
ModelInstallEventHandler = Callable[["ModelInstallJob"], None]
|
||||
|
||||
|
||||
class ModelInstallBase(ABC):
|
||||
"""Abstract base class for InvokeAI model installation."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[InvokeAIAppConfig] = None,
|
||||
store: Optional[ModelConfigStore] = None,
|
||||
logger: Optional[InvokeAILogger] = None,
|
||||
download: Optional[DownloadQueueBase] = None,
|
||||
event_handlers: Optional[List[DownloadEventHandler]] = None,
|
||||
):
|
||||
"""
|
||||
Create ModelInstall object.
|
||||
|
||||
:param store: Optional ModelConfigStore. If None passed,
|
||||
defaults to `configs/models.yaml`.
|
||||
:param config: Optional InvokeAIAppConfig. If None passed,
|
||||
uses the system-wide default app config.
|
||||
:param logger: Optional InvokeAILogger. If None passed,
|
||||
uses the system-wide default logger.
|
||||
:param download: Optional DownloadQueueBase object. If None passed,
|
||||
a default queue object will be created.
|
||||
:param event_handlers: List of event handlers to pass to the queue object.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def queue(self) -> DownloadQueueBase:
|
||||
"""Return the download queue used by the installer."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def store(self) -> ModelConfigStore:
|
||||
"""Return the storage backend used by the installer."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def config(self) -> InvokeAIAppConfig:
|
||||
"""Return the app_config used by the installer."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def register_path(self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
Probe and register the model at model_path.
|
||||
|
||||
:param model_path: Filesystem Path to the model.
|
||||
:param overrides: Dict of attributes that will override probed values.
|
||||
:returns id: The string ID of the registered model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def install_path(self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]] = None) -> str:
|
||||
"""
|
||||
Probe, register and install the model in the models directory.
|
||||
|
||||
This involves moving the model from its current location into
|
||||
the models directory handled by InvokeAI.
|
||||
|
||||
:param model_path: Filesystem Path to the model.
|
||||
:param overrides: Dictionary of model probe info fields that, if present, override probed values.
|
||||
:returns id: The string ID of the installed model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def install(
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
inplace: bool = True,
|
||||
priority: int = 10,
|
||||
variant: Optional[str] = None,
|
||||
subfolder: Optional[str] = None,
|
||||
probe_override: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[ModelSourceMetadata] = None,
|
||||
access_token: Optional[str] = None,
|
||||
) -> DownloadJobBase:
|
||||
"""
|
||||
Download and install the indicated model.
|
||||
|
||||
This will download the model located at `source`,
|
||||
probe it, and install it into the models directory.
|
||||
This call is executed asynchronously in a separate
|
||||
thread, and the returned object is a
|
||||
invokeai.backend.model_manager.download.DownloadJobBase
|
||||
object which can be interrogated to get the status of
|
||||
the download and install process. Call our `wait_for_installs()`
|
||||
method to wait for all downloads and installations to complete.
|
||||
|
||||
:param source: Either a URL or a HuggingFace repo_id.
|
||||
:param inplace: If True, local paths will not be moved into
|
||||
the models directory, but registered in place (the default).
|
||||
:param variant: For HuggingFace models, this optional parameter
|
||||
specifies which variant to download (e.g. 'fp16')
|
||||
:param subfolder: When downloading HF repo_ids this can be used to
|
||||
specify a subfolder of the HF repository to download from.
|
||||
:param probe_override: Optional dict. Any fields in this dict
|
||||
will override corresponding probe fields. Use it to override
|
||||
`base_type`, `model_type`, `format`, `prediction_type` and `image_size`.
|
||||
:param metadata: Use this to override the fields 'description`,
|
||||
`author`, `tags`, `source` and `license`.
|
||||
:returns DownloadQueueBase object.
|
||||
|
||||
The `inplace` flag does not affect the behavior of downloaded
|
||||
models, which are always moved into the `models` directory.
|
||||
|
||||
Variants recognized by HuggingFace currently are:
|
||||
1. onnx
|
||||
2. openvino
|
||||
3. fp16
|
||||
4. None (usually returns fp32 model)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]:
|
||||
"""
|
||||
Wait for all pending installs to complete.
|
||||
|
||||
This will block until all pending downloads have
|
||||
completed, been cancelled, or errored out. It will
|
||||
block indefinitely if one or more jobs are in the
|
||||
paused state.
|
||||
|
||||
It will return a dict that maps the source model
|
||||
path, URL or repo_id to the ID of the installed model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def unregister(self, id: str):
|
||||
"""
|
||||
Unregister the model identified by id.
|
||||
|
||||
This removes the model from the registry without
|
||||
deleting the underlying model from disk.
|
||||
|
||||
:param id: The string ID of the model to forget.
|
||||
:raises UnknownModelException: In the event the ID is unknown.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, id: str):
|
||||
"""
|
||||
Unregister and delete the model identified by id.
|
||||
|
||||
This removes the model from the registry and
|
||||
deletes the underlying model from disk.
|
||||
|
||||
:param id: The string ID of the model to forget.
|
||||
:raises UnknownModelException: In the event the ID is unknown.
|
||||
:raises OSError: In the event the model cannot be deleted from disk.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def conditionally_delete(self, key: str): # noqa D102
|
||||
"""Unregister the model. Delete its files only if they are within our models directory."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]:
|
||||
"""
|
||||
Recursively scan directory for new models and register or install them.
|
||||
|
||||
:param scan_dir: Path to the directory to scan.
|
||||
:param install: Install if True, otherwise register in place.
|
||||
:returns list of IDs: Returns list of IDs of models registered/installed
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def hash(self, model_path: Union[Path, str]) -> str:
|
||||
"""
|
||||
Compute and return the fast hash of the model.
|
||||
|
||||
:param model_path: Path to the model on disk.
|
||||
:return str: FastHash of the model for use as an ID.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def convert_model(
|
||||
self,
|
||||
key: str,
|
||||
dest_directory: Optional[Path] = None,
|
||||
) -> ModelConfigBase:
|
||||
"""
|
||||
Convert a checkpoint file into a diffusers folder.
|
||||
|
||||
It will delete the cached version ans well as the
|
||||
original checkpoint file if it is in the models directory.
|
||||
:param key: Unique key of model.
|
||||
:dest_directory: Optional place to put converted file. If not specified,
|
||||
will be stored in the `models_dir`.
|
||||
|
||||
This will raise a ValueError unless the model is a checkpoint.
|
||||
This will raise an UnknownModelException if key is unknown.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def sync_model_path(self, key) -> ModelConfigBase:
|
||||
"""
|
||||
Move model into the location indicated by its basetype, type and name.
|
||||
|
||||
Call this after updating a model's attributes in order to move
|
||||
the model's path into the location indicated by its basetype, type and
|
||||
name. Applies only to models whose paths are within the root `models_dir`
|
||||
directory.
|
||||
|
||||
May raise an UnknownModelException.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def sync_to_config(self):
|
||||
"""Synchronize models on disk to those in memory."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def scan_models_directory(self):
|
||||
"""
|
||||
Scan the models directory for new and missing models.
|
||||
|
||||
New models will be added to the storage backend. Missing models
|
||||
will be deleted.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ModelInstall(ModelInstallBase):
|
||||
"""Model installer class handles installation from a local path."""
|
||||
|
||||
_app_config: InvokeAIAppConfig
|
||||
_logger: Logger
|
||||
_store: ModelConfigStore
|
||||
_download_queue: DownloadQueueBase
|
||||
_async_installs: Dict[Union[str, Path, AnyHttpUrl], Optional[str]]
|
||||
_installed: Set[str] = Field(default=set)
|
||||
_tmpdir: Optional[tempfile.TemporaryDirectory] # used for downloads
|
||||
_cached_model_paths: Set[Path] = Field(default=set) # used to speed up directory scanning
|
||||
|
||||
_legacy_configs: Dict[BaseModelType, Dict[ModelVariantType, Union[str, dict]]] = {
|
||||
BaseModelType.StableDiffusion1: {
|
||||
ModelVariantType.Normal: "v1-inference.yaml",
|
||||
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
|
||||
},
|
||||
BaseModelType.StableDiffusion2: {
|
||||
ModelVariantType.Normal: {
|
||||
SchedulerPredictionType.Epsilon: "v2-inference.yaml",
|
||||
SchedulerPredictionType.VPrediction: "v2-inference-v.yaml",
|
||||
},
|
||||
ModelVariantType.Inpaint: {
|
||||
SchedulerPredictionType.Epsilon: "v2-inpainting-inference.yaml",
|
||||
SchedulerPredictionType.VPrediction: "v2-inpainting-inference-v.yaml",
|
||||
},
|
||||
},
|
||||
BaseModelType.StableDiffusionXL: {
|
||||
ModelVariantType.Normal: "sd_xl_base.yaml",
|
||||
},
|
||||
BaseModelType.StableDiffusionXLRefiner: {
|
||||
ModelVariantType.Normal: "sd_xl_refiner.yaml",
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
store: Optional[ModelConfigStore] = None,
|
||||
config: Optional[InvokeAIAppConfig] = None,
|
||||
logger: Optional[Logger] = None,
|
||||
download: Optional[DownloadQueueBase] = None,
|
||||
event_handlers: List[DownloadEventHandler] = [],
|
||||
): # noqa D107 - use base class docstrings
|
||||
self._app_config = config or InvokeAIAppConfig.get_config()
|
||||
self._logger = logger or InvokeAILogger.get_logger(config=self._app_config)
|
||||
self._store = store or ModelRecordServiceBase.get_impl(self._app_config)
|
||||
self._download_queue = download or ModelDownloadQueue(config=self._app_config, event_handlers=event_handlers)
|
||||
self._async_installs: Dict[Union[str, Path, AnyHttpUrl], Union[str, None]] = dict()
|
||||
self._installed = set()
|
||||
self._tmpdir = None
|
||||
|
||||
@property
|
||||
def queue(self) -> DownloadQueueBase:
|
||||
"""Return the queue."""
|
||||
return self._download_queue
|
||||
|
||||
@property
|
||||
def store(self) -> ModelConfigStore:
|
||||
"""Return the storage backend used by the installer."""
|
||||
return self._store
|
||||
|
||||
@property
|
||||
def config(self) -> InvokeAIAppConfig:
|
||||
"""Return the app_config used by the installer."""
|
||||
return self._app_config
|
||||
|
||||
def register_path(
|
||||
self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]] = None
|
||||
) -> str: # noqa D102
|
||||
model_path = Path(model_path)
|
||||
info: ModelProbeInfo = self._probe_model(model_path, overrides)
|
||||
return self._register(model_path, info)
|
||||
|
||||
def _register(self, model_path: Path, info: ModelProbeInfo) -> str:
|
||||
key: str = FastModelHash.hash(model_path)
|
||||
|
||||
model_path = model_path.absolute()
|
||||
if model_path.is_relative_to(self._app_config.models_path):
|
||||
model_path = model_path.relative_to(self._app_config.models_path)
|
||||
|
||||
registration_data = dict(
|
||||
path=model_path.as_posix(),
|
||||
name=model_path.name if model_path.is_dir() else model_path.stem,
|
||||
base_model=info.base_type,
|
||||
model_type=info.model_type,
|
||||
model_format=info.format,
|
||||
hash=key,
|
||||
)
|
||||
# add 'main' specific fields
|
||||
if info.model_type == ModelType.Main:
|
||||
if info.variant_type:
|
||||
registration_data.update(variant=info.variant_type)
|
||||
if info.format == ModelFormat.Checkpoint:
|
||||
try:
|
||||
config_file = self._legacy_configs[info.base_type][info.variant_type]
|
||||
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
||||
if prediction_type := info.prediction_type:
|
||||
config_file = config_file[prediction_type]
|
||||
else:
|
||||
self._logger.warning(
|
||||
f"Could not infer prediction type for {model_path.stem}. Guessing 'v_prediction' for a SD-2 768 pixel model"
|
||||
)
|
||||
config_file = config_file[SchedulerPredictionType.VPrediction]
|
||||
registration_data.update(
|
||||
config=Path(self._app_config.legacy_conf_dir, str(config_file)).as_posix(),
|
||||
)
|
||||
except KeyError as exc:
|
||||
raise InvalidModelException(
|
||||
"Configuration file for this checkpoint could not be determined"
|
||||
) from exc
|
||||
self._store.add_model(key, registration_data)
|
||||
return key
|
||||
|
||||
def install_path(
|
||||
self,
|
||||
model_path: Union[Path, str],
|
||||
overrides: Optional[Dict[str, Any]] = None,
|
||||
) -> str: # noqa D102
|
||||
model_path = Path(model_path)
|
||||
info: ModelProbeInfo = self._probe_model(model_path, overrides)
|
||||
|
||||
dest_path = self._app_config.models_path / info.base_type.value / info.model_type.value / model_path.name
|
||||
new_path = self._move_model(model_path, dest_path)
|
||||
new_hash = self.hash(new_path)
|
||||
assert new_hash == info.hash, f"{model_path}: Model hash changed during installation, possibly corrupted."
|
||||
return self._register(
|
||||
new_path,
|
||||
info,
|
||||
)
|
||||
|
||||
def _move_model(self, old_path: Path, new_path: Path) -> Path:
|
||||
if old_path == new_path:
|
||||
return old_path
|
||||
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# if path already exists then we jigger the name to make it unique
|
||||
counter: int = 1
|
||||
while new_path.exists():
|
||||
path = new_path.with_stem(new_path.stem + f"_{counter:02d}")
|
||||
if not path.exists():
|
||||
new_path = path
|
||||
counter += 1
|
||||
return move(old_path, new_path)
|
||||
|
||||
def _probe_model(self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]] = None) -> ModelProbeInfo:
|
||||
info: ModelProbeInfo = ModelProbe.probe(Path(model_path))
|
||||
if overrides: # used to override probe fields
|
||||
for key, value in overrides.items():
|
||||
try:
|
||||
setattr(info, key, value) # skip validation errors
|
||||
except Exception:
|
||||
pass
|
||||
return info
|
||||
|
||||
def unregister(self, key: str): # noqa D102
|
||||
self._store.del_model(key)
|
||||
|
||||
def delete(self, key: str): # noqa D102
|
||||
model = self._store.get_model(key)
|
||||
path = self._app_config.models_path / model.path
|
||||
if path.is_dir():
|
||||
rmtree(path)
|
||||
else:
|
||||
path.unlink()
|
||||
self.unregister(key)
|
||||
|
||||
def conditionally_delete(self, key: str): # noqa D102
|
||||
"""Unregister the model. Delete its files only if they are within our models directory."""
|
||||
model = self._store.get_model(key)
|
||||
models_dir = self._app_config.models_path
|
||||
model_path = models_dir / model.path
|
||||
if model_path.is_relative_to(models_dir):
|
||||
self.delete(key)
|
||||
else:
|
||||
self.unregister(key)
|
||||
|
||||
def install(
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
inplace: bool = True,
|
||||
priority: int = 10,
|
||||
variant: Optional[str] = None,
|
||||
subfolder: Optional[str] = None,
|
||||
probe_override: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[ModelSourceMetadata] = None,
|
||||
access_token: Optional[str] = None,
|
||||
) -> DownloadJobBase: # noqa D102
|
||||
queue = self._download_queue
|
||||
|
||||
job = self._make_download_job(
|
||||
source, variant=variant, access_token=access_token, subfolder=subfolder, priority=priority
|
||||
)
|
||||
handler = (
|
||||
self._complete_registration_handler
|
||||
if inplace and Path(source).exists()
|
||||
else self._complete_installation_handler
|
||||
)
|
||||
if isinstance(job, ModelInstallJob):
|
||||
job.probe_override = probe_override
|
||||
if metadata and isinstance(job, DownloadJobWithMetadata):
|
||||
job.metadata = metadata
|
||||
job.add_event_handler(handler)
|
||||
|
||||
self._async_installs[source] = None
|
||||
queue.submit_download_job(job, True)
|
||||
return job
|
||||
|
||||
def _complete_installation_handler(self, job: DownloadJobBase):
|
||||
assert isinstance(job, ModelInstallJob)
|
||||
if job.status == "completed":
|
||||
self._logger.info(f"{job.source}: Download finished with status {job.status}. Installing.")
|
||||
model_id = self.install_path(job.destination, job.probe_override)
|
||||
info = self._store.get_model(model_id)
|
||||
info.source = str(job.source)
|
||||
if isinstance(job, DownloadJobWithMetadata):
|
||||
metadata: ModelSourceMetadata = job.metadata
|
||||
info.description = metadata.description or f"Imported model {info.name}"
|
||||
info.name = metadata.name or info.name
|
||||
info.author = metadata.author
|
||||
info.tags = metadata.tags
|
||||
info.license = metadata.license
|
||||
info.thumbnail_url = metadata.thumbnail_url
|
||||
self._store.update_model(model_id, info)
|
||||
self._async_installs[job.source] = model_id
|
||||
job.model_key = model_id
|
||||
elif job.status == "error":
|
||||
self._logger.warning(f"{job.source}: Model installation error: {job.error}")
|
||||
elif job.status == "cancelled":
|
||||
self._logger.warning(f"{job.source}: Model installation cancelled at caller's request.")
|
||||
jobs = self._download_queue.list_jobs()
|
||||
if self._tmpdir and len(jobs) <= 1 and job.status in ["completed", "error", "cancelled"]:
|
||||
self._tmpdir.cleanup()
|
||||
self._tmpdir = None
|
||||
|
||||
def _complete_registration_handler(self, job: DownloadJobBase):
|
||||
assert isinstance(job, ModelInstallJob)
|
||||
if job.status == "completed":
|
||||
self._logger.info(f"{job.source}: Installing in place.")
|
||||
model_id = self.register_path(job.destination, job.probe_override)
|
||||
info = self._store.get_model(model_id)
|
||||
info.source = str(job.source)
|
||||
info.description = f"Imported model {info.name}"
|
||||
self._store.update_model(model_id, info)
|
||||
self._async_installs[job.source] = model_id
|
||||
job.model_key = model_id
|
||||
elif job.status == "error":
|
||||
self._logger.warning(f"{job.source}: Model installation error: {job.error}")
|
||||
elif job.status == "cancelled":
|
||||
self._logger.warning(f"{job.source}: Model installation cancelled at caller's request.")
|
||||
|
||||
def sync_model_path(self, key: str, ignore_hash_change: bool = False) -> ModelConfigBase:
|
||||
"""
|
||||
Move model into the location indicated by its basetype, type and name.
|
||||
|
||||
Call this after updating a model's attributes in order to move
|
||||
the model's path into the location indicated by its basetype, type and
|
||||
name. Applies only to models whose paths are within the root `models_dir`
|
||||
directory.
|
||||
|
||||
May raise an UnknownModelException.
|
||||
"""
|
||||
model = self._store.get_model(key)
|
||||
old_path = Path(model.path)
|
||||
models_dir = self._app_config.models_path
|
||||
|
||||
if not old_path.is_relative_to(models_dir):
|
||||
return model
|
||||
|
||||
new_path = models_dir / model.base_model.value / model.model_type.value / model.name
|
||||
self._logger.info(f"Moving {model.name} to {new_path}.")
|
||||
new_path = self._move_model(old_path, new_path)
|
||||
model.hash = self.hash(new_path)
|
||||
model.path = new_path.relative_to(models_dir).as_posix()
|
||||
if model.hash != key:
|
||||
assert (
|
||||
ignore_hash_change
|
||||
), f"{model.name}: Model hash changed during installation, model is possibly corrupted"
|
||||
self._logger.info(f"Model has new hash {model.hash}, but will continue to be identified by {key}")
|
||||
self._store.update_model(key, model)
|
||||
return model
|
||||
|
||||
def _make_download_job(
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
variant: Optional[str] = None,
|
||||
subfolder: Optional[str] = None,
|
||||
access_token: Optional[str] = None,
|
||||
priority: Optional[int] = 10,
|
||||
) -> ModelInstallJob:
|
||||
# Clean up a common source of error. Doesn't work with Paths.
|
||||
if isinstance(source, str):
|
||||
source = source.strip()
|
||||
|
||||
# In the event that we are being asked to install a path that is already on disk,
|
||||
# we simply probe and register/install it. The job does not actually do anything, but we
|
||||
# create one anyway in order to have similar behavior for local files, URLs and repo_ids.
|
||||
if Path(source).exists(): # a path that is already on disk
|
||||
destdir = source
|
||||
return ModelInstallPathJob(source=source, destination=Path(destdir))
|
||||
|
||||
# choose a temporary directory inside the models directory
|
||||
models_dir = self._app_config.models_path
|
||||
self._tmpdir = self._tmpdir or tempfile.TemporaryDirectory(dir=models_dir)
|
||||
|
||||
cls = ModelInstallJob
|
||||
if match := re.match(REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE, str(source)):
|
||||
cls = ModelInstallRepoIDJob
|
||||
source = match.group(1)
|
||||
subfolder = match.group(2) or subfolder
|
||||
kwargs = dict(variant=variant, subfolder=subfolder)
|
||||
elif re.match(HTTP_RE, str(source)):
|
||||
cls = ModelInstallURLJob
|
||||
kwargs = {}
|
||||
else:
|
||||
raise ValueError(f"'{source}' is not recognized as a local file, directory, repo_id or URL")
|
||||
return cls(
|
||||
source=str(source),
|
||||
destination=Path(self._tmpdir.name),
|
||||
access_token=access_token,
|
||||
priority=priority,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]:
|
||||
"""Pause until all installation jobs have completed."""
|
||||
self._download_queue.join()
|
||||
id_map = self._async_installs
|
||||
self._async_installs = dict()
|
||||
return id_map
|
||||
|
||||
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
|
||||
self._cached_model_paths = set([Path(x.path) for x in self._store.all_models()])
|
||||
callback = self._scan_install if install else self._scan_register
|
||||
search = ModelSearch(on_model_found=callback)
|
||||
self._installed = set()
|
||||
search.search(scan_dir)
|
||||
return list(self._installed)
|
||||
|
||||
def hash(self, model_path: Union[Path, str]) -> str: # noqa D102
|
||||
return FastModelHash.hash(model_path)
|
||||
|
||||
def convert_model(
|
||||
self,
|
||||
key: str,
|
||||
dest_directory: Optional[Path] = None,
|
||||
) -> ModelConfigBase:
|
||||
"""
|
||||
Convert a checkpoint file into a diffusers folder.
|
||||
|
||||
It will delete the cached version ans well as the
|
||||
original checkpoint file if it is in the models directory.
|
||||
:param key: Unique key of model.
|
||||
:dest_directory: Optional place to put converted file. If not specified,
|
||||
will be stored in the `models_dir`.
|
||||
|
||||
This will raise a ValueError unless the model is a checkpoint.
|
||||
This will raise an UnknownModelException if key is unknown.
|
||||
"""
|
||||
from .loader import ModelInfo, ModelLoad # to avoid circular imports
|
||||
|
||||
new_diffusers_path = None
|
||||
|
||||
try:
|
||||
info: ModelConfigBase = self._store.get_model(key)
|
||||
|
||||
if info.model_format != "checkpoint":
|
||||
raise ValueError(f"not a checkpoint format model: {info.name}")
|
||||
|
||||
# We are taking advantage of a side effect of get_model() that converts check points
|
||||
# into cached diffusers directories stored at `path`. It doesn't matter
|
||||
# what submodel type we request here, so we get the smallest.
|
||||
loader = ModelLoad(self._app_config, self.store)
|
||||
submodel = {"submodel_type": SubModelType.Scheduler} if info.model_type == ModelType.Main else {}
|
||||
converted_model: ModelInfo = loader.get_model(key, **submodel)
|
||||
|
||||
checkpoint_path = loader.resolve_model_path(info.path)
|
||||
old_diffusers_path = loader.resolve_model_path(converted_model.location)
|
||||
|
||||
# new values to write in
|
||||
update = info.dict()
|
||||
update.pop("config")
|
||||
update["model_format"] = "diffusers"
|
||||
update["path"] = str(converted_model.location)
|
||||
|
||||
if dest_directory:
|
||||
new_diffusers_path = Path(dest_directory) / info.name
|
||||
if new_diffusers_path.exists():
|
||||
raise ValueError(f"A diffusers model already exists at {new_diffusers_path}")
|
||||
move(old_diffusers_path, new_diffusers_path)
|
||||
update["path"] = new_diffusers_path.as_posix()
|
||||
|
||||
self._store.update_model(key, update)
|
||||
result = self.sync_model_path(key, ignore_hash_change=True)
|
||||
except Exception as excp:
|
||||
# something went wrong, so don't leave dangling diffusers model in directory or it will cause a duplicate model error!
|
||||
if new_diffusers_path:
|
||||
rmtree(new_diffusers_path)
|
||||
raise excp
|
||||
|
||||
if checkpoint_path.exists() and checkpoint_path.is_relative_to(self._app_config.models_path):
|
||||
checkpoint_path.unlink()
|
||||
|
||||
return result
|
||||
|
||||
# the following two methods are callbacks to the ModelSearch object
|
||||
def _scan_register(self, model: Path) -> bool:
|
||||
if model in self._cached_model_paths:
|
||||
return True
|
||||
try:
|
||||
id = self.register_path(model)
|
||||
self.sync_model_path(id) # possibly move it to right place in `models`
|
||||
self._logger.info(f"Registered {model.name} with id {id}")
|
||||
self._installed.add(id)
|
||||
except DuplicateModelException:
|
||||
pass
|
||||
return True
|
||||
|
||||
def _scan_install(self, model: Path) -> bool:
|
||||
if model in self._cached_model_paths:
|
||||
return True
|
||||
try:
|
||||
id = self.install_path(model)
|
||||
self._logger.info(f"Installed {model} with id {id}")
|
||||
self._installed.add(id)
|
||||
except DuplicateModelException:
|
||||
pass
|
||||
return True
|
||||
|
||||
def sync_to_config(self):
|
||||
"""Synchronize models on disk to those in memory."""
|
||||
self.scan_models_directory()
|
||||
if autoimport := self._app_config.autoimport_dir:
|
||||
self._logger.info("Scanning autoimport directory for new models")
|
||||
self.scan_directory(self._app_config.root_path / autoimport)
|
||||
|
||||
def scan_models_directory(self):
|
||||
"""
|
||||
Scan the models directory for new and missing models.
|
||||
|
||||
New models will be added to the storage backend. Missing models
|
||||
will be deleted.
|
||||
"""
|
||||
defunct_models = set()
|
||||
installed = set()
|
||||
|
||||
with Chdir(self._app_config.models_path):
|
||||
self._logger.info("Checking for models that have been moved or deleted from disk")
|
||||
for model_config in self._store.all_models():
|
||||
path = Path(model_config.path)
|
||||
if not path.exists():
|
||||
self._logger.info(f"{model_config.name}: path {path.as_posix()} no longer exists. Unregistering")
|
||||
defunct_models.add(model_config.key)
|
||||
for key in defunct_models:
|
||||
self.unregister(key)
|
||||
|
||||
self._logger.info(f"Scanning {self._app_config.models_path} for new models")
|
||||
for cur_base_model in BaseModelType:
|
||||
for cur_model_type in ModelType:
|
||||
models_dir = Path(cur_base_model.value, cur_model_type.value)
|
||||
installed.update(self.scan_directory(models_dir))
|
||||
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")
|
@ -6,6 +6,7 @@ from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, Union
|
||||
from shutil import move, rmtree
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -16,10 +16,10 @@ from diffusers import logging as dlogging
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.model_install_service import ModelInstallService
|
||||
|
||||
from . import BaseModelType, ModelConfigBase, ModelConfigStore, ModelType
|
||||
from .config import MainConfig
|
||||
from .loader import ModelLoad
|
||||
|
||||
|
||||
class MergeInterpolationMethod(str, Enum):
|
||||
@ -151,7 +151,7 @@ class ModelMerger(object):
|
||||
merged_pipe.save_pretrained(dump_path, safe_serialization=True)
|
||||
|
||||
# register model and get its unique key
|
||||
installer = ModelInstall(store=self._store, config=self._config)
|
||||
installer = ModelInstallService(store=self._store, config=self._config)
|
||||
key = installer.register_path(dump_path)
|
||||
|
||||
# update model's config
|
||||
|
@ -7,8 +7,8 @@ import torch
|
||||
|
||||
from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig
|
||||
from invokeai.app.services.model_record_service import ModelRecordServiceBase
|
||||
from invokeai.app.services.model_install_service import ModelInstallService
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelType, SubModelType, UnknownModelException
|
||||
from invokeai.backend.model_manager.install import ModelInstall
|
||||
from invokeai.backend.model_manager.loader import ModelInfo, ModelLoad
|
||||
|
||||
|
||||
@ -30,11 +30,11 @@ def model_installer():
|
||||
#
|
||||
config = InvokeAIAppConfig(log_level="info")
|
||||
model_store = ModelRecordServiceBase.get_impl(config)
|
||||
return ModelInstall(model_store, config)
|
||||
return ModelInstallService(store=model_store, config=config)
|
||||
|
||||
|
||||
def install_and_load_model(
|
||||
model_installer: ModelInstall,
|
||||
model_installer: ModelInstallService,
|
||||
model_path_id_or_url: Union[str, Path],
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
|
@ -26,9 +26,9 @@ from pydantic import BaseModel
|
||||
|
||||
import invokeai.configs as configs
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.model_install_service import ModelInstall, ModelInstallJob
|
||||
from invokeai.backend.install.install_helper import InstallHelper, UnifiedModelInfo
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelType
|
||||
from invokeai.backend.model_manager.install import ModelInstall, ModelInstallJob
|
||||
from invokeai.backend.util import choose_precision, choose_torch_device
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.frontend.install.widgets import (
|
||||
|
@ -59,6 +59,7 @@ def mock_services() -> InvocationServices:
|
||||
conn=db_conn, table_name="graph_executions", lock=lock
|
||||
)
|
||||
return InvocationServices(
|
||||
download_queue=None, # type: ignore
|
||||
model_loader=None, # type: ignore
|
||||
model_installer=None, # type: ignore
|
||||
model_record_store=None, # type: ignore
|
||||
|
@ -3,4 +3,5 @@
|
||||
|
||||
# We import the model_installer and torch_device fixtures here so that they can be used by all tests. Flake8 does not
|
||||
# play well with fixtures (F401 and F811), so this is cleaner than importing in all files that use these fixtures.
|
||||
from invokeai.backend.util.test_utils import model_installer, torch_device # noqa: F401
|
||||
from invokeai.backend.util.test_utils import torch_device # noqa: F401
|
||||
from invokeai.app.services.model_install_service import ModelInstallService # noqa: F401
|
||||
|
Reference in New Issue
Block a user