Fixes, add sd variant detection

This commit is contained in:
Sergey Borisov 2023-06-12 05:52:30 +03:00
parent 893f776f1d
commit 9fa78443de
5 changed files with 143 additions and 19 deletions

View File

@ -3,7 +3,7 @@
from typing import Any from typing import Any
from invokeai.app.models.image import ProgressImage from invokeai.app.models.image import ProgressImage
from invokeai.app.util.misc import get_timestamp from invokeai.app.util.misc import get_timestamp
from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, SDModelInfo from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, ModelInfo
from invokeai.app.models.exceptions import CanceledException from invokeai.app.models.exceptions import CanceledException
class EventServiceBase: class EventServiceBase:
@ -136,7 +136,7 @@ class EventServiceBase:
base_model: BaseModelType, base_model: BaseModelType,
model_type: ModelType, model_type: ModelType,
submodel: SubModelType, submodel: SubModelType,
model_info: SDModelInfo, model_info: ModelInfo,
) -> None: ) -> None:
"""Emitted when a model is correctly loaded (returns model info)""" """Emitted when a model is correctly loaded (returns model info)"""
self.__emit_session_event( self.__emit_session_event(

View File

@ -13,7 +13,7 @@ from invokeai.backend.model_management.model_manager import (
BaseModelType, BaseModelType,
ModelType, ModelType,
SubModelType, SubModelType,
SDModelInfo, ModelInfo,
) )
from invokeai.app.models.exceptions import CanceledException from invokeai.app.models.exceptions import CanceledException
from .config import InvokeAIAppConfig from .config import InvokeAIAppConfig
@ -49,7 +49,7 @@ class ModelManagerServiceBase(ABC):
submodel: Optional[SubModelType] = None, submodel: Optional[SubModelType] = None,
node: Optional[BaseInvocation] = None, node: Optional[BaseInvocation] = None,
context: Optional[InvocationContext] = None, context: Optional[InvocationContext] = None,
) -> SDModelInfo: ) -> ModelInfo:
"""Retrieve the indicated model with name and type. """Retrieve the indicated model with name and type.
submodel can be used to get a part (such as the vae) submodel can be used to get a part (such as the vae)
of a diffusers pipeline.""" of a diffusers pipeline."""
@ -302,7 +302,7 @@ class ModelManagerService(ModelManagerServiceBase):
submodel: Optional[SubModelType] = None, submodel: Optional[SubModelType] = None,
node: Optional[BaseInvocation] = None, node: Optional[BaseInvocation] = None,
context: Optional[InvocationContext] = None, context: Optional[InvocationContext] = None,
) -> SDModelInfo: ) -> ModelInfo:
""" """
Retrieve the indicated model. submodel can be used to get a Retrieve the indicated model. submodel can be used to get a
part (such as the vae) of a diffusers mode. part (such as the vae) of a diffusers mode.
@ -539,7 +539,7 @@ class ModelManagerService(ModelManagerServiceBase):
base_model: BaseModelType, base_model: BaseModelType,
model_type: ModelType, model_type: ModelType,
submodel: SubModelType, submodel: SubModelType,
model_info: Optional[SDModelInfo] = None, model_info: Optional[ModelInfo] = None,
): ):
if context.services.queue.is_canceled(context.graph_execution_state_id): if context.services.queue.is_canceled(context.graph_execution_state_id):
raise CanceledException() raise CanceledException()

View File

