mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add a new method to model_manager that retrieves individual pipeline parts
- New method is ModelManager.get_sub_model(model_name:str,model_part:SDModelComponent) To use: ``` from invokeai.backend import ModelManager, SDModelComponent as sdmc manager = ModelManager('/path/to/models.yaml') vae = manager.get_sub_model('stable-diffusion-1.5', sdmc.vae) ```
This commit is contained in:
parent
c4e6511a59
commit
d44151d6ff
@ -10,7 +10,7 @@ from .generator import (
|
||||
Img2Img,
|
||||
Inpaint
|
||||
)
|
||||
from .model_management import ModelManager
|
||||
from .model_management import ModelManager, SDModelComponent
|
||||
from .safety_checker import SafetyChecker
|
||||
from .args import Args
|
||||
from .globals import Globals
|
||||
|
@ -5,5 +5,6 @@ from .convert_ckpt_to_diffusers import (
|
||||
convert_ckpt_to_diffusers,
|
||||
load_pipeline_from_original_stable_diffusion_ckpt,
|
||||
)
|
||||
from .model_manager import ModelManager
|
||||
from .model_manager import ModelManager,SDModelComponent
|
||||
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
"""
|
||||
"""enum
|
||||
Manage a cache of Stable Diffusion model files for fast switching.
|
||||
They are moved between GPU and CPU as necessary. If CPU memory falls
|
||||
below a preset minimum, the least recently used model will be
|
||||
@ -15,7 +15,7 @@ import sys
|
||||
import textwrap
|
||||
import time
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from enum import Enum, auto
|
||||
from pathlib import Path
|
||||
from shutil import move, rmtree
|
||||
from typing import Any, Optional, Union, Callable
|
||||
@ -24,8 +24,11 @@ import safetensors
|
||||
import safetensors.torch
|
||||
import torch
|
||||
import transformers
|
||||
from diffusers import AutoencoderKL
|
||||
from diffusers import logging as dlogging
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
UNet2DConditionModel,
|
||||
logging as dlogging,
|
||||
)
|
||||
from huggingface_hub import scan_cache_dir
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
@ -33,37 +36,58 @@ from picklescan.scanner import scan_file_path
|
||||
|
||||
from invokeai.backend.globals import Globals, global_cache_dir
|
||||
|
||||
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
||||
from transformers import (
|
||||
CLIPTextModel,
|
||||
CLIPTokenizer,
|
||||
CLIPFeatureExtractor,
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
||||
StableDiffusionSafetyChecker,
|
||||
)
|
||||
from ..stable_diffusion import (
|
||||
StableDiffusionGeneratorPipeline,
|
||||
)
|
||||
from ..util import CUDA_DEVICE, ask_user, download_with_resume
|
||||
|
||||
class SDLegacyType(Enum):
|
||||
V1 = 1
|
||||
V1_INPAINT = 2
|
||||
V2 = 3
|
||||
V2_e = 4
|
||||
V2_v = 5
|
||||
UNKNOWN = 99
|
||||
|
||||
class SDLegacyType(Enum):
|
||||
V1 = auto()
|
||||
V1_INPAINT = auto()
|
||||
V2 = auto()
|
||||
V2_e = auto()
|
||||
V2_v = auto()
|
||||
UNKNOWN = auto()
|
||||
|
||||
class SDModelComponent(Enum):
|
||||
vae="vae"
|
||||
text_encoder="text_encoder"
|
||||
tokenizer="tokenizer"
|
||||
unet="unet"
|
||||
scheduler="scheduler"
|
||||
safety_checker="safety_checker"
|
||||
feature_extractor="feature_extractor"
|
||||
|
||||
DEFAULT_MAX_MODELS = 2
|
||||
|
||||
class ModelManager(object):
|
||||
'''
|
||||
"""
|
||||
Model manager handles loading, caching, importing, deleting, converting, and editing models.
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: OmegaConf|Path,
|
||||
device_type: torch.device = CUDA_DEVICE,
|
||||
precision: str = "float16",
|
||||
max_loaded_models=DEFAULT_MAX_MODELS,
|
||||
sequential_offload=False,
|
||||
embedding_path: Path=None,
|
||||
self,
|
||||
config: OmegaConf | Path,
|
||||
device_type: torch.device = CUDA_DEVICE,
|
||||
precision: str = "float16",
|
||||
max_loaded_models=DEFAULT_MAX_MODELS,
|
||||
sequential_offload=False,
|
||||
embedding_path: Path = None,
|
||||
):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file or
|
||||
an initialized OmegaConf dictionary. Optional parameters
|
||||
are the torch device type, precision, max_loaded_models,
|
||||
and sequential_offload boolean. Note that the default device
|
||||
and sequential_offload boolean. Note that the default device
|
||||
type and precision are set up for a CUDA system running at half precision.
|
||||
"""
|
||||
# prevent nasty-looking CLIP log message
|
||||
@ -87,15 +111,25 @@ class ModelManager(object):
|
||||
"""
|
||||
return model_name in self.config
|
||||
|
||||
def get_model(self, model_name: str=None)->dict:
|
||||
"""
|
||||
Given a model named identified in models.yaml, return
|
||||
the model object. If in RAM will load into GPU VRAM.
|
||||
If on disk, will load from there.
|
||||
def get_model(self, model_name: str = None) -> dict:
|
||||
"""Given a model named identified in models.yaml, return a dict
|
||||
containing the model object and some of its key features. If
|
||||
in RAM will load into GPU VRAM. If on disk, will load from
|
||||
there.
|
||||
The dict has the following keys:
|
||||
'model': The StableDiffusionGeneratorPipeline object
|
||||
'model_name': The name of the model in models.yaml
|
||||
'width': The width of images trained by this model
|
||||
'height': The height of images trained by this model
|
||||
'hash': A unique hash of this model's files on disk.
|
||||
"""
|
||||
if not model_name:
|
||||
return self.get_model(self.current_model) if self.current_model else self.get_model(self.default_model())
|
||||
|
||||
return (
|
||||
self.get_model(self.current_model)
|
||||
if self.current_model
|
||||
else self.get_model(self.default_model())
|
||||
)
|
||||
|
||||
if not self.valid_model(model_name):
|
||||
print(
|
||||
f'** "{model_name}" is not a known model name. Please check your models.yaml file'
|
||||
@ -135,6 +169,33 @@ class ModelManager(object):
|
||||
"hash": hash,
|
||||
}
|
||||
|
||||
def get_sub_model(
|
||||
self,
|
||||
model_name: str=None,
|
||||
model_part: SDModelComponent=SDModelComponent.vae,
|
||||
) -> Union[
|
||||
AutoencoderKL,
|
||||
CLIPTokenizer,
|
||||
CLIPFeatureExtractor,
|
||||
UNet2DConditionModel,
|
||||
CLIPTextModel,
|
||||
StableDiffusionSafetyChecker,
|
||||
]:
|
||||
"""Given a model named identified in models.yaml, and the part of the
|
||||
model you wish to retrieve, return that part. Parts are in an Enum
|
||||
class named SDModelComponent, and consist of:
|
||||
SDModelComponent.vae
|
||||
SDModelComponent.text_encoder
|
||||
SDModelComponent.tokenizer
|
||||
SDModelComponent.unet
|
||||
SDModelComponent.scheduler
|
||||
SDModelComponent.safety_checker
|
||||
SDModelComponent.feature_etractor
|
||||
"""
|
||||
model_dict = self.get_model(model_name)
|
||||
model = model_dict["model"]
|
||||
return getattr(model, model_part.value)
|
||||
|
||||
def default_model(self) -> str | None:
|
||||
"""
|
||||
Returns the name of the default model, or None
|
||||
@ -360,7 +421,7 @@ class ModelManager(object):
|
||||
f"Unknown model format {model_name}: {model_format}"
|
||||
)
|
||||
self._add_embeddings_to_model(model)
|
||||
|
||||
|
||||
# usage statistics
|
||||
toc = time.time()
|
||||
print(">> Model loaded in", "%4.2fs" % (toc - tic))
|
||||
@ -433,7 +494,7 @@ class ModelManager(object):
|
||||
width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor
|
||||
height = width
|
||||
print(f" | Default image dimensions = {width} x {height}")
|
||||
|
||||
|
||||
return pipeline, width, height, model_hash
|
||||
|
||||
def _load_ckpt_model(self, model_name, mconfig):
|
||||
@ -454,14 +515,18 @@ class ModelManager(object):
|
||||
from . import load_pipeline_from_original_stable_diffusion_ckpt
|
||||
|
||||
try:
|
||||
if self.list_models()[self.current_model]['status'] == 'active':
|
||||
if self.list_models()[self.current_model]["status"] == "active":
|
||||
self.offload_model(self.current_model)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
||||
vae_path = None
|
||||
if vae:
|
||||
vae_path = vae if os.path.isabs(vae) else os.path.normpath(os.path.join(Globals.root, vae))
|
||||
vae_path = (
|
||||
vae
|
||||
if os.path.isabs(vae)
|
||||
else os.path.normpath(os.path.join(Globals.root, vae))
|
||||
)
|
||||
if self._has_cuda():
|
||||
torch.cuda.empty_cache()
|
||||
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
@ -571,9 +636,7 @@ class ModelManager(object):
|
||||
models.yaml file.
|
||||
"""
|
||||
model_name = model_name or Path(repo_or_path).stem
|
||||
model_description = (
|
||||
description or f"Imported diffusers model {model_name}"
|
||||
)
|
||||
model_description = description or f"Imported diffusers model {model_name}"
|
||||
new_config = dict(
|
||||
description=model_description,
|
||||
vae=vae,
|
||||
@ -602,7 +665,7 @@ class ModelManager(object):
|
||||
SDLegacyType.V2_v (V2 using 'v_prediction' prediction type)
|
||||
SDLegacyType.UNKNOWN
|
||||
"""
|
||||
global_step = checkpoint.get('global_step')
|
||||
global_step = checkpoint.get("global_step")
|
||||
state_dict = checkpoint.get("state_dict") or checkpoint
|
||||
|
||||
try:
|
||||
@ -628,13 +691,13 @@ class ModelManager(object):
|
||||
return SDLegacyType.UNKNOWN
|
||||
|
||||
def heuristic_import(
|
||||
self,
|
||||
path_url_or_repo: str,
|
||||
model_name: str = None,
|
||||
description: str = None,
|
||||
model_config_file: Path = None,
|
||||
commit_to_conf: Path = None,
|
||||
config_file_callback: Callable[[Path], Path] = None,
|
||||
self,
|
||||
path_url_or_repo: str,
|
||||
model_name: str = None,
|
||||
description: str = None,
|
||||
model_config_file: Path = None,
|
||||
commit_to_conf: Path = None,
|
||||
config_file_callback: Callable[[Path], Path] = None,
|
||||
) -> str:
|
||||
"""Accept a string which could be:
|
||||
- a HF diffusers repo_id
|
||||
@ -738,8 +801,8 @@ class ModelManager(object):
|
||||
|
||||
# another round of heuristics to guess the correct config file.
|
||||
checkpoint = None
|
||||
if model_path.suffix in [".ckpt",".pt"]:
|
||||
self.scan_model(model_path,model_path)
|
||||
if model_path.suffix in [".ckpt", ".pt"]:
|
||||
self.scan_model(model_path, model_path)
|
||||
checkpoint = torch.load(model_path)
|
||||
else:
|
||||
checkpoint = safetensors.torch.load_file(model_path)
|
||||
@ -761,19 +824,16 @@ class ModelManager(object):
|
||||
elif model_type == SDLegacyType.V1_INPAINT:
|
||||
print(" | SD-v1 inpainting model detected")
|
||||
model_config_file = Path(
|
||||
Globals.root, "configs/stable-diffusion/v1-inpainting-inference.yaml"
|
||||
Globals.root,
|
||||
"configs/stable-diffusion/v1-inpainting-inference.yaml",
|
||||
)
|
||||
elif model_type == SDLegacyType.V2_v:
|
||||
print(
|
||||
" | SD-v2-v model detected"
|
||||
)
|
||||
print(" | SD-v2-v model detected")
|
||||
model_config_file = Path(
|
||||
Globals.root, "configs/stable-diffusion/v2-inference-v.yaml"
|
||||
)
|
||||
elif model_type == SDLegacyType.V2_e:
|
||||
print(
|
||||
" | SD-v2-e model detected"
|
||||
)
|
||||
print(" | SD-v2-e model detected")
|
||||
model_config_file = Path(
|
||||
Globals.root, "configs/stable-diffusion/v2-inference.yaml"
|
||||
)
|
||||
@ -820,16 +880,16 @@ class ModelManager(object):
|
||||
return model_name
|
||||
|
||||
def convert_and_import(
|
||||
self,
|
||||
ckpt_path: Path,
|
||||
diffusers_path: Path,
|
||||
model_name=None,
|
||||
model_description=None,
|
||||
vae:dict=None,
|
||||
vae_path:Path=None,
|
||||
original_config_file: Path = None,
|
||||
commit_to_conf: Path = None,
|
||||
scan_needed: bool=True,
|
||||
self,
|
||||
ckpt_path: Path,
|
||||
diffusers_path: Path,
|
||||
model_name=None,
|
||||
model_description=None,
|
||||
vae: dict = None,
|
||||
vae_path: Path = None,
|
||||
original_config_file: Path = None,
|
||||
commit_to_conf: Path = None,
|
||||
scan_needed: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
Convert a legacy ckpt weights file to diffuser model and import
|
||||
@ -857,10 +917,10 @@ class ModelManager(object):
|
||||
try:
|
||||
# By passing the specified VAE to the conversion function, the autoencoder
|
||||
# will be built into the model rather than tacked on afterward via the config file
|
||||
vae_model=None
|
||||
vae_model = None
|
||||
if vae:
|
||||
vae_model=self._load_vae(vae)
|
||||
vae_path=None
|
||||
vae_model = self._load_vae(vae)
|
||||
vae_path = None
|
||||
convert_ckpt_to_diffusers(
|
||||
ckpt_path,
|
||||
diffusers_path,
|
||||
@ -976,16 +1036,16 @@ class ModelManager(object):
|
||||
legacy_locations = [
|
||||
Path(
|
||||
models_dir,
|
||||
"CompVis/stable-diffusion-safety-checker/models--CompVis--stable-diffusion-safety-checker"
|
||||
"CompVis/stable-diffusion-safety-checker/models--CompVis--stable-diffusion-safety-checker",
|
||||
),
|
||||
Path(models_dir, "bert-base-uncased/models--bert-base-uncased"),
|
||||
Path(
|
||||
models_dir,
|
||||
"openai/clip-vit-large-patch14/models--openai--clip-vit-large-patch14"
|
||||
"openai/clip-vit-large-patch14/models--openai--clip-vit-large-patch14",
|
||||
),
|
||||
]
|
||||
legacy_locations.extend(list(global_cache_dir("diffusers").glob('*')))
|
||||
|
||||
legacy_locations.extend(list(global_cache_dir("diffusers").glob("*")))
|
||||
|
||||
legacy_layout = False
|
||||
for model in legacy_locations:
|
||||
legacy_layout = legacy_layout or model.exists()
|
||||
@ -1003,7 +1063,7 @@ class ModelManager(object):
|
||||
>> make adjustments, please press ctrl-C now to abort and relaunch InvokeAI when you are ready.
|
||||
>> Otherwise press <enter> to continue."""
|
||||
)
|
||||
input('continue> ')
|
||||
input("continue> ")
|
||||
|
||||
# transformer files get moved into the hub directory
|
||||
if cls._is_huggingface_hub_directory_present():
|
||||
@ -1090,7 +1150,7 @@ class ModelManager(object):
|
||||
print(
|
||||
f'>> Textual inversion triggers: {", ".join(sorted(model.textual_inversion_manager.get_all_trigger_strings()))}'
|
||||
)
|
||||
|
||||
|
||||
def _has_cuda(self) -> bool:
|
||||
return self.device.type == "cuda"
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user