move all installation code out of model_manager

This commit is contained in:
Lincoln Stein 2023-06-11 12:51:50 -04:00
parent 74b43c9bdf
commit 000626ab2e
8 changed files with 663 additions and 152 deletions

View File

@ -9,5 +9,5 @@ from .generator import (
Img2Img, Img2Img,
Inpaint Inpaint
) )
from .model_management import ModelManager, ModelCache, SDModelType, SDModelInfo from .model_management import ModelManager, ModelCache, ModelType, ModelInfo
from .safety_checker import SafetyChecker from .safety_checker import SafetyChecker

View File

@ -1,5 +1,6 @@
""" """
Initialization file for invokeai.backend.model_management Initialization file for invokeai.backend.model_management
""" """
from .model_manager import ModelManager, SDModelInfo from .model_manager import ModelManager, ModelInfo
from .model_cache import ModelCache, SDModelType from .model_cache import ModelCache
from .models import ModelType

View File

@ -29,11 +29,8 @@ import torch
from diffusers import logging as diffusers_logging from diffusers import logging as diffusers_logging
from transformers import logging as transformers_logging from transformers import logging as transformers_logging
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from .models import ModelType, SubModelType, ModelBase
from .model_manager import SDModelInfo, ModelType, SubModelType, ModelBase
# Maximum size of the cache, in gigs # Maximum size of the cache, in gigs
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously # Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
@ -50,6 +47,10 @@ class ModelCache(object):
"Forward declaration" "Forward declaration"
pass pass
class SDModelInfo(object):
"""Forward declaration"""
pass
class _CacheRecord: class _CacheRecord:
size: int size: int
model: Any model: Any

View File

@ -0,0 +1,201 @@
"""
Routines for downloading and installing models.
"""
import json
import safetensors
import safetensors.torch
import torch
import traceback
from dataclasses import dataclass
from diffusers import ModelMixin
from enum import Enum
from typing import Callable
from pathlib import Path
from picklescan.scanner import scan_file_path
import invokeai.backend.util.logging as logger
from .models import BaseModelType, ModelType
class CheckpointProbe(object):
PROBES = dict() # see below for redefinition
def __init__(self,
checkpoint_path: Path,
checkpoint: dict = None,
helper: Callable[[Path], BaseModelType]=None
):
checkpoint = checkpoint or self._scan_and_load_checkpoint(self.checkpoint_path)
self.checkpoint = checkpoint
self.checkpoint_path = checkpoint_path
self.helper = helper
def probe(self) -> ModelVariantInfo:
'''
Probes the checkpoint at path `checkpoint_path` and return
a ModelType object indicating the model base, model type and
model variant for the checkpoint.
'''
checkpoint = self.checkpoint
state_dict = checkpoint.get("state_dict") or checkpoint
model_info = None
try:
model_type = self.get_checkpoint_type(state_dict)
if not model_type:
if self.checkpoint_path.name == "learned_embeds.bin":
model_type = ModelType.TextualInversion
else:
return None # we give up
probe = self.PROBES[model_type]()
base_type = probe.get_base_type(checkpoint, self.checkpoint_path, self.helper)
variant_type = probe.get_variant_type(model_type, checkpoint)
model_info = ModelVariantInfo(
model_type = model_type,
base_type = base_type,
variant_type = variant_type,
)
except (KeyError, ValueError) as e:
logger.error(f'An error occurred while probing {self.checkpoint_path}: {str(e)}')
logger.error(traceback.format_exc())
return model_info
class CheckpointProbeBase(object):
def get_base_type(self,
checkpoint: dict,
checkpoint_path: Path = None,
helper: Callable[[Path],BaseModelType] = None
)->BaseModelType:
pass
def get_variant_type(self,
model_type: ModelType,
checkpoint: dict,
)-> VariantType:
if model_type != ModelType.Pipeline:
return None
state_dict = checkpoint.get('state_dict') or checkpoint
in_channels = state_dict[
"model.diffusion_model.input_blocks.0.0.weight"
].shape[1]
if in_channels == 9:
return VariantType.Inpaint
elif in_channels == 5:
return VariantType.depth
else:
return None
class CheckpointProbe(CheckpointProbeBase):
def get_base_type(self,
checkpoint: dict,
checkpoint_path: Path = None,
helper: Callable[[Path],BaseModelType] = None
)->BaseModelType:
state_dict = checkpoint.get('state_dict') or checkpoint
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
return BaseModelType.StableDiffusion1_5
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
if 'global_step' in checkpoint:
if checkpoint['global_step'] == 220000:
return BaseModelType.StableDiffusion2Base
elif checkpoint["global_step"] == 110000:
return BaseModelType.StableDiffusion2
if checkpoint_path and helper:
return helper(checkpoint_path)
else:
return None
class VaeProbe(CheckpointProbeBase):
def get_base_type(self,
checkpoint: dict,
checkpoint_path: Path = None,
helper: Callable[[Path],BaseModelType] = None
)->BaseModelType:
# I can't find any standalone 2.X VAEs to test with!
return BaseModelType.StableDiffusion1_5
class LoRAProbe(CheckpointProbeBase):
def get_base_type(self,
checkpoint: dict,
checkpoint_path: Path = None,
helper: Callable[[Path],BaseModelType] = None
)->BaseModelType:
key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a"
lora_token_vector_length = (
checkpoint[key1].shape[1]
if key1 in checkpoint
else checkpoint[key2].shape[0]
if key2 in checkpoint
else 768
)
if lora_token_vector_length == 768:
return BaseModelType.StableDiffusion1_5
elif lora_token_vector_length == 1024:
return BaseModelType.StableDiffusion2
else:
return None
class TextualInversionProbe(CheckpointProbeBase):
def get_base_type(self,
checkpoint: dict,
checkpoint_path: Path = None,
helper: Callable[[Path],BaseModelType] = None
)->BaseModelType:
if 'string_to_token' in checkpoint:
token_dim = list(checkpoint['string_to_param'].values())[0].shape[-1]
elif 'emb_params' in checkpoint:
token_dim = checkpoint['emb_params'].shape[-1]
else:
token_dim = list(checkpoint.values())[0].shape[0]
if token_dim == 768:
return BaseModelType.StableDiffusion1_5
elif token_dim == 1024:
return BaseModelType.StableDiffusion2Base
else:
return None
class ControlNetProbe(CheckpointProbeBase):
def get_base_type(self,
checkpoint: dict,
checkpoint_path: Path = None,
helper: Callable[[Path],BaseModelType] = None
)->BaseModelType:
for key_name in ('control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight',
'input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight'
):
if key_name not in checkpoint:
continue
if checkpoint[key_name].shape[-1] == 768:
return BaseModelType.StableDiffusion1_5
elif checkpoint_path and helper:
return helper(checkpoint_path)
PROBES = {
ModelType.Pipeline: CheckpointProbe,
ModelType.Vae: VaeProbe,
ModelType.Lora: LoRAProbe,
ModelType.TextualInversion: TextualInversionProbe,
ModelType.ControlNet: ControlNetProbe,
}
@classmethod
def get_checkpoint_type(cls, state_dict: dict) -> ModelType:
if any([x.startswith("model.diffusion_model") for x in state_dict.keys()]):
return ModelType.Pipeline
if any([x.startswith("encoder.conv_in") for x in state_dict.keys()]):
return ModelType.Vae
if "string_to_token" in state_dict or "emb_params" in state_dict:
return ModelType.TextualInversion
if any([x.startswith("lora") for x in state_dict.keys()]):
return ModelType.Lora
if any([x.startswith("control_model") for x in state_dict.keys()]):
return ModelType.ControlNet
if any([x.startswith("input_blocks") for x in state_dict.keys()]):
return ModelType.ControlNet
return None # give up

