mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix model manager documentation
This commit is contained in:
parent
c3c4a71173
commit
60b37b7ff4
@ -374,7 +374,8 @@ setting environment variables INVOKEAI_<setting>.
|
|||||||
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Memory/Performance')
|
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Memory/Performance')
|
||||||
|
|
||||||
root : Path = Field(default=_find_root(), description='InvokeAI runtime root directory', category='Paths')
|
root : Path = Field(default=_find_root(), description='InvokeAI runtime root directory', category='Paths')
|
||||||
autoconvert_dir : Path = Field(default=None, description='Path to a directory of ckpt files to be converted into diffusers and imported on startup.', category='Paths')
|
autoimport_dir : Path = Field(default='models/autoimport', description='Path to a directory of models files to be imported on startup.', category='Paths')
|
||||||
|
autoconvert_dir : Path = Field(default=None, description='Deprecated configuration option.', category='Paths')
|
||||||
conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths')
|
conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths')
|
||||||
models_dir : Path = Field(default='./models', description='Path to the models directory', category='Paths')
|
models_dir : Path = Field(default='./models', description='Path to the models directory', category='Paths')
|
||||||
legacy_conf_dir : Path = Field(default='configs/stable-diffusion', description='Path to directory of legacy checkpoint config files', category='Paths')
|
legacy_conf_dir : Path = Field(default='configs/stable-diffusion', description='Path to directory of legacy checkpoint config files', category='Paths')
|
||||||
|
@ -179,9 +179,9 @@ class ModelInstall(object):
|
|||||||
self.mgr.commit()
|
self.mgr.commit()
|
||||||
|
|
||||||
if selections.autoscan_on_startup and Path(selections.scan_directory).is_dir():
|
if selections.autoscan_on_startup and Path(selections.scan_directory).is_dir():
|
||||||
update_autoconvert_dir(selections.scan_directory)
|
update_autoimport_dir(selections.scan_directory)
|
||||||
else:
|
else:
|
||||||
update_autoconvert_dir(None)
|
update_autoimport_dir(None)
|
||||||
|
|
||||||
def heuristic_install(self, model_path_id_or_url: Union[str,Path]):
|
def heuristic_install(self, model_path_id_or_url: Union[str,Path]):
|
||||||
# A little hack to allow nested routines to retrieve info on the requested ID
|
# A little hack to allow nested routines to retrieve info on the requested ID
|
||||||
@ -375,13 +375,13 @@ class ModelInstall(object):
|
|||||||
'''
|
'''
|
||||||
return {v.get('path') or v.get('repo_id') : k for k, v in datasets.items()}
|
return {v.get('path') or v.get('repo_id') : k for k, v in datasets.items()}
|
||||||
|
|
||||||
def update_autoconvert_dir(autodir: Path):
|
def update_autoimport_dir(autodir: Path):
|
||||||
'''
|
'''
|
||||||
Update the "autoconvert_dir" option in invokeai.yaml
|
Update the "autoimport_dir" option in invokeai.yaml
|
||||||
'''
|
'''
|
||||||
invokeai_config_path = config.init_file_path
|
invokeai_config_path = config.init_file_path
|
||||||
conf = OmegaConf.load(invokeai_config_path)
|
conf = OmegaConf.load(invokeai_config_path)
|
||||||
conf.InvokeAI.Paths.autoconvert_dir = str(autodir) if autodir else None
|
conf.InvokeAI.Paths.autoimport_dir = str(autodir) if autodir else None
|
||||||
yaml = OmegaConf.to_yaml(conf)
|
yaml = OmegaConf.to_yaml(conf)
|
||||||
tmpfile = invokeai_config_path.parent / "new_config.tmp"
|
tmpfile = invokeai_config_path.parent / "new_config.tmp"
|
||||||
with open(tmpfile, "w", encoding="utf-8") as outfile:
|
with open(tmpfile, "w", encoding="utf-8") as outfile:
|
||||||
|
@ -1,53 +1,193 @@
|
|||||||
"""This module manages the InvokeAI `models.yaml` file, mapping
|
"""This module manages the InvokeAI `models.yaml` file, mapping
|
||||||
symbolic diffusers model names to the paths and repo_ids used
|
symbolic diffusers model names to the paths and repo_ids used by the
|
||||||
by the underlying `from_pretrained()` call.
|
underlying `from_pretrained()` call.
|
||||||
|
|
||||||
For fetching models, use manager.get_model('symbolic name'). This will
|
SYNOPSIS:
|
||||||
return a ModelInfo object that contains the following attributes:
|
|
||||||
|
|
||||||
* context -- a context manager Generator that loads and locks the
|
|
||||||
model into GPU VRAM and returns the model for use.
|
|
||||||
See below for usage.
|
|
||||||
* name -- symbolic name of the model
|
|
||||||
* type -- SubModelType of the model
|
|
||||||
* hash -- unique hash for the model
|
|
||||||
* location -- path or repo_id of the model
|
|
||||||
* revision -- revision of the model if coming from a repo id,
|
|
||||||
e.g. 'fp16'
|
|
||||||
* precision -- torch precision of the model
|
|
||||||
|
|
||||||
Typical usage:
|
mgr = ModelManager('/home/phi/invokeai/configs/models.yaml')
|
||||||
|
sd1_5 = mgr.get_model('stable-diffusion-v1-5',
|
||||||
|
model_type=ModelType.Main,
|
||||||
|
base_model=BaseModelType.StableDiffusion1,
|
||||||
|
submodel_type=SubModelType.Unet)
|
||||||
|
with sd1_5 as unet:
|
||||||
|
run_some_inference(unet)
|
||||||
|
|
||||||
from invokeai.backend import ModelManager
|
FETCHING MODELS:
|
||||||
|
|
||||||
manager = ModelManager(
|
Models are described using four attributes:
|
||||||
config='./configs/models.yaml',
|
|
||||||
max_cache_size=8
|
|
||||||
) # gigabytes
|
|
||||||
|
|
||||||
model_info = manager.get_model('stable-diffusion-1.5', SubModelType.Diffusers)
|
1) model_name -- the symbolic name for the model
|
||||||
with model_info.context as my_model:
|
|
||||||
my_model.latents_from_embeddings(...)
|
|
||||||
|
|
||||||
The manager uses the underlying ModelCache class to keep
|
2) ModelType -- an enum describing the type of the model. Currently
|
||||||
frequently-used models in RAM and move them into GPU as needed for
|
defined types are:
|
||||||
generation operations. The optional `max_cache_size` argument
|
ModelType.Main -- a full model capable of generating images
|
||||||
indicates the maximum size the cache can grow to, in gigabytes. The
|
ModelType.Vae -- a VAE model
|
||||||
underlying ModelCache object can be accessed using the manager's "cache"
|
ModelType.Lora -- a LoRA or LyCORIS fine-tune
|
||||||
attribute.
|
ModelType.TextualInversion -- a textual inversion embedding
|
||||||
|
ModelType.ControlNet -- a ControlNet model
|
||||||
|
|
||||||
Because the model manager can return multiple different types of
|
3) BaseModelType -- an enum indicating the stable diffusion base model, one of:
|
||||||
models, you may wish to add additional type checking on the class
|
BaseModelType.StableDiffusion1
|
||||||
of model returned. To do this, provide the option `model_type`
|
BaseModelType.StableDiffusion2
|
||||||
parameter:
|
|
||||||
|
|
||||||
model_info = manager.get_model(
|
4) SubModelType (optional) -- an enum that refers to one of the submodels contained
|
||||||
'clip-tokenizer',
|
within the main model. Values are:
|
||||||
model_type=SubModelType.Tokenizer
|
|
||||||
)
|
|
||||||
|
|
||||||
This will raise an InvalidModelError if the format defined in the
|
SubModelType.UNet
|
||||||
config file doesn't match the requested model type.
|
SubModelType.TextEncoder
|
||||||
|
SubModelType.Tokenizer
|
||||||
|
SubModelType.Scheduler
|
||||||
|
SubModelType.SafetyChecker
|
||||||
|
|
||||||
|
To fetch a model, use `manager.get_model()`. This takes the symbolic
|
||||||
|
name of the model, the ModelType, the BaseModelType and the
|
||||||
|
SubModelType. The latter is required for ModelType.Main.
|
||||||
|
|
||||||
|
get_model() will return a ModelInfo object that can then be used in
|
||||||
|
context to retrieve the model and move it into GPU VRAM (on GPU
|
||||||
|
systems).
|
||||||
|
|
||||||
|
A typical example is:
|
||||||
|
|
||||||
|
sd1_5 = mgr.get_model('stable-diffusion-v1-5',
|
||||||
|
model_type=ModelType.Main,
|
||||||
|
base_model=BaseModelType.StableDiffusion1,
|
||||||
|
submodel_type=SubModelType.Unet)
|
||||||
|
with sd1_5 as unet:
|
||||||
|
run_some_inference(unet)
|
||||||
|
|
||||||
|
The ModelInfo object provides a number of useful fields describing the
|
||||||
|
model, including:
|
||||||
|
|
||||||
|
name -- symbolic name of the model
|
||||||
|
base_model -- base model (BaseModelType)
|
||||||
|
type -- model type (ModelType)
|
||||||
|
location -- path to the model file
|
||||||
|
precision -- torch precision of the model
|
||||||
|
hash -- unique sha256 checksum for this model
|
||||||
|
|
||||||
|
SUBMODELS:
|
||||||
|
|
||||||
|
When fetching a main model, you must specify the submodel. Retrieval
|
||||||
|
of full pipelines is not supported.
|
||||||
|
|
||||||
|
vae_info = mgr.get_model('stable-diffusion-1.5',
|
||||||
|
model_type = ModelType.Main,
|
||||||
|
base_model = BaseModelType.StableDiffusion1,
|
||||||
|
submodel_type = SubModelType.Vae
|
||||||
|
)
|
||||||
|
with vae_info as vae:
|
||||||
|
do_something(vae)
|
||||||
|
|
||||||
|
This rule does not apply to controlnets, embeddings, loras and standalone
|
||||||
|
VAEs, which do not have submodels.
|
||||||
|
|
||||||
|
LISTING MODELS
|
||||||
|
|
||||||
|
The model_names() method will return a list of Tuples describing each
|
||||||
|
model it knows about:
|
||||||
|
|
||||||
|
>> mgr.model_names()
|
||||||
|
[
|
||||||
|
('stable-diffusion-1.5', <BaseModelType.StableDiffusion1: 'sd-1'>, <ModelType.Main: 'main'>),
|
||||||
|
('stable-diffusion-2.1', <BaseModelType.StableDiffusion2: 'sd-2'>, <ModelType.Main: 'main'>),
|
||||||
|
('inpaint', <BaseModelType.StableDiffusion1: 'sd-1'>, <ModelType.ControlNet: 'controlnet'>)
|
||||||
|
('Ink scenery', <BaseModelType.StableDiffusion1: 'sd-1'>, <ModelType.Lora: 'lora'>)
|
||||||
|
...
|
||||||
|
]
|
||||||
|
|
||||||
|
The tuple is in the correct order to pass to get_model():
|
||||||
|
|
||||||
|
for m in mgr.model_names():
|
||||||
|
info = get_model(*m)
|
||||||
|
|
||||||
|
In contrast, the list_models() method returns a list of dicts, each
|
||||||
|
providing information about a model defined in models.yaml. For example:
|
||||||
|
|
||||||
|
>>> models = mgr.list_models()
|
||||||
|
>>> json.dumps(models[0])
|
||||||
|
{"path": "/home/lstein/invokeai-main/models/sd-1/controlnet/canny",
|
||||||
|
"model_format": "diffusers",
|
||||||
|
"name": "canny",
|
||||||
|
"base_model": "sd-1",
|
||||||
|
"type": "controlnet"
|
||||||
|
}
|
||||||
|
|
||||||
|
You can filter by model type and base model as shown here:
|
||||||
|
|
||||||
|
|
||||||
|
controlnets = mgr.list_models(model_type=ModelType.ControlNet,
|
||||||
|
base_model=BaseModelType.StableDiffusion1)
|
||||||
|
for c in controlnets:
|
||||||
|
name = c['name']
|
||||||
|
format = c['model_format']
|
||||||
|
path = c['path']
|
||||||
|
type = c['type']
|
||||||
|
# etc
|
||||||
|
|
||||||
|
ADDING AND REMOVING MODELS
|
||||||
|
|
||||||
|
At startup time, the `models` directory will be scanned for
|
||||||
|
checkpoints, diffusers pipelines, controlnets, LoRAs and TI
|
||||||
|
embeddings. New entries will be added to the model manager and defunct
|
||||||
|
ones removed. Anything that is a main model (ModelType.Main) will be
|
||||||
|
added to models.yaml. For scanning to succeed, files need to be in
|
||||||
|
their proper places. For example, a controlnet folder built on the
|
||||||
|
stable diffusion 2 base, will need to be placed in
|
||||||
|
`models/sd-2/controlnet`.
|
||||||
|
|
||||||
|
Layout of the `models` directory:
|
||||||
|
|
||||||
|
models
|
||||||
|
├── sd-1
|
||||||
|
│ ├── controlnet
|
||||||
|
│ ├── lora
|
||||||
|
│ ├── main
|
||||||
|
│ └── embedding
|
||||||
|
├── sd-2
|
||||||
|
│ ├── controlnet
|
||||||
|
│ ├── lora
|
||||||
|
│ ├── main
|
||||||
|
│ └── embedding
|
||||||
|
└── core
|
||||||
|
├── face_reconstruction
|
||||||
|
│ ├── codeformer
|
||||||
|
│ └── gfpgan
|
||||||
|
├── sd-conversion
|
||||||
|
│ ├── clip-vit-large-patch14 - tokenizer, text_encoder subdirs
|
||||||
|
│ ├── stable-diffusion-2 - tokenizer, text_encoder subdirs
|
||||||
|
│ └── stable-diffusion-safety-checker
|
||||||
|
└── upscaling
|
||||||
|
└─── esrgan
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigMeta(BaseModel):Loras, textual_inversion and controlnet models are not listed
|
||||||
|
explicitly in models.yaml, but are added to the in-memory data
|
||||||
|
structure at initialization time by scanning the models directory. The
|
||||||
|
in-memory data structure can be resynchronized by calling
|
||||||
|
`manager.scan_models_directory()`.
|
||||||
|
|
||||||
|
Files and folders placed inside the `autoimport_dir` (path defined in
|
||||||
|
`invokeai.yaml`, defaulting to `ROOTDIR/autoimport` will also be
|
||||||
|
scanned for new models at initialization time and added to
|
||||||
|
`models.yaml`. Files will not be moved from this location but
|
||||||
|
preserved in-place.
|
||||||
|
|
||||||
|
A model can be manually added using `add_model()` using the model's
|
||||||
|
name, base model, type and a dict of model attributes. See
|
||||||
|
`invokeai/backend/model_management/models` for the attributes required
|
||||||
|
by each model type.
|
||||||
|
|
||||||
|
A model can be deleted using `del_model()`, providing the same
|
||||||
|
identifying information as `get_model()`
|
||||||
|
|
||||||
|
The `heuristic_import()` method will take a set of strings
|
||||||
|
corresponding to local paths, remote URLs, and repo_ids, probe the
|
||||||
|
object to determine what type of model it is (if any), and import new
|
||||||
|
models into the manager. If passed a directory, it will recursively
|
||||||
|
scan it for models to import. The return value is a set of the models
|
||||||
|
successfully added.
|
||||||
|
|
||||||
MODELS.YAML
|
MODELS.YAML
|
||||||
|
|
||||||
@ -56,94 +196,18 @@ The general format of a models.yaml section is:
|
|||||||
type-of-model/name-of-model:
|
type-of-model/name-of-model:
|
||||||
path: /path/to/local/file/or/directory
|
path: /path/to/local/file/or/directory
|
||||||
description: a description
|
description: a description
|
||||||
format: folder|ckpt|safetensors|pt
|
format: diffusers|checkpoint
|
||||||
base: SD-1|SD-2
|
variant: normal|inpaint|depth
|
||||||
subfolder: subfolder-name
|
|
||||||
|
|
||||||
The type of model is given in the stanza key, and is one of
|
The type of model is given in the stanza key, and is one of
|
||||||
{diffusers, ckpt, vae, text_encoder, tokenizer, unet, scheduler,
|
{main, vae, lora, controlnet, textual}
|
||||||
safety_checker, feature_extractor, lora, textual_inversion,
|
|
||||||
controlnet}, and correspond to items in the SubModelType enum defined
|
|
||||||
in model_cache.py
|
|
||||||
|
|
||||||
The format indicates whether the model is organized as a folder with
|
The format indicates whether the model is organized as a diffusers
|
||||||
model subdirectories, or is contained in a single checkpoint or
|
folder with model subdirectories, or is contained in a single
|
||||||
safetensors file.
|
checkpoint or safetensors file.
|
||||||
|
|
||||||
One, but not both, of repo_id and path are provided. repo_id is the
|
|
||||||
HuggingFace repository ID of the model, and path points to the file or
|
|
||||||
directory on disk.
|
|
||||||
|
|
||||||
If subfolder is provided, then the model exists in a subdirectory of
|
|
||||||
the main model. These are usually named after the model type, such as
|
|
||||||
"unet".
|
|
||||||
|
|
||||||
This example summarizes the two ways of getting a non-diffuser model:
|
|
||||||
|
|
||||||
text_encoder/clip-test-1:
|
|
||||||
format: folder
|
|
||||||
path: /path/to/folder
|
|
||||||
description: Returns standalone CLIPTextModel
|
|
||||||
|
|
||||||
text_encoder/clip-test-2:
|
|
||||||
format: folder
|
|
||||||
repo_id: /path/to/folder
|
|
||||||
subfolder: text_encoder
|
|
||||||
description: Returns the text_encoder in the subfolder of the diffusers model (just the encoder in RAM)
|
|
||||||
|
|
||||||
SUBMODELS:
|
|
||||||
|
|
||||||
It is also possible to fetch an isolated submodel from a diffusers
|
|
||||||
model. Use the `submodel` parameter to select which part:
|
|
||||||
|
|
||||||
vae = manager.get_model('stable-diffusion-1.5',submodel=SubModelType.Vae)
|
|
||||||
with vae.context as my_vae:
|
|
||||||
print(type(my_vae))
|
|
||||||
# "AutoencoderKL"
|
|
||||||
|
|
||||||
DIRECTORY_SCANNING:
|
|
||||||
|
|
||||||
Loras, textual_inversion and controlnet models are usually not listed
|
|
||||||
explicitly in models.yaml, but are added to the in-memory data
|
|
||||||
structure at initialization time by scanning the models directory. The
|
|
||||||
in-memory data structure can be resynchronized by calling
|
|
||||||
`manager.scan_models_directory`.
|
|
||||||
|
|
||||||
DISAMBIGUATION:
|
|
||||||
|
|
||||||
You may wish to use the same name for a related family of models. To
|
|
||||||
do this, disambiguate the stanza key with the model and and format
|
|
||||||
separated by "/". Example:
|
|
||||||
|
|
||||||
tokenizer/clip-large:
|
|
||||||
format: tokenizer
|
|
||||||
path: /path/to/folder
|
|
||||||
description: Returns standalone tokenizer
|
|
||||||
|
|
||||||
text_encoder/clip-large:
|
|
||||||
format: text_encoder
|
|
||||||
path: /path/to/folder
|
|
||||||
description: Returns standalone text encoder
|
|
||||||
|
|
||||||
You can now use the `model_type` argument to indicate which model you
|
|
||||||
want:
|
|
||||||
|
|
||||||
tokenizer = mgr.get('clip-large',model_type=SubModelType.Tokenizer)
|
|
||||||
encoder = mgr.get('clip-large',model_type=SubModelType.TextEncoder)
|
|
||||||
|
|
||||||
OTHER FUNCTIONS:
|
|
||||||
|
|
||||||
Other methods provided by ModelManager support importing, editing,
|
|
||||||
converting and deleting models.
|
|
||||||
|
|
||||||
IMPORTANT CHANGES AND LIMITATIONS SINCE 2.3:
|
|
||||||
|
|
||||||
1. Only local paths are supported. Repo_ids are no longer accepted. This
|
|
||||||
simplifies the logic.
|
|
||||||
|
|
||||||
2. VAEs can't be swapped in and out at load time. They must be baked
|
|
||||||
into the model when downloaded or converted.
|
|
||||||
|
|
||||||
|
The path points to a file or directory on disk. If a relative path,
|
||||||
|
the root is the InvokeAI ROOTDIR.
|
||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
@ -185,7 +249,6 @@ class ModelInfo():
|
|||||||
hash: str
|
hash: str
|
||||||
location: Union[Path, str]
|
location: Union[Path, str]
|
||||||
precision: torch.dtype
|
precision: torch.dtype
|
||||||
revision: str = None
|
|
||||||
_cache: ModelCache = None
|
_cache: ModelCache = None
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
@ -201,31 +264,6 @@ class InvalidModelError(Exception):
|
|||||||
MAX_CACHE_SIZE = 6.0 # GB
|
MAX_CACHE_SIZE = 6.0 # GB
|
||||||
|
|
||||||
|
|
||||||
# layout of the models directory:
|
|
||||||
# models
|
|
||||||
# ├── sd-1
|
|
||||||
# │ ├── controlnet
|
|
||||||
# │ ├── lora
|
|
||||||
# │ ├── pipeline
|
|
||||||
# │ └── textual_inversion
|
|
||||||
# ├── sd-2
|
|
||||||
# │ ├── controlnet
|
|
||||||
# │ ├── lora
|
|
||||||
# │ ├── pipeline
|
|
||||||
# │ └── textual_inversion
|
|
||||||
# └── core
|
|
||||||
# ├── face_reconstruction
|
|
||||||
# │ ├── codeformer
|
|
||||||
# │ └── gfpgan
|
|
||||||
# ├── sd-conversion
|
|
||||||
# │ ├── clip-vit-large-patch14 - tokenizer, text_encoder subdirs
|
|
||||||
# │ ├── stable-diffusion-2 - tokenizer, text_encoder subdirs
|
|
||||||
# │ └── stable-diffusion-safety-checker
|
|
||||||
# └── upscaling
|
|
||||||
# └─── esrgan
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigMeta(BaseModel):
|
class ConfigMeta(BaseModel):
|
||||||
version: str
|
version: str
|
||||||
|
|
||||||
@ -330,44 +368,14 @@ class ModelManager(object):
|
|||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
submodel_type: Optional[SubModelType] = None
|
submodel_type: Optional[SubModelType] = None
|
||||||
):
|
)->ModelInfo:
|
||||||
"""Given a model named identified in models.yaml, return
|
"""Given a model named identified in models.yaml, return
|
||||||
an ModelInfo object describing it.
|
an ModelInfo object describing it.
|
||||||
:param model_name: symbolic name of the model in models.yaml
|
:param model_name: symbolic name of the model in models.yaml
|
||||||
:param model_type: ModelType enum indicating the type of model to return
|
:param model_type: ModelType enum indicating the type of model to return
|
||||||
|
:param base_model: BaseModelType enum indicating the base model used by this model
|
||||||
:param submode_typel: an ModelType enum indicating the portion of
|
:param submode_typel: an ModelType enum indicating the portion of
|
||||||
the model to retrieve (e.g. ModelType.Vae)
|
the model to retrieve (e.g. ModelType.Vae)
|
||||||
|
|
||||||
If not provided, the model_type will be read from the `format` field
|
|
||||||
of the corresponding stanza. If provided, the model_type will be used
|
|
||||||
to disambiguate stanzas in the configuration file. The default is to
|
|
||||||
assume a diffusers pipeline. The behavior is illustrated here:
|
|
||||||
|
|
||||||
[models.yaml]
|
|
||||||
diffusers/test1:
|
|
||||||
repo_id: foo/bar
|
|
||||||
description: Typical diffusers pipeline
|
|
||||||
|
|
||||||
lora/test1:
|
|
||||||
repo_id: /tmp/loras/test1.safetensors
|
|
||||||
description: Typical lora file
|
|
||||||
|
|
||||||
test1_pipeline = mgr.get_model('test1')
|
|
||||||
# returns a StableDiffusionGeneratorPipeline
|
|
||||||
|
|
||||||
test1_vae1 = mgr.get_model('test1', submodel=ModelType.Vae)
|
|
||||||
# returns the VAE part of a diffusers model as an AutoencoderKL
|
|
||||||
|
|
||||||
test1_vae2 = mgr.get_model('test1', model_type=ModelType.Diffusers, submodel=ModelType.Vae)
|
|
||||||
# does the same thing as the previous statement. Note that model_type
|
|
||||||
# is for the parent model, and submodel is for the part
|
|
||||||
|
|
||||||
test1_lora = mgr.get_model('test1', model_type=ModelType.Lora)
|
|
||||||
# returns a LoRA embed (as a 'dict' of tensors)
|
|
||||||
|
|
||||||
test1_encoder = mgr.get_modelI('test1', model_type=ModelType.TextEncoder)
|
|
||||||
# raises an InvalidModelError
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
model_class = MODEL_CLASSES[base_model][model_type]
|
model_class = MODEL_CLASSES[base_model][model_type]
|
||||||
model_key = self.create_key(model_name, base_model, model_type)
|
model_key = self.create_key(model_name, base_model, model_type)
|
||||||
@ -511,7 +519,7 @@ class ModelManager(object):
|
|||||||
|
|
||||||
def print_models(self) -> None:
|
def print_models(self) -> None:
|
||||||
"""
|
"""
|
||||||
Print a table of models, their descriptions
|
Print a table of models and their descriptions. This needs to be redone
|
||||||
"""
|
"""
|
||||||
# TODO: redo
|
# TODO: redo
|
||||||
for model_type, model_dict in self.list_models().items():
|
for model_type, model_dict in self.list_models().items():
|
||||||
@ -552,7 +560,7 @@ class ModelManager(object):
|
|||||||
else:
|
else:
|
||||||
model_path.unlink()
|
model_path.unlink()
|
||||||
|
|
||||||
# TODO: test when ui implemented
|
# LS: tested
|
||||||
def add_model(
|
def add_model(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
@ -694,11 +702,18 @@ class ModelManager(object):
|
|||||||
items_to_import: Set[str],
|
items_to_import: Set[str],
|
||||||
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
||||||
)->Set[str]:
|
)->Set[str]:
|
||||||
'''
|
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
||||||
Import a list of paths, repo_ids or URLs. Returns the
|
successfully imported items.
|
||||||
set of successfully imported items. The prediction_type_helper
|
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||||
is a callback that receives the Path of a checkpoint or diffusers
|
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
||||||
model and returns a SchedulerPredictionType (or None).
|
|
||||||
|
The prediction type helper is necessary to distinguish between
|
||||||
|
models based on Stable Diffusion 2 Base (requiring
|
||||||
|
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
|
||||||
|
(requiring SchedulerPredictionType.VPrediction). It is
|
||||||
|
generally impossible to do this programmatically, so the
|
||||||
|
prediction_type_helper usually asks the user to choose.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
# avoid circular import here
|
# avoid circular import here
|
||||||
from invokeai.backend.install.model_install_backend import ModelInstall
|
from invokeai.backend.install.model_install_backend import ModelInstall
|
||||||
@ -716,6 +731,3 @@ class ModelManager(object):
|
|||||||
|
|
||||||
self.commit()
|
self.commit()
|
||||||
return successfully_installed
|
return successfully_installed
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -126,7 +126,10 @@ class ModelBase(metaclass=ABCMeta):
|
|||||||
if not isinstance(value, type) or not issubclass(value, ModelConfigBase):
|
if not isinstance(value, type) or not issubclass(value, ModelConfigBase):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
fields = inspect.get_annotations(value)
|
if hasattr(inspect,'get_annotations'):
|
||||||
|
fields = inspect.get_annotations(value)
|
||||||
|
else:
|
||||||
|
fields = value.__annotations__
|
||||||
try:
|
try:
|
||||||
field = fields["model_format"]
|
field = fields["model_format"]
|
||||||
except:
|
except:
|
||||||
|
@ -323,7 +323,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
FileBox,
|
FileBox,
|
||||||
max_height=3,
|
max_height=3,
|
||||||
name=label,
|
name=label,
|
||||||
value=str(config.autoconvert_dir) if config.autoconvert_dir else None,
|
value=str(config.autoimport_dir) if config.autoimport_dir else None,
|
||||||
select_dir=True,
|
select_dir=True,
|
||||||
must_exist=True,
|
must_exist=True,
|
||||||
use_two_lines=False,
|
use_two_lines=False,
|
||||||
@ -336,7 +336,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
autoscan_on_startup = self.add_widget_intelligent(
|
autoscan_on_startup = self.add_widget_intelligent(
|
||||||
npyscreen.Checkbox,
|
npyscreen.Checkbox,
|
||||||
name="Scan and import from this directory each time InvokeAI starts",
|
name="Scan and import from this directory each time InvokeAI starts",
|
||||||
value=config.autoconvert_dir is not None,
|
value=config.autoimport_dir is not None,
|
||||||
relx=4,
|
relx=4,
|
||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user