mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
450 lines
16 KiB
Python
450 lines
16 KiB
Python
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
|
|
|
from __future__ import annotations
|
|
|
|
import shutil
|
|
from abc import ABC, abstractmethod
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional, Union, Set, Literal
|
|
|
|
from pydantic import Field, parse_obj_as
|
|
from pydantic.networks import AnyHttpUrl
|
|
|
|
from invokeai.backend import get_precision
|
|
from invokeai.backend.model_manager.install import ModelInstallBase, ModelInstall, ModelInstallJob
|
|
from invokeai.backend.model_manager import (
|
|
ModelConfigBase,
|
|
ModelSearch,
|
|
)
|
|
from invokeai.backend.model_manager.storage import ModelConfigStore
|
|
from invokeai.backend.model_manager.download import DownloadJobBase
|
|
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
|
|
from invokeai.backend.util.logging import InvokeAILogger, Logger
|
|
|
|
from .config import InvokeAIAppConfig
|
|
from .events import EventServiceBase
|
|
from .model_record_service import ModelRecordServiceBase
|
|
|
|
|
|
class ModelInstallServiceBase(ABC):
|
|
"""Responsible for downloading, installing and deleting models."""
|
|
|
|
@abstractmethod
|
|
def __init__(self,
|
|
config: InvokeAIAppConfig,
|
|
store: Union[ModelConfigStore, ModelRecordServiceBase],
|
|
event_bus: Optional[EventServiceBase] = None
|
|
):
|
|
"""
|
|
Initialize a ModelInstallService instance.
|
|
|
|
:param config: InvokeAIAppConfig object
|
|
:param store: Either a ModelRecordServiceBase object or a ModelConfigStore
|
|
:param event_bus: Optional EventServiceBase object. If provided,
|
|
|
|
Installation and download events will be sent to the event bus as "model_event".
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def convert_model(
|
|
self,
|
|
key: str,
|
|
convert_dest_directory: Path,
|
|
) -> ModelConfigBase:
|
|
"""
|
|
Convert a checkpoint file into a diffusers folder.
|
|
|
|
This will delete the cached version if there is any and delete the original
|
|
checkpoint file if it is in the models directory.
|
|
:param key: Unique key for the model to convert.
|
|
:param convert_dest_directory: Save the converted model to the designated directory (`models/etc/etc` by default)
|
|
|
|
This will raise a ValueError unless the model is not a checkpoint. It will
|
|
also raise a ValueError in the event that there is a similarly-named diffusers
|
|
directory already in place.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def install_model(
|
|
self,
|
|
source: Union[str, Path, AnyHttpUrl],
|
|
priority: int = 10,
|
|
model_attributes: Optional[Dict[str, Any]] = None,
|
|
) -> ModelInstallJob:
|
|
"""Import a path, repo_id or URL. Returns an ModelInstallJob.
|
|
|
|
:param model_attributes: Additional attributes to supplement/override
|
|
the model information gained from automated probing.
|
|
:param priority: Queue priority. Lower values have higher priority.
|
|
|
|
Typical usage:
|
|
job = model_manager.install(
|
|
'stabilityai/stable-diffusion-2-1',
|
|
model_attributes={'prediction_type": 'v-prediction'}
|
|
)
|
|
|
|
The result is an ModelInstallJob object, which provides
|
|
information on the asynchronous model download and install
|
|
process. A series of "install_model_event" events will be emitted
|
|
until the install is completed, cancelled or errors out.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def list_install_jobs(self) -> List[ModelInstallJob]:
|
|
"""Return a series of active or enqueued ModelInstallJobs."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def id_to_job(self, id: int) -> ModelInstallJob:
|
|
"""Return the ModelInstallJob instance corresponding to the given job ID."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]:
|
|
"""
|
|
Wait for all pending installs to complete.
|
|
|
|
This will block until all pending downloads have
|
|
completed, been cancelled, or errored out. It will
|
|
block indefinitely if one or more jobs are in the
|
|
paused state.
|
|
|
|
It will return a dict that maps the source model
|
|
path, URL or repo_id to the ID of the installed model.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def start_job(self, job_id: int):
|
|
"""Start the given install job if it is paused or idle."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def pause_job(self, job_id: int):
|
|
"""Pause the given install job if it is paused or idle."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def cancel_job(self, job_id: int):
|
|
"""Cancel the given install job."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def cancel_all_jobs(self):
|
|
"""Cancel all active jobs."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def prune_jobs(self):
|
|
"""Remove completed or errored install jobs."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def change_job_priority(self, job_id: int, delta: int):
|
|
"""
|
|
Change an install job's priority.
|
|
|
|
:param job_id: Job to change
|
|
:param delta: Value to increment or decrement priority.
|
|
|
|
Lower values are higher priority. The default starting value is 10.
|
|
Thus to make this a really high priority job:
|
|
manager.change_job_priority(-10).
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def merge_models(
|
|
self,
|
|
model_keys: List[str] = Field(
|
|
default=None, min_items=2, max_items=3, description="List of model keys to merge"
|
|
),
|
|
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
|
alpha: Optional[float] = 0.5,
|
|
interp: Optional[MergeInterpolationMethod] = None,
|
|
force: Optional[bool] = False,
|
|
merge_dest_directory: Optional[Path] = None,
|
|
) -> ModelConfigBase:
|
|
"""
|
|
Merge two to three diffusrs pipeline models and save as a new model.
|
|
|
|
:param model_keys: List of 2-3 model unique keys to merge
|
|
:param merged_model_name: Name of destination merged model
|
|
:param alpha: Alpha strength to apply to 2d and 3d model
|
|
:param interp: Interpolation method. None (default)
|
|
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def list_checkpoint_configs(self) -> List[Path]:
|
|
"""List the checkpoint config paths from ROOT/configs/stable-diffusion."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def search_for_models(self, directory: Path) -> Set[Path]:
|
|
"""Return list of all models found in the designated directory."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def sync_to_config(self):
|
|
"""
|
|
Synchronize the in-memory models with on-disk.
|
|
|
|
Re-read models.yaml, rescan the models directory, and reimport models
|
|
in the autoimport directories. Call after making changes outside the
|
|
model manager API.
|
|
"""
|
|
pass
|
|
|
|
|
|
# implementation
|
|
class ModelInstallService(ModelInstallServiceBase):
|
|
"""Responsible for managing models on disk and in memory."""
|
|
|
|
_installer: ModelInstallBase = Field(description="ModelInstall object for the current process")
|
|
_config: InvokeAIAppConfig = Field(description="App configuration object")
|
|
_precision: Literal['float16', 'float32'] = Field(description="Floating point precision, string form")
|
|
_event_bus: Optional[EventServiceBase] = Field(description="an event bus to send install events to", default=None)
|
|
_logger: Logger = Field(description="logger module")
|
|
|
|
def __init__(self,
|
|
config: InvokeAIAppConfig,
|
|
store: Union[ModelConfigStore, ModelRecordServiceBase],
|
|
event_bus: Optional[EventServiceBase] = None
|
|
):
|
|
"""
|
|
Initialize a ModelInstallService instance.
|
|
|
|
:param config: InvokeAIAppConfig object
|
|
:param store: Either a ModelRecordService object or a ModelConfigStore
|
|
:param event_bus: Optional EventServiceBase object. If provided,
|
|
|
|
Installation and download events will be sent to the event bus as "model_event".
|
|
"""
|
|
self._event_bus = event_bus
|
|
self._config = config
|
|
kwargs: Dict[str, Any] = {}
|
|
if self._event_bus:
|
|
kwargs.update(event_handlers=[self._event_bus.emit_model_event])
|
|
self._precision = get_precision()
|
|
self._installer = ModelInstall(store, config, **kwargs)
|
|
self._logger = InvokeAILogger.get_logger()
|
|
|
|
def start(self, invoker: Any): # Because .processor is giving circular import errors, declaring invoker an 'Any'
|
|
"""Call automatically at process start."""
|
|
self._installer.scan_models_directory() # synchronize new/deleted models found in models directory
|
|
|
|
def install_model(
|
|
self,
|
|
source: Union[str, Path, AnyHttpUrl],
|
|
priority: int = 10,
|
|
model_attributes: Optional[Dict[str, Any]] = None,
|
|
) -> ModelInstallJob:
|
|
"""
|
|
Add a model using a path, repo_id or URL.
|
|
|
|
:param model_attributes: Dictionary of ModelConfigBase fields to
|
|
attach to the model. When installing a URL or repo_id, some metadata
|
|
values, such as `tags` will be automagically added.
|
|
:param priority: Queue priority for this install job. Lower value jobs
|
|
will run before higher value ones.
|
|
"""
|
|
self.logger.debug(f"add model {source}")
|
|
variant = "fp16" if self._precision == "float16" else None
|
|
job = self._installer.install(
|
|
source,
|
|
probe_override=model_attributes,
|
|
variant=variant,
|
|
priority=priority,
|
|
)
|
|
assert isinstance(job, ModelInstallJob)
|
|
return job
|
|
|
|
def list_install_jobs(self) -> List[ModelInstallJob]:
|
|
"""Return a series of active or enqueued ModelInstallJobs."""
|
|
queue = self._installer.queue
|
|
jobs: List[DownloadJobBase] = queue.list_jobs()
|
|
return [parse_obj_as(ModelInstallJob, x) for x in jobs] # downcast to proper type
|
|
|
|
def id_to_job(self, id: int) -> ModelInstallJob:
|
|
"""Return the ModelInstallJob instance corresponding to the given job ID."""
|
|
job = self._installer.queue.id_to_job(id)
|
|
assert isinstance(job, ModelInstallJob)
|
|
return job
|
|
|
|
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]:
|
|
"""
|
|
Wait for all pending installs to complete.
|
|
|
|
This will block until all pending downloads have
|
|
completed, been cancelled, or errored out. It will
|
|
block indefinitely if one or more jobs are in the
|
|
paused state.
|
|
|
|
It will return a dict that maps the source model
|
|
path, URL or repo_id to the ID of the installed model.
|
|
"""
|
|
return self._installer.wait_for_installs()
|
|
|
|
def start_job(self, job_id: int):
|
|
"""Start the given install job if it is paused or idle."""
|
|
queue = self._installer.queue
|
|
queue.start_job(queue.id_to_job(job_id))
|
|
|
|
def pause_job(self, job_id: int):
|
|
"""Pause the given install job if it is paused or idle."""
|
|
queue = self._installer.queue
|
|
queue.pause_job(queue.id_to_job(job_id))
|
|
|
|
def cancel_job(self, job_id: int):
|
|
"""Cancel the given install job."""
|
|
queue = self._installer.queue
|
|
queue.cancel_job(queue.id_to_job(job_id))
|
|
|
|
def cancel_all_jobs(self):
|
|
"""Cancel all active install job."""
|
|
queue = self._loader.queue
|
|
queue.cancel_all_jobs()
|
|
|
|
def prune_jobs(self):
|
|
"""Cancel all active install job."""
|
|
queue = self._loader.queue
|
|
queue.prune_jobs()
|
|
|
|
def change_job_priority(self, job_id: int, delta: int):
|
|
"""
|
|
Change an install job's priority.
|
|
|
|
:param job_id: Job to change
|
|
:param delta: Value to increment or decrement priority.
|
|
|
|
Lower values are higher priority. The default starting value is 10.
|
|
Thus to make this a really high priority job:
|
|
manager.change_job_priority(-10).
|
|
"""
|
|
queue = self._installer.queue
|
|
queue.change_priority(queue.id_to_job(job_id), delta)
|
|
|
|
def del_model(
|
|
self,
|
|
key: str,
|
|
delete_files: bool = False,
|
|
):
|
|
"""
|
|
Delete the named model from configuration.
|
|
|
|
If delete_files is true,
|
|
then the underlying weight file or diffusers directory will be deleted
|
|
as well.
|
|
"""
|
|
model_info = self.store.get_model(key)
|
|
self.logger.debug(f"delete model {model_info.name}")
|
|
self.store.del_model(key)
|
|
if delete_files and Path(self._config.models_path / model_info.path).exists():
|
|
path = Path(model_info.path)
|
|
if path.is_dir():
|
|
shutil.rmtree(path)
|
|
else:
|
|
path.unlink()
|
|
|
|
def convert_model(
|
|
self,
|
|
key: str,
|
|
convert_dest_directory: Path,
|
|
) -> ModelConfigBase:
|
|
"""
|
|
Convert a checkpoint file into a diffusers folder.
|
|
|
|
Delete the cached
|
|
version and delete the original checkpoint file if it is in the models
|
|
directory.
|
|
|
|
:param key: Key of the model to convert
|
|
:param convert_dest_directory: Save the converted model to the designated directory (`models/etc/etc` by default)
|
|
|
|
This will raise a ValueError unless the model is not a checkpoint. It will
|
|
also raise a ValueError in the event that there is a similarly-named diffusers
|
|
directory already in place.
|
|
"""
|
|
model_info = self.store.get_model(key)
|
|
self.logger.info(f"Converting model {model_info.name} into a diffusers")
|
|
return self._installer.convert_model(key, convert_dest_directory)
|
|
|
|
@property
|
|
def logger(self):
|
|
"""Get the logger associated with this instance."""
|
|
return self._loader.logger
|
|
|
|
@property
|
|
def store(self):
|
|
"""Get the store associated with this instance."""
|
|
return self._installer.store
|
|
|
|
def merge_models(
|
|
self,
|
|
model_keys: List[str] = Field(
|
|
default=None, min_items=2, max_items=3, description="List of model keys to merge"
|
|
),
|
|
merged_model_name: Optional[str] = Field(default=None, description="Name of destination model after merging"),
|
|
alpha: Optional[float] = 0.5,
|
|
interp: Optional[MergeInterpolationMethod] = None,
|
|
force: Optional[bool] = False,
|
|
merge_dest_directory: Optional[Path] = None,
|
|
) -> ModelConfigBase:
|
|
"""
|
|
Merge two to three diffusrs pipeline models and save as a new model.
|
|
|
|
:param model_keys: List of 2-3 model unique keys to merge
|
|
:param merged_model_name: Name of destination merged model
|
|
:param alpha: Alpha strength to apply to 2d and 3d model
|
|
:param interp: Interpolation method. None (default)
|
|
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
|
"""
|
|
merger = ModelMerger(self.store)
|
|
try:
|
|
if not merged_model_name:
|
|
merged_model_name = "+".join([self.store.get_model(x).name for x in model_keys])
|
|
raise Exception("not implemented")
|
|
|
|
result = merger.merge_diffusion_models_and_save(
|
|
model_keys=model_keys,
|
|
merged_model_name=merged_model_name,
|
|
alpha=alpha,
|
|
interp=interp,
|
|
force=force,
|
|
merge_dest_directory=merge_dest_directory,
|
|
)
|
|
except AssertionError as e:
|
|
raise ValueError(e)
|
|
return result
|
|
|
|
def search_for_models(self, directory: Path) -> Set[Path]:
|
|
"""
|
|
Return list of all models found in the designated directory.
|
|
|
|
:param directory: Path to the directory to recursively search.
|
|
returns a list of model paths
|
|
"""
|
|
return ModelSearch().search(directory)
|
|
|
|
def sync_to_config(self):
|
|
"""
|
|
Synchronize the model manager to the database.
|
|
|
|
Re-read models.yaml, rescan the models directory, and reimport models
|
|
in the autoimport directories. Call after making changes outside the
|
|
model manager API.
|
|
"""
|
|
return self._installer.sync_to_config()
|
|
|
|
def list_checkpoint_configs(self) -> List[Path]:
|
|
"""List the checkpoint config paths from ROOT/configs/stable-diffusion."""
|
|
config = self._config
|
|
conf_path = config.legacy_conf_path
|
|
root_path = config.root_path
|
|
return [(conf_path / x).relative_to(root_path) for x in conf_path.glob("**/*.yaml")]
|