@ -166,22 +166,13 @@ import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util import CUDA_DEVICE, download_with_resume from invokeai.backend.util import CUDA_DEVICE, download_with_resume
from .model_cache import ModelCache, ModelLocker from .model_cache import ModelCache, ModelLocker
from .models import BaseModelType, SubModelType, MODEL_CLASSES from .models import BaseModelType, ModelType, SubModelType, MODEL_CLASSES
# We are only starting to number the config file with release 3. # We are only starting to number the config file with release 3.
# The config file version doesn't have to start at release version, but it will help # The config file version doesn't have to start at release version, but it will help
# reduce confusion. # reduce confusion.
CONFIG_FILE_VERSION='3.0.0' CONFIG_FILE_VERSION='3.0.0'
# temporary forward definitions to avoid circular import errors.
class ModelLocker(object):
"Forward declaration"
pass
class ModelCache(object):
"Forward declaration"
pass
@dataclass @dataclass
class ModelInfo(): class ModelInfo():
context: ModelLocker context: ModelLocker
@ -744,3 +735,37 @@ class ModelManager(object):
resolved_path = self.globals.root_dir / source resolved_path = self.globals.root_dir / source
return resolved_path return resolved_path
def scan_models_directory(self):
loaded_files = set()
for model_key, model_config in list(self.models.items()):
model_name, base_model, model_type = self.parse_key(model_key)
if not os.path.exists(model_config.path):
if model_class.save_to_config:
model_config.error = ModelError.NotFound
else:
self.models.pop(model_key, None)
else:
loaded_files.add(model_config.path)
for base_model in BaseModelType:
for model_type in ModelType:
model_class = MODEL_CLASSES[base_model][model_type]
models_dir = os.path.join(self.globals.models_path, base_model, model_type)
if not os.path.exists(models_dir):
continue # TODO: or create all folders?
for entry_name in os.listdir(models_dir):
model_path = os.path.join(models_dir, entry_name)
if model_path not in loaded_files: # TODO: check
model_name = Path(model_path).stem
model_key = self.create_key(model_name, base_model, model_type)
if model_key in self.models:
raise Exception(f"Model with key {model_key} added twice")
model_config: ModelConfigBase = model_class.build_config(
path=model_path,
)
self.models[model_key] = model_config

View File

