mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fixes, add sd variant detection
This commit is contained in:
parent
893f776f1d
commit
9fa78443de
@ -3,7 +3,7 @@
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
from invokeai.app.models.image import ProgressImage
|
from invokeai.app.models.image import ProgressImage
|
||||||
from invokeai.app.util.misc import get_timestamp
|
from invokeai.app.util.misc import get_timestamp
|
||||||
from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, SDModelInfo
|
from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, ModelInfo
|
||||||
from invokeai.app.models.exceptions import CanceledException
|
from invokeai.app.models.exceptions import CanceledException
|
||||||
|
|
||||||
class EventServiceBase:
|
class EventServiceBase:
|
||||||
@ -136,7 +136,7 @@ class EventServiceBase:
|
|||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
submodel: SubModelType,
|
submodel: SubModelType,
|
||||||
model_info: SDModelInfo,
|
model_info: ModelInfo,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Emitted when a model is correctly loaded (returns model info)"""
|
"""Emitted when a model is correctly loaded (returns model info)"""
|
||||||
self.__emit_session_event(
|
self.__emit_session_event(
|
||||||
|
@ -13,7 +13,7 @@ from invokeai.backend.model_management.model_manager import (
|
|||||||
BaseModelType,
|
BaseModelType,
|
||||||
ModelType,
|
ModelType,
|
||||||
SubModelType,
|
SubModelType,
|
||||||
SDModelInfo,
|
ModelInfo,
|
||||||
)
|
)
|
||||||
from invokeai.app.models.exceptions import CanceledException
|
from invokeai.app.models.exceptions import CanceledException
|
||||||
from .config import InvokeAIAppConfig
|
from .config import InvokeAIAppConfig
|
||||||
@ -49,7 +49,7 @@ class ModelManagerServiceBase(ABC):
|
|||||||
submodel: Optional[SubModelType] = None,
|
submodel: Optional[SubModelType] = None,
|
||||||
node: Optional[BaseInvocation] = None,
|
node: Optional[BaseInvocation] = None,
|
||||||
context: Optional[InvocationContext] = None,
|
context: Optional[InvocationContext] = None,
|
||||||
) -> SDModelInfo:
|
) -> ModelInfo:
|
||||||
"""Retrieve the indicated model with name and type.
|
"""Retrieve the indicated model with name and type.
|
||||||
submodel can be used to get a part (such as the vae)
|
submodel can be used to get a part (such as the vae)
|
||||||
of a diffusers pipeline."""
|
of a diffusers pipeline."""
|
||||||
@ -302,7 +302,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
submodel: Optional[SubModelType] = None,
|
submodel: Optional[SubModelType] = None,
|
||||||
node: Optional[BaseInvocation] = None,
|
node: Optional[BaseInvocation] = None,
|
||||||
context: Optional[InvocationContext] = None,
|
context: Optional[InvocationContext] = None,
|
||||||
) -> SDModelInfo:
|
) -> ModelInfo:
|
||||||
"""
|
"""
|
||||||
Retrieve the indicated model. submodel can be used to get a
|
Retrieve the indicated model. submodel can be used to get a
|
||||||
part (such as the vae) of a diffusers mode.
|
part (such as the vae) of a diffusers mode.
|
||||||
@ -539,7 +539,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
submodel: SubModelType,
|
submodel: SubModelType,
|
||||||
model_info: Optional[SDModelInfo] = None,
|
model_info: Optional[ModelInfo] = None,
|
||||||
):
|
):
|
||||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
||||||
raise CanceledException()
|
raise CanceledException()
|
||||||
|
@ -166,22 +166,13 @@ import invokeai.backend.util.logging as logger
|
|||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.util import CUDA_DEVICE, download_with_resume
|
from invokeai.backend.util import CUDA_DEVICE, download_with_resume
|
||||||
from .model_cache import ModelCache, ModelLocker
|
from .model_cache import ModelCache, ModelLocker
|
||||||
from .models import BaseModelType, SubModelType, MODEL_CLASSES
|
from .models import BaseModelType, ModelType, SubModelType, MODEL_CLASSES
|
||||||
|
|
||||||
# We are only starting to number the config file with release 3.
|
# We are only starting to number the config file with release 3.
|
||||||
# The config file version doesn't have to start at release version, but it will help
|
# The config file version doesn't have to start at release version, but it will help
|
||||||
# reduce confusion.
|
# reduce confusion.
|
||||||
CONFIG_FILE_VERSION='3.0.0'
|
CONFIG_FILE_VERSION='3.0.0'
|
||||||
|
|
||||||
# temporary forward definitions to avoid circular import errors.
|
|
||||||
class ModelLocker(object):
|
|
||||||
"Forward declaration"
|
|
||||||
pass
|
|
||||||
|
|
||||||
class ModelCache(object):
|
|
||||||
"Forward declaration"
|
|
||||||
pass
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelInfo():
|
class ModelInfo():
|
||||||
context: ModelLocker
|
context: ModelLocker
|
||||||
@ -744,3 +735,37 @@ class ModelManager(object):
|
|||||||
resolved_path = self.globals.root_dir / source
|
resolved_path = self.globals.root_dir / source
|
||||||
return resolved_path
|
return resolved_path
|
||||||
|
|
||||||
|
def scan_models_directory(self):
|
||||||
|
loaded_files = set()
|
||||||
|
|
||||||
|
for model_key, model_config in list(self.models.items()):
|
||||||
|
model_name, base_model, model_type = self.parse_key(model_key)
|
||||||
|
if not os.path.exists(model_config.path):
|
||||||
|
if model_class.save_to_config:
|
||||||
|
model_config.error = ModelError.NotFound
|
||||||
|
else:
|
||||||
|
self.models.pop(model_key, None)
|
||||||
|
else:
|
||||||
|
loaded_files.add(model_config.path)
|
||||||
|
|
||||||
|
for base_model in BaseModelType:
|
||||||
|
for model_type in ModelType:
|
||||||
|
model_class = MODEL_CLASSES[base_model][model_type]
|
||||||
|
models_dir = os.path.join(self.globals.models_path, base_model, model_type)
|
||||||
|
|
||||||
|
if not os.path.exists(models_dir):
|
||||||
|
continue # TODO: or create all folders?
|
||||||
|
|
||||||
|
for entry_name in os.listdir(models_dir):
|
||||||
|
model_path = os.path.join(models_dir, entry_name)
|
||||||
|
if model_path not in loaded_files: # TODO: check
|
||||||
|
model_name = Path(model_path).stem
|
||||||
|
model_key = self.create_key(model_name, base_model, model_type)
|
||||||
|
|
||||||
|
if model_key in self.models:
|
||||||
|
raise Exception(f"Model with key {model_key} added twice")
|
||||||
|
|
||||||
|
model_config: ModelConfigBase = model_class.build_config(
|
||||||
|
path=model_path,
|
||||||
|
)
|
||||||
|
self.models[model_key] = model_config
|
||||||
|
@ -5,26 +5,30 @@ from .lora import LoRAModel
|
|||||||
#from .controlnet import ControlNetModel # TODO:
|
#from .controlnet import ControlNetModel # TODO:
|
||||||
from .textual_inversion import TextualInversionModel
|
from .textual_inversion import TextualInversionModel
|
||||||
|
|
||||||
|
# TODO:
|
||||||
|
class ControlNetModel:
|
||||||
|
pass
|
||||||
|
|
||||||
MODEL_CLASSES = {
|
MODEL_CLASSES = {
|
||||||
BaseModelType.StableDiffusion1_5: {
|
BaseModelType.StableDiffusion1_5: {
|
||||||
ModelType.Pipeline: StableDiffusion15Model,
|
ModelType.Pipeline: StableDiffusion15Model,
|
||||||
ModelType.Vae: VaeModel,
|
ModelType.Vae: VaeModel,
|
||||||
ModelType.Lora: LoRAModel,
|
ModelType.Lora: LoRAModel,
|
||||||
#ModelType.ControlNet: ControlNetModel,
|
ModelType.ControlNet: ControlNetModel,
|
||||||
ModelType.TextualInversion: TextualInversionModel,
|
ModelType.TextualInversion: TextualInversionModel,
|
||||||
},
|
},
|
||||||
BaseModelType.StableDiffusion2: {
|
BaseModelType.StableDiffusion2: {
|
||||||
ModelType.Pipeline: StableDiffusion2Model,
|
ModelType.Pipeline: StableDiffusion2Model,
|
||||||
ModelType.Vae: VaeModel,
|
ModelType.Vae: VaeModel,
|
||||||
ModelType.Lora: LoRAModel,
|
ModelType.Lora: LoRAModel,
|
||||||
#ModelType.ControlNet: ControlNetModel,
|
ModelType.ControlNet: ControlNetModel,
|
||||||
ModelType.TextualInversion: TextualInversionModel,
|
ModelType.TextualInversion: TextualInversionModel,
|
||||||
},
|
},
|
||||||
BaseModelType.StableDiffusion2Base: {
|
BaseModelType.StableDiffusion2Base: {
|
||||||
ModelType.Pipeline: StableDiffusion2BaseModel,
|
ModelType.Pipeline: StableDiffusion2BaseModel,
|
||||||
ModelType.Vae: VaeModel,
|
ModelType.Vae: VaeModel,
|
||||||
ModelType.Lora: LoRAModel,
|
ModelType.Lora: LoRAModel,
|
||||||
#ModelType.ControlNet: ControlNetModel,
|
ModelType.ControlNet: ControlNetModel,
|
||||||
ModelType.TextualInversion: TextualInversionModel,
|
ModelType.TextualInversion: TextualInversionModel,
|
||||||
},
|
},
|
||||||
#BaseModelType.Kandinsky2_1: {
|
#BaseModelType.Kandinsky2_1: {
|
||||||
@ -35,3 +39,11 @@ MODEL_CLASSES = {
|
|||||||
# ModelType.TextualInversion: TextualInversionModel,
|
# ModelType.TextualInversion: TextualInversionModel,
|
||||||
#},
|
#},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# TODO: check with openapi annotation
|
||||||
|
def get_all_model_configs():
|
||||||
|
configs = []
|
||||||
|
for models in MODEL_CLASSES.values():
|
||||||
|
for model in models.values():
|
||||||
|
configs.extend(model._get_configs())
|
||||||
|
return configs
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
|
import json
|
||||||
import torch
|
import torch
|
||||||
|
import safetensors.torch
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from typing import Literal, Optional
|
from typing import Literal, Optional
|
||||||
from .base import (
|
from .base import (
|
||||||
@ -8,10 +10,13 @@ from .base import (
|
|||||||
BaseModelType,
|
BaseModelType,
|
||||||
ModelType,
|
ModelType,
|
||||||
SubModelType,
|
SubModelType,
|
||||||
|
VariantType,
|
||||||
DiffusersModel,
|
DiffusersModel,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
|
ModelVariantType = VariantType # TODO:
|
||||||
|
|
||||||
|
|
||||||
# TODO: how to name properly
|
# TODO: how to name properly
|
||||||
class StableDiffusion15Model(DiffusersModel):
|
class StableDiffusion15Model(DiffusersModel):
|
||||||
@ -20,11 +25,13 @@ class StableDiffusion15Model(DiffusersModel):
|
|||||||
class DiffusersConfig(ModelConfigBase):
|
class DiffusersConfig(ModelConfigBase):
|
||||||
format: Literal["diffusers"]
|
format: Literal["diffusers"]
|
||||||
vae: Optional[str] = Field(None)
|
vae: Optional[str] = Field(None)
|
||||||
|
variant: ModelVariantType
|
||||||
|
|
||||||
class CheckpointConfig(ModelConfigBase):
|
class CheckpointConfig(ModelConfigBase):
|
||||||
format: Literal["checkpoint"]
|
format: Literal["checkpoint"]
|
||||||
vae: Optional[str] = Field(None)
|
vae: Optional[str] = Field(None)
|
||||||
config: Optional[str] = Field(None)
|
config: Optional[str] = Field(None)
|
||||||
|
variant: ModelVariantType
|
||||||
|
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
@ -36,6 +43,86 @@ class StableDiffusion15Model(DiffusersModel):
|
|||||||
model_type=ModelType.Pipeline,
|
model_type=ModelType.Pipeline,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _fast_safetensors_reader(path: str):
|
||||||
|
checkpoint = dict()
|
||||||
|
device = torch.device("meta")
|
||||||
|
with open(path, "rb") as f:
|
||||||
|
definition_len = int.from_bytes(f.read(8), 'little')
|
||||||
|
definition_json = f.read(definition_len)
|
||||||
|
definition = json.loads(definition_json)
|
||||||
|
|
||||||
|
if "__metadata__" in definition and definition["__metadata__"].get("format", "pt") not in {"pt", "torch", "pytorch"}:
|
||||||
|
raise Exception("Supported only pytorch safetensors files")
|
||||||
|
definition.pop("__metadata__", None)
|
||||||
|
|
||||||
|
for key, info in definition.items():
|
||||||
|
dtype = {
|
||||||
|
"I8": torch.int8,
|
||||||
|
"I16": torch.int16,
|
||||||
|
"I32": torch.int32,
|
||||||
|
"I64": torch.int64,
|
||||||
|
"F16": torch.float16,
|
||||||
|
"F32": torch.float32,
|
||||||
|
"F64": torch.float64,
|
||||||
|
}[info["dtype"]]
|
||||||
|
|
||||||
|
checkpoint[key] = torch.empty(info["shape"], dtype=dtype, device=device)
|
||||||
|
|
||||||
|
return checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def read_checkpoint_meta(cls, path: str):
|
||||||
|
if path.endswith(".safetensors"):
|
||||||
|
try:
|
||||||
|
checkpoint = cls._fast_safetensors_reader(path)
|
||||||
|
except:
|
||||||
|
checkpoint = safetensors.torch.load_file(path, device="cpu") # TODO: create issue for support "meta"?
|
||||||
|
else:
|
||||||
|
checkpoint = torch.load(path, map_location=torch.device("meta"))
|
||||||
|
return checkpoint
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_config(cls, **kwargs):
|
||||||
|
if "format" not in kwargs:
|
||||||
|
kwargs["format"] = cls.detect_format(kwargs["path"])
|
||||||
|
|
||||||
|
if "variant" not in kwargs:
|
||||||
|
if kwargs["format"] == "checkpoint":
|
||||||
|
if "config" in kwargs:
|
||||||
|
ckpt_config = OmegaConf.load(kwargs["config"])
|
||||||
|
in_channels = ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
|
||||||
|
|
||||||
|
else:
|
||||||
|
checkpoint = cls.read_checkpoint_meta(kwargs["path"])
|
||||||
|
checkpoint = checkpoint.get('state_dict', checkpoint)
|
||||||
|
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||||
|
|
||||||
|
elif kwargs["format"] == "diffusers":
|
||||||
|
unet_config_path = os.path.join(kwargs["path"], "unet", "config.json")
|
||||||
|
if os.path.exists(unet_config_path):
|
||||||
|
unet_config = json.loads(unet_config_path)
|
||||||
|
in_channels = unet_config['in_channels']
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)")
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unknown stable diffusion format: {kwargs['format']}")
|
||||||
|
|
||||||
|
if in_channels == 9:
|
||||||
|
kwargs["variant"] = ModelVariantType.Inpaint
|
||||||
|
elif in_channels == 5:
|
||||||
|
kwargs["variant"] = ModelVariantType.Depth
|
||||||
|
elif in_channels == 4:
|
||||||
|
kwargs["variant"] = ModelVariantType.Normal
|
||||||
|
else:
|
||||||
|
raise Exception("Unkown stable diffusion model format")
|
||||||
|
|
||||||
|
|
||||||
|
return super().build_config(**kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def save_to_config(cls) -> bool:
|
def save_to_config(cls) -> bool:
|
||||||
return True
|
return True
|
||||||
|
Loading…
x
Reference in New Issue
Block a user