diff --git a/invokeai/backend/model_management/models/__init__.py b/invokeai/backend/model_management/models/__init__.py index 2e830a3c05..40995498bf 100644 --- a/invokeai/backend/model_management/models/__init__.py +++ b/invokeai/backend/model_management/models/__init__.py @@ -2,13 +2,9 @@ from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfig from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model from .vae import VaeModel from .lora import LoRAModel -#from .controlnet import ControlNetModel # TODO: +from .controlnet import ControlNetModel # TODO: from .textual_inversion import TextualInversionModel -# TODO: -class ControlNetModel: - pass - MODEL_CLASSES = { BaseModelType.StableDiffusion1: { ModelType.Pipeline: StableDiffusion1Model, @@ -36,9 +32,7 @@ MODEL_CLASSES = { def get_all_model_configs(): configs = set() for models in MODEL_CLASSES.values(): - for type, model in models.items(): - if type == ModelType.ControlNet: - continue # TODO: + for _, model in models.items(): configs.update(model._get_configs().values()) configs.discard(None) return list(configs) # TODO: set, list or tuple diff --git a/invokeai/backend/model_management/models/base.py b/invokeai/backend/model_management/models/base.py index f4563d0aea..3bf0045918 100644 --- a/invokeai/backend/model_management/models/base.py +++ b/invokeai/backend/model_management/models/base.py @@ -8,8 +8,9 @@ import torch import safetensors.torch from diffusers import DiffusionPipeline, ConfigMixin +from contextlib import suppress from pydantic import BaseModel, Field -from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any +from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union class BaseModelType(str, Enum): StableDiffusion1 = "sd-1" @@ -108,32 +109,45 @@ class ModelBase(metaclass=ABCMeta): @classmethod def _get_configs(cls): - if not hasattr(cls, "__configs"): - configs = dict() - for name in dir(cls): - if name.startswith("__"): - continue + with suppress(Exception): + return cls.__configs + + configs = dict() + for name in dir(cls): + if name.startswith("__"): + continue - value = getattr(cls, name) - if not isinstance(value, type) or not issubclass(value, ModelConfigBase): - continue + value = getattr(cls, name) + if not isinstance(value, type) or not issubclass(value, ModelConfigBase): + continue - fields = inspect.get_annotations(value) - if "format" not in fields: - raise Exception("Invalid config definition - format field not found") + fields = inspect.get_annotations(value) + if "format" not in fields: + raise Exception("Invalid config definition - format field not found") - format_type = typing.get_origin(fields["format"]) - if format_type not in {None, Literal}: - raise Exception(f"Invalid config definition - unknown format type: {fields['format']}") + format_type = typing.get_origin(fields["format"]) + if format_type not in {None, Literal, Union}: + raise Exception(f"Invalid config definition - unknown format type: {fields['format']}") - if format_type is Literal: - format = fields["format"].__args__[0] + if format_type is Union and not all(typing.get_origin(v) in {None, Literal} for v in fields["format"].__args__): + raise Exception(f"Invalid config definition - unknown format type: {fields['format']}") + + if format_type == Union: + f_fields = fields["format"].__args__ + else: + f_fields = (fields["format"],) + + + for field in f_fields: + if field is None: + format_name = None else: - format = None - configs[format] = value # TODO: error when override(multiple)? + format_name = field.__args__[0] - cls.__configs = configs + configs[format_name] = value # TODO: error when override(multiple)? + + cls.__configs = configs return cls.__configs @classmethod @@ -237,8 +251,11 @@ class DiffusersModel(ModelBase): ) break except Exception as e: - print("====ERR LOAD====") - print(f"{variant}: {e}") + #print("====ERR LOAD====") + #print(f"{variant}: {e}") + pass + else: + raise Exception(f"Failed to load {self.base_model}:{self.model_type}:{child_type} model") # calc more accurate size self.child_sizes[child_type] = calc_model_size_by_data(model) diff --git a/invokeai/backend/model_management/models/controlnet.py b/invokeai/backend/model_management/models/controlnet.py new file mode 100644 index 0000000000..d75c55010a --- /dev/null +++ b/invokeai/backend/model_management/models/controlnet.py @@ -0,0 +1,87 @@ +import os +import torch +from pathlib import Path +from typing import Optional, Union, Literal +from .base import ( + ModelBase, + ModelConfigBase, + BaseModelType, + ModelType, + SubModelType, + EmptyConfigLoader, + calc_model_size_by_fs, + calc_model_size_by_data, + classproperty, +) + +class ControlNetModel(ModelBase): + #model_class: Type + #model_size: int + + class Config(ModelConfigBase): + format: Union[Literal["checkpoint"], Literal["diffusers"]] + + def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): + assert model_type == ModelType.ControlNet + super().__init__(model_path, base_model, model_type) + + try: + config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json") + #config = json.loads(os.path.join(self.model_path, "config.json")) + except: + raise Exception("Invalid controlnet model! (config.json not found or invalid)") + + model_class_name = config.get("_class_name", None) + if model_class_name not in {"ControlNetModel"}: + raise Exception(f"Invalid ControlNet model! Unknown _class_name: {model_class_name}") + + try: + self.model_class = self._hf_definition_to_type(["diffusers", model_class_name]) + self.model_size = calc_model_size_by_fs(self.model_path) + except: + raise Exception("Invalid ControlNet model!") + + def get_size(self, child_type: Optional[SubModelType] = None): + if child_type is not None: + raise Exception("There is no child models in controlnet model") + return self.model_size + + def get_model( + self, + torch_dtype: Optional[torch.dtype], + child_type: Optional[SubModelType] = None, + ): + if child_type is not None: + raise Exception("There is no child models in controlnet model") + + model = self.model_class.from_pretrained( + self.model_path, + torch_dtype=torch_dtype, + ) + # calc more accurate size + self.model_size = calc_model_size_by_data(model) + return model + + @classproperty + def save_to_config(cls) -> bool: + return False + + @classmethod + def detect_format(cls, path: str): + if os.path.isdir(path): + return "diffusers" + else: + return "checkpoint" + + @classmethod + def convert_if_required( + cls, + model_path: str, + output_path: str, + config: ModelConfigBase, # empty config or config of parent model + base_model: BaseModelType, + ) -> str: + if cls.detect_format(model_path) != "diffusers": + raise NotImlemetedError("Checkpoint controlnet models currently unsupported") + else: + return model_path diff --git a/invokeai/backend/model_management/models/lora.py b/invokeai/backend/model_management/models/lora.py index aea769ac80..c69677fd0c 100644 --- a/invokeai/backend/model_management/models/lora.py +++ b/invokeai/backend/model_management/models/lora.py @@ -1,5 +1,5 @@ import torch -from typing import Optional +from typing import Optional, Union, Literal from .base import ( ModelBase, ModelConfigBase, @@ -15,7 +15,7 @@ class LoRAModel(ModelBase): #model_size: int class Config(ModelConfigBase): - format: None + format: Union[Literal["lycoris"], Literal["diffusers"]] def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): assert model_type == ModelType.Lora diff --git a/invokeai/backend/model_management/models/vae.py b/invokeai/backend/model_management/models/vae.py index a8255d88c5..1edb57ccc4 100644 --- a/invokeai/backend/model_management/models/vae.py +++ b/invokeai/backend/model_management/models/vae.py @@ -1,7 +1,7 @@ import os import torch from pathlib import Path -from typing import Optional +from typing import Optional, Union, Literal from .base import ( ModelBase, ModelConfigBase, @@ -23,7 +23,7 @@ class VaeModel(ModelBase): #model_size: int class Config(ModelConfigBase): - format: None + format: Union[Literal["checkpoint"], Literal["diffusers"]] def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): assert model_type == ModelType.Vae