InvokeAI/invokeai/backend/model_management/model_install.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

202 lines
8.2 KiB
Python
Raw Normal View History

"""
Routines for downloading and installing models.
"""
import json
import safetensors
import safetensors.torch
import torch
import traceback
from dataclasses import dataclass
from diffusers import ModelMixin
from enum import Enum
from typing import Callable
from pathlib import Path
from picklescan.scanner import scan_file_path
import invokeai.backend.util.logging as logger
from .models import BaseModelType, ModelType
class CheckpointProbe(object):
PROBES = dict() # see below for redefinition
def __init__(self,
checkpoint_path: Path,
checkpoint: dict = None,
helper: Callable[[Path], BaseModelType]=None
):
checkpoint = checkpoint or self._scan_and_load_checkpoint(self.checkpoint_path)
self.checkpoint = checkpoint
self.checkpoint_path = checkpoint_path
self.helper = helper
def probe(self) -> ModelVariantInfo:
'''
Probes the checkpoint at path `checkpoint_path` and return
a ModelType object indicating the model base, model type and
model variant for the checkpoint.
'''
checkpoint = self.checkpoint
state_dict = checkpoint.get("state_dict") or checkpoint
model_info = None
try:
model_type = self.get_checkpoint_type(state_dict)
if not model_type:
if self.checkpoint_path.name == "learned_embeds.bin":
model_type = ModelType.TextualInversion
else:
return None # we give up
probe = self.PROBES[model_type]()
base_type = probe.get_base_type(checkpoint, self.checkpoint_path, self.helper)
variant_type = probe.get_variant_type(model_type, checkpoint)
model_info = ModelVariantInfo(
model_type = model_type,
base_type = base_type,
variant_type = variant_type,
)
except (KeyError, ValueError) as e:
logger.error(f'An error occurred while probing {self.checkpoint_path}: {str(e)}')
logger.error(traceback.format_exc())
return model_info
class CheckpointProbeBase(object):
def get_base_type(self,
checkpoint: dict,
checkpoint_path: Path = None,
helper: Callable[[Path],BaseModelType] = None
)->BaseModelType:
pass
def get_variant_type(self,
model_type: ModelType,
checkpoint: dict,
)-> VariantType:
if model_type != ModelType.Pipeline:
return None
state_dict = checkpoint.get('state_dict') or checkpoint
in_channels = state_dict[
"model.diffusion_model.input_blocks.0.0.weight"
].shape[1]
if in_channels == 9:
return VariantType.Inpaint
elif in_channels == 5:
return VariantType.depth
else:
return None
class CheckpointProbe(CheckpointProbeBase):
def get_base_type(self,
checkpoint: dict,
checkpoint_path: Path = None,
helper: Callable[[Path],BaseModelType] = None
)->BaseModelType:
state_dict = checkpoint.get('state_dict') or checkpoint
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
return BaseModelType.StableDiffusion1_5
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
if 'global_step' in checkpoint:
if checkpoint['global_step'] == 220000:
return BaseModelType.StableDiffusion2Base
elif checkpoint["global_step"] == 110000:
return BaseModelType.StableDiffusion2
if checkpoint_path and helper:
return helper(checkpoint_path)
else:
return None
class VaeProbe(CheckpointProbeBase):
def get_base_type(self,
checkpoint: dict,
checkpoint_path: Path = None,
helper: Callable[[Path],BaseModelType] = None
)->BaseModelType:
# I can't find any standalone 2.X VAEs to test with!
return BaseModelType.StableDiffusion1_5
class LoRAProbe(CheckpointProbeBase):
def get_base_type(self,
checkpoint: dict,
checkpoint_path: Path = None,
helper: Callable[[Path],BaseModelType] = None
)->BaseModelType:
key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a"
lora_token_vector_length = (
checkpoint[key1].shape[1]
if key1 in checkpoint
else checkpoint[key2].shape[0]
if key2 in checkpoint
else 768
)
if lora_token_vector_length == 768:
return BaseModelType.StableDiffusion1_5
elif lora_token_vector_length == 1024:
return BaseModelType.StableDiffusion2
else:
return None
class TextualInversionProbe(CheckpointProbeBase):
def get_base_type(self,
checkpoint: dict,
checkpoint_path: Path = None,
helper: Callable[[Path],BaseModelType] = None
)->BaseModelType:
if 'string_to_token' in checkpoint:
token_dim = list(checkpoint['string_to_param'].values())[0].shape[-1]
elif 'emb_params' in checkpoint:
token_dim = checkpoint['emb_params'].shape[-1]
else:
token_dim = list(checkpoint.values())[0].shape[0]
if token_dim == 768:
return BaseModelType.StableDiffusion1_5
elif token_dim == 1024:
return BaseModelType.StableDiffusion2Base
else:
return None
class ControlNetProbe(CheckpointProbeBase):
def get_base_type(self,
checkpoint: dict,
checkpoint_path: Path = None,
helper: Callable[[Path],BaseModelType] = None
)->BaseModelType:
for key_name in ('control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight',
'input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight'
):
if key_name not in checkpoint:
continue
if checkpoint[key_name].shape[-1] == 768:
return BaseModelType.StableDiffusion1_5
elif checkpoint_path and helper:
return helper(checkpoint_path)
PROBES = {
ModelType.Pipeline: CheckpointProbe,
ModelType.Vae: VaeProbe,
ModelType.Lora: LoRAProbe,
ModelType.TextualInversion: TextualInversionProbe,
ModelType.ControlNet: ControlNetProbe,
}
@classmethod
def get_checkpoint_type(cls, state_dict: dict) -> ModelType:
if any([x.startswith("model.diffusion_model") for x in state_dict.keys()]):
return ModelType.Pipeline
if any([x.startswith("encoder.conv_in") for x in state_dict.keys()]):
return ModelType.Vae
if "string_to_token" in state_dict or "emb_params" in state_dict:
return ModelType.TextualInversion
if any([x.startswith("lora") for x in state_dict.keys()]):
return ModelType.Lora
if any([x.startswith("control_model") for x in state_dict.keys()]):
return ModelType.ControlNet
if any([x.startswith("input_blocks") for x in state_dict.keys()]):
return ModelType.ControlNet
return None # give up