mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fixes, add sd variant detection
This commit is contained in:
parent
893f776f1d
commit
9fa78443de
@ -3,7 +3,7 @@
|
||||
from typing import Any
|
||||
from invokeai.app.models.image import ProgressImage
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, SDModelInfo
|
||||
from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, ModelInfo
|
||||
from invokeai.app.models.exceptions import CanceledException
|
||||
|
||||
class EventServiceBase:
|
||||
@ -136,7 +136,7 @@ class EventServiceBase:
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: SubModelType,
|
||||
model_info: SDModelInfo,
|
||||
model_info: ModelInfo,
|
||||
) -> None:
|
||||
"""Emitted when a model is correctly loaded (returns model info)"""
|
||||
self.__emit_session_event(
|
||||
|
@ -13,7 +13,7 @@ from invokeai.backend.model_management.model_manager import (
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
SDModelInfo,
|
||||
ModelInfo,
|
||||
)
|
||||
from invokeai.app.models.exceptions import CanceledException
|
||||
from .config import InvokeAIAppConfig
|
||||
@ -49,7 +49,7 @@ class ModelManagerServiceBase(ABC):
|
||||
submodel: Optional[SubModelType] = None,
|
||||
node: Optional[BaseInvocation] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> SDModelInfo:
|
||||
) -> ModelInfo:
|
||||
"""Retrieve the indicated model with name and type.
|
||||
submodel can be used to get a part (such as the vae)
|
||||
of a diffusers pipeline."""
|
||||
@ -302,7 +302,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
submodel: Optional[SubModelType] = None,
|
||||
node: Optional[BaseInvocation] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> SDModelInfo:
|
||||
) -> ModelInfo:
|
||||
"""
|
||||
Retrieve the indicated model. submodel can be used to get a
|
||||
part (such as the vae) of a diffusers mode.
|
||||
@ -539,7 +539,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: SubModelType,
|
||||
model_info: Optional[SDModelInfo] = None,
|
||||
model_info: Optional[ModelInfo] = None,
|
||||
):
|
||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
||||
raise CanceledException()
|
||||
|
@ -166,22 +166,13 @@ import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.util import CUDA_DEVICE, download_with_resume
|
||||
from .model_cache import ModelCache, ModelLocker
|
||||
from .models import BaseModelType, SubModelType, MODEL_CLASSES
|
||||
from .models import BaseModelType, ModelType, SubModelType, MODEL_CLASSES
|
||||
|
||||
# We are only starting to number the config file with release 3.
|
||||
# The config file version doesn't have to start at release version, but it will help
|
||||
# reduce confusion.
|
||||
CONFIG_FILE_VERSION='3.0.0'
|
||||
|
||||
# temporary forward definitions to avoid circular import errors.
|
||||
class ModelLocker(object):
|
||||
"Forward declaration"
|
||||
pass
|
||||
|
||||
class ModelCache(object):
|
||||
"Forward declaration"
|
||||
pass
|
||||
|
||||
@dataclass
|
||||
class ModelInfo():
|
||||
context: ModelLocker
|
||||
@ -744,3 +735,37 @@ class ModelManager(object):
|
||||
resolved_path = self.globals.root_dir / source
|
||||
return resolved_path
|
||||
|
||||
def scan_models_directory(self):
|
||||
loaded_files = set()
|
||||
|
||||
for model_key, model_config in list(self.models.items()):
|
||||
model_name, base_model, model_type = self.parse_key(model_key)
|
||||
if not os.path.exists(model_config.path):
|
||||
if model_class.save_to_config:
|
||||
model_config.error = ModelError.NotFound
|
||||
else:
|
||||
self.models.pop(model_key, None)
|
||||
else:
|
||||
loaded_files.add(model_config.path)
|
||||
|
||||
for base_model in BaseModelType:
|
||||
for model_type in ModelType:
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
models_dir = os.path.join(self.globals.models_path, base_model, model_type)
|
||||
|
||||
if not os.path.exists(models_dir):
|
||||
continue # TODO: or create all folders?
|
||||
|
||||
for entry_name in os.listdir(models_dir):
|
||||
model_path = os.path.join(models_dir, entry_name)
|
||||
if model_path not in loaded_files: # TODO: check
|
||||
model_name = Path(model_path).stem
|
||||
model_key = self.create_key(model_name, base_model, model_type)
|
||||
|
||||
if model_key in self.models:
|
||||
raise Exception(f"Model with key {model_key} added twice")
|
||||
|
||||
model_config: ModelConfigBase = model_class.build_config(
|
||||
path=model_path,
|
||||
)
|
||||
self.models[model_key] = model_config
|
||||
|
@ -5,26 +5,30 @@ from .lora import LoRAModel
|
||||
#from .controlnet import ControlNetModel # TODO:
|
||||
from .textual_inversion import TextualInversionModel
|
||||
|
||||
# TODO:
|
||||
class ControlNetModel:
|
||||
pass
|
||||
|
||||
MODEL_CLASSES = {
|
||||
BaseModelType.StableDiffusion1_5: {
|
||||
ModelType.Pipeline: StableDiffusion15Model,
|
||||
ModelType.Vae: VaeModel,
|
||||
ModelType.Lora: LoRAModel,
|
||||
#ModelType.ControlNet: ControlNetModel,
|
||||
ModelType.ControlNet: ControlNetModel,
|
||||
ModelType.TextualInversion: TextualInversionModel,
|
||||
},
|
||||
BaseModelType.StableDiffusion2: {
|
||||
ModelType.Pipeline: StableDiffusion2Model,
|
||||
ModelType.Vae: VaeModel,
|
||||
ModelType.Lora: LoRAModel,
|
||||
#ModelType.ControlNet: ControlNetModel,
|
||||
ModelType.ControlNet: ControlNetModel,
|
||||
ModelType.TextualInversion: TextualInversionModel,
|
||||
},
|
||||
BaseModelType.StableDiffusion2Base: {
|
||||
ModelType.Pipeline: StableDiffusion2BaseModel,
|
||||
ModelType.Vae: VaeModel,
|
||||
ModelType.Lora: LoRAModel,
|
||||
#ModelType.ControlNet: ControlNetModel,
|
||||
ModelType.ControlNet: ControlNetModel,
|
||||
ModelType.TextualInversion: TextualInversionModel,
|
||||
},
|
||||
#BaseModelType.Kandinsky2_1: {
|
||||
@ -35,3 +39,11 @@ MODEL_CLASSES = {
|
||||
# ModelType.TextualInversion: TextualInversionModel,
|
||||
#},
|
||||
}
|
||||
|
||||
# TODO: check with openapi annotation
|
||||
def get_all_model_configs():
|
||||
configs = []
|
||||
for models in MODEL_CLASSES.values():
|
||||
for model in models.values():
|
||||
configs.extend(model._get_configs())
|
||||
return configs
|
||||
|
@ -1,5 +1,7 @@
|
||||
import os
|
||||
import json
|
||||
import torch
|
||||
import safetensors.torch
|
||||
from pydantic import Field
|
||||
from typing import Literal, Optional
|
||||
from .base import (
|
||||
@ -8,10 +10,13 @@ from .base import (
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
VariantType,
|
||||
DiffusersModel,
|
||||
)
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
ModelVariantType = VariantType # TODO:
|
||||
|
||||
|
||||
# TODO: how to name properly
|
||||
class StableDiffusion15Model(DiffusersModel):
|
||||
@ -20,11 +25,13 @@ class StableDiffusion15Model(DiffusersModel):
|
||||
class DiffusersConfig(ModelConfigBase):
|
||||
format: Literal["diffusers"]
|
||||
vae: Optional[str] = Field(None)
|
||||
variant: ModelVariantType
|
||||
|
||||
class CheckpointConfig(ModelConfigBase):
|
||||
format: Literal["checkpoint"]
|
||||
vae: Optional[str] = Field(None)
|
||||
config: Optional[str] = Field(None)
|
||||
variant: ModelVariantType
|
||||
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
@ -36,6 +43,86 @@ class StableDiffusion15Model(DiffusersModel):
|
||||
model_type=ModelType.Pipeline,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _fast_safetensors_reader(path: str):
|
||||
checkpoint = dict()
|
||||
device = torch.device("meta")
|
||||
with open(path, "rb") as f:
|
||||
definition_len = int.from_bytes(f.read(8), 'little')
|
||||
definition_json = f.read(definition_len)
|
||||
definition = json.loads(definition_json)
|
||||
|
||||
if "__metadata__" in definition and definition["__metadata__"].get("format", "pt") not in {"pt", "torch", "pytorch"}:
|
||||
raise Exception("Supported only pytorch safetensors files")
|
||||
definition.pop("__metadata__", None)
|
||||
|
||||
for key, info in definition.items():
|
||||
dtype = {
|
||||
"I8": torch.int8,
|
||||
"I16": torch.int16,
|
||||
"I32": torch.int32,
|
||||
"I64": torch.int64,
|
||||
"F16": torch.float16,
|
||||
"F32": torch.float32,
|
||||
"F64": torch.float64,
|
||||
}[info["dtype"]]
|
||||
|
||||
checkpoint[key] = torch.empty(info["shape"], dtype=dtype, device=device)
|
||||
|
||||
return checkpoint
|
||||
|
||||
|
||||
@classmethod
|
||||
def read_checkpoint_meta(cls, path: str):
|
||||
if path.endswith(".safetensors"):
|
||||
try:
|
||||
checkpoint = cls._fast_safetensors_reader(path)
|
||||
except:
|
||||
checkpoint = safetensors.torch.load_file(path, device="cpu") # TODO: create issue for support "meta"?
|
||||
else:
|
||||
checkpoint = torch.load(path, map_location=torch.device("meta"))
|
||||
return checkpoint
|
||||
|
||||
@classmethod
|
||||
def build_config(cls, **kwargs):
|
||||
if "format" not in kwargs:
|
||||
kwargs["format"] = cls.detect_format(kwargs["path"])
|
||||
|
||||
if "variant" not in kwargs:
|
||||
if kwargs["format"] == "checkpoint":
|
||||
if "config" in kwargs:
|
||||
ckpt_config = OmegaConf.load(kwargs["config"])
|
||||
in_channels = ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
|
||||
|
||||
else:
|
||||
checkpoint = cls.read_checkpoint_meta(kwargs["path"])
|
||||
checkpoint = checkpoint.get('state_dict', checkpoint)
|
||||
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||
|
||||
elif kwargs["format"] == "diffusers":
|
||||
unet_config_path = os.path.join(kwargs["path"], "unet", "config.json")
|
||||
if os.path.exists(unet_config_path):
|
||||
unet_config = json.loads(unet_config_path)
|
||||
in_channels = unet_config['in_channels']
|
||||
|
||||
else:
|
||||
raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)")
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown stable diffusion format: {kwargs['format']}")
|
||||
|
||||
if in_channels == 9:
|
||||
kwargs["variant"] = ModelVariantType.Inpaint
|
||||
elif in_channels == 5:
|
||||
kwargs["variant"] = ModelVariantType.Depth
|
||||
elif in_channels == 4:
|
||||
kwargs["variant"] = ModelVariantType.Normal
|
||||
else:
|
||||
raise Exception("Unkown stable diffusion model format")
|
||||
|
||||
|
||||
return super().build_config(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def save_to_config(cls) -> bool:
|
||||
return True
|
||||
|
Loading…
Reference in New Issue
Block a user