From a328986b43cf7c53bf2714be8e7111e0318615d5 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Sat, 8 Jul 2023 04:09:10 +0300 Subject: [PATCH] Less naive model detection --- .../backend/model_management/model_manager.py | 10 +++---- .../model_management/models/__init__.py | 2 +- .../backend/model_management/models/base.py | 3 ++ .../backend/model_management/models/lora.py | 15 ++++++++-- .../models/stable_diffusion.py | 29 +++++++++++++++---- .../models/textual_inversion.py | 14 ++++++++- .../backend/model_management/models/vae.py | 15 ++++++++-- 7 files changed, 68 insertions(+), 20 deletions(-) diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 03514cfeff..a8d43a6888 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -250,8 +250,8 @@ from .model_cache import ModelCache, ModelLocker from .models import ( BaseModelType, ModelType, SubModelType, ModelError, SchedulerPredictionType, MODEL_CLASSES, - ModelConfigBase, ModelNotFoundException, - ) + ModelConfigBase, ModelNotFoundException, InvalidModelException, +) # We are only starting to number the config file with release 3. # The config file version doesn't have to start at release version, but it will help @@ -275,10 +275,6 @@ class ModelInfo(): def __exit__(self,*args, **kwargs): self.context.__exit__(*args, **kwargs) -class InvalidModelError(Exception): - "Raised when an invalid model is requested" - pass - class AddModelResult(BaseModel): name: str = Field(description="The name of the model after installation") model_type: ModelType = Field(description="The type of model") @@ -817,6 +813,8 @@ class ModelManager(object): model_config: ModelConfigBase = model_class.probe_config(str(model_path)) self.models[model_key] = model_config new_models_found = True + except InvalidModelException: + self.logger.warning(f"Not a valid model: {model_path}") except NotImplementedError as e: self.logger.warning(e) diff --git a/invokeai/backend/model_management/models/__init__.py b/invokeai/backend/model_management/models/__init__.py index 1b381cd2a8..b02d85471d 100644 --- a/invokeai/backend/model_management/models/__init__.py +++ b/invokeai/backend/model_management/models/__init__.py @@ -2,7 +2,7 @@ import inspect from enum import Enum from pydantic import BaseModel from typing import Literal, get_origin -from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings, ModelNotFoundException +from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings, ModelNotFoundException, InvalidModelException from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model from .vae import VaeModel from .lora import LoRAModel diff --git a/invokeai/backend/model_management/models/base.py b/invokeai/backend/model_management/models/base.py index 57c02bce76..ddbc401e5b 100644 --- a/invokeai/backend/model_management/models/base.py +++ b/invokeai/backend/model_management/models/base.py @@ -15,6 +15,9 @@ from contextlib import suppress from pydantic import BaseModel, Field from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union +class InvalidModelException(Exception): + pass + class ModelNotFoundException(Exception): pass diff --git a/invokeai/backend/model_management/models/lora.py b/invokeai/backend/model_management/models/lora.py index 59feacde06..5387ade0e5 100644 --- a/invokeai/backend/model_management/models/lora.py +++ b/invokeai/backend/model_management/models/lora.py @@ -9,6 +9,7 @@ from .base import ( ModelType, SubModelType, classproperty, + InvalidModelException, ) # TODO: naming from ..lora import LoRAModel as LoRAModelRaw @@ -56,10 +57,18 @@ class LoRAModel(ModelBase): @classmethod def detect_format(cls, path: str): + if not os.path.exists(path): + raise ModelNotFoundException() + if os.path.isdir(path): - return LoRAModelFormat.Diffusers - else: - return LoRAModelFormat.LyCORIS + if os.path.exists(os.path.join(path, "pytorch_lora_weights.bin")): + return LoRAModelFormat.Diffusers + + if os.path.isfile(path): + if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]): + return LoRAModelFormat.LyCORIS + + raise InvalidModelException(f"Not a valid model: {path}") @classmethod def convert_if_required( diff --git a/invokeai/backend/model_management/models/stable_diffusion.py b/invokeai/backend/model_management/models/stable_diffusion.py index c98d5a0ae8..74751a40dd 100644 --- a/invokeai/backend/model_management/models/stable_diffusion.py +++ b/invokeai/backend/model_management/models/stable_diffusion.py @@ -16,6 +16,7 @@ from .base import ( SilenceWarnings, read_checkpoint_meta, classproperty, + InvalidModelException, ) from invokeai.app.services.config import InvokeAIAppConfig from omegaconf import OmegaConf @@ -98,10 +99,18 @@ class StableDiffusion1Model(DiffusersModel): @classmethod def detect_format(cls, model_path: str): + if not os.path.exists(model_path): + raise ModelNotFoundException() + if os.path.isdir(model_path): - return StableDiffusion1ModelFormat.Diffusers - else: - return StableDiffusion1ModelFormat.Checkpoint + if os.path.exists(os.path.join(model_path, "model_index.json")): + return StableDiffusion1ModelFormat.Diffusers + + if os.path.isfile(model_path): + if any([model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]): + return StableDiffusion1ModelFormat.Checkpoint + + raise InvalidModelException(f"Not a valid model: {model_path}") @classmethod def convert_if_required( @@ -200,10 +209,18 @@ class StableDiffusion2Model(DiffusersModel): @classmethod def detect_format(cls, model_path: str): + if not os.path.exists(model_path): + raise ModelNotFoundException() + if os.path.isdir(model_path): - return StableDiffusion2ModelFormat.Diffusers - else: - return StableDiffusion2ModelFormat.Checkpoint + if os.path.exists(os.path.join(model_path, "model_index.json")): + return StableDiffusion2ModelFormat.Diffusers + + if os.path.isfile(model_path): + if any([model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]): + return StableDiffusion2ModelFormat.Checkpoint + + raise InvalidModelException(f"Not a valid model: {model_path}") @classmethod def convert_if_required( diff --git a/invokeai/backend/model_management/models/textual_inversion.py b/invokeai/backend/model_management/models/textual_inversion.py index 4dcdbb24ba..9cd62bb417 100644 --- a/invokeai/backend/model_management/models/textual_inversion.py +++ b/invokeai/backend/model_management/models/textual_inversion.py @@ -9,6 +9,7 @@ from .base import ( SubModelType, classproperty, ModelNotFoundException, + InvalidModelException, ) # TODO: naming from ..lora import TextualInversionModel as TextualInversionModelRaw @@ -59,7 +60,18 @@ class TextualInversionModel(ModelBase): @classmethod def detect_format(cls, path: str): - return None + if not os.path.exists(path): + raise ModelNotFoundException() + + if os.path.isdir(path): + if os.path.exists(os.path.join(path, "learned_embeds.bin")): + return None # diffusers-ti + + if os.path.isfile(path): + if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]): + return None + + raise InvalidModelException(f"Not a valid model: {path}") @classmethod def convert_if_required( diff --git a/invokeai/backend/model_management/models/vae.py b/invokeai/backend/model_management/models/vae.py index 3f0d226687..2a5b7cff24 100644 --- a/invokeai/backend/model_management/models/vae.py +++ b/invokeai/backend/model_management/models/vae.py @@ -15,6 +15,7 @@ from .base import ( calc_model_size_by_fs, calc_model_size_by_data, classproperty, + InvalidModelException, ) from invokeai.app.services.config import InvokeAIAppConfig from diffusers.utils import is_safetensors_available @@ -75,10 +76,18 @@ class VaeModel(ModelBase): @classmethod def detect_format(cls, path: str): + if not os.path.exists(path): + raise ModelNotFoundException() + if os.path.isdir(path): - return VaeModelFormat.Diffusers - else: - return VaeModelFormat.Checkpoint + if os.path.exists(os.path.join(path, "config.json")): + return VaeModelFormat.Diffusers + + if os.path.isfile(path): + if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]): + return VaeModelFormat.Checkpoint + + raise InvalidModelException(f"Not a valid model: {path}") @classmethod def convert_if_required(