2023-06-11 16:51:50 +00:00
|
|
|
import json
|
|
|
|
import torch
|
|
|
|
import safetensors.torch
|
|
|
|
|
|
|
|
from dataclasses import dataclass
|
2023-06-12 20:07:04 +00:00
|
|
|
|
2023-06-23 20:35:39 +00:00
|
|
|
from diffusers import ModelMixin, ConfigMixin
|
2023-06-11 16:51:50 +00:00
|
|
|
from pathlib import Path
|
|
|
|
from typing import Callable, Literal, Union, Dict
|
|
|
|
from picklescan.scanner import scan_file_path
|
|
|
|
|
2023-06-24 16:37:26 +00:00
|
|
|
from .models import (
|
|
|
|
BaseModelType, ModelType, ModelVariantType,
|
|
|
|
SchedulerPredictionType, SilenceWarnings,
|
|
|
|
)
|
|
|
|
from .models.base import read_checkpoint_meta
|
2023-06-11 16:51:50 +00:00
|
|
|
|
|
|
|
@dataclass
|
2023-06-17 02:54:36 +00:00
|
|
|
class ModelProbeInfo(object):
|
2023-06-11 16:51:50 +00:00
|
|
|
model_type: ModelType
|
|
|
|
base_type: BaseModelType
|
2023-06-12 20:07:04 +00:00
|
|
|
variant_type: ModelVariantType
|
|
|
|
prediction_type: SchedulerPredictionType
|
2023-06-13 15:05:12 +00:00
|
|
|
upcast_attention: bool
|
2023-06-27 16:30:53 +00:00
|
|
|
format: Literal['diffusers','checkpoint', 'lycoris']
|
2023-06-12 20:07:04 +00:00
|
|
|
image_size: int
|
2023-06-11 16:51:50 +00:00
|
|
|
|
|
|
|
class ProbeBase(object):
|
|
|
|
'''forward declaration'''
|
|
|
|
pass
|
|
|
|
|
|
|
|
class ModelProbe(object):
|
|
|
|
|
|
|
|
PROBES = {
|
2023-06-17 02:54:36 +00:00
|
|
|
'diffusers': { },
|
2023-06-12 20:07:04 +00:00
|
|
|
'checkpoint': { },
|
2023-06-11 16:51:50 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
CLASS2TYPE = {
|
2023-06-24 15:45:49 +00:00
|
|
|
'StableDiffusionPipeline' : ModelType.Main,
|
2023-06-11 23:51:53 +00:00
|
|
|
'AutoencoderKL' : ModelType.Vae,
|
|
|
|
'ControlNetModel' : ModelType.ControlNet,
|
2023-06-11 16:51:50 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def register_probe(cls,
|
2023-06-17 02:54:36 +00:00
|
|
|
format: Literal['diffusers','checkpoint'],
|
2023-06-11 16:51:50 +00:00
|
|
|
model_type: ModelType,
|
|
|
|
probe_class: ProbeBase):
|
|
|
|
cls.PROBES[format][model_type] = probe_class
|
|
|
|
|
2023-06-12 20:07:04 +00:00
|
|
|
@classmethod
|
|
|
|
def heuristic_probe(cls,
|
|
|
|
model: Union[Dict, ModelMixin, Path],
|
2023-06-17 02:54:36 +00:00
|
|
|
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
|
|
|
)->ModelProbeInfo:
|
2023-06-12 20:07:04 +00:00
|
|
|
if isinstance(model,Path):
|
|
|
|
return cls.probe(model_path=model,prediction_type_helper=prediction_type_helper)
|
|
|
|
elif isinstance(model,(dict,ModelMixin,ConfigMixin)):
|
|
|
|
return cls.probe(model_path=None, model=model, prediction_type_helper=prediction_type_helper)
|
|
|
|
else:
|
|
|
|
raise Exception("model parameter {model} is neither a Path, nor a model")
|
|
|
|
|
2023-06-11 16:51:50 +00:00
|
|
|
@classmethod
|
|
|
|
def probe(cls,
|
|
|
|
model_path: Path,
|
|
|
|
model: Union[Dict, ModelMixin] = None,
|
2023-06-17 02:54:36 +00:00
|
|
|
prediction_type_helper: Callable[[Path],SchedulerPredictionType] = None)->ModelProbeInfo:
|
2023-06-11 16:51:50 +00:00
|
|
|
'''
|
|
|
|
Probe the model at model_path and return sufficient information about it
|
|
|
|
to place it somewhere in the models directory hierarchy. If the model is
|
|
|
|
already loaded into memory, you may provide it as model in order to avoid
|
2023-06-12 20:07:04 +00:00
|
|
|
opening it a second time. The prediction_type_helper callable is a function that receives
|
2023-06-11 16:51:50 +00:00
|
|
|
the path to the model and returns the BaseModelType. It is called to distinguish
|
|
|
|
between V2-Base and V2-768 SD models.
|
|
|
|
'''
|
2023-06-11 23:51:53 +00:00
|
|
|
if model_path:
|
2023-06-27 16:30:53 +00:00
|
|
|
format_type = 'diffusers' if model_path.is_dir() else 'checkpoint'
|
2023-06-11 23:51:53 +00:00
|
|
|
else:
|
2023-06-27 16:30:53 +00:00
|
|
|
format_type = 'diffusers' if isinstance(model,(ConfigMixin,ModelMixin)) else 'checkpoint'
|
2023-06-13 05:27:51 +00:00
|
|
|
|
2023-06-11 16:51:50 +00:00
|
|
|
model_info = None
|
|
|
|
try:
|
|
|
|
model_type = cls.get_model_type_from_folder(model_path, model) \
|
2023-06-27 16:30:53 +00:00
|
|
|
if format_type == 'diffusers' \
|
2023-06-11 16:51:50 +00:00
|
|
|
else cls.get_model_type_from_checkpoint(model_path, model)
|
2023-06-27 16:30:53 +00:00
|
|
|
probe_class = cls.PROBES[format_type].get(model_type)
|
2023-06-11 16:51:50 +00:00
|
|
|
if not probe_class:
|
|
|
|
return None
|
2023-06-12 20:07:04 +00:00
|
|
|
probe = probe_class(model_path, model, prediction_type_helper)
|
2023-06-11 16:51:50 +00:00
|
|
|
base_type = probe.get_base_type()
|
|
|
|
variant_type = probe.get_variant_type()
|
2023-06-12 20:07:04 +00:00
|
|
|
prediction_type = probe.get_scheduler_prediction_type()
|
2023-06-27 16:30:53 +00:00
|
|
|
format = probe.get_format()
|
2023-06-17 02:54:36 +00:00
|
|
|
model_info = ModelProbeInfo(
|
2023-06-11 16:51:50 +00:00
|
|
|
model_type = model_type,
|
|
|
|
base_type = base_type,
|
|
|
|
variant_type = variant_type,
|
2023-06-12 20:07:04 +00:00
|
|
|
prediction_type = prediction_type,
|
2023-06-13 15:05:12 +00:00
|
|
|
upcast_attention = (base_type==BaseModelType.StableDiffusion2 \
|
|
|
|
and prediction_type==SchedulerPredictionType.VPrediction),
|
2023-06-13 05:27:51 +00:00
|
|
|
format = format,
|
2023-06-12 20:07:04 +00:00
|
|
|
image_size = 768 if (base_type==BaseModelType.StableDiffusion2 \
|
|
|
|
and prediction_type==SchedulerPredictionType.VPrediction \
|
2023-06-13 15:05:12 +00:00
|
|
|
) else 512,
|
2023-06-11 16:51:50 +00:00
|
|
|
)
|
2023-06-23 20:35:39 +00:00
|
|
|
except Exception:
|
2023-06-13 05:27:51 +00:00
|
|
|
return None
|
2023-06-11 16:51:50 +00:00
|
|
|
|
|
|
|
return model_info
|
|
|
|
|
|
|
|
@classmethod
|
2023-06-24 16:37:26 +00:00
|
|
|
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict) -> ModelType:
|
|
|
|
if model_path.suffix not in ('.bin','.pt','.ckpt','.safetensors','.pth'):
|
2023-06-13 05:27:51 +00:00
|
|
|
return None
|
2023-06-24 16:37:26 +00:00
|
|
|
|
|
|
|
if model_path.name == "learned_embeds.bin":
|
2023-06-13 05:27:51 +00:00
|
|
|
return ModelType.TextualInversion
|
2023-06-24 16:37:26 +00:00
|
|
|
|
2023-06-27 16:30:53 +00:00
|
|
|
ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True)
|
|
|
|
ckpt = ckpt.get("state_dict", ckpt)
|
2023-06-24 16:37:26 +00:00
|
|
|
|
2023-06-27 16:30:53 +00:00
|
|
|
for key in ckpt.keys():
|
2023-06-24 16:37:26 +00:00
|
|
|
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
|
|
|
|
return ModelType.Main
|
|
|
|
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
|
|
|
|
return ModelType.Vae
|
|
|
|
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
|
|
|
|
return ModelType.Lora
|
|
|
|
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
|
|
|
|
return ModelType.ControlNet
|
|
|
|
elif key in {"emb_params", "string_to_param"}:
|
|
|
|
return ModelType.TextualInversion
|
|
|
|
|
|
|
|
else:
|
|
|
|
# diffusers-ti
|
2023-06-27 16:30:53 +00:00
|
|
|
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
2023-06-24 16:37:26 +00:00
|
|
|
return ModelType.TextualInversion
|
2023-06-23 20:35:39 +00:00
|
|
|
|
2023-06-24 16:37:26 +00:00
|
|
|
raise ValueError("Unable to determine model type")
|
2023-06-11 16:51:50 +00:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin)->ModelType:
|
|
|
|
'''
|
|
|
|
Get the model type of a hugging-face style folder.
|
|
|
|
'''
|
2023-06-13 05:27:51 +00:00
|
|
|
class_name = None
|
2023-06-11 23:51:53 +00:00
|
|
|
if model:
|
|
|
|
class_name = model.__class__.__name__
|
|
|
|
else:
|
|
|
|
if (folder_path / 'learned_embeds.bin').exists():
|
|
|
|
return ModelType.TextualInversion
|
2023-06-11 16:51:50 +00:00
|
|
|
|
2023-06-11 23:51:53 +00:00
|
|
|
if (folder_path / 'pytorch_lora_weights.bin').exists():
|
|
|
|
return ModelType.Lora
|
2023-06-11 16:51:50 +00:00
|
|
|
|
2023-06-11 23:51:53 +00:00
|
|
|
i = folder_path / 'model_index.json'
|
|
|
|
c = folder_path / 'config.json'
|
|
|
|
config_path = i if i.exists() else c if c.exists() else None
|
|
|
|
|
|
|
|
if config_path:
|
2023-06-12 20:07:04 +00:00
|
|
|
with open(config_path,'r') as file:
|
|
|
|
conf = json.load(file)
|
2023-06-11 23:51:53 +00:00
|
|
|
class_name = conf['_class_name']
|
|
|
|
|
2023-06-13 05:27:51 +00:00
|
|
|
if class_name and (type := cls.CLASS2TYPE.get(class_name)):
|
2023-06-11 23:51:53 +00:00
|
|
|
return type
|
2023-06-11 16:51:50 +00:00
|
|
|
|
|
|
|
# give up
|
2023-06-11 23:51:53 +00:00
|
|
|
raise ValueError("Unable to determine model type")
|
2023-06-11 16:51:50 +00:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def _scan_and_load_checkpoint(cls,model_path: Path)->dict:
|
2023-06-11 23:51:53 +00:00
|
|
|
with SilenceWarnings():
|
|
|
|
if model_path.suffix.endswith((".ckpt", ".pt", ".bin")):
|
|
|
|
cls._scan_model(model_path, model_path)
|
|
|
|
return torch.load(model_path)
|
|
|
|
else:
|
|
|
|
return safetensors.torch.load_file(model_path)
|
2023-06-11 16:51:50 +00:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def _scan_model(cls, model_name, checkpoint):
|
|
|
|
"""
|
|
|
|
Apply picklescanner to the indicated checkpoint and issue a warning
|
|
|
|
and option to exit if an infected file is identified.
|
|
|
|
"""
|
|
|
|
# scan model
|
|
|
|
scan_result = scan_file_path(checkpoint)
|
|
|
|
if scan_result.infected_files != 0:
|
|
|
|
raise "The model {model_name} is potentially infected by malware. Aborting import."
|
|
|
|
|
|
|
|
###################################################3
|
|
|
|
# Checkpoint probing
|
|
|
|
###################################################3
|
|
|
|
class ProbeBase(object):
|
|
|
|
def get_base_type(self)->BaseModelType:
|
|
|
|
pass
|
|
|
|
|
2023-06-12 20:07:04 +00:00
|
|
|
def get_variant_type(self)->ModelVariantType:
|
2023-06-11 16:51:50 +00:00
|
|
|
pass
|
|
|
|
|
2023-06-12 20:07:04 +00:00
|
|
|
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
|
|
|
|
pass
|
|
|
|
|
2023-06-27 16:30:53 +00:00
|
|
|
def get_format(self)->str:
|
|
|
|
pass
|
|
|
|
|
2023-06-11 16:51:50 +00:00
|
|
|
class CheckpointProbeBase(ProbeBase):
|
|
|
|
def __init__(self,
|
|
|
|
checkpoint_path: Path,
|
|
|
|
checkpoint: dict,
|
2023-06-17 02:54:36 +00:00
|
|
|
helper: Callable[[Path],SchedulerPredictionType] = None
|
2023-06-11 16:51:50 +00:00
|
|
|
)->BaseModelType:
|
|
|
|
self.checkpoint = checkpoint or ModelProbe._scan_and_load_checkpoint(checkpoint_path)
|
|
|
|
self.checkpoint_path = checkpoint_path
|
|
|
|
self.helper = helper
|
|
|
|
|
|
|
|
def get_base_type(self)->BaseModelType:
|
|
|
|
pass
|
|
|
|
|
2023-06-27 16:30:53 +00:00
|
|
|
def get_format(self)->str:
|
|
|
|
return 'checkpoint'
|
|
|
|
|
2023-06-12 20:07:04 +00:00
|
|
|
def get_variant_type(self)-> ModelVariantType:
|
2023-06-11 16:51:50 +00:00
|
|
|
model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path,self.checkpoint)
|
2023-06-24 15:45:49 +00:00
|
|
|
if model_type != ModelType.Main:
|
2023-06-12 20:07:04 +00:00
|
|
|
return ModelVariantType.Normal
|
2023-06-11 16:51:50 +00:00
|
|
|
state_dict = self.checkpoint.get('state_dict') or self.checkpoint
|
|
|
|
in_channels = state_dict[
|
|
|
|
"model.diffusion_model.input_blocks.0.0.weight"
|
|
|
|
].shape[1]
|
|
|
|
if in_channels == 9:
|
2023-06-12 20:07:04 +00:00
|
|
|
return ModelVariantType.Inpaint
|
2023-06-11 16:51:50 +00:00
|
|
|
elif in_channels == 5:
|
2023-06-12 20:07:04 +00:00
|
|
|
return ModelVariantType.Depth
|
2023-06-13 05:27:51 +00:00
|
|
|
elif in_channels == 4:
|
|
|
|
return ModelVariantType.Normal
|
2023-06-11 16:51:50 +00:00
|
|
|
else:
|
2023-06-13 05:27:51 +00:00
|
|
|
raise Exception("Cannot determine variant type")
|
2023-06-11 16:51:50 +00:00
|
|
|
|
|
|
|
class PipelineCheckpointProbe(CheckpointProbeBase):
|
|
|
|
def get_base_type(self)->BaseModelType:
|
|
|
|
checkpoint = self.checkpoint
|
|
|
|
state_dict = self.checkpoint.get('state_dict') or checkpoint
|
|
|
|
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
|
|
|
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
|
2023-06-12 20:07:04 +00:00
|
|
|
return BaseModelType.StableDiffusion1
|
|
|
|
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
|
|
|
return BaseModelType.StableDiffusion2
|
|
|
|
raise Exception("Cannot determine base type")
|
|
|
|
|
|
|
|
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
|
|
|
|
type = self.get_base_type()
|
|
|
|
if type == BaseModelType.StableDiffusion1:
|
|
|
|
return SchedulerPredictionType.Epsilon
|
|
|
|
checkpoint = self.checkpoint
|
|
|
|
state_dict = self.checkpoint.get('state_dict') or checkpoint
|
|
|
|
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
2023-06-11 16:51:50 +00:00
|
|
|
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
|
|
|
if 'global_step' in checkpoint:
|
|
|
|
if checkpoint['global_step'] == 220000:
|
2023-06-12 20:07:04 +00:00
|
|
|
return SchedulerPredictionType.Epsilon
|
2023-06-11 16:51:50 +00:00
|
|
|
elif checkpoint["global_step"] == 110000:
|
2023-06-12 20:07:04 +00:00
|
|
|
return SchedulerPredictionType.VPrediction
|
2023-06-26 20:18:16 +00:00
|
|
|
if self.checkpoint_path and self.helper \
|
|
|
|
and not self.checkpoint_path.with_suffix('.yaml').exists(): # if a .yaml config file exists, then this step not needed
|
2023-06-12 20:07:04 +00:00
|
|
|
return self.helper(self.checkpoint_path)
|
2023-06-11 16:51:50 +00:00
|
|
|
else:
|
|
|
|
return None
|
|
|
|
|
|
|
|
class VaeCheckpointProbe(CheckpointProbeBase):
|
|
|
|
def get_base_type(self)->BaseModelType:
|
|
|
|
# I can't find any standalone 2.X VAEs to test with!
|
2023-06-12 20:07:04 +00:00
|
|
|
return BaseModelType.StableDiffusion1
|
2023-06-11 16:51:50 +00:00
|
|
|
|
|
|
|
class LoRACheckpointProbe(CheckpointProbeBase):
|
2023-06-27 16:30:53 +00:00
|
|
|
def get_format(self)->str:
|
|
|
|
return 'lycoris'
|
|
|
|
|
2023-06-11 16:51:50 +00:00
|
|
|
def get_base_type(self)->BaseModelType:
|
|
|
|
checkpoint = self.checkpoint
|
|
|
|
key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
|
|
|
|
key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a"
|
|
|
|
lora_token_vector_length = (
|
|
|
|
checkpoint[key1].shape[1]
|
|
|
|
if key1 in checkpoint
|
|
|
|
else checkpoint[key2].shape[0]
|
|
|
|
if key2 in checkpoint
|
|
|
|
else 768
|
|
|
|
)
|
|
|
|
if lora_token_vector_length == 768:
|
2023-06-12 20:07:04 +00:00
|
|
|
return BaseModelType.StableDiffusion1
|
2023-06-11 16:51:50 +00:00
|
|
|
elif lora_token_vector_length == 1024:
|
|
|
|
return BaseModelType.StableDiffusion2
|
|
|
|
else:
|
|
|
|
return None
|
|
|
|
|
|
|
|
class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
2023-06-27 16:30:53 +00:00
|
|
|
def get_format(self)->str:
|
|
|
|
return None
|
|
|
|
|
2023-06-11 16:51:50 +00:00
|
|
|
def get_base_type(self)->BaseModelType:
|
|
|
|
checkpoint = self.checkpoint
|
|
|
|
if 'string_to_token' in checkpoint:
|
|
|
|
token_dim = list(checkpoint['string_to_param'].values())[0].shape[-1]
|
|
|
|
elif 'emb_params' in checkpoint:
|
|
|
|
token_dim = checkpoint['emb_params'].shape[-1]
|
|
|
|
else:
|
|
|
|
token_dim = list(checkpoint.values())[0].shape[0]
|
|
|
|
if token_dim == 768:
|
2023-06-12 20:07:04 +00:00
|
|
|
return BaseModelType.StableDiffusion1
|
2023-06-11 16:51:50 +00:00
|
|
|
elif token_dim == 1024:
|
2023-06-12 20:07:04 +00:00
|
|
|
return BaseModelType.StableDiffusion2
|
2023-06-11 16:51:50 +00:00
|
|
|
else:
|
|
|
|
return None
|
|
|
|
|
|
|
|
class ControlNetCheckpointProbe(CheckpointProbeBase):
|
|
|
|
def get_base_type(self)->BaseModelType:
|
|
|
|
checkpoint = self.checkpoint
|
|
|
|
for key_name in ('control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight',
|
|
|
|
'input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight'
|
|
|
|
):
|
|
|
|
if key_name not in checkpoint:
|
|
|
|
continue
|
|
|
|
if checkpoint[key_name].shape[-1] == 768:
|
2023-06-12 20:07:04 +00:00
|
|
|
return BaseModelType.StableDiffusion1
|
2023-06-13 05:27:51 +00:00
|
|
|
elif checkpoint[key_name].shape[-1] == 1024:
|
|
|
|
return BaseModelType.StableDiffusion2
|
2023-06-11 16:51:50 +00:00
|
|
|
elif self.checkpoint_path and self.helper:
|
|
|
|
return self.helper(self.checkpoint_path)
|
2023-06-13 05:27:51 +00:00
|
|
|
raise Exception("Unable to determine base type for {self.checkpoint_path}")
|
2023-06-11 16:51:50 +00:00
|
|
|
|
|
|
|
########################################################
|
|
|
|
# classes for probing folders
|
|
|
|
#######################################################
|
|
|
|
class FolderProbeBase(ProbeBase):
|
|
|
|
def __init__(self,
|
|
|
|
folder_path: Path,
|
2023-06-11 23:51:53 +00:00
|
|
|
model: ModelMixin = None,
|
2023-06-11 16:51:50 +00:00
|
|
|
helper: Callable=None # not used
|
|
|
|
):
|
|
|
|
self.model = model
|
|
|
|
self.folder_path = folder_path
|
|
|
|
|
2023-06-12 20:07:04 +00:00
|
|
|
def get_variant_type(self)->ModelVariantType:
|
|
|
|
return ModelVariantType.Normal
|
2023-06-11 16:51:50 +00:00
|
|
|
|
2023-06-27 16:30:53 +00:00
|
|
|
def get_format(self)->str:
|
|
|
|
return 'diffusers'
|
|
|
|
|
2023-06-11 16:51:50 +00:00
|
|
|
class PipelineFolderProbe(FolderProbeBase):
|
|
|
|
def get_base_type(self)->BaseModelType:
|
2023-06-11 23:51:53 +00:00
|
|
|
if self.model:
|
|
|
|
unet_conf = self.model.unet.config
|
2023-06-11 16:51:50 +00:00
|
|
|
else:
|
2023-06-12 20:07:04 +00:00
|
|
|
with open(self.folder_path / 'unet' / 'config.json','r') as file:
|
|
|
|
unet_conf = json.load(file)
|
2023-06-11 23:51:53 +00:00
|
|
|
if unet_conf['cross_attention_dim'] == 768:
|
2023-06-12 20:07:04 +00:00
|
|
|
return BaseModelType.StableDiffusion1
|
2023-06-11 23:51:53 +00:00
|
|
|
elif unet_conf['cross_attention_dim'] == 1024:
|
2023-06-12 20:07:04 +00:00
|
|
|
return BaseModelType.StableDiffusion2
|
2023-06-11 23:51:53 +00:00
|
|
|
else:
|
|
|
|
raise ValueError(f'Unknown base model for {self.folder_path}')
|
2023-06-12 20:07:04 +00:00
|
|
|
|
|
|
|
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
|
|
|
|
if self.model:
|
|
|
|
scheduler_conf = self.model.scheduler.config
|
|
|
|
else:
|
|
|
|
with open(self.folder_path / 'scheduler' / 'scheduler_config.json','r') as file:
|
|
|
|
scheduler_conf = json.load(file)
|
|
|
|
if scheduler_conf['prediction_type'] == "v_prediction":
|
|
|
|
return SchedulerPredictionType.VPrediction
|
|
|
|
elif scheduler_conf['prediction_type'] == 'epsilon':
|
|
|
|
return SchedulerPredictionType.Epsilon
|
|
|
|
else:
|
|
|
|
return None
|
2023-06-11 16:51:50 +00:00
|
|
|
|
2023-06-12 20:07:04 +00:00
|
|
|
def get_variant_type(self)->ModelVariantType:
|
2023-06-11 23:51:53 +00:00
|
|
|
# This only works for pipelines! Any kind of
|
|
|
|
# exception results in our returning the
|
|
|
|
# "normal" variant type
|
|
|
|
try:
|
|
|
|
if self.model:
|
|
|
|
conf = self.model.unet.config
|
|
|
|
else:
|
|
|
|
config_file = self.folder_path / 'unet' / 'config.json'
|
2023-06-12 20:07:04 +00:00
|
|
|
with open(config_file,'r') as file:
|
|
|
|
conf = json.load(file)
|
2023-06-11 23:51:53 +00:00
|
|
|
|
|
|
|
in_channels = conf['in_channels']
|
|
|
|
if in_channels == 9:
|
2023-06-12 20:07:04 +00:00
|
|
|
return ModelVariantType.Inpainting
|
2023-06-11 23:51:53 +00:00
|
|
|
elif in_channels == 5:
|
2023-06-12 20:07:04 +00:00
|
|
|
return ModelVariantType.Depth
|
2023-06-11 23:51:53 +00:00
|
|
|
elif in_channels == 4:
|
2023-06-12 20:07:04 +00:00
|
|
|
return ModelVariantType.Normal
|
2023-06-11 23:51:53 +00:00
|
|
|
except:
|
|
|
|
pass
|
2023-06-12 20:07:04 +00:00
|
|
|
return ModelVariantType.Normal
|
2023-06-11 23:51:53 +00:00
|
|
|
|
2023-06-11 16:51:50 +00:00
|
|
|
class VaeFolderProbe(FolderProbeBase):
|
|
|
|
def get_base_type(self)->BaseModelType:
|
2023-06-12 20:07:04 +00:00
|
|
|
return BaseModelType.StableDiffusion1
|
2023-06-11 16:51:50 +00:00
|
|
|
|
|
|
|
class TextualInversionFolderProbe(FolderProbeBase):
|
2023-06-27 16:30:53 +00:00
|
|
|
def get_format(self)->str:
|
|
|
|
return None
|
|
|
|
|
2023-06-11 16:51:50 +00:00
|
|
|
def get_base_type(self)->BaseModelType:
|
|
|
|
path = self.folder_path / 'learned_embeds.bin'
|
|
|
|
if not path.exists():
|
|
|
|
return None
|
|
|
|
checkpoint = ModelProbe._scan_and_load_checkpoint(path)
|
2023-06-13 05:27:51 +00:00
|
|
|
return TextualInversionCheckpointProbe(None,checkpoint=checkpoint).get_base_type()
|
2023-06-11 16:51:50 +00:00
|
|
|
|
|
|
|
class ControlNetFolderProbe(FolderProbeBase):
|
|
|
|
def get_base_type(self)->BaseModelType:
|
2023-06-13 05:27:51 +00:00
|
|
|
config_file = self.folder_path / 'config.json'
|
2023-06-11 16:51:50 +00:00
|
|
|
if not config_file.exists():
|
2023-06-13 05:27:51 +00:00
|
|
|
raise Exception(f"Cannot determine base type for {self.folder_path}")
|
|
|
|
with open(config_file,'r') as file:
|
|
|
|
config = json.load(file)
|
2023-06-11 16:51:50 +00:00
|
|
|
# no obvious way to distinguish between sd2-base and sd2-768
|
2023-06-12 20:07:04 +00:00
|
|
|
return BaseModelType.StableDiffusion1 \
|
2023-06-11 16:51:50 +00:00
|
|
|
if config['cross_attention_dim']==768 \
|
|
|
|
else BaseModelType.StableDiffusion2
|
|
|
|
|
|
|
|
class LoRAFolderProbe(FolderProbeBase):
|
2023-06-20 15:08:27 +00:00
|
|
|
def get_base_type(self)->BaseModelType:
|
|
|
|
model_file = None
|
|
|
|
for suffix in ['safetensors','bin']:
|
|
|
|
base_file = self.folder_path / f'pytorch_lora_weights.{suffix}'
|
|
|
|
if base_file.exists():
|
|
|
|
model_file = base_file
|
|
|
|
break
|
|
|
|
if not model_file:
|
|
|
|
raise Exception('Unknown LoRA format encountered')
|
|
|
|
return LoRACheckpointProbe(model_file,None).get_base_type()
|
2023-06-11 16:51:50 +00:00
|
|
|
|
|
|
|
############## register probe classes ######
|
2023-06-24 15:45:49 +00:00
|
|
|
ModelProbe.register_probe('diffusers', ModelType.Main, PipelineFolderProbe)
|
2023-06-17 02:54:36 +00:00
|
|
|
ModelProbe.register_probe('diffusers', ModelType.Vae, VaeFolderProbe)
|
|
|
|
ModelProbe.register_probe('diffusers', ModelType.Lora, LoRAFolderProbe)
|
|
|
|
ModelProbe.register_probe('diffusers', ModelType.TextualInversion, TextualInversionFolderProbe)
|
|
|
|
ModelProbe.register_probe('diffusers', ModelType.ControlNet, ControlNetFolderProbe)
|
2023-06-24 15:45:49 +00:00
|
|
|
ModelProbe.register_probe('checkpoint', ModelType.Main, PipelineCheckpointProbe)
|
2023-06-12 20:07:04 +00:00
|
|
|
ModelProbe.register_probe('checkpoint', ModelType.Vae, VaeCheckpointProbe)
|
|
|
|
ModelProbe.register_probe('checkpoint', ModelType.Lora, LoRACheckpointProbe)
|
|
|
|
ModelProbe.register_probe('checkpoint', ModelType.TextualInversion, TextualInversionCheckpointProbe)
|
|
|
|
ModelProbe.register_probe('checkpoint', ModelType.ControlNet, ControlNetCheckpointProbe)
|