model_manager_service now mostly type correct

This commit is contained in:
Lincoln Stein
2023-09-14 21:12:31 -04:00
parent 171d789646
commit 716a1b6423
11 changed files with 345 additions and 393 deletions

View File

@ -25,8 +25,6 @@ from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsSto
from ..services.model_manager_service import ModelManagerService
from ..services.processor import DefaultInvocationProcessor
from ..services.sqlite import SqliteItemStorage
from ..services.model_manager_service import ModelManagerService
from ..services.download_manager import DownloadQueueService
from ..services.invocation_stats import InvocationStatsService
from .events import FastAPIEventService
@ -128,7 +126,6 @@ class ApiDependencies:
processor=DefaultInvocationProcessor(),
configuration=config,
performance_statistics=InvocationStatsService(graph_execution_manager),
download_manager=DownloadQueueService(event_bus=events),
logger=logger,
)

View File

@ -195,5 +195,5 @@ class EventServiceBase:
def emit_model_download_event(self, job: DownloadJobBase):
"""Emit event when the status of a download job changes."""
self.dispatch( # use dispatch() directly here because we are not a session event.
event_name="download_job_event", payload=dict(job=job)
event_name="install_model_event", payload=dict(job=job)
)

View File

@ -14,7 +14,6 @@ if TYPE_CHECKING:
from invokeai.app.services.invocation_queue import InvocationQueueABC
from invokeai.app.services.invocation_stats import InvocationStatsServiceBase
from invokeai.app.services.invoker import InvocationProcessorABC
from invokeai.app.services.download_manager import DownloadQueueServiceBase
from invokeai.app.services.item_storage import ItemStorageABC
from invokeai.app.services.latent_storage import LatentsStorageBase
from invokeai.app.services.model_manager_service import ModelManagerServiceBase
@ -36,7 +35,6 @@ class InvocationServices:
model_manager: "ModelManagerServiceBase"
processor: "InvocationProcessorABC"
performance_statistics: "InvocationStatsServiceBase"
download_manager: Optional["DownloadQueueServiceBase"]
queue: "InvocationQueueABC"
def __init__(
@ -54,7 +52,6 @@ class InvocationServices:
processor: "InvocationProcessorABC",
performance_statistics: "InvocationStatsServiceBase",
queue: "InvocationQueueABC",
download_manager: Optional["DownloadQueueServiceBase"] = None, # optional for now pending design decisions
):
self.board_images = board_images
self.boards = boards
@ -67,7 +64,6 @@ class InvocationServices:
self.latents = latents
self.logger = logger
self.model_manager = model_manager
self.download_manager = download_manager
self.processor = processor
self.performance_statistics = performance_statistics
self.queue = queue

View File

