Add controlnet to model manager, fixes

This commit is contained in:
Sergey Borisov 2023-06-14 04:26:21 +03:00
parent 740c05a0bb
commit 6c5954f9d1
5 changed files with 132 additions and 34 deletions

View File

@ -2,13 +2,9 @@ from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfig
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
from .vae import VaeModel
from .lora import LoRAModel
#from .controlnet import ControlNetModel # TODO:
from .controlnet import ControlNetModel # TODO:
from .textual_inversion import TextualInversionModel
# TODO:
class ControlNetModel:
pass
MODEL_CLASSES = {
BaseModelType.StableDiffusion1: {
ModelType.Pipeline: StableDiffusion1Model,
@ -36,9 +32,7 @@ MODEL_CLASSES = {
def get_all_model_configs():
configs = set()
for models in MODEL_CLASSES.values():
for type, model in models.items():
if type == ModelType.ControlNet:
continue # TODO:
for _, model in models.items():
configs.update(model._get_configs().values())
configs.discard(None)
return list(configs) # TODO: set, list or tuple

View File

@ -8,8 +8,9 @@ import torch
import safetensors.torch
from diffusers import DiffusionPipeline, ConfigMixin
from contextlib import suppress
from pydantic import BaseModel, Field
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
class BaseModelType(str, Enum):
StableDiffusion1 = "sd-1"
@ -108,32 +109,45 @@ class ModelBase(metaclass=ABCMeta):
@classmethod
def _get_configs(cls):
if not hasattr(cls, "__configs"):
configs = dict()
for name in dir(cls):
if name.startswith("__"):
continue
with suppress(Exception):
return cls.__configs
configs = dict()
for name in dir(cls):
if name.startswith("__"):
continue
value = getattr(cls, name)
if not isinstance(value, type) or not issubclass(value, ModelConfigBase):
continue
value = getattr(cls, name)
if not isinstance(value, type) or not issubclass(value, ModelConfigBase):
continue
fields = inspect.get_annotations(value)
if "format" not in fields:
raise Exception("Invalid config definition - format field not found")
fields = inspect.get_annotations(value)
if "format" not in fields:
raise Exception("Invalid config definition - format field not found")
format_type = typing.get_origin(fields["format"])
if format_type not in {None, Literal}:
raise Exception(f"Invalid config definition - unknown format type: {fields['format']}")
format_type = typing.get_origin(fields["format"])
if format_type not in {None, Literal, Union}:
raise Exception(f"Invalid config definition - unknown format type: {fields['format']}")
if format_type is Literal:
format = fields["format"].__args__[0]
if format_type is Union and not all(typing.get_origin(v) in {None, Literal} for v in fields["format"].__args__):
raise Exception(f"Invalid config definition - unknown format type: {fields['format']}")
if format_type == Union:
f_fields = fields["format"].__args__
else:
f_fields = (fields["format"],)
for field in f_fields:
if field is None:
format_name = None
else:
format = None
configs[format] = value # TODO: error when override(multiple)?
format_name = field.__args__[0]
cls.__configs = configs
configs[format_name] = value # TODO: error when override(multiple)?
cls.__configs = configs
return cls.__configs
@classmethod
@ -237,8 +251,11 @@ class DiffusersModel(ModelBase):
)
break
except Exception as e:
print("====ERR LOAD====")
print(f"{variant}: {e}")
#print("====ERR LOAD====")
#print(f"{variant}: {e}")
pass
else:
raise Exception(f"Failed to load {self.base_model}:{self.model_type}:{child_type} model")
# calc more accurate size
self.child_sizes[child_type] = calc_model_size_by_data(model)

View File

@ -0,0 +1,87 @@
import os
import torch
from pathlib import Path
from typing import Optional, Union, Literal
from .base import (
ModelBase,
ModelConfigBase,
BaseModelType,
ModelType,
SubModelType,
EmptyConfigLoader,
calc_model_size_by_fs,
calc_model_size_by_data,
classproperty,
)
class ControlNetModel(ModelBase):
#model_class: Type
#model_size: int
class Config(ModelConfigBase):
format: Union[Literal["checkpoint"], Literal["diffusers"]]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.ControlNet
super().__init__(model_path, base_model, model_type)
try:
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
#config = json.loads(os.path.join(self.model_path, "config.json"))
except:
raise Exception("Invalid controlnet model! (config.json not found or invalid)")
model_class_name = config.get("_class_name", None)
if model_class_name not in {"ControlNetModel"}:
raise Exception(f"Invalid ControlNet model! Unknown _class_name: {model_class_name}")
try:
self.model_class = self._hf_definition_to_type(["diffusers", model_class_name])
self.model_size = calc_model_size_by_fs(self.model_path)
except:
raise Exception("Invalid ControlNet model!")
def get_size(self, child_type: Optional[SubModelType] = None):
if child_type is not None:
raise Exception("There is no child models in controlnet model")
return self.model_size
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
):
if child_type is not None:
raise Exception("There is no child models in controlnet model")
model = self.model_class.from_pretrained(
self.model_path,
torch_dtype=torch_dtype,
)
# calc more accurate size
self.model_size = calc_model_size_by_data(model)
return model
@classproperty
def save_to_config(cls) -> bool:
return False
@classmethod
def detect_format(cls, path: str):
if os.path.isdir(path):
return "diffusers"
else:
return "checkpoint"
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase, # empty config or config of parent model
base_model: BaseModelType,
) -> str:
if cls.detect_format(model_path) != "diffusers":
raise NotImlemetedError("Checkpoint controlnet models currently unsupported")
else:
return model_path

View File

@ -1,5 +1,5 @@
import torch
from typing import Optional
from typing import Optional, Union, Literal
from .base import (
ModelBase,
ModelConfigBase,
@ -15,7 +15,7 @@ class LoRAModel(ModelBase):
#model_size: int
class Config(ModelConfigBase):
format: None
format: Union[Literal["lycoris"], Literal["diffusers"]]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.Lora

View File

@ -1,7 +1,7 @@
import os
import torch
from pathlib import Path
from typing import Optional
from typing import Optional, Union, Literal
from .base import (
ModelBase,
ModelConfigBase,
@ -23,7 +23,7 @@ class VaeModel(ModelBase):
#model_size: int
class Config(ModelConfigBase):
format: None
format: Union[Literal["checkpoint"], Literal["diffusers"]]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.Vae