Add SchedulerPredictionType and ModelVariantType enums

This commit is contained in:
Lincoln Stein
2023-06-12 16:07:04 -04:00
parent 36eb1bd893
commit 1439dc7712
6 changed files with 118 additions and 71 deletions

View File

@ -3,4 +3,4 @@ Initialization file for invokeai.backend.model_management
""" """
from .model_manager import ModelManager, ModelInfo from .model_manager import ModelManager, ModelInfo
from .model_cache import ModelCache from .model_cache import ModelCache
from .models import BaseModelType, ModelType, SubModelType, VariantType from .models import BaseModelType, ModelType, SubModelType, ModelVariantType

View File

@ -200,20 +200,27 @@ MAX_CACHE_SIZE = 6.0 # GB
# layout of the models directory: # layout of the models directory:
# models # models
# ├── SD-1 # ├── sd-1
# │   ├── controlnet # │   ├── controlnet
# │   ├── lora # │   ├── lora
# │   ├── diffusers # │   ├── diffusers
# │   └── textual_inversion # │   └── textual_inversion
# ├── SD-2 # ├── sd-2
# │   ├── controlnet # │   ├── controlnet
# │   ├── lora # │   ├── lora
# │   ├── diffusers # │   ├── diffusers
# │   └── textual_inversion # │ └── textual_inversion
# └── support # └── core
# ├── codeformer # ├── face_reconstruction
# ├── gfpgan # │ ├── codeformer
# └── realesrgan # │ └── gfpgan
# ├── sd-conversion
# │ ├── clip-vit-large-patch14 - tokenizer, text_encoder subdirs
# │ ├── stable-diffusion-2 - tokenizer, text_encoder subdirs
# │ └── stable-diffusion-safety-checker
# └── upscaling
# └─── esrgan
class ConfigMeta(BaseModel): class ConfigMeta(BaseModel):

View File