@ -2,38 +2,35 @@
from __future__ import annotations
import shutil
from abc import ABC, abstractmethod
from logging import Logger
from pathlib import Path
from types import ModuleType
from invokeai.backend.model_manager import (
BaseModelType,
DownloadJobBase,
MergeInterpolationMethod,
ModelConfigBase,
ModelInfo,
ModelInstallJob,
ModelLoader,
ModelMerger,
ModelSearch,
ModelType,
SchedulerPredictionType,
SubModelType,
UnknownModelException,
DuplicateModelException
)
from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.model_manager.cache import CacheStats
from typing import TYPE_CHECKING, Callable, List, Literal, Optional, Tuple, Union
from typing import TYPE_CHECKING, List, Optional, Union, Dict, Any
import torch
from pydantic import Field
from pydantic.networks import AnyHttpUrl
from invokeai.app.models.exceptions import CanceledException
from ...backend.util import choose_precision, choose_torch_device
from .config import InvokeAIAppConfig
if TYPE_CHECKING:
from ..invocations.baseinvocation import BaseInvocation, InvocationContext
from ..invocations.baseinvocation import InvocationContext
class ModelManagerServiceBase(ABC):
@ -43,7 +40,6 @@ class ModelManagerServiceBase(ABC):
def __init__(
self,
config: InvokeAIAppConfig,
logger: ModuleType,
):
"""
Initialize with the path to the models.yaml config file.
@ -55,17 +51,17 @@ class ModelManagerServiceBase(ABC):
@abstractmethod
def get_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
node: Optional[BaseInvocation] = None,
context: Optional[InvocationContext] = None,
self,
key: str,
submodel_type: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None,
) -> ModelInfo:
"""Retrieve the indicated model with name and type.
submodel can be used to get a part (such as the vae)
of a diffusers pipeline."""
"""Retrieve the indicated model identified by key.
:param key: Unique key returned by the ModelConfigStore module.
:param submodel_type: Submodel to return (required for main models)
:param context" Optional InvocationContext, used in event reporting.
"""
pass
@property
@ -75,15 +71,13 @@ class ModelManagerServiceBase(ABC):
@abstractmethod
def model_exists(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
self,
key: str,
) -> bool:
pass
@abstractmethod
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
def model_info(self, key: str) -> ModelConfigBase:
"""
Given a model name returns a dict-like (OmegaConf) object describing it.
Uses the exact format as the omegaconf stanza.
@ -91,63 +85,60 @@ class ModelManagerServiceBase(ABC):
pass
@abstractmethod
def list_models(self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None) -> dict:
def list_models(self,
model_name: Optional[str] = None,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None,
) -> List[ModelConfigBase]:
"""
Return a dict of models in the format:
{ model_type1:
{ model_name1: {'status': 'active'|'cached'|'not loaded',
'model_name' : name,
'model_type' : SDModelType,
'description': description,
'format': 'folder'|'safetensors'|'ckpt'
},
model_name2: { etc }
},
model_type2:
{ model_name_n: etc
}
Return a list of ModelConfigBases that match the base, type and name criteria.
:param base_model: Filter by the base model type.
:param model_type: Filter by the model type.
:param model_name: Filter by the model name.
"""
pass
@abstractmethod
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> ModelConfigBase:
"""
Return information about the model using the same format as list_models()
Return information about the model using the same format as list_models().
If there are more than one model that match, raises a DuplicateModelException.
If no model matches, raises an UnknownModelException
"""
pass
model_configs = self.list_models(
model_name=model_name,
base_model=base_model,
model_type=model_type
)
if len(model_configs) > 1:
raise DuplicateModelException("More than one model share the same name and type: {base_model}/{model_type}/{model_name}")
if len(model_configs) == 0:
raise UnknownModelException("No known model with name and type: {base_model}/{model_type}/{model_name}")
return model_configs[0]
@abstractmethod
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
def all_models(self) -> List[ModelConfigBase]:
"""
Returns a list of all the model names known.
Returns a list of all the models.
"""
pass
return self.list_models()
@abstractmethod
def add_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
model_attributes: dict,
clobber: bool = False,
) -> InstallJobBase:
self,
model_path: Path,
probe_overrides: Optional[Dict[str, Any]] = None,
wait: bool = False
) -> ModelInstallJob:
"""
Update the named model with a dictionary of attributes. Will fail with an
assertion error if the name already exists. Pass clobber=True to overwrite.
On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk.
Add a model using its path, with a dictionary of attributes. Will fail with an
assertion error if the name already exists.
"""
pass
@abstractmethod
def update_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
model_attributes: dict,
self,
key: str,
new_config: Union[dict, ModelConfigBase],
) -> ModelConfigBase:
"""
Update the named model with a dictionary of attributes. Will fail with a
@ -155,36 +146,32 @@ class ModelManagerServiceBase(ABC):
On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk.
the model key is unknown.
"""
pass
@abstractmethod
def del_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
self,
key: str,
delete_files: bool = False
):
"""
Delete the named model from configuration. If delete_files is true,
then the underlying weight file or diffusers directory will be deleted
as well. Call commit() to write to disk.
Delete the named model from configuration. If delete_files
is true, then the underlying file or directory will be
deleted as well.
"""
pass
@abstractmethod
def rename_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
new_name: str,
):
self,
key: str,
new_name: str,
) -> ModelConfigBase:
"""
Rename the indicated model.
"""
pass
return self.update_model(key, {"name": new_name})
@abstractmethod
def list_checkpoint_configs(self) -> List[Path]:
@ -195,18 +182,17 @@ class ModelManagerServiceBase(ABC):
@abstractmethod
def convert_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: Literal[ModelType.Main, ModelType.Vae],
) -> InstallJobBase:
self,
key: str,
convert_dest_directory: Path,
) -> ModelConfigBase:
"""
Convert a checkpoint file into a diffusers folder, deleting the cached
version and deleting the original checkpoint file if it is in the models
directory.
:param model_name: Name of the model to convert
:param base_model: Base model type
:param model_type: Type of model ['vae' or 'main']
Convert a checkpoint file into a diffusers folder.
This will delete the cached version if there is any and delete the original
checkpoint file if it is in the models directory.
:param key: Unique key for the model to convert.
:param convert_dest_directory: Save the converted model to the designated directory (`models/etc/etc` by default)
This will raise a ValueError unless the model is not a checkpoint. It will
also raise a ValueError in the event that there is a similarly-named diffusers
@ -215,37 +201,34 @@ class ModelManagerServiceBase(ABC):
pass
@abstractmethod
def heuristic_import(
self,
items_to_import: set[str],
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
) -> InstallJobBase:
"""Import a list of paths, repo_ids or URLs. Returns the set of
successfully imported items.
:param items_to_import: Set of strings corresponding to models to be imported.
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
def install_model (
self,
source: Union[str, Path, AnyHttpUrl],
model_attributes: Optional[Dict[str, Any]] = None,
) -> ModelInstallJob:
"""Import a path, repo_id or URL. Returns an ModelInstallJob.
The prediction type helper is necessary to distinguish between
models based on Stable Diffusion 2 Base (requiring
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
(requiring SchedulerPredictionType.VPrediction). It is
generally impossible to do this programmatically, so the
prediction_type_helper usually asks the user to choose.
:param model_attributes: Additional attributes to supplement/override
the model information gained from automated probing.
The result is a set of successfully installed models. Each element
of the set is a dict corresponding to the newly-created OmegaConf stanza for
that model.
Typical usage:
job = model_manager.install(
'stabilityai/stable-diffusion-2-1',
model_attributes={'prediction_type": 'v-prediction'}
)
The result is an ModelInstallJob object, which provides
information on the asynchronous model download and install
process. A series of "install_model_event" events will be emitted
until the install is completed, cancelled or errors out.
"""
pass
@abstractmethod
def merge_models(
self,
model_names: List[str] = Field(
default=None, min_items=2, max_items=3, description="List of model names to merge"
),
base_model: Union[BaseModelType, str] = Field(
default=None, description="Base model shared by all models to be merged"
model_keys: List[str] = Field(
default=None, min_items=2, max_items=3, description="List of model keys to merge"
),
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
alpha: Optional[float] = 0.5,
@ -255,8 +238,7 @@ class ModelManagerServiceBase(ABC):
) -> ModelConfigBase:
"""
Merge two to three diffusrs pipeline models and save as a new model.
:param model_names: List of 2-3 models to merge
:param base_model: Base model to use for all models
:param model_keys: List of 2-3 model unique keys to merge
:param merged_model_name: Name of destination merged model
:param alpha: Alpha strength to apply to 2d and 3d model
:param interp: Interpolation method. None (default)
@ -287,24 +269,16 @@ class ModelManagerServiceBase(ABC):
"""
pass
@abstractmethod
def commit(self, conf_file: Optional[Path] = None) -> None:
"""
Write current configuration out to the indicated file.
If no conf_file is provided, then replaces the
original file/database used to initialize the object.
"""
pass
# simple implementation
# implementation
class ModelManagerService(ModelManagerServiceBase):
"""Responsible for managing models on disk and in memory"""
_loader: ModelLoader = Field(description="InvokeAIAppConfig object for the current process")
def __init__(
self,
config: InvokeAIAppConfig,
logger: Logger,
):
"""
Initialize with the path to the models.yaml config file.
@ -312,218 +286,164 @@ class ModelManagerService(ModelManagerServiceBase):
and sequential_offload boolean. Note that the default device
type and precision are set up for a CUDA system running at half precision.
"""
if config.model_conf_path and config.model_conf_path.exists():
config_file = config.model_conf_path
else:
config_file = config.root_dir / "configs/models.yaml"
logger.debug(f"Config file={config_file}")
device = torch.device(choose_torch_device())
device_name = torch.cuda.get_device_name() if device == torch.device("cuda") else ""
logger.info(f"GPU device = {device} {device_name}")
precision = config.precision
if precision == "auto":
precision = choose_precision(device)
dtype = torch.float32 if precision == "float32" else torch.float16
# this is transitional backward compatibility
# support for the deprecated `max_loaded_models`
# configuration value. If present, then the
# cache size is set to 2.5 GB times
# the number of max_loaded_models. Otherwise
# use new `ram_cache_size` config setting
max_cache_size = config.ram_cache_size
logger.debug(f"Maximum RAM cache size: {max_cache_size} GiB")
sequential_offload = config.sequential_guidance
self.mgr = ModelManager(
config=config_file,
device_type=device,
precision=dtype,
max_cache_size=max_cache_size,
sequential_offload=sequential_offload,
logger=logger,
)
logger.info("Model manager service initialized")
self._loader = ModelLoader(config)
def get_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None,
self,
key: str,
submodel_type: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None,
) -> ModelInfo:
"""
Retrieve the indicated model. submodel can be used to get a
part (such as the vae) of a diffusers mode.
"""
model_info: ModelInfo = self._loader.get_model(key, submodel_type)
# we can emit model loading events if we are executing with access to the invocation context
if context:
self._emit_load_event(
context=context,
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel,
)
model_info = self.mgr.get_model(
model_name,
base_model,
model_type,
submodel,
)
if context:
self._emit_load_event(
context=context,
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel,
key=key,
submodel_type=submodel_type,
model_info=model_info,
)
return model_info
def model_exists(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
self,
key: str,
) -> bool:
"""
Given a model name, returns True if it is a valid
Given a model key, returns True if it is a valid
identifier.
"""
return self.mgr.model_exists(
model_name,
base_model,
model_type,
)
return self._loader.store.exists(key)
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]:
def model_info(self, key: str) -> ModelConfigBase:
"""
Given a model name returns a dict-like (OmegaConf) object describing it.
"""
return self.mgr.model_info(model_name, base_model, model_type)
return self._loader.store.get_model(key)
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
"""
Returns a list of all the model names known.
"""
return self.mgr.model_names()
# def all_models(self) -> List[ModelConfigBase] -- defined in base class, same as list_models()
# def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -- defined in base class
def list_models(
self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None
) -> list[dict]:
def list_models(self,
model_name: Optional[str] = None,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None,
) -> List[ModelConfigBase]:
"""
Return a list of models.
Return a ModelConfigBase object for each model in the database.
"""
return self.mgr.list_models(base_model, model_type)
return self._loader.store.search_by_name(model_name, base_model, model_type)
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]:
"""
Return information about the model using the same format as list_models()
"""
return self.mgr.list_model(model_name=model_name, base_model=base_model, model_type=model_type)
def add_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
model_attributes: dict,
clobber: bool = False,
) -> InstallJobBase:
self,
model_path: Path,
model_attributes: Optional[dict] = None,
wait: bool = False
) -> ModelInstallJob:
"""
Update the named model with a dictionary of attributes. Will fail with an
assertion error if the name already exists. Pass clobber=True to overwrite.
On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk.
Add a model using its path, with a dictionary of attributes. Will fail with an
assertion error if the name already exists.
"""
self.logger.debug(f"add/update model {model_name}")
return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber)
self.logger.debug(f"add/update model {model_path}")
return self._loader.installer.install(
model_path,
probe_override=model_attributes,
)
def install_model(
self,
source: Union[str, Path, AnyHttpUrl],
model_attributes: Optional[Dict[str, Any]] = None,
) -> ModelInstallJob:
"""
Add a model using its path, with a dictionary of attributes. Will fail with an
assertion error if the name already exists.
"""
self.logger.debug(f"add/update model {source}")
variant = 'fp16' if self._loader.precision == 'float16' else None
return self._loader.installer.install(
source,
probe_override=model_attributes,
variant=variant,
)
def update_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
model_attributes: dict,
) -> InstallJobBase:
self,
key: str,
new_config: Union[dict, ModelConfigBase],
) -> ModelConfigBase:
"""
Update the named model with a dictionary of attributes. Will fail with a
UnknownModelException exception if the name does not already exist.
UnknownModelException if the name does not already exist.
On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk.
the model key is unknown.
"""
self.logger.debug(f"update model {model_name}")
if not self.model_exists(model_name, base_model, model_type):
raise UnknownModelException(f"Unknown model {model_name}")
return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True)
model_info = self.model_info(key)
self.logger.debug(f"update model {model_info.name}")
self.logger.warning("TO DO: write code to move models around if base or type change")
return self._loader.store.update_model(key, new_config)
def del_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
self,
key: str,
delete_files: bool = False,
):
"""
Delete the named model from configuration. If delete_files is true,
then the underlying weight file or diffusers directory will be deleted
as well.
"""
self.logger.debug(f"delete model {model_name}")
self.mgr.del_model(model_name, base_model, model_type)
self.mgr.commit()
model_info = self.model_info(key)
self.logger.debug(f"delete model {model_info.name}")
self._loader.store.del_model(key)
if delete_files and Path(model_info.path).exists():
path = Path(model_info)
if path.is_dir():
shutil.rmtree(path)
else:
path.unlink()
def convert_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: Literal[ModelType.Main, ModelType.Vae],
convert_dest_directory: Optional[Path] = Field(
default=None, description="Optional directory location for merged model"
),
) -> InstallJobBase:
self,
key: str,
convert_dest_directory: Path,
) -> ModelConfigBase:
"""
Convert a checkpoint file into a diffusers folder, deleting the cached
version and deleting the original checkpoint file if it is in the models
directory.
:param model_name: Name of the model to convert
:param base_model: Base model type
:param model_type: Type of model ['vae' or 'main']
:param key: Key of the model to convert
:param convert_dest_directory: Save the converted model to the designated directory (`models/etc/etc` by default)
This will raise a ValueError unless the model is not a checkpoint. It will
also raise a ValueError in the event that there is a similarly-named diffusers
directory already in place.
"""
self.logger.debug(f"convert model {model_name}")
return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory)
model_info = self.model_info(key)
self.logger.debug(f"convert model {model_info.name}")
self.logger.warning('This is not implemented yet')
return self._loader.convert_model(key, convert_dest_directory)
def collect_cache_stats(self, cache_stats: CacheStats):
"""
Reset model cache statistics for graph with graph_id.
Reset model cache statistics. Is this used?
"""
self.mgr.cache.stats = cache_stats
def commit(self, conf_file: Optional[Path] = None):
"""
Write current configuration out to the indicated file.
If no conf_file is provided, then replaces the
original file/database used to initialize the object.
"""
return self.mgr.commit(conf_file)
self._loader.collect_cache_stats(cache_stats)
def _emit_load_event(
self,
@ -557,51 +477,22 @@ class ModelManagerService(ModelManagerServiceBase):
@property
def logger(self):
return self.mgr.logger
def heuristic_import(
self,
items_to_import: set[str],
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
) -> dict[str, InstallJobBase]:
"""Import a list of paths, repo_ids or URLs. Returns the set of
successfully imported items.
:param items_to_import: Set of strings corresponding to models to be imported.
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
The prediction type helper is necessary to distinguish between
models based on Stable Diffusion 2 Base (requiring
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
(requiring SchedulerPredictionType.VPrediction). It is
generally impossible to do this programmatically, so the
prediction_type_helper usually asks the user to choose.
The result is a set of successfully installed models. Each element
of the set is a dict corresponding to the newly-created OmegaConf stanza for
that model.
"""
return self.mgr.heuristic_import(items_to_import, prediction_type_helper)
return self._loader.logger
def merge_models(
self,
model_names: List[str] = Field(
default=None, min_items=2, max_items=3, description="List of model names to merge"
),
base_model: Union[BaseModelType, str] = Field(
default=None, description="Base model shared by all models to be merged"
),
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
alpha: float = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
force: bool = False,
merge_dest_directory: Optional[Path] = Field(
default=None, description="Optional directory location for merged model"
),
) -> str:
self,
model_keys: List[str] = Field(
default=None, min_items=2, max_items=3, description="List of model keys to merge"
),
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
alpha: Optional[float] = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
force: Optional[bool] = False,
merge_dest_directory: Optional[Path] = None,
) -> ModelConfigBase:
"""
Merge two to three diffusrs pipeline models and save as a new model.
:param model_names: List of 2-3 models to merge
:param base_model: Base model to use for all models
:param model_keys: List of 2-3 model unique keys to merge
:param merged_model_name: Name of destination merged model
:param alpha: Alpha strength to apply to 2d and 3d model
:param interp: Interpolation method. None (default)
@ -609,9 +500,9 @@ class ModelManagerService(ModelManagerServiceBase):
"""
merger = ModelMerger(self.mgr)
try:
self.logger.error('ModelMerger needs to be rewritten.')
result = merger.merge_diffusion_models_and_save(
model_names=model_names,
base_model=base_model,
model_keys=model_keys,
merged_model_name=merged_model_name,
alpha=alpha,
interp=interp,
@ -626,8 +517,7 @@ class ModelManagerService(ModelManagerServiceBase):
"""
Return list of all models found in the designated directory.
"""
search = FindModels([directory], self.logger)
return search.list_models()
return ModelSearch().search(directory)
def sync_to_config(self):
"""
@ -635,7 +525,7 @@ class ModelManagerService(ModelManagerServiceBase):
in the autoimport directories. Call after making changes outside the
model manager API.
"""
return self.mgr.sync_to_config()
return self._loader.sync_to_config()
def list_checkpoint_configs(self) -> List[Path]:
"""
@ -648,24 +538,13 @@ class ModelManagerService(ModelManagerServiceBase):
def rename_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
new_name: Optional[str] = None,
new_base: Optional[BaseModelType] = None,
key: str,
new_name: str,
):
"""
Rename the indicated model. Can provide a new name and/or a new base.
:param model_name: Current name of the model
:param base_model: Current base of the model
:param model_type: Model type (can't be changed)
Rename the indicated model to the new name.
:param key: Unique key for the model.
:param new_name: New name for the model
:param new_base: New base for the model
"""
self.mgr.rename_model(
base_model=base_model,
model_type=model_type,
model_name=model_name,
new_name=new_name,
new_base=new_base,
)
return self.update_model(key, {"name": new_name})

View File

@ -16,7 +16,7 @@ from .config import ( # noqa F401
)
from .lora import ONNXModelPatcher, ModelPatcher
from .loader import ModelLoader, ModelInfo # noqa F401
from .install import ModelInstall, DownloadJobBase # noqa F401
from .install import ModelInstall, ModelInstallJob # noqa F401
from .probe import ModelProbe, InvalidModelException # noqa F401
from .storage import (
UnknownModelException,

View File

@ -136,6 +136,11 @@ class ModelConfigBase(BaseModel):
v = list(v)
return v
def update(self, attributes: dict):
"""Update the object with fields in dict."""
for key, value in attributes.items():
setattr(self, key, value) # may raise a validation error
class CheckpointConfig(ModelConfigBase):
"""Model config for checkpoint-style models."""

View File

@ -329,14 +329,14 @@ class DownloadQueue(DownloadQueueBase):
self._dones += 1
self._queue.task_done()
def _fetch_metadata(self, job: DownloadJobBase) -> Tuple[AnyHttpUrl, ModelSourceMetadata]:
def _get_metadata_and_url(self, job: DownloadJobBase) -> AnyHttpUrl:
"""
Fetch metadata from certain well-known URLs.
The metadata will be stashed in job.metadata, if found
Return the download URL.
"""
metadata = ModelSourceMetadata()
metadata = job.metadata
url = job.source
metadata_url = url
try:
@ -344,12 +344,14 @@ class DownloadQueue(DownloadQueueBase):
if match := re.match(CIVITAI_MODEL_DOWNLOAD + r"(\d+)", metadata_url):
version = match.group(1)
resp = self._requests.get(CIVITAI_VERSIONS_ENDPOINT + version).json()
metadata.thumbnail_url = resp["images"][0]["url"]
metadata.description = (
f"Trigger terms: {(', ').join(resp['trainedWords'])}"
if resp["trainedWords"]
else resp["description"]
)
metadata.thumbnail_url = metadata.thumbnail_url \
or resp["images"][0]["url"]
metadata.description = metadata.description \
or (
f"Trigger terms: {(', ').join(resp['trainedWords'])}"
if resp["trainedWords"]
else resp["description"]
)
metadata_url = CIVITAI_MODEL_PAGE + str(resp["modelId"])
# a Civitai model page
@ -360,21 +362,22 @@ class DownloadQueue(DownloadQueueBase):
# note that we munge the URL here to get the download URL of the first model
url = resp["modelVersions"][0]["downloadUrl"]
metadata.author = resp["creator"]["username"]
metadata.tags = resp["tags"]
metadata.thumbnail_url = resp["modelVersions"][0]["images"][0]["url"]
metadata.license = f"allowCommercialUse={resp['allowCommercialUse']}; allowDerivatives={resp['allowDerivatives']}; allowNoCredit={resp['allowNoCredit']}"
metadata.author = metadata.author or resp["creator"]["username"]
metadata.tags = metadata.tags or resp["tags"]
metadata.thumbnail_url = metadata.thumbnail_url \
or resp["modelVersions"][0]["images"][0]["url"]
metadata.license = metadata.license \
or f"allowCommercialUse={resp['allowCommercialUse']}; allowDerivatives={resp['allowDerivatives']}; allowNoCredit={resp['allowNoCredit']}"
except (HTTPError, KeyError, TypeError, JSONDecodeError) as excp:
self._logger.warn(excp)
# update metadata and return the download url
return url, metadata
# return the download url
return url
def _download_with_resume(self, job: DownloadJobBase):
"""Do the actual download."""
try:
url, metadata = self._fetch_metadata(job)
job.metadata = metadata
url = self._get_metadata_and_url(job)
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
open_mode = "wb"
@ -602,7 +605,6 @@ class DownloadQueue(DownloadQueueBase):
"""Call when the source is a Path or pathlike object."""
source = Path(job.source).resolve()
destination = Path(job.destination).resolve()
job.metadata = ModelSourceMetadata()
try:
if source != destination:
shutil.move(source, destination)

View File

@ -53,7 +53,7 @@ import tempfile
from abc import ABC, abstractmethod
from pathlib import Path
from shutil import rmtree
from typing import Optional, List, Union, Dict, Set
from typing import Optional, List, Union, Dict, Set, Any
from pydantic import Field
from pydantic.networks import AnyHttpUrl
from invokeai.app.services.config import InvokeAIAppConfig
@ -79,8 +79,8 @@ class ModelInstallJob(DownloadJobBase):
model_key: Optional[str] = Field(
description="After model installation, this field will hold its primary key", default=None
)
probe_info: Optional[ModelProbeInfo] = Field(
description="If provided, information here will be used instead of probing the model.",
probe_override: Optional[Dict[str, Any]] = Field(
description="Keys in this dict will override like-named attributes in the automatic probe info",
default=None,
)
@ -316,9 +316,12 @@ class ModelInstall(ModelInstallBase):
"""Return the queue."""
return self._download_queue
def register_path(self, model_path: Union[Path, str], info: Optional[ModelProbeInfo] = None) -> str: # noqa D102
def register_path(self,
model_path: Union[Path, str],
overrides: Optional[Dict[str, Any]] = None
) -> str: # noqa D102
model_path = Path(model_path)
info: ModelProbeInfo = info or ModelProbe.probe(model_path)
info: ModelProbeInfo = self._probe_model(model_path, overrides)
return self._register(model_path, info)
def _register(self, model_path: Path, info: ModelProbeInfo) -> str:
@ -351,12 +354,13 @@ class ModelInstall(ModelInstallBase):
return key
def install_path(
self,
model_path: Union[Path, str],
info: Optional[ModelProbeInfo] = None,
self,
model_path: Union[Path, str],
overrides: Optional[Dict[str, Any]] = None,
) -> str: # noqa D102
model_path = Path(model_path)
info: ModelProbeInfo = info or ModelProbe.probe(model_path)
info: ModelProbeInfo = self._probe_model(model_path, overrides)
dest_path = self._config.models_path / info.base_type.value / info.model_type.value / model_path.name
dest_path.parent.mkdir(parents=True, exist_ok=True)
@ -371,6 +375,16 @@ class ModelInstall(ModelInstallBase):
info,
)
def _probe_model(self,
model_path: Union[Path, str],
overrides: Optional[Dict[str,Any]] = None
) -> ModelProbeInfo:
info: ModelProbeInfo = ModelProbe.probe(model_path)
if overrides: # used to override probe fields
for key, value in overrides.items():
setattr(info, key, value) # may generate a pydantic validation error
return info
def unregister(self, key: str): # noqa D102
self._store.del_model(key)
@ -380,12 +394,12 @@ class ModelInstall(ModelInstallBase):
self.unregister(key)
def install(
self,
source: Union[str, Path, AnyHttpUrl],
info: Optional[ModelProbeInfo] = None,
inplace: bool = True,
variant: Optional[str] = None,
access_token: Optional[str] = None,
self,
source: Union[str, Path, AnyHttpUrl],
inplace: bool = True,
variant: Optional[str] = None,
probe_override: Optional[Dict[str, Any]] = None,
access_token: Optional[str] = None,
) -> DownloadJobBase: # noqa D102
queue = self._download_queue
@ -395,8 +409,8 @@ class ModelInstall(ModelInstallBase):
if inplace and Path(source).exists()
else self._complete_installation_handler
)
job.probe_override = probe_override
job.add_event_handler(handler)
job.probe_info = info
self._async_installs[source] = None
queue.submit_download_job(job, True)
@ -405,7 +419,7 @@ class ModelInstall(ModelInstallBase):
def _complete_installation_handler(self, job: DownloadJobBase):
if job.status == "completed":
self._logger.info(f"{job.source}: Download finished with status {job.status}. Installing.")
model_id = self.install_path(job.destination, job.probe_info)
model_id = self.install_path(job.destination, job.probe_override)
info = self._store.get_model(model_id)
info.source = str(job.source)
metadata: ModelSourceMetadata = job.metadata
@ -429,7 +443,7 @@ class ModelInstall(ModelInstallBase):
def _complete_registration_handler(self, job: DownloadJobBase):
if job.status == "completed":
self._logger.info(f"{job.source}: Installing in place.")
model_id = self.register_path(job.destination, job.probe_info)
model_id = self.register_path(job.destination, job.probe_override)
info = self._store.get_model(model_id)
info.source = str(job.source)
info.description = f"Imported model {info.name}"

View File

@ -14,7 +14,7 @@ from invokeai.backend.util import choose_precision, choose_torch_device, InvokeA
from .config import BaseModelType, ModelType, SubModelType, ModelConfigBase
from .install import ModelInstallBase, ModelInstall
from .storage import ModelConfigStore, get_config_store
from .cache import ModelCache, ModelLocker
from .cache import ModelCache, ModelLocker, CacheStats
from .models import InvalidModelException, ModelBase, MODEL_CLASSES
@ -69,6 +69,34 @@ class ModelLoaderBase(ABC):
"""Return the ModelInstallBase object that supports this loader."""
pass
@property
@abstractmethod
def logger(self) -> InvokeAILogger:
"""Return the current logger."""
pass
@abstractmethod
def collect_cache_stats(
self,
cache_stats: CacheStats
):
"""Replace cache statistics."""
pass
@property
@abstractmethod
def precision(self) -> str:
"""Return 'float32' or 'float16'."""
pass
@abstractmethod
def sync_to_config(self):
"""
Reinitialize the store to sync in-memory and in-disk
versions.
"""
pass
class ModelLoader(ModelLoaderBase):
"""Implementation of ModelLoaderBase."""
@ -79,6 +107,7 @@ class ModelLoader(ModelLoaderBase):
_cache: ModelCache
_logger: InvokeAILogger
_cache_keys: dict
_models_file: Path
def __init__(
self,
@ -102,6 +131,7 @@ class ModelLoader(ModelLoaderBase):
self._logger = InvokeAILogger.getLogger()
self._installer = ModelInstall(store=self._store, logger=self._logger, config=self._app_config)
self._cache_keys = dict()
self._models_file = models_file
device = torch.device(choose_torch_device())
device_name = torch.cuda.get_device_name() if device == torch.device("cuda") else ""
precision = choose_precision(device) if config.precision == "auto" else config.precision
@ -130,11 +160,21 @@ class ModelLoader(ModelLoaderBase):
"""Return the ModelConfigStore instance used by this class."""
return self._store
@property
def precision(self) -> str:
"""Return 'float32' or 'float16'."""
return self._cache.precision
@property
def installer(self) -> ModelInstallBase:
"""Return the ModelInstallBase instance used by this class."""
return self._installer
@property
def logger(self) -> InvokeAILogger:
"""Return the current logger."""
return self._logger
def get_model(self, key: str, submodel_type: Optional[SubModelType] = None) -> ModelInfo:
"""
Get the ModelInfo corresponding to the model with key "key".
@ -188,6 +228,12 @@ class ModelLoader(ModelLoaderBase):
_cache=self._cache,
)
def collect_cache_stats(
self,
cache_stats: CacheStats
):
self._cache.stats = cache_stats
def _get_implementation(self, base_model: BaseModelType, model_type: ModelType) -> type[ModelBase]:
"""Get the concrete implementation class for a specific model type."""
model_class = MODEL_CLASSES[base_model][model_type]
@ -220,6 +266,10 @@ class ModelLoader(ModelLoaderBase):
model_path = self._resolve_model_path(model_path)
return model_path, is_submodel_override
def sync_to_config(self):
self._store = get_config_store(self._models_file)
self._scan_models_directory()
def _scan_models_directory(self):
defunct_models = set()
installed = set()

View File

@ -124,17 +124,26 @@ class ModelMerger(object):
dump_path = (dump_path / merged_model_name).as_posix()
merged_pipe.save_pretrained(dump_path, safe_serialization=True)
attributes = dict(
path=dump_path,
description=f"Merge of models {', '.join(model_names)}",
model_format="diffusers",
variant=ModelVariantType.Normal.value,
vae=vae,
)
return self.manager.add_model(
merged_model_name,
base_model=base_model,
# register model and get its unique key
info = ModelProbeInfo(
model_type=ModelType.Main,
model_attributes=attributes,
clobber=True,
base_type=base_model,
format="diffusers",
)
key = self.manager.installer.register_path(
model_path=dump_path,
info=info,
)
# update model's config
model_config = self.manager.store.get_model(key)
model_config.update(
dict(
name=merged_model_name,
description=f"Merge of models {', '.join(model_names)}",
vae=vae,
)
)
self.manager.store.update_model(key, model_config)
return model_config

View File

@ -1,5 +1,5 @@
__metadata__:
version: 3.1.0
version: 3.1.1
ed799245c762f6d0a9ddfd4e31fdb010:
name: sdxl-base-1-0
path: sdxl/main/SDXL base 1_0