model_probe working; model_install incomplete

This commit is contained in:
Lincoln Stein 2023-06-11 19:51:53 -04:00
parent 085ab54124
commit 893f776f1d
6 changed files with 122 additions and 77 deletions

View File

@ -3,4 +3,4 @@ Initialization file for invokeai.backend.model_management
""" """
from .model_manager import ModelManager, ModelInfo from .model_manager import ModelManager, ModelInfo
from .model_cache import ModelCache from .model_cache import ModelCache
from .models import BaseModelType, ModelType, SubModelType from .models import BaseModelType, ModelType, SubModelType, VariantType

View File

@ -29,6 +29,7 @@ import torch
from diffusers import logging as diffusers_logging from diffusers import logging as diffusers_logging
from transformers import logging as transformers_logging from transformers import logging as transformers_logging
import logging
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config import get_invokeai_config from invokeai.app.services.config import get_invokeai_config
from .lora import LoRAModel, TextualInversionModel from .lora import LoRAModel, TextualInversionModel

View File

@ -4,6 +4,8 @@ Routines for downloading and installing models.
import json import json
import safetensors import safetensors
import safetensors.torch import safetensors.torch
import shutil
import tempfile
import torch import torch
import traceback import traceback
from dataclasses import dataclass from dataclasses import dataclass
@ -14,8 +16,10 @@ from pathlib import Path
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from .models import BaseModelType, ModelType from . import ModelManager
from .models import BaseModelType, ModelType, VariantType
from .model_probe import ModelProbe, ModelVariantInfo from .model_probe import ModelProbe, ModelVariantInfo
from .model_cache import SilenceWarnings
class ModelInstall(object): class ModelInstall(object):
''' '''
@ -54,32 +58,49 @@ class ModelInstall(object):
if not model_info: if not model_info:
raise ValueError(f"Unable to determine type of checkpoint file {checkpoint}") raise ValueError(f"Unable to determine type of checkpoint file {checkpoint}")
key = ModelManager.create_key(
model_name = checkpoint.stem,
base_model = model_info.base_type,
model_type = model_info.model_type,
)
destination_path = self._dest_path(model_info) / checkpoint
destination_path.parent.mkdir(parents=True, exist_ok=True)
self._check_for_collision(destination_path)
stanza = {
key: dict(
name = checkpoint.stem,
description = f'{model_info.model_type} model {checkpoint.stem}',
base = model_info.base_model.value,
type = model_info.model_type.value,
variant = model_info.variant_type.value,
path = str(destination_path),
)
}
# non-pipeline; no conversion needed, just copy into right place # non-pipeline; no conversion needed, just copy into right place
if model_info.model_type != ModelType.Pipeline: if model_info.model_type != ModelType.Pipeline:
destination_path = self._dest_path(model_info) / checkpoint.name
self._check_for_collision(destination_path)
shutil.copyfile(checkpoint, destination_path) shutil.copyfile(checkpoint, destination_path)
key = ModelManager.create_key( stanza[key].update({'format': 'checkpoint'})
model_name = checkpoint.stem,
base_model = model_info.base_type # pipeline - conversion needed here
model_type = model_info.model_type else:
) destination_path = self._dest_path(model_info) / checkpoint.stem
return { config_file = self._pipeline_type_to_config_file(model_info.model_type)
key: dict(
name = model_name, from .convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
description = f'{model_info.model_type} model {model_name}', with SilenceWarnings:
path = str(destination_path), convert_ckpt_to_diffusers(
format = 'checkpoint', checkpoint,
base = str(base_model), destination_path,
type = str(model_type), extract_ema=True,
variant = str(model_info.variant_type), original_config_file=config_file,
scan_needed=False,
) )
} stanza[key].update({'format': 'folder',
'path': destination_path, # no suffix on this
})
destination_path = self._dest_path(model_info) / checkpoint.stem
return stanza
def _check_for_collision(self, path: Path): def _check_for_collision(self, path: Path):

View File

