Fixes, add sd variant detection

This commit is contained in:
Sergey Borisov 2023-06-12 05:52:30 +03:00
parent 893f776f1d
commit 9fa78443de
5 changed files with 143 additions and 19 deletions

View File

@ -3,7 +3,7 @@
from typing import Any
from invokeai.app.models.image import ProgressImage
from invokeai.app.util.misc import get_timestamp
from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, SDModelInfo
from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, ModelInfo
from invokeai.app.models.exceptions import CanceledException
class EventServiceBase:
@ -136,7 +136,7 @@ class EventServiceBase:
base_model: BaseModelType,
model_type: ModelType,
submodel: SubModelType,
model_info: SDModelInfo,
model_info: ModelInfo,
) -> None:
"""Emitted when a model is correctly loaded (returns model info)"""
self.__emit_session_event(

View File

@ -13,7 +13,7 @@ from invokeai.backend.model_management.model_manager import (
BaseModelType,
ModelType,
SubModelType,
SDModelInfo,
ModelInfo,
)
from invokeai.app.models.exceptions import CanceledException
from .config import InvokeAIAppConfig
@ -49,7 +49,7 @@ class ModelManagerServiceBase(ABC):
submodel: Optional[SubModelType] = None,
node: Optional[BaseInvocation] = None,
context: Optional[InvocationContext] = None,
) -> SDModelInfo:
) -> ModelInfo:
"""Retrieve the indicated model with name and type.
submodel can be used to get a part (such as the vae)
of a diffusers pipeline."""
@ -302,7 +302,7 @@ class ModelManagerService(ModelManagerServiceBase):
submodel: Optional[SubModelType] = None,
node: Optional[BaseInvocation] = None,
context: Optional[InvocationContext] = None,
) -> SDModelInfo:
) -> ModelInfo:
"""
Retrieve the indicated model. submodel can be used to get a
part (such as the vae) of a diffusers mode.
@ -539,7 +539,7 @@ class ModelManagerService(ModelManagerServiceBase):
base_model: BaseModelType,
model_type: ModelType,
submodel: SubModelType,
model_info: Optional[SDModelInfo] = None,
model_info: Optional[ModelInfo] = None,
):
if context.services.queue.is_canceled(context.graph_execution_state_id):
raise CanceledException()

View File

@ -166,22 +166,13 @@ import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util import CUDA_DEVICE, download_with_resume
from .model_cache import ModelCache, ModelLocker
from .models import BaseModelType, SubModelType, MODEL_CLASSES
from .models import BaseModelType, ModelType, SubModelType, MODEL_CLASSES
# We are only starting to number the config file with release 3.
# The config file version doesn't have to start at release version, but it will help
# reduce confusion.
CONFIG_FILE_VERSION='3.0.0'
# temporary forward definitions to avoid circular import errors.
class ModelLocker(object):
"Forward declaration"
pass
class ModelCache(object):
"Forward declaration"
pass
@dataclass
class ModelInfo():
context: ModelLocker
@ -744,3 +735,37 @@ class ModelManager(object):
resolved_path = self.globals.root_dir / source
return resolved_path
def scan_models_directory(self):
loaded_files = set()
for model_key, model_config in list(self.models.items()):
model_name, base_model, model_type = self.parse_key(model_key)
if not os.path.exists(model_config.path):
if model_class.save_to_config:
model_config.error = ModelError.NotFound
else:
self.models.pop(model_key, None)
else:
loaded_files.add(model_config.path)
for base_model in BaseModelType:
for model_type in ModelType:
model_class = MODEL_CLASSES[base_model][model_type]
models_dir = os.path.join(self.globals.models_path, base_model, model_type)
if not os.path.exists(models_dir):
continue # TODO: or create all folders?
for entry_name in os.listdir(models_dir):
model_path = os.path.join(models_dir, entry_name)
if model_path not in loaded_files: # TODO: check
model_name = Path(model_path).stem
model_key = self.create_key(model_name, base_model, model_type)
if model_key in self.models:
raise Exception(f"Model with key {model_key} added twice")
model_config: ModelConfigBase = model_class.build_config(
path=model_path,
)
self.models[model_key] = model_config

View File

@ -5,26 +5,30 @@ from .lora import LoRAModel
#from .controlnet import ControlNetModel # TODO:
from .textual_inversion import TextualInversionModel
# TODO:
class ControlNetModel:
pass
MODEL_CLASSES = {
BaseModelType.StableDiffusion1_5: {
ModelType.Pipeline: StableDiffusion15Model,
ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel,
#ModelType.ControlNet: ControlNetModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
},
BaseModelType.StableDiffusion2: {
ModelType.Pipeline: StableDiffusion2Model,
ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel,
#ModelType.ControlNet: ControlNetModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
},
BaseModelType.StableDiffusion2Base: {
ModelType.Pipeline: StableDiffusion2BaseModel,
ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel,
#ModelType.ControlNet: ControlNetModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
},
#BaseModelType.Kandinsky2_1: {
@ -35,3 +39,11 @@ MODEL_CLASSES = {
# ModelType.TextualInversion: TextualInversionModel,
#},
}
# TODO: check with openapi annotation
def get_all_model_configs():
configs = []
for models in MODEL_CLASSES.values():
for model in models.values():
configs.extend(model._get_configs())
return configs

View File

@ -1,5 +1,7 @@
import os
import json
import torch
import safetensors.torch
from pydantic import Field
from typing import Literal, Optional
from .base import (
@ -8,10 +10,13 @@ from .base import (
BaseModelType,
ModelType,
SubModelType,
VariantType,
DiffusersModel,
)
from invokeai.app.services.config import InvokeAIAppConfig
ModelVariantType = VariantType # TODO:
# TODO: how to name properly
class StableDiffusion15Model(DiffusersModel):
@ -20,11 +25,13 @@ class StableDiffusion15Model(DiffusersModel):
class DiffusersConfig(ModelConfigBase):
format: Literal["diffusers"]
vae: Optional[str] = Field(None)
variant: ModelVariantType
class CheckpointConfig(ModelConfigBase):
format: Literal["checkpoint"]
vae: Optional[str] = Field(None)
config: Optional[str] = Field(None)
variant: ModelVariantType
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
@ -36,6 +43,86 @@ class StableDiffusion15Model(DiffusersModel):
model_type=ModelType.Pipeline,
)
@staticmethod
def _fast_safetensors_reader(path: str):
checkpoint = dict()
device = torch.device("meta")
with open(path, "rb") as f:
definition_len = int.from_bytes(f.read(8), 'little')
definition_json = f.read(definition_len)
definition = json.loads(definition_json)
if "__metadata__" in definition and definition["__metadata__"].get("format", "pt") not in {"pt", "torch", "pytorch"}:
raise Exception("Supported only pytorch safetensors files")
definition.pop("__metadata__", None)
for key, info in definition.items():
dtype = {
"I8": torch.int8,
"I16": torch.int16,
"I32": torch.int32,
"I64": torch.int64,
"F16": torch.float16,
"F32": torch.float32,
"F64": torch.float64,
}[info["dtype"]]
checkpoint[key] = torch.empty(info["shape"], dtype=dtype, device=device)
return checkpoint
@classmethod
def read_checkpoint_meta(cls, path: str):
if path.endswith(".safetensors"):
try:
checkpoint = cls._fast_safetensors_reader(path)
except:
checkpoint = safetensors.torch.load_file(path, device="cpu") # TODO: create issue for support "meta"?
else:
checkpoint = torch.load(path, map_location=torch.device("meta"))
return checkpoint
@classmethod
def build_config(cls, **kwargs):
if "format" not in kwargs:
kwargs["format"] = cls.detect_format(kwargs["path"])
if "variant" not in kwargs:
if kwargs["format"] == "checkpoint":
if "config" in kwargs:
ckpt_config = OmegaConf.load(kwargs["config"])
in_channels = ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
else:
checkpoint = cls.read_checkpoint_meta(kwargs["path"])
checkpoint = checkpoint.get('state_dict', checkpoint)
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
elif kwargs["format"] == "diffusers":
unet_config_path = os.path.join(kwargs["path"], "unet", "config.json")
if os.path.exists(unet_config_path):
unet_config = json.loads(unet_config_path)
in_channels = unet_config['in_channels']
else:
raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)")
else:
raise NotImplementedError(f"Unknown stable diffusion format: {kwargs['format']}")
if in_channels == 9:
kwargs["variant"] = ModelVariantType.Inpaint
elif in_channels == 5:
kwargs["variant"] = ModelVariantType.Depth
elif in_channels == 4:
kwargs["variant"] = ModelVariantType.Normal
else:
raise Exception("Unkown stable diffusion model format")
return super().build_config(**kwargs)
@classmethod
def save_to_config(cls) -> bool:
return True