mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
b7938d9ca9
* fix(config): fix typing issues in `config/` `config/invokeai_config.py`: - use `Optional` for things that are optional - fix typing of `ram_cache_size()` and `vram_cache_size()` - remove unused and incorrectly typed method `autoconvert_path` - fix types and logic for `parse_args()`, in which `InvokeAIAppConfig.initconf` *must* be a `DictConfig`, but function would allow it to be set as a `ListConfig`, which presumably would cause issues elsewhere `config/base.py`: - use `cls` for first arg of class methods - use `Optional` for things that are optional - fix minor type issue related to setting of `env_prefix` - remove unused `add_subparser()` method, which calls `add_parser()` on an `ArgumentParser` (method only available on the `_SubParsersAction` object, which is returned from ArgumentParser.add_subparsers()`) * feat: queued generation and batches Due to a very messy branch with broad addition of `isort` on `main` alongside it, some git surgery was needed to get an agreeable git history. This commit represents all of the work on queued generation. See PR for notes. * chore: flake8, isort, black * fix(nodes): fix incorrect service stop() method * fix(nodes): improve names of a few variables * fix(tests): fix up tests after changes to batches/queue * feat(tests): add unit tests for session queue helper functions * feat(ui): dynamic prompts is always enabled * feat(queue): add queue_status_changed event * feat(ui): wip queue graphs * feat(nodes): move cleanup til after invoker startup * feat(nodes): add cancel_by_batch_ids * feat(ui): wip batch graphs & UI * fix(nodes): remove `Batch.batch_id` from required * fix(ui): cleanup and use fixedCacheKey for all mutations * fix(ui): remove orphaned nodes from canvas graphs * fix(nodes): fix cancel_by_batch_ids result count * fix(ui): only show cancel batch tooltip when batches were canceled * chore: isort * fix(api): return `[""]` when dynamic prompts generates no prompts Just a simple fallback so we always have a prompt. * feat(ui): dynamicPrompts.combinatorial is always on There seems to be little purpose in using the combinatorial generation for dynamic prompts. I've disabled it by hiding it from the UI and defaulting combinatorial to true. If we want to enable it again in the future it's straightforward to do so. * feat: add queue_id & support logic * feat(ui): fix upscale button It prepends the upscale operation to queue * feat(nodes): return queue item when enqueuing a single graph This facilitates one-off graph async workflows in the client. * feat(ui): move controlnet autoprocess to queue * fix(ui): fix non-serializable DOMRect in redux state * feat(ui): QueueTable performance tweaks * feat(ui): update queue list Queue items expand to show the full queue item. Just as JSON for now. * wip threaded session_processor * feat(nodes,ui): fully migrate queue to session_processor * feat(nodes,ui): add processor events * feat(ui): ui tweaks * feat(nodes,ui): consolidate events, reduce network requests * feat(ui): cleanup & abstract queue hooks * feat(nodes): optimize batch permutation Use a generator to do only as much work as is needed. Previously, though we only ended up creating exactly as many queue items as was needed, there was still some intermediary work that calculated *all* permutations. When that number was very high, the system had a very hard time and used a lot of memory. The logic has been refactored to use a generator. Additionally, the batch validators are optimized to return early and use less memory. * feat(ui): add seed behaviour parameter This dynamic prompts parameter allows the seed to be randomized per prompt or per iteration: - Per iteration: Use the same seed for all prompts in a single dynamic prompt expansion - Per prompt: Use a different seed for every single prompt "Per iteration" is appropriate for exploring a the latents space with a stable starting noise, while "Per prompt" provides more variation. * fix(ui): remove extraneous random seed nodes from linear graphs * fix(ui): fix controlnet autoprocess not working when queue is running * feat(queue): add timestamps to queue status updates Also show execution time in queue list * feat(queue): change all execution-related events to use the `queue_id` as the room, also include `queue_item_id` in InvocationQueueItem This allows for much simpler handling of queue items. * feat(api): deprecate sessions router * chore(backend): tidy logging in `dependencies.py` * fix(backend): respect `use_memory_db` * feat(backend): add `config.log_sql` (enables sql trace logging) * feat: add invocation cache Supersedes #4574 The invocation cache provides simple node memoization functionality. Nodes that use the cache are memoized and not re-executed if their inputs haven't changed. Instead, the stored output is returned. ## Results This feature provides anywhere some significant to massive performance improvement. The improvement is most marked on large batches of generations where you only change a couple things (e.g. different seed or prompt for each iteration) and low-VRAM systems, where skipping an extraneous model load is a big deal. ## Overview A new `invocation_cache` service is added to handle the caching. There's not much to it. All nodes now inherit a boolean `use_cache` field from `BaseInvocation`. This is a node field and not a class attribute, because specific instances of nodes may want to opt in or out of caching. The recently-added `invoke_internal()` method on `BaseInvocation` is used as an entrypoint for the cache logic. To create a cache key, the invocation is first serialized using pydantic's provided `json()` method, skipping the unique `id` field. Then python's very fast builtin `hash()` is used to create an integer key. All implementations of `InvocationCacheBase` must provide a class method `create_key()` which accepts an invocation and outputs a string or integer key. ## In-Memory Implementation An in-memory implementation is provided. In this implementation, the node outputs are stored in memory as python classes. The in-memory cache does not persist application restarts. Max node cache size is added as `node_cache_size` under the `Generation` config category. It defaults to 512 - this number is up for discussion, but given that these are relatively lightweight pydantic models, I think it's safe to up this even higher. Note that the cache isn't storing the big stuff - tensors and images are store on disk, and outputs include only references to them. ## Node Definition The default for all nodes is to use the cache. The `@invocation` decorator now accepts an optional `use_cache: bool` argument to override the default of `True`. Non-deterministic nodes, however, should set this to `False`. Currently, all random-stuff nodes, including `dynamic_prompt`, are set to `False`. The field name `use_cache` is now effectively a reserved field name and possibly a breaking change if any community nodes use this as a field name. In hindsight, all our reserved field names should have been prefixed with underscores or something. ## One Gotcha Leaf nodes probably want to opt out of the cache, because if they are not cached, their outputs are not saved again. If you run the same graph multiple times, you only end up with a single image output, because the image storage side-effects are in the `invoke()` method, which is bypassed if we have a cache hit. ## Linear UI The linear graphs _almost_ just work, but due to the gotcha, we need to be careful about the final image-outputting node. To resolve this, a `SaveImageInvocation` node is added and used in the linear graphs. This node is similar to `ImagePrimitive`, except it saves a copy of its input image, and has `use_cache` set to `False` by default. This is now the leaf node in all linear graphs, and is the only node in those graphs with `use_cache == False` _and_ the only node with `is_intermedate == False`. ## Workflow Editor All nodes now have a footer with a new `Use Cache [ ]` checkbox. It defaults to the value set by the invocation in its python definition, but can be changed by the user. The workflow/node validation logic has been updated to migrate old workflows to use the new default values for `use_cache`. Users may still want to review the settings that have been chosen. In the event of catastrophic failure when running this migration, the default value of `True` is applied, as this is correct for most nodes. Users should consider saving their workflows after loading them in and having them updated. ## Future Enhancements - Callback A future enhancement would be to provide a callback to the `use_cache` flag that would be run as the node is executed to determine, based on its own internal state, if the cache should be used or not. This would be useful for `DynamicPromptInvocation`, where the deterministic behaviour is determined by the `combinatorial: bool` field. ## Future Enhancements - Persisted Cache Similar to how the latents storage is backed by disk, the invocation cache could be persisted to the database or disk. We'd need to be very careful about deserializing outputs, but it's perhaps worth exploring in the future. * fix(ui): fix queue list item width * feat(nodes): do not send the whole node on every generator progress * feat(ui): strip out old logic related to sessions Things like `isProcessing` are no longer relevant with queue. Removed them all & updated everything be appropriate for queue. May be a few little quirks I've missed... * feat(ui): fix up param collapse labels * feat(ui): click queue count to go to queue tab * tidy(queue): update comment, query format * feat(ui): fix progress bar when canceling * fix(ui): fix circular dependency * feat(nodes): bail on node caching logic if `node_cache_size == 0` * feat(nodes): handle KeyError on node cache pop * feat(nodes): bypass cache codepath if caches is disabled more better no do thing * fix(ui): reset api cache on connect/disconnect * feat(ui): prevent enqueue when no prompts generated * feat(ui): add queue controls to workflow editor * feat(ui): update floating buttons & other incidental UI tweaks * fix(ui): fix missing/incorrect translation keys * fix(tests): add config service to mock invocation services invoking needs access to `node_cache_size` to occur * optionally remove pause/resume buttons from queue UI * option to disable prepending * chore(ui): remove unused file * feat(queue): remove `order_id` entirely, `item_id` is now an autoinc pk --------- Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
674 lines
24 KiB
Python
674 lines
24 KiB
Python
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
|
|
|
from __future__ import annotations
|
|
|
|
from abc import ABC, abstractmethod
|
|
from logging import Logger
|
|
from pathlib import Path
|
|
from types import ModuleType
|
|
from typing import TYPE_CHECKING, Callable, List, Literal, Optional, Tuple, Union
|
|
|
|
import torch
|
|
from pydantic import Field
|
|
|
|
from invokeai.app.models.exceptions import CanceledException
|
|
from invokeai.backend.model_management import (
|
|
AddModelResult,
|
|
BaseModelType,
|
|
MergeInterpolationMethod,
|
|
ModelInfo,
|
|
ModelManager,
|
|
ModelMerger,
|
|
ModelNotFoundException,
|
|
ModelType,
|
|
SchedulerPredictionType,
|
|
SubModelType,
|
|
)
|
|
from invokeai.backend.model_management.model_cache import CacheStats
|
|
from invokeai.backend.model_management.model_search import FindModels
|
|
|
|
from ...backend.util import choose_precision, choose_torch_device
|
|
from .config import InvokeAIAppConfig
|
|
|
|
if TYPE_CHECKING:
|
|
from ..invocations.baseinvocation import BaseInvocation, InvocationContext
|
|
|
|
|
|
class ModelManagerServiceBase(ABC):
|
|
"""Responsible for managing models on disk and in memory"""
|
|
|
|
@abstractmethod
|
|
def __init__(
|
|
self,
|
|
config: InvokeAIAppConfig,
|
|
logger: ModuleType,
|
|
):
|
|
"""
|
|
Initialize with the path to the models.yaml config file.
|
|
Optional parameters are the torch device type, precision, max_models,
|
|
and sequential_offload boolean. Note that the default device
|
|
type and precision are set up for a CUDA system running at half precision.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_model(
|
|
self,
|
|
model_name: str,
|
|
base_model: BaseModelType,
|
|
model_type: ModelType,
|
|
submodel: Optional[SubModelType] = None,
|
|
node: Optional[BaseInvocation] = None,
|
|
context: Optional[InvocationContext] = None,
|
|
) -> ModelInfo:
|
|
"""Retrieve the indicated model with name and type.
|
|
submodel can be used to get a part (such as the vae)
|
|
of a diffusers pipeline."""
|
|
pass
|
|
|
|
@property
|
|
@abstractmethod
|
|
def logger(self):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def model_exists(
|
|
self,
|
|
model_name: str,
|
|
base_model: BaseModelType,
|
|
model_type: ModelType,
|
|
) -> bool:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
|
"""
|
|
Given a model name returns a dict-like (OmegaConf) object describing it.
|
|
Uses the exact format as the omegaconf stanza.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def list_models(self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None) -> dict:
|
|
"""
|
|
Return a dict of models in the format:
|
|
{ model_type1:
|
|
{ model_name1: {'status': 'active'|'cached'|'not loaded',
|
|
'model_name' : name,
|
|
'model_type' : SDModelType,
|
|
'description': description,
|
|
'format': 'folder'|'safetensors'|'ckpt'
|
|
},
|
|
model_name2: { etc }
|
|
},
|
|
model_type2:
|
|
{ model_name_n: etc
|
|
}
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
|
"""
|
|
Return information about the model using the same format as list_models()
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
|
"""
|
|
Returns a list of all the model names known.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def add_model(
|
|
self,
|
|
model_name: str,
|
|
base_model: BaseModelType,
|
|
model_type: ModelType,
|
|
model_attributes: dict,
|
|
clobber: bool = False,
|
|
) -> AddModelResult:
|
|
"""
|
|
Update the named model with a dictionary of attributes. Will fail with an
|
|
assertion error if the name already exists. Pass clobber=True to overwrite.
|
|
On a successful update, the config will be changed in memory. Will fail
|
|
with an assertion error if provided attributes are incorrect or
|
|
the model name is missing. Call commit() to write changes to disk.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def update_model(
|
|
self,
|
|
model_name: str,
|
|
base_model: BaseModelType,
|
|
model_type: ModelType,
|
|
model_attributes: dict,
|
|
) -> AddModelResult:
|
|
"""
|
|
Update the named model with a dictionary of attributes. Will fail with a
|
|
ModelNotFoundException if the name does not already exist.
|
|
|
|
On a successful update, the config will be changed in memory. Will fail
|
|
with an assertion error if provided attributes are incorrect or
|
|
the model name is missing. Call commit() to write changes to disk.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def del_model(
|
|
self,
|
|
model_name: str,
|
|
base_model: BaseModelType,
|
|
model_type: ModelType,
|
|
):
|
|
"""
|
|
Delete the named model from configuration. If delete_files is true,
|
|
then the underlying weight file or diffusers directory will be deleted
|
|
as well. Call commit() to write to disk.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def rename_model(
|
|
self,
|
|
model_name: str,
|
|
base_model: BaseModelType,
|
|
model_type: ModelType,
|
|
new_name: str,
|
|
):
|
|
"""
|
|
Rename the indicated model.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def list_checkpoint_configs(self) -> List[Path]:
|
|
"""
|
|
List the checkpoint config paths from ROOT/configs/stable-diffusion.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def convert_model(
|
|
self,
|
|
model_name: str,
|
|
base_model: BaseModelType,
|
|
model_type: Literal[ModelType.Main, ModelType.Vae],
|
|
) -> AddModelResult:
|
|
"""
|
|
Convert a checkpoint file into a diffusers folder, deleting the cached
|
|
version and deleting the original checkpoint file if it is in the models
|
|
directory.
|
|
:param model_name: Name of the model to convert
|
|
:param base_model: Base model type
|
|
:param model_type: Type of model ['vae' or 'main']
|
|
|
|
This will raise a ValueError unless the model is not a checkpoint. It will
|
|
also raise a ValueError in the event that there is a similarly-named diffusers
|
|
directory already in place.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def heuristic_import(
|
|
self,
|
|
items_to_import: set[str],
|
|
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
|
) -> dict[str, AddModelResult]:
|
|
"""Import a list of paths, repo_ids or URLs. Returns the set of
|
|
successfully imported items.
|
|
:param items_to_import: Set of strings corresponding to models to be imported.
|
|
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
|
|
|
The prediction type helper is necessary to distinguish between
|
|
models based on Stable Diffusion 2 Base (requiring
|
|
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
|
|
(requiring SchedulerPredictionType.VPrediction). It is
|
|
generally impossible to do this programmatically, so the
|
|
prediction_type_helper usually asks the user to choose.
|
|
|
|
The result is a set of successfully installed models. Each element
|
|
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
|
that model.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def merge_models(
|
|
self,
|
|
model_names: List[str] = Field(
|
|
default=None, min_items=2, max_items=3, description="List of model names to merge"
|
|
),
|
|
base_model: Union[BaseModelType, str] = Field(
|
|
default=None, description="Base model shared by all models to be merged"
|
|
),
|
|
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
|
alpha: Optional[float] = 0.5,
|
|
interp: Optional[MergeInterpolationMethod] = None,
|
|
force: Optional[bool] = False,
|
|
merge_dest_directory: Optional[Path] = None,
|
|
) -> AddModelResult:
|
|
"""
|
|
Merge two to three diffusrs pipeline models and save as a new model.
|
|
:param model_names: List of 2-3 models to merge
|
|
:param base_model: Base model to use for all models
|
|
:param merged_model_name: Name of destination merged model
|
|
:param alpha: Alpha strength to apply to 2d and 3d model
|
|
:param interp: Interpolation method. None (default)
|
|
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def search_for_models(self, directory: Path) -> List[Path]:
|
|
"""
|
|
Return list of all models found in the designated directory.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def sync_to_config(self):
|
|
"""
|
|
Re-read models.yaml, rescan the models directory, and reimport models
|
|
in the autoimport directories. Call after making changes outside the
|
|
model manager API.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def collect_cache_stats(self, cache_stats: CacheStats):
|
|
"""
|
|
Reset model cache statistics for graph with graph_id.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def commit(self, conf_file: Optional[Path] = None) -> None:
|
|
"""
|
|
Write current configuration out to the indicated file.
|
|
If no conf_file is provided, then replaces the
|
|
original file/database used to initialize the object.
|
|
"""
|
|
pass
|
|
|
|
|
|
# simple implementation
|
|
class ModelManagerService(ModelManagerServiceBase):
|
|
"""Responsible for managing models on disk and in memory"""
|
|
|
|
def __init__(
|
|
self,
|
|
config: InvokeAIAppConfig,
|
|
logger: Logger,
|
|
):
|
|
"""
|
|
Initialize with the path to the models.yaml config file.
|
|
Optional parameters are the torch device type, precision, max_models,
|
|
and sequential_offload boolean. Note that the default device
|
|
type and precision are set up for a CUDA system running at half precision.
|
|
"""
|
|
if config.model_conf_path and config.model_conf_path.exists():
|
|
config_file = config.model_conf_path
|
|
else:
|
|
config_file = config.root_dir / "configs/models.yaml"
|
|
|
|
logger.debug(f"Config file={config_file}")
|
|
|
|
device = torch.device(choose_torch_device())
|
|
device_name = torch.cuda.get_device_name() if device == torch.device("cuda") else ""
|
|
logger.info(f"GPU device = {device} {device_name}")
|
|
|
|
precision = config.precision
|
|
if precision == "auto":
|
|
precision = choose_precision(device)
|
|
dtype = torch.float32 if precision == "float32" else torch.float16
|
|
|
|
# this is transitional backward compatibility
|
|
# support for the deprecated `max_loaded_models`
|
|
# configuration value. If present, then the
|
|
# cache size is set to 2.5 GB times
|
|
# the number of max_loaded_models. Otherwise
|
|
# use new `ram_cache_size` config setting
|
|
max_cache_size = config.ram_cache_size
|
|
|
|
logger.debug(f"Maximum RAM cache size: {max_cache_size} GiB")
|
|
|
|
sequential_offload = config.sequential_guidance
|
|
|
|
self.mgr = ModelManager(
|
|
config=config_file,
|
|
device_type=device,
|
|
precision=dtype,
|
|
max_cache_size=max_cache_size,
|
|
sequential_offload=sequential_offload,
|
|
logger=logger,
|
|
)
|
|
logger.info("Model manager service initialized")
|
|
|
|
def get_model(
|
|
self,
|
|
model_name: str,
|
|
base_model: BaseModelType,
|
|
model_type: ModelType,
|
|
submodel: Optional[SubModelType] = None,
|
|
context: Optional[InvocationContext] = None,
|
|
) -> ModelInfo:
|
|
"""
|
|
Retrieve the indicated model. submodel can be used to get a
|
|
part (such as the vae) of a diffusers mode.
|
|
"""
|
|
|
|
# we can emit model loading events if we are executing with access to the invocation context
|
|
if context:
|
|
self._emit_load_event(
|
|
context=context,
|
|
model_name=model_name,
|
|
base_model=base_model,
|
|
model_type=model_type,
|
|
submodel=submodel,
|
|
)
|
|
|
|
model_info = self.mgr.get_model(
|
|
model_name,
|
|
base_model,
|
|
model_type,
|
|
submodel,
|
|
)
|
|
|
|
if context:
|
|
self._emit_load_event(
|
|
context=context,
|
|
model_name=model_name,
|
|
base_model=base_model,
|
|
model_type=model_type,
|
|
submodel=submodel,
|
|
model_info=model_info,
|
|
)
|
|
|
|
return model_info
|
|
|
|
def model_exists(
|
|
self,
|
|
model_name: str,
|
|
base_model: BaseModelType,
|
|
model_type: ModelType,
|
|
) -> bool:
|
|
"""
|
|
Given a model name, returns True if it is a valid
|
|
identifier.
|
|
"""
|
|
return self.mgr.model_exists(
|
|
model_name,
|
|
base_model,
|
|
model_type,
|
|
)
|
|
|
|
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]:
|
|
"""
|
|
Given a model name returns a dict-like (OmegaConf) object describing it.
|
|
"""
|
|
return self.mgr.model_info(model_name, base_model, model_type)
|
|
|
|
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
|
"""
|
|
Returns a list of all the model names known.
|
|
"""
|
|
return self.mgr.model_names()
|
|
|
|
def list_models(
|
|
self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None
|
|
) -> list[dict]:
|
|
"""
|
|
Return a list of models.
|
|
"""
|
|
return self.mgr.list_models(base_model, model_type)
|
|
|
|
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]:
|
|
"""
|
|
Return information about the model using the same format as list_models()
|
|
"""
|
|
return self.mgr.list_model(model_name=model_name, base_model=base_model, model_type=model_type)
|
|
|
|
def add_model(
|
|
self,
|
|
model_name: str,
|
|
base_model: BaseModelType,
|
|
model_type: ModelType,
|
|
model_attributes: dict,
|
|
clobber: bool = False,
|
|
) -> AddModelResult:
|
|
"""
|
|
Update the named model with a dictionary of attributes. Will fail with an
|
|
assertion error if the name already exists. Pass clobber=True to overwrite.
|
|
On a successful update, the config will be changed in memory. Will fail
|
|
with an assertion error if provided attributes are incorrect or
|
|
the model name is missing. Call commit() to write changes to disk.
|
|
"""
|
|
self.logger.debug(f"add/update model {model_name}")
|
|
return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber)
|
|
|
|
def update_model(
|
|
self,
|
|
model_name: str,
|
|
base_model: BaseModelType,
|
|
model_type: ModelType,
|
|
model_attributes: dict,
|
|
) -> AddModelResult:
|
|
"""
|
|
Update the named model with a dictionary of attributes. Will fail with a
|
|
ModelNotFoundException exception if the name does not already exist.
|
|
On a successful update, the config will be changed in memory. Will fail
|
|
with an assertion error if provided attributes are incorrect or
|
|
the model name is missing. Call commit() to write changes to disk.
|
|
"""
|
|
self.logger.debug(f"update model {model_name}")
|
|
if not self.model_exists(model_name, base_model, model_type):
|
|
raise ModelNotFoundException(f"Unknown model {model_name}")
|
|
return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True)
|
|
|
|
def del_model(
|
|
self,
|
|
model_name: str,
|
|
base_model: BaseModelType,
|
|
model_type: ModelType,
|
|
):
|
|
"""
|
|
Delete the named model from configuration. If delete_files is true,
|
|
then the underlying weight file or diffusers directory will be deleted
|
|
as well.
|
|
"""
|
|
self.logger.debug(f"delete model {model_name}")
|
|
self.mgr.del_model(model_name, base_model, model_type)
|
|
self.mgr.commit()
|
|
|
|
def convert_model(
|
|
self,
|
|
model_name: str,
|
|
base_model: BaseModelType,
|
|
model_type: Literal[ModelType.Main, ModelType.Vae],
|
|
convert_dest_directory: Optional[Path] = Field(
|
|
default=None, description="Optional directory location for merged model"
|
|
),
|
|
) -> AddModelResult:
|
|
"""
|
|
Convert a checkpoint file into a diffusers folder, deleting the cached
|
|
version and deleting the original checkpoint file if it is in the models
|
|
directory.
|
|
:param model_name: Name of the model to convert
|
|
:param base_model: Base model type
|
|
:param model_type: Type of model ['vae' or 'main']
|
|
:param convert_dest_directory: Save the converted model to the designated directory (`models/etc/etc` by default)
|
|
|
|
This will raise a ValueError unless the model is not a checkpoint. It will
|
|
also raise a ValueError in the event that there is a similarly-named diffusers
|
|
directory already in place.
|
|
"""
|
|
self.logger.debug(f"convert model {model_name}")
|
|
return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory)
|
|
|
|
def collect_cache_stats(self, cache_stats: CacheStats):
|
|
"""
|
|
Reset model cache statistics for graph with graph_id.
|
|
"""
|
|
self.mgr.cache.stats = cache_stats
|
|
|
|
def commit(self, conf_file: Optional[Path] = None):
|
|
"""
|
|
Write current configuration out to the indicated file.
|
|
If no conf_file is provided, then replaces the
|
|
original file/database used to initialize the object.
|
|
"""
|
|
return self.mgr.commit(conf_file)
|
|
|
|
def _emit_load_event(
|
|
self,
|
|
context: InvocationContext,
|
|
model_name: str,
|
|
base_model: BaseModelType,
|
|
model_type: ModelType,
|
|
submodel: Optional[SubModelType] = None,
|
|
model_info: Optional[ModelInfo] = None,
|
|
):
|
|
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
|
raise CanceledException()
|
|
|
|
if model_info:
|
|
context.services.events.emit_model_load_completed(
|
|
queue_id=context.queue_id,
|
|
queue_item_id=context.queue_item_id,
|
|
graph_execution_state_id=context.graph_execution_state_id,
|
|
model_name=model_name,
|
|
base_model=base_model,
|
|
model_type=model_type,
|
|
submodel=submodel,
|
|
model_info=model_info,
|
|
)
|
|
else:
|
|
context.services.events.emit_model_load_started(
|
|
queue_id=context.queue_id,
|
|
queue_item_id=context.queue_item_id,
|
|
graph_execution_state_id=context.graph_execution_state_id,
|
|
model_name=model_name,
|
|
base_model=base_model,
|
|
model_type=model_type,
|
|
submodel=submodel,
|
|
)
|
|
|
|
@property
|
|
def logger(self):
|
|
return self.mgr.logger
|
|
|
|
def heuristic_import(
|
|
self,
|
|
items_to_import: set[str],
|
|
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
|
) -> dict[str, AddModelResult]:
|
|
"""Import a list of paths, repo_ids or URLs. Returns the set of
|
|
successfully imported items.
|
|
:param items_to_import: Set of strings corresponding to models to be imported.
|
|
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
|
|
|
The prediction type helper is necessary to distinguish between
|
|
models based on Stable Diffusion 2 Base (requiring
|
|
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
|
|
(requiring SchedulerPredictionType.VPrediction). It is
|
|
generally impossible to do this programmatically, so the
|
|
prediction_type_helper usually asks the user to choose.
|
|
|
|
The result is a set of successfully installed models. Each element
|
|
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
|
that model.
|
|
"""
|
|
return self.mgr.heuristic_import(items_to_import, prediction_type_helper)
|
|
|
|
def merge_models(
|
|
self,
|
|
model_names: List[str] = Field(
|
|
default=None, min_items=2, max_items=3, description="List of model names to merge"
|
|
),
|
|
base_model: Union[BaseModelType, str] = Field(
|
|
default=None, description="Base model shared by all models to be merged"
|
|
),
|
|
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
|
alpha: float = 0.5,
|
|
interp: Optional[MergeInterpolationMethod] = None,
|
|
force: bool = False,
|
|
merge_dest_directory: Optional[Path] = Field(
|
|
default=None, description="Optional directory location for merged model"
|
|
),
|
|
) -> AddModelResult:
|
|
"""
|
|
Merge two to three diffusrs pipeline models and save as a new model.
|
|
:param model_names: List of 2-3 models to merge
|
|
:param base_model: Base model to use for all models
|
|
:param merged_model_name: Name of destination merged model
|
|
:param alpha: Alpha strength to apply to 2d and 3d model
|
|
:param interp: Interpolation method. None (default)
|
|
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
|
"""
|
|
merger = ModelMerger(self.mgr)
|
|
try:
|
|
result = merger.merge_diffusion_models_and_save(
|
|
model_names=model_names,
|
|
base_model=base_model,
|
|
merged_model_name=merged_model_name,
|
|
alpha=alpha,
|
|
interp=interp,
|
|
force=force,
|
|
merge_dest_directory=merge_dest_directory,
|
|
)
|
|
except AssertionError as e:
|
|
raise ValueError(e)
|
|
return result
|
|
|
|
def search_for_models(self, directory: Path) -> List[Path]:
|
|
"""
|
|
Return list of all models found in the designated directory.
|
|
"""
|
|
search = FindModels([directory], self.logger)
|
|
return search.list_models()
|
|
|
|
def sync_to_config(self):
|
|
"""
|
|
Re-read models.yaml, rescan the models directory, and reimport models
|
|
in the autoimport directories. Call after making changes outside the
|
|
model manager API.
|
|
"""
|
|
return self.mgr.sync_to_config()
|
|
|
|
def list_checkpoint_configs(self) -> List[Path]:
|
|
"""
|
|
List the checkpoint config paths from ROOT/configs/stable-diffusion.
|
|
"""
|
|
config = self.mgr.app_config
|
|
conf_path = config.legacy_conf_path
|
|
root_path = config.root_path
|
|
return [(conf_path / x).relative_to(root_path) for x in conf_path.glob("**/*.yaml")]
|
|
|
|
def rename_model(
|
|
self,
|
|
model_name: str,
|
|
base_model: BaseModelType,
|
|
model_type: ModelType,
|
|
new_name: Optional[str] = None,
|
|
new_base: Optional[BaseModelType] = None,
|
|
):
|
|
"""
|
|
Rename the indicated model. Can provide a new name and/or a new base.
|
|
:param model_name: Current name of the model
|
|
:param base_model: Current base of the model
|
|
:param model_type: Model type (can't be changed)
|
|
:param new_name: New name for the model
|
|
:param new_base: New base for the model
|
|
"""
|
|
self.mgr.rename_model(
|
|
base_model=base_model,
|
|
model_type=model_type,
|
|
model_name=model_name,
|
|
new_name=new_name,
|
|
new_base=new_base,
|
|
)
|