add a new method to model_manager that retrieves individual pipeline parts

- New method is ModelManager.get_sub_model(model_name:str,model_part:SDModelComponent)

To use:

```
from invokeai.backend import ModelManager, SDModelComponent as sdmc
manager = ModelManager('/path/to/models.yaml')
vae = manager.get_sub_model('stable-diffusion-1.5', sdmc.vae)
```
This commit is contained in:
Lincoln Stein 2023-04-05 17:25:42 -04:00
parent c4e6511a59
commit d44151d6ff
3 changed files with 136 additions and 75 deletions

View File

@ -10,7 +10,7 @@ from .generator import (
Img2Img,
Inpaint
)
from .model_management import ModelManager
from .model_management import ModelManager, SDModelComponent
from .safety_checker import SafetyChecker
from .args import Args
from .globals import Globals

View File

@ -5,5 +5,6 @@ from .convert_ckpt_to_diffusers import (
convert_ckpt_to_diffusers,
load_pipeline_from_original_stable_diffusion_ckpt,
)
from .model_manager import ModelManager
from .model_manager import ModelManager,SDModelComponent

View File

@ -1,4 +1,4 @@
"""
"""enum
Manage a cache of Stable Diffusion model files for fast switching.
They are moved between GPU and CPU as necessary. If CPU memory falls
below a preset minimum, the least recently used model will be
@ -15,7 +15,7 @@ import sys
import textwrap
import time
import warnings
from enum import Enum
from enum import Enum, auto
from pathlib import Path
from shutil import move, rmtree
from typing import Any, Optional, Union, Callable
@ -24,8 +24,11 @@ import safetensors
import safetensors.torch
import torch
import transformers
from diffusers import AutoencoderKL
from diffusers import logging as dlogging
from diffusers import (
AutoencoderKL,
UNet2DConditionModel,
logging as dlogging,
)
from huggingface_hub import scan_cache_dir
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
@ -33,37 +36,58 @@ from picklescan.scanner import scan_file_path
from invokeai.backend.globals import Globals, global_cache_dir
from ..stable_diffusion import StableDiffusionGeneratorPipeline
from transformers import (
CLIPTextModel,
CLIPTokenizer,
CLIPFeatureExtractor,
)
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
from ..stable_diffusion import (
StableDiffusionGeneratorPipeline,
)
from ..util import CUDA_DEVICE, ask_user, download_with_resume
class SDLegacyType(Enum):
V1 = 1
V1_INPAINT = 2
V2 = 3
V2_e = 4
V2_v = 5
UNKNOWN = 99
class SDLegacyType(Enum):
V1 = auto()
V1_INPAINT = auto()
V2 = auto()
V2_e = auto()
V2_v = auto()
UNKNOWN = auto()
class SDModelComponent(Enum):
vae="vae"
text_encoder="text_encoder"
tokenizer="tokenizer"
unet="unet"
scheduler="scheduler"
safety_checker="safety_checker"
feature_extractor="feature_extractor"
DEFAULT_MAX_MODELS = 2
class ModelManager(object):
'''
"""
Model manager handles loading, caching, importing, deleting, converting, and editing models.
'''
"""
def __init__(
self,
config: OmegaConf|Path,
device_type: torch.device = CUDA_DEVICE,
precision: str = "float16",
max_loaded_models=DEFAULT_MAX_MODELS,
sequential_offload=False,
embedding_path: Path=None,
self,
config: OmegaConf | Path,
device_type: torch.device = CUDA_DEVICE,
precision: str = "float16",
max_loaded_models=DEFAULT_MAX_MODELS,
sequential_offload=False,
embedding_path: Path = None,
):
"""
Initialize with the path to the models.yaml config file or
an initialized OmegaConf dictionary. Optional parameters
are the torch device type, precision, max_loaded_models,
and sequential_offload boolean. Note that the default device
and sequential_offload boolean. Note that the default device
type and precision are set up for a CUDA system running at half precision.
"""
# prevent nasty-looking CLIP log message
@ -87,15 +111,25 @@ class ModelManager(object):
"""
return model_name in self.config
def get_model(self, model_name: str=None)->dict:
"""
Given a model named identified in models.yaml, return
the model object. If in RAM will load into GPU VRAM.
If on disk, will load from there.
def get_model(self, model_name: str = None) -> dict:
"""Given a model named identified in models.yaml, return a dict
containing the model object and some of its key features. If
in RAM will load into GPU VRAM. If on disk, will load from
there.
The dict has the following keys:
'model': The StableDiffusionGeneratorPipeline object
'model_name': The name of the model in models.yaml
'width': The width of images trained by this model
'height': The height of images trained by this model
'hash': A unique hash of this model's files on disk.
"""
if not model_name:
return self.get_model(self.current_model) if self.current_model else self.get_model(self.default_model())
return (
self.get_model(self.current_model)
if self.current_model
else self.get_model(self.default_model())
)
if not self.valid_model(model_name):
print(
f'** "{model_name}" is not a known model name. Please check your models.yaml file'
@ -135,6 +169,33 @@ class ModelManager(object):
"hash": hash,
}
def get_sub_model(
self,
model_name: str=None,
model_part: SDModelComponent=SDModelComponent.vae,
) -> Union[
AutoencoderKL,
CLIPTokenizer,
CLIPFeatureExtractor,
UNet2DConditionModel,
CLIPTextModel,
StableDiffusionSafetyChecker,
]:
"""Given a model named identified in models.yaml, and the part of the
model you wish to retrieve, return that part. Parts are in an Enum
class named SDModelComponent, and consist of:
SDModelComponent.vae
SDModelComponent.text_encoder
SDModelComponent.tokenizer
SDModelComponent.unet
SDModelComponent.scheduler
SDModelComponent.safety_checker
SDModelComponent.feature_etractor
"""
model_dict = self.get_model(model_name)
model = model_dict["model"]
return getattr(model, model_part.value)
def default_model(self) -> str | None:
"""
Returns the name of the default model, or None
@ -360,7 +421,7 @@ class ModelManager(object):
f"Unknown model format {model_name}: {model_format}"
)
self._add_embeddings_to_model(model)
# usage statistics
toc = time.time()
print(">> Model loaded in", "%4.2fs" % (toc - tic))
@ -433,7 +494,7 @@ class ModelManager(object):
width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor
height = width
print(f" | Default image dimensions = {width} x {height}")
return pipeline, width, height, model_hash
def _load_ckpt_model(self, model_name, mconfig):
@ -454,14 +515,18 @@ class ModelManager(object):
from . import load_pipeline_from_original_stable_diffusion_ckpt
try:
if self.list_models()[self.current_model]['status'] == 'active':
if self.list_models()[self.current_model]["status"] == "active":
self.offload_model(self.current_model)
except Exception as e:
pass
vae_path = None
if vae:
vae_path = vae if os.path.isabs(vae) else os.path.normpath(os.path.join(Globals.root, vae))
vae_path = (
vae
if os.path.isabs(vae)
else os.path.normpath(os.path.join(Globals.root, vae))
)
if self._has_cuda():
torch.cuda.empty_cache()
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
@ -571,9 +636,7 @@ class ModelManager(object):
models.yaml file.
"""
model_name = model_name or Path(repo_or_path).stem
model_description = (
description or f"Imported diffusers model {model_name}"
)
model_description = description or f"Imported diffusers model {model_name}"
new_config = dict(
description=model_description,
vae=vae,
@ -602,7 +665,7 @@ class ModelManager(object):
SDLegacyType.V2_v (V2 using 'v_prediction' prediction type)
SDLegacyType.UNKNOWN
"""
global_step = checkpoint.get('global_step')
global_step = checkpoint.get("global_step")
state_dict = checkpoint.get("state_dict") or checkpoint
try:
@ -628,13 +691,13 @@ class ModelManager(object):
return SDLegacyType.UNKNOWN
def heuristic_import(
self,
path_url_or_repo: str,
model_name: str = None,
description: str = None,
model_config_file: Path = None,
commit_to_conf: Path = None,
config_file_callback: Callable[[Path], Path] = None,
self,
path_url_or_repo: str,
model_name: str = None,
description: str = None,
model_config_file: Path = None,
commit_to_conf: Path = None,
config_file_callback: Callable[[Path], Path] = None,
) -> str:
"""Accept a string which could be:
- a HF diffusers repo_id
@ -738,8 +801,8 @@ class ModelManager(object):
# another round of heuristics to guess the correct config file.
checkpoint = None
if model_path.suffix in [".ckpt",".pt"]:
self.scan_model(model_path,model_path)
if model_path.suffix in [".ckpt", ".pt"]:
self.scan_model(model_path, model_path)
checkpoint = torch.load(model_path)
else:
checkpoint = safetensors.torch.load_file(model_path)
@ -761,19 +824,16 @@ class ModelManager(object):
elif model_type == SDLegacyType.V1_INPAINT:
print(" | SD-v1 inpainting model detected")
model_config_file = Path(
Globals.root, "configs/stable-diffusion/v1-inpainting-inference.yaml"
Globals.root,
"configs/stable-diffusion/v1-inpainting-inference.yaml",
)
elif model_type == SDLegacyType.V2_v:
print(
" | SD-v2-v model detected"
)
print(" | SD-v2-v model detected")
model_config_file = Path(
Globals.root, "configs/stable-diffusion/v2-inference-v.yaml"
)
elif model_type == SDLegacyType.V2_e:
print(
" | SD-v2-e model detected"
)
print(" | SD-v2-e model detected")
model_config_file = Path(
Globals.root, "configs/stable-diffusion/v2-inference.yaml"
)
@ -820,16 +880,16 @@ class ModelManager(object):
return model_name
def convert_and_import(
self,
ckpt_path: Path,
diffusers_path: Path,
model_name=None,
model_description=None,
vae:dict=None,
vae_path:Path=None,
original_config_file: Path = None,
commit_to_conf: Path = None,
scan_needed: bool=True,
self,
ckpt_path: Path,
diffusers_path: Path,
model_name=None,
model_description=None,
vae: dict = None,
vae_path: Path = None,
original_config_file: Path = None,
commit_to_conf: Path = None,
scan_needed: bool = True,
) -> str:
"""
Convert a legacy ckpt weights file to diffuser model and import
@ -857,10 +917,10 @@ class ModelManager(object):
try:
# By passing the specified VAE to the conversion function, the autoencoder
# will be built into the model rather than tacked on afterward via the config file
vae_model=None
vae_model = None
if vae:
vae_model=self._load_vae(vae)
vae_path=None
vae_model = self._load_vae(vae)
vae_path = None
convert_ckpt_to_diffusers(
ckpt_path,
diffusers_path,
@ -976,16 +1036,16 @@ class ModelManager(object):
legacy_locations = [
Path(
models_dir,
"CompVis/stable-diffusion-safety-checker/models--CompVis--stable-diffusion-safety-checker"
"CompVis/stable-diffusion-safety-checker/models--CompVis--stable-diffusion-safety-checker",
),
Path(models_dir, "bert-base-uncased/models--bert-base-uncased"),
Path(
models_dir,
"openai/clip-vit-large-patch14/models--openai--clip-vit-large-patch14"
"openai/clip-vit-large-patch14/models--openai--clip-vit-large-patch14",
),
]
legacy_locations.extend(list(global_cache_dir("diffusers").glob('*')))
legacy_locations.extend(list(global_cache_dir("diffusers").glob("*")))
legacy_layout = False
for model in legacy_locations:
legacy_layout = legacy_layout or model.exists()
@ -1003,7 +1063,7 @@ class ModelManager(object):
>> make adjustments, please press ctrl-C now to abort and relaunch InvokeAI when you are ready.
>> Otherwise press <enter> to continue."""
)
input('continue> ')
input("continue> ")
# transformer files get moved into the hub directory
if cls._is_huggingface_hub_directory_present():
@ -1090,7 +1150,7 @@ class ModelManager(object):
print(
f'>> Textual inversion triggers: {", ".join(sorted(model.textual_inversion_manager.get_all_trigger_strings()))}'
)
def _has_cuda(self) -> bool:
return self.device.type == "cuda"