View File

@ -3,13 +3,13 @@ symbolic diffusers model names to the paths and repo_ids used
by the underlying `from_pretrained()` call. by the underlying `from_pretrained()` call.
For fetching models, use manager.get_model('symbolic name'). This will For fetching models, use manager.get_model('symbolic name'). This will
return a SDModelInfo object that contains the following attributes: return a ModelInfo object that contains the following attributes:
* context -- a context manager Generator that loads and locks the * context -- a context manager Generator that loads and locks the
model into GPU VRAM and returns the model for use. model into GPU VRAM and returns the model for use.
See below for usage. See below for usage.
* name -- symbolic name of the model * name -- symbolic name of the model
* type -- SDModelType of the model * type -- SubModelType of the model
* hash -- unique hash for the model * hash -- unique hash for the model
* location -- path or repo_id of the model * location -- path or repo_id of the model
* revision -- revision of the model if coming from a repo id, * revision -- revision of the model if coming from a repo id,
@ -25,7 +25,7 @@ Typical usage:
max_cache_size=8 max_cache_size=8
) # gigabytes ) # gigabytes
model_info = manager.get_model('stable-diffusion-1.5', SDModelType.Diffusers) model_info = manager.get_model('stable-diffusion-1.5', SubModelType.Diffusers)
with model_info.context as my_model: with model_info.context as my_model:
my_model.latents_from_embeddings(...) my_model.latents_from_embeddings(...)
@ -43,7 +43,7 @@ parameter:
model_info = manager.get_model( model_info = manager.get_model(
'clip-tokenizer', 'clip-tokenizer',
model_type=SDModelType.Tokenizer model_type=SubModelType.Tokenizer
) )
This will raise an InvalidModelError if the format defined in the This will raise an InvalidModelError if the format defined in the
@ -63,7 +63,7 @@ The general format of a models.yaml section is:
The type of model is given in the stanza key, and is one of The type of model is given in the stanza key, and is one of
{diffusers, ckpt, vae, text_encoder, tokenizer, unet, scheduler, {diffusers, ckpt, vae, text_encoder, tokenizer, unet, scheduler,
safety_checker, feature_extractor, lora, textual_inversion, safety_checker, feature_extractor, lora, textual_inversion,
controlnet}, and correspond to items in the SDModelType enum defined controlnet}, and correspond to items in the SubModelType enum defined
in model_cache.py in model_cache.py
The format indicates whether the model is organized as a folder with The format indicates whether the model is organized as a folder with
@ -96,7 +96,7 @@ SUBMODELS:
It is also possible to fetch an isolated submodel from a diffusers It is also possible to fetch an isolated submodel from a diffusers
model. Use the `submodel` parameter to select which part: model. Use the `submodel` parameter to select which part:
vae = manager.get_model('stable-diffusion-1.5',submodel=SDModelType.Vae) vae = manager.get_model('stable-diffusion-1.5',submodel=SubModelType.Vae)
with vae.context as my_vae: with vae.context as my_vae:
print(type(my_vae)) print(type(my_vae))
# "AutoencoderKL" # "AutoencoderKL"
@ -128,8 +128,8 @@ separated by "/". Example:
You can now use the `model_type` argument to indicate which model you You can now use the `model_type` argument to indicate which model you
want: want:
tokenizer = mgr.get('clip-large',model_type=SDModelType.Tokenizer) tokenizer = mgr.get('clip-large',model_type=SubModelType.Tokenizer)
encoder = mgr.get('clip-large',model_type=SDModelType.TextEncoder) encoder = mgr.get('clip-large',model_type=SubModelType.TextEncoder)
OTHER FUNCTIONS: OTHER FUNCTIONS:
@ -164,15 +164,24 @@ import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util import CUDA_DEVICE, download_with_resume from invokeai.backend.util import CUDA_DEVICE, download_with_resume
from .model_cache import ModelCache, ModelLocker from .model_cache import ModelCache, ModelLocker
from .models import BaseModelType, ModelType, SubModelType, MODEL_CLASSES from .models import BaseModelType, SubModelType, MODEL_CLASSES
# We are only starting to number the config file with release 3. # We are only starting to number the config file with release 3.
# The config file version doesn't have to start at release version, but it will help # The config file version doesn't have to start at release version, but it will help
# reduce confusion. # reduce confusion.
CONFIG_FILE_VERSION='3.0.0' CONFIG_FILE_VERSION='3.0.0'
# wanted to use pydantic here, but Generator objects not supported # temporary forward definitions to avoid circular import errors.
class ModelLocker(object):
"Forward declaration"
pass
class ModelCache(object):
"Forward declaration"
pass
@dataclass @dataclass
class SDModelInfo(): class ModelInfo():
context: ModelLocker context: ModelLocker
name: str name: str
type: ModelType type: ModelType
@ -303,7 +312,7 @@ class ModelManager(object):
submodel_type: Optional[SubModelType] = None submodel_type: Optional[SubModelType] = None
): ):
"""Given a model named identified in models.yaml, return """Given a model named identified in models.yaml, return
an SDModelInfo object describing it. an ModelInfo object describing it.
:param model_name: symbolic name of the model in models.yaml :param model_name: symbolic name of the model in models.yaml
:param model_type: ModelType enum indicating the type of model to return :param model_type: ModelType enum indicating the type of model to return
:param submode_typel: an ModelType enum indicating the portion of :param submode_typel: an ModelType enum indicating the portion of
@ -389,7 +398,7 @@ class ModelManager(object):
hash = "<NO_HASH>" # TODO: hash = "<NO_HASH>" # TODO:
return SDModelInfo( return ModelInfo(
context = model_context, context = model_context,
name = model_name, name = model_name,
base_model = base_model, base_model = base_model,
@ -746,62 +755,3 @@ class ModelManager(object):
if self.config_path: if self.config_path:
self.commit() self.commit()
def _delete_defunct_models(self):
'''
Remove models no longer on disk.
'''
config = self.config
to_delete = set()
for key in config:
if 'path' not in config[key]:
continue
path = self.globals.root_dir / config[key].path
if path.exists():
continue
to_delete.add(key)
for key in to_delete:
self.logger.warn(f'Removing model {key} from in-memory config because its path is no longer on disk')
config.pop(key)
def scan_models_directory(self, include_diffusers:bool=False):
'''
Scan the models directory for loras, textual_inversions and controlnets
and create appropriate entries in the in-memory omegaconf. Diffusers
will not be added unless include_diffusers is true.
'''
self._delete_defunct_models()
model_directory = self.globals.models_path
config = self.config
for root, dirs, files in os.walk(model_directory):
parents = root.split('/')
subpaths = parents[parents.index('models')+1:]
if len(subpaths) < 2:
continue
base, model_type, *_ = subpaths
if model_type == "diffusers" and not include_diffusers:
continue
for d in dirs:
config[f'{model_type}/{d}'] = dict(
path = os.path.join(root,d),
description = f'{model_type} model {d}',
format = 'folder',
base = base,
)
for f in files:
basename = Path(f).stem
format = Path(f).suffix[1:]
config[f'{model_type}/{basename}'] = dict(
path = os.path.join(root,f),
description = f'{model_type} model {basename}',
format = format,
base = base,
)

View File

@ -0,0 +1,333 @@
import json
import traceback
import torch
import safetensors.torch
from dataclasses import dataclass
from diffusers import ModelMixin
from pathlib import Path
from typing import Callable, Literal, Union, Dict
from picklescan.scanner import scan_file_path
import invokeai.backend.util.logging as logger
from .models import BaseModelType, ModelType, VariantType
@dataclass
class ModelVariantInfo(object):
model_type: ModelType
base_type: BaseModelType
variant_type: VariantType
class ProbeBase(object):
'''forward declaration'''
pass
class ModelProbe(object):
PROBES = {
'folder': { },
'file': { },
}
CLASS2TYPE = {
"StableDiffusionPipeline" : ModelType.Pipeline,
"AutoencoderKL": ModelType.Vae,
"ControlNetModel" : ModelType.ControlNet,
}
@classmethod
def register_probe(cls,
format: Literal['folder','file'],
model_type: ModelType,
probe_class: ProbeBase):
cls.PROBES[format][model_type] = probe_class
@classmethod
def probe(cls,
model_path: Path,
model: Union[Dict, ModelMixin] = None,
base_helper: Callable[[Path],BaseModelType] = None)->ModelVariantInfo:
'''
Probe the model at model_path and return sufficient information about it
to place it somewhere in the models directory hierarchy. If the model is
already loaded into memory, you may provide it as model in order to avoid
opening it a second time. The base_helper callable is a function that receives
the path to the model and returns the BaseModelType. It is called to distinguish
between V2-Base and V2-768 SD models.
'''
format = 'folder' if model_path.is_dir() else 'file'
model_info = None
try:
model_type = cls.get_model_type_from_folder(model_path, model) \
if format == 'folder' \
else cls.get_model_type_from_checkpoint(model_path, model)
probe_class = cls.PROBES[format].get(model_type)
if not probe_class:
return None
probe = probe_class(model_path, model, base_helper)
base_type = probe.get_base_type()
variant_type = probe.get_variant_type()
model_info = ModelVariantInfo(
model_type = model_type,
base_type = base_type,
variant_type = variant_type,
)
except (KeyError, ValueError) as e:
logger.error(f'An error occurred while probing {model_path}: {str(e)}')
logger.error(traceback.format_exc())
return model_info
@classmethod
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict)->ModelType:
checkpoint = checkpoint or cls._scan_and_load_checkpoint(model_path)
state_dict = checkpoint.get("state_dict") or checkpoint
if any([x.startswith("model.diffusion_model") for x in state_dict.keys()]):
return ModelType.Pipeline
if any([x.startswith("encoder.conv_in") for x in state_dict.keys()]):
return ModelType.Vae
if "string_to_token" in state_dict or "emb_params" in state_dict:
return ModelType.TextualInversion
if any([x.startswith("lora") for x in state_dict.keys()]):
return ModelType.Lora
if any([x.startswith("control_model") for x in state_dict.keys()]):
return ModelType.ControlNet
if any([x.startswith("input_blocks") for x in state_dict.keys()]):
return ModelType.ControlNet
return None # give up
@classmethod
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin)->ModelType:
'''
Get the model type of a hugging-face style folder.
'''
if (folder_path / 'learned_embeds.bin').exists():
return ModelType.TextualInversion
if (folder_path / 'pytorch_lora_weights.bin').exists():
return ModelType.Lora
i = folder_path / 'model_index.json'
c = folder_path / 'config.json'
config_path = i if i.exists() else c if c.exists() else None
if config_path:
conf = json.loads(config_path)
class_name = conf['_class_name']
if type := cls.CLASS2TYPE.get(class_name):
return type
# give up
raise ValueError("Unable to determine model type of {model_path}")
@classmethod
def _scan_and_load_checkpoint(cls,model_path: Path)->dict:
if model_path.suffix.endswith((".ckpt", ".pt", ".bin")):
cls._scan_model(model_path, model_path)
return torch.load(model_path)
else:
return safetensors.torch.load_file(model_path)
@classmethod
def _scan_model(cls, model_name, checkpoint):
"""
Apply picklescanner to the indicated checkpoint and issue a warning
and option to exit if an infected file is identified.
"""
# scan model
scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0:
raise "The model {model_name} is potentially infected by malware. Aborting import."
###################################################3
# Checkpoint probing
###################################################3
class ProbeBase(object):
def get_base_type(self)->BaseModelType:
pass
def get_variant_type(self)->VariantType:
pass
class CheckpointProbeBase(ProbeBase):
def __init__(self,
checkpoint_path: Path,
checkpoint: dict,
helper: Callable[[Path],BaseModelType] = None
)->BaseModelType:
self.checkpoint = checkpoint or ModelProbe._scan_and_load_checkpoint(checkpoint_path)
self.checkpoint_path = checkpoint_path
self.helper = helper
def get_base_type(self)->BaseModelType:
pass
def get_variant_type(self)-> VariantType:
model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path,self.checkpoint)
if model_type != ModelType.Pipeline:
return VariantType.Normal
state_dict = self.checkpoint.get('state_dict') or self.checkpoint
in_channels = state_dict[
"model.diffusion_model.input_blocks.0.0.weight"
].shape[1]
if in_channels == 9:
return VariantType.Inpaint
elif in_channels == 5:
return VariantType.Depth
else:
return None
class PipelineCheckpointProbe(CheckpointProbeBase):
def get_base_type(self)->BaseModelType:
checkpoint = self.checkpoint
helper = self.helper
state_dict = self.checkpoint.get('state_dict') or checkpoint
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
return BaseModelType.StableDiffusion1_5
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
if 'global_step' in checkpoint:
if checkpoint['global_step'] == 220000:
return BaseModelType.StableDiffusion2Base
elif checkpoint["global_step"] == 110000:
return BaseModelType.StableDiffusion2
if self.checkpoint_path and helper:
return helper(self.checkpoint_path)
else:
return None
class VaeCheckpointProbe(CheckpointProbeBase):
def get_base_type(self)->BaseModelType:
# I can't find any standalone 2.X VAEs to test with!
return BaseModelType.StableDiffusion1_5
class LoRACheckpointProbe(CheckpointProbeBase):
def get_base_type(self)->BaseModelType:
checkpoint = self.checkpoint
key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a"
lora_token_vector_length = (
checkpoint[key1].shape[1]
if key1 in checkpoint
else checkpoint[key2].shape[0]
if key2 in checkpoint
else 768
)
if lora_token_vector_length == 768:
return BaseModelType.StableDiffusion1_5
elif lora_token_vector_length == 1024:
return BaseModelType.StableDiffusion2
else:
return None
class TextualInversionCheckpointProbe(CheckpointProbeBase):
def get_base_type(self)->BaseModelType:
checkpoint = self.checkpoint
if 'string_to_token' in checkpoint:
token_dim = list(checkpoint['string_to_param'].values())[0].shape[-1]
elif 'emb_params' in checkpoint:
token_dim = checkpoint['emb_params'].shape[-1]
else:
token_dim = list(checkpoint.values())[0].shape[0]
if token_dim == 768:
return BaseModelType.StableDiffusion1_5
elif token_dim == 1024:
return BaseModelType.StableDiffusion2Base
else:
return None
class ControlNetCheckpointProbe(CheckpointProbeBase):
def get_base_type(self)->BaseModelType:
checkpoint = self.checkpoint
for key_name in ('control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight',
'input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight'
):
if key_name not in checkpoint:
continue
if checkpoint[key_name].shape[-1] == 768:
return BaseModelType.StableDiffusion1_5
elif self.checkpoint_path and self.helper:
return self.helper(self.checkpoint_path)
########################################################
# classes for probing folders
#######################################################
class FolderProbeBase(ProbeBase):
def __init__(self,
model: ModelMixin,
folder_path: Path,
helper: Callable=None # not used
):
self.model = model
self.folder_path = folder_path
def get_variant_type(self)->VariantType:
# only works for pipelines
config_file = self.folder_path / 'unet' / 'config.json'
if not config_file.exists():
return VariantType.Normal
conf = json.loads(config_file)
channels = conf['in_channels']
if channels == 9:
return VariantType.Inpainting
elif channels == 5:
return VariantType.Depth
elif channels == 4:
return VariantType.Normal
else:
return VariantType.Normal
class PipelineFolderProbe(FolderProbeBase):
def get_base_type(self)->BaseModelType:
config_file = self.folder_path / 'scheduler' / 'scheduler_config.json'
if not config_file.exists():
return None
conf = json.load(config_file)
if conf['prediction_type'] == "v_prediction":
return BaseModelType.StableDiffusion2
elif conf['prediction_type'] == 'epsilon':
return BaseModelType.StableDiffusion2Base
else:
return BaseModelType.StableDiffusion2
class VaeFolderProbe(FolderProbeBase):
def get_base_type(self)->BaseModelType:
return BaseModelType.StableDiffusion1_5
class TextualInversionFolderProbe(FolderProbeBase):
def get_base_type(self)->BaseModelType:
path = self.folder_path / 'learned_embeds.bin'
if not path.exists():
return None
checkpoint = ModelProbe._scan_and_load_checkpoint(path)
return TextualInversionCheckpointProbe(checkpoint).get_base_type
class ControlNetFolderProbe(FolderProbeBase):
def get_base_type(self)->BaseModelType:
config_file = self.folder_path / 'scheduler_config.json'
if not config_file.exists():
return None
config = json.load(config_file)
# no obvious way to distinguish between sd2-base and sd2-768
return BaseModelType.StableDiffusion1_5 \
if config['cross_attention_dim']==768 \
else BaseModelType.StableDiffusion2
class LoRAFolderProbe(FolderProbeBase):
# I've never seen one of these in the wild, so this is a noop
pass
############## register probe classes ######
ModelProbe.register_probe('folder', ModelType.Pipeline, PipelineFolderProbe)
ModelProbe.register_probe('folder', ModelType.Vae, VaeFolderProbe)
ModelProbe.register_probe('folder', ModelType.Lora, LoRAFolderProbe)
ModelProbe.register_probe('folder', ModelType.TextualInversion, TextualInversionFolderProbe)
ModelProbe.register_probe('folder', ModelType.ControlNet, ControlNetFolderProbe)
ModelProbe.register_probe('file', ModelType.Pipeline, PipelineCheckpointProbe)
ModelProbe.register_probe('file', ModelType.Vae, VaeCheckpointProbe)
ModelProbe.register_probe('file', ModelType.Lora, LoRACheckpointProbe)
ModelProbe.register_probe('file', ModelType.TextualInversion, TextualInversionCheckpointProbe)
ModelProbe.register_probe('file', ModelType.ControlNet, ControlNetCheckpointProbe)

