Add SchedulerPredictionType and ModelVariantType enums

This commit is contained in:
Lincoln Stein 2023-06-12 16:07:04 -04:00
parent 36eb1bd893
commit 1439dc7712
6 changed files with 118 additions and 71 deletions

View File

@ -3,4 +3,4 @@ Initialization file for invokeai.backend.model_management
"""
from .model_manager import ModelManager, ModelInfo
from .model_cache import ModelCache
from .models import BaseModelType, ModelType, SubModelType, VariantType
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType

View File

@ -200,20 +200,27 @@ MAX_CACHE_SIZE = 6.0 # GB
# layout of the models directory:
# models
# ├── SD-1
# ├── sd-1
# │   ├── controlnet
# │   ├── lora
# │   ├── diffusers
# │   └── textual_inversion
# ├── SD-2
# ├── sd-2
# │   ├── controlnet
# │   ├── lora
# │   ├── diffusers
# │   └── textual_inversion
# └── support
# ├── codeformer
# ├── gfpgan
# └── realesrgan
# │ └── textual_inversion
# └── core
# ├── face_reconstruction
# │ ├── codeformer
# │ └── gfpgan
# ├── sd-conversion
# │ ├── clip-vit-large-patch14 - tokenizer, text_encoder subdirs
# │ ├── stable-diffusion-2 - tokenizer, text_encoder subdirs
# │ └── stable-diffusion-safety-checker
# └── upscaling
# └─── esrgan
class ConfigMeta(BaseModel):

View File

@ -4,20 +4,24 @@ import torch
import safetensors.torch
from dataclasses import dataclass
from enum import Enum
from diffusers import ModelMixin, ConfigMixin, StableDiffusionPipeline, AutoencoderKL, ControlNetModel
from pathlib import Path
from typing import Callable, Literal, Union, Dict
from picklescan.scanner import scan_file_path
import invokeai.backend.util.logging as logger
from .models import BaseModelType, ModelType, VariantType
from .models import BaseModelType, ModelType, ModelVariantType, SchedulerPredictionType
from .model_cache import SilenceWarnings
@dataclass
class ModelVariantInfo(object):
model_type: ModelType
base_type: BaseModelType
variant_type: VariantType
variant_type: ModelVariantType
prediction_type: SchedulerPredictionType
image_size: int
class ProbeBase(object):
'''forward declaration'''
@ -27,7 +31,7 @@ class ModelProbe(object):
PROBES = {
'folder': { },
'file': { },
'checkpoint': { },
}
CLASS2TYPE = {
@ -43,16 +47,28 @@ class ModelProbe(object):
probe_class: ProbeBase):
cls.PROBES[format][model_type] = probe_class
@classmethod
def heuristic_probe(cls,
model: Union[Dict, ModelMixin, Path],
prediction_type_helper: Callable[[Path],BaseModelType]=None,
)->ModelVariantInfo:
if isinstance(model,Path):
return cls.probe(model_path=model,prediction_type_helper=prediction_type_helper)
elif isinstance(model,(dict,ModelMixin,ConfigMixin)):
return cls.probe(model_path=None, model=model, prediction_type_helper=prediction_type_helper)
else:
raise Exception("model parameter {model} is neither a Path, nor a model")
@classmethod
def probe(cls,
model_path: Path,
model: Union[Dict, ModelMixin] = None,
base_helper: Callable[[Path],BaseModelType] = None)->ModelVariantInfo:
prediction_type_helper: Callable[[Path],BaseModelType] = None)->ModelVariantInfo:
'''
Probe the model at model_path and return sufficient information about it
to place it somewhere in the models directory hierarchy. If the model is
already loaded into memory, you may provide it as model in order to avoid
opening it a second time. The base_helper callable is a function that receives
opening it a second time. The prediction_type_helper callable is a function that receives
the path to the model and returns the BaseModelType. It is called to distinguish
between V2-Base and V2-768 SD models.
'''
@ -69,13 +85,18 @@ class ModelProbe(object):
probe_class = cls.PROBES[format].get(model_type)
if not probe_class:
return None
probe = probe_class(model_path, model, base_helper)
probe = probe_class(model_path, model, prediction_type_helper)
base_type = probe.get_base_type()
variant_type = probe.get_variant_type()
prediction_type = probe.get_scheduler_prediction_type()
model_info = ModelVariantInfo(
model_type = model_type,
base_type = base_type,
variant_type = variant_type,
prediction_type = prediction_type,
image_size = 768 if (base_type==BaseModelType.StableDiffusion2 \
and prediction_type==SchedulerPredictionType.VPrediction \
) else 512
)
except (KeyError, ValueError) as e:
logger.error(f'An error occurred while probing {model_path}: {str(e)}')
@ -120,7 +141,8 @@ class ModelProbe(object):
config_path = i if i.exists() else c if c.exists() else None
if config_path:
conf = json.load(open(config_path,'r'))
with open(config_path,'r') as file:
conf = json.load(file)
class_name = conf['_class_name']
if type := cls.CLASS2TYPE.get(class_name):
@ -156,9 +178,12 @@ class ProbeBase(object):
def get_base_type(self)->BaseModelType:
pass
def get_variant_type(self)->VariantType:
def get_variant_type(self)->ModelVariantType:
pass
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
pass
class CheckpointProbeBase(ProbeBase):
def __init__(self,
checkpoint_path: Path,
@ -172,44 +197,54 @@ class CheckpointProbeBase(ProbeBase):
def get_base_type(self)->BaseModelType:
pass
def get_variant_type(self)-> VariantType:
def get_variant_type(self)-> ModelVariantType:
model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path,self.checkpoint)
if model_type != ModelType.Pipeline:
return VariantType.Normal
return ModelVariantType.Normal
state_dict = self.checkpoint.get('state_dict') or self.checkpoint
in_channels = state_dict[
"model.diffusion_model.input_blocks.0.0.weight"
].shape[1]
if in_channels == 9:
return VariantType.Inpaint
return ModelVariantType.Inpaint
elif in_channels == 5:
return VariantType.Depth
return ModelVariantType.Depth
else:
return None
class PipelineCheckpointProbe(CheckpointProbeBase):
def get_base_type(self)->BaseModelType:
checkpoint = self.checkpoint
helper = self.helper
state_dict = self.checkpoint.get('state_dict') or checkpoint
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
return BaseModelType.StableDiffusion1_5
return BaseModelType.StableDiffusion1
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
return BaseModelType.StableDiffusion2
raise Exception("Cannot determine base type")
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
type = self.get_base_type()
if type == BaseModelType.StableDiffusion1:
return SchedulerPredictionType.Epsilon
checkpoint = self.checkpoint
state_dict = self.checkpoint.get('state_dict') or checkpoint
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
if 'global_step' in checkpoint:
if checkpoint['global_step'] == 220000:
return BaseModelType.StableDiffusion2Base
return SchedulerPredictionType.Epsilon
elif checkpoint["global_step"] == 110000:
return BaseModelType.StableDiffusion2
if self.checkpoint_path and helper:
return helper(self.checkpoint_path)
return SchedulerPredictionType.VPrediction
if self.checkpoint_path and self.helper:
return self.helper(self.checkpoint_path)
else:
return None
class VaeCheckpointProbe(CheckpointProbeBase):
def get_base_type(self)->BaseModelType:
# I can't find any standalone 2.X VAEs to test with!
return BaseModelType.StableDiffusion1_5
return BaseModelType.StableDiffusion1
class LoRACheckpointProbe(CheckpointProbeBase):
def get_base_type(self)->BaseModelType:
@ -224,7 +259,7 @@ class LoRACheckpointProbe(CheckpointProbeBase):
else 768
)
if lora_token_vector_length == 768:
return BaseModelType.StableDiffusion1_5
return BaseModelType.StableDiffusion1
elif lora_token_vector_length == 1024:
return BaseModelType.StableDiffusion2
else:
@ -240,9 +275,9 @@ class TextualInversionCheckpointProbe(CheckpointProbeBase):
else:
token_dim = list(checkpoint.values())[0].shape[0]
if token_dim == 768:
return BaseModelType.StableDiffusion1_5
return BaseModelType.StableDiffusion1
elif token_dim == 1024:
return BaseModelType.StableDiffusion2Base
return BaseModelType.StableDiffusion2
else:
return None
@ -255,7 +290,7 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
if key_name not in checkpoint:
continue
if checkpoint[key_name].shape[-1] == 768:
return BaseModelType.StableDiffusion1_5
return BaseModelType.StableDiffusion1
elif self.checkpoint_path and self.helper:
return self.helper(self.checkpoint_path)
@ -271,8 +306,8 @@ class FolderProbeBase(ProbeBase):
self.model = model
self.folder_path = folder_path
def get_variant_type(self)->VariantType:
return VariantType.Normal
def get_variant_type(self)->ModelVariantType:
return ModelVariantType.Normal
class PipelineFolderProbe(FolderProbeBase):
def get_base_type(self)->BaseModelType:
@ -280,22 +315,32 @@ class PipelineFolderProbe(FolderProbeBase):
unet_conf = self.model.unet.config
scheduler_conf = self.model.scheduler.config
else:
unet_conf = json.load(open(self.folder_path / 'unet' / 'config.json','r'))
scheduler_conf = json.load(open(self.folder_path / 'scheduler' / 'scheduler_config.json','r'))
with open(self.folder_path / 'unet' / 'config.json','r') as file:
unet_conf = json.load(file)
with open(self.folder_path / 'scheduler' / 'scheduler_config.json','r') as file:
scheduler_conf = json.load(file)
if unet_conf['cross_attention_dim'] == 768:
return BaseModelType.StableDiffusion1_5
return BaseModelType.StableDiffusion1
elif unet_conf['cross_attention_dim'] == 1024:
if scheduler_conf['prediction_type'] == "v_prediction":
return BaseModelType.StableDiffusion2
elif scheduler_conf['prediction_type'] == 'epsilon':
return BaseModelType.StableDiffusion2Base
else:
return BaseModelType.StableDiffusion2
return BaseModelType.StableDiffusion2
else:
raise ValueError(f'Unknown base model for {self.folder_path}')
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
if self.model:
scheduler_conf = self.model.scheduler.config
else:
with open(self.folder_path / 'scheduler' / 'scheduler_config.json','r') as file:
scheduler_conf = json.load(file)
if scheduler_conf['prediction_type'] == "v_prediction":
return SchedulerPredictionType.VPrediction
elif scheduler_conf['prediction_type'] == 'epsilon':
return SchedulerPredictionType.Epsilon
else:
return None
def get_variant_type(self)->VariantType:
def get_variant_type(self)->ModelVariantType:
# This only works for pipelines! Any kind of
# exception results in our returning the
# "normal" variant type
@ -304,22 +349,23 @@ class PipelineFolderProbe(FolderProbeBase):
conf = self.model.unet.config
else:
config_file = self.folder_path / 'unet' / 'config.json'
conf = json.load(open(config_file,'r'))
with open(config_file,'r') as file:
conf = json.load(file)
in_channels = conf['in_channels']
if in_channels == 9:
return VariantType.Inpainting
return ModelVariantType.Inpainting
elif in_channels == 5:
return VariantType.Depth
return ModelVariantType.Depth
elif in_channels == 4:
return VariantType.Normal
return ModelVariantType.Normal
except:
pass
return VariantType.Normal
return ModelVariantType.Normal
class VaeFolderProbe(FolderProbeBase):
def get_base_type(self)->BaseModelType:
return BaseModelType.StableDiffusion1_5
return BaseModelType.StableDiffusion1
class TextualInversionFolderProbe(FolderProbeBase):
def get_base_type(self)->BaseModelType:
@ -336,7 +382,7 @@ class ControlNetFolderProbe(FolderProbeBase):
return None
config = json.load(config_file)
# no obvious way to distinguish between sd2-base and sd2-768
return BaseModelType.StableDiffusion1_5 \
return BaseModelType.StableDiffusion1 \
if config['cross_attention_dim']==768 \
else BaseModelType.StableDiffusion2
@ -350,8 +396,8 @@ ModelProbe.register_probe('folder', ModelType.Vae, VaeFolderProbe)
ModelProbe.register_probe('folder', ModelType.Lora, LoRAFolderProbe)
ModelProbe.register_probe('folder', ModelType.TextualInversion, TextualInversionFolderProbe)
ModelProbe.register_probe('folder', ModelType.ControlNet, ControlNetFolderProbe)
ModelProbe.register_probe('file', ModelType.Pipeline, PipelineCheckpointProbe)
ModelProbe.register_probe('file', ModelType.Vae, VaeCheckpointProbe)
ModelProbe.register_probe('file', ModelType.Lora, LoRACheckpointProbe)
ModelProbe.register_probe('file', ModelType.TextualInversion, TextualInversionCheckpointProbe)
ModelProbe.register_probe('file', ModelType.ControlNet, ControlNetCheckpointProbe)
ModelProbe.register_probe('checkpoint', ModelType.Pipeline, PipelineCheckpointProbe)
ModelProbe.register_probe('checkpoint', ModelType.Vae, VaeCheckpointProbe)
ModelProbe.register_probe('checkpoint', ModelType.Lora, LoRACheckpointProbe)
ModelProbe.register_probe('checkpoint', ModelType.TextualInversion, TextualInversionCheckpointProbe)
ModelProbe.register_probe('checkpoint', ModelType.ControlNet, ControlNetCheckpointProbe)

View File

@ -1,4 +1,4 @@
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, VariantType
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType
from .stable_diffusion import StableDiffusion15Model, StableDiffusion2Model, StableDiffusion2BaseModel
from .vae import VaeModel
from .lora import LoRAModel
@ -10,7 +10,7 @@ class ControlNetModel:
pass
MODEL_CLASSES = {
BaseModelType.StableDiffusion1_5: {
BaseModelType.StableDiffusion1: {
ModelType.Pipeline: StableDiffusion15Model,
ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel,
@ -24,13 +24,6 @@ MODEL_CLASSES = {
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
},
BaseModelType.StableDiffusion2Base: {
ModelType.Pipeline: StableDiffusion2BaseModel,
ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
},
#BaseModelType.Kandinsky2_1: {
# ModelType.Pipeline: Kandinsky2_1Model,
# ModelType.MoVQ: MoVQModel,

View File

@ -14,8 +14,7 @@ class BaseModelType(str, Enum):
#StableDiffusion2 = "stable_diffusion_2"
#StableDiffusion2Base = "stable_diffusion_2_base"
# TODO: maybe then add sample size(512/768)?
StableDiffusion1_5 = "sd-1.5"
StableDiffusion2Base = "sd-2-base" # 512 pixels; this will have epsilon parameterization
StableDiffusion1 = "sd-1"
StableDiffusion2 = "sd-2" # 768 pixels; this will have v-prediction parameterization
#Kandinsky2_1 = "kandinsky_2_1"
@ -35,10 +34,15 @@ class SubModelType(str, Enum):
SafetyChecker = "safety_checker"
#MoVQ = "movq"
class VariantType(str, Enum):
class ModelVariantType(str, Enum):
Normal = "normal"
Inpaint = "inpaint"
Depth = "depth"
class SchedulerPredictionType(str, Enum):
Epsilon = "epsilon"
VPrediction = "v_prediction"
Sample = "sample"
class ModelError(str, Enum):
NotFound = "not_found"

View File

@ -10,14 +10,11 @@ from .base import (
BaseModelType,
ModelType,
SubModelType,
VariantType,
ModelVariantType,
DiffusersModel,
)
from invokeai.app.services.config import InvokeAIAppConfig
ModelVariantType = VariantType # TODO:
# TODO: how to name properly
class StableDiffusion15Model(DiffusersModel):