move all installation code out of model_manager

This commit is contained in:
Lincoln Stein 2023-06-11 12:51:50 -04:00
parent 74b43c9bdf
commit 000626ab2e
8 changed files with 663 additions and 152 deletions

View File

@ -9,5 +9,5 @@ from .generator import (
Img2Img,
Inpaint
)
from .model_management import ModelManager, ModelCache, SDModelType, SDModelInfo
from .model_management import ModelManager, ModelCache, ModelType, ModelInfo
from .safety_checker import SafetyChecker

View File

@ -1,5 +1,6 @@
"""
Initialization file for invokeai.backend.model_management
"""
from .model_manager import ModelManager, SDModelInfo
from .model_cache import ModelCache, SDModelType
from .model_manager import ModelManager, ModelInfo
from .model_cache import ModelCache
from .models import ModelType

View File

@ -29,11 +29,8 @@ import torch
from diffusers import logging as diffusers_logging
from transformers import logging as transformers_logging
import invokeai.backend.util.logging as logger
from .model_manager import SDModelInfo, ModelType, SubModelType, ModelBase
from .models import ModelType, SubModelType, ModelBase
# Maximum size of the cache, in gigs
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
@ -50,6 +47,10 @@ class ModelCache(object):
"Forward declaration"
pass
class SDModelInfo(object):
"""Forward declaration"""
pass
class _CacheRecord:
size: int
model: Any

View File

