2023-05-05 23:32:28 +00:00
|
|
|
"""This module manages the InvokeAI `models.yaml` file, mapping
|
2023-06-25 20:04:43 +00:00
|
|
|
symbolic diffusers model names to the paths and repo_ids used by the
|
|
|
|
underlying `from_pretrained()` call.
|
2023-05-08 03:18:17 +00:00
|
|
|
|
2023-06-25 20:04:43 +00:00
|
|
|
SYNOPSIS:
|
2023-05-06 19:58:44 +00:00
|
|
|
|
2023-06-25 20:04:43 +00:00
|
|
|
mgr = ModelManager('/home/phi/invokeai/configs/models.yaml')
|
|
|
|
sd1_5 = mgr.get_model('stable-diffusion-v1-5',
|
|
|
|
model_type=ModelType.Main,
|
|
|
|
base_model=BaseModelType.StableDiffusion1,
|
|
|
|
submodel_type=SubModelType.Unet)
|
|
|
|
with sd1_5 as unet:
|
|
|
|
run_some_inference(unet)
|
2023-05-06 19:58:44 +00:00
|
|
|
|
2023-06-25 20:04:43 +00:00
|
|
|
FETCHING MODELS:
|
2023-05-06 19:58:44 +00:00
|
|
|
|
2023-06-25 20:04:43 +00:00
|
|
|
Models are described using four attributes:
|
2023-05-13 18:44:44 +00:00
|
|
|
|
2023-06-25 20:04:43 +00:00
|
|
|
1) model_name -- the symbolic name for the model
|
2023-05-06 19:58:44 +00:00
|
|
|
|
2023-06-25 20:04:43 +00:00
|
|
|
2) ModelType -- an enum describing the type of the model. Currently
|
|
|
|
defined types are:
|
|
|
|
ModelType.Main -- a full model capable of generating images
|
|
|
|
ModelType.Vae -- a VAE model
|
|
|
|
ModelType.Lora -- a LoRA or LyCORIS fine-tune
|
|
|
|
ModelType.TextualInversion -- a textual inversion embedding
|
|
|
|
ModelType.ControlNet -- a ControlNet model
|
2023-05-06 19:58:44 +00:00
|
|
|
|
2023-06-25 20:04:43 +00:00
|
|
|
3) BaseModelType -- an enum indicating the stable diffusion base model, one of:
|
|
|
|
BaseModelType.StableDiffusion1
|
|
|
|
BaseModelType.StableDiffusion2
|
2023-05-06 19:58:44 +00:00
|
|
|
|
2023-06-25 20:04:43 +00:00
|
|
|
4) SubModelType (optional) -- an enum that refers to one of the submodels contained
|
|
|
|
within the main model. Values are:
|
2023-05-06 19:58:44 +00:00
|
|
|
|
2023-06-25 20:04:43 +00:00
|
|
|
SubModelType.UNet
|
|
|
|
SubModelType.TextEncoder
|
|
|
|
SubModelType.Tokenizer
|
|
|
|
SubModelType.Scheduler
|
|
|
|
SubModelType.SafetyChecker
|
2023-05-06 19:58:44 +00:00
|
|
|
|
2023-06-25 20:04:43 +00:00
|
|
|
To fetch a model, use `manager.get_model()`. This takes the symbolic
|
|
|
|
name of the model, the ModelType, the BaseModelType and the
|
|
|
|
SubModelType. The latter is required for ModelType.Main.
|
2023-05-06 19:58:44 +00:00
|
|
|
|
2023-06-25 20:04:43 +00:00
|
|
|
get_model() will return a ModelInfo object that can then be used in
|
|
|
|
context to retrieve the model and move it into GPU VRAM (on GPU
|
|
|
|
systems).
|
|
|
|
|
|
|
|
A typical example is:
|
|
|
|
|
|
|
|
sd1_5 = mgr.get_model('stable-diffusion-v1-5',
|
|
|
|
model_type=ModelType.Main,
|
|
|
|
base_model=BaseModelType.StableDiffusion1,
|
2023-07-01 18:32:58 +00:00
|
|
|
submodel_type=SubModelType.UNet)
|
2023-06-25 20:04:43 +00:00
|
|
|
with sd1_5 as unet:
|
|
|
|
run_some_inference(unet)
|
2023-05-08 03:18:17 +00:00
|
|
|
|
2023-06-25 20:04:43 +00:00
|
|
|
The ModelInfo object provides a number of useful fields describing the
|
|
|
|
model, including:
|
2023-05-08 03:18:17 +00:00
|
|
|
|
2023-06-25 20:04:43 +00:00
|
|
|
name -- symbolic name of the model
|
|
|
|
base_model -- base model (BaseModelType)
|
|
|
|
type -- model type (ModelType)
|
|
|
|
location -- path to the model file
|
|
|
|
precision -- torch precision of the model
|
|
|
|
hash -- unique sha256 checksum for this model
|
2023-05-08 03:18:17 +00:00
|
|
|
|
2023-06-25 20:04:43 +00:00
|
|
|
SUBMODELS:
|
2023-06-09 03:11:53 +00:00
|
|
|
|
2023-06-25 20:04:43 +00:00
|
|
|
When fetching a main model, you must specify the submodel. Retrieval
|
|
|
|
of full pipelines is not supported.
|
|
|
|
|
|
|
|
vae_info = mgr.get_model('stable-diffusion-1.5',
|
|
|
|
model_type = ModelType.Main,
|
|
|
|
base_model = BaseModelType.StableDiffusion1,
|
|
|
|
submodel_type = SubModelType.Vae
|
|
|
|
)
|
|
|
|
with vae_info as vae:
|
|
|
|
do_something(vae)
|
|
|
|
|
|
|
|
This rule does not apply to controlnets, embeddings, loras and standalone
|
|
|
|
VAEs, which do not have submodels.
|
|
|
|
|
|
|
|
LISTING MODELS
|
|
|
|
|
|
|
|
The model_names() method will return a list of Tuples describing each
|
|
|
|
model it knows about:
|
|
|
|
|
|
|
|
>> mgr.model_names()
|
|
|
|
[
|
|
|
|
('stable-diffusion-1.5', <BaseModelType.StableDiffusion1: 'sd-1'>, <ModelType.Main: 'main'>),
|
|
|
|
('stable-diffusion-2.1', <BaseModelType.StableDiffusion2: 'sd-2'>, <ModelType.Main: 'main'>),
|
|
|
|
('inpaint', <BaseModelType.StableDiffusion1: 'sd-1'>, <ModelType.ControlNet: 'controlnet'>)
|
|
|
|
('Ink scenery', <BaseModelType.StableDiffusion1: 'sd-1'>, <ModelType.Lora: 'lora'>)
|
|
|
|
...
|
|
|
|
]
|
|
|
|
|
|
|
|
The tuple is in the correct order to pass to get_model():
|
|
|
|
|
|
|
|
for m in mgr.model_names():
|
|
|
|
info = get_model(*m)
|
|
|
|
|
|
|
|
In contrast, the list_models() method returns a list of dicts, each
|
|
|
|
providing information about a model defined in models.yaml. For example:
|
|
|
|
|
|
|
|
>>> models = mgr.list_models()
|
|
|
|
>>> json.dumps(models[0])
|
2023-07-18 20:33:19 +00:00
|
|
|
{"path": "/home/lstein/invokeai-main/models/sd-1/controlnet/canny",
|
|
|
|
"model_format": "diffusers",
|
|
|
|
"name": "canny",
|
|
|
|
"base_model": "sd-1",
|
2023-06-25 20:04:43 +00:00
|
|
|
"type": "controlnet"
|
|
|
|
}
|
|
|
|
|
|
|
|
You can filter by model type and base model as shown here:
|
|
|
|
|
2023-07-18 20:33:19 +00:00
|
|
|
|
2023-06-25 20:04:43 +00:00
|
|
|
controlnets = mgr.list_models(model_type=ModelType.ControlNet,
|
|
|
|
base_model=BaseModelType.StableDiffusion1)
|
|
|
|
for c in controlnets:
|
|
|
|
name = c['name']
|
|
|
|
format = c['model_format']
|
|
|
|
path = c['path']
|
|
|
|
type = c['type']
|
|
|
|
# etc
|
|
|
|
|
|
|
|
ADDING AND REMOVING MODELS
|
|
|
|
|
|
|
|
At startup time, the `models` directory will be scanned for
|
|
|
|
checkpoints, diffusers pipelines, controlnets, LoRAs and TI
|
|
|
|
embeddings. New entries will be added to the model manager and defunct
|
|
|
|
ones removed. Anything that is a main model (ModelType.Main) will be
|
|
|
|
added to models.yaml. For scanning to succeed, files need to be in
|
|
|
|
their proper places. For example, a controlnet folder built on the
|
|
|
|
stable diffusion 2 base, will need to be placed in
|
|
|
|
`models/sd-2/controlnet`.
|
|
|
|
|
|
|
|
Layout of the `models` directory:
|
|
|
|
|
|
|
|
models
|
|
|
|
├── sd-1
|
2023-07-18 20:33:19 +00:00
|
|
|
│ ├── controlnet
|
|
|
|
│ ├── lora
|
|
|
|
│ ├── main
|
|
|
|
│ └── embedding
|
2023-06-25 20:04:43 +00:00
|
|
|
├── sd-2
|
2023-07-18 20:33:19 +00:00
|
|
|
│ ├── controlnet
|
|
|
|
│ ├── lora
|
|
|
|
│ ├── main
|
2023-06-25 20:04:43 +00:00
|
|
|
│ └── embedding
|
|
|
|
└── core
|
|
|
|
├── face_reconstruction
|
|
|
|
│ ├── codeformer
|
|
|
|
│ └── gfpgan
|
|
|
|
├── sd-conversion
|
|
|
|
│ ├── clip-vit-large-patch14 - tokenizer, text_encoder subdirs
|
|
|
|
│ ├── stable-diffusion-2 - tokenizer, text_encoder subdirs
|
|
|
|
│ └── stable-diffusion-safety-checker
|
|
|
|
└── upscaling
|
|
|
|
└─── esrgan
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConfigMeta(BaseModel):Loras, textual_inversion and controlnet models are not listed
|
2023-06-09 03:11:53 +00:00
|
|
|
explicitly in models.yaml, but are added to the in-memory data
|
|
|
|
structure at initialization time by scanning the models directory. The
|
|
|
|
in-memory data structure can be resynchronized by calling
|
2023-06-25 20:04:43 +00:00
|
|
|
`manager.scan_models_directory()`.
|
2023-05-08 03:18:17 +00:00
|
|
|
|
2023-06-27 16:30:53 +00:00
|
|
|
Files and folders placed inside the `autoimport` paths (paths
|
|
|
|
defined in `invokeai.yaml`) will also be scanned for new models at
|
|
|
|
initialization time and added to `models.yaml`. Files will not be
|
|
|
|
moved from this location but preserved in-place. These directories
|
|
|
|
are:
|
|
|
|
|
|
|
|
configuration default description
|
|
|
|
------------- ------- -----------
|
|
|
|
autoimport_dir autoimport/main main models
|
|
|
|
lora_dir autoimport/lora LoRA/LyCORIS models
|
|
|
|
embedding_dir autoimport/embedding TI embeddings
|
|
|
|
controlnet_dir autoimport/controlnet ControlNet models
|
|
|
|
|
|
|
|
In actuality, models located in any of these directories are scanned
|
|
|
|
to determine their type, so it isn't strictly necessary to organize
|
|
|
|
the different types in this way. This entry in `invokeai.yaml` will
|
|
|
|
recursively scan all subdirectories within `autoimport`, scan models
|
|
|
|
files it finds, and import them if recognized.
|
|
|
|
|
|
|
|
Paths:
|
|
|
|
autoimport_dir: autoimport
|
2023-05-08 03:18:17 +00:00
|
|
|
|
2023-06-25 20:04:43 +00:00
|
|
|
A model can be manually added using `add_model()` using the model's
|
|
|
|
name, base model, type and a dict of model attributes. See
|
|
|
|
`invokeai/backend/model_management/models` for the attributes required
|
|
|
|
by each model type.
|
2023-05-08 03:18:17 +00:00
|
|
|
|
2023-07-18 20:33:19 +00:00
|
|
|
A model can be deleted using `del_model()`, providing the same
|
2023-06-25 20:04:43 +00:00
|
|
|
identifying information as `get_model()`
|
2023-05-08 03:18:17 +00:00
|
|
|
|
2023-06-25 20:04:43 +00:00
|
|
|
The `heuristic_import()` method will take a set of strings
|
|
|
|
corresponding to local paths, remote URLs, and repo_ids, probe the
|
|
|
|
object to determine what type of model it is (if any), and import new
|
|
|
|
models into the manager. If passed a directory, it will recursively
|
|
|
|
scan it for models to import. The return value is a set of the models
|
|
|
|
successfully added.
|
2023-05-08 03:18:17 +00:00
|
|
|
|
2023-06-25 20:04:43 +00:00
|
|
|
MODELS.YAML
|
2023-05-08 03:18:17 +00:00
|
|
|
|
2023-06-25 20:04:43 +00:00
|
|
|
The general format of a models.yaml section is:
|
2023-05-08 03:18:17 +00:00
|
|
|
|
2023-06-25 20:04:43 +00:00
|
|
|
type-of-model/name-of-model:
|
|
|
|
path: /path/to/local/file/or/directory
|
|
|
|
description: a description
|
|
|
|
format: diffusers|checkpoint
|
|
|
|
variant: normal|inpaint|depth
|
2023-06-09 03:11:53 +00:00
|
|
|
|
2023-06-25 20:04:43 +00:00
|
|
|
The type of model is given in the stanza key, and is one of
|
|
|
|
{main, vae, lora, controlnet, textual}
|
2023-06-09 03:11:53 +00:00
|
|
|
|
2023-06-25 20:04:43 +00:00
|
|
|
The format indicates whether the model is organized as a diffusers
|
|
|
|
folder with model subdirectories, or is contained in a single
|
|
|
|
checkpoint or safetensors file.
|
2023-06-09 03:11:53 +00:00
|
|
|
|
2023-06-25 20:04:43 +00:00
|
|
|
The path points to a file or directory on disk. If a relative path,
|
|
|
|
the root is the InvokeAI ROOTDIR.
|
2023-06-27 16:30:53 +00:00
|
|
|
|
2023-02-28 05:31:15 +00:00
|
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
|
2023-06-13 15:05:12 +00:00
|
|
|
import hashlib
|
2023-07-29 18:50:04 +00:00
|
|
|
import os
|
2023-02-28 05:31:15 +00:00
|
|
|
import textwrap
|
2023-07-29 18:50:04 +00:00
|
|
|
import types
|
2023-05-05 23:32:28 +00:00
|
|
|
from dataclasses import dataclass
|
2023-02-28 05:31:15 +00:00
|
|
|
from pathlib import Path
|
2023-07-05 13:05:05 +00:00
|
|
|
from shutil import rmtree, move
|
2023-08-06 05:02:28 +00:00
|
|
|
from typing import Optional, List, Literal, Tuple, Union, Dict, Set, Callable
|
2023-02-28 05:31:15 +00:00
|
|
|
|
|
|
|
import torch
|
2023-07-29 18:50:04 +00:00
|
|
|
import yaml
|
2023-02-28 05:31:15 +00:00
|
|
|
from omegaconf import OmegaConf
|
|
|
|
from omegaconf.dictconfig import DictConfig
|
2023-07-03 23:32:54 +00:00
|
|
|
from pydantic import BaseModel, Field
|
2023-06-11 03:12:21 +00:00
|
|
|
|
2023-05-13 18:44:44 +00:00
|
|
|
import invokeai.backend.util.logging as logger
|
2023-05-26 00:41:26 +00:00
|
|
|
from invokeai.app.services.config import InvokeAIAppConfig
|
2023-06-25 22:50:15 +00:00
|
|
|
from invokeai.backend.util import CUDA_DEVICE, Chdir
|
2023-06-10 14:41:48 +00:00
|
|
|
from .model_cache import ModelCache, ModelLocker
|
2023-07-14 15:14:33 +00:00
|
|
|
from .model_search import ModelSearch
|
2023-06-23 20:35:39 +00:00
|
|
|
from .models import (
|
|
|
|
BaseModelType,
|
|
|
|
ModelType,
|
|
|
|
SubModelType,
|
|
|
|
ModelError,
|
|
|
|
SchedulerPredictionType,
|
|
|
|
MODEL_CLASSES,
|
2023-07-21 18:14:33 +00:00
|
|
|
ModelConfigBase,
|
|
|
|
ModelNotFoundException,
|
|
|
|
InvalidModelException,
|
|
|
|
DuplicateModelException,
|
2023-07-30 03:02:31 +00:00
|
|
|
ModelBase,
|
2023-07-08 01:09:10 +00:00
|
|
|
)
|
2023-06-11 16:51:50 +00:00
|
|
|
|
2023-05-13 18:44:44 +00:00
|
|
|
# We are only starting to number the config file with release 3.
|
|
|
|
# The config file version doesn't have to start at release version, but it will help
|
|
|
|
# reduce confusion.
|
|
|
|
CONFIG_FILE_VERSION = "3.0.0"
|
2023-07-27 14:54:01 +00:00
|
|
|
|
2023-05-05 23:32:28 +00:00
|
|
|
|
|
|
|
@dataclass
|
2023-06-11 16:51:50 +00:00
|
|
|
class ModelInfo:
|
2023-05-06 19:58:44 +00:00
|
|
|
context: ModelLocker
|
2023-05-05 23:32:28 +00:00
|
|
|
name: str
|
2023-06-11 13:42:40 +00:00
|
|
|
base_model: BaseModelType
|
2023-06-10 14:41:48 +00:00
|
|
|
type: ModelType
|
2023-05-05 23:32:28 +00:00
|
|
|
hash: str
|
2023-06-11 13:42:40 +00:00
|
|
|
location: Union[Path, str]
|
2023-05-05 23:32:28 +00:00
|
|
|
precision: torch.dtype
|
2023-07-30 14:25:12 +00:00
|
|
|
_cache: Optional[ModelCache] = None
|
2023-05-05 23:32:28 +00:00
|
|
|
|
2023-05-13 20:29:18 +00:00
|
|
|
def __enter__(self):
|
|
|
|
return self.context.__enter__()
|
|
|
|
|
|
|
|
def __exit__(self, *args, **kwargs):
|
|
|
|
self.context.__exit__(*args, **kwargs)
|
2023-02-28 05:31:15 +00:00
|
|
|
|
2023-07-27 14:54:01 +00:00
|
|
|
|
2023-07-03 23:32:54 +00:00
|
|
|
class AddModelResult(BaseModel):
|
2023-07-06 16:21:42 +00:00
|
|
|
name: str = Field(description="The name of the model after installation")
|
2023-07-03 23:32:54 +00:00
|
|
|
model_type: ModelType = Field(description="The type of model")
|
|
|
|
base_model: BaseModelType = Field(description="The base model")
|
|
|
|
config: ModelConfigBase = Field(description="The configuration of the model")
|
2023-02-28 05:31:15 +00:00
|
|
|
|
2023-07-27 14:54:01 +00:00
|
|
|
|
2023-07-03 23:32:54 +00:00
|
|
|
MAX_CACHE_SIZE = 6.0 # GB
|
2023-06-09 03:11:53 +00:00
|
|
|
|
2023-07-27 14:54:01 +00:00
|
|
|
|
2023-06-11 01:49:09 +00:00
|
|
|
class ConfigMeta(BaseModel):
|
|
|
|
version: str
|
|
|
|
|
2023-07-27 14:54:01 +00:00
|
|
|
|
2023-02-28 05:31:15 +00:00
|
|
|
class ModelManager(object):
|
2023-04-05 21:25:42 +00:00
|
|
|
"""
|
2023-05-05 23:32:28 +00:00
|
|
|
High-level interface to model management.
|
2023-04-05 21:25:42 +00:00
|
|
|
"""
|
|
|
|
|
2023-04-29 14:48:50 +00:00
|
|
|
logger: types.ModuleType = logger
|
|
|
|
|
2023-02-28 05:31:15 +00:00
|
|
|
def __init__(
|
2023-05-12 20:13:34 +00:00
|
|
|
self,
|
|
|
|
config: Union[Path, DictConfig, str],
|
|
|
|
device_type: torch.device = CUDA_DEVICE,
|
|
|
|
precision: torch.dtype = torch.float16,
|
|
|
|
max_cache_size=MAX_CACHE_SIZE,
|
|
|
|
sequential_offload=False,
|
|
|
|
logger: types.ModuleType = logger,
|
2023-02-28 05:31:15 +00:00
|
|
|
):
|
|
|
|
"""
|
2023-07-18 20:33:19 +00:00
|
|
|
Initialize with the path to the models.yaml config file.
|
2023-05-05 23:32:28 +00:00
|
|
|
Optional parameters are the torch device type, precision, max_models,
|
2023-04-05 21:25:42 +00:00
|
|
|
and sequential_offload boolean. Note that the default device
|
2023-03-09 06:09:54 +00:00
|
|
|
type and precision are set up for a CUDA system running at half precision.
|
2023-02-28 05:31:15 +00:00
|
|
|
"""
|
2023-06-11 01:49:09 +00:00
|
|
|
self.config_path = None
|
|
|
|
if isinstance(config, (str, Path)):
|
|
|
|
self.config_path = Path(config)
|
2023-07-08 19:13:51 +00:00
|
|
|
if not self.config_path.exists():
|
|
|
|
logger.warning(f"The file {self.config_path} was not found. Initializing a new file")
|
|
|
|
self.initialize_model_config(self.config_path)
|
2023-06-11 01:49:09 +00:00
|
|
|
config = OmegaConf.load(self.config_path)
|
|
|
|
|
|
|
|
elif not isinstance(config, DictConfig):
|
2023-05-08 03:18:17 +00:00
|
|
|
raise ValueError("config argument must be an OmegaConf object, a Path or a string")
|
|
|
|
|
2023-06-14 00:12:12 +00:00
|
|
|
self.config_meta = ConfigMeta(**config.pop("__metadata__"))
|
2023-06-11 01:49:09 +00:00
|
|
|
# TODO: metadata not found
|
2023-06-14 00:12:12 +00:00
|
|
|
# TODO: version check
|
2023-07-18 20:33:19 +00:00
|
|
|
|
2023-06-26 00:07:54 +00:00
|
|
|
self.app_config = InvokeAIAppConfig.get_config()
|
2023-04-29 14:48:50 +00:00
|
|
|
self.logger = logger
|
2023-05-05 23:32:28 +00:00
|
|
|
self.cache = ModelCache(
|
2023-05-07 22:07:28 +00:00
|
|
|
max_cache_size=max_cache_size,
|
2023-07-11 19:25:39 +00:00
|
|
|
max_vram_cache_size=self.app_config.max_vram_cache_size,
|
2023-05-05 23:32:28 +00:00
|
|
|
execution_device=device_type,
|
|
|
|
precision=precision,
|
|
|
|
sequential_offload=sequential_offload,
|
|
|
|
logger=logger,
|
|
|
|
)
|
2023-07-14 17:45:16 +00:00
|
|
|
|
|
|
|
self._read_models(config)
|
|
|
|
|
|
|
|
def _read_models(self, config: Optional[DictConfig] = None):
|
|
|
|
if not config:
|
|
|
|
if self.config_path:
|
|
|
|
config = OmegaConf.load(self.config_path)
|
|
|
|
else:
|
|
|
|
return
|
|
|
|
|
|
|
|
self.models = dict()
|
|
|
|
for model_key, model_config in config.items():
|
|
|
|
if model_key.startswith("_"):
|
|
|
|
continue
|
|
|
|
model_name, base_model, model_type = self.parse_key(model_key)
|
2023-07-29 04:03:27 +00:00
|
|
|
model_class = self._get_implementation(base_model, model_type)
|
2023-07-14 17:45:16 +00:00
|
|
|
# alias for config file
|
|
|
|
model_config["model_format"] = model_config.pop("format")
|
|
|
|
self.models[model_key] = model_class.create_config(**model_config)
|
|
|
|
|
|
|
|
# check config version number and update on disk/RAM if necessary
|
2023-05-06 19:58:44 +00:00
|
|
|
self.cache_keys = dict()
|
2023-02-28 05:31:15 +00:00
|
|
|
|
2023-06-09 03:11:53 +00:00
|
|
|
# add controlnet, lora and textual_inversion models from disk
|
2023-06-11 13:42:40 +00:00
|
|
|
self.scan_models_directory()
|
2023-06-09 03:11:53 +00:00
|
|
|
|
2023-07-14 17:45:16 +00:00
|
|
|
def sync_to_config(self):
|
|
|
|
"""
|
|
|
|
Call this when `models.yaml` has been changed externally.
|
|
|
|
This will reinitialize internal data structures
|
|
|
|
"""
|
|
|
|
# Reread models directory; note that this will reinitialize the cache,
|
|
|
|
# causing otherwise unreferenced models to be removed from memory
|
|
|
|
self._read_models()
|
|
|
|
|
2023-07-30 18:53:12 +00:00
|
|
|
def model_exists(self, model_name: str, base_model: BaseModelType, model_type: ModelType, *, rescan=False) -> bool:
|
2023-02-28 05:31:15 +00:00
|
|
|
"""
|
2023-07-30 18:53:12 +00:00
|
|
|
Given a model name, returns True if it is a valid identifier.
|
2023-07-31 16:08:46 +00:00
|
|
|
|
|
|
|
:param model_name: symbolic name of the model in models.yaml
|
|
|
|
:param model_type: ModelType enum indicating the type of model to return
|
|
|
|
:param base_model: BaseModelType enum indicating the base model used by this model
|
|
|
|
:param rescan: if True, scan_models_directory
|
2023-02-28 05:31:15 +00:00
|
|
|
"""
|
2023-06-10 00:14:10 +00:00
|
|
|
model_key = self.create_key(model_name, base_model, model_type)
|
2023-07-29 04:11:00 +00:00
|
|
|
exists = model_key in self.models
|
|
|
|
|
|
|
|
# if model not found try to find it (maybe file just pasted)
|
|
|
|
if rescan and not exists:
|
|
|
|
self.scan_models_directory(base_model=base_model, model_type=model_type)
|
2023-07-30 18:53:12 +00:00
|
|
|
exists = self.model_exists(model_name, base_model, model_type, rescan=False)
|
2023-07-29 04:11:00 +00:00
|
|
|
|
|
|
|
return exists
|
2023-05-12 20:13:34 +00:00
|
|
|
|
2023-06-11 20:10:15 +00:00
|
|
|
@classmethod
|
2023-06-10 00:14:10 +00:00
|
|
|
def create_key(
|
2023-06-11 20:10:15 +00:00
|
|
|
cls,
|
2023-06-10 00:14:10 +00:00
|
|
|
model_name: str,
|
|
|
|
base_model: BaseModelType,
|
|
|
|
model_type: ModelType,
|
|
|
|
) -> str:
|
2023-07-25 12:36:57 +00:00
|
|
|
# In 3.11, the behavior of (str,enum) when interpolated into a
|
|
|
|
# string has changed. The next two lines are defensive.
|
|
|
|
base_model = BaseModelType(base_model)
|
|
|
|
model_type = ModelType(model_type)
|
2023-07-24 21:13:32 +00:00
|
|
|
return f"{base_model.value}/{model_type.value}/{model_name}"
|
2023-05-12 20:13:34 +00:00
|
|
|
|
2023-06-16 03:32:33 +00:00
|
|
|
@classmethod
|
|
|
|
def parse_key(cls, model_key: str) -> Tuple[str, BaseModelType, ModelType]:
|
2023-06-10 00:14:10 +00:00
|
|
|
base_model_str, model_type_str, model_name = model_key.split("/", 2)
|
2023-05-14 00:06:26 +00:00
|
|
|
try:
|
2023-06-10 14:41:48 +00:00
|
|
|
model_type = ModelType(model_type_str)
|
2023-05-14 00:06:26 +00:00
|
|
|
except:
|
2023-05-13 18:44:44 +00:00
|
|
|
raise Exception(f"Unknown model type: {model_type_str}")
|
2023-02-28 05:31:15 +00:00
|
|
|
|
2023-06-10 00:14:10 +00:00
|
|
|
try:
|
|
|
|
base_model = BaseModelType(base_model_str)
|
|
|
|
except:
|
|
|
|
raise Exception(f"Unknown base model: {base_model_str}")
|
|
|
|
|
|
|
|
return (model_name, base_model, model_type)
|
|
|
|
|
2023-06-26 00:07:54 +00:00
|
|
|
def _get_model_cache_path(self, model_path):
|
2023-07-29 17:13:22 +00:00
|
|
|
return self.resolve_model_path(Path(".cache") / hashlib.md5(str(model_path).encode()).hexdigest())
|
2023-06-26 00:07:54 +00:00
|
|
|
|
2023-07-08 19:13:51 +00:00
|
|
|
@classmethod
|
|
|
|
def initialize_model_config(cls, config_path: Path):
|
|
|
|
"""Create empty config file"""
|
|
|
|
with open(config_path, "w") as yaml_file:
|
|
|
|
yaml_file.write(yaml.dump({"__metadata__": {"version": "3.0.0"}}))
|
|
|
|
|
2023-05-12 20:13:34 +00:00
|
|
|
def get_model(
|
|
|
|
self,
|
|
|
|
model_name: str,
|
2023-06-10 00:14:10 +00:00
|
|
|
base_model: BaseModelType,
|
|
|
|
model_type: ModelType,
|
|
|
|
submodel_type: Optional[SubModelType] = None,
|
2023-06-25 20:04:43 +00:00
|
|
|
) -> ModelInfo:
|
2023-05-05 23:32:28 +00:00
|
|
|
"""Given a model named identified in models.yaml, return
|
2023-06-11 16:51:50 +00:00
|
|
|
an ModelInfo object describing it.
|
2023-05-05 23:32:28 +00:00
|
|
|
:param model_name: symbolic name of the model in models.yaml
|
2023-06-10 14:41:48 +00:00
|
|
|
:param model_type: ModelType enum indicating the type of model to return
|
2023-06-25 20:04:43 +00:00
|
|
|
:param base_model: BaseModelType enum indicating the base model used by this model
|
2023-07-29 05:01:28 +00:00
|
|
|
:param submodel_type: an ModelType enum indicating the portion of
|
2023-06-10 14:41:48 +00:00
|
|
|
the model to retrieve (e.g. ModelType.Vae)
|
2023-02-28 05:31:15 +00:00
|
|
|
"""
|
2023-06-11 01:49:09 +00:00
|
|
|
model_key = self.create_key(model_name, base_model, model_type)
|
|
|
|
|
2023-07-29 04:11:00 +00:00
|
|
|
if not self.model_exists(model_name, base_model, model_type, rescan=True):
|
|
|
|
raise ModelNotFoundException(f"Model not found - {model_key}")
|
2023-06-10 13:57:23 +00:00
|
|
|
|
2023-07-29 04:30:20 +00:00
|
|
|
model_config = self._get_model_config(base_model, model_name, model_type)
|
2023-07-29 05:01:28 +00:00
|
|
|
|
|
|
|
model_path, is_submodel_override = self._get_model_path(model_config, submodel_type)
|
2023-06-10 13:57:23 +00:00
|
|
|
|
2023-07-29 05:01:28 +00:00
|
|
|
if is_submodel_override:
|
|
|
|
model_type = submodel_type
|
|
|
|
submodel_type = None
|
|
|
|
|
|
|
|
model_class = self._get_implementation(base_model, model_type)
|
2023-06-26 00:07:54 +00:00
|
|
|
|
|
|
|
if not model_path.exists():
|
|
|
|
if model_class.save_to_config:
|
|
|
|
self.models[model_key].error = ModelError.NotFound
|
2023-07-29 05:01:28 +00:00
|
|
|
raise Exception(f'Files for model "{model_key}" not found at {model_path}')
|
2023-06-13 15:05:12 +00:00
|
|
|
|
2023-06-26 00:07:54 +00:00
|
|
|
else:
|
|
|
|
self.models.pop(model_key, None)
|
2023-07-31 16:08:46 +00:00
|
|
|
raise ModelNotFoundException(f'Files for model "{model_key}" not found at {model_path}')
|
2023-06-10 00:14:10 +00:00
|
|
|
|
2023-06-13 15:05:12 +00:00
|
|
|
# TODO: path
|
|
|
|
# TODO: is it accurate to use path as id
|
2023-06-26 00:07:54 +00:00
|
|
|
dst_convert_path = self._get_model_cache_path(model_path)
|
2023-07-03 15:19:33 +00:00
|
|
|
|
2023-06-10 00:14:10 +00:00
|
|
|
model_path = model_class.convert_if_required(
|
2023-06-13 15:05:12 +00:00
|
|
|
base_model=base_model,
|
2023-06-26 00:54:42 +00:00
|
|
|
model_path=str(model_path), # TODO: refactor str/Path types logic
|
2023-06-13 15:05:12 +00:00
|
|
|
output_path=dst_convert_path,
|
|
|
|
config=model_config,
|
2023-06-10 00:14:10 +00:00
|
|
|
)
|
2023-05-18 00:56:52 +00:00
|
|
|
|
2023-05-05 23:32:28 +00:00
|
|
|
model_context = self.cache.get_model(
|
2023-06-12 13:14:09 +00:00
|
|
|
model_path=model_path,
|
|
|
|
model_class=model_class,
|
|
|
|
base_model=base_model,
|
|
|
|
model_type=model_type,
|
|
|
|
submodel=submodel_type,
|
2023-05-05 23:32:28 +00:00
|
|
|
)
|
2023-05-06 19:58:44 +00:00
|
|
|
|
2023-06-14 00:12:12 +00:00
|
|
|
if model_key not in self.cache_keys:
|
|
|
|
self.cache_keys[model_key] = set()
|
|
|
|
self.cache_keys[model_key].add(model_context.key)
|
|
|
|
|
2023-06-13 15:05:12 +00:00
|
|
|
model_hash = "<NO_HASH>" # TODO:
|
2023-07-18 20:33:19 +00:00
|
|
|
|
2023-06-11 16:51:50 +00:00
|
|
|
return ModelInfo(
|
2023-05-05 23:32:28 +00:00
|
|
|
context=model_context,
|
|
|
|
name=model_name,
|
2023-06-10 00:14:10 +00:00
|
|
|
base_model=base_model,
|
2023-06-10 14:41:48 +00:00
|
|
|
type=submodel_type or model_type,
|
2023-06-13 15:05:12 +00:00
|
|
|
hash=model_hash,
|
2023-06-10 00:14:10 +00:00
|
|
|
location=model_path, # TODO:
|
2023-05-05 23:32:28 +00:00
|
|
|
precision=self.cache.precision,
|
2023-06-10 00:14:10 +00:00
|
|
|
_cache=self.cache,
|
2023-05-05 23:32:28 +00:00
|
|
|
)
|
2023-04-05 21:25:42 +00:00
|
|
|
|
2023-07-31 16:08:46 +00:00
|
|
|
def _get_model_path(
|
|
|
|
self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None
|
|
|
|
) -> (Path, bool):
|
|
|
|
"""Extract a model's filesystem path from its config.
|
|
|
|
|
|
|
|
:return: The fully qualified Path of the module (or submodule).
|
|
|
|
"""
|
2023-07-29 05:01:28 +00:00
|
|
|
model_path = model_config.path
|
|
|
|
is_submodel_override = False
|
|
|
|
|
|
|
|
# Does the config explicitly override the submodel?
|
|
|
|
if submodel_type is not None and hasattr(model_config, submodel_type):
|
|
|
|
submodel_path = getattr(model_config, submodel_type)
|
|
|
|
if submodel_path is not None:
|
|
|
|
model_path = getattr(model_config, submodel_type)
|
|
|
|
is_submodel_override = True
|
|
|
|
|
2023-07-30 18:33:13 +00:00
|
|
|
model_path = self.resolve_model_path(model_path)
|
2023-07-29 05:01:28 +00:00
|
|
|
return model_path, is_submodel_override
|
|
|
|
|
2023-08-05 22:22:23 +00:00
|
|
|
def _get_model_config(self, base_model: BaseModelType, model_name: str, model_type: ModelType) -> ModelConfigBase:
|
2023-07-31 16:08:46 +00:00
|
|
|
"""Get a model's config object."""
|
2023-07-29 04:30:20 +00:00
|
|
|
model_key = self.create_key(model_name, base_model, model_type)
|
|
|
|
try:
|
|
|
|
model_config = self.models[model_key]
|
|
|
|
except KeyError:
|
|
|
|
raise ModelNotFoundException(f"Model not found - {model_key}")
|
|
|
|
return model_config
|
|
|
|
|
2023-07-29 04:03:27 +00:00
|
|
|
def _get_implementation(self, base_model: BaseModelType, model_type: ModelType) -> type[ModelBase]:
|
2023-07-31 16:08:46 +00:00
|
|
|
"""Get the concrete implementation class for a specific model type."""
|
2023-07-29 04:03:27 +00:00
|
|
|
model_class = MODEL_CLASSES[base_model][model_type]
|
|
|
|
return model_class
|
|
|
|
|
2023-07-30 03:02:31 +00:00
|
|
|
def _instantiate(
|
2023-07-31 16:08:46 +00:00
|
|
|
self,
|
|
|
|
model_name: str,
|
|
|
|
base_model: BaseModelType,
|
|
|
|
model_type: ModelType,
|
|
|
|
submodel_type: Optional[SubModelType] = None,
|
2023-07-30 03:02:31 +00:00
|
|
|
) -> ModelBase:
|
2023-07-31 16:08:46 +00:00
|
|
|
"""Make a new instance of this model, without loading it."""
|
2023-07-29 05:30:25 +00:00
|
|
|
model_config = self._get_model_config(base_model, model_name, model_type)
|
|
|
|
model_path, is_submodel_override = self._get_model_path(model_config, submodel_type)
|
|
|
|
# FIXME: do non-overriden submodels get the right class?
|
|
|
|
constructor = self._get_implementation(base_model, model_type)
|
|
|
|
instance = constructor(model_path, base_model, model_type)
|
|
|
|
return instance
|
|
|
|
|
2023-05-13 18:44:44 +00:00
|
|
|
def model_info(
|
2023-05-14 00:06:26 +00:00
|
|
|
self,
|
|
|
|
model_name: str,
|
2023-06-10 00:14:10 +00:00
|
|
|
base_model: BaseModelType,
|
|
|
|
model_type: ModelType,
|
2023-08-01 07:55:13 +00:00
|
|
|
) -> Union[dict, None]:
|
2023-02-28 05:31:15 +00:00
|
|
|
"""
|
|
|
|
Given a model name returns the OmegaConf (dict-like) object describing it.
|
|
|
|
"""
|
2023-06-10 00:14:10 +00:00
|
|
|
model_key = self.create_key(model_name, base_model, model_type)
|
2023-06-11 01:49:09 +00:00
|
|
|
if model_key in self.models:
|
|
|
|
return self.models[model_key].dict(exclude_defaults=True)
|
|
|
|
else:
|
|
|
|
return None # TODO: None or empty dict on not found
|
2023-02-28 05:31:15 +00:00
|
|
|
|
2023-06-10 00:14:10 +00:00
|
|
|
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
2023-02-28 05:31:15 +00:00
|
|
|
"""
|
2023-07-18 20:33:19 +00:00
|
|
|
Return a list of (str, BaseModelType, ModelType) corresponding to all models
|
2023-05-13 18:44:44 +00:00
|
|
|
known to the configuration.
|
2023-02-28 05:31:15 +00:00
|
|
|
"""
|
2023-06-11 01:49:09 +00:00
|
|
|
return [(self.parse_key(x)) for x in self.models.keys()]
|
2023-02-28 05:31:15 +00:00
|
|
|
|
2023-07-06 03:13:01 +00:00
|
|
|
def list_model(
|
|
|
|
self,
|
|
|
|
model_name: str,
|
|
|
|
base_model: BaseModelType,
|
|
|
|
model_type: ModelType,
|
2023-08-01 07:55:13 +00:00
|
|
|
) -> Union[dict, None]:
|
2023-07-06 03:13:01 +00:00
|
|
|
"""
|
|
|
|
Returns a dict describing one installed model, using
|
|
|
|
the combined format of the list_models() method.
|
|
|
|
"""
|
|
|
|
models = self.list_models(base_model, model_type, model_name)
|
2023-08-07 19:38:22 +00:00
|
|
|
if len(models) >= 1:
|
2023-08-01 07:55:13 +00:00
|
|
|
return models[0]
|
2023-08-07 19:38:22 +00:00
|
|
|
else:
|
|
|
|
return None
|
2023-07-06 03:13:01 +00:00
|
|
|
|
2023-06-10 00:14:10 +00:00
|
|
|
def list_models(
|
|
|
|
self,
|
|
|
|
base_model: Optional[BaseModelType] = None,
|
2023-06-10 14:41:48 +00:00
|
|
|
model_type: Optional[ModelType] = None,
|
2023-07-06 03:13:01 +00:00
|
|
|
model_name: Optional[str] = None,
|
2023-06-22 07:34:12 +00:00
|
|
|
) -> list[dict]:
|
2023-02-28 05:31:15 +00:00
|
|
|
"""
|
2023-06-22 07:34:12 +00:00
|
|
|
Return a list of models.
|
2023-02-28 05:31:15 +00:00
|
|
|
"""
|
2023-06-10 00:14:10 +00:00
|
|
|
|
2023-07-06 03:13:01 +00:00
|
|
|
model_keys = (
|
|
|
|
[self.create_key(model_name, base_model, model_type)]
|
2023-08-01 07:55:13 +00:00
|
|
|
if model_name and base_model and model_type
|
2023-07-06 03:13:01 +00:00
|
|
|
else sorted(self.models, key=str.casefold)
|
2023-07-27 14:54:01 +00:00
|
|
|
)
|
2023-06-22 07:34:12 +00:00
|
|
|
models = []
|
2023-07-06 03:13:01 +00:00
|
|
|
for model_key in model_keys:
|
2023-07-14 17:45:16 +00:00
|
|
|
model_config = self.models.get(model_key)
|
|
|
|
if not model_config:
|
|
|
|
self.logger.error(f"Unknown model {model_name}")
|
2023-07-16 18:17:05 +00:00
|
|
|
raise ModelNotFoundException(f"Unknown model {model_name}")
|
2023-02-28 05:31:15 +00:00
|
|
|
|
2023-06-11 01:49:09 +00:00
|
|
|
cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
|
|
|
|
if base_model is not None and cur_base_model != base_model:
|
2023-06-10 00:14:10 +00:00
|
|
|
continue
|
2023-06-11 01:49:09 +00:00
|
|
|
if model_type is not None and cur_model_type != model_type:
|
2023-05-12 20:13:34 +00:00
|
|
|
continue
|
|
|
|
|
2023-06-22 07:34:12 +00:00
|
|
|
model_dict = dict(
|
2023-06-11 01:49:09 +00:00
|
|
|
**model_config.dict(exclude_defaults=True),
|
2023-06-17 19:48:44 +00:00
|
|
|
# OpenAPIModelInfoBase
|
2023-07-12 15:06:45 +00:00
|
|
|
model_name=cur_model_name,
|
2023-06-11 01:49:09 +00:00
|
|
|
base_model=cur_base_model,
|
2023-07-12 15:06:45 +00:00
|
|
|
model_type=cur_model_type,
|
2023-02-28 05:31:15 +00:00
|
|
|
)
|
|
|
|
|
2023-07-17 11:29:26 +00:00
|
|
|
# expose paths as absolute to help web UI
|
2023-07-17 11:26:05 +00:00
|
|
|
if path := model_dict.get("path"):
|
2023-07-29 17:00:07 +00:00
|
|
|
model_dict["path"] = str(self.resolve_model_path(path))
|
2023-06-22 07:34:12 +00:00
|
|
|
models.append(model_dict)
|
|
|
|
|
2023-02-28 05:31:15 +00:00
|
|
|
return models
|
|
|
|
|
|
|
|
def print_models(self) -> None:
|
|
|
|
"""
|
2023-06-25 20:04:43 +00:00
|
|
|
Print a table of models and their descriptions. This needs to be redone
|
2023-02-28 05:31:15 +00:00
|
|
|
"""
|
2023-06-11 01:49:09 +00:00
|
|
|
# TODO: redo
|
2023-08-01 07:55:13 +00:00
|
|
|
for model_dict in self.list_models():
|
2023-05-16 03:44:08 +00:00
|
|
|
for model_name, model_info in model_dict.items():
|
2023-06-11 01:49:09 +00:00
|
|
|
line = f'{model_info["name"]:25s} {model_info["type"]:10s} {model_info["description"]}'
|
2023-05-16 03:44:08 +00:00
|
|
|
print(line)
|
2023-02-28 05:31:15 +00:00
|
|
|
|
2023-06-16 03:32:33 +00:00
|
|
|
# Tested - LS
|
2023-05-12 20:13:34 +00:00
|
|
|
def del_model(
|
|
|
|
self,
|
|
|
|
model_name: str,
|
2023-06-11 03:12:21 +00:00
|
|
|
base_model: BaseModelType,
|
|
|
|
model_type: ModelType,
|
2023-05-12 20:13:34 +00:00
|
|
|
):
|
2023-02-28 05:31:15 +00:00
|
|
|
"""
|
|
|
|
Delete the named model.
|
|
|
|
"""
|
2023-06-11 03:12:21 +00:00
|
|
|
model_key = self.create_key(model_name, base_model, model_type)
|
|
|
|
model_cfg = self.models.pop(model_key, None)
|
2023-05-12 20:13:34 +00:00
|
|
|
|
|
|
|
if model_cfg is None:
|
2023-07-16 18:17:05 +00:00
|
|
|
raise ModelNotFoundException(f"Unknown model {model_key}")
|
2023-05-12 20:13:34 +00:00
|
|
|
|
2023-06-14 00:12:12 +00:00
|
|
|
# note: it not garantie to release memory(model can has other references)
|
|
|
|
cache_ids = self.cache_keys.pop(model_key, [])
|
|
|
|
for cache_id in cache_ids:
|
|
|
|
self.cache.uncache_model(cache_id)
|
|
|
|
|
|
|
|
# if model inside invoke models folder - delete files
|
2023-07-29 14:45:26 +00:00
|
|
|
model_path = self.resolve_model_path(model_cfg.path)
|
2023-06-26 00:07:54 +00:00
|
|
|
cache_path = self._get_model_cache_path(model_path)
|
|
|
|
if cache_path.exists():
|
|
|
|
rmtree(str(cache_path))
|
2023-06-16 03:32:33 +00:00
|
|
|
|
2023-06-26 00:07:54 +00:00
|
|
|
if model_path.is_relative_to(self.app_config.models_path):
|
2023-06-16 03:32:33 +00:00
|
|
|
if model_path.is_dir():
|
|
|
|
rmtree(str(model_path))
|
2023-06-14 00:12:12 +00:00
|
|
|
else:
|
|
|
|
model_path.unlink()
|
2023-07-12 04:39:07 +00:00
|
|
|
self.commit()
|
2023-02-28 05:31:15 +00:00
|
|
|
|
2023-06-25 20:04:43 +00:00
|
|
|
# LS: tested
|
2023-02-28 05:31:15 +00:00
|
|
|
def add_model(
|
2023-05-12 20:13:34 +00:00
|
|
|
self,
|
|
|
|
model_name: str,
|
2023-06-10 00:14:10 +00:00
|
|
|
base_model: BaseModelType,
|
|
|
|
model_type: ModelType,
|
2023-05-12 20:13:34 +00:00
|
|
|
model_attributes: dict,
|
2023-05-14 00:06:26 +00:00
|
|
|
clobber: bool = False,
|
2023-07-03 23:32:54 +00:00
|
|
|
) -> AddModelResult:
|
2023-02-28 05:31:15 +00:00
|
|
|
"""
|
|
|
|
Update the named model with a dictionary of attributes. Will fail with an
|
|
|
|
assertion error if the name already exists. Pass clobber=True to overwrite.
|
|
|
|
On a successful update, the config will be changed in memory and the
|
|
|
|
method will return True. Will fail with an assertion error if provided
|
|
|
|
attributes are incorrect or the model name is missing.
|
2023-07-03 23:32:54 +00:00
|
|
|
|
|
|
|
The returned dict has the same format as the dict returned by
|
|
|
|
model_info().
|
2023-02-28 05:31:15 +00:00
|
|
|
"""
|
2023-07-29 17:00:07 +00:00
|
|
|
# relativize paths as they go in - this makes it easier to move the models directory around
|
2023-07-17 11:26:05 +00:00
|
|
|
if path := model_attributes.get("path"):
|
2023-07-29 17:00:07 +00:00
|
|
|
model_attributes["path"] = str(self.relative_model_path(Path(path)))
|
2023-05-12 20:13:34 +00:00
|
|
|
|
2023-07-29 04:03:27 +00:00
|
|
|
model_class = self._get_implementation(base_model, model_type)
|
2023-06-13 15:05:12 +00:00
|
|
|
model_config = model_class.create_config(**model_attributes)
|
2023-06-10 00:14:10 +00:00
|
|
|
model_key = self.create_key(model_name, base_model, model_type)
|
2023-05-08 03:18:17 +00:00
|
|
|
|
2023-06-26 17:33:38 +00:00
|
|
|
if model_key in self.models and not clobber:
|
2023-06-26 00:07:54 +00:00
|
|
|
raise Exception(f'Attempt to overwrite existing model definition "{model_key}"')
|
2023-02-28 05:31:15 +00:00
|
|
|
|
2023-06-26 00:54:42 +00:00
|
|
|
old_model = self.models.pop(model_key, None)
|
2023-06-26 00:07:54 +00:00
|
|
|
if old_model is not None:
|
|
|
|
# TODO: if path changed and old_model.path inside models folder should we delete this too?
|
|
|
|
|
|
|
|
# remove conversion cache as config changed
|
2023-08-03 23:01:05 +00:00
|
|
|
old_model_path = self.resolve_model_path(old_model.path)
|
2023-06-26 00:07:54 +00:00
|
|
|
old_model_cache = self._get_model_cache_path(old_model_path)
|
|
|
|
if old_model_cache.exists():
|
|
|
|
if old_model_cache.is_dir():
|
|
|
|
rmtree(str(old_model_cache))
|
|
|
|
else:
|
|
|
|
old_model_cache.unlink()
|
|
|
|
|
|
|
|
# remove in-memory cache
|
2023-07-03 23:32:54 +00:00
|
|
|
# note: it not guaranteed to release memory(model can has other references)
|
2023-06-14 00:12:12 +00:00
|
|
|
cache_ids = self.cache_keys.pop(model_key, [])
|
|
|
|
for cache_id in cache_ids:
|
|
|
|
self.cache.uncache_model(cache_id)
|
2023-02-28 05:31:15 +00:00
|
|
|
|
2023-06-26 00:07:54 +00:00
|
|
|
self.models[model_key] = model_config
|
2023-07-04 21:26:57 +00:00
|
|
|
self.commit()
|
2023-07-23 00:12:16 +00:00
|
|
|
|
2023-07-03 23:32:54 +00:00
|
|
|
return AddModelResult(
|
|
|
|
name=model_name,
|
|
|
|
model_type=model_type,
|
|
|
|
base_model=base_model,
|
|
|
|
config=model_config,
|
|
|
|
)
|
2023-06-26 00:07:54 +00:00
|
|
|
|
2023-07-15 03:03:18 +00:00
|
|
|
def rename_model(
|
|
|
|
self,
|
|
|
|
model_name: str,
|
|
|
|
base_model: BaseModelType,
|
|
|
|
model_type: ModelType,
|
2023-08-01 07:55:13 +00:00
|
|
|
new_name: Optional[str] = None,
|
|
|
|
new_base: Optional[BaseModelType] = None,
|
2023-07-15 03:03:18 +00:00
|
|
|
):
|
2023-07-27 14:54:01 +00:00
|
|
|
"""
|
2023-07-15 03:03:18 +00:00
|
|
|
Rename or rebase a model.
|
2023-07-27 14:54:01 +00:00
|
|
|
"""
|
2023-07-15 03:03:18 +00:00
|
|
|
if new_name is None and new_base is None:
|
|
|
|
self.logger.error("rename_model() called with neither a new_name nor a new_base. {model_name} unchanged.")
|
|
|
|
return
|
2023-07-18 20:33:19 +00:00
|
|
|
|
2023-07-15 03:03:18 +00:00
|
|
|
model_key = self.create_key(model_name, base_model, model_type)
|
|
|
|
model_cfg = self.models.get(model_key, None)
|
|
|
|
if not model_cfg:
|
2023-07-16 18:17:05 +00:00
|
|
|
raise ModelNotFoundException(f"Unknown model: {model_key}")
|
2023-07-18 20:33:19 +00:00
|
|
|
|
2023-07-29 17:00:07 +00:00
|
|
|
old_path = self.resolve_model_path(model_cfg.path)
|
2023-07-15 03:03:18 +00:00
|
|
|
new_name = new_name or model_name
|
|
|
|
new_base = new_base or base_model
|
|
|
|
new_key = self.create_key(new_name, new_base, model_type)
|
|
|
|
if new_key in self.models:
|
|
|
|
raise ValueError(f'Attempt to overwrite existing model definition "{new_key}"')
|
|
|
|
|
|
|
|
# if this is a model file/directory that we manage ourselves, we need to move it
|
|
|
|
if old_path.is_relative_to(self.app_config.models_path):
|
2023-07-29 17:00:43 +00:00
|
|
|
new_path = self.resolve_model_path(
|
|
|
|
Path(
|
|
|
|
BaseModelType(new_base).value,
|
|
|
|
ModelType(model_type).value,
|
|
|
|
new_name,
|
2023-07-29 17:00:07 +00:00
|
|
|
)
|
2023-07-27 14:54:01 +00:00
|
|
|
)
|
2023-07-15 03:03:18 +00:00
|
|
|
move(old_path, new_path)
|
2023-07-29 14:30:27 +00:00
|
|
|
model_cfg.path = str(new_path.relative_to(self.app_config.models_path))
|
2023-07-15 03:03:18 +00:00
|
|
|
|
|
|
|
# clean up caches
|
|
|
|
old_model_cache = self._get_model_cache_path(old_path)
|
|
|
|
if old_model_cache.exists():
|
|
|
|
if old_model_cache.is_dir():
|
|
|
|
rmtree(str(old_model_cache))
|
|
|
|
else:
|
|
|
|
old_model_cache.unlink()
|
|
|
|
|
|
|
|
cache_ids = self.cache_keys.pop(model_key, [])
|
|
|
|
for cache_id in cache_ids:
|
|
|
|
self.cache.uncache_model(cache_id)
|
|
|
|
|
|
|
|
self.models.pop(model_key, None) # delete
|
|
|
|
self.models[new_key] = model_cfg
|
|
|
|
self.commit()
|
2023-07-18 20:33:19 +00:00
|
|
|
|
2023-07-05 13:05:05 +00:00
|
|
|
def convert_model(
|
|
|
|
self,
|
|
|
|
model_name: str,
|
|
|
|
base_model: BaseModelType,
|
2023-08-01 07:55:13 +00:00
|
|
|
model_type: Literal[ModelType.Main, ModelType.Vae],
|
2023-07-14 17:45:16 +00:00
|
|
|
dest_directory: Optional[Path] = None,
|
2023-07-05 13:05:05 +00:00
|
|
|
) -> AddModelResult:
|
2023-07-27 14:54:01 +00:00
|
|
|
"""
|
2023-07-05 13:05:05 +00:00
|
|
|
Convert a checkpoint file into a diffusers folder, deleting the cached
|
|
|
|
version and deleting the original checkpoint file if it is in the models
|
|
|
|
directory.
|
|
|
|
:param model_name: Name of the model to convert
|
|
|
|
:param base_model: Base model type
|
|
|
|
:param model_type: Type of model ['vae' or 'main']
|
|
|
|
|
|
|
|
This will raise a ValueError unless the model is a checkpoint.
|
2023-07-27 14:54:01 +00:00
|
|
|
"""
|
2023-07-05 13:05:05 +00:00
|
|
|
info = self.model_info(model_name, base_model, model_type)
|
2023-08-01 07:55:13 +00:00
|
|
|
|
|
|
|
if info is None:
|
|
|
|
raise FileNotFoundError(f"model not found: {model_name}")
|
|
|
|
|
2023-07-05 13:05:05 +00:00
|
|
|
if info["model_format"] != "checkpoint":
|
|
|
|
raise ValueError(f"not a checkpoint format model: {model_name}")
|
|
|
|
|
|
|
|
# We are taking advantage of a side effect of get_model() that converts check points
|
|
|
|
# into cached diffusers directories stored at `location`. It doesn't matter
|
|
|
|
# what submodeltype we request here, so we get the smallest.
|
2023-07-27 04:02:10 +00:00
|
|
|
submodel = {"submodel_type": SubModelType.Scheduler} if model_type == ModelType.Main else {}
|
2023-07-05 13:05:05 +00:00
|
|
|
model = self.get_model(
|
|
|
|
model_name,
|
|
|
|
base_model,
|
|
|
|
model_type,
|
|
|
|
**submodel,
|
|
|
|
)
|
2023-08-03 23:01:05 +00:00
|
|
|
checkpoint_path = self.resolve_model_path(info["path"])
|
2023-07-29 14:45:26 +00:00
|
|
|
old_diffusers_path = self.resolve_model_path(model.location)
|
2023-07-14 17:45:16 +00:00
|
|
|
new_diffusers_path = (
|
|
|
|
dest_directory or self.app_config.models_path / base_model.value / model_type.value
|
|
|
|
) / model_name
|
2023-07-05 13:05:05 +00:00
|
|
|
if new_diffusers_path.exists():
|
2023-07-06 03:13:01 +00:00
|
|
|
raise ValueError(f"A diffusers model already exists at {new_diffusers_path}")
|
2023-07-05 13:05:05 +00:00
|
|
|
|
|
|
|
try:
|
|
|
|
move(old_diffusers_path, new_diffusers_path)
|
|
|
|
info["model_format"] = "diffusers"
|
2023-07-14 17:45:16 +00:00
|
|
|
info["path"] = (
|
|
|
|
str(new_diffusers_path)
|
|
|
|
if dest_directory
|
2023-07-29 14:30:27 +00:00
|
|
|
else str(new_diffusers_path.relative_to(self.app_config.models_path))
|
2023-07-27 14:54:01 +00:00
|
|
|
)
|
2023-07-05 13:05:05 +00:00
|
|
|
info.pop("config")
|
|
|
|
|
|
|
|
result = self.add_model(model_name, base_model, model_type, model_attributes=info, clobber=True)
|
|
|
|
except:
|
|
|
|
# something went wrong, so don't leave dangling diffusers model in directory or it will cause a duplicate model error!
|
|
|
|
rmtree(new_diffusers_path)
|
|
|
|
raise
|
2023-07-18 20:33:19 +00:00
|
|
|
|
2023-07-05 13:05:05 +00:00
|
|
|
if checkpoint_path.exists() and checkpoint_path.is_relative_to(self.app_config.models_path):
|
|
|
|
checkpoint_path.unlink()
|
2023-07-18 20:33:19 +00:00
|
|
|
|
2023-07-05 13:05:05 +00:00
|
|
|
return result
|
2023-07-18 20:33:19 +00:00
|
|
|
|
2023-07-29 17:00:43 +00:00
|
|
|
def resolve_model_path(self, path: Union[Path, str]) -> Path:
|
2023-07-29 17:00:07 +00:00
|
|
|
"""return relative paths based on configured models_path"""
|
2023-07-29 14:45:26 +00:00
|
|
|
return self.app_config.models_path / path
|
|
|
|
|
2023-07-29 17:00:07 +00:00
|
|
|
def relative_model_path(self, model_path: Path) -> Path:
|
|
|
|
if model_path.is_relative_to(self.app_config.models_path):
|
|
|
|
model_path = model_path.relative_to(self.app_config.models_path)
|
|
|
|
return model_path
|
|
|
|
|
2023-02-28 05:31:15 +00:00
|
|
|
def search_models(self, search_folder):
|
2023-04-29 14:48:50 +00:00
|
|
|
self.logger.info(f"Finding Models In: {search_folder}")
|
2023-02-28 05:31:15 +00:00
|
|
|
models_folder_ckpt = Path(search_folder).glob("**/*.ckpt")
|
|
|
|
models_folder_safetensors = Path(search_folder).glob("**/*.safetensors")
|
|
|
|
|
|
|
|
ckpt_files = [x for x in models_folder_ckpt if x.is_file()]
|
|
|
|
safetensor_files = [x for x in models_folder_safetensors if x.is_file()]
|
|
|
|
|
|
|
|
files = ckpt_files + safetensor_files
|
|
|
|
|
|
|
|
found_models = []
|
|
|
|
for file in files:
|
|
|
|
location = str(file.resolve()).replace("\\", "/")
|
|
|
|
if "model.safetensors" not in location and "diffusion_pytorch_model.safetensors" not in location:
|
|
|
|
found_models.append({"name": file.stem, "location": location})
|
|
|
|
|
|
|
|
return search_folder, found_models
|
|
|
|
|
2023-08-01 07:55:13 +00:00
|
|
|
def commit(self, conf_file: Optional[Path] = None) -> None:
|
2023-02-28 05:31:15 +00:00
|
|
|
"""
|
|
|
|
Write current configuration out to the indicated file.
|
|
|
|
"""
|
2023-06-11 01:49:09 +00:00
|
|
|
data_to_save = dict()
|
2023-06-14 00:12:12 +00:00
|
|
|
data_to_save["__metadata__"] = self.config_meta.dict()
|
|
|
|
|
2023-06-11 01:49:09 +00:00
|
|
|
for model_key, model_config in self.models.items():
|
|
|
|
model_name, base_model, model_type = self.parse_key(model_key)
|
2023-07-29 04:03:27 +00:00
|
|
|
model_class = self._get_implementation(base_model, model_type)
|
2023-06-11 01:49:09 +00:00
|
|
|
if model_class.save_to_config:
|
|
|
|
# TODO: or exclude_unset better fits here?
|
2023-06-17 19:48:44 +00:00
|
|
|
data_to_save[model_key] = model_config.dict(exclude_defaults=True, exclude={"error"})
|
2023-06-20 00:25:08 +00:00
|
|
|
# alias for config file
|
|
|
|
data_to_save[model_key]["format"] = data_to_save[model_key].pop("model_format")
|
2023-06-11 01:49:09 +00:00
|
|
|
|
|
|
|
yaml_str = OmegaConf.to_yaml(data_to_save)
|
2023-05-07 02:41:19 +00:00
|
|
|
config_file_path = conf_file or self.config_path
|
2023-05-08 03:18:17 +00:00
|
|
|
assert config_file_path is not None, "no config file path to write to"
|
2023-06-26 00:07:54 +00:00
|
|
|
config_file_path = self.app_config.root_path / config_file_path
|
2023-02-28 05:31:15 +00:00
|
|
|
tmpfile = os.path.join(os.path.dirname(config_file_path), "new_config.tmp")
|
2023-07-18 20:33:19 +00:00
|
|
|
try:
|
|
|
|
with open(tmpfile, "w", encoding="utf-8") as outfile:
|
|
|
|
outfile.write(self.preamble())
|
|
|
|
outfile.write(yaml_str)
|
|
|
|
os.replace(tmpfile, config_file_path)
|
|
|
|
except OSError as err:
|
|
|
|
self.logger.warning(f"Could not modify the config file at {config_file_path}")
|
|
|
|
self.logger.warning(err)
|
2023-02-28 05:31:15 +00:00
|
|
|
|
|
|
|
def preamble(self) -> str:
|
|
|
|
"""
|
|
|
|
Returns the preamble for the config file.
|
|
|
|
"""
|
|
|
|
return textwrap.dedent(
|
2023-07-23 00:12:16 +00:00
|
|
|
"""
|
2023-02-28 05:31:15 +00:00
|
|
|
# This file describes the alternative machine learning models
|
|
|
|
# available to InvokeAI script.
|
|
|
|
#
|
|
|
|
# To add a new model, follow the examples below. Each
|
|
|
|
# model requires a model config file, a weights file,
|
|
|
|
# and the width and height of the images it
|
|
|
|
# was trained on.
|
|
|
|
"""
|
2023-05-08 03:18:17 +00:00
|
|
|
)
|
|
|
|
|
2023-06-26 00:07:54 +00:00
|
|
|
def scan_models_directory(
|
|
|
|
self,
|
|
|
|
base_model: Optional[BaseModelType] = None,
|
|
|
|
model_type: Optional[ModelType] = None,
|
|
|
|
):
|
2023-06-12 02:52:30 +00:00
|
|
|
loaded_files = set()
|
2023-06-14 00:12:12 +00:00
|
|
|
new_models_found = False
|
2023-06-27 16:30:53 +00:00
|
|
|
|
2023-07-21 02:45:35 +00:00
|
|
|
self.logger.info(f"Scanning {self.app_config.models_path} for new models")
|
2023-07-29 14:30:27 +00:00
|
|
|
with Chdir(self.app_config.models_path):
|
2023-06-25 22:50:15 +00:00
|
|
|
for model_key, model_config in list(self.models.items()):
|
2023-06-26 00:07:54 +00:00
|
|
|
model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
|
2023-07-29 14:47:55 +00:00
|
|
|
|
2023-07-29 14:30:27 +00:00
|
|
|
# Patch for relative path bug in older models.yaml - paths should not
|
|
|
|
# be starting with a hard-coded 'models'. This will also fix up
|
|
|
|
# models.yaml when committed.
|
2023-07-29 14:47:55 +00:00
|
|
|
if model_config.path.startswith("models"):
|
2023-07-29 14:30:27 +00:00
|
|
|
model_config.path = str(Path(*Path(model_config.path).parts[1:]))
|
2023-07-29 14:47:55 +00:00
|
|
|
|
2023-07-29 17:00:07 +00:00
|
|
|
model_path = self.resolve_model_path(model_config.path).absolute()
|
2023-06-26 00:07:54 +00:00
|
|
|
if not model_path.exists():
|
2023-07-29 04:03:27 +00:00
|
|
|
model_class = self._get_implementation(cur_base_model, cur_model_type)
|
2023-06-25 22:50:15 +00:00
|
|
|
if model_class.save_to_config:
|
|
|
|
model_config.error = ModelError.NotFound
|
2023-07-01 18:32:58 +00:00
|
|
|
self.models.pop(model_key, None)
|
2023-06-25 22:50:15 +00:00
|
|
|
else:
|
|
|
|
self.models.pop(model_key, None)
|
2023-06-12 02:52:30 +00:00
|
|
|
else:
|
2023-06-25 22:50:15 +00:00
|
|
|
loaded_files.add(model_path)
|
2023-06-12 02:52:30 +00:00
|
|
|
|
2023-06-26 00:07:54 +00:00
|
|
|
for cur_base_model in BaseModelType:
|
|
|
|
if base_model is not None and cur_base_model != base_model:
|
|
|
|
continue
|
|
|
|
|
|
|
|
for cur_model_type in ModelType:
|
|
|
|
if model_type is not None and cur_model_type != model_type:
|
|
|
|
continue
|
2023-07-29 04:03:27 +00:00
|
|
|
model_class = self._get_implementation(cur_base_model, cur_model_type)
|
2023-07-29 17:00:07 +00:00
|
|
|
models_dir = self.resolve_model_path(Path(cur_base_model.value, cur_model_type.value))
|
2023-06-12 02:52:30 +00:00
|
|
|
|
2023-06-26 00:07:54 +00:00
|
|
|
if not models_dir.exists():
|
2023-06-25 22:50:15 +00:00
|
|
|
continue # TODO: or create all folders?
|
|
|
|
|
2023-06-26 00:07:54 +00:00
|
|
|
for model_path in models_dir.iterdir():
|
2023-06-25 22:50:15 +00:00
|
|
|
if model_path not in loaded_files: # TODO: check
|
|
|
|
model_name = model_path.name if model_path.is_dir() else model_path.stem
|
2023-06-26 00:07:54 +00:00
|
|
|
model_key = self.create_key(model_name, cur_base_model, cur_model_type)
|
2023-06-12 02:52:30 +00:00
|
|
|
|
2023-07-05 06:43:46 +00:00
|
|
|
try:
|
2023-07-21 18:14:33 +00:00
|
|
|
if model_key in self.models:
|
|
|
|
raise DuplicateModelException(f"Model with key {model_key} added twice")
|
2023-07-29 17:00:43 +00:00
|
|
|
|
2023-07-29 17:00:07 +00:00
|
|
|
model_path = self.relative_model_path(model_path)
|
2023-07-05 06:43:46 +00:00
|
|
|
model_config: ModelConfigBase = model_class.probe_config(str(model_path))
|
|
|
|
self.models[model_key] = model_config
|
|
|
|
new_models_found = True
|
2023-07-21 18:14:33 +00:00
|
|
|
except DuplicateModelException as e:
|
|
|
|
self.logger.warning(e)
|
2023-07-08 01:09:10 +00:00
|
|
|
except InvalidModelException:
|
|
|
|
self.logger.warning(f"Not a valid model: {model_path}")
|
2023-07-05 06:43:46 +00:00
|
|
|
except NotImplementedError as e:
|
|
|
|
self.logger.warning(e)
|
2023-06-14 00:12:12 +00:00
|
|
|
|
2023-07-29 17:00:07 +00:00
|
|
|
imported_models = self.scan_autoimport_directory()
|
2023-06-25 22:50:15 +00:00
|
|
|
if (new_models_found or imported_models) and self.config_path:
|
2023-06-14 00:12:12 +00:00
|
|
|
self.commit()
|
2023-06-23 20:35:39 +00:00
|
|
|
|
2023-07-29 17:00:07 +00:00
|
|
|
def scan_autoimport_directory(self) -> Dict[str, AddModelResult]:
|
2023-07-27 14:54:01 +00:00
|
|
|
"""
|
2023-06-25 22:50:15 +00:00
|
|
|
Scan the autoimport directory (if defined) and import new models, delete defunct models.
|
2023-07-27 14:54:01 +00:00
|
|
|
"""
|
2023-06-25 22:50:15 +00:00
|
|
|
# avoid circular import
|
|
|
|
from invokeai.backend.install.model_install_backend import ModelInstall
|
2023-06-26 20:18:16 +00:00
|
|
|
from invokeai.frontend.install.model_install import ask_user_for_prediction_type
|
2023-06-27 16:30:53 +00:00
|
|
|
|
2023-07-14 15:14:33 +00:00
|
|
|
class ScanAndImport(ModelSearch):
|
|
|
|
def __init__(self, directories, logger, ignore: Set[Path], installer: ModelInstall):
|
|
|
|
super().__init__(directories, logger)
|
|
|
|
self.installer = installer
|
|
|
|
self.ignore = ignore
|
2023-06-28 19:26:42 +00:00
|
|
|
|
2023-07-14 15:14:33 +00:00
|
|
|
def on_search_started(self):
|
|
|
|
self.new_models_found = dict()
|
2023-06-28 19:26:42 +00:00
|
|
|
|
2023-07-14 15:14:33 +00:00
|
|
|
def on_model_found(self, model: Path):
|
|
|
|
if model not in self.ignore:
|
|
|
|
self.new_models_found.update(self.installer.heuristic_import(model))
|
2023-06-28 19:26:42 +00:00
|
|
|
|
2023-07-14 15:14:33 +00:00
|
|
|
def on_search_completed(self):
|
|
|
|
self.logger.info(
|
|
|
|
f"Scanned {self._items_scanned} files and directories, imported {len(self.new_models_found)} models"
|
|
|
|
)
|
2023-06-28 19:26:42 +00:00
|
|
|
|
2023-07-14 15:14:33 +00:00
|
|
|
def models_found(self):
|
|
|
|
return self.new_models_found
|
2023-06-28 19:26:42 +00:00
|
|
|
|
2023-07-19 13:05:24 +00:00
|
|
|
config = self.app_config
|
|
|
|
|
|
|
|
# LS: hacky
|
|
|
|
# Patch in the SD VAE from core so that it is available for use by the UI
|
|
|
|
try:
|
2023-08-01 07:55:13 +00:00
|
|
|
self.heuristic_import({str(self.resolve_model_path("core/convert/sd-vae-ft-mse"))})
|
2023-07-19 13:05:24 +00:00
|
|
|
except:
|
|
|
|
pass
|
2023-06-28 19:26:42 +00:00
|
|
|
|
2023-07-14 15:14:33 +00:00
|
|
|
installer = ModelInstall(
|
|
|
|
config=self.app_config,
|
|
|
|
model_manager=self,
|
|
|
|
prediction_type_helper=ask_user_for_prediction_type,
|
|
|
|
)
|
2023-08-03 23:01:05 +00:00
|
|
|
known_paths = {self.resolve_model_path(x["path"]) for x in self.list_models()}
|
2023-07-14 15:14:33 +00:00
|
|
|
directories = {
|
|
|
|
config.root_path / x
|
|
|
|
for x in [
|
|
|
|
config.autoimport_dir,
|
|
|
|
config.lora_dir,
|
|
|
|
config.embedding_dir,
|
2023-07-19 13:05:24 +00:00
|
|
|
config.controlnet_dir,
|
2023-07-20 13:01:49 +00:00
|
|
|
]
|
|
|
|
if x
|
2023-07-14 15:14:33 +00:00
|
|
|
}
|
|
|
|
scanner = ScanAndImport(directories, self.logger, ignore=known_paths, installer=installer)
|
|
|
|
scanner.search()
|
2023-07-27 14:54:01 +00:00
|
|
|
|
2023-07-14 15:14:33 +00:00
|
|
|
return scanner.models_found()
|
2023-06-23 20:35:39 +00:00
|
|
|
|
|
|
|
def heuristic_import(
|
|
|
|
self,
|
|
|
|
items_to_import: Set[str],
|
2023-08-01 07:55:13 +00:00
|
|
|
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
2023-07-03 23:32:54 +00:00
|
|
|
) -> Dict[str, AddModelResult]:
|
2023-06-25 20:04:43 +00:00
|
|
|
"""Import a list of paths, repo_ids or URLs. Returns the set of
|
|
|
|
successfully imported items.
|
|
|
|
:param items_to_import: Set of strings corresponding to models to be imported.
|
|
|
|
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
|
|
|
|
|
|
|
The prediction type helper is necessary to distinguish between
|
|
|
|
models based on Stable Diffusion 2 Base (requiring
|
|
|
|
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
|
|
|
|
(requiring SchedulerPredictionType.VPrediction). It is
|
|
|
|
generally impossible to do this programmatically, so the
|
|
|
|
prediction_type_helper usually asks the user to choose.
|
|
|
|
|
2023-07-03 23:32:54 +00:00
|
|
|
The result is a set of successfully installed models. Each element
|
|
|
|
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
|
|
|
that model.
|
2023-07-04 13:59:11 +00:00
|
|
|
|
|
|
|
May return the following exceptions:
|
2023-07-16 18:17:05 +00:00
|
|
|
- ModelNotFoundException - one or more of the items to import is not a valid path, repo_id or URL
|
2023-07-04 13:59:11 +00:00
|
|
|
- ValueError - a corresponding model already exists
|
2023-07-27 14:54:01 +00:00
|
|
|
"""
|
2023-06-23 20:35:39 +00:00
|
|
|
# avoid circular import here
|
|
|
|
from invokeai.backend.install.model_install_backend import ModelInstall
|
2023-07-27 14:54:01 +00:00
|
|
|
|
2023-07-03 23:32:54 +00:00
|
|
|
successfully_installed = dict()
|
2023-07-18 20:33:19 +00:00
|
|
|
|
2023-06-26 00:07:54 +00:00
|
|
|
installer = ModelInstall(
|
2023-06-23 20:35:39 +00:00
|
|
|
config=self.app_config, prediction_type_helper=prediction_type_helper, model_manager=self
|
|
|
|
)
|
|
|
|
for thing in items_to_import:
|
2023-07-04 13:59:11 +00:00
|
|
|
installed = installer.heuristic_import(thing)
|
|
|
|
successfully_installed.update(installed)
|
2023-07-18 20:33:19 +00:00
|
|
|
self.commit()
|
2023-06-23 20:35:39 +00:00
|
|
|
return successfully_installed
|