Merge branch 'main' into lstein/installer-for-new-model-layout

This commit is contained in:
Lincoln Stein
2023-06-23 01:45:05 +01:00
committed by GitHub
136 changed files with 4085 additions and 1013 deletions

View File

@ -266,6 +266,8 @@ class ModelManager(object):
for model_key, model_config in config.items():
model_name, base_model, model_type = self.parse_key(model_key)
model_class = MODEL_CLASSES[base_model][model_type]
# alias for config file
model_config["model_format"] = model_config.pop("format")
self.models[model_key] = model_class.create_config(**model_config)
# check config version number and update on disk/RAM if necessary
@ -446,38 +448,6 @@ class ModelManager(object):
_cache = self.cache,
)
def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]:
"""
Returns the name of the default model, or None
if none is defined.
"""
for model_key, model_config in self.models.items():
if model_config.default:
return self.parse_key(model_key)
for model_key, _ in self.models.items():
return self.parse_key(model_key)
else:
return None # TODO: or redo as (None, None, None)
def set_default_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
) -> None:
"""
Set the default model. The change will not take
effect until you call model_manager.commit()
"""
model_key = self.model_key(model_name, base_model, model_type)
if model_key not in self.models:
raise Exception(f"Unknown model: {model_key}")
for cur_model_key, config in self.models.items():
config.default = cur_model_key == model_key
def model_info(
self,
model_name: str,
@ -504,9 +474,9 @@ class ModelManager(object):
self,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None,
) -> Dict[str, Dict[str, str]]:
) -> list[dict]:
"""
Return a dict of models, in format [base_model][model_type][model_name]
Return a list of models.
Please use model_manager.models() to get all the model names,
model_manager.model_info('model-name') to get the stanza for the model
@ -514,7 +484,7 @@ class ModelManager(object):
object derived from models.yaml
"""
models = dict()
models = []
for model_key in sorted(self.models, key=str.casefold):
model_config = self.models[model_key]
@ -524,18 +494,16 @@ class ModelManager(object):
if model_type is not None and cur_model_type != model_type:
continue
if cur_base_model not in models:
models[cur_base_model] = dict()
if cur_model_type not in models[cur_base_model]:
models[cur_base_model][cur_model_type] = dict()
models[cur_base_model][cur_model_type][cur_model_name] = dict(
model_dict = dict(
**model_config.dict(exclude_defaults=True),
# OpenAPIModelInfoBase
name=cur_model_name,
base_model=cur_base_model,
type=cur_model_type,
)
models.append(model_dict)
return models
def print_models(self) -> None:
@ -647,7 +615,9 @@ class ModelManager(object):
model_class = MODEL_CLASSES[base_model][model_type]
if model_class.save_to_config:
# TODO: or exclude_unset better fits here?
data_to_save[model_key] = model_config.dict(exclude_defaults=True)
data_to_save[model_key] = model_config.dict(exclude_defaults=True, exclude={"error"})
# alias for config file
data_to_save[model_key]["format"] = data_to_save[model_key].pop("model_format")
yaml_str = OmegaConf.to_yaml(data_to_save)
config_file_path = conf_file or self.config_path

View File

@ -1,3 +1,7 @@
import inspect
from enum import Enum
from pydantic import BaseModel
from typing import Literal, get_origin
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
from .vae import VaeModel
@ -29,10 +33,63 @@ MODEL_CLASSES = {
#},
}
def get_all_model_configs():
configs = set()
for models in MODEL_CLASSES.values():
for _, model in models.items():
configs.update(model._get_configs().values())
configs.discard(None)
return list(configs) # TODO: set, list or tuple
MODEL_CONFIGS = list()
OPENAPI_MODEL_CONFIGS = list()
class OpenAPIModelInfoBase(BaseModel):
name: str
base_model: BaseModelType
type: ModelType
for base_model, models in MODEL_CLASSES.items():
for model_type, model_class in models.items():
model_configs = set(model_class._get_configs().values())
model_configs.discard(None)
MODEL_CONFIGS.extend(model_configs)
for cfg in model_configs:
model_name, cfg_name = cfg.__qualname__.split('.')[-2:]
openapi_cfg_name = model_name + cfg_name
if openapi_cfg_name in vars():
continue
api_wrapper = type(openapi_cfg_name, (cfg, OpenAPIModelInfoBase), dict(
__annotations__ = dict(
type=Literal[model_type.value],
),
))
#globals()[openapi_cfg_name] = api_wrapper
vars()[openapi_cfg_name] = api_wrapper
OPENAPI_MODEL_CONFIGS.append(api_wrapper)
def get_model_config_enums():
enums = list()
for model_config in MODEL_CONFIGS:
fields = inspect.get_annotations(model_config)
try:
field = fields["model_format"]
except:
raise Exception("format field not found")
# model_format: None
# model_format: SomeModelFormat
# model_format: Literal[SomeModelFormat.Diffusers]
# model_format: Literal[SomeModelFormat.Diffusers, SomeModelFormat.Checkpoint]
if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum):
enums.append(field)
elif get_origin(field) is Literal and all(isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__):
enums.append(type(field.__args__[0]))
elif field is None:
pass
else:
raise Exception(f"Unsupported format definition in {model_configs.__qualname__}")
return enums

