mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
move all installation code out of model_manager
This commit is contained in:
parent
74b43c9bdf
commit
000626ab2e
@ -9,5 +9,5 @@ from .generator import (
|
|||||||
Img2Img,
|
Img2Img,
|
||||||
Inpaint
|
Inpaint
|
||||||
)
|
)
|
||||||
from .model_management import ModelManager, ModelCache, SDModelType, SDModelInfo
|
from .model_management import ModelManager, ModelCache, ModelType, ModelInfo
|
||||||
from .safety_checker import SafetyChecker
|
from .safety_checker import SafetyChecker
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
Initialization file for invokeai.backend.model_management
|
Initialization file for invokeai.backend.model_management
|
||||||
"""
|
"""
|
||||||
from .model_manager import ModelManager, SDModelInfo
|
from .model_manager import ModelManager, ModelInfo
|
||||||
from .model_cache import ModelCache, SDModelType
|
from .model_cache import ModelCache
|
||||||
|
from .models import ModelType
|
||||||
|
@ -29,11 +29,8 @@ import torch
|
|||||||
|
|
||||||
from diffusers import logging as diffusers_logging
|
from diffusers import logging as diffusers_logging
|
||||||
from transformers import logging as transformers_logging
|
from transformers import logging as transformers_logging
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
|
from .models import ModelType, SubModelType, ModelBase
|
||||||
from .model_manager import SDModelInfo, ModelType, SubModelType, ModelBase
|
|
||||||
|
|
||||||
|
|
||||||
# Maximum size of the cache, in gigs
|
# Maximum size of the cache, in gigs
|
||||||
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
|
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
|
||||||
@ -50,6 +47,10 @@ class ModelCache(object):
|
|||||||
"Forward declaration"
|
"Forward declaration"
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
class SDModelInfo(object):
|
||||||
|
"""Forward declaration"""
|
||||||
|
pass
|
||||||
|
|
||||||
class _CacheRecord:
|
class _CacheRecord:
|
||||||
size: int
|
size: int
|
||||||
model: Any
|
model: Any
|
||||||
|
201
invokeai/backend/model_management/model_install.py
Normal file
201
invokeai/backend/model_management/model_install.py
Normal file
@ -0,0 +1,201 @@
|
|||||||
|
"""
|
||||||
|
Routines for downloading and installing models.
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import safetensors
|
||||||
|
import safetensors.torch
|
||||||
|
import torch
|
||||||
|
import traceback
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from diffusers import ModelMixin
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Callable
|
||||||
|
from pathlib import Path
|
||||||
|
from picklescan.scanner import scan_file_path
|
||||||
|
|
||||||
|
import invokeai.backend.util.logging as logger
|
||||||
|
from .models import BaseModelType, ModelType
|
||||||
|
|
||||||
|
class CheckpointProbe(object):
|
||||||
|
PROBES = dict() # see below for redefinition
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
checkpoint_path: Path,
|
||||||
|
checkpoint: dict = None,
|
||||||
|
helper: Callable[[Path], BaseModelType]=None
|
||||||
|
):
|
||||||
|
checkpoint = checkpoint or self._scan_and_load_checkpoint(self.checkpoint_path)
|
||||||
|
self.checkpoint = checkpoint
|
||||||
|
self.checkpoint_path = checkpoint_path
|
||||||
|
self.helper = helper
|
||||||
|
|
||||||
|
def probe(self) -> ModelVariantInfo:
|
||||||
|
'''
|
||||||
|
Probes the checkpoint at path `checkpoint_path` and return
|
||||||
|
a ModelType object indicating the model base, model type and
|
||||||
|
model variant for the checkpoint.
|
||||||
|
'''
|
||||||
|
checkpoint = self.checkpoint
|
||||||
|
state_dict = checkpoint.get("state_dict") or checkpoint
|
||||||
|
|
||||||
|
model_info = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
model_type = self.get_checkpoint_type(state_dict)
|
||||||
|
if not model_type:
|
||||||
|
if self.checkpoint_path.name == "learned_embeds.bin":
|
||||||
|
model_type = ModelType.TextualInversion
|
||||||
|
else:
|
||||||
|
return None # we give up
|
||||||
|
probe = self.PROBES[model_type]()
|
||||||
|
base_type = probe.get_base_type(checkpoint, self.checkpoint_path, self.helper)
|
||||||
|
variant_type = probe.get_variant_type(model_type, checkpoint)
|
||||||
|
|
||||||
|
model_info = ModelVariantInfo(
|
||||||
|
model_type = model_type,
|
||||||
|
base_type = base_type,
|
||||||
|
variant_type = variant_type,
|
||||||
|
)
|
||||||
|
except (KeyError, ValueError) as e:
|
||||||
|
logger.error(f'An error occurred while probing {self.checkpoint_path}: {str(e)}')
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
|
return model_info
|
||||||
|
|
||||||
|
class CheckpointProbeBase(object):
|
||||||
|
def get_base_type(self,
|
||||||
|
checkpoint: dict,
|
||||||
|
checkpoint_path: Path = None,
|
||||||
|
helper: Callable[[Path],BaseModelType] = None
|
||||||
|
)->BaseModelType:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_variant_type(self,
|
||||||
|
model_type: ModelType,
|
||||||
|
checkpoint: dict,
|
||||||
|
)-> VariantType:
|
||||||
|
if model_type != ModelType.Pipeline:
|
||||||
|
return None
|
||||||
|
state_dict = checkpoint.get('state_dict') or checkpoint
|
||||||
|
in_channels = state_dict[
|
||||||
|
"model.diffusion_model.input_blocks.0.0.weight"
|
||||||
|
].shape[1]
|
||||||
|
if in_channels == 9:
|
||||||
|
return VariantType.Inpaint
|
||||||
|
elif in_channels == 5:
|
||||||
|
return VariantType.depth
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointProbe(CheckpointProbeBase):
|
||||||
|
def get_base_type(self,
|
||||||
|
checkpoint: dict,
|
||||||
|
checkpoint_path: Path = None,
|
||||||
|
helper: Callable[[Path],BaseModelType] = None
|
||||||
|
)->BaseModelType:
|
||||||
|
state_dict = checkpoint.get('state_dict') or checkpoint
|
||||||
|
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||||
|
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
|
||||||
|
return BaseModelType.StableDiffusion1_5
|
||||||
|
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
||||||
|
if 'global_step' in checkpoint:
|
||||||
|
if checkpoint['global_step'] == 220000:
|
||||||
|
return BaseModelType.StableDiffusion2Base
|
||||||
|
elif checkpoint["global_step"] == 110000:
|
||||||
|
return BaseModelType.StableDiffusion2
|
||||||
|
if checkpoint_path and helper:
|
||||||
|
return helper(checkpoint_path)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
class VaeProbe(CheckpointProbeBase):
|
||||||
|
def get_base_type(self,
|
||||||
|
checkpoint: dict,
|
||||||
|
checkpoint_path: Path = None,
|
||||||
|
helper: Callable[[Path],BaseModelType] = None
|
||||||
|
)->BaseModelType:
|
||||||
|
# I can't find any standalone 2.X VAEs to test with!
|
||||||
|
return BaseModelType.StableDiffusion1_5
|
||||||
|
|
||||||
|
class LoRAProbe(CheckpointProbeBase):
|
||||||
|
def get_base_type(self,
|
||||||
|
checkpoint: dict,
|
||||||
|
checkpoint_path: Path = None,
|
||||||
|
helper: Callable[[Path],BaseModelType] = None
|
||||||
|
)->BaseModelType:
|
||||||
|
key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
|
||||||
|
key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a"
|
||||||
|
lora_token_vector_length = (
|
||||||
|
checkpoint[key1].shape[1]
|
||||||
|
if key1 in checkpoint
|
||||||
|
else checkpoint[key2].shape[0]
|
||||||
|
if key2 in checkpoint
|
||||||
|
else 768
|
||||||
|
)
|
||||||
|
if lora_token_vector_length == 768:
|
||||||
|
return BaseModelType.StableDiffusion1_5
|
||||||
|
elif lora_token_vector_length == 1024:
|
||||||
|
return BaseModelType.StableDiffusion2
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
class TextualInversionProbe(CheckpointProbeBase):
|
||||||
|
def get_base_type(self,
|
||||||
|
checkpoint: dict,
|
||||||
|
checkpoint_path: Path = None,
|
||||||
|
helper: Callable[[Path],BaseModelType] = None
|
||||||
|
)->BaseModelType:
|
||||||
|
|
||||||
|
if 'string_to_token' in checkpoint:
|
||||||
|
token_dim = list(checkpoint['string_to_param'].values())[0].shape[-1]
|
||||||
|
elif 'emb_params' in checkpoint:
|
||||||
|
token_dim = checkpoint['emb_params'].shape[-1]
|
||||||
|
else:
|
||||||
|
token_dim = list(checkpoint.values())[0].shape[0]
|
||||||
|
if token_dim == 768:
|
||||||
|
return BaseModelType.StableDiffusion1_5
|
||||||
|
elif token_dim == 1024:
|
||||||
|
return BaseModelType.StableDiffusion2Base
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
class ControlNetProbe(CheckpointProbeBase):
|
||||||
|
def get_base_type(self,
|
||||||
|
checkpoint: dict,
|
||||||
|
checkpoint_path: Path = None,
|
||||||
|
helper: Callable[[Path],BaseModelType] = None
|
||||||
|
)->BaseModelType:
|
||||||
|
for key_name in ('control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight',
|
||||||
|
'input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight'
|
||||||
|
):
|
||||||
|
if key_name not in checkpoint:
|
||||||
|
continue
|
||||||
|
if checkpoint[key_name].shape[-1] == 768:
|
||||||
|
return BaseModelType.StableDiffusion1_5
|
||||||
|
elif checkpoint_path and helper:
|
||||||
|
return helper(checkpoint_path)
|
||||||
|
PROBES = {
|
||||||
|
ModelType.Pipeline: CheckpointProbe,
|
||||||
|
ModelType.Vae: VaeProbe,
|
||||||
|
ModelType.Lora: LoRAProbe,
|
||||||
|
ModelType.TextualInversion: TextualInversionProbe,
|
||||||
|
ModelType.ControlNet: ControlNetProbe,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_checkpoint_type(cls, state_dict: dict) -> ModelType:
|
||||||
|
if any([x.startswith("model.diffusion_model") for x in state_dict.keys()]):
|
||||||
|
return ModelType.Pipeline
|
||||||
|
if any([x.startswith("encoder.conv_in") for x in state_dict.keys()]):
|
||||||
|
return ModelType.Vae
|
||||||
|
if "string_to_token" in state_dict or "emb_params" in state_dict:
|
||||||
|
return ModelType.TextualInversion
|
||||||
|
if any([x.startswith("lora") for x in state_dict.keys()]):
|
||||||
|
return ModelType.Lora
|
||||||
|
if any([x.startswith("control_model") for x in state_dict.keys()]):
|
||||||
|
return ModelType.ControlNet
|
||||||
|
if any([x.startswith("input_blocks") for x in state_dict.keys()]):
|
||||||
|
return ModelType.ControlNet
|
||||||
|
return None # give up
|
||||||
|
|
@ -3,13 +3,13 @@ symbolic diffusers model names to the paths and repo_ids used
|
|||||||
by the underlying `from_pretrained()` call.
|
by the underlying `from_pretrained()` call.
|
||||||
|
|
||||||
For fetching models, use manager.get_model('symbolic name'). This will
|
For fetching models, use manager.get_model('symbolic name'). This will
|
||||||
return a SDModelInfo object that contains the following attributes:
|
return a ModelInfo object that contains the following attributes:
|
||||||
|
|
||||||
* context -- a context manager Generator that loads and locks the
|
* context -- a context manager Generator that loads and locks the
|
||||||
model into GPU VRAM and returns the model for use.
|
model into GPU VRAM and returns the model for use.
|
||||||
See below for usage.
|
See below for usage.
|
||||||
* name -- symbolic name of the model
|
* name -- symbolic name of the model
|
||||||
* type -- SDModelType of the model
|
* type -- SubModelType of the model
|
||||||
* hash -- unique hash for the model
|
* hash -- unique hash for the model
|
||||||
* location -- path or repo_id of the model
|
* location -- path or repo_id of the model
|
||||||
* revision -- revision of the model if coming from a repo id,
|
* revision -- revision of the model if coming from a repo id,
|
||||||
@ -25,7 +25,7 @@ Typical usage:
|
|||||||
max_cache_size=8
|
max_cache_size=8
|
||||||
) # gigabytes
|
) # gigabytes
|
||||||
|
|
||||||
model_info = manager.get_model('stable-diffusion-1.5', SDModelType.Diffusers)
|
model_info = manager.get_model('stable-diffusion-1.5', SubModelType.Diffusers)
|
||||||
with model_info.context as my_model:
|
with model_info.context as my_model:
|
||||||
my_model.latents_from_embeddings(...)
|
my_model.latents_from_embeddings(...)
|
||||||
|
|
||||||
@ -43,7 +43,7 @@ parameter:
|
|||||||
|
|
||||||
model_info = manager.get_model(
|
model_info = manager.get_model(
|
||||||
'clip-tokenizer',
|
'clip-tokenizer',
|
||||||
model_type=SDModelType.Tokenizer
|
model_type=SubModelType.Tokenizer
|
||||||
)
|
)
|
||||||
|
|
||||||
This will raise an InvalidModelError if the format defined in the
|
This will raise an InvalidModelError if the format defined in the
|
||||||
@ -63,7 +63,7 @@ The general format of a models.yaml section is:
|
|||||||
The type of model is given in the stanza key, and is one of
|
The type of model is given in the stanza key, and is one of
|
||||||
{diffusers, ckpt, vae, text_encoder, tokenizer, unet, scheduler,
|
{diffusers, ckpt, vae, text_encoder, tokenizer, unet, scheduler,
|
||||||
safety_checker, feature_extractor, lora, textual_inversion,
|
safety_checker, feature_extractor, lora, textual_inversion,
|
||||||
controlnet}, and correspond to items in the SDModelType enum defined
|
controlnet}, and correspond to items in the SubModelType enum defined
|
||||||
in model_cache.py
|
in model_cache.py
|
||||||
|
|
||||||
The format indicates whether the model is organized as a folder with
|
The format indicates whether the model is organized as a folder with
|
||||||
@ -96,7 +96,7 @@ SUBMODELS:
|
|||||||
It is also possible to fetch an isolated submodel from a diffusers
|
It is also possible to fetch an isolated submodel from a diffusers
|
||||||
model. Use the `submodel` parameter to select which part:
|
model. Use the `submodel` parameter to select which part:
|
||||||
|
|
||||||
vae = manager.get_model('stable-diffusion-1.5',submodel=SDModelType.Vae)
|
vae = manager.get_model('stable-diffusion-1.5',submodel=SubModelType.Vae)
|
||||||
with vae.context as my_vae:
|
with vae.context as my_vae:
|
||||||
print(type(my_vae))
|
print(type(my_vae))
|
||||||
# "AutoencoderKL"
|
# "AutoencoderKL"
|
||||||
@ -128,8 +128,8 @@ separated by "/". Example:
|
|||||||
You can now use the `model_type` argument to indicate which model you
|
You can now use the `model_type` argument to indicate which model you
|
||||||
want:
|
want:
|
||||||
|
|
||||||
tokenizer = mgr.get('clip-large',model_type=SDModelType.Tokenizer)
|
tokenizer = mgr.get('clip-large',model_type=SubModelType.Tokenizer)
|
||||||
encoder = mgr.get('clip-large',model_type=SDModelType.TextEncoder)
|
encoder = mgr.get('clip-large',model_type=SubModelType.TextEncoder)
|
||||||
|
|
||||||
OTHER FUNCTIONS:
|
OTHER FUNCTIONS:
|
||||||
|
|
||||||
@ -164,15 +164,24 @@ import invokeai.backend.util.logging as logger
|
|||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.util import CUDA_DEVICE, download_with_resume
|
from invokeai.backend.util import CUDA_DEVICE, download_with_resume
|
||||||
from .model_cache import ModelCache, ModelLocker
|
from .model_cache import ModelCache, ModelLocker
|
||||||
from .models import BaseModelType, ModelType, SubModelType, MODEL_CLASSES
|
from .models import BaseModelType, SubModelType, MODEL_CLASSES
|
||||||
|
|
||||||
# We are only starting to number the config file with release 3.
|
# We are only starting to number the config file with release 3.
|
||||||
# The config file version doesn't have to start at release version, but it will help
|
# The config file version doesn't have to start at release version, but it will help
|
||||||
# reduce confusion.
|
# reduce confusion.
|
||||||
CONFIG_FILE_VERSION='3.0.0'
|
CONFIG_FILE_VERSION='3.0.0'
|
||||||
|
|
||||||
# wanted to use pydantic here, but Generator objects not supported
|
# temporary forward definitions to avoid circular import errors.
|
||||||
|
class ModelLocker(object):
|
||||||
|
"Forward declaration"
|
||||||
|
pass
|
||||||
|
|
||||||
|
class ModelCache(object):
|
||||||
|
"Forward declaration"
|
||||||
|
pass
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SDModelInfo():
|
class ModelInfo():
|
||||||
context: ModelLocker
|
context: ModelLocker
|
||||||
name: str
|
name: str
|
||||||
type: ModelType
|
type: ModelType
|
||||||
@ -303,7 +312,7 @@ class ModelManager(object):
|
|||||||
submodel_type: Optional[SubModelType] = None
|
submodel_type: Optional[SubModelType] = None
|
||||||
):
|
):
|
||||||
"""Given a model named identified in models.yaml, return
|
"""Given a model named identified in models.yaml, return
|
||||||
an SDModelInfo object describing it.
|
an ModelInfo object describing it.
|
||||||
:param model_name: symbolic name of the model in models.yaml
|
:param model_name: symbolic name of the model in models.yaml
|
||||||
:param model_type: ModelType enum indicating the type of model to return
|
:param model_type: ModelType enum indicating the type of model to return
|
||||||
:param submode_typel: an ModelType enum indicating the portion of
|
:param submode_typel: an ModelType enum indicating the portion of
|
||||||
@ -389,7 +398,7 @@ class ModelManager(object):
|
|||||||
|
|
||||||
hash = "<NO_HASH>" # TODO:
|
hash = "<NO_HASH>" # TODO:
|
||||||
|
|
||||||
return SDModelInfo(
|
return ModelInfo(
|
||||||
context = model_context,
|
context = model_context,
|
||||||
name = model_name,
|
name = model_name,
|
||||||
base_model = base_model,
|
base_model = base_model,
|
||||||
@ -746,62 +755,3 @@ class ModelManager(object):
|
|||||||
if self.config_path:
|
if self.config_path:
|
||||||
self.commit()
|
self.commit()
|
||||||
|
|
||||||
def _delete_defunct_models(self):
|
|
||||||
'''
|
|
||||||
Remove models no longer on disk.
|
|
||||||
'''
|
|
||||||
config = self.config
|
|
||||||
|
|
||||||
to_delete = set()
|
|
||||||
for key in config:
|
|
||||||
if 'path' not in config[key]:
|
|
||||||
continue
|
|
||||||
path = self.globals.root_dir / config[key].path
|
|
||||||
if path.exists():
|
|
||||||
continue
|
|
||||||
to_delete.add(key)
|
|
||||||
|
|
||||||
for key in to_delete:
|
|
||||||
self.logger.warn(f'Removing model {key} from in-memory config because its path is no longer on disk')
|
|
||||||
config.pop(key)
|
|
||||||
|
|
||||||
def scan_models_directory(self, include_diffusers:bool=False):
|
|
||||||
'''
|
|
||||||
Scan the models directory for loras, textual_inversions and controlnets
|
|
||||||
and create appropriate entries in the in-memory omegaconf. Diffusers
|
|
||||||
will not be added unless include_diffusers is true.
|
|
||||||
'''
|
|
||||||
self._delete_defunct_models()
|
|
||||||
|
|
||||||
model_directory = self.globals.models_path
|
|
||||||
config = self.config
|
|
||||||
|
|
||||||
for root, dirs, files in os.walk(model_directory):
|
|
||||||
parents = root.split('/')
|
|
||||||
subpaths = parents[parents.index('models')+1:]
|
|
||||||
if len(subpaths) < 2:
|
|
||||||
continue
|
|
||||||
base, model_type, *_ = subpaths
|
|
||||||
|
|
||||||
if model_type == "diffusers" and not include_diffusers:
|
|
||||||
continue
|
|
||||||
|
|
||||||
for d in dirs:
|
|
||||||
config[f'{model_type}/{d}'] = dict(
|
|
||||||
path = os.path.join(root,d),
|
|
||||||
description = f'{model_type} model {d}',
|
|
||||||
format = 'folder',
|
|
||||||
base = base,
|
|
||||||
)
|
|
||||||
|
|
||||||
for f in files:
|
|
||||||
basename = Path(f).stem
|
|
||||||
format = Path(f).suffix[1:]
|
|
||||||
config[f'{model_type}/{basename}'] = dict(
|
|
||||||
path = os.path.join(root,f),
|
|
||||||
description = f'{model_type} model {basename}',
|
|
||||||
format = format,
|
|
||||||
base = base,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
333
invokeai/backend/model_management/model_probe.py
Normal file
333
invokeai/backend/model_management/model_probe.py
Normal file
@ -0,0 +1,333 @@
|
|||||||
|
import json
|
||||||
|
import traceback
|
||||||
|
import torch
|
||||||
|
import safetensors.torch
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from diffusers import ModelMixin
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable, Literal, Union, Dict
|
||||||
|
from picklescan.scanner import scan_file_path
|
||||||
|
|
||||||
|
import invokeai.backend.util.logging as logger
|
||||||
|
from .models import BaseModelType, ModelType, VariantType
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelVariantInfo(object):
|
||||||
|
model_type: ModelType
|
||||||
|
base_type: BaseModelType
|
||||||
|
variant_type: VariantType
|
||||||
|
|
||||||
|
class ProbeBase(object):
|
||||||
|
'''forward declaration'''
|
||||||
|
pass
|
||||||
|
|
||||||
|
class ModelProbe(object):
|
||||||
|
|
||||||
|
PROBES = {
|
||||||
|
'folder': { },
|
||||||
|
'file': { },
|
||||||
|
}
|
||||||
|
|
||||||
|
CLASS2TYPE = {
|
||||||
|
"StableDiffusionPipeline" : ModelType.Pipeline,
|
||||||
|
"AutoencoderKL": ModelType.Vae,
|
||||||
|
"ControlNetModel" : ModelType.ControlNet,
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_probe(cls,
|
||||||
|
format: Literal['folder','file'],
|
||||||
|
model_type: ModelType,
|
||||||
|
probe_class: ProbeBase):
|
||||||
|
cls.PROBES[format][model_type] = probe_class
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def probe(cls,
|
||||||
|
model_path: Path,
|
||||||
|
model: Union[Dict, ModelMixin] = None,
|
||||||
|
base_helper: Callable[[Path],BaseModelType] = None)->ModelVariantInfo:
|
||||||
|
'''
|
||||||
|
Probe the model at model_path and return sufficient information about it
|
||||||
|
to place it somewhere in the models directory hierarchy. If the model is
|
||||||
|
already loaded into memory, you may provide it as model in order to avoid
|
||||||
|
opening it a second time. The base_helper callable is a function that receives
|
||||||
|
the path to the model and returns the BaseModelType. It is called to distinguish
|
||||||
|
between V2-Base and V2-768 SD models.
|
||||||
|
'''
|
||||||
|
format = 'folder' if model_path.is_dir() else 'file'
|
||||||
|
model_info = None
|
||||||
|
try:
|
||||||
|
model_type = cls.get_model_type_from_folder(model_path, model) \
|
||||||
|
if format == 'folder' \
|
||||||
|
else cls.get_model_type_from_checkpoint(model_path, model)
|
||||||
|
probe_class = cls.PROBES[format].get(model_type)
|
||||||
|
if not probe_class:
|
||||||
|
return None
|
||||||
|
probe = probe_class(model_path, model, base_helper)
|
||||||
|
base_type = probe.get_base_type()
|
||||||
|
variant_type = probe.get_variant_type()
|
||||||
|
model_info = ModelVariantInfo(
|
||||||
|
model_type = model_type,
|
||||||
|
base_type = base_type,
|
||||||
|
variant_type = variant_type,
|
||||||
|
)
|
||||||
|
except (KeyError, ValueError) as e:
|
||||||
|
logger.error(f'An error occurred while probing {model_path}: {str(e)}')
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
|
return model_info
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict)->ModelType:
|
||||||
|
checkpoint = checkpoint or cls._scan_and_load_checkpoint(model_path)
|
||||||
|
state_dict = checkpoint.get("state_dict") or checkpoint
|
||||||
|
if any([x.startswith("model.diffusion_model") for x in state_dict.keys()]):
|
||||||
|
return ModelType.Pipeline
|
||||||
|
if any([x.startswith("encoder.conv_in") for x in state_dict.keys()]):
|
||||||
|
return ModelType.Vae
|
||||||
|
if "string_to_token" in state_dict or "emb_params" in state_dict:
|
||||||
|
return ModelType.TextualInversion
|
||||||
|
if any([x.startswith("lora") for x in state_dict.keys()]):
|
||||||
|
return ModelType.Lora
|
||||||
|
if any([x.startswith("control_model") for x in state_dict.keys()]):
|
||||||
|
return ModelType.ControlNet
|
||||||
|
if any([x.startswith("input_blocks") for x in state_dict.keys()]):
|
||||||
|
return ModelType.ControlNet
|
||||||
|
return None # give up
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin)->ModelType:
|
||||||
|
'''
|
||||||
|
Get the model type of a hugging-face style folder.
|
||||||
|
'''
|
||||||
|
if (folder_path / 'learned_embeds.bin').exists():
|
||||||
|
return ModelType.TextualInversion
|
||||||
|
|
||||||
|
if (folder_path / 'pytorch_lora_weights.bin').exists():
|
||||||
|
return ModelType.Lora
|
||||||
|
|
||||||
|
i = folder_path / 'model_index.json'
|
||||||
|
c = folder_path / 'config.json'
|
||||||
|
config_path = i if i.exists() else c if c.exists() else None
|
||||||
|
|
||||||
|
if config_path:
|
||||||
|
conf = json.loads(config_path)
|
||||||
|
class_name = conf['_class_name']
|
||||||
|
if type := cls.CLASS2TYPE.get(class_name):
|
||||||
|
return type
|
||||||
|
|
||||||
|
# give up
|
||||||
|
raise ValueError("Unable to determine model type of {model_path}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _scan_and_load_checkpoint(cls,model_path: Path)->dict:
|
||||||
|
if model_path.suffix.endswith((".ckpt", ".pt", ".bin")):
|
||||||
|
cls._scan_model(model_path, model_path)
|
||||||
|
return torch.load(model_path)
|
||||||
|
else:
|
||||||
|
return safetensors.torch.load_file(model_path)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _scan_model(cls, model_name, checkpoint):
|
||||||
|
"""
|
||||||
|
Apply picklescanner to the indicated checkpoint and issue a warning
|
||||||
|
and option to exit if an infected file is identified.
|
||||||
|
"""
|
||||||
|
# scan model
|
||||||
|
scan_result = scan_file_path(checkpoint)
|
||||||
|
if scan_result.infected_files != 0:
|
||||||
|
raise "The model {model_name} is potentially infected by malware. Aborting import."
|
||||||
|
|
||||||
|
###################################################3
|
||||||
|
# Checkpoint probing
|
||||||
|
###################################################3
|
||||||
|
class ProbeBase(object):
|
||||||
|
def get_base_type(self)->BaseModelType:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_variant_type(self)->VariantType:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class CheckpointProbeBase(ProbeBase):
|
||||||
|
def __init__(self,
|
||||||
|
checkpoint_path: Path,
|
||||||
|
checkpoint: dict,
|
||||||
|
helper: Callable[[Path],BaseModelType] = None
|
||||||
|
)->BaseModelType:
|
||||||
|
self.checkpoint = checkpoint or ModelProbe._scan_and_load_checkpoint(checkpoint_path)
|
||||||
|
self.checkpoint_path = checkpoint_path
|
||||||
|
self.helper = helper
|
||||||
|
|
||||||
|
def get_base_type(self)->BaseModelType:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_variant_type(self)-> VariantType:
|
||||||
|
model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path,self.checkpoint)
|
||||||
|
if model_type != ModelType.Pipeline:
|
||||||
|
return VariantType.Normal
|
||||||
|
state_dict = self.checkpoint.get('state_dict') or self.checkpoint
|
||||||
|
in_channels = state_dict[
|
||||||
|
"model.diffusion_model.input_blocks.0.0.weight"
|
||||||
|
].shape[1]
|
||||||
|
if in_channels == 9:
|
||||||
|
return VariantType.Inpaint
|
||||||
|
elif in_channels == 5:
|
||||||
|
return VariantType.Depth
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
class PipelineCheckpointProbe(CheckpointProbeBase):
|
||||||
|
def get_base_type(self)->BaseModelType:
|
||||||
|
checkpoint = self.checkpoint
|
||||||
|
helper = self.helper
|
||||||
|
state_dict = self.checkpoint.get('state_dict') or checkpoint
|
||||||
|
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||||
|
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
|
||||||
|
return BaseModelType.StableDiffusion1_5
|
||||||
|
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
||||||
|
if 'global_step' in checkpoint:
|
||||||
|
if checkpoint['global_step'] == 220000:
|
||||||
|
return BaseModelType.StableDiffusion2Base
|
||||||
|
elif checkpoint["global_step"] == 110000:
|
||||||
|
return BaseModelType.StableDiffusion2
|
||||||
|
if self.checkpoint_path and helper:
|
||||||
|
return helper(self.checkpoint_path)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
class VaeCheckpointProbe(CheckpointProbeBase):
|
||||||
|
def get_base_type(self)->BaseModelType:
|
||||||
|
# I can't find any standalone 2.X VAEs to test with!
|
||||||
|
return BaseModelType.StableDiffusion1_5
|
||||||
|
|
||||||
|
class LoRACheckpointProbe(CheckpointProbeBase):
|
||||||
|
def get_base_type(self)->BaseModelType:
|
||||||
|
checkpoint = self.checkpoint
|
||||||
|
key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
|
||||||
|
key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a"
|
||||||
|
lora_token_vector_length = (
|
||||||
|
checkpoint[key1].shape[1]
|
||||||
|
if key1 in checkpoint
|
||||||
|
else checkpoint[key2].shape[0]
|
||||||
|
if key2 in checkpoint
|
||||||
|
else 768
|
||||||
|
)
|
||||||
|
if lora_token_vector_length == 768:
|
||||||
|
return BaseModelType.StableDiffusion1_5
|
||||||
|
elif lora_token_vector_length == 1024:
|
||||||
|
return BaseModelType.StableDiffusion2
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
||||||
|
def get_base_type(self)->BaseModelType:
|
||||||
|
checkpoint = self.checkpoint
|
||||||
|
if 'string_to_token' in checkpoint:
|
||||||
|
token_dim = list(checkpoint['string_to_param'].values())[0].shape[-1]
|
||||||
|
elif 'emb_params' in checkpoint:
|
||||||
|
token_dim = checkpoint['emb_params'].shape[-1]
|
||||||
|
else:
|
||||||
|
token_dim = list(checkpoint.values())[0].shape[0]
|
||||||
|
if token_dim == 768:
|
||||||
|
return BaseModelType.StableDiffusion1_5
|
||||||
|
elif token_dim == 1024:
|
||||||
|
return BaseModelType.StableDiffusion2Base
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
class ControlNetCheckpointProbe(CheckpointProbeBase):
|
||||||
|
def get_base_type(self)->BaseModelType:
|
||||||
|
checkpoint = self.checkpoint
|
||||||
|
for key_name in ('control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight',
|
||||||
|
'input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight'
|
||||||
|
):
|
||||||
|
if key_name not in checkpoint:
|
||||||
|
continue
|
||||||
|
if checkpoint[key_name].shape[-1] == 768:
|
||||||
|
return BaseModelType.StableDiffusion1_5
|
||||||
|
elif self.checkpoint_path and self.helper:
|
||||||
|
return self.helper(self.checkpoint_path)
|
||||||
|
|
||||||
|
########################################################
|
||||||
|
# classes for probing folders
|
||||||
|
#######################################################
|
||||||
|
class FolderProbeBase(ProbeBase):
|
||||||
|
def __init__(self,
|
||||||
|
model: ModelMixin,
|
||||||
|
folder_path: Path,
|
||||||
|
helper: Callable=None # not used
|
||||||
|
):
|
||||||
|
self.model = model
|
||||||
|
self.folder_path = folder_path
|
||||||
|
|
||||||
|
def get_variant_type(self)->VariantType:
|
||||||
|
|
||||||
|
# only works for pipelines
|
||||||
|
config_file = self.folder_path / 'unet' / 'config.json'
|
||||||
|
if not config_file.exists():
|
||||||
|
return VariantType.Normal
|
||||||
|
|
||||||
|
conf = json.loads(config_file)
|
||||||
|
channels = conf['in_channels']
|
||||||
|
if channels == 9:
|
||||||
|
return VariantType.Inpainting
|
||||||
|
elif channels == 5:
|
||||||
|
return VariantType.Depth
|
||||||
|
elif channels == 4:
|
||||||
|
return VariantType.Normal
|
||||||
|
else:
|
||||||
|
return VariantType.Normal
|
||||||
|
|
||||||
|
class PipelineFolderProbe(FolderProbeBase):
|
||||||
|
def get_base_type(self)->BaseModelType:
|
||||||
|
config_file = self.folder_path / 'scheduler' / 'scheduler_config.json'
|
||||||
|
if not config_file.exists():
|
||||||
|
return None
|
||||||
|
conf = json.load(config_file)
|
||||||
|
if conf['prediction_type'] == "v_prediction":
|
||||||
|
return BaseModelType.StableDiffusion2
|
||||||
|
elif conf['prediction_type'] == 'epsilon':
|
||||||
|
return BaseModelType.StableDiffusion2Base
|
||||||
|
else:
|
||||||
|
return BaseModelType.StableDiffusion2
|
||||||
|
|
||||||
|
class VaeFolderProbe(FolderProbeBase):
|
||||||
|
def get_base_type(self)->BaseModelType:
|
||||||
|
return BaseModelType.StableDiffusion1_5
|
||||||
|
|
||||||
|
class TextualInversionFolderProbe(FolderProbeBase):
|
||||||
|
def get_base_type(self)->BaseModelType:
|
||||||
|
path = self.folder_path / 'learned_embeds.bin'
|
||||||
|
if not path.exists():
|
||||||
|
return None
|
||||||
|
checkpoint = ModelProbe._scan_and_load_checkpoint(path)
|
||||||
|
return TextualInversionCheckpointProbe(checkpoint).get_base_type
|
||||||
|
|
||||||
|
class ControlNetFolderProbe(FolderProbeBase):
|
||||||
|
def get_base_type(self)->BaseModelType:
|
||||||
|
config_file = self.folder_path / 'scheduler_config.json'
|
||||||
|
if not config_file.exists():
|
||||||
|
return None
|
||||||
|
config = json.load(config_file)
|
||||||
|
# no obvious way to distinguish between sd2-base and sd2-768
|
||||||
|
return BaseModelType.StableDiffusion1_5 \
|
||||||
|
if config['cross_attention_dim']==768 \
|
||||||
|
else BaseModelType.StableDiffusion2
|
||||||
|
|
||||||
|
class LoRAFolderProbe(FolderProbeBase):
|
||||||
|
# I've never seen one of these in the wild, so this is a noop
|
||||||
|
pass
|
||||||
|
|
||||||
|
############## register probe classes ######
|
||||||
|
ModelProbe.register_probe('folder', ModelType.Pipeline, PipelineFolderProbe)
|
||||||
|
ModelProbe.register_probe('folder', ModelType.Vae, VaeFolderProbe)
|
||||||
|
ModelProbe.register_probe('folder', ModelType.Lora, LoRAFolderProbe)
|
||||||
|
ModelProbe.register_probe('folder', ModelType.TextualInversion, TextualInversionFolderProbe)
|
||||||
|
ModelProbe.register_probe('folder', ModelType.ControlNet, ControlNetFolderProbe)
|
||||||
|
ModelProbe.register_probe('file', ModelType.Pipeline, PipelineCheckpointProbe)
|
||||||
|
ModelProbe.register_probe('file', ModelType.Vae, VaeCheckpointProbe)
|
||||||
|
ModelProbe.register_probe('file', ModelType.Lora, LoRACheckpointProbe)
|
||||||
|
ModelProbe.register_probe('file', ModelType.TextualInversion, TextualInversionCheckpointProbe)
|
||||||
|
ModelProbe.register_probe('file', ModelType.ControlNet, ControlNetCheckpointProbe)
|
@ -1,8 +1,14 @@
|
|||||||
import sys
|
import sys
|
||||||
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import torch
|
import torch
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
|
from diffusers import ConfigMixin
|
||||||
from diffusers.utils import is_safetensors_available
|
from diffusers.utils import is_safetensors_available
|
||||||
|
from omegaconf import DictConfig
|
||||||
|
from pathlib import Path
|
||||||
|
from pydantic import BaseModel, Field, root_validator
|
||||||
|
from typing import Union, List, Type, Optional
|
||||||
|
|
||||||
class BaseModelType(str, Enum):
|
class BaseModelType(str, Enum):
|
||||||
# TODO: maybe then add sample size(512/768)?
|
# TODO: maybe then add sample size(512/768)?
|
||||||
@ -26,42 +32,13 @@ class SubModelType:
|
|||||||
Vae = "vae"
|
Vae = "vae"
|
||||||
Scheduler = "scheduler"
|
Scheduler = "scheduler"
|
||||||
SafetyChecker = "safety_checker"
|
SafetyChecker = "safety_checker"
|
||||||
|
FeatureExtractor = "feature_extractor"
|
||||||
#MoVQ = "movq"
|
#MoVQ = "movq"
|
||||||
|
|
||||||
MODEL_CLASSES = {
|
class VariantType(str, Enum):
|
||||||
BaseModel.StableDiffusion1_5: {
|
Normal = "normal"
|
||||||
ModelType.Pipeline: StableDiffusionModel,
|
Inpaint = "inpaint"
|
||||||
ModelType.Classifier: ClassifierModel,
|
Depth = "depth"
|
||||||
ModelType.Vae: VaeModel,
|
|
||||||
ModelType.Lora: LoraModel,
|
|
||||||
ModelType.ControlNet: ControlNetModel,
|
|
||||||
ModelType.TextualInversion: TextualInversionModel,
|
|
||||||
},
|
|
||||||
BaseModel.StableDiffusion2: {
|
|
||||||
ModelType.Pipeline: StableDiffusionModel,
|
|
||||||
ModelType.Classifier: ClassifierModel,
|
|
||||||
ModelType.Vae: VaeModel,
|
|
||||||
ModelType.Lora: LoraModel,
|
|
||||||
ModelType.ControlNet: ControlNetModel,
|
|
||||||
ModelType.TextualInversion: TextualInversionModel,
|
|
||||||
},
|
|
||||||
BaseModel.StableDiffusion2Base: {
|
|
||||||
ModelType.Pipeline: StableDiffusionModel,
|
|
||||||
ModelType.Classifier: ClassifierModel,
|
|
||||||
ModelType.Vae: VaeModel,
|
|
||||||
ModelType.Lora: LoraModel,
|
|
||||||
ModelType.ControlNet: ControlNetModel,
|
|
||||||
ModelType.TextualInversion: TextualInversionModel,
|
|
||||||
},
|
|
||||||
#BaseModel.Kandinsky2_1: {
|
|
||||||
# ModelType.Pipeline: Kandinsky2_1Model,
|
|
||||||
# ModelType.Classifier: ClassifierModel,
|
|
||||||
# ModelType.MoVQ: MoVQModel,
|
|
||||||
# ModelType.Lora: LoraModel,
|
|
||||||
# ModelType.ControlNet: ControlNetModel,
|
|
||||||
# ModelType.TextualInversion: TextualInversionModel,
|
|
||||||
#},
|
|
||||||
}
|
|
||||||
|
|
||||||
class EmptyConfigLoader(ConfigMixin):
|
class EmptyConfigLoader(ConfigMixin):
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -323,7 +300,7 @@ class ClassifierModel(ModelBase):
|
|||||||
#child_sizes: Dict[str, int]
|
#child_sizes: Dict[str, int]
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
assert model_type == SDModelType.Classifier
|
assert model_type == ModelType.Classifier
|
||||||
super().__init__(model_path, base_model, model_type)
|
super().__init__(model_path, base_model, model_type)
|
||||||
|
|
||||||
self.child_types: Dict[str, Type] = dict()
|
self.child_types: Dict[str, Type] = dict()
|
||||||
@ -354,8 +331,8 @@ class ClassifierModel(ModelBase):
|
|||||||
else:
|
else:
|
||||||
raise Exception("Invalid classifier model! (Failed to detect tokenizer type)")
|
raise Exception("Invalid classifier model! (Failed to detect tokenizer type)")
|
||||||
|
|
||||||
self.child_types[SDModelType.Tokenizer] = self._hf_definition_to_type(["transformers", tokenizer_class_name])
|
self.child_types[SubModelType.Tokenizer] = self._hf_definition_to_type(["transformers", tokenizer_class_name])
|
||||||
self.child_sizes[SDModelType.Tokenizer] = 0
|
self.child_sizes[SubModelType.Tokenizer] = 0
|
||||||
|
|
||||||
|
|
||||||
def _load_text_encoder(self, main_config: dict):
|
def _load_text_encoder(self, main_config: dict):
|
||||||
@ -366,12 +343,12 @@ class ClassifierModel(ModelBase):
|
|||||||
else:
|
else:
|
||||||
raise Exception("Invalid classifier model! (Failed to detect text_encoder type)")
|
raise Exception("Invalid classifier model! (Failed to detect text_encoder type)")
|
||||||
|
|
||||||
self.child_types[SDModelType.TextEncoder] = self._hf_definition_to_type(["transformers", text_encoder_class_name])
|
self.child_types[SubModelType.TextEncoder] = self._hf_definition_to_type(["transformers", text_encoder_class_name])
|
||||||
self.child_sizes[SDModelType.TextEncoder] = calc_model_size_by_fs(self.model_path)
|
self.child_sizes[SubModelType.TextEncoder] = calc_model_size_by_fs(self.model_path)
|
||||||
|
|
||||||
|
|
||||||
def _load_feature_extractor(self, main_config: dict):
|
def _load_feature_extractor(self, main_config: dict):
|
||||||
self.child_sizes[SDModelType.FeatureExtractor] = 0
|
self.child_sizes[SubModelType.FeatureExtractor] = 0
|
||||||
try:
|
try:
|
||||||
feature_extractor_config = EmptyConfigLoader.load_config(self.model_path, config_name="preprocessor_config.json")
|
feature_extractor_config = EmptyConfigLoader.load_config(self.model_path, config_name="preprocessor_config.json")
|
||||||
except:
|
except:
|
||||||
@ -379,12 +356,12 @@ class ClassifierModel(ModelBase):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
feature_extractor_class_name = feature_extractor_config["feature_extractor_type"]
|
feature_extractor_class_name = feature_extractor_config["feature_extractor_type"]
|
||||||
self.child_types[SDModelType.FeatureExtractor] = self._hf_definition_to_type(["transformers", feature_extractor_class_name])
|
self.child_types[SubModelType.FeatureExtractor] = self._hf_definition_to_type(["transformers", feature_extractor_class_name])
|
||||||
except:
|
except:
|
||||||
raise Exception("Invalid classifier model! (Unknown feature_extrator type)")
|
raise Exception("Invalid classifier model! (Unknown feature_extrator type)")
|
||||||
|
|
||||||
|
|
||||||
def get_size(self, child_type: Optional[SDModelType] = None):
|
def get_size(self, child_type: Optional[SubModelType] = None):
|
||||||
if child_type is None:
|
if child_type is None:
|
||||||
return sum(self.child_sizes.values())
|
return sum(self.child_sizes.values())
|
||||||
else:
|
else:
|
||||||
@ -394,7 +371,7 @@ class ClassifierModel(ModelBase):
|
|||||||
def get_model(
|
def get_model(
|
||||||
self,
|
self,
|
||||||
torch_dtype: Optional[torch.dtype],
|
torch_dtype: Optional[torch.dtype],
|
||||||
child_type: Optional[SDModelType] = None,
|
child_type: Optional[SubModelType] = None,
|
||||||
):
|
):
|
||||||
if child_type is None:
|
if child_type is None:
|
||||||
raise Exception("Child model type can't be null on classififer model")
|
raise Exception("Child model type can't be null on classififer model")
|
||||||
@ -437,7 +414,7 @@ class VaeModel(ModelBase):
|
|||||||
except:
|
except:
|
||||||
raise Exception("Invalid vae model! (Unkown vae type)")
|
raise Exception("Invalid vae model! (Unkown vae type)")
|
||||||
|
|
||||||
def get_size(self, child_type: Optional[SDModelType] = None):
|
def get_size(self, child_type: Optional[SubModelType] = None):
|
||||||
if child_type is not None:
|
if child_type is not None:
|
||||||
raise Exception("There is no child models in vae model")
|
raise Exception("There is no child models in vae model")
|
||||||
return self.model_size
|
return self.model_size
|
||||||
@ -445,7 +422,7 @@ class VaeModel(ModelBase):
|
|||||||
def get_model(
|
def get_model(
|
||||||
self,
|
self,
|
||||||
torch_dtype: Optional[torch.dtype],
|
torch_dtype: Optional[torch.dtype],
|
||||||
child_type: Optional[SDModelType] = None,
|
child_type: Optional[SubModelType] = None,
|
||||||
):
|
):
|
||||||
if child_type is not None:
|
if child_type is not None:
|
||||||
raise Exception("There is no child models in vae model")
|
raise Exception("There is no child models in vae model")
|
||||||
@ -476,7 +453,7 @@ class LoRAModel(ModelBase):
|
|||||||
|
|
||||||
self.model_size = os.path.getsize(self.model_path)
|
self.model_size = os.path.getsize(self.model_path)
|
||||||
|
|
||||||
def get_size(self, child_type: Optional[SDModelType] = None):
|
def get_size(self, child_type: Optional[SubModelType] = None):
|
||||||
if child_type is not None:
|
if child_type is not None:
|
||||||
raise Exception("There is no child models in lora")
|
raise Exception("There is no child models in lora")
|
||||||
return self.model_size
|
return self.model_size
|
||||||
@ -484,7 +461,7 @@ class LoRAModel(ModelBase):
|
|||||||
def get_model(
|
def get_model(
|
||||||
self,
|
self,
|
||||||
torch_dtype: Optional[torch.dtype],
|
torch_dtype: Optional[torch.dtype],
|
||||||
child_type: Optional[SDModelType] = None,
|
child_type: Optional[SubModelType] = None,
|
||||||
):
|
):
|
||||||
if child_type is not None:
|
if child_type is not None:
|
||||||
raise Exception("There is no child models in lora")
|
raise Exception("There is no child models in lora")
|
||||||
@ -505,7 +482,6 @@ class LoRAModel(ModelBase):
|
|||||||
# TODO: add diffusers lora when it stabilizes a bit
|
# TODO: add diffusers lora when it stabilizes a bit
|
||||||
return model_path
|
return model_path
|
||||||
|
|
||||||
|
|
||||||
class TextualInversionModel(ModelBase):
|
class TextualInversionModel(ModelBase):
|
||||||
#model_size: int
|
#model_size: int
|
||||||
|
|
||||||
@ -515,7 +491,7 @@ class TextualInversionModel(ModelBase):
|
|||||||
|
|
||||||
self.model_size = os.path.getsize(self.model_path)
|
self.model_size = os.path.getsize(self.model_path)
|
||||||
|
|
||||||
def get_size(self, child_type: Optional[SDModelType] = None):
|
def get_size(self, child_type: Optional[SubModelType] = None):
|
||||||
if child_type is not None:
|
if child_type is not None:
|
||||||
raise Exception("There is no child models in textual inversion")
|
raise Exception("There is no child models in textual inversion")
|
||||||
return self.model_size
|
return self.model_size
|
||||||
@ -523,7 +499,7 @@ class TextualInversionModel(ModelBase):
|
|||||||
def get_model(
|
def get_model(
|
||||||
self,
|
self,
|
||||||
torch_dtype: Optional[torch.dtype],
|
torch_dtype: Optional[torch.dtype],
|
||||||
child_type: Optional[SDModelType] = None,
|
child_type: Optional[SubModelType] = None,
|
||||||
):
|
):
|
||||||
if child_type is not None:
|
if child_type is not None:
|
||||||
raise Exception("There is no child models in textual inversion")
|
raise Exception("There is no child models in textual inversion")
|
||||||
@ -542,6 +518,10 @@ class TextualInversionModel(ModelBase):
|
|||||||
model_path = Path(model_path)
|
model_path = Path(model_path)
|
||||||
return model_path
|
return model_path
|
||||||
|
|
||||||
|
class ControlNetModel(ModelBase):
|
||||||
|
"""requires implementation"""
|
||||||
|
pass
|
||||||
|
|
||||||
def calc_model_size_by_fs(
|
def calc_model_size_by_fs(
|
||||||
model_path: str,
|
model_path: str,
|
||||||
subfolder: Optional[str] = None,
|
subfolder: Optional[str] = None,
|
||||||
@ -710,3 +690,39 @@ def _convert_vae_ckpt_and_cache(self, mconfig: DictConfig) -> Path:
|
|||||||
safe_serialization=is_safetensors_available()
|
safe_serialization=is_safetensors_available()
|
||||||
)
|
)
|
||||||
return diffusers_path
|
return diffusers_path
|
||||||
|
|
||||||
|
MODEL_CLASSES = {
|
||||||
|
BaseModelType.StableDiffusion1_5: {
|
||||||
|
ModelType.Pipeline: StableDiffusionModel,
|
||||||
|
ModelType.Classifier: ClassifierModel,
|
||||||
|
ModelType.Vae: VaeModel,
|
||||||
|
ModelType.Lora: LoRAModel,
|
||||||
|
ModelType.ControlNet: ControlNetModel,
|
||||||
|
ModelType.TextualInversion: TextualInversionModel,
|
||||||
|
},
|
||||||
|
BaseModelType.StableDiffusion2: {
|
||||||
|
ModelType.Pipeline: StableDiffusionModel,
|
||||||
|
ModelType.Classifier: ClassifierModel,
|
||||||
|
ModelType.Vae: VaeModel,
|
||||||
|
ModelType.Lora: LoRAModel,
|
||||||
|
ModelType.ControlNet: ControlNetModel,
|
||||||
|
ModelType.TextualInversion: TextualInversionModel,
|
||||||
|
},
|
||||||
|
BaseModelType.StableDiffusion2Base: {
|
||||||
|
ModelType.Pipeline: StableDiffusionModel,
|
||||||
|
ModelType.Classifier: ClassifierModel,
|
||||||
|
ModelType.Vae: VaeModel,
|
||||||
|
ModelType.Lora: LoRAModel,
|
||||||
|
ModelType.ControlNet: ControlNetModel,
|
||||||
|
ModelType.TextualInversion: TextualInversionModel,
|
||||||
|
},
|
||||||
|
#BaseModel.Kandinsky2_1: {
|
||||||
|
# ModelType.Pipeline: Kandinsky2_1Model,
|
||||||
|
# ModelType.Classifier: ClassifierModel,
|
||||||
|
# ModelType.MoVQ: MoVQModel,
|
||||||
|
# ModelType.Lora: LoraModel,
|
||||||
|
# ModelType.ControlNet: ControlNetModel,
|
||||||
|
# ModelType.TextualInversion: TextualInversionModel,
|
||||||
|
#},
|
||||||
|
}
|
||||||
|
|
||||||
|
55
scripts/scan_models_directory.py
Normal file → Executable file
55
scripts/scan_models_directory.py
Normal file → Executable file
@ -1,3 +1,5 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
'''
|
'''
|
||||||
Scan the models directory and print out a new models.yaml
|
Scan the models directory and print out a new models.yaml
|
||||||
'''
|
'''
|
||||||
@ -12,6 +14,12 @@ from omegaconf import OmegaConf
|
|||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="Model directory scanner")
|
parser = argparse.ArgumentParser(description="Model directory scanner")
|
||||||
parser.add_argument('models_directory')
|
parser.add_argument('models_directory')
|
||||||
|
parser.add_argument('--all-models',
|
||||||
|
default=False,
|
||||||
|
action='store_true',
|
||||||
|
help='If true, then generates stanzas for all models; otherwise just diffusers'
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
directory = args.models_directory
|
directory = args.models_directory
|
||||||
|
|
||||||
@ -19,29 +27,30 @@ def main():
|
|||||||
conf['_version'] = '3.0.0'
|
conf['_version'] = '3.0.0'
|
||||||
|
|
||||||
for root, dirs, files in os.walk(directory):
|
for root, dirs, files in os.walk(directory):
|
||||||
for d in dirs:
|
parents = root.split('/')
|
||||||
parents = root.split('/')
|
subpaths = parents[parents.index('models')+1:]
|
||||||
subpaths = parents[parents.index('models')+1:]
|
if len(subpaths) < 2:
|
||||||
if len(subpaths) < 2:
|
continue
|
||||||
continue
|
base, model_type, *_ = subpaths
|
||||||
base, model_type, *_ = subpaths
|
|
||||||
|
if args.all_models or model_type=='diffusers':
|
||||||
conf[f'{model_type}/{d}'] = dict(
|
for d in dirs:
|
||||||
path = os.path.join(root,d),
|
conf[f'{base}/{model_type}/{d}'] = dict(
|
||||||
description = f'{model_type} model {d}',
|
path = os.path.join(root,d),
|
||||||
format = 'folder',
|
description = f'{model_type} model {d}',
|
||||||
base = base,
|
format = 'folder',
|
||||||
)
|
base = base,
|
||||||
|
)
|
||||||
for f in files:
|
|
||||||
basename = Path(f).stem
|
for f in files:
|
||||||
format = Path(f).suffix[1:]
|
basename = Path(f).stem
|
||||||
conf[f'{model_type}/{basename}'] = dict(
|
format = Path(f).suffix[1:]
|
||||||
path = os.path.join(root,f),
|
conf[f'{base}/{model_type}/{basename}'] = dict(
|
||||||
description = f'{model_type} model {basename}',
|
path = os.path.join(root,f),
|
||||||
format = format,
|
description = f'{model_type} model {basename}',
|
||||||
base = base,
|
format = format,
|
||||||
)
|
base = base,
|
||||||
|
)
|
||||||
|
|
||||||
OmegaConf.save(config=dict(sorted(conf.items())), f=sys.stdout)
|
OmegaConf.save(config=dict(sorted(conf.items())), f=sys.stdout)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user