mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add SchedulerPredictionType and ModelVariantType enums
This commit is contained in:
parent
36eb1bd893
commit
1439dc7712
@ -3,4 +3,4 @@ Initialization file for invokeai.backend.model_management
|
||||
"""
|
||||
from .model_manager import ModelManager, ModelInfo
|
||||
from .model_cache import ModelCache
|
||||
from .models import BaseModelType, ModelType, SubModelType, VariantType
|
||||
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType
|
||||
|
@ -200,20 +200,27 @@ MAX_CACHE_SIZE = 6.0 # GB
|
||||
|
||||
# layout of the models directory:
|
||||
# models
|
||||
# ├── SD-1
|
||||
# ├── sd-1
|
||||
# │ ├── controlnet
|
||||
# │ ├── lora
|
||||
# │ ├── diffusers
|
||||
# │ └── textual_inversion
|
||||
# ├── SD-2
|
||||
# ├── sd-2
|
||||
# │ ├── controlnet
|
||||
# │ ├── lora
|
||||
# │ ├── diffusers
|
||||
# │ └── textual_inversion
|
||||
# └── support
|
||||
# ├── codeformer
|
||||
# ├── gfpgan
|
||||
# └── realesrgan
|
||||
# │ └── textual_inversion
|
||||
# └── core
|
||||
# ├── face_reconstruction
|
||||
# │ ├── codeformer
|
||||
# │ └── gfpgan
|
||||
# ├── sd-conversion
|
||||
# │ ├── clip-vit-large-patch14 - tokenizer, text_encoder subdirs
|
||||
# │ ├── stable-diffusion-2 - tokenizer, text_encoder subdirs
|
||||
# │ └── stable-diffusion-safety-checker
|
||||
# └── upscaling
|
||||
# └─── esrgan
|
||||
|
||||
|
||||
|
||||
class ConfigMeta(BaseModel):
|
||||
|
@ -4,20 +4,24 @@ import torch
|
||||
import safetensors.torch
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
from diffusers import ModelMixin, ConfigMixin, StableDiffusionPipeline, AutoencoderKL, ControlNetModel
|
||||
from pathlib import Path
|
||||
from typing import Callable, Literal, Union, Dict
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from .models import BaseModelType, ModelType, VariantType
|
||||
from .models import BaseModelType, ModelType, ModelVariantType, SchedulerPredictionType
|
||||
from .model_cache import SilenceWarnings
|
||||
|
||||
@dataclass
|
||||
class ModelVariantInfo(object):
|
||||
model_type: ModelType
|
||||
base_type: BaseModelType
|
||||
variant_type: VariantType
|
||||
variant_type: ModelVariantType
|
||||
prediction_type: SchedulerPredictionType
|
||||
image_size: int
|
||||
|
||||
class ProbeBase(object):
|
||||
'''forward declaration'''
|
||||
@ -27,7 +31,7 @@ class ModelProbe(object):
|
||||
|
||||
PROBES = {
|
||||
'folder': { },
|
||||
'file': { },
|
||||
'checkpoint': { },
|
||||
}
|
||||
|
||||
CLASS2TYPE = {
|
||||
@ -43,16 +47,28 @@ class ModelProbe(object):
|
||||
probe_class: ProbeBase):
|
||||
cls.PROBES[format][model_type] = probe_class
|
||||
|
||||
@classmethod
|
||||
def heuristic_probe(cls,
|
||||
model: Union[Dict, ModelMixin, Path],
|
||||
prediction_type_helper: Callable[[Path],BaseModelType]=None,
|
||||
)->ModelVariantInfo:
|
||||
if isinstance(model,Path):
|
||||
return cls.probe(model_path=model,prediction_type_helper=prediction_type_helper)
|
||||
elif isinstance(model,(dict,ModelMixin,ConfigMixin)):
|
||||
return cls.probe(model_path=None, model=model, prediction_type_helper=prediction_type_helper)
|
||||
else:
|
||||
raise Exception("model parameter {model} is neither a Path, nor a model")
|
||||
|
||||
@classmethod
|
||||
def probe(cls,
|
||||
model_path: Path,
|
||||
model: Union[Dict, ModelMixin] = None,
|
||||
base_helper: Callable[[Path],BaseModelType] = None)->ModelVariantInfo:
|
||||
prediction_type_helper: Callable[[Path],BaseModelType] = None)->ModelVariantInfo:
|
||||
'''
|
||||
Probe the model at model_path and return sufficient information about it
|
||||
to place it somewhere in the models directory hierarchy. If the model is
|
||||
already loaded into memory, you may provide it as model in order to avoid
|
||||
opening it a second time. The base_helper callable is a function that receives
|
||||
opening it a second time. The prediction_type_helper callable is a function that receives
|
||||
the path to the model and returns the BaseModelType. It is called to distinguish
|
||||
between V2-Base and V2-768 SD models.
|
||||
'''
|
||||
@ -69,13 +85,18 @@ class ModelProbe(object):
|
||||
probe_class = cls.PROBES[format].get(model_type)
|
||||
if not probe_class:
|
||||
return None
|
||||
probe = probe_class(model_path, model, base_helper)
|
||||
probe = probe_class(model_path, model, prediction_type_helper)
|
||||
base_type = probe.get_base_type()
|
||||
variant_type = probe.get_variant_type()
|
||||
prediction_type = probe.get_scheduler_prediction_type()
|
||||
model_info = ModelVariantInfo(
|
||||
model_type = model_type,
|
||||
base_type = base_type,
|
||||
variant_type = variant_type,
|
||||
prediction_type = prediction_type,
|
||||
image_size = 768 if (base_type==BaseModelType.StableDiffusion2 \
|
||||
and prediction_type==SchedulerPredictionType.VPrediction \
|
||||
) else 512
|
||||
)
|
||||
except (KeyError, ValueError) as e:
|
||||
logger.error(f'An error occurred while probing {model_path}: {str(e)}')
|
||||
@ -120,7 +141,8 @@ class ModelProbe(object):
|
||||
config_path = i if i.exists() else c if c.exists() else None
|
||||
|
||||
if config_path:
|
||||
conf = json.load(open(config_path,'r'))
|
||||
with open(config_path,'r') as file:
|
||||
conf = json.load(file)
|
||||
class_name = conf['_class_name']
|
||||
|
||||
if type := cls.CLASS2TYPE.get(class_name):
|
||||
@ -156,9 +178,12 @@ class ProbeBase(object):
|
||||
def get_base_type(self)->BaseModelType:
|
||||
pass
|
||||
|
||||
def get_variant_type(self)->VariantType:
|
||||
def get_variant_type(self)->ModelVariantType:
|
||||
pass
|
||||
|
||||
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
|
||||
pass
|
||||
|
||||
class CheckpointProbeBase(ProbeBase):
|
||||
def __init__(self,
|
||||
checkpoint_path: Path,
|
||||
@ -172,44 +197,54 @@ class CheckpointProbeBase(ProbeBase):
|
||||
def get_base_type(self)->BaseModelType:
|
||||
pass
|
||||
|
||||
def get_variant_type(self)-> VariantType:
|
||||
def get_variant_type(self)-> ModelVariantType:
|
||||
model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path,self.checkpoint)
|
||||
if model_type != ModelType.Pipeline:
|
||||
return VariantType.Normal
|
||||
return ModelVariantType.Normal
|
||||
state_dict = self.checkpoint.get('state_dict') or self.checkpoint
|
||||
in_channels = state_dict[
|
||||
"model.diffusion_model.input_blocks.0.0.weight"
|
||||
].shape[1]
|
||||
if in_channels == 9:
|
||||
return VariantType.Inpaint
|
||||
return ModelVariantType.Inpaint
|
||||
elif in_channels == 5:
|
||||
return VariantType.Depth
|
||||
return ModelVariantType.Depth
|
||||
else:
|
||||
return None
|
||||
|
||||
class PipelineCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self)->BaseModelType:
|
||||
checkpoint = self.checkpoint
|
||||
helper = self.helper
|
||||
state_dict = self.checkpoint.get('state_dict') or checkpoint
|
||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
|
||||
return BaseModelType.StableDiffusion1_5
|
||||
return BaseModelType.StableDiffusion1
|
||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
raise Exception("Cannot determine base type")
|
||||
|
||||
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
|
||||
type = self.get_base_type()
|
||||
if type == BaseModelType.StableDiffusion1:
|
||||
return SchedulerPredictionType.Epsilon
|
||||
checkpoint = self.checkpoint
|
||||
state_dict = self.checkpoint.get('state_dict') or checkpoint
|
||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
||||
if 'global_step' in checkpoint:
|
||||
if checkpoint['global_step'] == 220000:
|
||||
return BaseModelType.StableDiffusion2Base
|
||||
return SchedulerPredictionType.Epsilon
|
||||
elif checkpoint["global_step"] == 110000:
|
||||
return BaseModelType.StableDiffusion2
|
||||
if self.checkpoint_path and helper:
|
||||
return helper(self.checkpoint_path)
|
||||
return SchedulerPredictionType.VPrediction
|
||||
if self.checkpoint_path and self.helper:
|
||||
return self.helper(self.checkpoint_path)
|
||||
else:
|
||||
return None
|
||||
|
||||
class VaeCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self)->BaseModelType:
|
||||
# I can't find any standalone 2.X VAEs to test with!
|
||||
return BaseModelType.StableDiffusion1_5
|
||||
return BaseModelType.StableDiffusion1
|
||||
|
||||
class LoRACheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self)->BaseModelType:
|
||||
@ -224,7 +259,7 @@ class LoRACheckpointProbe(CheckpointProbeBase):
|
||||
else 768
|
||||
)
|
||||
if lora_token_vector_length == 768:
|
||||
return BaseModelType.StableDiffusion1_5
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif lora_token_vector_length == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
else:
|
||||
@ -240,9 +275,9 @@ class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
||||
else:
|
||||
token_dim = list(checkpoint.values())[0].shape[0]
|
||||
if token_dim == 768:
|
||||
return BaseModelType.StableDiffusion1_5
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif token_dim == 1024:
|
||||
return BaseModelType.StableDiffusion2Base
|
||||
return BaseModelType.StableDiffusion2
|
||||
else:
|
||||
return None
|
||||
|
||||
@ -255,7 +290,7 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
|
||||
if key_name not in checkpoint:
|
||||
continue
|
||||
if checkpoint[key_name].shape[-1] == 768:
|
||||
return BaseModelType.StableDiffusion1_5
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif self.checkpoint_path and self.helper:
|
||||
return self.helper(self.checkpoint_path)
|
||||
|
||||
@ -271,8 +306,8 @@ class FolderProbeBase(ProbeBase):
|
||||
self.model = model
|
||||
self.folder_path = folder_path
|
||||
|
||||
def get_variant_type(self)->VariantType:
|
||||
return VariantType.Normal
|
||||
def get_variant_type(self)->ModelVariantType:
|
||||
return ModelVariantType.Normal
|
||||
|
||||
class PipelineFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self)->BaseModelType:
|
||||
@ -280,22 +315,32 @@ class PipelineFolderProbe(FolderProbeBase):
|
||||
unet_conf = self.model.unet.config
|
||||
scheduler_conf = self.model.scheduler.config
|
||||
else:
|
||||
unet_conf = json.load(open(self.folder_path / 'unet' / 'config.json','r'))
|
||||
scheduler_conf = json.load(open(self.folder_path / 'scheduler' / 'scheduler_config.json','r'))
|
||||
with open(self.folder_path / 'unet' / 'config.json','r') as file:
|
||||
unet_conf = json.load(file)
|
||||
with open(self.folder_path / 'scheduler' / 'scheduler_config.json','r') as file:
|
||||
scheduler_conf = json.load(file)
|
||||
|
||||
if unet_conf['cross_attention_dim'] == 768:
|
||||
return BaseModelType.StableDiffusion1_5
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif unet_conf['cross_attention_dim'] == 1024:
|
||||
if scheduler_conf['prediction_type'] == "v_prediction":
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif scheduler_conf['prediction_type'] == 'epsilon':
|
||||
return BaseModelType.StableDiffusion2Base
|
||||
else:
|
||||
return BaseModelType.StableDiffusion2
|
||||
return BaseModelType.StableDiffusion2
|
||||
else:
|
||||
raise ValueError(f'Unknown base model for {self.folder_path}')
|
||||
|
||||
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
|
||||
if self.model:
|
||||
scheduler_conf = self.model.scheduler.config
|
||||
else:
|
||||
with open(self.folder_path / 'scheduler' / 'scheduler_config.json','r') as file:
|
||||
scheduler_conf = json.load(file)
|
||||
if scheduler_conf['prediction_type'] == "v_prediction":
|
||||
return SchedulerPredictionType.VPrediction
|
||||
elif scheduler_conf['prediction_type'] == 'epsilon':
|
||||
return SchedulerPredictionType.Epsilon
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_variant_type(self)->VariantType:
|
||||
def get_variant_type(self)->ModelVariantType:
|
||||
# This only works for pipelines! Any kind of
|
||||
# exception results in our returning the
|
||||
# "normal" variant type
|
||||
@ -304,22 +349,23 @@ class PipelineFolderProbe(FolderProbeBase):
|
||||
conf = self.model.unet.config
|
||||
else:
|
||||
config_file = self.folder_path / 'unet' / 'config.json'
|
||||
conf = json.load(open(config_file,'r'))
|
||||
with open(config_file,'r') as file:
|
||||
conf = json.load(file)
|
||||
|
||||
in_channels = conf['in_channels']
|
||||
if in_channels == 9:
|
||||
return VariantType.Inpainting
|
||||
return ModelVariantType.Inpainting
|
||||
elif in_channels == 5:
|
||||
return VariantType.Depth
|
||||
return ModelVariantType.Depth
|
||||
elif in_channels == 4:
|
||||
return VariantType.Normal
|
||||
return ModelVariantType.Normal
|
||||
except:
|
||||
pass
|
||||
return VariantType.Normal
|
||||
return ModelVariantType.Normal
|
||||
|
||||
class VaeFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self)->BaseModelType:
|
||||
return BaseModelType.StableDiffusion1_5
|
||||
return BaseModelType.StableDiffusion1
|
||||
|
||||
class TextualInversionFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self)->BaseModelType:
|
||||
@ -336,7 +382,7 @@ class ControlNetFolderProbe(FolderProbeBase):
|
||||
return None
|
||||
config = json.load(config_file)
|
||||
# no obvious way to distinguish between sd2-base and sd2-768
|
||||
return BaseModelType.StableDiffusion1_5 \
|
||||
return BaseModelType.StableDiffusion1 \
|
||||
if config['cross_attention_dim']==768 \
|
||||
else BaseModelType.StableDiffusion2
|
||||
|
||||
@ -350,8 +396,8 @@ ModelProbe.register_probe('folder', ModelType.Vae, VaeFolderProbe)
|
||||
ModelProbe.register_probe('folder', ModelType.Lora, LoRAFolderProbe)
|
||||
ModelProbe.register_probe('folder', ModelType.TextualInversion, TextualInversionFolderProbe)
|
||||
ModelProbe.register_probe('folder', ModelType.ControlNet, ControlNetFolderProbe)
|
||||
ModelProbe.register_probe('file', ModelType.Pipeline, PipelineCheckpointProbe)
|
||||
ModelProbe.register_probe('file', ModelType.Vae, VaeCheckpointProbe)
|
||||
ModelProbe.register_probe('file', ModelType.Lora, LoRACheckpointProbe)
|
||||
ModelProbe.register_probe('file', ModelType.TextualInversion, TextualInversionCheckpointProbe)
|
||||
ModelProbe.register_probe('file', ModelType.ControlNet, ControlNetCheckpointProbe)
|
||||
ModelProbe.register_probe('checkpoint', ModelType.Pipeline, PipelineCheckpointProbe)
|
||||
ModelProbe.register_probe('checkpoint', ModelType.Vae, VaeCheckpointProbe)
|
||||
ModelProbe.register_probe('checkpoint', ModelType.Lora, LoRACheckpointProbe)
|
||||
ModelProbe.register_probe('checkpoint', ModelType.TextualInversion, TextualInversionCheckpointProbe)
|
||||
ModelProbe.register_probe('checkpoint', ModelType.ControlNet, ControlNetCheckpointProbe)
|
||||
|
@ -1,4 +1,4 @@
|
||||
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, VariantType
|
||||
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType
|
||||
from .stable_diffusion import StableDiffusion15Model, StableDiffusion2Model, StableDiffusion2BaseModel
|
||||
from .vae import VaeModel
|
||||
from .lora import LoRAModel
|
||||
@ -10,7 +10,7 @@ class ControlNetModel:
|
||||
pass
|
||||
|
||||
MODEL_CLASSES = {
|
||||
BaseModelType.StableDiffusion1_5: {
|
||||
BaseModelType.StableDiffusion1: {
|
||||
ModelType.Pipeline: StableDiffusion15Model,
|
||||
ModelType.Vae: VaeModel,
|
||||
ModelType.Lora: LoRAModel,
|
||||
@ -24,13 +24,6 @@ MODEL_CLASSES = {
|
||||
ModelType.ControlNet: ControlNetModel,
|
||||
ModelType.TextualInversion: TextualInversionModel,
|
||||
},
|
||||
BaseModelType.StableDiffusion2Base: {
|
||||
ModelType.Pipeline: StableDiffusion2BaseModel,
|
||||
ModelType.Vae: VaeModel,
|
||||
ModelType.Lora: LoRAModel,
|
||||
ModelType.ControlNet: ControlNetModel,
|
||||
ModelType.TextualInversion: TextualInversionModel,
|
||||
},
|
||||
#BaseModelType.Kandinsky2_1: {
|
||||
# ModelType.Pipeline: Kandinsky2_1Model,
|
||||
# ModelType.MoVQ: MoVQModel,
|
||||
|
@ -14,8 +14,7 @@ class BaseModelType(str, Enum):
|
||||
#StableDiffusion2 = "stable_diffusion_2"
|
||||
#StableDiffusion2Base = "stable_diffusion_2_base"
|
||||
# TODO: maybe then add sample size(512/768)?
|
||||
StableDiffusion1_5 = "sd-1.5"
|
||||
StableDiffusion2Base = "sd-2-base" # 512 pixels; this will have epsilon parameterization
|
||||
StableDiffusion1 = "sd-1"
|
||||
StableDiffusion2 = "sd-2" # 768 pixels; this will have v-prediction parameterization
|
||||
#Kandinsky2_1 = "kandinsky_2_1"
|
||||
|
||||
@ -35,10 +34,15 @@ class SubModelType(str, Enum):
|
||||
SafetyChecker = "safety_checker"
|
||||
#MoVQ = "movq"
|
||||
|
||||
class VariantType(str, Enum):
|
||||
class ModelVariantType(str, Enum):
|
||||
Normal = "normal"
|
||||
Inpaint = "inpaint"
|
||||
Depth = "depth"
|
||||
|
||||
class SchedulerPredictionType(str, Enum):
|
||||
Epsilon = "epsilon"
|
||||
VPrediction = "v_prediction"
|
||||
Sample = "sample"
|
||||
|
||||
class ModelError(str, Enum):
|
||||
NotFound = "not_found"
|
||||
|
@ -10,14 +10,11 @@ from .base import (
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
VariantType,
|
||||
ModelVariantType,
|
||||
DiffusersModel,
|
||||
)
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
ModelVariantType = VariantType # TODO:
|
||||
|
||||
|
||||
# TODO: how to name properly
|
||||
class StableDiffusion15Model(DiffusersModel):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user