View File

@ -48,12 +48,10 @@ class ModelError(str, Enum):
class ModelConfigBase(BaseModel):
path: str # or Path
#name: str # not included as present in model key
description: Optional[str] = Field(None)
format: Optional[str] = Field(None)
default: Optional[bool] = Field(False)
model_format: Optional[str] = Field(None)
# do not save to config
error: Optional[ModelError] = Field(None, exclude=True)
error: Optional[ModelError] = Field(None)
class Config:
use_enum_values = True
@ -94,6 +92,11 @@ class ModelBase(metaclass=ABCMeta):
def _hf_definition_to_type(self, subtypes: List[str]) -> Type:
if len(subtypes) < 2:
raise Exception("Invalid subfolder definition!")
if all(t is None for t in subtypes):
return None
elif any(t is None for t in subtypes):
raise Exception(f"Unsupported definition: {subtypes}")
if subtypes[0] in ["diffusers", "transformers"]:
res_type = sys.modules[subtypes[0]]
subtypes = subtypes[1:]
@ -122,46 +125,41 @@ class ModelBase(metaclass=ABCMeta):
continue
fields = inspect.get_annotations(value)
if "format" not in fields:
raise Exception("Invalid config definition - format field not found")
try:
field = fields["model_format"]
except:
raise Exception(f"Invalid config definition - format field not found({cls.__qualname__})")
format_type = typing.get_origin(fields["format"])
if format_type not in {None, Literal, Union}:
raise Exception(f"Invalid config definition - unknown format type: {fields['format']}")
if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum):
for model_format in field:
configs[model_format.value] = value
if format_type is Union and not all(typing.get_origin(v) in {None, Literal} for v in fields["format"].__args__):
raise Exception(f"Invalid config definition - unknown format type: {fields['format']}")
elif typing.get_origin(field) is Literal and all(isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__):
for model_format in field.__args__:
configs[model_format.value] = value
elif field is None:
configs[None] = value
if format_type == Union:
f_fields = fields["format"].__args__
else:
f_fields = (fields["format"],)
for field in f_fields:
if field is None:
format_name = None
else:
format_name = field.__args__[0]
configs[format_name] = value # TODO: error when override(multiple)?
raise Exception(f"Unsupported format definition in {cls.__qualname__}")
cls.__configs = configs
return cls.__configs
@classmethod
def create_config(cls, **kwargs) -> ModelConfigBase:
if "format" not in kwargs:
raise Exception("Field 'format' not found in model config")
if "model_format" not in kwargs:
raise Exception("Field 'model_format' not found in model config")
configs = cls._get_configs()
return configs[kwargs["format"]](**kwargs)
return configs[kwargs["model_format"]](**kwargs)
@classmethod
def probe_config(cls, path: str, **kwargs) -> ModelConfigBase:
return cls.create_config(
path=path,
format=cls.detect_format(path),
model_format=cls.detect_format(path),
)
@classmethod

View File

