mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add controlnet to model manager, fixes
This commit is contained in:
parent
740c05a0bb
commit
6c5954f9d1
@ -2,13 +2,9 @@ from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfig
|
|||||||
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
|
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
|
||||||
from .vae import VaeModel
|
from .vae import VaeModel
|
||||||
from .lora import LoRAModel
|
from .lora import LoRAModel
|
||||||
#from .controlnet import ControlNetModel # TODO:
|
from .controlnet import ControlNetModel # TODO:
|
||||||
from .textual_inversion import TextualInversionModel
|
from .textual_inversion import TextualInversionModel
|
||||||
|
|
||||||
# TODO:
|
|
||||||
class ControlNetModel:
|
|
||||||
pass
|
|
||||||
|
|
||||||
MODEL_CLASSES = {
|
MODEL_CLASSES = {
|
||||||
BaseModelType.StableDiffusion1: {
|
BaseModelType.StableDiffusion1: {
|
||||||
ModelType.Pipeline: StableDiffusion1Model,
|
ModelType.Pipeline: StableDiffusion1Model,
|
||||||
@ -36,9 +32,7 @@ MODEL_CLASSES = {
|
|||||||
def get_all_model_configs():
|
def get_all_model_configs():
|
||||||
configs = set()
|
configs = set()
|
||||||
for models in MODEL_CLASSES.values():
|
for models in MODEL_CLASSES.values():
|
||||||
for type, model in models.items():
|
for _, model in models.items():
|
||||||
if type == ModelType.ControlNet:
|
|
||||||
continue # TODO:
|
|
||||||
configs.update(model._get_configs().values())
|
configs.update(model._get_configs().values())
|
||||||
configs.discard(None)
|
configs.discard(None)
|
||||||
return list(configs) # TODO: set, list or tuple
|
return list(configs) # TODO: set, list or tuple
|
||||||
|
@ -8,8 +8,9 @@ import torch
|
|||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
from diffusers import DiffusionPipeline, ConfigMixin
|
from diffusers import DiffusionPipeline, ConfigMixin
|
||||||
|
|
||||||
|
from contextlib import suppress
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any
|
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
|
||||||
|
|
||||||
class BaseModelType(str, Enum):
|
class BaseModelType(str, Enum):
|
||||||
StableDiffusion1 = "sd-1"
|
StableDiffusion1 = "sd-1"
|
||||||
@ -108,7 +109,9 @@ class ModelBase(metaclass=ABCMeta):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_configs(cls):
|
def _get_configs(cls):
|
||||||
if not hasattr(cls, "__configs"):
|
with suppress(Exception):
|
||||||
|
return cls.__configs
|
||||||
|
|
||||||
configs = dict()
|
configs = dict()
|
||||||
for name in dir(cls):
|
for name in dir(cls):
|
||||||
if name.startswith("__"):
|
if name.startswith("__"):
|
||||||
@ -123,17 +126,28 @@ class ModelBase(metaclass=ABCMeta):
|
|||||||
raise Exception("Invalid config definition - format field not found")
|
raise Exception("Invalid config definition - format field not found")
|
||||||
|
|
||||||
format_type = typing.get_origin(fields["format"])
|
format_type = typing.get_origin(fields["format"])
|
||||||
if format_type not in {None, Literal}:
|
if format_type not in {None, Literal, Union}:
|
||||||
raise Exception(f"Invalid config definition - unknown format type: {fields['format']}")
|
raise Exception(f"Invalid config definition - unknown format type: {fields['format']}")
|
||||||
|
|
||||||
if format_type is Literal:
|
if format_type is Union and not all(typing.get_origin(v) in {None, Literal} for v in fields["format"].__args__):
|
||||||
format = fields["format"].__args__[0]
|
raise Exception(f"Invalid config definition - unknown format type: {fields['format']}")
|
||||||
|
|
||||||
|
if format_type == Union:
|
||||||
|
f_fields = fields["format"].__args__
|
||||||
else:
|
else:
|
||||||
format = None
|
f_fields = (fields["format"],)
|
||||||
configs[format] = value # TODO: error when override(multiple)?
|
|
||||||
|
|
||||||
|
for field in f_fields:
|
||||||
|
if field is None:
|
||||||
|
format_name = None
|
||||||
|
else:
|
||||||
|
format_name = field.__args__[0]
|
||||||
|
|
||||||
|
configs[format_name] = value # TODO: error when override(multiple)?
|
||||||
|
|
||||||
|
|
||||||
cls.__configs = configs
|
cls.__configs = configs
|
||||||
|
|
||||||
return cls.__configs
|
return cls.__configs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -237,8 +251,11 @@ class DiffusersModel(ModelBase):
|
|||||||
)
|
)
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("====ERR LOAD====")
|
#print("====ERR LOAD====")
|
||||||
print(f"{variant}: {e}")
|
#print(f"{variant}: {e}")
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise Exception(f"Failed to load {self.base_model}:{self.model_type}:{child_type} model")
|
||||||
|
|
||||||
# calc more accurate size
|
# calc more accurate size
|
||||||
self.child_sizes[child_type] = calc_model_size_by_data(model)
|
self.child_sizes[child_type] = calc_model_size_by_data(model)
|
||||||
|
87
invokeai/backend/model_management/models/controlnet.py
Normal file
87
invokeai/backend/model_management/models/controlnet.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Union, Literal
|
||||||
|
from .base import (
|
||||||
|
ModelBase,
|
||||||
|
ModelConfigBase,
|
||||||
|
BaseModelType,
|
||||||
|
ModelType,
|
||||||
|
SubModelType,
|
||||||
|
EmptyConfigLoader,
|
||||||
|
calc_model_size_by_fs,
|
||||||
|
calc_model_size_by_data,
|
||||||
|
classproperty,
|
||||||
|
)
|
||||||
|
|
||||||
|
class ControlNetModel(ModelBase):
|
||||||
|
#model_class: Type
|
||||||
|
#model_size: int
|
||||||
|
|
||||||
|
class Config(ModelConfigBase):
|
||||||
|
format: Union[Literal["checkpoint"], Literal["diffusers"]]
|
||||||
|
|
||||||
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
|
assert model_type == ModelType.ControlNet
|
||||||
|
super().__init__(model_path, base_model, model_type)
|
||||||
|
|
||||||
|
try:
|
||||||
|
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
|
||||||
|
#config = json.loads(os.path.join(self.model_path, "config.json"))
|
||||||
|
except:
|
||||||
|
raise Exception("Invalid controlnet model! (config.json not found or invalid)")
|
||||||
|
|
||||||
|
model_class_name = config.get("_class_name", None)
|
||||||
|
if model_class_name not in {"ControlNetModel"}:
|
||||||
|
raise Exception(f"Invalid ControlNet model! Unknown _class_name: {model_class_name}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.model_class = self._hf_definition_to_type(["diffusers", model_class_name])
|
||||||
|
self.model_size = calc_model_size_by_fs(self.model_path)
|
||||||
|
except:
|
||||||
|
raise Exception("Invalid ControlNet model!")
|
||||||
|
|
||||||
|
def get_size(self, child_type: Optional[SubModelType] = None):
|
||||||
|
if child_type is not None:
|
||||||
|
raise Exception("There is no child models in controlnet model")
|
||||||
|
return self.model_size
|
||||||
|
|
||||||
|
def get_model(
|
||||||
|
self,
|
||||||
|
torch_dtype: Optional[torch.dtype],
|
||||||
|
child_type: Optional[SubModelType] = None,
|
||||||
|
):
|
||||||
|
if child_type is not None:
|
||||||
|
raise Exception("There is no child models in controlnet model")
|
||||||
|
|
||||||
|
model = self.model_class.from_pretrained(
|
||||||
|
self.model_path,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
)
|
||||||
|
# calc more accurate size
|
||||||
|
self.model_size = calc_model_size_by_data(model)
|
||||||
|
return model
|
||||||
|
|
||||||
|
@classproperty
|
||||||
|
def save_to_config(cls) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def detect_format(cls, path: str):
|
||||||
|
if os.path.isdir(path):
|
||||||
|
return "diffusers"
|
||||||
|
else:
|
||||||
|
return "checkpoint"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def convert_if_required(
|
||||||
|
cls,
|
||||||
|
model_path: str,
|
||||||
|
output_path: str,
|
||||||
|
config: ModelConfigBase, # empty config or config of parent model
|
||||||
|
base_model: BaseModelType,
|
||||||
|
) -> str:
|
||||||
|
if cls.detect_format(model_path) != "diffusers":
|
||||||
|
raise NotImlemetedError("Checkpoint controlnet models currently unsupported")
|
||||||
|
else:
|
||||||
|
return model_path
|
@ -1,5 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
from typing import Optional
|
from typing import Optional, Union, Literal
|
||||||
from .base import (
|
from .base import (
|
||||||
ModelBase,
|
ModelBase,
|
||||||
ModelConfigBase,
|
ModelConfigBase,
|
||||||
@ -15,7 +15,7 @@ class LoRAModel(ModelBase):
|
|||||||
#model_size: int
|
#model_size: int
|
||||||
|
|
||||||
class Config(ModelConfigBase):
|
class Config(ModelConfigBase):
|
||||||
format: None
|
format: Union[Literal["lycoris"], Literal["diffusers"]]
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
assert model_type == ModelType.Lora
|
assert model_type == ModelType.Lora
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional, Union, Literal
|
||||||
from .base import (
|
from .base import (
|
||||||
ModelBase,
|
ModelBase,
|
||||||
ModelConfigBase,
|
ModelConfigBase,
|
||||||
@ -23,7 +23,7 @@ class VaeModel(ModelBase):
|
|||||||
#model_size: int
|
#model_size: int
|
||||||
|
|
||||||
class Config(ModelConfigBase):
|
class Config(ModelConfigBase):
|
||||||
format: None
|
format: Union[Literal["checkpoint"], Literal["diffusers"]]
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
assert model_type == ModelType.Vae
|
assert model_type == ModelType.Vae
|
||||||
|
Loading…
Reference in New Issue
Block a user