convert add_model(), del_model(), list_models() etc to use bifurcated names

This commit is contained in:
Lincoln Stein
2023-05-13 14:44:44 -04:00
parent bc96727cbe
commit 72967bf118
6 changed files with 217 additions and 128 deletions

View File

@ -227,7 +227,7 @@ class Generate:
# don't accept invalid models
fallback = self.model_manager.default_model() or FALLBACK_MODEL_NAME
model = model or fallback
if not self.model_manager.valid_model(model):
if not self.model_manager.model_exists(model):
logger.warning(
f'"{model}" is not a known model name; falling back to {fallback}.'
)
@ -877,7 +877,7 @@ class Generate:
# the model cache does the loading and offloading
cache = self.model_manager
if not cache.valid_model(model_name):
if not cache.model_exists(model_name):
raise KeyError(
f'** "{model_name}" is not a known model name. Cannot change.'
)

View File

@ -9,6 +9,7 @@ return a SDModelInfo object that contains the following attributes:
model into GPU VRAM and returns the model for use.
See below for usage.
* name -- symbolic name of the model
* type -- SDModelType of the model
* hash -- unique hash for the model
* location -- path or repo_id of the model
* revision -- revision of the model if coming from a repo id,
@ -26,7 +27,7 @@ Typical usage:
max_cache_size=8
) # gigabytes
model_info = manager.get_model('stable-diffusion-1.5')
model_info = manager.get_model('stable-diffusion-1.5', SDModelType.diffusers)
with model_info.context as my_model:
my_model.latents_from_embeddings(...)
@ -54,15 +55,20 @@ MODELS.YAML
The general format of a models.yaml section is:
name-of-model:
format: diffusers|ckpt|vae|text_encoder|tokenizer...
type-of-model/name-of-model:
format: folder|ckpt|safetensors
repo_id: owner/repo
path: /path/to/local/file/or/directory
subfolder: subfolder-name
The format is one of {diffusers, ckpt, vae, text_encoder, tokenizer,
unet, scheduler, safety_checker, feature_extractor}, and correspond to
items in the SDModelType enum defined in model_cache.py
The type of model is given in the stanza key, and is one of
{diffusers, ckpt, vae, text_encoder, tokenizer, unet, scheduler,
safety_checker, feature_extractor, lora, textual_inversion}, and
correspond to items in the SDModelType enum defined in model_cache.py
The format indicates whether the model is organized as a folder with
model subdirectories, or is contained in a single checkpoint or
safetensors file.
One, but not both, of repo_id and path are provided. repo_id is the
HuggingFace repository ID of the model, and path points to the file or
@ -74,13 +80,13 @@ the main model. These are usually named after the model type, such as
This example summarizes the two ways of getting a non-diffuser model:
clip-test-1:
format: text_encoder
text_encoder/clip-test-1:
format: folder
repo_id: openai/clip-vit-large-patch14
description: Returns standalone CLIPTextModel
clip-test-2:
format: text_encoder
text_encoder/clip-test-2:
format: folder
repo_id: stabilityai/stable-diffusion-2
subfolder: text_encoder
description: Returns the text_encoder in the subfolder of the diffusers model (just the encoder in RAM)
@ -101,12 +107,12 @@ You may wish to use the same name for a related family of models. To
do this, disambiguate the stanza key with the model and and format
separated by "/". Example:
clip-large/tokenizer:
tokenizer/clip-large:
format: tokenizer
repo_id: openai/clip-vit-large-patch14
description: Returns standalone tokenizer
clip-large/text_encoder:
text_encoder/clip-large:
format: text_encoder
repo_id: openai/clip-vit-large-patch14
description: Returns standalone text encoder
@ -128,31 +134,41 @@ from __future__ import annotations
import os
import re
import textwrap
from contextlib import suppress
from dataclasses import dataclass
from enum import Enum, auto
from packaging import version
from pathlib import Path
from shutil import rmtree
from typing import Union, Callable, types
from contextlib import suppress
from typing import Callable, Optional, List, Tuple, Union, types
import safetensors
import safetensors.torch
import torch
import invokeai.backend.util.logging as logger
from huggingface_hub import scan_cache_dir
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from invokeai.backend.globals import Globals, global_cache_dir, global_resolve_path
from .model_cache import ModelCache, ModelLocker, SDModelType, ModelStatus, SilenceWarnings
import invokeai.backend.util.logging as logger
from invokeai.backend.globals import (Globals, global_cache_dir,
global_resolve_path)
from invokeai.backend.util import download_with_resume
from ..util import CUDA_DEVICE
from .model_cache import (ModelCache, ModelLocker, ModelStatus, SDModelType,
SilenceWarnings)
# We are only starting to number the config file with release 3.
# The config file version doesn't have to start at release version, but it will help
# reduce confusion.
CONFIG_FILE_VERSION='3.0.0'
# wanted to use pydantic here, but Generator objects not supported
@dataclass
class SDModelInfo():
context: ModelLocker
name: str
type: SDModelType
hash: str
location: Union[Path,str]
precision: torch.dtype
@ -208,14 +224,17 @@ class ModelManager(object):
type and precision are set up for a CUDA system running at half precision.
"""
if isinstance(config, DictConfig):
self.config = config
self.config_path = None
self.config = config
elif isinstance(config,(str,Path)):
self.config_path = config
self.config = OmegaConf.load(self.config_path)
else:
raise ValueError('config argument must be an OmegaConf object, a Path or a string')
# check config version number and update on disk/RAM if necessary
self._update_config_file_version()
self.cache = ModelCache(
max_cache_size=max_cache_size,
execution_device = device_type,
@ -226,8 +245,7 @@ class ModelManager(object):
self.cache_keys = dict()
self.logger = logger
# TODO: rename to smth like - is_model_exists
def valid_model(
def model_exists(
self,
model_name: str,
model_type: SDModelType = SDModelType.diffusers,
@ -246,14 +264,14 @@ class ModelManager(object):
model_type_str, model_name = model_key.split('/', 1)
if model_type_str not in SDModelType.__members__:
# TODO:
raise Exception(f"Unkown model type: {model_type_str}")
raise Exception(f"Unknown model type: {model_type_str}")
return (model_name, SDModelType[model_type_str])
def get_model(
self,
model_name: str,
model_type: SDModelType=None,
model_type: SDModelType=SDModelType.diffusers,
submodel: SDModelType=None,
) -> SDModelInfo:
"""Given a model named identified in models.yaml, return
@ -311,7 +329,7 @@ class ModelManager(object):
if model_type == SDModelType.diffusers:
# intercept stanzas that point to checkpoint weights and replace them
# with the equivalent diffusers model
if 'weights' in mconfig:
if mconfig.format in ["ckpt", "diffusers"]:
location = self.convert_ckpt_and_cache(mconfig)
else:
location = global_resolve_path(mconfig.get('path')) or mconfig.get('repo_id')
@ -350,6 +368,7 @@ class ModelManager(object):
return SDModelInfo(
context = model_context,
name = model_name,
type = submodel or model_type,
hash = hash,
location = location,
revision = revision,
@ -358,43 +377,50 @@ class ModelManager(object):
_cache = self.cache
)
def default_model(self) -> Union[str,None]:
def default_model(self) -> Union[Tuple(str, SDModelType),None]:
"""
Returns the name of the default model, or None
if none is defined.
"""
for model_name in self.config:
if self.config[model_name].get("default"):
return model_name
return list(self.config.keys())[0] # first one
for model_name, model_type in self.model_names():
model_key = self.create_key(model_name, model_type)
if self.config[model_key].get("default"):
return (model_name, model_type)
return self.model_names()[0][0]
def set_default_model(self, model_name: str) -> None:
def set_default_model(self, model_name: str, model_type: SDModelType=SDModelType.diffusers) -> None:
"""
Set the default model. The change will not take
effect until you call model_manager.commit()
"""
assert model_name in self.model_names(), f"unknown model '{model_name}'"
assert self.model_exists(model_name, model_type), f"unknown model '{model_name}'"
config = self.config
for model in config:
config[model].pop("default", None)
config[model_name]["default"] = True
for model_name, model_type in self.model_names():
key = self.create_key(model_name, model_type)
config[key].pop("default", None)
config[self.create_key(model_name, model_type)]["default"] = True
def model_info(self, model_name: str) -> dict:
def model_info(
self,
model_name: str,
model_type: SDModelType=SDModelType.diffusers
) -> dict:
"""
Given a model name returns the OmegaConf (dict-like) object describing it.
"""
if model_name not in self.config:
if not self.exists(model_name, model_type):
return None
return self.config[model_name]
return self.config[self.create_key(model_name,model_type)]
def model_names(self) -> list[str]:
def model_names(self) -> List[Tuple(str, SDModelType)]:
"""
Return a list consisting of all the names of models defined in models.yaml
Return a list of (str, SDModelType) corresponding to all models
known to the configuration.
"""
return list(self.config.keys())
return [(self.parse_key(x)) for x in self.config.keys() if isinstance(self.config[x],DictConfig)]
def is_legacy(self, model_name: str) -> bool:
def is_legacy(self, model_name: str, model_type: SDModelType.diffusers) -> bool:
"""
Return true if this is a legacy (.ckpt) model
"""
@ -402,7 +428,7 @@ class ModelManager(object):
# there are no legacy ckpts!
if Globals.ckpt_convert:
return False
info = self.model_info(model_name)
info = self.model_info(model_name, model_type)
if "weights" in info and info["weights"].endswith((".ckpt", ".safetensors")):
return True
return False
@ -422,15 +448,17 @@ class ModelManager(object):
# don't include VAEs in listing (legacy style)
if "config" in stanza and "/VAE/" in stanza["config"]:
continue
if model_key=='config_file_version':
continue
model_name, model_type = self.parse_key(model_key)
models[model_name] = dict()
models[model_key] = dict()
# TODO: return all models in future
if model_type != SDModelType.diffusers:
continue
model_format = "ckpt" if "weights" in stanza else "diffusers"
model_format = stanza.get('format')
# Common Attribs
status = self.cache.status(
@ -439,26 +467,25 @@ class ModelManager(object):
subfolder=stanza.get('subfolder')
)
description = stanza.get("description", None)
models[model_name].update(
description=description,
type=model_type,
models[model_key].update(
model_name=model_name,
model_type=model_type.name,
format=model_format,
description=description,
status=status.value
)
# Checkpoint Config Parse
if model_format == "ckpt":
models[model_name].update(
if model_format in ["ckpt","safetensors"]:
models[model_key].update(
config = str(stanza.get("config", None)),
weights = str(stanza.get("weights", None)),
vae = str(stanza.get("vae", None)),
width = str(stanza.get("width", 512)),
height = str(stanza.get("height", 512)),
)
# Diffusers Config Parse
elif model_format == "diffusers":
elif model_format == "folder":
if vae := stanza.get("vae", None):
if isinstance(vae, DictConfig):
vae = dict(
@ -467,7 +494,7 @@ class ModelManager(object):
subfolder = str(vae.get("subfolder", None)),
)
models[model_name].update(
models[model_key].update(
vae = vae,
repo_id = str(stanza.get("repo_id", None)),
path = str(stanza.get("path", None)),
@ -479,12 +506,9 @@ class ModelManager(object):
"""
Print a table of models, their descriptions, and load status
"""
models = self.list_models()
for name in models:
if models[name]["type"] == "vae":
continue
line = f'{name:25s} {models[name]["status"]:>15s} {models[name]["type"]:10s} {models[name]["description"]}'
if models[name]["status"] == "active":
for model_key, model_info in self.list_models().items():
line = f'{model_info["model_name"]:25s} {model_info["status"]:>15s} {model_info["model_type"]:10s} {model_info["description"]}'
if model_info["status"] in ["in gpu","locked in gpu"]:
line = f"\033[1m{line}\033[0m"
print(line)
@ -511,9 +535,9 @@ class ModelManager(object):
# self.stack.remove(model_name)
if delete_files:
repo_id = conf.get("repo_id", None)
path = self._abs_path(conf.get("path", None))
weights = self._abs_path(conf.get("weights", None))
repo_id = model_cfg.get("repo_id", None)
path = self._abs_path(model_cfg.get("path", None))
weights = self._abs_path(model_cfg.get("weights", None))
if "weights" in model_cfg:
weights = self._abs_path(model_cfg["weights"])
self.logger.info(f"Deleting file {weights}")
@ -558,7 +582,7 @@ class ModelManager(object):
), 'model must have either the "path" or "repo_id" fields defined'
elif model_format == "ckpt":
for field in ("description", "weights", "height", "width", "config"):
for field in ("description", "weights", "config"):
assert field in model_attributes, f"required field {field} is missing"
else:
@ -579,11 +603,6 @@ class ModelManager(object):
self.cache.uncache_model(self.cache_keys[model_key])
del self.cache_keys[model_key]
def import_diffuser_model(
self,
repo_or_path: Union[str, Path],
@ -604,7 +623,6 @@ class ModelManager(object):
models.yaml file.
"""
model_name = model_name or Path(repo_or_path).stem
model_key = f'{model_name}/diffusers'
model_description = description or f"Imported diffusers model {model_name}"
new_config = dict(
description=model_description,
@ -616,10 +634,10 @@ class ModelManager(object):
else:
new_config.update(repo_id=repo_or_path)
self.add_model(model_key, new_config, True)
self.add_model(model_name, SDModelType.diffusers, new_config, True)
if commit_to_conf:
self.commit(commit_to_conf)
return model_key
return self.create_key(model_name, SDModelType.diffusers)
def import_lora(
self,
@ -635,7 +653,8 @@ class ModelManager(object):
model_name = model_name or path.stem
model_description = description or f"LoRA model {model_name}"
self.add_model(
f'{model_name}/{SDModelType.lora.name}',
model_name,
SDModelType.lora,
dict(
format="lora",
weights=str(path),
@ -663,7 +682,8 @@ class ModelManager(object):
model_name = model_name or path.stem
model_description = description or f"Textual embedding model {model_name}"
self.add_model(
f'{model_name}/{SDModelType.textual_inversion.name}',
model_name,
SDModelType.textual_inversion,
dict(
format="textual_inversion",
weights=str(weights),
@ -1025,9 +1045,14 @@ class ModelManager(object):
description=model_description,
format="diffusers",
)
if model_name in self.config:
self.del_model(model_name)
self.add_model(model_name, new_config, True)
if self.model_exists(model_name, SDModelType.diffusers):
self.del_model(model_name, SDModelType.diffusers)
self.add_model(
model_name,
SDModelType.diffusers,
new_config,
True
)
if commit_to_conf:
self.commit(commit_to_conf)
self.logger.debug("Conversion succeeded")
@ -1081,11 +1106,6 @@ class ModelManager(object):
"""\
# This file describes the alternative machine learning models
# available to InvokeAI script.
#
# To add a new model, follow the examples below. Each
# model requires a model config file, a weights file,
# and the width and height of the images it
# was trained on.
"""
)
@ -1130,3 +1150,53 @@ class ModelManager(object):
source = os.path.join(Globals.root, source)
resolved_path = Path(source)
return resolved_path
def _update_config_file_version(self):
"""
This gets called at object init time and will update
from older versions of the config file to new ones
as necessary.
"""
current_version = self.config.get("config_file_version","1.0.0")
if version.parse(current_version) < version.parse(CONFIG_FILE_VERSION):
self.logger.info(f'models.yaml version {current_version} detected. Updating to {CONFIG_FILE_VERSION}')
new_config = OmegaConf.create()
new_config["config_file_version"] = CONFIG_FILE_VERSION
for model_key in self.config:
old_stanza = self.config[model_key]
# ignore old and ugly way of associating a legacy
# vae with a legacy checkpont model
if old_stanza.get("config") and '/VAE/' in old_stanza.get("config"):
continue
# bare keys are updated to be prefixed with 'diffusers/'
if '/' not in model_key:
new_key = f'diffusers/{model_key}'
else:
new_key = model_key
if old_stanza.get('format')=='diffusers':
model_format = 'folder'
elif old_stanza.get('weights') and Path(old_stanza.get('weights')).suffix == '.ckpt':
model_format = 'ckpt'
elif old_stanza.get('weights') and Path(old_stanza.get('weights')).suffix == '.safetensors':
model_format = 'safetensors'
# copy fields over manually rather than doing a copy() or deepcopy()
# in order to avoid bringing in unwanted fields.
new_config[new_key] = dict(
description = old_stanza.get('description'),
format = model_format,
)
for field in ["repo_id", "path", "weights", "config", "vae"]:
if field_value := old_stanza.get(field):
new_config[new_key].update({field: field_value})
self.config = new_config
if self.config_path:
self.commit()