View File

@ -1,8 +1,14 @@
import sys import sys
from dataclasses import dataclass
from enum import Enum from enum import Enum
import torch import torch
import safetensors.torch import safetensors.torch
from diffusers import ConfigMixin
from diffusers.utils import is_safetensors_available from diffusers.utils import is_safetensors_available
from omegaconf import DictConfig
from pathlib import Path
from pydantic import BaseModel, Field, root_validator
from typing import Union, List, Type, Optional
class BaseModelType(str, Enum): class BaseModelType(str, Enum):
# TODO: maybe then add sample size(512/768)? # TODO: maybe then add sample size(512/768)?
@ -26,42 +32,13 @@ class SubModelType:
Vae = "vae" Vae = "vae"
Scheduler = "scheduler" Scheduler = "scheduler"
SafetyChecker = "safety_checker" SafetyChecker = "safety_checker"
FeatureExtractor = "feature_extractor"
#MoVQ = "movq" #MoVQ = "movq"
MODEL_CLASSES = { class VariantType(str, Enum):
BaseModel.StableDiffusion1_5: { Normal = "normal"
ModelType.Pipeline: StableDiffusionModel, Inpaint = "inpaint"
ModelType.Classifier: ClassifierModel, Depth = "depth"
ModelType.Vae: VaeModel,
ModelType.Lora: LoraModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
},
BaseModel.StableDiffusion2: {
ModelType.Pipeline: StableDiffusionModel,
ModelType.Classifier: ClassifierModel,
ModelType.Vae: VaeModel,
ModelType.Lora: LoraModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
},
BaseModel.StableDiffusion2Base: {
ModelType.Pipeline: StableDiffusionModel,
ModelType.Classifier: ClassifierModel,
ModelType.Vae: VaeModel,
ModelType.Lora: LoraModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
},
#BaseModel.Kandinsky2_1: {
# ModelType.Pipeline: Kandinsky2_1Model,
# ModelType.Classifier: ClassifierModel,
# ModelType.MoVQ: MoVQModel,
# ModelType.Lora: LoraModel,
# ModelType.ControlNet: ControlNetModel,
# ModelType.TextualInversion: TextualInversionModel,
#},
}
class EmptyConfigLoader(ConfigMixin): class EmptyConfigLoader(ConfigMixin):
@classmethod @classmethod
@ -323,7 +300,7 @@ class ClassifierModel(ModelBase):
#child_sizes: Dict[str, int] #child_sizes: Dict[str, int]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == SDModelType.Classifier assert model_type == ModelType.Classifier
super().__init__(model_path, base_model, model_type) super().__init__(model_path, base_model, model_type)
self.child_types: Dict[str, Type] = dict() self.child_types: Dict[str, Type] = dict()
@ -354,8 +331,8 @@ class ClassifierModel(ModelBase):
else: else:
raise Exception("Invalid classifier model! (Failed to detect tokenizer type)") raise Exception("Invalid classifier model! (Failed to detect tokenizer type)")
self.child_types[SDModelType.Tokenizer] = self._hf_definition_to_type(["transformers", tokenizer_class_name]) self.child_types[SubModelType.Tokenizer] = self._hf_definition_to_type(["transformers", tokenizer_class_name])
self.child_sizes[SDModelType.Tokenizer] = 0 self.child_sizes[SubModelType.Tokenizer] = 0
def _load_text_encoder(self, main_config: dict): def _load_text_encoder(self, main_config: dict):
@ -366,12 +343,12 @@ class ClassifierModel(ModelBase):
else: else:
raise Exception("Invalid classifier model! (Failed to detect text_encoder type)") raise Exception("Invalid classifier model! (Failed to detect text_encoder type)")
self.child_types[SDModelType.TextEncoder] = self._hf_definition_to_type(["transformers", text_encoder_class_name]) self.child_types[SubModelType.TextEncoder] = self._hf_definition_to_type(["transformers", text_encoder_class_name])
self.child_sizes[SDModelType.TextEncoder] = calc_model_size_by_fs(self.model_path) self.child_sizes[SubModelType.TextEncoder] = calc_model_size_by_fs(self.model_path)
def _load_feature_extractor(self, main_config: dict): def _load_feature_extractor(self, main_config: dict):
self.child_sizes[SDModelType.FeatureExtractor] = 0 self.child_sizes[SubModelType.FeatureExtractor] = 0
try: try:
feature_extractor_config = EmptyConfigLoader.load_config(self.model_path, config_name="preprocessor_config.json") feature_extractor_config = EmptyConfigLoader.load_config(self.model_path, config_name="preprocessor_config.json")
except: except:
@ -379,12 +356,12 @@ class ClassifierModel(ModelBase):
try: try:
feature_extractor_class_name = feature_extractor_config["feature_extractor_type"] feature_extractor_class_name = feature_extractor_config["feature_extractor_type"]
self.child_types[SDModelType.FeatureExtractor] = self._hf_definition_to_type(["transformers", feature_extractor_class_name]) self.child_types[SubModelType.FeatureExtractor] = self._hf_definition_to_type(["transformers", feature_extractor_class_name])
except: except:
raise Exception("Invalid classifier model! (Unknown feature_extrator type)") raise Exception("Invalid classifier model! (Unknown feature_extrator type)")
def get_size(self, child_type: Optional[SDModelType] = None): def get_size(self, child_type: Optional[SubModelType] = None):
if child_type is None: if child_type is None:
return sum(self.child_sizes.values()) return sum(self.child_sizes.values())
else: else:
@ -394,7 +371,7 @@ class ClassifierModel(ModelBase):
def get_model( def get_model(
self, self,
torch_dtype: Optional[torch.dtype], torch_dtype: Optional[torch.dtype],
child_type: Optional[SDModelType] = None, child_type: Optional[SubModelType] = None,
): ):
if child_type is None: if child_type is None:
raise Exception("Child model type can't be null on classififer model") raise Exception("Child model type can't be null on classififer model")
@ -437,7 +414,7 @@ class VaeModel(ModelBase):
except: except:
raise Exception("Invalid vae model! (Unkown vae type)") raise Exception("Invalid vae model! (Unkown vae type)")
def get_size(self, child_type: Optional[SDModelType] = None): def get_size(self, child_type: Optional[SubModelType] = None):
if child_type is not None: if child_type is not None:
raise Exception("There is no child models in vae model") raise Exception("There is no child models in vae model")
return self.model_size return self.model_size
@ -445,7 +422,7 @@ class VaeModel(ModelBase):
def get_model( def get_model(
self, self,
torch_dtype: Optional[torch.dtype], torch_dtype: Optional[torch.dtype],
child_type: Optional[SDModelType] = None, child_type: Optional[SubModelType] = None,
): ):
if child_type is not None: if child_type is not None:
raise Exception("There is no child models in vae model") raise Exception("There is no child models in vae model")
@ -476,7 +453,7 @@ class LoRAModel(ModelBase):
self.model_size = os.path.getsize(self.model_path) self.model_size = os.path.getsize(self.model_path)
def get_size(self, child_type: Optional[SDModelType] = None): def get_size(self, child_type: Optional[SubModelType] = None):
if child_type is not None: if child_type is not None:
raise Exception("There is no child models in lora") raise Exception("There is no child models in lora")
return self.model_size return self.model_size
@ -484,7 +461,7 @@ class LoRAModel(ModelBase):
def get_model( def get_model(
self, self,
torch_dtype: Optional[torch.dtype], torch_dtype: Optional[torch.dtype],
child_type: Optional[SDModelType] = None, child_type: Optional[SubModelType] = None,
): ):
if child_type is not None: if child_type is not None:
raise Exception("There is no child models in lora") raise Exception("There is no child models in lora")
@ -505,7 +482,6 @@ class LoRAModel(ModelBase):
# TODO: add diffusers lora when it stabilizes a bit # TODO: add diffusers lora when it stabilizes a bit
return model_path return model_path
class TextualInversionModel(ModelBase): class TextualInversionModel(ModelBase):
#model_size: int #model_size: int
@ -515,7 +491,7 @@ class TextualInversionModel(ModelBase):
self.model_size = os.path.getsize(self.model_path) self.model_size = os.path.getsize(self.model_path)
def get_size(self, child_type: Optional[SDModelType] = None): def get_size(self, child_type: Optional[SubModelType] = None):
if child_type is not None: if child_type is not None:
raise Exception("There is no child models in textual inversion") raise Exception("There is no child models in textual inversion")
return self.model_size return self.model_size
@ -523,7 +499,7 @@ class TextualInversionModel(ModelBase):
def get_model( def get_model(
self, self,
torch_dtype: Optional[torch.dtype], torch_dtype: Optional[torch.dtype],
child_type: Optional[SDModelType] = None, child_type: Optional[SubModelType] = None,
): ):
if child_type is not None: if child_type is not None:
raise Exception("There is no child models in textual inversion") raise Exception("There is no child models in textual inversion")
@ -542,6 +518,10 @@ class TextualInversionModel(ModelBase):
model_path = Path(model_path) model_path = Path(model_path)
return model_path return model_path
class ControlNetModel(ModelBase):
"""requires implementation"""
pass
def calc_model_size_by_fs( def calc_model_size_by_fs(
model_path: str, model_path: str,
subfolder: Optional[str] = None, subfolder: Optional[str] = None,
@ -710,3 +690,39 @@ def _convert_vae_ckpt_and_cache(self, mconfig: DictConfig) -> Path:
safe_serialization=is_safetensors_available() safe_serialization=is_safetensors_available()
) )
return diffusers_path return diffusers_path
MODEL_CLASSES = {
BaseModelType.StableDiffusion1_5: {
ModelType.Pipeline: StableDiffusionModel,
ModelType.Classifier: ClassifierModel,
ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
},
BaseModelType.StableDiffusion2: {
ModelType.Pipeline: StableDiffusionModel,
ModelType.Classifier: ClassifierModel,
ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
},
BaseModelType.StableDiffusion2Base: {
ModelType.Pipeline: StableDiffusionModel,
ModelType.Classifier: ClassifierModel,
ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
},
#BaseModel.Kandinsky2_1: {
# ModelType.Pipeline: Kandinsky2_1Model,
# ModelType.Classifier: ClassifierModel,
# ModelType.MoVQ: MoVQModel,
# ModelType.Lora: LoraModel,
# ModelType.ControlNet: ControlNetModel,
# ModelType.TextualInversion: TextualInversionModel,
#},
}

