add a new method to model_manager that retrieves individual pipeline components (#3120)

This PR introduces a new set of ModelManager methods that enables you to
retrieve the individual parts of a stable diffusion pipeline model,
including the vae, text_encoder, unet, tokenizer, etc.

To use:

```
from invokeai.backend import ModelManager

manager = ModelManager('/path/to/models.yaml')

# get the VAE
vae = manager.get_model_vae('stable-diffusion-1.5')

# get the unet
unet = manager.get_model_unet('stable-diffusion-1.5')

# get the tokenizer
tokenizer = manager.get_model_tokenizer('stable-diffusion-1.5')

# etc etc
feature_extractor = manager.get_model_feature_extractor('stable-diffusion-1.5')
scheduler = manager.get_model_scheduler('stable-diffusion-1.5')
text_encoder = manager.get_model_text_encoder('stable-diffusion-1.5')

# if no model provided, then defaults to the one currently in GPU, if any
vae = manager.get_model_vae()
```
This commit is contained in:
Lincoln Stein 2023-04-07 01:39:57 -04:00 committed by GitHub
commit e5f8b22a43
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 183 additions and 73 deletions

View File

@ -7,3 +7,4 @@ from .convert_ckpt_to_diffusers import (
) )
from .model_manager import ModelManager from .model_manager import ModelManager

View File

