Less naive model detection

This commit is contained in:
Sergey Borisov 2023-07-08 04:09:10 +03:00 committed by Kent Keirsey
parent af239fa122
commit a328986b43
7 changed files with 68 additions and 20 deletions

View File

@ -250,8 +250,8 @@ from .model_cache import ModelCache, ModelLocker
from .models import ( from .models import (
BaseModelType, ModelType, SubModelType, BaseModelType, ModelType, SubModelType,
ModelError, SchedulerPredictionType, MODEL_CLASSES, ModelError, SchedulerPredictionType, MODEL_CLASSES,
ModelConfigBase, ModelNotFoundException, ModelConfigBase, ModelNotFoundException, InvalidModelException,
) )
# We are only starting to number the config file with release 3. # We are only starting to number the config file with release 3.
# The config file version doesn't have to start at release version, but it will help # The config file version doesn't have to start at release version, but it will help
@ -275,10 +275,6 @@ class ModelInfo():
def __exit__(self,*args, **kwargs): def __exit__(self,*args, **kwargs):
self.context.__exit__(*args, **kwargs) self.context.__exit__(*args, **kwargs)
class InvalidModelError(Exception):
"Raised when an invalid model is requested"
pass
class AddModelResult(BaseModel): class AddModelResult(BaseModel):
name: str = Field(description="The name of the model after installation") name: str = Field(description="The name of the model after installation")
model_type: ModelType = Field(description="The type of model") model_type: ModelType = Field(description="The type of model")
@ -817,6 +813,8 @@ class ModelManager(object):
model_config: ModelConfigBase = model_class.probe_config(str(model_path)) model_config: ModelConfigBase = model_class.probe_config(str(model_path))
self.models[model_key] = model_config self.models[model_key] = model_config
new_models_found = True new_models_found = True
except InvalidModelException:
self.logger.warning(f"Not a valid model: {model_path}")
except NotImplementedError as e: except NotImplementedError as e:
self.logger.warning(e) self.logger.warning(e)

View File

@ -2,7 +2,7 @@ import inspect
from enum import Enum from enum import Enum
from pydantic import BaseModel from pydantic import BaseModel
from typing import Literal, get_origin from typing import Literal, get_origin
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings, ModelNotFoundException from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings, ModelNotFoundException, InvalidModelException
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
from .vae import VaeModel from .vae import VaeModel
from .lora import LoRAModel from .lora import LoRAModel

View File

@ -15,6 +15,9 @@ from contextlib import suppress
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
class InvalidModelException(Exception):
pass
class ModelNotFoundException(Exception): class ModelNotFoundException(Exception):
pass pass

View File

@ -9,6 +9,7 @@ from .base import (
ModelType, ModelType,
SubModelType, SubModelType,
classproperty, classproperty,
InvalidModelException,
) )
# TODO: naming # TODO: naming
from ..lora import LoRAModel as LoRAModelRaw from ..lora import LoRAModel as LoRAModelRaw
@ -56,10 +57,18 @@ class LoRAModel(ModelBase):
@classmethod @classmethod
def detect_format(cls, path: str): def detect_format(cls, path: str):
if not os.path.exists(path):
raise ModelNotFoundException()
if os.path.isdir(path): if os.path.isdir(path):
return LoRAModelFormat.Diffusers if os.path.exists(os.path.join(path, "pytorch_lora_weights.bin")):
else: return LoRAModelFormat.Diffusers
return LoRAModelFormat.LyCORIS
if os.path.isfile(path):
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
return LoRAModelFormat.LyCORIS
raise InvalidModelException(f"Not a valid model: {path}")
@classmethod @classmethod
def convert_if_required( def convert_if_required(

View File

@ -16,6 +16,7 @@ from .base import (
SilenceWarnings, SilenceWarnings,
read_checkpoint_meta, read_checkpoint_meta,
classproperty, classproperty,
InvalidModelException,
) )
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from omegaconf import OmegaConf from omegaconf import OmegaConf
@ -98,10 +99,18 @@ class StableDiffusion1Model(DiffusersModel):
@classmethod @classmethod
def detect_format(cls, model_path: str): def detect_format(cls, model_path: str):
if not os.path.exists(model_path):
raise ModelNotFoundException()
if os.path.isdir(model_path): if os.path.isdir(model_path):
return StableDiffusion1ModelFormat.Diffusers if os.path.exists(os.path.join(model_path, "model_index.json")):
else: return StableDiffusion1ModelFormat.Diffusers
return StableDiffusion1ModelFormat.Checkpoint
if os.path.isfile(model_path):
if any([model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
return StableDiffusion1ModelFormat.Checkpoint
raise InvalidModelException(f"Not a valid model: {model_path}")
@classmethod @classmethod
def convert_if_required( def convert_if_required(
@ -200,10 +209,18 @@ class StableDiffusion2Model(DiffusersModel):
@classmethod @classmethod
def detect_format(cls, model_path: str): def detect_format(cls, model_path: str):
if not os.path.exists(model_path):
raise ModelNotFoundException()
if os.path.isdir(model_path): if os.path.isdir(model_path):
return StableDiffusion2ModelFormat.Diffusers if os.path.exists(os.path.join(model_path, "model_index.json")):
else: return StableDiffusion2ModelFormat.Diffusers
return StableDiffusion2ModelFormat.Checkpoint
if os.path.isfile(model_path):
if any([model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
return StableDiffusion2ModelFormat.Checkpoint
raise InvalidModelException(f"Not a valid model: {model_path}")
@classmethod @classmethod
def convert_if_required( def convert_if_required(

View File

@ -9,6 +9,7 @@ from .base import (
SubModelType, SubModelType,
classproperty, classproperty,
ModelNotFoundException, ModelNotFoundException,
InvalidModelException,
) )
# TODO: naming # TODO: naming
from ..lora import TextualInversionModel as TextualInversionModelRaw from ..lora import TextualInversionModel as TextualInversionModelRaw
@ -59,7 +60,18 @@ class TextualInversionModel(ModelBase):
@classmethod @classmethod
def detect_format(cls, path: str): def detect_format(cls, path: str):
return None if not os.path.exists(path):
raise ModelNotFoundException()
if os.path.isdir(path):
if os.path.exists(os.path.join(path, "learned_embeds.bin")):
return None # diffusers-ti
if os.path.isfile(path):
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
return None
raise InvalidModelException(f"Not a valid model: {path}")
@classmethod @classmethod
def convert_if_required( def convert_if_required(

View File

@ -15,6 +15,7 @@ from .base import (
calc_model_size_by_fs, calc_model_size_by_fs,
calc_model_size_by_data, calc_model_size_by_data,
classproperty, classproperty,
InvalidModelException,
) )
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from diffusers.utils import is_safetensors_available from diffusers.utils import is_safetensors_available
@ -75,10 +76,18 @@ class VaeModel(ModelBase):
@classmethod @classmethod
def detect_format(cls, path: str): def detect_format(cls, path: str):
if not os.path.exists(path):
raise ModelNotFoundException()
if os.path.isdir(path): if os.path.isdir(path):
return VaeModelFormat.Diffusers if os.path.exists(os.path.join(path, "config.json")):
else: return VaeModelFormat.Diffusers
return VaeModelFormat.Checkpoint
if os.path.isfile(path):
if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
return VaeModelFormat.Checkpoint
raise InvalidModelException(f"Not a valid model: {path}")
@classmethod @classmethod
def convert_if_required( def convert_if_required(