mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(ModelManager): fix overridden VAE with relative path (#4059)
This commit is contained in:
commit
4367061b19
@ -228,19 +228,19 @@ the root is the InvokeAI ROOTDIR.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import hashlib
|
||||
import os
|
||||
import textwrap
|
||||
import yaml
|
||||
import types
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, List, Tuple, Union, Dict, Set, Callable, types
|
||||
from shutil import rmtree, move
|
||||
from typing import Optional, List, Literal, Tuple, Union, Dict, Set, Callable
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
@ -259,6 +259,7 @@ from .models import (
|
||||
ModelNotFoundException,
|
||||
InvalidModelException,
|
||||
DuplicateModelException,
|
||||
ModelBase,
|
||||
)
|
||||
|
||||
# We are only starting to number the config file with release 3.
|
||||
@ -361,7 +362,7 @@ class ModelManager(object):
|
||||
if model_key.startswith("_"):
|
||||
continue
|
||||
model_name, base_model, model_type = self.parse_key(model_key)
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
model_class = self._get_implementation(base_model, model_type)
|
||||
# alias for config file
|
||||
model_config["model_format"] = model_config.pop("format")
|
||||
self.models[model_key] = model_class.create_config(**model_config)
|
||||
@ -381,18 +382,24 @@ class ModelManager(object):
|
||||
# causing otherwise unreferenced models to be removed from memory
|
||||
self._read_models()
|
||||
|
||||
def model_exists(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
) -> bool:
|
||||
def model_exists(self, model_name: str, base_model: BaseModelType, model_type: ModelType, *, rescan=False) -> bool:
|
||||
"""
|
||||
Given a model name, returns True if it is a valid
|
||||
identifier.
|
||||
Given a model name, returns True if it is a valid identifier.
|
||||
|
||||
:param model_name: symbolic name of the model in models.yaml
|
||||
:param model_type: ModelType enum indicating the type of model to return
|
||||
:param base_model: BaseModelType enum indicating the base model used by this model
|
||||
:param rescan: if True, scan_models_directory
|
||||
"""
|
||||
model_key = self.create_key(model_name, base_model, model_type)
|
||||
return model_key in self.models
|
||||
exists = model_key in self.models
|
||||
|
||||
# if model not found try to find it (maybe file just pasted)
|
||||
if rescan and not exists:
|
||||
self.scan_models_directory(base_model=base_model, model_type=model_type)
|
||||
exists = self.model_exists(model_name, base_model, model_type, rescan=False)
|
||||
|
||||
return exists
|
||||
|
||||
@classmethod
|
||||
def create_key(
|
||||
@ -443,39 +450,32 @@ class ModelManager(object):
|
||||
:param model_name: symbolic name of the model in models.yaml
|
||||
:param model_type: ModelType enum indicating the type of model to return
|
||||
:param base_model: BaseModelType enum indicating the base model used by this model
|
||||
:param submode_typel: an ModelType enum indicating the portion of
|
||||
:param submodel_type: an ModelType enum indicating the portion of
|
||||
the model to retrieve (e.g. ModelType.Vae)
|
||||
"""
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
model_key = self.create_key(model_name, base_model, model_type)
|
||||
|
||||
# if model not found try to find it (maybe file just pasted)
|
||||
if model_key not in self.models:
|
||||
self.scan_models_directory(base_model=base_model, model_type=model_type)
|
||||
if model_key not in self.models:
|
||||
raise ModelNotFoundException(f"Model not found - {model_key}")
|
||||
if not self.model_exists(model_name, base_model, model_type, rescan=True):
|
||||
raise ModelNotFoundException(f"Model not found - {model_key}")
|
||||
|
||||
model_config = self.models[model_key]
|
||||
model_path = self.resolve_model_path(model_config.path)
|
||||
model_config = self._get_model_config(base_model, model_name, model_type)
|
||||
|
||||
model_path, is_submodel_override = self._get_model_path(model_config, submodel_type)
|
||||
|
||||
if is_submodel_override:
|
||||
model_type = submodel_type
|
||||
submodel_type = None
|
||||
|
||||
model_class = self._get_implementation(base_model, model_type)
|
||||
|
||||
if not model_path.exists():
|
||||
if model_class.save_to_config:
|
||||
self.models[model_key].error = ModelError.NotFound
|
||||
raise Exception(f'Files for model "{model_key}" not found')
|
||||
raise Exception(f'Files for model "{model_key}" not found at {model_path}')
|
||||
|
||||
else:
|
||||
self.models.pop(model_key, None)
|
||||
raise ModelNotFoundException(f"Model not found - {model_key}")
|
||||
|
||||
# vae/movq override
|
||||
# TODO:
|
||||
if submodel_type is not None and hasattr(model_config, submodel_type):
|
||||
override_path = getattr(model_config, submodel_type)
|
||||
if override_path:
|
||||
model_path = self.resolve_path(override_path)
|
||||
model_type = submodel_type
|
||||
submodel_type = None
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
raise ModelNotFoundException(f'Files for model "{model_key}" not found at {model_path}')
|
||||
|
||||
# TODO: path
|
||||
# TODO: is it accurate to use path as id
|
||||
@ -513,6 +513,55 @@ class ModelManager(object):
|
||||
_cache=self.cache,
|
||||
)
|
||||
|
||||
def _get_model_path(
|
||||
self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None
|
||||
) -> (Path, bool):
|
||||
"""Extract a model's filesystem path from its config.
|
||||
|
||||
:return: The fully qualified Path of the module (or submodule).
|
||||
"""
|
||||
model_path = model_config.path
|
||||
is_submodel_override = False
|
||||
|
||||
# Does the config explicitly override the submodel?
|
||||
if submodel_type is not None and hasattr(model_config, submodel_type):
|
||||
submodel_path = getattr(model_config, submodel_type)
|
||||
if submodel_path is not None:
|
||||
model_path = getattr(model_config, submodel_type)
|
||||
is_submodel_override = True
|
||||
|
||||
model_path = self.resolve_model_path(model_path)
|
||||
return model_path, is_submodel_override
|
||||
|
||||
def _get_model_config(self, base_model: BaseModelType, model_name: str, model_type: ModelType) -> ModelConfigBase:
|
||||
"""Get a model's config object."""
|
||||
model_key = self.create_key(model_name, base_model, model_type)
|
||||
try:
|
||||
model_config = self.models[model_key]
|
||||
except KeyError:
|
||||
raise ModelNotFoundException(f"Model not found - {model_key}")
|
||||
return model_config
|
||||
|
||||
def _get_implementation(self, base_model: BaseModelType, model_type: ModelType) -> type[ModelBase]:
|
||||
"""Get the concrete implementation class for a specific model type."""
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
return model_class
|
||||
|
||||
def _instantiate(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> ModelBase:
|
||||
"""Make a new instance of this model, without loading it."""
|
||||
model_config = self._get_model_config(base_model, model_name, model_type)
|
||||
model_path, is_submodel_override = self._get_model_path(model_config, submodel_type)
|
||||
# FIXME: do non-overriden submodels get the right class?
|
||||
constructor = self._get_implementation(base_model, model_type)
|
||||
instance = constructor(model_path, base_model, model_type)
|
||||
return instance
|
||||
|
||||
def model_info(
|
||||
self,
|
||||
model_name: str,
|
||||
@ -660,7 +709,7 @@ class ModelManager(object):
|
||||
if path := model_attributes.get("path"):
|
||||
model_attributes["path"] = str(self.relative_model_path(Path(path)))
|
||||
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
model_class = self._get_implementation(base_model, model_type)
|
||||
model_config = model_class.create_config(**model_attributes)
|
||||
model_key = self.create_key(model_name, base_model, model_type)
|
||||
|
||||
@ -851,7 +900,7 @@ class ModelManager(object):
|
||||
|
||||
for model_key, model_config in self.models.items():
|
||||
model_name, base_model, model_type = self.parse_key(model_key)
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
model_class = self._get_implementation(base_model, model_type)
|
||||
if model_class.save_to_config:
|
||||
# TODO: or exclude_unset better fits here?
|
||||
data_to_save[model_key] = model_config.dict(exclude_defaults=True, exclude={"error"})
|
||||
@ -909,7 +958,7 @@ class ModelManager(object):
|
||||
|
||||
model_path = self.resolve_model_path(model_config.path).absolute()
|
||||
if not model_path.exists():
|
||||
model_class = MODEL_CLASSES[cur_base_model][cur_model_type]
|
||||
model_class = self._get_implementation(cur_base_model, cur_model_type)
|
||||
if model_class.save_to_config:
|
||||
model_config.error = ModelError.NotFound
|
||||
self.models.pop(model_key, None)
|
||||
@ -925,7 +974,7 @@ class ModelManager(object):
|
||||
for cur_model_type in ModelType:
|
||||
if model_type is not None and cur_model_type != model_type:
|
||||
continue
|
||||
model_class = MODEL_CLASSES[cur_base_model][cur_model_type]
|
||||
model_class = self._get_implementation(cur_base_model, cur_model_type)
|
||||
models_dir = self.resolve_model_path(Path(cur_base_model.value, cur_model_type.value))
|
||||
|
||||
if not models_dir.exists():
|
||||
|
@ -1,9 +1,14 @@
|
||||
import os
|
||||
import torch
|
||||
import safetensors
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, Literal
|
||||
from typing import Optional
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
from diffusers.utils import is_safetensors_available
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from .base import (
|
||||
ModelBase,
|
||||
ModelConfigBase,
|
||||
@ -18,9 +23,6 @@ from .base import (
|
||||
InvalidModelException,
|
||||
ModelNotFoundException,
|
||||
)
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from diffusers.utils import is_safetensors_available
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
class VaeModelFormat(str, Enum):
|
||||
@ -80,7 +82,7 @@ class VaeModel(ModelBase):
|
||||
@classmethod
|
||||
def detect_format(cls, path: str):
|
||||
if not os.path.exists(path):
|
||||
raise ModelNotFoundException()
|
||||
raise ModelNotFoundException(f"Does not exist as local file: {path}")
|
||||
|
||||
if os.path.isdir(path):
|
||||
if os.path.exists(os.path.join(path, "config.json")):
|
||||
|
@ -100,7 +100,7 @@ dependencies = [
|
||||
"dev" = [
|
||||
"pudb",
|
||||
]
|
||||
"test" = ["pytest>6.0.0", "pytest-cov", "black"]
|
||||
"test" = ["pytest>6.0.0", "pytest-cov", "pytest-datadir", "black"]
|
||||
"xformers" = [
|
||||
"xformers~=0.0.19; sys_platform!='darwin'",
|
||||
"triton; sys_platform=='linux'",
|
||||
|
38
tests/test_model_manager.py
Normal file
38
tests/test_model_manager.py
Normal file
@ -0,0 +1,38 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend import ModelManager, BaseModelType, ModelType, SubModelType
|
||||
|
||||
BASIC_MODEL_NAME = ("SDXL base", BaseModelType.StableDiffusionXL, ModelType.Main)
|
||||
VAE_OVERRIDE_MODEL_NAME = ("SDXL with VAE", BaseModelType.StableDiffusionXL, ModelType.Main)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_manager(datadir) -> ModelManager:
|
||||
InvokeAIAppConfig.get_config(root=datadir)
|
||||
return ModelManager(datadir / "configs" / "relative_sub.models.yaml")
|
||||
|
||||
|
||||
def test_get_model_names(model_manager: ModelManager):
|
||||
names = model_manager.model_names()
|
||||
assert names[:2] == [BASIC_MODEL_NAME, VAE_OVERRIDE_MODEL_NAME]
|
||||
|
||||
|
||||
def test_get_model_path_for_diffusers(model_manager: ModelManager, datadir: Path):
|
||||
model_config = model_manager._get_model_config(BASIC_MODEL_NAME[1], BASIC_MODEL_NAME[0], BASIC_MODEL_NAME[2])
|
||||
top_model_path, is_override = model_manager._get_model_path(model_config)
|
||||
expected_model_path = datadir / "models" / "sdxl" / "main" / "SDXL base 1_0"
|
||||
assert top_model_path == expected_model_path
|
||||
assert not is_override
|
||||
|
||||
|
||||
def test_get_model_path_for_overridden_vae(model_manager: ModelManager, datadir: Path):
|
||||
model_config = model_manager._get_model_config(
|
||||
VAE_OVERRIDE_MODEL_NAME[1], VAE_OVERRIDE_MODEL_NAME[0], VAE_OVERRIDE_MODEL_NAME[2]
|
||||
)
|
||||
vae_model_path, is_override = model_manager._get_model_path(model_config, SubModelType.Vae)
|
||||
expected_vae_path = datadir / "models" / "sdxl" / "vae" / "sdxl-vae-fp16-fix"
|
||||
assert vae_model_path == expected_vae_path
|
||||
assert is_override
|
15
tests/test_model_manager/configs/relative_sub.models.yaml
Normal file
15
tests/test_model_manager/configs/relative_sub.models.yaml
Normal file
@ -0,0 +1,15 @@
|
||||
__metadata__:
|
||||
version: 3.0.0
|
||||
|
||||
sdxl/main/SDXL base:
|
||||
path: sdxl/main/SDXL base 1_0
|
||||
description: SDXL base v1.0
|
||||
variant: normal
|
||||
format: diffusers
|
||||
|
||||
sdxl/main/SDXL with VAE:
|
||||
path: sdxl/main/SDXL base 1_0
|
||||
description: SDXL with customized VAE
|
||||
vae: sdxl/vae/sdxl-vae-fp16-fix/
|
||||
variant: normal
|
||||
format: diffusers
|
Loading…
Reference in New Issue
Block a user