@ -5,26 +5,30 @@ from .lora import LoRAModel
#from .controlnet import ControlNetModel # TODO: #from .controlnet import ControlNetModel # TODO:
from .textual_inversion import TextualInversionModel from .textual_inversion import TextualInversionModel
# TODO:
class ControlNetModel:
pass
MODEL_CLASSES = { MODEL_CLASSES = {
BaseModelType.StableDiffusion1_5: { BaseModelType.StableDiffusion1_5: {
ModelType.Pipeline: StableDiffusion15Model, ModelType.Pipeline: StableDiffusion15Model,
ModelType.Vae: VaeModel, ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel, ModelType.Lora: LoRAModel,
#ModelType.ControlNet: ControlNetModel, ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel, ModelType.TextualInversion: TextualInversionModel,
}, },
BaseModelType.StableDiffusion2: { BaseModelType.StableDiffusion2: {
ModelType.Pipeline: StableDiffusion2Model, ModelType.Pipeline: StableDiffusion2Model,
ModelType.Vae: VaeModel, ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel, ModelType.Lora: LoRAModel,
#ModelType.ControlNet: ControlNetModel, ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel, ModelType.TextualInversion: TextualInversionModel,
}, },
BaseModelType.StableDiffusion2Base: { BaseModelType.StableDiffusion2Base: {
ModelType.Pipeline: StableDiffusion2BaseModel, ModelType.Pipeline: StableDiffusion2BaseModel,
ModelType.Vae: VaeModel, ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel, ModelType.Lora: LoRAModel,
#ModelType.ControlNet: ControlNetModel, ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel, ModelType.TextualInversion: TextualInversionModel,
}, },
#BaseModelType.Kandinsky2_1: { #BaseModelType.Kandinsky2_1: {
@ -35,3 +39,11 @@ MODEL_CLASSES = {
# ModelType.TextualInversion: TextualInversionModel, # ModelType.TextualInversion: TextualInversionModel,
#}, #},
} }
# TODO: check with openapi annotation
def get_all_model_configs():
configs = []
for models in MODEL_CLASSES.values():
for model in models.values():
configs.extend(model._get_configs())
return configs

View File

@ -1,5 +1,7 @@
import os import os
import json
import torch import torch
import safetensors.torch
from pydantic import Field from pydantic import Field
from typing import Literal, Optional from typing import Literal, Optional
from .base import ( from .base import (
@ -8,10 +10,13 @@ from .base import (
BaseModelType, BaseModelType,
ModelType, ModelType,
SubModelType, SubModelType,
VariantType,
DiffusersModel, DiffusersModel,
) )
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
ModelVariantType = VariantType # TODO:
# TODO: how to name properly # TODO: how to name properly
class StableDiffusion15Model(DiffusersModel): class StableDiffusion15Model(DiffusersModel):
@ -20,11 +25,13 @@ class StableDiffusion15Model(DiffusersModel):
class DiffusersConfig(ModelConfigBase): class DiffusersConfig(ModelConfigBase):
format: Literal["diffusers"] format: Literal["diffusers"]
vae: Optional[str] = Field(None) vae: Optional[str] = Field(None)
variant: ModelVariantType
class CheckpointConfig(ModelConfigBase): class CheckpointConfig(ModelConfigBase):
format: Literal["checkpoint"] format: Literal["checkpoint"]
vae: Optional[str] = Field(None) vae: Optional[str] = Field(None)
config: Optional[str] = Field(None) config: Optional[str] = Field(None)
variant: ModelVariantType
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
@ -36,6 +43,86 @@ class StableDiffusion15Model(DiffusersModel):
model_type=ModelType.Pipeline, model_type=ModelType.Pipeline,
) )
@staticmethod
def _fast_safetensors_reader(path: str):
checkpoint = dict()
device = torch.device("meta")
with open(path, "rb") as f:
definition_len = int.from_bytes(f.read(8), 'little')
definition_json = f.read(definition_len)
definition = json.loads(definition_json)
if "__metadata__" in definition and definition["__metadata__"].get("format", "pt") not in {"pt", "torch", "pytorch"}:
raise Exception("Supported only pytorch safetensors files")
definition.pop("__metadata__", None)
for key, info in definition.items():
dtype = {
"I8": torch.int8,
"I16": torch.int16,
"I32": torch.int32,
"I64": torch.int64,
"F16": torch.float16,
"F32": torch.float32,
"F64": torch.float64,
}[info["dtype"]]
checkpoint[key] = torch.empty(info["shape"], dtype=dtype, device=device)
return checkpoint
@classmethod
def read_checkpoint_meta(cls, path: str):
if path.endswith(".safetensors"):
try:
checkpoint = cls._fast_safetensors_reader(path)
except:
checkpoint = safetensors.torch.load_file(path, device="cpu") # TODO: create issue for support "meta"?
else:
checkpoint = torch.load(path, map_location=torch.device("meta"))
return checkpoint
@classmethod
def build_config(cls, **kwargs):
if "format" not in kwargs:
kwargs["format"] = cls.detect_format(kwargs["path"])
if "variant" not in kwargs:
if kwargs["format"] == "checkpoint":
if "config" in kwargs:
ckpt_config = OmegaConf.load(kwargs["config"])
in_channels = ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
else:
checkpoint = cls.read_checkpoint_meta(kwargs["path"])
checkpoint = checkpoint.get('state_dict', checkpoint)
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
elif kwargs["format"] == "diffusers":
unet_config_path = os.path.join(kwargs["path"], "unet", "config.json")
if os.path.exists(unet_config_path):
unet_config = json.loads(unet_config_path)
in_channels = unet_config['in_channels']
else:
raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)")
else:
raise NotImplementedError(f"Unknown stable diffusion format: {kwargs['format']}")
if in_channels == 9:
kwargs["variant"] = ModelVariantType.Inpaint
elif in_channels == 5:
kwargs["variant"] = ModelVariantType.Depth
elif in_channels == 4:
kwargs["variant"] = ModelVariantType.Normal
else:
raise Exception("Unkown stable diffusion model format")
return super().build_config(**kwargs)
@classmethod @classmethod
def save_to_config(cls) -> bool: def save_to_config(cls) -> bool:
return True return True