@ -4,20 +4,24 @@ import torch
import safetensors.torch import safetensors.torch
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum
from diffusers import ModelMixin, ConfigMixin, StableDiffusionPipeline, AutoencoderKL, ControlNetModel from diffusers import ModelMixin, ConfigMixin, StableDiffusionPipeline, AutoencoderKL, ControlNetModel
from pathlib import Path from pathlib import Path
from typing import Callable, Literal, Union, Dict from typing import Callable, Literal, Union, Dict
from picklescan.scanner import scan_file_path from picklescan.scanner import scan_file_path
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from .models import BaseModelType, ModelType, VariantType from .models import BaseModelType, ModelType, ModelVariantType, SchedulerPredictionType
from .model_cache import SilenceWarnings from .model_cache import SilenceWarnings
@dataclass @dataclass
class ModelVariantInfo(object): class ModelVariantInfo(object):
model_type: ModelType model_type: ModelType
base_type: BaseModelType base_type: BaseModelType
variant_type: VariantType variant_type: ModelVariantType
prediction_type: SchedulerPredictionType
image_size: int
class ProbeBase(object): class ProbeBase(object):
'''forward declaration''' '''forward declaration'''
@ -27,7 +31,7 @@ class ModelProbe(object):
PROBES = { PROBES = {
'folder': { }, 'folder': { },
'file': { }, 'checkpoint': { },
} }
CLASS2TYPE = { CLASS2TYPE = {
@ -43,16 +47,28 @@ class ModelProbe(object):
probe_class: ProbeBase): probe_class: ProbeBase):
cls.PROBES[format][model_type] = probe_class cls.PROBES[format][model_type] = probe_class
@classmethod
def heuristic_probe(cls,
model: Union[Dict, ModelMixin, Path],
prediction_type_helper: Callable[[Path],BaseModelType]=None,
)->ModelVariantInfo:
if isinstance(model,Path):
return cls.probe(model_path=model,prediction_type_helper=prediction_type_helper)
elif isinstance(model,(dict,ModelMixin,ConfigMixin)):
return cls.probe(model_path=None, model=model, prediction_type_helper=prediction_type_helper)
else:
raise Exception("model parameter {model} is neither a Path, nor a model")
@classmethod @classmethod
def probe(cls, def probe(cls,
model_path: Path, model_path: Path,
model: Union[Dict, ModelMixin] = None, model: Union[Dict, ModelMixin] = None,
base_helper: Callable[[Path],BaseModelType] = None)->ModelVariantInfo: prediction_type_helper: Callable[[Path],BaseModelType] = None)->ModelVariantInfo:
''' '''
Probe the model at model_path and return sufficient information about it Probe the model at model_path and return sufficient information about it
to place it somewhere in the models directory hierarchy. If the model is to place it somewhere in the models directory hierarchy. If the model is
already loaded into memory, you may provide it as model in order to avoid already loaded into memory, you may provide it as model in order to avoid
opening it a second time. The base_helper callable is a function that receives opening it a second time. The prediction_type_helper callable is a function that receives
the path to the model and returns the BaseModelType. It is called to distinguish the path to the model and returns the BaseModelType. It is called to distinguish
between V2-Base and V2-768 SD models. between V2-Base and V2-768 SD models.
''' '''
@ -69,13 +85,18 @@ class ModelProbe(object):
probe_class = cls.PROBES[format].get(model_type) probe_class = cls.PROBES[format].get(model_type)
if not probe_class: if not probe_class:
return None return None
probe = probe_class(model_path, model, base_helper) probe = probe_class(model_path, model, prediction_type_helper)
base_type = probe.get_base_type() base_type = probe.get_base_type()
variant_type = probe.get_variant_type() variant_type = probe.get_variant_type()
prediction_type = probe.get_scheduler_prediction_type()
model_info = ModelVariantInfo( model_info = ModelVariantInfo(
model_type = model_type, model_type = model_type,
base_type = base_type, base_type = base_type,
variant_type = variant_type, variant_type = variant_type,
prediction_type = prediction_type,
image_size = 768 if (base_type==BaseModelType.StableDiffusion2 \
and prediction_type==SchedulerPredictionType.VPrediction \
) else 512
) )
except (KeyError, ValueError) as e: except (KeyError, ValueError) as e:
logger.error(f'An error occurred while probing {model_path}: {str(e)}') logger.error(f'An error occurred while probing {model_path}: {str(e)}')
@ -120,7 +141,8 @@ class ModelProbe(object):
config_path = i if i.exists() else c if c.exists() else None config_path = i if i.exists() else c if c.exists() else None
if config_path: if config_path:
conf = json.load(open(config_path,'r')) with open(config_path,'r') as file:
conf = json.load(file)
class_name = conf['_class_name'] class_name = conf['_class_name']
if type := cls.CLASS2TYPE.get(class_name): if type := cls.CLASS2TYPE.get(class_name):
@ -156,9 +178,12 @@ class ProbeBase(object):
def get_base_type(self)->BaseModelType: def get_base_type(self)->BaseModelType:
pass pass
def get_variant_type(self)->VariantType: def get_variant_type(self)->ModelVariantType:
pass pass
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
pass
class CheckpointProbeBase(ProbeBase): class CheckpointProbeBase(ProbeBase):
def __init__(self, def __init__(self,
checkpoint_path: Path, checkpoint_path: Path,
@ -172,44 +197,54 @@ class CheckpointProbeBase(ProbeBase):
def get_base_type(self)->BaseModelType: def get_base_type(self)->BaseModelType:
pass pass
def get_variant_type(self)-> VariantType: def get_variant_type(self)-> ModelVariantType:
model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path,self.checkpoint) model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path,self.checkpoint)
if model_type != ModelType.Pipeline: if model_type != ModelType.Pipeline:
return VariantType.Normal return ModelVariantType.Normal
state_dict = self.checkpoint.get('state_dict') or self.checkpoint state_dict = self.checkpoint.get('state_dict') or self.checkpoint
in_channels = state_dict[ in_channels = state_dict[
"model.diffusion_model.input_blocks.0.0.weight" "model.diffusion_model.input_blocks.0.0.weight"
].shape[1] ].shape[1]
if in_channels == 9: if in_channels == 9:
return VariantType.Inpaint return ModelVariantType.Inpaint
elif in_channels == 5: elif in_channels == 5:
return VariantType.Depth return ModelVariantType.Depth
else: else:
return None return None
class PipelineCheckpointProbe(CheckpointProbeBase): class PipelineCheckpointProbe(CheckpointProbeBase):
def get_base_type(self)->BaseModelType: def get_base_type(self)->BaseModelType:
checkpoint = self.checkpoint checkpoint = self.checkpoint
helper = self.helper
state_dict = self.checkpoint.get('state_dict') or checkpoint state_dict = self.checkpoint.get('state_dict') or checkpoint
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 768: if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
return BaseModelType.StableDiffusion1_5 return BaseModelType.StableDiffusion1
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
return BaseModelType.StableDiffusion2
raise Exception("Cannot determine base type")
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
type = self.get_base_type()
if type == BaseModelType.StableDiffusion1:
return SchedulerPredictionType.Epsilon
checkpoint = self.checkpoint
state_dict = self.checkpoint.get('state_dict') or checkpoint
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024: if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
if 'global_step' in checkpoint: if 'global_step' in checkpoint:
if checkpoint['global_step'] == 220000: if checkpoint['global_step'] == 220000:
return BaseModelType.StableDiffusion2Base return SchedulerPredictionType.Epsilon
elif checkpoint["global_step"] == 110000: elif checkpoint["global_step"] == 110000:
return BaseModelType.StableDiffusion2 return SchedulerPredictionType.VPrediction
if self.checkpoint_path and helper: if self.checkpoint_path and self.helper:
return helper(self.checkpoint_path) return self.helper(self.checkpoint_path)
else: else:
return None return None
class VaeCheckpointProbe(CheckpointProbeBase): class VaeCheckpointProbe(CheckpointProbeBase):
def get_base_type(self)->BaseModelType: def get_base_type(self)->BaseModelType:
# I can't find any standalone 2.X VAEs to test with! # I can't find any standalone 2.X VAEs to test with!
return BaseModelType.StableDiffusion1_5 return BaseModelType.StableDiffusion1
class LoRACheckpointProbe(CheckpointProbeBase): class LoRACheckpointProbe(CheckpointProbeBase):
def get_base_type(self)->BaseModelType: def get_base_type(self)->BaseModelType:
@ -224,7 +259,7 @@ class LoRACheckpointProbe(CheckpointProbeBase):
else 768 else 768
) )
if lora_token_vector_length == 768: if lora_token_vector_length == 768:
return BaseModelType.StableDiffusion1_5 return BaseModelType.StableDiffusion1
elif lora_token_vector_length == 1024: elif lora_token_vector_length == 1024:
return BaseModelType.StableDiffusion2 return BaseModelType.StableDiffusion2
else: else:
@ -240,9 +275,9 @@ class TextualInversionCheckpointProbe(CheckpointProbeBase):
else: else:
token_dim = list(checkpoint.values())[0].shape[0] token_dim = list(checkpoint.values())[0].shape[0]
if token_dim == 768: if token_dim == 768:
return BaseModelType.StableDiffusion1_5 return BaseModelType.StableDiffusion1
elif token_dim == 1024: elif token_dim == 1024:
return BaseModelType.StableDiffusion2Base return BaseModelType.StableDiffusion2
else: else:
return None return None
@ -255,7 +290,7 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
if key_name not in checkpoint: if key_name not in checkpoint:
continue continue
if checkpoint[key_name].shape[-1] == 768: if checkpoint[key_name].shape[-1] == 768:
return BaseModelType.StableDiffusion1_5 return BaseModelType.StableDiffusion1
elif self.checkpoint_path and self.helper: elif self.checkpoint_path and self.helper:
return self.helper(self.checkpoint_path) return self.helper(self.checkpoint_path)
@ -271,8 +306,8 @@ class FolderProbeBase(ProbeBase):
self.model = model self.model = model
self.folder_path = folder_path self.folder_path = folder_path
def get_variant_type(self)->VariantType: def get_variant_type(self)->ModelVariantType:
return VariantType.Normal return ModelVariantType.Normal
class PipelineFolderProbe(FolderProbeBase): class PipelineFolderProbe(FolderProbeBase):
def get_base_type(self)->BaseModelType: def get_base_type(self)->BaseModelType:
@ -280,22 +315,32 @@ class PipelineFolderProbe(FolderProbeBase):
unet_conf = self.model.unet.config unet_conf = self.model.unet.config
scheduler_conf = self.model.scheduler.config scheduler_conf = self.model.scheduler.config
else: else:
unet_conf = json.load(open(self.folder_path / 'unet' / 'config.json','r')) with open(self.folder_path / 'unet' / 'config.json','r') as file:
scheduler_conf = json.load(open(self.folder_path / 'scheduler' / 'scheduler_config.json','r')) unet_conf = json.load(file)
with open(self.folder_path / 'scheduler' / 'scheduler_config.json','r') as file:
scheduler_conf = json.load(file)
if unet_conf['cross_attention_dim'] == 768: if unet_conf['cross_attention_dim'] == 768:
return BaseModelType.StableDiffusion1_5 return BaseModelType.StableDiffusion1
elif unet_conf['cross_attention_dim'] == 1024: elif unet_conf['cross_attention_dim'] == 1024:
if scheduler_conf['prediction_type'] == "v_prediction": return BaseModelType.StableDiffusion2
return BaseModelType.StableDiffusion2
elif scheduler_conf['prediction_type'] == 'epsilon':
return BaseModelType.StableDiffusion2Base
else:
return BaseModelType.StableDiffusion2
else: else:
raise ValueError(f'Unknown base model for {self.folder_path}') raise ValueError(f'Unknown base model for {self.folder_path}')
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
if self.model:
scheduler_conf = self.model.scheduler.config
else:
with open(self.folder_path / 'scheduler' / 'scheduler_config.json','r') as file:
scheduler_conf = json.load(file)
if scheduler_conf['prediction_type'] == "v_prediction":
return SchedulerPredictionType.VPrediction
elif scheduler_conf['prediction_type'] == 'epsilon':
return SchedulerPredictionType.Epsilon
else:
return None
def get_variant_type(self)->VariantType: def get_variant_type(self)->ModelVariantType:
# This only works for pipelines! Any kind of # This only works for pipelines! Any kind of
# exception results in our returning the # exception results in our returning the
# "normal" variant type # "normal" variant type
@ -304,22 +349,23 @@ class PipelineFolderProbe(FolderProbeBase):
conf = self.model.unet.config conf = self.model.unet.config
else: else:
config_file = self.folder_path / 'unet' / 'config.json' config_file = self.folder_path / 'unet' / 'config.json'
conf = json.load(open(config_file,'r')) with open(config_file,'r') as file:
conf = json.load(file)
in_channels = conf['in_channels'] in_channels = conf['in_channels']
if in_channels == 9: if in_channels == 9:
return VariantType.Inpainting return ModelVariantType.Inpainting
elif in_channels == 5: elif in_channels == 5:
return VariantType.Depth return ModelVariantType.Depth
elif in_channels == 4: elif in_channels == 4:
return VariantType.Normal return ModelVariantType.Normal
except: except:
pass pass
return VariantType.Normal return ModelVariantType.Normal
class VaeFolderProbe(FolderProbeBase): class VaeFolderProbe(FolderProbeBase):
def get_base_type(self)->BaseModelType: def get_base_type(self)->BaseModelType:
return BaseModelType.StableDiffusion1_5 return BaseModelType.StableDiffusion1
class TextualInversionFolderProbe(FolderProbeBase): class TextualInversionFolderProbe(FolderProbeBase):
def get_base_type(self)->BaseModelType: def get_base_type(self)->BaseModelType:
@ -336,7 +382,7 @@ class ControlNetFolderProbe(FolderProbeBase):
return None return None
config = json.load(config_file) config = json.load(config_file)
# no obvious way to distinguish between sd2-base and sd2-768 # no obvious way to distinguish between sd2-base and sd2-768
return BaseModelType.StableDiffusion1_5 \ return BaseModelType.StableDiffusion1 \
if config['cross_attention_dim']==768 \ if config['cross_attention_dim']==768 \
else BaseModelType.StableDiffusion2 else BaseModelType.StableDiffusion2
@ -350,8 +396,8 @@ ModelProbe.register_probe('folder', ModelType.Vae, VaeFolderProbe)
ModelProbe.register_probe('folder', ModelType.Lora, LoRAFolderProbe) ModelProbe.register_probe('folder', ModelType.Lora, LoRAFolderProbe)
ModelProbe.register_probe('folder', ModelType.TextualInversion, TextualInversionFolderProbe) ModelProbe.register_probe('folder', ModelType.TextualInversion, TextualInversionFolderProbe)
ModelProbe.register_probe('folder', ModelType.ControlNet, ControlNetFolderProbe) ModelProbe.register_probe('folder', ModelType.ControlNet, ControlNetFolderProbe)
ModelProbe.register_probe('file', ModelType.Pipeline, PipelineCheckpointProbe) ModelProbe.register_probe('checkpoint', ModelType.Pipeline, PipelineCheckpointProbe)
ModelProbe.register_probe('file', ModelType.Vae, VaeCheckpointProbe) ModelProbe.register_probe('checkpoint', ModelType.Vae, VaeCheckpointProbe)
ModelProbe.register_probe('file', ModelType.Lora, LoRACheckpointProbe) ModelProbe.register_probe('checkpoint', ModelType.Lora, LoRACheckpointProbe)
ModelProbe.register_probe('file', ModelType.TextualInversion, TextualInversionCheckpointProbe) ModelProbe.register_probe('checkpoint', ModelType.TextualInversion, TextualInversionCheckpointProbe)
ModelProbe.register_probe('file', ModelType.ControlNet, ControlNetCheckpointProbe) ModelProbe.register_probe('checkpoint', ModelType.ControlNet, ControlNetCheckpointProbe)