@ -1,4 +1,4 @@
""" """enum
Manage a cache of Stable Diffusion model files for fast switching. Manage a cache of Stable Diffusion model files for fast switching.
They are moved between GPU and CPU as necessary. If CPU memory falls They are moved between GPU and CPU as necessary. If CPU memory falls
below a preset minimum, the least recently used model will be below a preset minimum, the least recently used model will be
@ -15,7 +15,7 @@ import sys
import textwrap import textwrap
import time import time
import warnings import warnings
from enum import Enum from enum import Enum, auto
from pathlib import Path from pathlib import Path
from shutil import move, rmtree from shutil import move, rmtree
from typing import Any, Optional, Union, Callable from typing import Any, Optional, Union, Callable
@ -24,8 +24,12 @@ import safetensors
import safetensors.torch import safetensors.torch
import torch import torch
import transformers import transformers
from diffusers import AutoencoderKL from diffusers import (
from diffusers import logging as dlogging AutoencoderKL,
UNet2DConditionModel,
SchedulerMixin,
logging as dlogging,
)
from huggingface_hub import scan_cache_dir from huggingface_hub import scan_cache_dir
from omegaconf import OmegaConf from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig from omegaconf.dictconfig import DictConfig
@ -33,23 +37,44 @@ from picklescan.scanner import scan_file_path
from invokeai.backend.globals import Globals, global_cache_dir from invokeai.backend.globals import Globals, global_cache_dir
from ..stable_diffusion import StableDiffusionGeneratorPipeline from transformers import (
CLIPTextModel,
CLIPTokenizer,
CLIPFeatureExtractor,
)
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
from ..stable_diffusion import (
StableDiffusionGeneratorPipeline,
)
from ..util import CUDA_DEVICE, ask_user, download_with_resume from ..util import CUDA_DEVICE, ask_user, download_with_resume
class SDLegacyType(Enum): class SDLegacyType(Enum):
V1 = 1 V1 = auto()
V1_INPAINT = 2 V1_INPAINT = auto()
V2 = 3 V2 = auto()
V2_e = 4 V2_e = auto()
V2_v = 5 V2_v = auto()
UNKNOWN = 99 UNKNOWN = auto()
class SDModelComponent(Enum):
vae="vae"
text_encoder="text_encoder"
tokenizer="tokenizer"
unet="unet"
scheduler="scheduler"
safety_checker="safety_checker"
feature_extractor="feature_extractor"
DEFAULT_MAX_MODELS = 2 DEFAULT_MAX_MODELS = 2
class ModelManager(object): class ModelManager(object):
''' """
Model manager handles loading, caching, importing, deleting, converting, and editing models. Model manager handles loading, caching, importing, deleting, converting, and editing models.
''' """
def __init__( def __init__(
self, self,
config: OmegaConf | Path, config: OmegaConf | Path,
@ -88,13 +113,23 @@ class ModelManager(object):
return model_name in self.config return model_name in self.config
def get_model(self, model_name: str = None) -> dict: def get_model(self, model_name: str = None) -> dict:
""" """Given a model named identified in models.yaml, return a dict
Given a model named identified in models.yaml, return containing the model object and some of its key features. If
the model object. If in RAM will load into GPU VRAM. in RAM will load into GPU VRAM. If on disk, will load from
If on disk, will load from there. there.
The dict has the following keys:
'model': The StableDiffusionGeneratorPipeline object
'model_name': The name of the model in models.yaml
'width': The width of images trained by this model
'height': The height of images trained by this model
'hash': A unique hash of this model's files on disk.
""" """
if not model_name: if not model_name:
return self.get_model(self.current_model) if self.current_model else self.get_model(self.default_model()) return (
self.get_model(self.current_model)
if self.current_model
else self.get_model(self.default_model())
)
if not self.valid_model(model_name): if not self.valid_model(model_name):
print( print(
@ -135,6 +170,81 @@ class ModelManager(object):
"hash": hash, "hash": hash,
} }
def get_model_vae(self, model_name: str=None)->AutoencoderKL:
"""Given a model name identified in models.yaml, load the model into
GPU if necessary and return its assigned VAE as an
AutoencoderKL object. If no model name is provided, return the
vae from the model currently in the GPU.
"""
return self._get_sub_model(model_name, SDModelComponent.vae)
def get_model_tokenizer(self, model_name: str=None)->CLIPTokenizer:
"""Given a model name identified in models.yaml, load the model into
GPU if necessary and return its assigned CLIPTokenizer. If no
model name is provided, return the tokenizer from the model
currently in the GPU.
"""
return self._get_sub_model(model_name, SDModelComponent.tokenizer)
def get_model_unet(self, model_name: str=None)->UNet2DConditionModel:
"""Given a model name identified in models.yaml, load the model into
GPU if necessary and return its assigned UNet2DConditionModel. If no model
name is provided, return the UNet from the model
currently in the GPU.
"""
return self._get_sub_model(model_name, SDModelComponent.unet)
def get_model_text_encoder(self, model_name: str=None)->CLIPTextModel:
"""Given a model name identified in models.yaml, load the model into
GPU if necessary and return its assigned CLIPTextModel. If no
model name is provided, return the text encoder from the model
currently in the GPU.
"""
return self._get_sub_model(model_name, SDModelComponent.text_encoder)
def get_model_feature_extractor(self, model_name: str=None)->CLIPFeatureExtractor:
"""Given a model name identified in models.yaml, load the model into
GPU if necessary and return its assigned CLIPFeatureExtractor. If no
model name is provided, return the text encoder from the model
currently in the GPU.
"""
return self._get_sub_model(model_name, SDModelComponent.feature_extractor)
def get_model_scheduler(self, model_name: str=None)->SchedulerMixin:
"""Given a model name identified in models.yaml, load the model into
GPU if necessary and return its assigned scheduler. If no
model name is provided, return the text encoder from the model
currently in the GPU.
"""
return self._get_sub_model(model_name, SDModelComponent.scheduler)
def _get_sub_model(
self,
model_name: str=None,
model_part: SDModelComponent=SDModelComponent.vae,
) -> Union[
AutoencoderKL,
CLIPTokenizer,
CLIPFeatureExtractor,
UNet2DConditionModel,
CLIPTextModel,
StableDiffusionSafetyChecker,
]:
"""Given a model name identified in models.yaml, and the part of the
model you wish to retrieve, return that part. Parts are in an Enum
class named SDModelComponent, and consist of:
SDModelComponent.vae
SDModelComponent.text_encoder
SDModelComponent.tokenizer
SDModelComponent.unet
SDModelComponent.scheduler
SDModelComponent.safety_checker
SDModelComponent.feature_extractor
"""
model_dict = self.get_model(model_name)
model = model_dict["model"]
return getattr(model, model_part.value)
def default_model(self) -> str | None: def default_model(self) -> str | None:
""" """
Returns the name of the default model, or None Returns the name of the default model, or None
@ -454,14 +564,18 @@ class ModelManager(object):
from . import load_pipeline_from_original_stable_diffusion_ckpt from . import load_pipeline_from_original_stable_diffusion_ckpt
try: try:
if self.list_models()[self.current_model]['status'] == 'active': if self.list_models()[self.current_model]["status"] == "active":
self.offload_model(self.current_model) self.offload_model(self.current_model)
except Exception as e: except Exception as e:
pass pass
vae_path = None vae_path = None
if vae: if vae:
vae_path = vae if os.path.isabs(vae) else os.path.normpath(os.path.join(Globals.root, vae)) vae_path = (
vae
if os.path.isabs(vae)
else os.path.normpath(os.path.join(Globals.root, vae))
)
if self._has_cuda(): if self._has_cuda():
torch.cuda.empty_cache() torch.cuda.empty_cache()
pipeline = load_pipeline_from_original_stable_diffusion_ckpt( pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
@ -571,9 +685,7 @@ class ModelManager(object):
models.yaml file. models.yaml file.
""" """
model_name = model_name or Path(repo_or_path).stem model_name = model_name or Path(repo_or_path).stem
model_description = ( model_description = description or f"Imported diffusers model {model_name}"
description or f"Imported diffusers model {model_name}"
)
new_config = dict( new_config = dict(
description=model_description, description=model_description,
vae=vae, vae=vae,
@ -602,7 +714,7 @@ class ModelManager(object):
SDLegacyType.V2_v (V2 using 'v_prediction' prediction type) SDLegacyType.V2_v (V2 using 'v_prediction' prediction type)
SDLegacyType.UNKNOWN SDLegacyType.UNKNOWN
""" """
global_step = checkpoint.get('global_step') global_step = checkpoint.get("global_step")
state_dict = checkpoint.get("state_dict") or checkpoint state_dict = checkpoint.get("state_dict") or checkpoint
try: try:
@ -761,19 +873,16 @@ class ModelManager(object):
elif model_type == SDLegacyType.V1_INPAINT: elif model_type == SDLegacyType.V1_INPAINT:
print(" | SD-v1 inpainting model detected") print(" | SD-v1 inpainting model detected")
model_config_file = Path( model_config_file = Path(
Globals.root, "configs/stable-diffusion/v1-inpainting-inference.yaml" Globals.root,
"configs/stable-diffusion/v1-inpainting-inference.yaml",
) )
elif model_type == SDLegacyType.V2_v: elif model_type == SDLegacyType.V2_v:
print( print(" | SD-v2-v model detected")
" | SD-v2-v model detected"
)
model_config_file = Path( model_config_file = Path(
Globals.root, "configs/stable-diffusion/v2-inference-v.yaml" Globals.root, "configs/stable-diffusion/v2-inference-v.yaml"
) )
elif model_type == SDLegacyType.V2_e: elif model_type == SDLegacyType.V2_e:
print( print(" | SD-v2-e model detected")
" | SD-v2-e model detected"
)
model_config_file = Path( model_config_file = Path(
Globals.root, "configs/stable-diffusion/v2-inference.yaml" Globals.root, "configs/stable-diffusion/v2-inference.yaml"
) )
@ -976,15 +1085,15 @@ class ModelManager(object):
legacy_locations = [ legacy_locations = [
Path( Path(
models_dir, models_dir,
"CompVis/stable-diffusion-safety-checker/models--CompVis--stable-diffusion-safety-checker" "CompVis/stable-diffusion-safety-checker/models--CompVis--stable-diffusion-safety-checker",
), ),
Path(models_dir, "bert-base-uncased/models--bert-base-uncased"), Path(models_dir, "bert-base-uncased/models--bert-base-uncased"),
Path( Path(
models_dir, models_dir,
"openai/clip-vit-large-patch14/models--openai--clip-vit-large-patch14" "openai/clip-vit-large-patch14/models--openai--clip-vit-large-patch14",
), ),
] ]
legacy_locations.extend(list(global_cache_dir("diffusers").glob('*'))) legacy_locations.extend(list(global_cache_dir("diffusers").glob("*")))
legacy_layout = False legacy_layout = False
for model in legacy_locations: for model in legacy_locations:
@ -1003,7 +1112,7 @@ class ModelManager(object):
>> make adjustments, please press ctrl-C now to abort and relaunch InvokeAI when you are ready. >> make adjustments, please press ctrl-C now to abort and relaunch InvokeAI when you are ready.
>> Otherwise press <enter> to continue.""" >> Otherwise press <enter> to continue."""
) )
input('continue> ') input("continue> ")
# transformer files get moved into the hub directory # transformer files get moved into the hub directory
if cls._is_huggingface_hub_directory_present(): if cls._is_huggingface_hub_directory_present():