mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add controlnet to model manager, fixes
This commit is contained in:
parent
740c05a0bb
commit
6c5954f9d1
@ -2,13 +2,9 @@ from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfig
|
||||
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
|
||||
from .vae import VaeModel
|
||||
from .lora import LoRAModel
|
||||
#from .controlnet import ControlNetModel # TODO:
|
||||
from .controlnet import ControlNetModel # TODO:
|
||||
from .textual_inversion import TextualInversionModel
|
||||
|
||||
# TODO:
|
||||
class ControlNetModel:
|
||||
pass
|
||||
|
||||
MODEL_CLASSES = {
|
||||
BaseModelType.StableDiffusion1: {
|
||||
ModelType.Pipeline: StableDiffusion1Model,
|
||||
@ -36,9 +32,7 @@ MODEL_CLASSES = {
|
||||
def get_all_model_configs():
|
||||
configs = set()
|
||||
for models in MODEL_CLASSES.values():
|
||||
for type, model in models.items():
|
||||
if type == ModelType.ControlNet:
|
||||
continue # TODO:
|
||||
for _, model in models.items():
|
||||
configs.update(model._get_configs().values())
|
||||
configs.discard(None)
|
||||
return list(configs) # TODO: set, list or tuple
|
||||
|
@ -8,8 +8,9 @@ import torch
|
||||
import safetensors.torch
|
||||
from diffusers import DiffusionPipeline, ConfigMixin
|
||||
|
||||
from contextlib import suppress
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any
|
||||
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
|
||||
|
||||
class BaseModelType(str, Enum):
|
||||
StableDiffusion1 = "sd-1"
|
||||
@ -108,32 +109,45 @@ class ModelBase(metaclass=ABCMeta):
|
||||
|
||||
@classmethod
|
||||
def _get_configs(cls):
|
||||
if not hasattr(cls, "__configs"):
|
||||
configs = dict()
|
||||
for name in dir(cls):
|
||||
if name.startswith("__"):
|
||||
continue
|
||||
with suppress(Exception):
|
||||
return cls.__configs
|
||||
|
||||
configs = dict()
|
||||
for name in dir(cls):
|
||||
if name.startswith("__"):
|
||||
continue
|
||||
|
||||
value = getattr(cls, name)
|
||||
if not isinstance(value, type) or not issubclass(value, ModelConfigBase):
|
||||
continue
|
||||
value = getattr(cls, name)
|
||||
if not isinstance(value, type) or not issubclass(value, ModelConfigBase):
|
||||
continue
|
||||
|
||||
fields = inspect.get_annotations(value)
|
||||
if "format" not in fields:
|
||||
raise Exception("Invalid config definition - format field not found")
|
||||
fields = inspect.get_annotations(value)
|
||||
if "format" not in fields:
|
||||
raise Exception("Invalid config definition - format field not found")
|
||||
|
||||
format_type = typing.get_origin(fields["format"])
|
||||
if format_type not in {None, Literal}:
|
||||
raise Exception(f"Invalid config definition - unknown format type: {fields['format']}")
|
||||
format_type = typing.get_origin(fields["format"])
|
||||
if format_type not in {None, Literal, Union}:
|
||||
raise Exception(f"Invalid config definition - unknown format type: {fields['format']}")
|
||||
|
||||
if format_type is Literal:
|
||||
format = fields["format"].__args__[0]
|
||||
if format_type is Union and not all(typing.get_origin(v) in {None, Literal} for v in fields["format"].__args__):
|
||||
raise Exception(f"Invalid config definition - unknown format type: {fields['format']}")
|
||||
|
||||
if format_type == Union:
|
||||
f_fields = fields["format"].__args__
|
||||
else:
|
||||
f_fields = (fields["format"],)
|
||||
|
||||
|
||||
for field in f_fields:
|
||||
if field is None:
|
||||
format_name = None
|
||||
else:
|
||||
format = None
|
||||
configs[format] = value # TODO: error when override(multiple)?
|
||||
format_name = field.__args__[0]
|
||||
|
||||
cls.__configs = configs
|
||||
configs[format_name] = value # TODO: error when override(multiple)?
|
||||
|
||||
|
||||
cls.__configs = configs
|
||||
return cls.__configs
|
||||
|
||||
@classmethod
|
||||
@ -237,8 +251,11 @@ class DiffusersModel(ModelBase):
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
print("====ERR LOAD====")
|
||||
print(f"{variant}: {e}")
|
||||
#print("====ERR LOAD====")
|
||||
#print(f"{variant}: {e}")
|
||||
pass
|
||||
else:
|
||||
raise Exception(f"Failed to load {self.base_model}:{self.model_type}:{child_type} model")
|
||||
|
||||
# calc more accurate size
|
||||
self.child_sizes[child_type] = calc_model_size_by_data(model)
|
||||
|
87
invokeai/backend/model_management/models/controlnet.py
Normal file
87
invokeai/backend/model_management/models/controlnet.py
Normal file
@ -0,0 +1,87 @@
|
||||
import os
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, Literal
|
||||
from .base import (
|
||||
ModelBase,
|
||||
ModelConfigBase,
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
EmptyConfigLoader,
|
||||
calc_model_size_by_fs,
|
||||
calc_model_size_by_data,
|
||||
classproperty,
|
||||
)
|
||||
|
||||
class ControlNetModel(ModelBase):
|
||||
#model_class: Type
|
||||
#model_size: int
|
||||
|
||||
class Config(ModelConfigBase):
|
||||
format: Union[Literal["checkpoint"], Literal["diffusers"]]
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert model_type == ModelType.ControlNet
|
||||
super().__init__(model_path, base_model, model_type)
|
||||
|
||||
try:
|
||||
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
|
||||
#config = json.loads(os.path.join(self.model_path, "config.json"))
|
||||
except:
|
||||
raise Exception("Invalid controlnet model! (config.json not found or invalid)")
|
||||
|
||||
model_class_name = config.get("_class_name", None)
|
||||
if model_class_name not in {"ControlNetModel"}:
|
||||
raise Exception(f"Invalid ControlNet model! Unknown _class_name: {model_class_name}")
|
||||
|
||||
try:
|
||||
self.model_class = self._hf_definition_to_type(["diffusers", model_class_name])
|
||||
self.model_size = calc_model_size_by_fs(self.model_path)
|
||||
except:
|
||||
raise Exception("Invalid ControlNet model!")
|
||||
|
||||
def get_size(self, child_type: Optional[SubModelType] = None):
|
||||
if child_type is not None:
|
||||
raise Exception("There is no child models in controlnet model")
|
||||
return self.model_size
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
torch_dtype: Optional[torch.dtype],
|
||||
child_type: Optional[SubModelType] = None,
|
||||
):
|
||||
if child_type is not None:
|
||||
raise Exception("There is no child models in controlnet model")
|
||||
|
||||
model = self.model_class.from_pretrained(
|
||||
self.model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
# calc more accurate size
|
||||
self.model_size = calc_model_size_by_data(model)
|
||||
return model
|
||||
|
||||
@classproperty
|
||||
def save_to_config(cls) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, path: str):
|
||||
if os.path.isdir(path):
|
||||
return "diffusers"
|
||||
else:
|
||||
return "checkpoint"
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
config: ModelConfigBase, # empty config or config of parent model
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
if cls.detect_format(model_path) != "diffusers":
|
||||
raise NotImlemetedError("Checkpoint controlnet models currently unsupported")
|
||||
else:
|
||||
return model_path
|
@ -1,5 +1,5 @@
|
||||
import torch
|
||||
from typing import Optional
|
||||
from typing import Optional, Union, Literal
|
||||
from .base import (
|
||||
ModelBase,
|
||||
ModelConfigBase,
|
||||
@ -15,7 +15,7 @@ class LoRAModel(ModelBase):
|
||||
#model_size: int
|
||||
|
||||
class Config(ModelConfigBase):
|
||||
format: None
|
||||
format: Union[Literal["lycoris"], Literal["diffusers"]]
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert model_type == ModelType.Lora
|
||||
|
@ -1,7 +1,7 @@
|
||||
import os
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Optional, Union, Literal
|
||||
from .base import (
|
||||
ModelBase,
|
||||
ModelConfigBase,
|
||||
@ -23,7 +23,7 @@ class VaeModel(ModelBase):
|
||||
#model_size: int
|
||||
|
||||
class Config(ModelConfigBase):
|
||||
format: None
|
||||
format: Union[Literal["checkpoint"], Literal["diffusers"]]
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert model_type == ModelType.Vae
|
||||
|
Loading…
Reference in New Issue
Block a user