@ -0,0 +1,201 @@
"""
Routines for downloading and installing models.
"""
import json
import safetensors
import safetensors.torch
import torch
import traceback
from dataclasses import dataclass
from diffusers import ModelMixin
from enum import Enum
from typing import Callable
from pathlib import Path
from picklescan.scanner import scan_file_path
import invokeai.backend.util.logging as logger
from .models import BaseModelType, ModelType
class CheckpointProbe(object):
PROBES = dict() # see below for redefinition
def __init__(self,
checkpoint_path: Path,
checkpoint: dict = None,
helper: Callable[[Path], BaseModelType]=None
):
checkpoint = checkpoint or self._scan_and_load_checkpoint(self.checkpoint_path)
self.checkpoint = checkpoint
self.checkpoint_path = checkpoint_path
self.helper = helper
def probe(self) -> ModelVariantInfo:
'''
Probes the checkpoint at path `checkpoint_path` and return
a ModelType object indicating the model base, model type and
model variant for the checkpoint.
'''
checkpoint = self.checkpoint
state_dict = checkpoint.get("state_dict") or checkpoint
model_info = None
try:
model_type = self.get_checkpoint_type(state_dict)
if not model_type:
if self.checkpoint_path.name == "learned_embeds.bin":
model_type = ModelType.TextualInversion
else:
return None # we give up
probe = self.PROBES[model_type]()
base_type = probe.get_base_type(checkpoint, self.checkpoint_path, self.helper)
variant_type = probe.get_variant_type(model_type, checkpoint)
model_info = ModelVariantInfo(
model_type = model_type,
base_type = base_type,
variant_type = variant_type,
)
except (KeyError, ValueError) as e:
logger.error(f'An error occurred while probing {self.checkpoint_path}: {str(e)}')
logger.error(traceback.format_exc())
return model_info
class CheckpointProbeBase(object):
def get_base_type(self,
checkpoint: dict,
checkpoint_path: Path = None,
helper: Callable[[Path],BaseModelType] = None
)->BaseModelType:
pass
def get_variant_type(self,
model_type: ModelType,
checkpoint: dict,
)-> VariantType:
if model_type != ModelType.Pipeline:
return None
state_dict = checkpoint.get('state_dict') or checkpoint
in_channels = state_dict[
"model.diffusion_model.input_blocks.0.0.weight"
].shape[1]
if in_channels == 9:
return VariantType.Inpaint
elif in_channels == 5:
return VariantType.depth
else:
return None
class CheckpointProbe(CheckpointProbeBase):
def get_base_type(self,
checkpoint: dict,
checkpoint_path: Path = None,
helper: Callable[[Path],BaseModelType] = None
)->BaseModelType:
state_dict = checkpoint.get('state_dict') or checkpoint
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
return BaseModelType.StableDiffusion1_5
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
if 'global_step' in checkpoint:
if checkpoint['global_step'] == 220000:
return BaseModelType.StableDiffusion2Base
elif checkpoint["global_step"] == 110000:
return BaseModelType.StableDiffusion2
if checkpoint_path and helper:
return helper(checkpoint_path)
else:
return None
class VaeProbe(CheckpointProbeBase):
def get_base_type(self,
checkpoint: dict,
checkpoint_path: Path = None,
helper: Callable[[Path],BaseModelType] = None
)->BaseModelType:
# I can't find any standalone 2.X VAEs to test with!
return BaseModelType.StableDiffusion1_5
class LoRAProbe(CheckpointProbeBase):
def get_base_type(self,
checkpoint: dict,
checkpoint_path: Path = None,
helper: Callable[[Path],BaseModelType] = None
)->BaseModelType:
key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a"
lora_token_vector_length = (
checkpoint[key1].shape[1]
if key1 in checkpoint
else checkpoint[key2].shape[0]
if key2 in checkpoint
else 768
)
if lora_token_vector_length == 768:
return BaseModelType.StableDiffusion1_5
elif lora_token_vector_length == 1024:
return BaseModelType.StableDiffusion2
else:
return None
class TextualInversionProbe(CheckpointProbeBase):
def get_base_type(self,
checkpoint: dict,
checkpoint_path: Path = None,
helper: Callable[[Path],BaseModelType] = None
)->BaseModelType:
if 'string_to_token' in checkpoint:
token_dim = list(checkpoint['string_to_param'].values())[0].shape[-1]
elif 'emb_params' in checkpoint:
token_dim = checkpoint['emb_params'].shape[-1]
else:
token_dim = list(checkpoint.values())[0].shape[0]
if token_dim == 768:
return BaseModelType.StableDiffusion1_5
elif token_dim == 1024:
return BaseModelType.StableDiffusion2Base
else:
return None
class ControlNetProbe(CheckpointProbeBase):
def get_base_type(self,
checkpoint: dict,
checkpoint_path: Path = None,
helper: Callable[[Path],BaseModelType] = None
)->BaseModelType:
for key_name in ('control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight',
'input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight'
):
if key_name not in checkpoint:
continue
if checkpoint[key_name].shape[-1] == 768:
return BaseModelType.StableDiffusion1_5
elif checkpoint_path and helper:
return helper(checkpoint_path)
PROBES = {
ModelType.Pipeline: CheckpointProbe,
ModelType.Vae: VaeProbe,
ModelType.Lora: LoRAProbe,
ModelType.TextualInversion: TextualInversionProbe,
ModelType.ControlNet: ControlNetProbe,
}
@classmethod
def get_checkpoint_type(cls, state_dict: dict) -> ModelType:
if any([x.startswith("model.diffusion_model") for x in state_dict.keys()]):
return ModelType.Pipeline
if any([x.startswith("encoder.conv_in") for x in state_dict.keys()]):
return ModelType.Vae
if "string_to_token" in state_dict or "emb_params" in state_dict:
return ModelType.TextualInversion
if any([x.startswith("lora") for x in state_dict.keys()]):
return ModelType.Lora
if any([x.startswith("control_model") for x in state_dict.keys()]):
return ModelType.ControlNet
if any([x.startswith("input_blocks") for x in state_dict.keys()]):
return ModelType.ControlNet
return None # give up

View File

