mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add a new method to model_manager that retrieves individual pipeline components (#3120)
This PR introduces a new set of ModelManager methods that enables you to retrieve the individual parts of a stable diffusion pipeline model, including the vae, text_encoder, unet, tokenizer, etc. To use: ``` from invokeai.backend import ModelManager manager = ModelManager('/path/to/models.yaml') # get the VAE vae = manager.get_model_vae('stable-diffusion-1.5') # get the unet unet = manager.get_model_unet('stable-diffusion-1.5') # get the tokenizer tokenizer = manager.get_model_tokenizer('stable-diffusion-1.5') # etc etc feature_extractor = manager.get_model_feature_extractor('stable-diffusion-1.5') scheduler = manager.get_model_scheduler('stable-diffusion-1.5') text_encoder = manager.get_model_text_encoder('stable-diffusion-1.5') # if no model provided, then defaults to the one currently in GPU, if any vae = manager.get_model_vae() ```
This commit is contained in:
commit
e5f8b22a43
@ -7,3 +7,4 @@ from .convert_ckpt_to_diffusers import (
|
|||||||
)
|
)
|
||||||
from .model_manager import ModelManager
|
from .model_manager import ModelManager
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
"""
|
"""enum
|
||||||
Manage a cache of Stable Diffusion model files for fast switching.
|
Manage a cache of Stable Diffusion model files for fast switching.
|
||||||
They are moved between GPU and CPU as necessary. If CPU memory falls
|
They are moved between GPU and CPU as necessary. If CPU memory falls
|
||||||
below a preset minimum, the least recently used model will be
|
below a preset minimum, the least recently used model will be
|
||||||
@ -15,7 +15,7 @@ import sys
|
|||||||
import textwrap
|
import textwrap
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from enum import Enum
|
from enum import Enum, auto
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import move, rmtree
|
from shutil import move, rmtree
|
||||||
from typing import Any, Optional, Union, Callable
|
from typing import Any, Optional, Union, Callable
|
||||||
@ -24,8 +24,12 @@ import safetensors
|
|||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from diffusers import AutoencoderKL
|
from diffusers import (
|
||||||
from diffusers import logging as dlogging
|
AutoencoderKL,
|
||||||
|
UNet2DConditionModel,
|
||||||
|
SchedulerMixin,
|
||||||
|
logging as dlogging,
|
||||||
|
)
|
||||||
from huggingface_hub import scan_cache_dir
|
from huggingface_hub import scan_cache_dir
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from omegaconf.dictconfig import DictConfig
|
from omegaconf.dictconfig import DictConfig
|
||||||
@ -33,31 +37,52 @@ from picklescan.scanner import scan_file_path
|
|||||||
|
|
||||||
from invokeai.backend.globals import Globals, global_cache_dir
|
from invokeai.backend.globals import Globals, global_cache_dir
|
||||||
|
|
||||||
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
from transformers import (
|
||||||
|
CLIPTextModel,
|
||||||
|
CLIPTokenizer,
|
||||||
|
CLIPFeatureExtractor,
|
||||||
|
)
|
||||||
|
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
||||||
|
StableDiffusionSafetyChecker,
|
||||||
|
)
|
||||||
|
from ..stable_diffusion import (
|
||||||
|
StableDiffusionGeneratorPipeline,
|
||||||
|
)
|
||||||
from ..util import CUDA_DEVICE, ask_user, download_with_resume
|
from ..util import CUDA_DEVICE, ask_user, download_with_resume
|
||||||
|
|
||||||
|
|
||||||
class SDLegacyType(Enum):
|
class SDLegacyType(Enum):
|
||||||
V1 = 1
|
V1 = auto()
|
||||||
V1_INPAINT = 2
|
V1_INPAINT = auto()
|
||||||
V2 = 3
|
V2 = auto()
|
||||||
V2_e = 4
|
V2_e = auto()
|
||||||
V2_v = 5
|
V2_v = auto()
|
||||||
UNKNOWN = 99
|
UNKNOWN = auto()
|
||||||
|
|
||||||
|
class SDModelComponent(Enum):
|
||||||
|
vae="vae"
|
||||||
|
text_encoder="text_encoder"
|
||||||
|
tokenizer="tokenizer"
|
||||||
|
unet="unet"
|
||||||
|
scheduler="scheduler"
|
||||||
|
safety_checker="safety_checker"
|
||||||
|
feature_extractor="feature_extractor"
|
||||||
|
|
||||||
DEFAULT_MAX_MODELS = 2
|
DEFAULT_MAX_MODELS = 2
|
||||||
|
|
||||||
class ModelManager(object):
|
class ModelManager(object):
|
||||||
'''
|
"""
|
||||||
Model manager handles loading, caching, importing, deleting, converting, and editing models.
|
Model manager handles loading, caching, importing, deleting, converting, and editing models.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: OmegaConf|Path,
|
config: OmegaConf | Path,
|
||||||
device_type: torch.device = CUDA_DEVICE,
|
device_type: torch.device = CUDA_DEVICE,
|
||||||
precision: str = "float16",
|
precision: str = "float16",
|
||||||
max_loaded_models=DEFAULT_MAX_MODELS,
|
max_loaded_models=DEFAULT_MAX_MODELS,
|
||||||
sequential_offload=False,
|
sequential_offload=False,
|
||||||
embedding_path: Path=None,
|
embedding_path: Path = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize with the path to the models.yaml config file or
|
Initialize with the path to the models.yaml config file or
|
||||||
@ -87,14 +112,24 @@ class ModelManager(object):
|
|||||||
"""
|
"""
|
||||||
return model_name in self.config
|
return model_name in self.config
|
||||||
|
|
||||||
def get_model(self, model_name: str=None)->dict:
|
def get_model(self, model_name: str = None) -> dict:
|
||||||
"""
|
"""Given a model named identified in models.yaml, return a dict
|
||||||
Given a model named identified in models.yaml, return
|
containing the model object and some of its key features. If
|
||||||
the model object. If in RAM will load into GPU VRAM.
|
in RAM will load into GPU VRAM. If on disk, will load from
|
||||||
If on disk, will load from there.
|
there.
|
||||||
|
The dict has the following keys:
|
||||||
|
'model': The StableDiffusionGeneratorPipeline object
|
||||||
|
'model_name': The name of the model in models.yaml
|
||||||
|
'width': The width of images trained by this model
|
||||||
|
'height': The height of images trained by this model
|
||||||
|
'hash': A unique hash of this model's files on disk.
|
||||||
"""
|
"""
|
||||||
if not model_name:
|
if not model_name:
|
||||||
return self.get_model(self.current_model) if self.current_model else self.get_model(self.default_model())
|
return (
|
||||||
|
self.get_model(self.current_model)
|
||||||
|
if self.current_model
|
||||||
|
else self.get_model(self.default_model())
|
||||||
|
)
|
||||||
|
|
||||||
if not self.valid_model(model_name):
|
if not self.valid_model(model_name):
|
||||||
print(
|
print(
|
||||||
@ -135,6 +170,81 @@ class ModelManager(object):
|
|||||||
"hash": hash,
|
"hash": hash,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def get_model_vae(self, model_name: str=None)->AutoencoderKL:
|
||||||
|
"""Given a model name identified in models.yaml, load the model into
|
||||||
|
GPU if necessary and return its assigned VAE as an
|
||||||
|
AutoencoderKL object. If no model name is provided, return the
|
||||||
|
vae from the model currently in the GPU.
|
||||||
|
"""
|
||||||
|
return self._get_sub_model(model_name, SDModelComponent.vae)
|
||||||
|
|
||||||
|
def get_model_tokenizer(self, model_name: str=None)->CLIPTokenizer:
|
||||||
|
"""Given a model name identified in models.yaml, load the model into
|
||||||
|
GPU if necessary and return its assigned CLIPTokenizer. If no
|
||||||
|
model name is provided, return the tokenizer from the model
|
||||||
|
currently in the GPU.
|
||||||
|
"""
|
||||||
|
return self._get_sub_model(model_name, SDModelComponent.tokenizer)
|
||||||
|
|
||||||
|
def get_model_unet(self, model_name: str=None)->UNet2DConditionModel:
|
||||||
|
"""Given a model name identified in models.yaml, load the model into
|
||||||
|
GPU if necessary and return its assigned UNet2DConditionModel. If no model
|
||||||
|
name is provided, return the UNet from the model
|
||||||
|
currently in the GPU.
|
||||||
|
"""
|
||||||
|
return self._get_sub_model(model_name, SDModelComponent.unet)
|
||||||
|
|
||||||
|
def get_model_text_encoder(self, model_name: str=None)->CLIPTextModel:
|
||||||
|
"""Given a model name identified in models.yaml, load the model into
|
||||||
|
GPU if necessary and return its assigned CLIPTextModel. If no
|
||||||
|
model name is provided, return the text encoder from the model
|
||||||
|
currently in the GPU.
|
||||||
|
"""
|
||||||
|
return self._get_sub_model(model_name, SDModelComponent.text_encoder)
|
||||||
|
|
||||||
|
def get_model_feature_extractor(self, model_name: str=None)->CLIPFeatureExtractor:
|
||||||
|
"""Given a model name identified in models.yaml, load the model into
|
||||||
|
GPU if necessary and return its assigned CLIPFeatureExtractor. If no
|
||||||
|
model name is provided, return the text encoder from the model
|
||||||
|
currently in the GPU.
|
||||||
|
"""
|
||||||
|
return self._get_sub_model(model_name, SDModelComponent.feature_extractor)
|
||||||
|
|
||||||
|
def get_model_scheduler(self, model_name: str=None)->SchedulerMixin:
|
||||||
|
"""Given a model name identified in models.yaml, load the model into
|
||||||
|
GPU if necessary and return its assigned scheduler. If no
|
||||||
|
model name is provided, return the text encoder from the model
|
||||||
|
currently in the GPU.
|
||||||
|
"""
|
||||||
|
return self._get_sub_model(model_name, SDModelComponent.scheduler)
|
||||||
|
|
||||||
|
def _get_sub_model(
|
||||||
|
self,
|
||||||
|
model_name: str=None,
|
||||||
|
model_part: SDModelComponent=SDModelComponent.vae,
|
||||||
|
) -> Union[
|
||||||
|
AutoencoderKL,
|
||||||
|
CLIPTokenizer,
|
||||||
|
CLIPFeatureExtractor,
|
||||||
|
UNet2DConditionModel,
|
||||||
|
CLIPTextModel,
|
||||||
|
StableDiffusionSafetyChecker,
|
||||||
|
]:
|
||||||
|
"""Given a model name identified in models.yaml, and the part of the
|
||||||
|
model you wish to retrieve, return that part. Parts are in an Enum
|
||||||
|
class named SDModelComponent, and consist of:
|
||||||
|
SDModelComponent.vae
|
||||||
|
SDModelComponent.text_encoder
|
||||||
|
SDModelComponent.tokenizer
|
||||||
|
SDModelComponent.unet
|
||||||
|
SDModelComponent.scheduler
|
||||||
|
SDModelComponent.safety_checker
|
||||||
|
SDModelComponent.feature_extractor
|
||||||
|
"""
|
||||||
|
model_dict = self.get_model(model_name)
|
||||||
|
model = model_dict["model"]
|
||||||
|
return getattr(model, model_part.value)
|
||||||
|
|
||||||
def default_model(self) -> str | None:
|
def default_model(self) -> str | None:
|
||||||
"""
|
"""
|
||||||
Returns the name of the default model, or None
|
Returns the name of the default model, or None
|
||||||
@ -454,14 +564,18 @@ class ModelManager(object):
|
|||||||
from . import load_pipeline_from_original_stable_diffusion_ckpt
|
from . import load_pipeline_from_original_stable_diffusion_ckpt
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if self.list_models()[self.current_model]['status'] == 'active':
|
if self.list_models()[self.current_model]["status"] == "active":
|
||||||
self.offload_model(self.current_model)
|
self.offload_model(self.current_model)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
vae_path = None
|
vae_path = None
|
||||||
if vae:
|
if vae:
|
||||||
vae_path = vae if os.path.isabs(vae) else os.path.normpath(os.path.join(Globals.root, vae))
|
vae_path = (
|
||||||
|
vae
|
||||||
|
if os.path.isabs(vae)
|
||||||
|
else os.path.normpath(os.path.join(Globals.root, vae))
|
||||||
|
)
|
||||||
if self._has_cuda():
|
if self._has_cuda():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
|
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
|
||||||
@ -571,9 +685,7 @@ class ModelManager(object):
|
|||||||
models.yaml file.
|
models.yaml file.
|
||||||
"""
|
"""
|
||||||
model_name = model_name or Path(repo_or_path).stem
|
model_name = model_name or Path(repo_or_path).stem
|
||||||
model_description = (
|
model_description = description or f"Imported diffusers model {model_name}"
|
||||||
description or f"Imported diffusers model {model_name}"
|
|
||||||
)
|
|
||||||
new_config = dict(
|
new_config = dict(
|
||||||
description=model_description,
|
description=model_description,
|
||||||
vae=vae,
|
vae=vae,
|
||||||
@ -602,7 +714,7 @@ class ModelManager(object):
|
|||||||
SDLegacyType.V2_v (V2 using 'v_prediction' prediction type)
|
SDLegacyType.V2_v (V2 using 'v_prediction' prediction type)
|
||||||
SDLegacyType.UNKNOWN
|
SDLegacyType.UNKNOWN
|
||||||
"""
|
"""
|
||||||
global_step = checkpoint.get('global_step')
|
global_step = checkpoint.get("global_step")
|
||||||
state_dict = checkpoint.get("state_dict") or checkpoint
|
state_dict = checkpoint.get("state_dict") or checkpoint
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -738,8 +850,8 @@ class ModelManager(object):
|
|||||||
|
|
||||||
# another round of heuristics to guess the correct config file.
|
# another round of heuristics to guess the correct config file.
|
||||||
checkpoint = None
|
checkpoint = None
|
||||||
if model_path.suffix in [".ckpt",".pt"]:
|
if model_path.suffix in [".ckpt", ".pt"]:
|
||||||
self.scan_model(model_path,model_path)
|
self.scan_model(model_path, model_path)
|
||||||
checkpoint = torch.load(model_path)
|
checkpoint = torch.load(model_path)
|
||||||
else:
|
else:
|
||||||
checkpoint = safetensors.torch.load_file(model_path)
|
checkpoint = safetensors.torch.load_file(model_path)
|
||||||
@ -761,19 +873,16 @@ class ModelManager(object):
|
|||||||
elif model_type == SDLegacyType.V1_INPAINT:
|
elif model_type == SDLegacyType.V1_INPAINT:
|
||||||
print(" | SD-v1 inpainting model detected")
|
print(" | SD-v1 inpainting model detected")
|
||||||
model_config_file = Path(
|
model_config_file = Path(
|
||||||
Globals.root, "configs/stable-diffusion/v1-inpainting-inference.yaml"
|
Globals.root,
|
||||||
|
"configs/stable-diffusion/v1-inpainting-inference.yaml",
|
||||||
)
|
)
|
||||||
elif model_type == SDLegacyType.V2_v:
|
elif model_type == SDLegacyType.V2_v:
|
||||||
print(
|
print(" | SD-v2-v model detected")
|
||||||
" | SD-v2-v model detected"
|
|
||||||
)
|
|
||||||
model_config_file = Path(
|
model_config_file = Path(
|
||||||
Globals.root, "configs/stable-diffusion/v2-inference-v.yaml"
|
Globals.root, "configs/stable-diffusion/v2-inference-v.yaml"
|
||||||
)
|
)
|
||||||
elif model_type == SDLegacyType.V2_e:
|
elif model_type == SDLegacyType.V2_e:
|
||||||
print(
|
print(" | SD-v2-e model detected")
|
||||||
" | SD-v2-e model detected"
|
|
||||||
)
|
|
||||||
model_config_file = Path(
|
model_config_file = Path(
|
||||||
Globals.root, "configs/stable-diffusion/v2-inference.yaml"
|
Globals.root, "configs/stable-diffusion/v2-inference.yaml"
|
||||||
)
|
)
|
||||||
@ -825,11 +934,11 @@ class ModelManager(object):
|
|||||||
diffusers_path: Path,
|
diffusers_path: Path,
|
||||||
model_name=None,
|
model_name=None,
|
||||||
model_description=None,
|
model_description=None,
|
||||||
vae:dict=None,
|
vae: dict = None,
|
||||||
vae_path:Path=None,
|
vae_path: Path = None,
|
||||||
original_config_file: Path = None,
|
original_config_file: Path = None,
|
||||||
commit_to_conf: Path = None,
|
commit_to_conf: Path = None,
|
||||||
scan_needed: bool=True,
|
scan_needed: bool = True,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Convert a legacy ckpt weights file to diffuser model and import
|
Convert a legacy ckpt weights file to diffuser model and import
|
||||||
@ -857,10 +966,10 @@ class ModelManager(object):
|
|||||||
try:
|
try:
|
||||||
# By passing the specified VAE to the conversion function, the autoencoder
|
# By passing the specified VAE to the conversion function, the autoencoder
|
||||||
# will be built into the model rather than tacked on afterward via the config file
|
# will be built into the model rather than tacked on afterward via the config file
|
||||||
vae_model=None
|
vae_model = None
|
||||||
if vae:
|
if vae:
|
||||||
vae_model=self._load_vae(vae)
|
vae_model = self._load_vae(vae)
|
||||||
vae_path=None
|
vae_path = None
|
||||||
convert_ckpt_to_diffusers(
|
convert_ckpt_to_diffusers(
|
||||||
ckpt_path,
|
ckpt_path,
|
||||||
diffusers_path,
|
diffusers_path,
|
||||||
@ -976,15 +1085,15 @@ class ModelManager(object):
|
|||||||
legacy_locations = [
|
legacy_locations = [
|
||||||
Path(
|
Path(
|
||||||
models_dir,
|
models_dir,
|
||||||
"CompVis/stable-diffusion-safety-checker/models--CompVis--stable-diffusion-safety-checker"
|
"CompVis/stable-diffusion-safety-checker/models--CompVis--stable-diffusion-safety-checker",
|
||||||
),
|
),
|
||||||
Path(models_dir, "bert-base-uncased/models--bert-base-uncased"),
|
Path(models_dir, "bert-base-uncased/models--bert-base-uncased"),
|
||||||
Path(
|
Path(
|
||||||
models_dir,
|
models_dir,
|
||||||
"openai/clip-vit-large-patch14/models--openai--clip-vit-large-patch14"
|
"openai/clip-vit-large-patch14/models--openai--clip-vit-large-patch14",
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
legacy_locations.extend(list(global_cache_dir("diffusers").glob('*')))
|
legacy_locations.extend(list(global_cache_dir("diffusers").glob("*")))
|
||||||
|
|
||||||
legacy_layout = False
|
legacy_layout = False
|
||||||
for model in legacy_locations:
|
for model in legacy_locations:
|
||||||
@ -1003,7 +1112,7 @@ class ModelManager(object):
|
|||||||
>> make adjustments, please press ctrl-C now to abort and relaunch InvokeAI when you are ready.
|
>> make adjustments, please press ctrl-C now to abort and relaunch InvokeAI when you are ready.
|
||||||
>> Otherwise press <enter> to continue."""
|
>> Otherwise press <enter> to continue."""
|
||||||
)
|
)
|
||||||
input('continue> ')
|
input("continue> ")
|
||||||
|
|
||||||
# transformer files get moved into the hub directory
|
# transformer files get moved into the hub directory
|
||||||
if cls._is_huggingface_hub_directory_present():
|
if cls._is_huggingface_hub_directory_present():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user