@ -4,13 +4,14 @@ import torch
import safetensors.torch import safetensors.torch
from dataclasses import dataclass from dataclasses import dataclass
from diffusers import ModelMixin from diffusers import ModelMixin, ConfigMixin, StableDiffusionPipeline, AutoencoderKL, ControlNetModel
from pathlib import Path from pathlib import Path
from typing import Callable, Literal, Union, Dict from typing import Callable, Literal, Union, Dict
from picklescan.scanner import scan_file_path from picklescan.scanner import scan_file_path
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from .models import BaseModelType, ModelType, VariantType from .models import BaseModelType, ModelType, VariantType
from .model_cache import SilenceWarnings
@dataclass @dataclass
class ModelVariantInfo(object): class ModelVariantInfo(object):
@ -30,10 +31,9 @@ class ModelProbe(object):
} }
CLASS2TYPE = { CLASS2TYPE = {
"StableDiffusionPipeline" : ModelType.Pipeline, 'StableDiffusionPipeline' : ModelType.Pipeline,
"AutoencoderKL": ModelType.Vae, 'AutoencoderKL' : ModelType.Vae,
"ControlNetModel" : ModelType.ControlNet, 'ControlNetModel' : ModelType.ControlNet,
} }
@classmethod @classmethod
@ -56,7 +56,11 @@ class ModelProbe(object):
the path to the model and returns the BaseModelType. It is called to distinguish the path to the model and returns the BaseModelType. It is called to distinguish
between V2-Base and V2-768 SD models. between V2-Base and V2-768 SD models.
''' '''
format = 'folder' if model_path.is_dir() else 'file' if model_path:
format = 'folder' if model_path.is_dir() else 'checkpoint'
else:
format = 'folder' if isinstance(model,(ConfigMixin,ModelMixin)) else 'checkpoint'
model_info = None model_info = None
try: try:
model_type = cls.get_model_type_from_folder(model_path, model) \ model_type = cls.get_model_type_from_folder(model_path, model) \
@ -102,32 +106,37 @@ class ModelProbe(object):
''' '''
Get the model type of a hugging-face style folder. Get the model type of a hugging-face style folder.
''' '''
if (folder_path / 'learned_embeds.bin').exists(): if model:
return ModelType.TextualInversion class_name = model.__class__.__name__
else:
if (folder_path / 'learned_embeds.bin').exists():
return ModelType.TextualInversion
if (folder_path / 'pytorch_lora_weights.bin').exists(): if (folder_path / 'pytorch_lora_weights.bin').exists():
return ModelType.Lora return ModelType.Lora
i = folder_path / 'model_index.json' i = folder_path / 'model_index.json'
c = folder_path / 'config.json' c = folder_path / 'config.json'
config_path = i if i.exists() else c if c.exists() else None config_path = i if i.exists() else c if c.exists() else None
if config_path: if config_path:
conf = json.loads(config_path) conf = json.load(open(config_path,'r'))
class_name = conf['_class_name'] class_name = conf['_class_name']
if type := cls.CLASS2TYPE.get(class_name):
return type if type := cls.CLASS2TYPE.get(class_name):
return type
# give up # give up
raise ValueError("Unable to determine model type of {model_path}") raise ValueError("Unable to determine model type")
@classmethod @classmethod
def _scan_and_load_checkpoint(cls,model_path: Path)->dict: def _scan_and_load_checkpoint(cls,model_path: Path)->dict:
if model_path.suffix.endswith((".ckpt", ".pt", ".bin")): with SilenceWarnings():
cls._scan_model(model_path, model_path) if model_path.suffix.endswith((".ckpt", ".pt", ".bin")):
return torch.load(model_path) cls._scan_model(model_path, model_path)
else: return torch.load(model_path)
return safetensors.torch.load_file(model_path) else:
return safetensors.torch.load_file(model_path)
@classmethod @classmethod
def _scan_model(cls, model_name, checkpoint): def _scan_model(cls, model_name, checkpoint):
@ -255,43 +264,58 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
####################################################### #######################################################
class FolderProbeBase(ProbeBase): class FolderProbeBase(ProbeBase):
def __init__(self, def __init__(self,
model: ModelMixin,
folder_path: Path, folder_path: Path,
model: ModelMixin = None,
helper: Callable=None # not used helper: Callable=None # not used
): ):
self.model = model self.model = model
self.folder_path = folder_path self.folder_path = folder_path
def get_variant_type(self)->VariantType: def get_variant_type(self)->VariantType:
return VariantType.Normal
# only works for pipelines
config_file = self.folder_path / 'unet' / 'config.json'
if not config_file.exists():
return VariantType.Normal
conf = json.loads(config_file)
channels = conf['in_channels']
if channels == 9:
return VariantType.Inpainting
elif channels == 5:
return VariantType.Depth
elif channels == 4:
return VariantType.Normal
else:
return VariantType.Normal
class PipelineFolderProbe(FolderProbeBase): class PipelineFolderProbe(FolderProbeBase):
def get_base_type(self)->BaseModelType: def get_base_type(self)->BaseModelType:
config_file = self.folder_path / 'scheduler' / 'scheduler_config.json' if self.model:
if not config_file.exists(): unet_conf = self.model.unet.config
return None scheduler_conf = self.model.scheduler.config
conf = json.load(config_file)
if conf['prediction_type'] == "v_prediction":
return BaseModelType.StableDiffusion2
elif conf['prediction_type'] == 'epsilon':
return BaseModelType.StableDiffusion2Base
else: else:
return BaseModelType.StableDiffusion2 unet_conf = json.load(open(self.folder_path / 'unet' / 'config.json','r'))
scheduler_conf = json.load(open(self.folder_path / 'scheduler' / 'scheduler_config.json','r'))
if unet_conf['cross_attention_dim'] == 768:
return BaseModelType.StableDiffusion1_5
elif unet_conf['cross_attention_dim'] == 1024:
if scheduler_conf['prediction_type'] == "v_prediction":
return BaseModelType.StableDiffusion2
elif scheduler_conf['prediction_type'] == 'epsilon':
return BaseModelType.StableDiffusion2Base
else:
return BaseModelType.StableDiffusion2
else:
raise ValueError(f'Unknown base model for {self.folder_path}')
def get_variant_type(self)->VariantType:
# This only works for pipelines! Any kind of
# exception results in our returning the
# "normal" variant type
try:
if self.model:
conf = self.model.unet.config
else:
config_file = self.folder_path / 'unet' / 'config.json'
conf = json.load(open(config_file,'r'))
in_channels = conf['in_channels']
if in_channels == 9:
return VariantType.Inpainting
elif in_channels == 5:
return VariantType.Depth
elif in_channels == 4:
return VariantType.Normal
except:
pass
return VariantType.Normal
class VaeFolderProbe(FolderProbeBase): class VaeFolderProbe(FolderProbeBase):
def get_base_type(self)->BaseModelType: def get_base_type(self)->BaseModelType:

View File

@ -1,4 +1,4 @@
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, VariantType
from .stable_diffusion import StableDiffusion15Model, StableDiffusion2Model, StableDiffusion2BaseModel from .stable_diffusion import StableDiffusion15Model, StableDiffusion2Model, StableDiffusion2BaseModel
from .vae import VaeModel from .vae import VaeModel
from .lora import LoRAModel from .lora import LoRAModel

View File

@ -21,9 +21,8 @@ class BaseModelType(str, Enum):
class ModelType(str, Enum): class ModelType(str, Enum):
Pipeline = "pipeline" Pipeline = "pipeline"
Vae = "vae" Vae = "vae"
Lora = "lora" Lora = "lora"
#ControlNet = "controlnet" ControlNet = "controlnet" # used by model_probe
TextualInversion = "embedding" TextualInversion = "embedding"
class SubModelType(str, Enum): class SubModelType(str, Enum):