mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
model_manager_service now mostly type correct
This commit is contained in:
@ -25,8 +25,6 @@ from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsSto
|
||||
from ..services.model_manager_service import ModelManagerService
|
||||
from ..services.processor import DefaultInvocationProcessor
|
||||
from ..services.sqlite import SqliteItemStorage
|
||||
from ..services.model_manager_service import ModelManagerService
|
||||
from ..services.download_manager import DownloadQueueService
|
||||
from ..services.invocation_stats import InvocationStatsService
|
||||
from .events import FastAPIEventService
|
||||
|
||||
@ -128,7 +126,6 @@ class ApiDependencies:
|
||||
processor=DefaultInvocationProcessor(),
|
||||
configuration=config,
|
||||
performance_statistics=InvocationStatsService(graph_execution_manager),
|
||||
download_manager=DownloadQueueService(event_bus=events),
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
|
@ -195,5 +195,5 @@ class EventServiceBase:
|
||||
def emit_model_download_event(self, job: DownloadJobBase):
|
||||
"""Emit event when the status of a download job changes."""
|
||||
self.dispatch( # use dispatch() directly here because we are not a session event.
|
||||
event_name="download_job_event", payload=dict(job=job)
|
||||
event_name="install_model_event", payload=dict(job=job)
|
||||
)
|
||||
|
@ -14,7 +14,6 @@ if TYPE_CHECKING:
|
||||
from invokeai.app.services.invocation_queue import InvocationQueueABC
|
||||
from invokeai.app.services.invocation_stats import InvocationStatsServiceBase
|
||||
from invokeai.app.services.invoker import InvocationProcessorABC
|
||||
from invokeai.app.services.download_manager import DownloadQueueServiceBase
|
||||
from invokeai.app.services.item_storage import ItemStorageABC
|
||||
from invokeai.app.services.latent_storage import LatentsStorageBase
|
||||
from invokeai.app.services.model_manager_service import ModelManagerServiceBase
|
||||
@ -36,7 +35,6 @@ class InvocationServices:
|
||||
model_manager: "ModelManagerServiceBase"
|
||||
processor: "InvocationProcessorABC"
|
||||
performance_statistics: "InvocationStatsServiceBase"
|
||||
download_manager: Optional["DownloadQueueServiceBase"]
|
||||
queue: "InvocationQueueABC"
|
||||
|
||||
def __init__(
|
||||
@ -54,7 +52,6 @@ class InvocationServices:
|
||||
processor: "InvocationProcessorABC",
|
||||
performance_statistics: "InvocationStatsServiceBase",
|
||||
queue: "InvocationQueueABC",
|
||||
download_manager: Optional["DownloadQueueServiceBase"] = None, # optional for now pending design decisions
|
||||
):
|
||||
self.board_images = board_images
|
||||
self.boards = boards
|
||||
@ -67,7 +64,6 @@ class InvocationServices:
|
||||
self.latents = latents
|
||||
self.logger = logger
|
||||
self.model_manager = model_manager
|
||||
self.download_manager = download_manager
|
||||
self.processor = processor
|
||||
self.performance_statistics = performance_statistics
|
||||
self.queue = queue
|
||||
|
@ -2,38 +2,35 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import shutil
|
||||
from abc import ABC, abstractmethod
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
BaseModelType,
|
||||
DownloadJobBase,
|
||||
MergeInterpolationMethod,
|
||||
ModelConfigBase,
|
||||
ModelInfo,
|
||||
ModelInstallJob,
|
||||
ModelLoader,
|
||||
ModelMerger,
|
||||
ModelSearch,
|
||||
ModelType,
|
||||
SchedulerPredictionType,
|
||||
SubModelType,
|
||||
UnknownModelException,
|
||||
DuplicateModelException
|
||||
)
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
from invokeai.backend.model_manager.cache import CacheStats
|
||||
from typing import TYPE_CHECKING, Callable, List, Literal, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, List, Optional, Union, Dict, Any
|
||||
|
||||
import torch
|
||||
from pydantic import Field
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
|
||||
from invokeai.app.models.exceptions import CanceledException
|
||||
|
||||
from ...backend.util import choose_precision, choose_torch_device
|
||||
from .config import InvokeAIAppConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..invocations.baseinvocation import BaseInvocation, InvocationContext
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
|
||||
|
||||
class ModelManagerServiceBase(ABC):
|
||||
@ -43,7 +40,6 @@ class ModelManagerServiceBase(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
config: InvokeAIAppConfig,
|
||||
logger: ModuleType,
|
||||
):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file.
|
||||
@ -55,17 +51,17 @@ class ModelManagerServiceBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def get_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
node: Optional[BaseInvocation] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> ModelInfo:
|
||||
"""Retrieve the indicated model with name and type.
|
||||
submodel can be used to get a part (such as the vae)
|
||||
of a diffusers pipeline."""
|
||||
"""Retrieve the indicated model identified by key.
|
||||
|
||||
:param key: Unique key returned by the ModelConfigStore module.
|
||||
:param submodel_type: Submodel to return (required for main models)
|
||||
:param context" Optional InvocationContext, used in event reporting.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@ -75,15 +71,13 @@ class ModelManagerServiceBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def model_exists(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
self,
|
||||
key: str,
|
||||
) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||
def model_info(self, key: str) -> ModelConfigBase:
|
||||
"""
|
||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
||||
Uses the exact format as the omegaconf stanza.
|
||||
@ -91,63 +85,60 @@ class ModelManagerServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_models(self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None) -> dict:
|
||||
def list_models(self,
|
||||
model_name: Optional[str] = None,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
) -> List[ModelConfigBase]:
|
||||
"""
|
||||
Return a dict of models in the format:
|
||||
{ model_type1:
|
||||
{ model_name1: {'status': 'active'|'cached'|'not loaded',
|
||||
'model_name' : name,
|
||||
'model_type' : SDModelType,
|
||||
'description': description,
|
||||
'format': 'folder'|'safetensors'|'ckpt'
|
||||
},
|
||||
model_name2: { etc }
|
||||
},
|
||||
model_type2:
|
||||
{ model_name_n: etc
|
||||
}
|
||||
Return a list of ModelConfigBases that match the base, type and name criteria.
|
||||
:param base_model: Filter by the base model type.
|
||||
:param model_type: Filter by the model type.
|
||||
:param model_name: Filter by the model name.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> ModelConfigBase:
|
||||
"""
|
||||
Return information about the model using the same format as list_models()
|
||||
Return information about the model using the same format as list_models().
|
||||
If there are more than one model that match, raises a DuplicateModelException.
|
||||
If no model matches, raises an UnknownModelException
|
||||
"""
|
||||
pass
|
||||
model_configs = self.list_models(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type
|
||||
)
|
||||
if len(model_configs) > 1:
|
||||
raise DuplicateModelException("More than one model share the same name and type: {base_model}/{model_type}/{model_name}")
|
||||
if len(model_configs) == 0:
|
||||
raise UnknownModelException("No known model with name and type: {base_model}/{model_type}/{model_name}")
|
||||
return model_configs[0]
|
||||
|
||||
@abstractmethod
|
||||
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
||||
def all_models(self) -> List[ModelConfigBase]:
|
||||
"""
|
||||
Returns a list of all the model names known.
|
||||
Returns a list of all the models.
|
||||
"""
|
||||
pass
|
||||
return self.list_models()
|
||||
|
||||
@abstractmethod
|
||||
def add_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
clobber: bool = False,
|
||||
) -> InstallJobBase:
|
||||
self,
|
||||
model_path: Path,
|
||||
probe_overrides: Optional[Dict[str, Any]] = None,
|
||||
wait: bool = False
|
||||
) -> ModelInstallJob:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||
On a successful update, the config will be changed in memory. Will fail
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
Add a model using its path, with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
self,
|
||||
key: str,
|
||||
new_config: Union[dict, ModelConfigBase],
|
||||
) -> ModelConfigBase:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with a
|
||||
@ -155,36 +146,32 @@ class ModelManagerServiceBase(ABC):
|
||||
|
||||
On a successful update, the config will be changed in memory. Will fail
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
the model key is unknown.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def del_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
self,
|
||||
key: str,
|
||||
delete_files: bool = False
|
||||
):
|
||||
"""
|
||||
Delete the named model from configuration. If delete_files is true,
|
||||
then the underlying weight file or diffusers directory will be deleted
|
||||
as well. Call commit() to write to disk.
|
||||
Delete the named model from configuration. If delete_files
|
||||
is true, then the underlying file or directory will be
|
||||
deleted as well.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def rename_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
new_name: str,
|
||||
):
|
||||
self,
|
||||
key: str,
|
||||
new_name: str,
|
||||
) -> ModelConfigBase:
|
||||
"""
|
||||
Rename the indicated model.
|
||||
"""
|
||||
pass
|
||||
return self.update_model(key, {"name": new_name})
|
||||
|
||||
@abstractmethod
|
||||
def list_checkpoint_configs(self) -> List[Path]:
|
||||
@ -195,18 +182,17 @@ class ModelManagerServiceBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def convert_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: Literal[ModelType.Main, ModelType.Vae],
|
||||
) -> InstallJobBase:
|
||||
self,
|
||||
key: str,
|
||||
convert_dest_directory: Path,
|
||||
) -> ModelConfigBase:
|
||||
"""
|
||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||
version and deleting the original checkpoint file if it is in the models
|
||||
directory.
|
||||
:param model_name: Name of the model to convert
|
||||
:param base_model: Base model type
|
||||
:param model_type: Type of model ['vae' or 'main']
|
||||
Convert a checkpoint file into a diffusers folder.
|
||||
|
||||
This will delete the cached version if there is any and delete the original
|
||||
checkpoint file if it is in the models directory.
|
||||
:param key: Unique key for the model to convert.
|
||||
:param convert_dest_directory: Save the converted model to the designated directory (`models/etc/etc` by default)
|
||||
|
||||
This will raise a ValueError unless the model is not a checkpoint. It will
|
||||
also raise a ValueError in the event that there is a similarly-named diffusers
|
||||
@ -215,37 +201,34 @@ class ModelManagerServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def heuristic_import(
|
||||
self,
|
||||
items_to_import: set[str],
|
||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
||||
) -> InstallJobBase:
|
||||
"""Import a list of paths, repo_ids or URLs. Returns the set of
|
||||
successfully imported items.
|
||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
||||
def install_model (
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
model_attributes: Optional[Dict[str, Any]] = None,
|
||||
) -> ModelInstallJob:
|
||||
"""Import a path, repo_id or URL. Returns an ModelInstallJob.
|
||||
|
||||
The prediction type helper is necessary to distinguish between
|
||||
models based on Stable Diffusion 2 Base (requiring
|
||||
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
|
||||
(requiring SchedulerPredictionType.VPrediction). It is
|
||||
generally impossible to do this programmatically, so the
|
||||
prediction_type_helper usually asks the user to choose.
|
||||
:param model_attributes: Additional attributes to supplement/override
|
||||
the model information gained from automated probing.
|
||||
|
||||
The result is a set of successfully installed models. Each element
|
||||
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
||||
that model.
|
||||
Typical usage:
|
||||
job = model_manager.install(
|
||||
'stabilityai/stable-diffusion-2-1',
|
||||
model_attributes={'prediction_type": 'v-prediction'}
|
||||
)
|
||||
|
||||
The result is an ModelInstallJob object, which provides
|
||||
information on the asynchronous model download and install
|
||||
process. A series of "install_model_event" events will be emitted
|
||||
until the install is completed, cancelled or errors out.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def merge_models(
|
||||
self,
|
||||
model_names: List[str] = Field(
|
||||
default=None, min_items=2, max_items=3, description="List of model names to merge"
|
||||
),
|
||||
base_model: Union[BaseModelType, str] = Field(
|
||||
default=None, description="Base model shared by all models to be merged"
|
||||
model_keys: List[str] = Field(
|
||||
default=None, min_items=2, max_items=3, description="List of model keys to merge"
|
||||
),
|
||||
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
||||
alpha: Optional[float] = 0.5,
|
||||
@ -255,8 +238,7 @@ class ModelManagerServiceBase(ABC):
|
||||
) -> ModelConfigBase:
|
||||
"""
|
||||
Merge two to three diffusrs pipeline models and save as a new model.
|
||||
:param model_names: List of 2-3 models to merge
|
||||
:param base_model: Base model to use for all models
|
||||
:param model_keys: List of 2-3 model unique keys to merge
|
||||
:param merged_model_name: Name of destination merged model
|
||||
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||
:param interp: Interpolation method. None (default)
|
||||
@ -287,24 +269,16 @@ class ModelManagerServiceBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def commit(self, conf_file: Optional[Path] = None) -> None:
|
||||
"""
|
||||
Write current configuration out to the indicated file.
|
||||
If no conf_file is provided, then replaces the
|
||||
original file/database used to initialize the object.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
# simple implementation
|
||||
# implementation
|
||||
class ModelManagerService(ModelManagerServiceBase):
|
||||
"""Responsible for managing models on disk and in memory"""
|
||||
|
||||
_loader: ModelLoader = Field(description="InvokeAIAppConfig object for the current process")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: InvokeAIAppConfig,
|
||||
logger: Logger,
|
||||
):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file.
|
||||
@ -312,218 +286,164 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
and sequential_offload boolean. Note that the default device
|
||||
type and precision are set up for a CUDA system running at half precision.
|
||||
"""
|
||||
if config.model_conf_path and config.model_conf_path.exists():
|
||||
config_file = config.model_conf_path
|
||||
else:
|
||||
config_file = config.root_dir / "configs/models.yaml"
|
||||
|
||||
logger.debug(f"Config file={config_file}")
|
||||
|
||||
device = torch.device(choose_torch_device())
|
||||
device_name = torch.cuda.get_device_name() if device == torch.device("cuda") else ""
|
||||
logger.info(f"GPU device = {device} {device_name}")
|
||||
|
||||
precision = config.precision
|
||||
if precision == "auto":
|
||||
precision = choose_precision(device)
|
||||
dtype = torch.float32 if precision == "float32" else torch.float16
|
||||
|
||||
# this is transitional backward compatibility
|
||||
# support for the deprecated `max_loaded_models`
|
||||
# configuration value. If present, then the
|
||||
# cache size is set to 2.5 GB times
|
||||
# the number of max_loaded_models. Otherwise
|
||||
# use new `ram_cache_size` config setting
|
||||
max_cache_size = config.ram_cache_size
|
||||
|
||||
logger.debug(f"Maximum RAM cache size: {max_cache_size} GiB")
|
||||
|
||||
sequential_offload = config.sequential_guidance
|
||||
|
||||
self.mgr = ModelManager(
|
||||
config=config_file,
|
||||
device_type=device,
|
||||
precision=dtype,
|
||||
max_cache_size=max_cache_size,
|
||||
sequential_offload=sequential_offload,
|
||||
logger=logger,
|
||||
)
|
||||
logger.info("Model manager service initialized")
|
||||
self._loader = ModelLoader(config)
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> ModelInfo:
|
||||
"""
|
||||
Retrieve the indicated model. submodel can be used to get a
|
||||
part (such as the vae) of a diffusers mode.
|
||||
"""
|
||||
|
||||
|
||||
model_info: ModelInfo = self._loader.get_model(key, submodel_type)
|
||||
|
||||
# we can emit model loading events if we are executing with access to the invocation context
|
||||
if context:
|
||||
self._emit_load_event(
|
||||
context=context,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
)
|
||||
|
||||
model_info = self.mgr.get_model(
|
||||
model_name,
|
||||
base_model,
|
||||
model_type,
|
||||
submodel,
|
||||
)
|
||||
|
||||
if context:
|
||||
self._emit_load_event(
|
||||
context=context,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
key=key,
|
||||
submodel_type=submodel_type,
|
||||
model_info=model_info,
|
||||
)
|
||||
|
||||
return model_info
|
||||
|
||||
def model_exists(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
self,
|
||||
key: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Given a model name, returns True if it is a valid
|
||||
Given a model key, returns True if it is a valid
|
||||
identifier.
|
||||
"""
|
||||
return self.mgr.model_exists(
|
||||
model_name,
|
||||
base_model,
|
||||
model_type,
|
||||
)
|
||||
return self._loader.store.exists(key)
|
||||
|
||||
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]:
|
||||
def model_info(self, key: str) -> ModelConfigBase:
|
||||
"""
|
||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
||||
"""
|
||||
return self.mgr.model_info(model_name, base_model, model_type)
|
||||
return self._loader.store.get_model(key)
|
||||
|
||||
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
||||
"""
|
||||
Returns a list of all the model names known.
|
||||
"""
|
||||
return self.mgr.model_names()
|
||||
# def all_models(self) -> List[ModelConfigBase] -- defined in base class, same as list_models()
|
||||
# def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -- defined in base class
|
||||
|
||||
def list_models(
|
||||
self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None
|
||||
) -> list[dict]:
|
||||
def list_models(self,
|
||||
model_name: Optional[str] = None,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
) -> List[ModelConfigBase]:
|
||||
"""
|
||||
Return a list of models.
|
||||
Return a ModelConfigBase object for each model in the database.
|
||||
"""
|
||||
return self.mgr.list_models(base_model, model_type)
|
||||
return self._loader.store.search_by_name(model_name, base_model, model_type)
|
||||
|
||||
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]:
|
||||
"""
|
||||
Return information about the model using the same format as list_models()
|
||||
"""
|
||||
return self.mgr.list_model(model_name=model_name, base_model=base_model, model_type=model_type)
|
||||
|
||||
def add_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
clobber: bool = False,
|
||||
) -> InstallJobBase:
|
||||
self,
|
||||
model_path: Path,
|
||||
model_attributes: Optional[dict] = None,
|
||||
wait: bool = False
|
||||
) -> ModelInstallJob:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||
On a successful update, the config will be changed in memory. Will fail
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
Add a model using its path, with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists.
|
||||
"""
|
||||
self.logger.debug(f"add/update model {model_name}")
|
||||
return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber)
|
||||
self.logger.debug(f"add/update model {model_path}")
|
||||
return self._loader.installer.install(
|
||||
model_path,
|
||||
probe_override=model_attributes,
|
||||
)
|
||||
|
||||
def install_model(
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
model_attributes: Optional[Dict[str, Any]] = None,
|
||||
) -> ModelInstallJob:
|
||||
"""
|
||||
Add a model using its path, with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists.
|
||||
"""
|
||||
self.logger.debug(f"add/update model {source}")
|
||||
variant = 'fp16' if self._loader.precision == 'float16' else None
|
||||
return self._loader.installer.install(
|
||||
source,
|
||||
probe_override=model_attributes,
|
||||
variant=variant,
|
||||
)
|
||||
|
||||
def update_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
) -> InstallJobBase:
|
||||
self,
|
||||
key: str,
|
||||
new_config: Union[dict, ModelConfigBase],
|
||||
) -> ModelConfigBase:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with a
|
||||
UnknownModelException exception if the name does not already exist.
|
||||
UnknownModelException if the name does not already exist.
|
||||
|
||||
On a successful update, the config will be changed in memory. Will fail
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
the model key is unknown.
|
||||
"""
|
||||
self.logger.debug(f"update model {model_name}")
|
||||
if not self.model_exists(model_name, base_model, model_type):
|
||||
raise UnknownModelException(f"Unknown model {model_name}")
|
||||
return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True)
|
||||
model_info = self.model_info(key)
|
||||
self.logger.debug(f"update model {model_info.name}")
|
||||
self.logger.warning("TO DO: write code to move models around if base or type change")
|
||||
return self._loader.store.update_model(key, new_config)
|
||||
|
||||
def del_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
self,
|
||||
key: str,
|
||||
delete_files: bool = False,
|
||||
):
|
||||
"""
|
||||
Delete the named model from configuration. If delete_files is true,
|
||||
then the underlying weight file or diffusers directory will be deleted
|
||||
as well.
|
||||
"""
|
||||
self.logger.debug(f"delete model {model_name}")
|
||||
self.mgr.del_model(model_name, base_model, model_type)
|
||||
self.mgr.commit()
|
||||
model_info = self.model_info(key)
|
||||
self.logger.debug(f"delete model {model_info.name}")
|
||||
self._loader.store.del_model(key)
|
||||
if delete_files and Path(model_info.path).exists():
|
||||
path = Path(model_info)
|
||||
if path.is_dir():
|
||||
shutil.rmtree(path)
|
||||
else:
|
||||
path.unlink()
|
||||
|
||||
def convert_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: Literal[ModelType.Main, ModelType.Vae],
|
||||
convert_dest_directory: Optional[Path] = Field(
|
||||
default=None, description="Optional directory location for merged model"
|
||||
),
|
||||
) -> InstallJobBase:
|
||||
self,
|
||||
key: str,
|
||||
convert_dest_directory: Path,
|
||||
) -> ModelConfigBase:
|
||||
"""
|
||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||
version and deleting the original checkpoint file if it is in the models
|
||||
directory.
|
||||
:param model_name: Name of the model to convert
|
||||
:param base_model: Base model type
|
||||
:param model_type: Type of model ['vae' or 'main']
|
||||
|
||||
:param key: Key of the model to convert
|
||||
:param convert_dest_directory: Save the converted model to the designated directory (`models/etc/etc` by default)
|
||||
|
||||
This will raise a ValueError unless the model is not a checkpoint. It will
|
||||
also raise a ValueError in the event that there is a similarly-named diffusers
|
||||
directory already in place.
|
||||
"""
|
||||
self.logger.debug(f"convert model {model_name}")
|
||||
return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory)
|
||||
model_info = self.model_info(key)
|
||||
self.logger.debug(f"convert model {model_info.name}")
|
||||
self.logger.warning('This is not implemented yet')
|
||||
return self._loader.convert_model(key, convert_dest_directory)
|
||||
|
||||
def collect_cache_stats(self, cache_stats: CacheStats):
|
||||
"""
|
||||
Reset model cache statistics for graph with graph_id.
|
||||
Reset model cache statistics. Is this used?
|
||||
"""
|
||||
self.mgr.cache.stats = cache_stats
|
||||
|
||||
def commit(self, conf_file: Optional[Path] = None):
|
||||
"""
|
||||
Write current configuration out to the indicated file.
|
||||
If no conf_file is provided, then replaces the
|
||||
original file/database used to initialize the object.
|
||||
"""
|
||||
return self.mgr.commit(conf_file)
|
||||
self._loader.collect_cache_stats(cache_stats)
|
||||
|
||||
def _emit_load_event(
|
||||
self,
|
||||
@ -557,51 +477,22 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
|
||||
@property
|
||||
def logger(self):
|
||||
return self.mgr.logger
|
||||
|
||||
def heuristic_import(
|
||||
self,
|
||||
items_to_import: set[str],
|
||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
||||
) -> dict[str, InstallJobBase]:
|
||||
"""Import a list of paths, repo_ids or URLs. Returns the set of
|
||||
successfully imported items.
|
||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
||||
|
||||
The prediction type helper is necessary to distinguish between
|
||||
models based on Stable Diffusion 2 Base (requiring
|
||||
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
|
||||
(requiring SchedulerPredictionType.VPrediction). It is
|
||||
generally impossible to do this programmatically, so the
|
||||
prediction_type_helper usually asks the user to choose.
|
||||
|
||||
The result is a set of successfully installed models. Each element
|
||||
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
||||
that model.
|
||||
"""
|
||||
return self.mgr.heuristic_import(items_to_import, prediction_type_helper)
|
||||
|
||||
return self._loader.logger
|
||||
|
||||
def merge_models(
|
||||
self,
|
||||
model_names: List[str] = Field(
|
||||
default=None, min_items=2, max_items=3, description="List of model names to merge"
|
||||
),
|
||||
base_model: Union[BaseModelType, str] = Field(
|
||||
default=None, description="Base model shared by all models to be merged"
|
||||
),
|
||||
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
||||
alpha: float = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: bool = False,
|
||||
merge_dest_directory: Optional[Path] = Field(
|
||||
default=None, description="Optional directory location for merged model"
|
||||
),
|
||||
) -> str:
|
||||
self,
|
||||
model_keys: List[str] = Field(
|
||||
default=None, min_items=2, max_items=3, description="List of model keys to merge"
|
||||
),
|
||||
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
||||
alpha: Optional[float] = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: Optional[bool] = False,
|
||||
merge_dest_directory: Optional[Path] = None,
|
||||
) -> ModelConfigBase:
|
||||
"""
|
||||
Merge two to three diffusrs pipeline models and save as a new model.
|
||||
:param model_names: List of 2-3 models to merge
|
||||
:param base_model: Base model to use for all models
|
||||
:param model_keys: List of 2-3 model unique keys to merge
|
||||
:param merged_model_name: Name of destination merged model
|
||||
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||
:param interp: Interpolation method. None (default)
|
||||
@ -609,9 +500,9 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
"""
|
||||
merger = ModelMerger(self.mgr)
|
||||
try:
|
||||
self.logger.error('ModelMerger needs to be rewritten.')
|
||||
result = merger.merge_diffusion_models_and_save(
|
||||
model_names=model_names,
|
||||
base_model=base_model,
|
||||
model_keys=model_keys,
|
||||
merged_model_name=merged_model_name,
|
||||
alpha=alpha,
|
||||
interp=interp,
|
||||
@ -626,8 +517,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
"""
|
||||
Return list of all models found in the designated directory.
|
||||
"""
|
||||
search = FindModels([directory], self.logger)
|
||||
return search.list_models()
|
||||
return ModelSearch().search(directory)
|
||||
|
||||
def sync_to_config(self):
|
||||
"""
|
||||
@ -635,7 +525,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
in the autoimport directories. Call after making changes outside the
|
||||
model manager API.
|
||||
"""
|
||||
return self.mgr.sync_to_config()
|
||||
return self._loader.sync_to_config()
|
||||
|
||||
def list_checkpoint_configs(self) -> List[Path]:
|
||||
"""
|
||||
@ -648,24 +538,13 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
|
||||
def rename_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
new_name: Optional[str] = None,
|
||||
new_base: Optional[BaseModelType] = None,
|
||||
key: str,
|
||||
new_name: str,
|
||||
):
|
||||
"""
|
||||
Rename the indicated model. Can provide a new name and/or a new base.
|
||||
:param model_name: Current name of the model
|
||||
:param base_model: Current base of the model
|
||||
:param model_type: Model type (can't be changed)
|
||||
Rename the indicated model to the new name.
|
||||
|
||||
:param key: Unique key for the model.
|
||||
:param new_name: New name for the model
|
||||
:param new_base: New base for the model
|
||||
"""
|
||||
self.mgr.rename_model(
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
model_name=model_name,
|
||||
new_name=new_name,
|
||||
new_base=new_base,
|
||||
)
|
||||
return self.update_model(key, {"name": new_name})
|
||||
|
@ -16,7 +16,7 @@ from .config import ( # noqa F401
|
||||
)
|
||||
from .lora import ONNXModelPatcher, ModelPatcher
|
||||
from .loader import ModelLoader, ModelInfo # noqa F401
|
||||
from .install import ModelInstall, DownloadJobBase # noqa F401
|
||||
from .install import ModelInstall, ModelInstallJob # noqa F401
|
||||
from .probe import ModelProbe, InvalidModelException # noqa F401
|
||||
from .storage import (
|
||||
UnknownModelException,
|
||||
|
@ -136,6 +136,11 @@ class ModelConfigBase(BaseModel):
|
||||
v = list(v)
|
||||
return v
|
||||
|
||||
def update(self, attributes: dict):
|
||||
"""Update the object with fields in dict."""
|
||||
for key, value in attributes.items():
|
||||
setattr(self, key, value) # may raise a validation error
|
||||
|
||||
|
||||
class CheckpointConfig(ModelConfigBase):
|
||||
"""Model config for checkpoint-style models."""
|
||||
|
@ -329,14 +329,14 @@ class DownloadQueue(DownloadQueueBase):
|
||||
self._dones += 1
|
||||
self._queue.task_done()
|
||||
|
||||
def _fetch_metadata(self, job: DownloadJobBase) -> Tuple[AnyHttpUrl, ModelSourceMetadata]:
|
||||
def _get_metadata_and_url(self, job: DownloadJobBase) -> AnyHttpUrl:
|
||||
"""
|
||||
Fetch metadata from certain well-known URLs.
|
||||
|
||||
The metadata will be stashed in job.metadata, if found
|
||||
Return the download URL.
|
||||
"""
|
||||
metadata = ModelSourceMetadata()
|
||||
metadata = job.metadata
|
||||
url = job.source
|
||||
metadata_url = url
|
||||
try:
|
||||
@ -344,12 +344,14 @@ class DownloadQueue(DownloadQueueBase):
|
||||
if match := re.match(CIVITAI_MODEL_DOWNLOAD + r"(\d+)", metadata_url):
|
||||
version = match.group(1)
|
||||
resp = self._requests.get(CIVITAI_VERSIONS_ENDPOINT + version).json()
|
||||
metadata.thumbnail_url = resp["images"][0]["url"]
|
||||
metadata.description = (
|
||||
f"Trigger terms: {(', ').join(resp['trainedWords'])}"
|
||||
if resp["trainedWords"]
|
||||
else resp["description"]
|
||||
)
|
||||
metadata.thumbnail_url = metadata.thumbnail_url \
|
||||
or resp["images"][0]["url"]
|
||||
metadata.description = metadata.description \
|
||||
or (
|
||||
f"Trigger terms: {(', ').join(resp['trainedWords'])}"
|
||||
if resp["trainedWords"]
|
||||
else resp["description"]
|
||||
)
|
||||
metadata_url = CIVITAI_MODEL_PAGE + str(resp["modelId"])
|
||||
|
||||
# a Civitai model page
|
||||
@ -360,21 +362,22 @@ class DownloadQueue(DownloadQueueBase):
|
||||
# note that we munge the URL here to get the download URL of the first model
|
||||
url = resp["modelVersions"][0]["downloadUrl"]
|
||||
|
||||
metadata.author = resp["creator"]["username"]
|
||||
metadata.tags = resp["tags"]
|
||||
metadata.thumbnail_url = resp["modelVersions"][0]["images"][0]["url"]
|
||||
metadata.license = f"allowCommercialUse={resp['allowCommercialUse']}; allowDerivatives={resp['allowDerivatives']}; allowNoCredit={resp['allowNoCredit']}"
|
||||
metadata.author = metadata.author or resp["creator"]["username"]
|
||||
metadata.tags = metadata.tags or resp["tags"]
|
||||
metadata.thumbnail_url = metadata.thumbnail_url \
|
||||
or resp["modelVersions"][0]["images"][0]["url"]
|
||||
metadata.license = metadata.license \
|
||||
or f"allowCommercialUse={resp['allowCommercialUse']}; allowDerivatives={resp['allowDerivatives']}; allowNoCredit={resp['allowNoCredit']}"
|
||||
except (HTTPError, KeyError, TypeError, JSONDecodeError) as excp:
|
||||
self._logger.warn(excp)
|
||||
|
||||
# update metadata and return the download url
|
||||
return url, metadata
|
||||
# return the download url
|
||||
return url
|
||||
|
||||
def _download_with_resume(self, job: DownloadJobBase):
|
||||
"""Do the actual download."""
|
||||
try:
|
||||
url, metadata = self._fetch_metadata(job)
|
||||
job.metadata = metadata
|
||||
url = self._get_metadata_and_url(job)
|
||||
|
||||
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
|
||||
open_mode = "wb"
|
||||
@ -602,7 +605,6 @@ class DownloadQueue(DownloadQueueBase):
|
||||
"""Call when the source is a Path or pathlike object."""
|
||||
source = Path(job.source).resolve()
|
||||
destination = Path(job.destination).resolve()
|
||||
job.metadata = ModelSourceMetadata()
|
||||
try:
|
||||
if source != destination:
|
||||
shutil.move(source, destination)
|
||||
|
@ -53,7 +53,7 @@ import tempfile
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
from typing import Optional, List, Union, Dict, Set
|
||||
from typing import Optional, List, Union, Dict, Set, Any
|
||||
from pydantic import Field
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
@ -79,8 +79,8 @@ class ModelInstallJob(DownloadJobBase):
|
||||
model_key: Optional[str] = Field(
|
||||
description="After model installation, this field will hold its primary key", default=None
|
||||
)
|
||||
probe_info: Optional[ModelProbeInfo] = Field(
|
||||
description="If provided, information here will be used instead of probing the model.",
|
||||
probe_override: Optional[Dict[str, Any]] = Field(
|
||||
description="Keys in this dict will override like-named attributes in the automatic probe info",
|
||||
default=None,
|
||||
)
|
||||
|
||||
@ -316,9 +316,12 @@ class ModelInstall(ModelInstallBase):
|
||||
"""Return the queue."""
|
||||
return self._download_queue
|
||||
|
||||
def register_path(self, model_path: Union[Path, str], info: Optional[ModelProbeInfo] = None) -> str: # noqa D102
|
||||
def register_path(self,
|
||||
model_path: Union[Path, str],
|
||||
overrides: Optional[Dict[str, Any]] = None
|
||||
) -> str: # noqa D102
|
||||
model_path = Path(model_path)
|
||||
info: ModelProbeInfo = info or ModelProbe.probe(model_path)
|
||||
info: ModelProbeInfo = self._probe_model(model_path, overrides)
|
||||
return self._register(model_path, info)
|
||||
|
||||
def _register(self, model_path: Path, info: ModelProbeInfo) -> str:
|
||||
@ -351,12 +354,13 @@ class ModelInstall(ModelInstallBase):
|
||||
return key
|
||||
|
||||
def install_path(
|
||||
self,
|
||||
model_path: Union[Path, str],
|
||||
info: Optional[ModelProbeInfo] = None,
|
||||
self,
|
||||
model_path: Union[Path, str],
|
||||
overrides: Optional[Dict[str, Any]] = None,
|
||||
) -> str: # noqa D102
|
||||
model_path = Path(model_path)
|
||||
info: ModelProbeInfo = info or ModelProbe.probe(model_path)
|
||||
info: ModelProbeInfo = self._probe_model(model_path, overrides)
|
||||
|
||||
dest_path = self._config.models_path / info.base_type.value / info.model_type.value / model_path.name
|
||||
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@ -371,6 +375,16 @@ class ModelInstall(ModelInstallBase):
|
||||
info,
|
||||
)
|
||||
|
||||
def _probe_model(self,
|
||||
model_path: Union[Path, str],
|
||||
overrides: Optional[Dict[str,Any]] = None
|
||||
) -> ModelProbeInfo:
|
||||
info: ModelProbeInfo = ModelProbe.probe(model_path)
|
||||
if overrides: # used to override probe fields
|
||||
for key, value in overrides.items():
|
||||
setattr(info, key, value) # may generate a pydantic validation error
|
||||
return info
|
||||
|
||||
def unregister(self, key: str): # noqa D102
|
||||
self._store.del_model(key)
|
||||
|
||||
@ -380,12 +394,12 @@ class ModelInstall(ModelInstallBase):
|
||||
self.unregister(key)
|
||||
|
||||
def install(
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
info: Optional[ModelProbeInfo] = None,
|
||||
inplace: bool = True,
|
||||
variant: Optional[str] = None,
|
||||
access_token: Optional[str] = None,
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
inplace: bool = True,
|
||||
variant: Optional[str] = None,
|
||||
probe_override: Optional[Dict[str, Any]] = None,
|
||||
access_token: Optional[str] = None,
|
||||
) -> DownloadJobBase: # noqa D102
|
||||
queue = self._download_queue
|
||||
|
||||
@ -395,8 +409,8 @@ class ModelInstall(ModelInstallBase):
|
||||
if inplace and Path(source).exists()
|
||||
else self._complete_installation_handler
|
||||
)
|
||||
job.probe_override = probe_override
|
||||
job.add_event_handler(handler)
|
||||
job.probe_info = info
|
||||
|
||||
self._async_installs[source] = None
|
||||
queue.submit_download_job(job, True)
|
||||
@ -405,7 +419,7 @@ class ModelInstall(ModelInstallBase):
|
||||
def _complete_installation_handler(self, job: DownloadJobBase):
|
||||
if job.status == "completed":
|
||||
self._logger.info(f"{job.source}: Download finished with status {job.status}. Installing.")
|
||||
model_id = self.install_path(job.destination, job.probe_info)
|
||||
model_id = self.install_path(job.destination, job.probe_override)
|
||||
info = self._store.get_model(model_id)
|
||||
info.source = str(job.source)
|
||||
metadata: ModelSourceMetadata = job.metadata
|
||||
@ -429,7 +443,7 @@ class ModelInstall(ModelInstallBase):
|
||||
def _complete_registration_handler(self, job: DownloadJobBase):
|
||||
if job.status == "completed":
|
||||
self._logger.info(f"{job.source}: Installing in place.")
|
||||
model_id = self.register_path(job.destination, job.probe_info)
|
||||
model_id = self.register_path(job.destination, job.probe_override)
|
||||
info = self._store.get_model(model_id)
|
||||
info.source = str(job.source)
|
||||
info.description = f"Imported model {info.name}"
|
||||
|
@ -14,7 +14,7 @@ from invokeai.backend.util import choose_precision, choose_torch_device, InvokeA
|
||||
from .config import BaseModelType, ModelType, SubModelType, ModelConfigBase
|
||||
from .install import ModelInstallBase, ModelInstall
|
||||
from .storage import ModelConfigStore, get_config_store
|
||||
from .cache import ModelCache, ModelLocker
|
||||
from .cache import ModelCache, ModelLocker, CacheStats
|
||||
from .models import InvalidModelException, ModelBase, MODEL_CLASSES
|
||||
|
||||
|
||||
@ -69,6 +69,34 @@ class ModelLoaderBase(ABC):
|
||||
"""Return the ModelInstallBase object that supports this loader."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def logger(self) -> InvokeAILogger:
|
||||
"""Return the current logger."""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def collect_cache_stats(
|
||||
self,
|
||||
cache_stats: CacheStats
|
||||
):
|
||||
"""Replace cache statistics."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def precision(self) -> str:
|
||||
"""Return 'float32' or 'float16'."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def sync_to_config(self):
|
||||
"""
|
||||
Reinitialize the store to sync in-memory and in-disk
|
||||
versions.
|
||||
"""
|
||||
pass
|
||||
|
||||
class ModelLoader(ModelLoaderBase):
|
||||
"""Implementation of ModelLoaderBase."""
|
||||
@ -79,6 +107,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
_cache: ModelCache
|
||||
_logger: InvokeAILogger
|
||||
_cache_keys: dict
|
||||
_models_file: Path
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -102,6 +131,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
self._logger = InvokeAILogger.getLogger()
|
||||
self._installer = ModelInstall(store=self._store, logger=self._logger, config=self._app_config)
|
||||
self._cache_keys = dict()
|
||||
self._models_file = models_file
|
||||
device = torch.device(choose_torch_device())
|
||||
device_name = torch.cuda.get_device_name() if device == torch.device("cuda") else ""
|
||||
precision = choose_precision(device) if config.precision == "auto" else config.precision
|
||||
@ -130,11 +160,21 @@ class ModelLoader(ModelLoaderBase):
|
||||
"""Return the ModelConfigStore instance used by this class."""
|
||||
return self._store
|
||||
|
||||
@property
|
||||
def precision(self) -> str:
|
||||
"""Return 'float32' or 'float16'."""
|
||||
return self._cache.precision
|
||||
|
||||
@property
|
||||
def installer(self) -> ModelInstallBase:
|
||||
"""Return the ModelInstallBase instance used by this class."""
|
||||
return self._installer
|
||||
|
||||
@property
|
||||
def logger(self) -> InvokeAILogger:
|
||||
"""Return the current logger."""
|
||||
return self._logger
|
||||
|
||||
def get_model(self, key: str, submodel_type: Optional[SubModelType] = None) -> ModelInfo:
|
||||
"""
|
||||
Get the ModelInfo corresponding to the model with key "key".
|
||||
@ -188,6 +228,12 @@ class ModelLoader(ModelLoaderBase):
|
||||
_cache=self._cache,
|
||||
)
|
||||
|
||||
def collect_cache_stats(
|
||||
self,
|
||||
cache_stats: CacheStats
|
||||
):
|
||||
self._cache.stats = cache_stats
|
||||
|
||||
def _get_implementation(self, base_model: BaseModelType, model_type: ModelType) -> type[ModelBase]:
|
||||
"""Get the concrete implementation class for a specific model type."""
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
@ -220,6 +266,10 @@ class ModelLoader(ModelLoaderBase):
|
||||
model_path = self._resolve_model_path(model_path)
|
||||
return model_path, is_submodel_override
|
||||
|
||||
def sync_to_config(self):
|
||||
self._store = get_config_store(self._models_file)
|
||||
self._scan_models_directory()
|
||||
|
||||
def _scan_models_directory(self):
|
||||
defunct_models = set()
|
||||
installed = set()
|
||||
|
@ -124,17 +124,26 @@ class ModelMerger(object):
|
||||
dump_path = (dump_path / merged_model_name).as_posix()
|
||||
|
||||
merged_pipe.save_pretrained(dump_path, safe_serialization=True)
|
||||
attributes = dict(
|
||||
path=dump_path,
|
||||
description=f"Merge of models {', '.join(model_names)}",
|
||||
model_format="diffusers",
|
||||
variant=ModelVariantType.Normal.value,
|
||||
vae=vae,
|
||||
)
|
||||
return self.manager.add_model(
|
||||
merged_model_name,
|
||||
base_model=base_model,
|
||||
|
||||
# register model and get its unique key
|
||||
info = ModelProbeInfo(
|
||||
model_type=ModelType.Main,
|
||||
model_attributes=attributes,
|
||||
clobber=True,
|
||||
base_type=base_model,
|
||||
format="diffusers",
|
||||
)
|
||||
key = self.manager.installer.register_path(
|
||||
model_path=dump_path,
|
||||
info=info,
|
||||
)
|
||||
|
||||
# update model's config
|
||||
model_config = self.manager.store.get_model(key)
|
||||
model_config.update(
|
||||
dict(
|
||||
name=merged_model_name,
|
||||
description=f"Merge of models {', '.join(model_names)}",
|
||||
vae=vae,
|
||||
)
|
||||
)
|
||||
self.manager.store.update_model(key, model_config)
|
||||
return model_config
|
||||
|
@ -1,5 +1,5 @@
|
||||
__metadata__:
|
||||
version: 3.1.0
|
||||
version: 3.1.1
|
||||
ed799245c762f6d0a9ddfd4e31fdb010:
|
||||
name: sdxl-base-1-0
|
||||
path: sdxl/main/SDXL base 1_0
|
||||
|
Reference in New Issue
Block a user