From 893f776f1d1f81e41ec0ecf9c5a7ee41a346abe0 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sun, 11 Jun 2023 19:51:53 -0400 Subject: [PATCH] model_probe working; model_install incomplete --- invokeai/backend/model_management/__init__.py | 2 +- .../backend/model_management/model_cache.py | 1 + .../backend/model_management/model_install.py | 65 ++++++--- .../backend/model_management/model_probe.py | 126 +++++++++++------- .../model_management/models/__init__.py | 2 +- .../backend/model_management/models/base.py | 3 +- 6 files changed, 122 insertions(+), 77 deletions(-) diff --git a/invokeai/backend/model_management/__init__.py b/invokeai/backend/model_management/__init__.py index a5dd2093d8..29c6f6b2d3 100644 --- a/invokeai/backend/model_management/__init__.py +++ b/invokeai/backend/model_management/__init__.py @@ -3,4 +3,4 @@ Initialization file for invokeai.backend.model_management """ from .model_manager import ModelManager, ModelInfo from .model_cache import ModelCache -from .models import BaseModelType, ModelType, SubModelType +from .models import BaseModelType, ModelType, SubModelType, VariantType diff --git a/invokeai/backend/model_management/model_cache.py b/invokeai/backend/model_management/model_cache.py index f3a6dac5bc..eac400a339 100644 --- a/invokeai/backend/model_management/model_cache.py +++ b/invokeai/backend/model_management/model_cache.py @@ -29,6 +29,7 @@ import torch from diffusers import logging as diffusers_logging from transformers import logging as transformers_logging +import logging import invokeai.backend.util.logging as logger from invokeai.app.services.config import get_invokeai_config from .lora import LoRAModel, TextualInversionModel diff --git a/invokeai/backend/model_management/model_install.py b/invokeai/backend/model_management/model_install.py index 6016f5f3f5..64c52185f3 100644 --- a/invokeai/backend/model_management/model_install.py +++ b/invokeai/backend/model_management/model_install.py @@ -4,6 +4,8 @@ Routines for downloading and installing models. import json import safetensors import safetensors.torch +import shutil +import tempfile import torch import traceback from dataclasses import dataclass @@ -14,8 +16,10 @@ from pathlib import Path import invokeai.backend.util.logging as logger from invokeai.app.services.config import InvokeAIAppConfig -from .models import BaseModelType, ModelType +from . import ModelManager +from .models import BaseModelType, ModelType, VariantType from .model_probe import ModelProbe, ModelVariantInfo +from .model_cache import SilenceWarnings class ModelInstall(object): ''' @@ -53,33 +57,50 @@ class ModelInstall(object): model_info = self.prober.probe(checkpoint, self.helper) if not model_info: raise ValueError(f"Unable to determine type of checkpoint file {checkpoint}") + + key = ModelManager.create_key( + model_name = checkpoint.stem, + base_model = model_info.base_type, + model_type = model_info.model_type, + ) + destination_path = self._dest_path(model_info) / checkpoint + destination_path.parent.mkdir(parents=True, exist_ok=True) + self._check_for_collision(destination_path) + stanza = { + key: dict( + name = checkpoint.stem, + description = f'{model_info.model_type} model {checkpoint.stem}', + base = model_info.base_model.value, + type = model_info.model_type.value, + variant = model_info.variant_type.value, + path = str(destination_path), + ) + } # non-pipeline; no conversion needed, just copy into right place if model_info.model_type != ModelType.Pipeline: - destination_path = self._dest_path(model_info) / checkpoint.name - self._check_for_collision(destination_path) shutil.copyfile(checkpoint, destination_path) - key = ModelManager.create_key( - model_name = checkpoint.stem, - base_model = model_info.base_type - model_type = model_info.model_type - ) - return { - key: dict( - name = model_name, - description = f'{model_info.model_type} model {model_name}', - path = str(destination_path), - format = 'checkpoint', - base = str(base_model), - type = str(model_type), - variant = str(model_info.variant_type), - ) - } - + stanza[key].update({'format': 'checkpoint'}) - destination_path = self._dest_path(model_info) / checkpoint.stem - + # pipeline - conversion needed here + else: + destination_path = self._dest_path(model_info) / checkpoint.stem + config_file = self._pipeline_type_to_config_file(model_info.model_type) + from .convert_ckpt_to_diffusers import convert_ckpt_to_diffusers + with SilenceWarnings: + convert_ckpt_to_diffusers( + checkpoint, + destination_path, + extract_ema=True, + original_config_file=config_file, + scan_needed=False, + ) + stanza[key].update({'format': 'folder', + 'path': destination_path, # no suffix on this + }) + + return stanza def _check_for_collision(self, path: Path): diff --git a/invokeai/backend/model_management/model_probe.py b/invokeai/backend/model_management/model_probe.py index bb2bbc2a85..b60d2d7358 100644 --- a/invokeai/backend/model_management/model_probe.py +++ b/invokeai/backend/model_management/model_probe.py @@ -4,13 +4,14 @@ import torch import safetensors.torch from dataclasses import dataclass -from diffusers import ModelMixin +from diffusers import ModelMixin, ConfigMixin, StableDiffusionPipeline, AutoencoderKL, ControlNetModel from pathlib import Path from typing import Callable, Literal, Union, Dict from picklescan.scanner import scan_file_path import invokeai.backend.util.logging as logger from .models import BaseModelType, ModelType, VariantType +from .model_cache import SilenceWarnings @dataclass class ModelVariantInfo(object): @@ -30,10 +31,9 @@ class ModelProbe(object): } CLASS2TYPE = { - "StableDiffusionPipeline" : ModelType.Pipeline, - "AutoencoderKL": ModelType.Vae, - "ControlNetModel" : ModelType.ControlNet, - + 'StableDiffusionPipeline' : ModelType.Pipeline, + 'AutoencoderKL' : ModelType.Vae, + 'ControlNetModel' : ModelType.ControlNet, } @classmethod @@ -56,7 +56,11 @@ class ModelProbe(object): the path to the model and returns the BaseModelType. It is called to distinguish between V2-Base and V2-768 SD models. ''' - format = 'folder' if model_path.is_dir() else 'file' + if model_path: + format = 'folder' if model_path.is_dir() else 'checkpoint' + else: + format = 'folder' if isinstance(model,(ConfigMixin,ModelMixin)) else 'checkpoint' + model_info = None try: model_type = cls.get_model_type_from_folder(model_path, model) \ @@ -102,32 +106,37 @@ class ModelProbe(object): ''' Get the model type of a hugging-face style folder. ''' - if (folder_path / 'learned_embeds.bin').exists(): - return ModelType.TextualInversion + if model: + class_name = model.__class__.__name__ + else: + if (folder_path / 'learned_embeds.bin').exists(): + return ModelType.TextualInversion - if (folder_path / 'pytorch_lora_weights.bin').exists(): - return ModelType.Lora + if (folder_path / 'pytorch_lora_weights.bin').exists(): + return ModelType.Lora - i = folder_path / 'model_index.json' - c = folder_path / 'config.json' - config_path = i if i.exists() else c if c.exists() else None - - if config_path: - conf = json.loads(config_path) - class_name = conf['_class_name'] - if type := cls.CLASS2TYPE.get(class_name): - return type + i = folder_path / 'model_index.json' + c = folder_path / 'config.json' + config_path = i if i.exists() else c if c.exists() else None + + if config_path: + conf = json.load(open(config_path,'r')) + class_name = conf['_class_name'] + + if type := cls.CLASS2TYPE.get(class_name): + return type # give up - raise ValueError("Unable to determine model type of {model_path}") + raise ValueError("Unable to determine model type") @classmethod def _scan_and_load_checkpoint(cls,model_path: Path)->dict: - if model_path.suffix.endswith((".ckpt", ".pt", ".bin")): - cls._scan_model(model_path, model_path) - return torch.load(model_path) - else: - return safetensors.torch.load_file(model_path) + with SilenceWarnings(): + if model_path.suffix.endswith((".ckpt", ".pt", ".bin")): + cls._scan_model(model_path, model_path) + return torch.load(model_path) + else: + return safetensors.torch.load_file(model_path) @classmethod def _scan_model(cls, model_name, checkpoint): @@ -255,44 +264,59 @@ class ControlNetCheckpointProbe(CheckpointProbeBase): ####################################################### class FolderProbeBase(ProbeBase): def __init__(self, - model: ModelMixin, folder_path: Path, + model: ModelMixin = None, helper: Callable=None # not used ): self.model = model self.folder_path = folder_path def get_variant_type(self)->VariantType: - - # only works for pipelines - config_file = self.folder_path / 'unet' / 'config.json' - if not config_file.exists(): - return VariantType.Normal - - conf = json.loads(config_file) - channels = conf['in_channels'] - if channels == 9: - return VariantType.Inpainting - elif channels == 5: - return VariantType.Depth - elif channels == 4: - return VariantType.Normal - else: - return VariantType.Normal + return VariantType.Normal class PipelineFolderProbe(FolderProbeBase): def get_base_type(self)->BaseModelType: - config_file = self.folder_path / 'scheduler' / 'scheduler_config.json' - if not config_file.exists(): - return None - conf = json.load(config_file) - if conf['prediction_type'] == "v_prediction": - return BaseModelType.StableDiffusion2 - elif conf['prediction_type'] == 'epsilon': - return BaseModelType.StableDiffusion2Base + if self.model: + unet_conf = self.model.unet.config + scheduler_conf = self.model.scheduler.config else: - return BaseModelType.StableDiffusion2 + unet_conf = json.load(open(self.folder_path / 'unet' / 'config.json','r')) + scheduler_conf = json.load(open(self.folder_path / 'scheduler' / 'scheduler_config.json','r')) + + if unet_conf['cross_attention_dim'] == 768: + return BaseModelType.StableDiffusion1_5 + elif unet_conf['cross_attention_dim'] == 1024: + if scheduler_conf['prediction_type'] == "v_prediction": + return BaseModelType.StableDiffusion2 + elif scheduler_conf['prediction_type'] == 'epsilon': + return BaseModelType.StableDiffusion2Base + else: + return BaseModelType.StableDiffusion2 + else: + raise ValueError(f'Unknown base model for {self.folder_path}') + def get_variant_type(self)->VariantType: + # This only works for pipelines! Any kind of + # exception results in our returning the + # "normal" variant type + try: + if self.model: + conf = self.model.unet.config + else: + config_file = self.folder_path / 'unet' / 'config.json' + conf = json.load(open(config_file,'r')) + + in_channels = conf['in_channels'] + if in_channels == 9: + return VariantType.Inpainting + elif in_channels == 5: + return VariantType.Depth + elif in_channels == 4: + return VariantType.Normal + except: + pass + return VariantType.Normal + class VaeFolderProbe(FolderProbeBase): def get_base_type(self)->BaseModelType: return BaseModelType.StableDiffusion1_5 diff --git a/invokeai/backend/model_management/models/__init__.py b/invokeai/backend/model_management/models/__init__.py index ec8dc34973..2fa328a1f7 100644 --- a/invokeai/backend/model_management/models/__init__.py +++ b/invokeai/backend/model_management/models/__init__.py @@ -1,4 +1,4 @@ -from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase +from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, VariantType from .stable_diffusion import StableDiffusion15Model, StableDiffusion2Model, StableDiffusion2BaseModel from .vae import VaeModel from .lora import LoRAModel diff --git a/invokeai/backend/model_management/models/base.py b/invokeai/backend/model_management/models/base.py index 7415dcac0a..7c4283e9af 100644 --- a/invokeai/backend/model_management/models/base.py +++ b/invokeai/backend/model_management/models/base.py @@ -21,9 +21,8 @@ class BaseModelType(str, Enum): class ModelType(str, Enum): Pipeline = "pipeline" Vae = "vae" - Lora = "lora" - #ControlNet = "controlnet" + ControlNet = "controlnet" # used by model_probe TextualInversion = "embedding" class SubModelType(str, Enum):