55
scripts/scan_models_directory.py Normal file → Executable file
View File

@ -1,3 +1,5 @@
#!/usr/bin/env python
''' '''
Scan the models directory and print out a new models.yaml Scan the models directory and print out a new models.yaml
''' '''
@ -12,6 +14,12 @@ from omegaconf import OmegaConf
def main(): def main():
parser = argparse.ArgumentParser(description="Model directory scanner") parser = argparse.ArgumentParser(description="Model directory scanner")
parser.add_argument('models_directory') parser.add_argument('models_directory')
parser.add_argument('--all-models',
default=False,
action='store_true',
help='If true, then generates stanzas for all models; otherwise just diffusers'
)
args = parser.parse_args() args = parser.parse_args()
directory = args.models_directory directory = args.models_directory
@ -19,29 +27,30 @@ def main():
conf['_version'] = '3.0.0' conf['_version'] = '3.0.0'
for root, dirs, files in os.walk(directory): for root, dirs, files in os.walk(directory):
for d in dirs: parents = root.split('/')
parents = root.split('/') subpaths = parents[parents.index('models')+1:]
subpaths = parents[parents.index('models')+1:] if len(subpaths) < 2:
if len(subpaths) < 2: continue
continue base, model_type, *_ = subpaths
base, model_type, *_ = subpaths
if args.all_models or model_type=='diffusers':
conf[f'{model_type}/{d}'] = dict( for d in dirs:
path = os.path.join(root,d), conf[f'{base}/{model_type}/{d}'] = dict(
description = f'{model_type} model {d}', path = os.path.join(root,d),
format = 'folder', description = f'{model_type} model {d}',
base = base, format = 'folder',
) base = base,
)
for f in files:
basename = Path(f).stem for f in files:
format = Path(f).suffix[1:] basename = Path(f).stem
conf[f'{model_type}/{basename}'] = dict( format = Path(f).suffix[1:]
path = os.path.join(root,f), conf[f'{base}/{model_type}/{basename}'] = dict(
description = f'{model_type} model {basename}', path = os.path.join(root,f),
format = format, description = f'{model_type} model {basename}',
base = base, format = format,
) base = base,
)
OmegaConf.save(config=dict(sorted(conf.items())), f=sys.stdout) OmegaConf.save(config=dict(sorted(conf.items())), f=sys.stdout)