@ -3,13 +3,13 @@ symbolic diffusers model names to the paths and repo_ids used
by the underlying `from_pretrained()` call.
For fetching models, use manager.get_model('symbolic name'). This will
return a SDModelInfo object that contains the following attributes:
return a ModelInfo object that contains the following attributes:
* context -- a context manager Generator that loads and locks the
model into GPU VRAM and returns the model for use.
See below for usage.
* name -- symbolic name of the model
* type -- SDModelType of the model
* type -- SubModelType of the model
* hash -- unique hash for the model
* location -- path or repo_id of the model
* revision -- revision of the model if coming from a repo id,
@ -25,7 +25,7 @@ Typical usage:
max_cache_size=8
) # gigabytes
model_info = manager.get_model('stable-diffusion-1.5', SDModelType.Diffusers)
model_info = manager.get_model('stable-diffusion-1.5', SubModelType.Diffusers)
with model_info.context as my_model:
my_model.latents_from_embeddings(...)
@ -43,7 +43,7 @@ parameter:
model_info = manager.get_model(
'clip-tokenizer',
model_type=SDModelType.Tokenizer
model_type=SubModelType.Tokenizer
)
This will raise an InvalidModelError if the format defined in the
@ -63,7 +63,7 @@ The general format of a models.yaml section is:
The type of model is given in the stanza key, and is one of
{diffusers, ckpt, vae, text_encoder, tokenizer, unet, scheduler,
safety_checker, feature_extractor, lora, textual_inversion,
controlnet}, and correspond to items in the SDModelType enum defined
controlnet}, and correspond to items in the SubModelType enum defined
in model_cache.py
The format indicates whether the model is organized as a folder with
@ -96,7 +96,7 @@ SUBMODELS:
It is also possible to fetch an isolated submodel from a diffusers
model. Use the `submodel` parameter to select which part:
vae = manager.get_model('stable-diffusion-1.5',submodel=SDModelType.Vae)
vae = manager.get_model('stable-diffusion-1.5',submodel=SubModelType.Vae)
with vae.context as my_vae:
print(type(my_vae))
# "AutoencoderKL"
@ -128,8 +128,8 @@ separated by "/". Example:
You can now use the `model_type` argument to indicate which model you
want:
tokenizer = mgr.get('clip-large',model_type=SDModelType.Tokenizer)
encoder = mgr.get('clip-large',model_type=SDModelType.TextEncoder)
tokenizer = mgr.get('clip-large',model_type=SubModelType.Tokenizer)
encoder = mgr.get('clip-large',model_type=SubModelType.TextEncoder)
OTHER FUNCTIONS:
@ -164,15 +164,24 @@ import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util import CUDA_DEVICE, download_with_resume
from .model_cache import ModelCache, ModelLocker
from .models import BaseModelType, ModelType, SubModelType, MODEL_CLASSES
from .models import BaseModelType, SubModelType, MODEL_CLASSES
# We are only starting to number the config file with release 3.
# The config file version doesn't have to start at release version, but it will help
# reduce confusion.
CONFIG_FILE_VERSION='3.0.0'
# wanted to use pydantic here, but Generator objects not supported
# temporary forward definitions to avoid circular import errors.
class ModelLocker(object):
"Forward declaration"
pass
class ModelCache(object):
"Forward declaration"
pass
@dataclass
class SDModelInfo():
class ModelInfo():
context: ModelLocker
name: str
type: ModelType
@ -303,7 +312,7 @@ class ModelManager(object):
submodel_type: Optional[SubModelType] = None
):
"""Given a model named identified in models.yaml, return
an SDModelInfo object describing it.
an ModelInfo object describing it.
:param model_name: symbolic name of the model in models.yaml
:param model_type: ModelType enum indicating the type of model to return
:param submode_typel: an ModelType enum indicating the portion of
@ -389,7 +398,7 @@ class ModelManager(object):
hash = "<NO_HASH>" # TODO:
return SDModelInfo(
return ModelInfo(
context = model_context,
name = model_name,
base_model = base_model,
@ -746,62 +755,3 @@ class ModelManager(object):
if self.config_path:
self.commit()
def _delete_defunct_models(self):
'''
Remove models no longer on disk.
'''
config = self.config
to_delete = set()
for key in config:
if 'path' not in config[key]:
continue
path = self.globals.root_dir / config[key].path
if path.exists():
continue
to_delete.add(key)
for key in to_delete:
self.logger.warn(f'Removing model {key} from in-memory config because its path is no longer on disk')
config.pop(key)
def scan_models_directory(self, include_diffusers:bool=False):
'''
Scan the models directory for loras, textual_inversions and controlnets
and create appropriate entries in the in-memory omegaconf. Diffusers
will not be added unless include_diffusers is true.
'''
self._delete_defunct_models()
model_directory = self.globals.models_path
config = self.config
for root, dirs, files in os.walk(model_directory):
parents = root.split('/')
subpaths = parents[parents.index('models')+1:]
if len(subpaths) < 2:
continue
base, model_type, *_ = subpaths
if model_type == "diffusers" and not include_diffusers:
continue
for d in dirs:
config[f'{model_type}/{d}'] = dict(
path = os.path.join(root,d),
description = f'{model_type} model {d}',
format = 'folder',
base = base,
)
for f in files:
basename = Path(f).stem
format = Path(f).suffix[1:]
config[f'{model_type}/{basename}'] = dict(
path = os.path.join(root,f),
description = f'{model_type} model {basename}',
format = format,
base = base,
)

View File

@ -0,0 +1,333 @@
import json
import traceback
import torch
import safetensors.torch
from dataclasses import dataclass
from diffusers import ModelMixin
from pathlib import Path
from typing import Callable, Literal, Union, Dict
from picklescan.scanner import scan_file_path
import invokeai.backend.util.logging as logger
from .models import BaseModelType, ModelType, VariantType
@dataclass
class ModelVariantInfo(object):
model_type: ModelType
base_type: BaseModelType
variant_type: VariantType
class ProbeBase(object):
'''forward declaration'''
pass
class ModelProbe(object):
PROBES = {
'folder': { },
'file': { },
}
CLASS2TYPE = {
"StableDiffusionPipeline" : ModelType.Pipeline,
"AutoencoderKL": ModelType.Vae,
"ControlNetModel" : ModelType.ControlNet,
}
@classmethod
def register_probe(cls,
format: Literal['folder','file'],
model_type: ModelType,
probe_class: ProbeBase):
cls.PROBES[format][model_type] = probe_class
@classmethod
def probe(cls,
model_path: Path,
model: Union[Dict, ModelMixin] = None,
base_helper: Callable[[Path],BaseModelType] = None)->ModelVariantInfo:
'''
Probe the model at model_path and return sufficient information about it
to place it somewhere in the models directory hierarchy. If the model is
already loaded into memory, you may provide it as model in order to avoid
opening it a second time. The base_helper callable is a function that receives
the path to the model and returns the BaseModelType. It is called to distinguish
between V2-Base and V2-768 SD models.
'''
format = 'folder' if model_path.is_dir() else 'file'
model_info = None
try:
model_type = cls.get_model_type_from_folder(model_path, model) \
if format == 'folder' \
else cls.get_model_type_from_checkpoint(model_path, model)
probe_class = cls.PROBES[format].get(model_type)
if not probe_class:
return None
probe = probe_class(model_path, model, base_helper)
base_type = probe.get_base_type()
variant_type = probe.get_variant_type()
model_info = ModelVariantInfo(
model_type = model_type,
base_type = base_type,
variant_type = variant_type,
)
except (KeyError, ValueError) as e:
logger.error(f'An error occurred while probing {model_path}: {str(e)}')
logger.error(traceback.format_exc())
return model_info
@classmethod
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict)->ModelType:
checkpoint = checkpoint or cls._scan_and_load_checkpoint(model_path)
state_dict = checkpoint.get("state_dict") or checkpoint
if any([x.startswith("model.diffusion_model") for x in state_dict.keys()]):
return ModelType.Pipeline
if any([x.startswith("encoder.conv_in") for x in state_dict.keys()]):
return ModelType.Vae
if "string_to_token" in state_dict or "emb_params" in state_dict:
return ModelType.TextualInversion
if any([x.startswith("lora") for x in state_dict.keys()]):
return ModelType.Lora
if any([x.startswith("control_model") for x in state_dict.keys()]):
return ModelType.ControlNet
if any([x.startswith("input_blocks") for x in state_dict.keys()]):
return ModelType.ControlNet
return None # give up
@classmethod
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin)->ModelType:
'''
Get the model type of a hugging-face style folder.
'''
if (folder_path / 'learned_embeds.bin').exists():
return ModelType.TextualInversion
if (folder_path / 'pytorch_lora_weights.bin').exists():
return ModelType.Lora
i = folder_path / 'model_index.json'
c = folder_path / 'config.json'
config_path = i if i.exists() else c if c.exists() else None
if config_path:
conf = json.loads(config_path)
class_name = conf['_class_name']
if type := cls.CLASS2TYPE.get(class_name):
return type
# give up
raise ValueError("Unable to determine model type of {model_path}")
@classmethod
def _scan_and_load_checkpoint(cls,model_path: Path)->dict:
if model_path.suffix.endswith((".ckpt", ".pt", ".bin")):
cls._scan_model(model_path, model_path)
return torch.load(model_path)
else:
return safetensors.torch.load_file(model_path)
@classmethod
def _scan_model(cls, model_name, checkpoint):
"""
Apply picklescanner to the indicated checkpoint and issue a warning
and option to exit if an infected file is identified.
"""
# scan model
scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0:
raise "The model {model_name} is potentially infected by malware. Aborting import."
###################################################3
# Checkpoint probing
###################################################3
class ProbeBase(object):
def get_base_type(self)->BaseModelType:
pass
def get_variant_type(self)->VariantType:
pass
class CheckpointProbeBase(ProbeBase):
def __init__(self,
checkpoint_path: Path,
checkpoint: dict,
helper: Callable[[Path],BaseModelType] = None
)->BaseModelType:
self.checkpoint = checkpoint or ModelProbe._scan_and_load_checkpoint(checkpoint_path)
self.checkpoint_path = checkpoint_path
self.helper = helper
def get_base_type(self)->BaseModelType:
pass
def get_variant_type(self)-> VariantType:
model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path,self.checkpoint)
if model_type != ModelType.Pipeline:
return VariantType.Normal
state_dict = self.checkpoint.get('state_dict') or self.checkpoint
in_channels = state_dict[
"model.diffusion_model.input_blocks.0.0.weight"
].shape[1]
if in_channels == 9:
return VariantType.Inpaint
elif in_channels == 5:
return VariantType.Depth
else:
return None
class PipelineCheckpointProbe(CheckpointProbeBase):
def get_base_type(self)->BaseModelType:
checkpoint = self.checkpoint
helper = self.helper
state_dict = self.checkpoint.get('state_dict') or checkpoint
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
return BaseModelType.StableDiffusion1_5
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
if 'global_step' in checkpoint:
if checkpoint['global_step'] == 220000:
return BaseModelType.StableDiffusion2Base
elif checkpoint["global_step"] == 110000:
return BaseModelType.StableDiffusion2
if self.checkpoint_path and helper:
return helper(self.checkpoint_path)
else:
return None
class VaeCheckpointProbe(CheckpointProbeBase):
def get_base_type(self)->BaseModelType:
# I can't find any standalone 2.X VAEs to test with!
return BaseModelType.StableDiffusion1_5
class LoRACheckpointProbe(CheckpointProbeBase):
def get_base_type(self)->BaseModelType:
checkpoint = self.checkpoint
key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a"
lora_token_vector_length = (
checkpoint[key1].shape[1]
if key1 in checkpoint
else checkpoint[key2].shape[0]
if key2 in checkpoint
else 768
)
if lora_token_vector_length == 768:
return BaseModelType.StableDiffusion1_5
elif lora_token_vector_length == 1024:
return BaseModelType.StableDiffusion2
else:
return None
class TextualInversionCheckpointProbe(CheckpointProbeBase):
def get_base_type(self)->BaseModelType:
checkpoint = self.checkpoint
if 'string_to_token' in checkpoint:
token_dim = list(checkpoint['string_to_param'].values())[0].shape[-1]
elif 'emb_params' in checkpoint:
token_dim = checkpoint['emb_params'].shape[-1]
else:
token_dim = list(checkpoint.values())[0].shape[0]
if token_dim == 768:
return BaseModelType.StableDiffusion1_5
elif token_dim == 1024:
return BaseModelType.StableDiffusion2Base
else:
return None
class ControlNetCheckpointProbe(CheckpointProbeBase):
def get_base_type(self)->BaseModelType:
checkpoint = self.checkpoint
for key_name in ('control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight',
'input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight'
):
if key_name not in checkpoint:
continue
if checkpoint[key_name].shape[-1] == 768:
return BaseModelType.StableDiffusion1_5
elif self.checkpoint_path and self.helper:
return self.helper(self.checkpoint_path)
########################################################
# classes for probing folders
#######################################################
class FolderProbeBase(ProbeBase):
def __init__(self,
model: ModelMixin,
folder_path: Path,
helper: Callable=None # not used
):
self.model = model
self.folder_path = folder_path
def get_variant_type(self)->VariantType:
# only works for pipelines
config_file = self.folder_path / 'unet' / 'config.json'
if not config_file.exists():
return VariantType.Normal
conf = json.loads(config_file)
channels = conf['in_channels']
if channels == 9:
return VariantType.Inpainting
elif channels == 5:
return VariantType.Depth
elif channels == 4:
return VariantType.Normal
else:
return VariantType.Normal
class PipelineFolderProbe(FolderProbeBase):
def get_base_type(self)->BaseModelType:
config_file = self.folder_path / 'scheduler' / 'scheduler_config.json'
if not config_file.exists():
return None
conf = json.load(config_file)
if conf['prediction_type'] == "v_prediction":
return BaseModelType.StableDiffusion2
elif conf['prediction_type'] == 'epsilon':
return BaseModelType.StableDiffusion2Base
else:
return BaseModelType.StableDiffusion2
class VaeFolderProbe(FolderProbeBase):
def get_base_type(self)->BaseModelType:
return BaseModelType.StableDiffusion1_5
class TextualInversionFolderProbe(FolderProbeBase):
def get_base_type(self)->BaseModelType:
path = self.folder_path / 'learned_embeds.bin'
if not path.exists():
return None
checkpoint = ModelProbe._scan_and_load_checkpoint(path)
return TextualInversionCheckpointProbe(checkpoint).get_base_type
class ControlNetFolderProbe(FolderProbeBase):
def get_base_type(self)->BaseModelType:
config_file = self.folder_path / 'scheduler_config.json'
if not config_file.exists():
return None
config = json.load(config_file)
# no obvious way to distinguish between sd2-base and sd2-768
return BaseModelType.StableDiffusion1_5 \
if config['cross_attention_dim']==768 \
else BaseModelType.StableDiffusion2
class LoRAFolderProbe(FolderProbeBase):
# I've never seen one of these in the wild, so this is a noop
pass
############## register probe classes ######
ModelProbe.register_probe('folder', ModelType.Pipeline, PipelineFolderProbe)
ModelProbe.register_probe('folder', ModelType.Vae, VaeFolderProbe)
ModelProbe.register_probe('folder', ModelType.Lora, LoRAFolderProbe)
ModelProbe.register_probe('folder', ModelType.TextualInversion, TextualInversionFolderProbe)
ModelProbe.register_probe('folder', ModelType.ControlNet, ControlNetFolderProbe)
ModelProbe.register_probe('file', ModelType.Pipeline, PipelineCheckpointProbe)
ModelProbe.register_probe('file', ModelType.Vae, VaeCheckpointProbe)
ModelProbe.register_probe('file', ModelType.Lora, LoRACheckpointProbe)
ModelProbe.register_probe('file', ModelType.TextualInversion, TextualInversionCheckpointProbe)
ModelProbe.register_probe('file', ModelType.ControlNet, ControlNetCheckpointProbe)

View File

@ -1,8 +1,14 @@
import sys
from dataclasses import dataclass
from enum import Enum
import torch
import safetensors.torch
from diffusers import ConfigMixin
from diffusers.utils import is_safetensors_available
from omegaconf import DictConfig
from pathlib import Path
from pydantic import BaseModel, Field, root_validator
from typing import Union, List, Type, Optional
class BaseModelType(str, Enum):
# TODO: maybe then add sample size(512/768)?
@ -26,42 +32,13 @@ class SubModelType:
Vae = "vae"
Scheduler = "scheduler"
SafetyChecker = "safety_checker"
FeatureExtractor = "feature_extractor"
#MoVQ = "movq"
MODEL_CLASSES = {
BaseModel.StableDiffusion1_5: {
ModelType.Pipeline: StableDiffusionModel,
ModelType.Classifier: ClassifierModel,
ModelType.Vae: VaeModel,
ModelType.Lora: LoraModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
},
BaseModel.StableDiffusion2: {
ModelType.Pipeline: StableDiffusionModel,
ModelType.Classifier: ClassifierModel,
ModelType.Vae: VaeModel,
ModelType.Lora: LoraModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
},
BaseModel.StableDiffusion2Base: {
ModelType.Pipeline: StableDiffusionModel,
ModelType.Classifier: ClassifierModel,
ModelType.Vae: VaeModel,
ModelType.Lora: LoraModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
},
#BaseModel.Kandinsky2_1: {
# ModelType.Pipeline: Kandinsky2_1Model,
# ModelType.Classifier: ClassifierModel,
# ModelType.MoVQ: MoVQModel,
# ModelType.Lora: LoraModel,
# ModelType.ControlNet: ControlNetModel,
# ModelType.TextualInversion: TextualInversionModel,
#},
}
class VariantType(str, Enum):
Normal = "normal"
Inpaint = "inpaint"
Depth = "depth"
class EmptyConfigLoader(ConfigMixin):
@classmethod
@ -323,7 +300,7 @@ class ClassifierModel(ModelBase):
#child_sizes: Dict[str, int]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == SDModelType.Classifier
assert model_type == ModelType.Classifier
super().__init__(model_path, base_model, model_type)
self.child_types: Dict[str, Type] = dict()
@ -354,8 +331,8 @@ class ClassifierModel(ModelBase):
else:
raise Exception("Invalid classifier model! (Failed to detect tokenizer type)")
self.child_types[SDModelType.Tokenizer] = self._hf_definition_to_type(["transformers", tokenizer_class_name])
self.child_sizes[SDModelType.Tokenizer] = 0
self.child_types[SubModelType.Tokenizer] = self._hf_definition_to_type(["transformers", tokenizer_class_name])
self.child_sizes[SubModelType.Tokenizer] = 0
def _load_text_encoder(self, main_config: dict):
@ -366,12 +343,12 @@ class ClassifierModel(ModelBase):
else:
raise Exception("Invalid classifier model! (Failed to detect text_encoder type)")
self.child_types[SDModelType.TextEncoder] = self._hf_definition_to_type(["transformers", text_encoder_class_name])
self.child_sizes[SDModelType.TextEncoder] = calc_model_size_by_fs(self.model_path)
self.child_types[SubModelType.TextEncoder] = self._hf_definition_to_type(["transformers", text_encoder_class_name])
self.child_sizes[SubModelType.TextEncoder] = calc_model_size_by_fs(self.model_path)
def _load_feature_extractor(self, main_config: dict):
self.child_sizes[SDModelType.FeatureExtractor] = 0
self.child_sizes[SubModelType.FeatureExtractor] = 0
try:
feature_extractor_config = EmptyConfigLoader.load_config(self.model_path, config_name="preprocessor_config.json")
except:
@ -379,12 +356,12 @@ class ClassifierModel(ModelBase):
try:
feature_extractor_class_name = feature_extractor_config["feature_extractor_type"]
self.child_types[SDModelType.FeatureExtractor] = self._hf_definition_to_type(["transformers", feature_extractor_class_name])
self.child_types[SubModelType.FeatureExtractor] = self._hf_definition_to_type(["transformers", feature_extractor_class_name])
except:
raise Exception("Invalid classifier model! (Unknown feature_extrator type)")
def get_size(self, child_type: Optional[SDModelType] = None):
def get_size(self, child_type: Optional[SubModelType] = None):
if child_type is None:
return sum(self.child_sizes.values())
else:
@ -394,7 +371,7 @@ class ClassifierModel(ModelBase):
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SDModelType] = None,
child_type: Optional[SubModelType] = None,
):
if child_type is None:
raise Exception("Child model type can't be null on classififer model")
@ -437,7 +414,7 @@ class VaeModel(ModelBase):
except:
raise Exception("Invalid vae model! (Unkown vae type)")
def get_size(self, child_type: Optional[SDModelType] = None):
def get_size(self, child_type: Optional[SubModelType] = None):
if child_type is not None:
raise Exception("There is no child models in vae model")
return self.model_size
@ -445,7 +422,7 @@ class VaeModel(ModelBase):
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SDModelType] = None,
child_type: Optional[SubModelType] = None,
):
if child_type is not None:
raise Exception("There is no child models in vae model")
@ -476,7 +453,7 @@ class LoRAModel(ModelBase):
self.model_size = os.path.getsize(self.model_path)
def get_size(self, child_type: Optional[SDModelType] = None):
def get_size(self, child_type: Optional[SubModelType] = None):
if child_type is not None:
raise Exception("There is no child models in lora")
return self.model_size
@ -484,7 +461,7 @@ class LoRAModel(ModelBase):
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SDModelType] = None,
child_type: Optional[SubModelType] = None,
):
if child_type is not None:
raise Exception("There is no child models in lora")
@ -505,7 +482,6 @@ class LoRAModel(ModelBase):
# TODO: add diffusers lora when it stabilizes a bit
return model_path
class TextualInversionModel(ModelBase):
#model_size: int
@ -515,7 +491,7 @@ class TextualInversionModel(ModelBase):
self.model_size = os.path.getsize(self.model_path)
def get_size(self, child_type: Optional[SDModelType] = None):
def get_size(self, child_type: Optional[SubModelType] = None):
if child_type is not None:
raise Exception("There is no child models in textual inversion")
return self.model_size
@ -523,7 +499,7 @@ class TextualInversionModel(ModelBase):
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SDModelType] = None,
child_type: Optional[SubModelType] = None,
):
if child_type is not None:
raise Exception("There is no child models in textual inversion")
@ -542,6 +518,10 @@ class TextualInversionModel(ModelBase):
model_path = Path(model_path)
return model_path
class ControlNetModel(ModelBase):
"""requires implementation"""
pass
def calc_model_size_by_fs(
model_path: str,
subfolder: Optional[str] = None,
@ -710,3 +690,39 @@ def _convert_vae_ckpt_and_cache(self, mconfig: DictConfig) -> Path:
safe_serialization=is_safetensors_available()
)
return diffusers_path
MODEL_CLASSES = {
BaseModelType.StableDiffusion1_5: {
ModelType.Pipeline: StableDiffusionModel,
ModelType.Classifier: ClassifierModel,
ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
},
BaseModelType.StableDiffusion2: {
ModelType.Pipeline: StableDiffusionModel,
ModelType.Classifier: ClassifierModel,
ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
},
BaseModelType.StableDiffusion2Base: {
ModelType.Pipeline: StableDiffusionModel,
ModelType.Classifier: ClassifierModel,
ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
},
#BaseModel.Kandinsky2_1: {
# ModelType.Pipeline: Kandinsky2_1Model,
# ModelType.Classifier: ClassifierModel,
# ModelType.MoVQ: MoVQModel,
# ModelType.Lora: LoraModel,
# ModelType.ControlNet: ControlNetModel,
# ModelType.TextualInversion: TextualInversionModel,
#},
}

55
scripts/scan_models_directory.py Normal file → Executable file
View File

@ -1,3 +1,5 @@
#!/usr/bin/env python
'''
Scan the models directory and print out a new models.yaml
'''
@ -12,6 +14,12 @@ from omegaconf import OmegaConf
def main():
parser = argparse.ArgumentParser(description="Model directory scanner")
parser.add_argument('models_directory')
parser.add_argument('--all-models',
default=False,
action='store_true',
help='If true, then generates stanzas for all models; otherwise just diffusers'
)
args = parser.parse_args()
directory = args.models_directory
@ -19,29 +27,30 @@ def main():
conf['_version'] = '3.0.0'
for root, dirs, files in os.walk(directory):
for d in dirs:
parents = root.split('/')
subpaths = parents[parents.index('models')+1:]
if len(subpaths) < 2:
continue
base, model_type, *_ = subpaths
conf[f'{model_type}/{d}'] = dict(
path = os.path.join(root,d),
description = f'{model_type} model {d}',
format = 'folder',
base = base,
)
for f in files:
basename = Path(f).stem
format = Path(f).suffix[1:]
conf[f'{model_type}/{basename}'] = dict(
path = os.path.join(root,f),
description = f'{model_type} model {basename}',
format = format,
base = base,
)
parents = root.split('/')
subpaths = parents[parents.index('models')+1:]
if len(subpaths) < 2:
continue
base, model_type, *_ = subpaths
if args.all_models or model_type=='diffusers':
for d in dirs:
conf[f'{base}/{model_type}/{d}'] = dict(
path = os.path.join(root,d),
description = f'{model_type} model {d}',
format = 'folder',
base = base,
)
for f in files:
basename = Path(f).stem
format = Path(f).suffix[1:]
conf[f'{base}/{model_type}/{basename}'] = dict(
path = os.path.join(root,f),
description = f'{model_type} model {basename}',
format = format,
base = base,
)
OmegaConf.save(config=dict(sorted(conf.items())), f=sys.stdout)