View File

@ -1,4 +1,4 @@
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, VariantType from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType
from .stable_diffusion import StableDiffusion15Model, StableDiffusion2Model, StableDiffusion2BaseModel from .stable_diffusion import StableDiffusion15Model, StableDiffusion2Model, StableDiffusion2BaseModel
from .vae import VaeModel from .vae import VaeModel
from .lora import LoRAModel from .lora import LoRAModel
@ -10,7 +10,7 @@ class ControlNetModel:
pass pass
MODEL_CLASSES = { MODEL_CLASSES = {
BaseModelType.StableDiffusion1_5: { BaseModelType.StableDiffusion1: {
ModelType.Pipeline: StableDiffusion15Model, ModelType.Pipeline: StableDiffusion15Model,
ModelType.Vae: VaeModel, ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel, ModelType.Lora: LoRAModel,
@ -24,13 +24,6 @@ MODEL_CLASSES = {
ModelType.ControlNet: ControlNetModel, ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel, ModelType.TextualInversion: TextualInversionModel,
}, },
BaseModelType.StableDiffusion2Base: {
ModelType.Pipeline: StableDiffusion2BaseModel,
ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
},
#BaseModelType.Kandinsky2_1: { #BaseModelType.Kandinsky2_1: {
# ModelType.Pipeline: Kandinsky2_1Model, # ModelType.Pipeline: Kandinsky2_1Model,
# ModelType.MoVQ: MoVQModel, # ModelType.MoVQ: MoVQModel,

View File

@ -14,8 +14,7 @@ class BaseModelType(str, Enum):
#StableDiffusion2 = "stable_diffusion_2" #StableDiffusion2 = "stable_diffusion_2"
#StableDiffusion2Base = "stable_diffusion_2_base" #StableDiffusion2Base = "stable_diffusion_2_base"
# TODO: maybe then add sample size(512/768)? # TODO: maybe then add sample size(512/768)?
StableDiffusion1_5 = "sd-1.5" StableDiffusion1 = "sd-1"
StableDiffusion2Base = "sd-2-base" # 512 pixels; this will have epsilon parameterization
StableDiffusion2 = "sd-2" # 768 pixels; this will have v-prediction parameterization StableDiffusion2 = "sd-2" # 768 pixels; this will have v-prediction parameterization
#Kandinsky2_1 = "kandinsky_2_1" #Kandinsky2_1 = "kandinsky_2_1"
@ -35,10 +34,15 @@ class SubModelType(str, Enum):
SafetyChecker = "safety_checker" SafetyChecker = "safety_checker"
#MoVQ = "movq" #MoVQ = "movq"
class VariantType(str, Enum): class ModelVariantType(str, Enum):
Normal = "normal" Normal = "normal"
Inpaint = "inpaint" Inpaint = "inpaint"
Depth = "depth" Depth = "depth"
class SchedulerPredictionType(str, Enum):
Epsilon = "epsilon"
VPrediction = "v_prediction"
Sample = "sample"
class ModelError(str, Enum): class ModelError(str, Enum):
NotFound = "not_found" NotFound = "not_found"

View File

@ -10,14 +10,11 @@ from .base import (
BaseModelType, BaseModelType,
ModelType, ModelType,
SubModelType, SubModelType,
VariantType, ModelVariantType,
DiffusersModel, DiffusersModel,
) )
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
ModelVariantType = VariantType # TODO:
# TODO: how to name properly # TODO: how to name properly
class StableDiffusion15Model(DiffusersModel): class StableDiffusion15Model(DiffusersModel):