Add controlnet to model manager, fixes

This commit is contained in:
Sergey Borisov 2023-06-14 04:26:21 +03:00
parent 740c05a0bb
commit 6c5954f9d1
5 changed files with 132 additions and 34 deletions

View File

@ -2,13 +2,9 @@ from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfig
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
from .vae import VaeModel from .vae import VaeModel
from .lora import LoRAModel from .lora import LoRAModel
#from .controlnet import ControlNetModel # TODO: from .controlnet import ControlNetModel # TODO:
from .textual_inversion import TextualInversionModel from .textual_inversion import TextualInversionModel
# TODO:
class ControlNetModel:
pass
MODEL_CLASSES = { MODEL_CLASSES = {
BaseModelType.StableDiffusion1: { BaseModelType.StableDiffusion1: {
ModelType.Pipeline: StableDiffusion1Model, ModelType.Pipeline: StableDiffusion1Model,
@ -36,9 +32,7 @@ MODEL_CLASSES = {
def get_all_model_configs(): def get_all_model_configs():
configs = set() configs = set()
for models in MODEL_CLASSES.values(): for models in MODEL_CLASSES.values():
for type, model in models.items(): for _, model in models.items():
if type == ModelType.ControlNet:
continue # TODO:
configs.update(model._get_configs().values()) configs.update(model._get_configs().values())
configs.discard(None) configs.discard(None)
return list(configs) # TODO: set, list or tuple return list(configs) # TODO: set, list or tuple

View File

@ -8,8 +8,9 @@ import torch
import safetensors.torch import safetensors.torch
from diffusers import DiffusionPipeline, ConfigMixin from diffusers import DiffusionPipeline, ConfigMixin
from contextlib import suppress
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
class BaseModelType(str, Enum): class BaseModelType(str, Enum):
StableDiffusion1 = "sd-1" StableDiffusion1 = "sd-1"
@ -108,7 +109,9 @@ class ModelBase(metaclass=ABCMeta):
@classmethod @classmethod
def _get_configs(cls): def _get_configs(cls):
if not hasattr(cls, "__configs"): with suppress(Exception):
return cls.__configs
configs = dict() configs = dict()
for name in dir(cls): for name in dir(cls):
if name.startswith("__"): if name.startswith("__"):
@ -123,17 +126,28 @@ class ModelBase(metaclass=ABCMeta):
raise Exception("Invalid config definition - format field not found") raise Exception("Invalid config definition - format field not found")
format_type = typing.get_origin(fields["format"]) format_type = typing.get_origin(fields["format"])
if format_type not in {None, Literal}: if format_type not in {None, Literal, Union}:
raise Exception(f"Invalid config definition - unknown format type: {fields['format']}") raise Exception(f"Invalid config definition - unknown format type: {fields['format']}")
if format_type is Literal: if format_type is Union and not all(typing.get_origin(v) in {None, Literal} for v in fields["format"].__args__):
format = fields["format"].__args__[0] raise Exception(f"Invalid config definition - unknown format type: {fields['format']}")
if format_type == Union:
f_fields = fields["format"].__args__
else: else:
format = None f_fields = (fields["format"],)
configs[format] = value # TODO: error when override(multiple)?
for field in f_fields:
if field is None:
format_name = None
else:
format_name = field.__args__[0]
configs[format_name] = value # TODO: error when override(multiple)?
cls.__configs = configs cls.__configs = configs
return cls.__configs return cls.__configs
@classmethod @classmethod
@ -237,8 +251,11 @@ class DiffusersModel(ModelBase):
) )
break break
except Exception as e: except Exception as e:
print("====ERR LOAD====") #print("====ERR LOAD====")
print(f"{variant}: {e}") #print(f"{variant}: {e}")
pass
else:
raise Exception(f"Failed to load {self.base_model}:{self.model_type}:{child_type} model")
# calc more accurate size # calc more accurate size
self.child_sizes[child_type] = calc_model_size_by_data(model) self.child_sizes[child_type] = calc_model_size_by_data(model)

View File

@ -0,0 +1,87 @@
import os
import torch
from pathlib import Path
from typing import Optional, Union, Literal
from .base import (
ModelBase,
ModelConfigBase,
BaseModelType,
ModelType,
SubModelType,
EmptyConfigLoader,
calc_model_size_by_fs,
calc_model_size_by_data,
classproperty,
)
class ControlNetModel(ModelBase):
#model_class: Type
#model_size: int
class Config(ModelConfigBase):
format: Union[Literal["checkpoint"], Literal["diffusers"]]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.ControlNet
super().__init__(model_path, base_model, model_type)
try:
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
#config = json.loads(os.path.join(self.model_path, "config.json"))
except:
raise Exception("Invalid controlnet model! (config.json not found or invalid)")
model_class_name = config.get("_class_name", None)
if model_class_name not in {"ControlNetModel"}:
raise Exception(f"Invalid ControlNet model! Unknown _class_name: {model_class_name}")
try:
self.model_class = self._hf_definition_to_type(["diffusers", model_class_name])
self.model_size = calc_model_size_by_fs(self.model_path)
except:
raise Exception("Invalid ControlNet model!")
def get_size(self, child_type: Optional[SubModelType] = None):
if child_type is not None:
raise Exception("There is no child models in controlnet model")
return self.model_size
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
):
if child_type is not None:
raise Exception("There is no child models in controlnet model")
model = self.model_class.from_pretrained(
self.model_path,
torch_dtype=torch_dtype,
)
# calc more accurate size
self.model_size = calc_model_size_by_data(model)
return model
@classproperty
def save_to_config(cls) -> bool:
return False
@classmethod
def detect_format(cls, path: str):
if os.path.isdir(path):
return "diffusers"
else:
return "checkpoint"
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase, # empty config or config of parent model
base_model: BaseModelType,
) -> str:
if cls.detect_format(model_path) != "diffusers":
raise NotImlemetedError("Checkpoint controlnet models currently unsupported")
else:
return model_path

View File

@ -1,5 +1,5 @@
import torch import torch
from typing import Optional from typing import Optional, Union, Literal
from .base import ( from .base import (
ModelBase, ModelBase,
ModelConfigBase, ModelConfigBase,
@ -15,7 +15,7 @@ class LoRAModel(ModelBase):
#model_size: int #model_size: int
class Config(ModelConfigBase): class Config(ModelConfigBase):
format: None format: Union[Literal["lycoris"], Literal["diffusers"]]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.Lora assert model_type == ModelType.Lora

View File

@ -1,7 +1,7 @@
import os import os
import torch import torch
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional, Union, Literal
from .base import ( from .base import (
ModelBase, ModelBase,
ModelConfigBase, ModelConfigBase,
@ -23,7 +23,7 @@ class VaeModel(ModelBase):
#model_size: int #model_size: int
class Config(ModelConfigBase): class Config(ModelConfigBase):
format: None format: Union[Literal["checkpoint"], Literal["diffusers"]]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.Vae assert model_type == ModelType.Vae