model_probe working; model_install incomplete

This commit is contained in:
Lincoln Stein 2023-06-11 19:51:53 -04:00
parent 085ab54124
commit 893f776f1d
6 changed files with 122 additions and 77 deletions

View File

@ -3,4 +3,4 @@ Initialization file for invokeai.backend.model_management
"""
from .model_manager import ModelManager, ModelInfo
from .model_cache import ModelCache
from .models import BaseModelType, ModelType, SubModelType
from .models import BaseModelType, ModelType, SubModelType, VariantType

View File

@ -29,6 +29,7 @@ import torch
from diffusers import logging as diffusers_logging
from transformers import logging as transformers_logging
import logging
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import get_invokeai_config
from .lora import LoRAModel, TextualInversionModel

View File

@ -4,6 +4,8 @@ Routines for downloading and installing models.
import json
import safetensors
import safetensors.torch
import shutil
import tempfile
import torch
import traceback
from dataclasses import dataclass
@ -14,8 +16,10 @@ from pathlib import Path
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from .models import BaseModelType, ModelType
from . import ModelManager
from .models import BaseModelType, ModelType, VariantType
from .model_probe import ModelProbe, ModelVariantInfo
from .model_cache import SilenceWarnings
class ModelInstall(object):
'''
@ -54,32 +58,49 @@ class ModelInstall(object):
if not model_info:
raise ValueError(f"Unable to determine type of checkpoint file {checkpoint}")
# non-pipeline; no conversion needed, just copy into right place
if model_info.model_type != ModelType.Pipeline:
destination_path = self._dest_path(model_info) / checkpoint.name
self._check_for_collision(destination_path)
shutil.copyfile(checkpoint, destination_path)
key = ModelManager.create_key(
model_name = checkpoint.stem,
base_model = model_info.base_type
model_type = model_info.model_type
base_model = model_info.base_type,
model_type = model_info.model_type,
)
return {
destination_path = self._dest_path(model_info) / checkpoint
destination_path.parent.mkdir(parents=True, exist_ok=True)
self._check_for_collision(destination_path)
stanza = {
key: dict(
name = model_name,
description = f'{model_info.model_type} model {model_name}',
name = checkpoint.stem,
description = f'{model_info.model_type} model {checkpoint.stem}',
base = model_info.base_model.value,
type = model_info.model_type.value,
variant = model_info.variant_type.value,
path = str(destination_path),
format = 'checkpoint',
base = str(base_model),
type = str(model_type),
variant = str(model_info.variant_type),
)
}
# non-pipeline; no conversion needed, just copy into right place
if model_info.model_type != ModelType.Pipeline:
shutil.copyfile(checkpoint, destination_path)
stanza[key].update({'format': 'checkpoint'})
# pipeline - conversion needed here
else:
destination_path = self._dest_path(model_info) / checkpoint.stem
config_file = self._pipeline_type_to_config_file(model_info.model_type)
from .convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
with SilenceWarnings:
convert_ckpt_to_diffusers(
checkpoint,
destination_path,
extract_ema=True,
original_config_file=config_file,
scan_needed=False,
)
stanza[key].update({'format': 'folder',
'path': destination_path, # no suffix on this
})
return stanza
def _check_for_collision(self, path: Path):

View File

@ -4,13 +4,14 @@ import torch
import safetensors.torch
from dataclasses import dataclass
from diffusers import ModelMixin
from diffusers import ModelMixin, ConfigMixin, StableDiffusionPipeline, AutoencoderKL, ControlNetModel
from pathlib import Path
from typing import Callable, Literal, Union, Dict
from picklescan.scanner import scan_file_path
import invokeai.backend.util.logging as logger
from .models import BaseModelType, ModelType, VariantType
from .model_cache import SilenceWarnings
@dataclass
class ModelVariantInfo(object):
@ -30,10 +31,9 @@ class ModelProbe(object):
}
CLASS2TYPE = {
"StableDiffusionPipeline" : ModelType.Pipeline,
"AutoencoderKL": ModelType.Vae,
"ControlNetModel" : ModelType.ControlNet,
'StableDiffusionPipeline' : ModelType.Pipeline,
'AutoencoderKL' : ModelType.Vae,
'ControlNetModel' : ModelType.ControlNet,
}
@classmethod
@ -56,7 +56,11 @@ class ModelProbe(object):
the path to the model and returns the BaseModelType. It is called to distinguish
between V2-Base and V2-768 SD models.
'''
format = 'folder' if model_path.is_dir() else 'file'
if model_path:
format = 'folder' if model_path.is_dir() else 'checkpoint'
else:
format = 'folder' if isinstance(model,(ConfigMixin,ModelMixin)) else 'checkpoint'
model_info = None
try:
model_type = cls.get_model_type_from_folder(model_path, model) \
@ -102,6 +106,9 @@ class ModelProbe(object):
'''
Get the model type of a hugging-face style folder.
'''
if model:
class_name = model.__class__.__name__
else:
if (folder_path / 'learned_embeds.bin').exists():
return ModelType.TextualInversion
@ -113,16 +120,18 @@ class ModelProbe(object):
config_path = i if i.exists() else c if c.exists() else None
if config_path:
conf = json.loads(config_path)
conf = json.load(open(config_path,'r'))
class_name = conf['_class_name']
if type := cls.CLASS2TYPE.get(class_name):
return type
# give up
raise ValueError("Unable to determine model type of {model_path}")
raise ValueError("Unable to determine model type")
@classmethod
def _scan_and_load_checkpoint(cls,model_path: Path)->dict:
with SilenceWarnings():
if model_path.suffix.endswith((".ckpt", ".pt", ".bin")):
cls._scan_model(model_path, model_path)
return torch.load(model_path)
@ -255,43 +264,58 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
#######################################################
class FolderProbeBase(ProbeBase):
def __init__(self,
model: ModelMixin,
folder_path: Path,
model: ModelMixin = None,
helper: Callable=None # not used
):
self.model = model
self.folder_path = folder_path
def get_variant_type(self)->VariantType:
# only works for pipelines
config_file = self.folder_path / 'unet' / 'config.json'
if not config_file.exists():
return VariantType.Normal
conf = json.loads(config_file)
channels = conf['in_channels']
if channels == 9:
return VariantType.Inpainting
elif channels == 5:
return VariantType.Depth
elif channels == 4:
return VariantType.Normal
else:
return VariantType.Normal
class PipelineFolderProbe(FolderProbeBase):
def get_base_type(self)->BaseModelType:
config_file = self.folder_path / 'scheduler' / 'scheduler_config.json'
if not config_file.exists():
return None
conf = json.load(config_file)
if conf['prediction_type'] == "v_prediction":
if self.model:
unet_conf = self.model.unet.config
scheduler_conf = self.model.scheduler.config
else:
unet_conf = json.load(open(self.folder_path / 'unet' / 'config.json','r'))
scheduler_conf = json.load(open(self.folder_path / 'scheduler' / 'scheduler_config.json','r'))
if unet_conf['cross_attention_dim'] == 768:
return BaseModelType.StableDiffusion1_5
elif unet_conf['cross_attention_dim'] == 1024:
if scheduler_conf['prediction_type'] == "v_prediction":
return BaseModelType.StableDiffusion2
elif conf['prediction_type'] == 'epsilon':
elif scheduler_conf['prediction_type'] == 'epsilon':
return BaseModelType.StableDiffusion2Base
else:
return BaseModelType.StableDiffusion2
else:
raise ValueError(f'Unknown base model for {self.folder_path}')
def get_variant_type(self)->VariantType:
# This only works for pipelines! Any kind of
# exception results in our returning the
# "normal" variant type
try:
if self.model:
conf = self.model.unet.config
else:
config_file = self.folder_path / 'unet' / 'config.json'
conf = json.load(open(config_file,'r'))
in_channels = conf['in_channels']
if in_channels == 9:
return VariantType.Inpainting
elif in_channels == 5:
return VariantType.Depth
elif in_channels == 4:
return VariantType.Normal
except:
pass
return VariantType.Normal
class VaeFolderProbe(FolderProbeBase):
def get_base_type(self)->BaseModelType:

View File

@ -1,4 +1,4 @@
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, VariantType
from .stable_diffusion import StableDiffusion15Model, StableDiffusion2Model, StableDiffusion2BaseModel
from .vae import VaeModel
from .lora import LoRAModel

View File

@ -21,9 +21,8 @@ class BaseModelType(str, Enum):
class ModelType(str, Enum):
Pipeline = "pipeline"
Vae = "vae"
Lora = "lora"
#ControlNet = "controlnet"
ControlNet = "controlnet" # used by model_probe
TextualInversion = "embedding"
class SubModelType(str, Enum):