@ -1,5 +1,6 @@
import os
import torch
from enum import Enum
from pathlib import Path
from typing import Optional, Union, Literal
from .base import (
@ -14,12 +15,16 @@ from .base import (
classproperty,
)
class ControlNetModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class ControlNetModel(ModelBase):
#model_class: Type
#model_size: int
class Config(ModelConfigBase):
format: Union[Literal["checkpoint"], Literal["diffusers"]]
model_format: ControlNetModelFormat
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.ControlNet
@ -69,9 +74,9 @@ class ControlNetModel(ModelBase):
@classmethod
def detect_format(cls, path: str):
if os.path.isdir(path):
return "diffusers"
return ControlNetModelFormat.Diffusers
else:
return "checkpoint"
return ControlNetModelFormat.Checkpoint
@classmethod
def convert_if_required(
@ -81,7 +86,7 @@ class ControlNetModel(ModelBase):
config: ModelConfigBase, # empty config or config of parent model
base_model: BaseModelType,
) -> str:
if cls.detect_format(model_path) != "diffusers":
raise NotImlemetedError("Checkpoint controlnet models currently unsupported")
if cls.detect_format(model_path) != ControlNetModelFormat.Diffusers:
raise NotImplementedError("Checkpoint controlnet models currently unsupported")
else:
return model_path

View File

@ -1,5 +1,6 @@
import os
import torch
from enum import Enum
from typing import Optional, Union, Literal
from .base import (
ModelBase,
@ -12,11 +13,15 @@ from .base import (
# TODO: naming
from ..lora import LoRAModel as LoRAModelRaw
class LoRAModelFormat(str, Enum):
LyCORIS = "lycoris"
Diffusers = "diffusers"
class LoRAModel(ModelBase):
#model_size: int
class Config(ModelConfigBase):
format: Union[Literal["lycoris"], Literal["diffusers"]]
model_format: LoRAModelFormat # TODO:
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.Lora
@ -52,9 +57,9 @@ class LoRAModel(ModelBase):
@classmethod
def detect_format(cls, path: str):
if os.path.isdir(path):
return "diffusers"
return LoRAModelFormat.Diffusers
else:
return "lycoris"
return LoRAModelFormat.LyCORIS
@classmethod
def convert_if_required(
@ -64,7 +69,7 @@ class LoRAModel(ModelBase):
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
if cls.detect_format(model_path) == "diffusers":
if cls.detect_format(model_path) == LoRAModelFormat.Diffusers:
# TODO: add diffusers lora when it stabilizes a bit
raise NotImplementedError("Diffusers lora not supported")
else:

View File

@ -1,5 +1,6 @@
import os
import json
from enum import Enum
from pydantic import Field
from pathlib import Path
from typing import Literal, Optional, Union
@ -19,16 +20,19 @@ from .base import (
from invokeai.app.services.config import InvokeAIAppConfig
from omegaconf import OmegaConf
class StableDiffusion1ModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class StableDiffusion1Model(DiffusersModel):
class DiffusersConfig(ModelConfigBase):
format: Literal["diffusers"]
model_format: Literal[StableDiffusion1ModelFormat.Diffusers]
vae: Optional[str] = Field(None)
variant: ModelVariantType
class CheckpointConfig(ModelConfigBase):
format: Literal["checkpoint"]
model_format: Literal[StableDiffusion1ModelFormat.Checkpoint]
vae: Optional[str] = Field(None)
config: Optional[str] = Field(None)
variant: ModelVariantType
@ -47,7 +51,7 @@ class StableDiffusion1Model(DiffusersModel):
def probe_config(cls, path: str, **kwargs):
model_format = cls.detect_format(path)
ckpt_config_path = kwargs.get("config", None)
if model_format == "checkpoint":
if model_format == StableDiffusion1ModelFormat.Checkpoint:
if ckpt_config_path:
ckpt_config = OmegaConf.load(ckpt_config_path)
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
@ -57,7 +61,7 @@ class StableDiffusion1Model(DiffusersModel):
checkpoint = checkpoint.get('state_dict', checkpoint)
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
elif model_format == "diffusers":
elif model_format == StableDiffusion1ModelFormat.Diffusers:
unet_config_path = os.path.join(path, "unet", "config.json")
if os.path.exists(unet_config_path):
with open(unet_config_path, "r") as f:
@ -80,7 +84,7 @@ class StableDiffusion1Model(DiffusersModel):
return cls.create_config(
path=path,
format=model_format,
model_format=model_format,
config=ckpt_config_path,
variant=variant,
@ -93,9 +97,9 @@ class StableDiffusion1Model(DiffusersModel):
@classmethod
def detect_format(cls, model_path: str):
if os.path.isdir(model_path):
return "diffusers"
return StableDiffusion1ModelFormat.Diffusers
else:
return "checkpoint"
return StableDiffusion1ModelFormat.Checkpoint
@classmethod
def convert_if_required(
@ -116,19 +120,22 @@ class StableDiffusion1Model(DiffusersModel):
else:
return model_path
class StableDiffusion2ModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class StableDiffusion2Model(DiffusersModel):
# TODO: check that configs overwriten properly
class DiffusersConfig(ModelConfigBase):
format: Literal["diffusers"]
model_format: Literal[StableDiffusion2ModelFormat.Diffusers]
vae: Optional[str] = Field(None)
variant: ModelVariantType
prediction_type: SchedulerPredictionType
upcast_attention: bool
class CheckpointConfig(ModelConfigBase):
format: Literal["checkpoint"]
model_format: Literal[StableDiffusion2ModelFormat.Checkpoint]
vae: Optional[str] = Field(None)
config: Optional[str] = Field(None)
variant: ModelVariantType
@ -149,7 +156,7 @@ class StableDiffusion2Model(DiffusersModel):
def probe_config(cls, path: str, **kwargs):
model_format = cls.detect_format(path)
ckpt_config_path = kwargs.get("config", None)
if model_format == "checkpoint":
if model_format == StableDiffusion2ModelFormat.Checkpoint:
if ckpt_config_path:
ckpt_config = OmegaConf.load(ckpt_config_path)
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
@ -159,7 +166,7 @@ class StableDiffusion2Model(DiffusersModel):
checkpoint = checkpoint.get('state_dict', checkpoint)
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
elif model_format == "diffusers":
elif model_format == StableDiffusion2ModelFormat.Diffusers:
unet_config_path = os.path.join(path, "unet", "config.json")
if os.path.exists(unet_config_path):
with open(unet_config_path, "r") as f:
@ -191,7 +198,7 @@ class StableDiffusion2Model(DiffusersModel):
return cls.create_config(
path=path,
format=model_format,
model_format=model_format,
config=ckpt_config_path,
variant=variant,
@ -206,9 +213,9 @@ class StableDiffusion2Model(DiffusersModel):
@classmethod
def detect_format(cls, model_path: str):
if os.path.isdir(model_path):
return "diffusers"
return StableDiffusion2ModelFormat.Diffusers
else:
return "checkpoint"
return StableDiffusion2ModelFormat.Checkpoint
@classmethod
def convert_if_required(
@ -281,8 +288,8 @@ def _convert_ckpt_and_cache(
prediction_type = SchedulerPredictionType.Epsilon
elif version == BaseModelType.StableDiffusion2:
upcast_attention = config.upcast_attention
prediction_type = config.prediction_type
upcast_attention = model_config.upcast_attention
prediction_type = model_config.prediction_type
else:
raise Exception(f"Unknown model provided: {version}")

View File

@ -16,7 +16,7 @@ class TextualInversionModel(ModelBase):
#model_size: int
class Config(ModelConfigBase):
format: None
model_format: None
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.TextualInversion

View File

@ -1,5 +1,7 @@
import os
import torch
import safetensors
from enum import Enum
from pathlib import Path
from typing import Optional, Union, Literal
from .base import (
@ -18,12 +20,16 @@ from invokeai.app.services.config import InvokeAIAppConfig
from diffusers.utils import is_safetensors_available
from omegaconf import OmegaConf
class VaeModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class VaeModel(ModelBase):
#vae_class: Type
#model_size: int
class Config(ModelConfigBase):
format: Union[Literal["checkpoint"], Literal["diffusers"]]
model_format: VaeModelFormat
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.Vae
@ -70,9 +76,9 @@ class VaeModel(ModelBase):
@classmethod
def detect_format(cls, path: str):
if os.path.isdir(path):
return "diffusers"
return VaeModelFormat.Diffusers
else:
return "checkpoint"
return VaeModelFormat.Checkpoint
@classmethod
def convert_if_required(
@ -82,7 +88,7 @@ class VaeModel(ModelBase):
config: ModelConfigBase, # empty config or config of parent model
base_model: BaseModelType,
) -> str:
if cls.detect_format(model_path) != "diffusers":
if cls.detect_format(model_path) == VaeModelFormat.Checkpoint:
return _convert_vae_ckpt_and_cache(
weights_path=model_path,
output_path=output_path,