From e70bedba7d16b4c928220286ec09d38f058d742c Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Fri, 28 Jul 2023 21:03:27 -0700 Subject: [PATCH 01/12] refactor(ModelManager): factor out _get_implementation method --- .../backend/model_management/model_manager.py | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 2a82061a97..fbabd2fece 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -258,7 +258,7 @@ from .models import ( ModelConfigBase, ModelNotFoundException, InvalidModelException, - DuplicateModelException, + DuplicateModelException, ModelBase, ) # We are only starting to number the config file with release 3. @@ -361,7 +361,7 @@ class ModelManager(object): if model_key.startswith("_"): continue model_name, base_model, model_type = self.parse_key(model_key) - model_class = MODEL_CLASSES[base_model][model_type] + model_class = self._get_implementation(base_model, model_type) # alias for config file model_config["model_format"] = model_config.pop("format") self.models[model_key] = model_class.create_config(**model_config) @@ -446,7 +446,7 @@ class ModelManager(object): :param submode_typel: an ModelType enum indicating the portion of the model to retrieve (e.g. ModelType.Vae) """ - model_class = MODEL_CLASSES[base_model][model_type] + model_class = self._get_implementation(base_model, model_type) model_key = self.create_key(model_name, base_model, model_type) # if model not found try to find it (maybe file just pasted) @@ -475,7 +475,7 @@ class ModelManager(object): model_path = self.app_config.root_path / override_path model_type = submodel_type submodel_type = None - model_class = MODEL_CLASSES[base_model][model_type] + model_class = self._get_implementation(base_model, model_type) # TODO: path # TODO: is it accurate to use path as id @@ -513,6 +513,10 @@ class ModelManager(object): _cache=self.cache, ) + def _get_implementation(self, base_model: BaseModelType, model_type: ModelType) -> type[ModelBase]: + model_class = MODEL_CLASSES[base_model][model_type] + return model_class + def model_info( self, model_name: str, @@ -659,7 +663,7 @@ class ModelManager(object): if Path(path).is_relative_to(self.app_config.root_path): model_attributes["path"] = str(Path(path).relative_to(self.app_config.root_path)) - model_class = MODEL_CLASSES[base_model][model_type] + model_class = self._get_implementation(base_model, model_type) model_config = model_class.create_config(**model_attributes) model_key = self.create_key(model_name, base_model, model_type) @@ -837,7 +841,7 @@ class ModelManager(object): for model_key, model_config in self.models.items(): model_name, base_model, model_type = self.parse_key(model_key) - model_class = MODEL_CLASSES[base_model][model_type] + model_class = self._get_implementation(base_model, model_type) if model_class.save_to_config: # TODO: or exclude_unset better fits here? data_to_save[model_key] = model_config.dict(exclude_defaults=True, exclude={"error"}) @@ -888,7 +892,7 @@ class ModelManager(object): model_name, cur_base_model, cur_model_type = self.parse_key(model_key) model_path = self.app_config.root_path.absolute() / model_config.path if not model_path.exists(): - model_class = MODEL_CLASSES[cur_base_model][cur_model_type] + model_class = self._get_implementation(cur_base_model, cur_model_type) if model_class.save_to_config: model_config.error = ModelError.NotFound self.models.pop(model_key, None) @@ -904,7 +908,7 @@ class ModelManager(object): for cur_model_type in ModelType: if model_type is not None and cur_model_type != model_type: continue - model_class = MODEL_CLASSES[cur_base_model][cur_model_type] + model_class = self._get_implementation(cur_base_model, cur_model_type) models_dir = self.app_config.models_path / cur_base_model.value / cur_model_type.value if not models_dir.exists(): From dca685ac252055dff6dd38cfe6575994cc7385a0 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Fri, 28 Jul 2023 21:11:00 -0700 Subject: [PATCH 02/12] refactor(ModelManager): refactor rescan-on-miss to exists() method --- .../backend/model_management/model_manager.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index fbabd2fece..5da4344b89 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -386,13 +386,21 @@ class ModelManager(object): model_name: str, base_model: BaseModelType, model_type: ModelType, + rescan = False ) -> bool: """ Given a model name, returns True if it is a valid identifier. """ model_key = self.create_key(model_name, base_model, model_type) - return model_key in self.models + exists = model_key in self.models + + # if model not found try to find it (maybe file just pasted) + if rescan and not exists: + self.scan_models_directory(base_model=base_model, model_type=model_type) + exists = model_key in self.models + + return exists @classmethod def create_key( @@ -449,11 +457,8 @@ class ModelManager(object): model_class = self._get_implementation(base_model, model_type) model_key = self.create_key(model_name, base_model, model_type) - # if model not found try to find it (maybe file just pasted) - if model_key not in self.models: - self.scan_models_directory(base_model=base_model, model_type=model_type) - if model_key not in self.models: - raise ModelNotFoundException(f"Model not found - {model_key}") + if not self.model_exists(model_name, base_model, model_type, rescan=True): + raise ModelNotFoundException(f"Model not found - {model_key}") model_config = self.models[model_key] model_path = self.app_config.root_path / model_config.path From b163ae6a4dcb9b5024888c33620db8ddcd01b037 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Fri, 28 Jul 2023 21:30:20 -0700 Subject: [PATCH 03/12] refactor(ModelManager): factor out get_model_config --- invokeai/backend/model_management/model_manager.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 5da4344b89..51053f92cc 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -460,7 +460,7 @@ class ModelManager(object): if not self.model_exists(model_name, base_model, model_type, rescan=True): raise ModelNotFoundException(f"Model not found - {model_key}") - model_config = self.models[model_key] + model_config = self._get_model_config(base_model, model_name, model_type) model_path = self.app_config.root_path / model_config.path if not model_path.exists(): @@ -518,6 +518,14 @@ class ModelManager(object): _cache=self.cache, ) + def _get_model_config(self, base_model, model_name, model_type) -> ModelConfigBase: + model_key = self.create_key(model_name, base_model, model_type) + try: + model_config = self.models[model_key] + except KeyError: + raise ModelNotFoundException(f"Model not found - {model_key}") + return model_config + def _get_implementation(self, base_model: BaseModelType, model_type: ModelType) -> type[ModelBase]: model_class = MODEL_CLASSES[base_model][model_type] return model_class From bc9a5038fdbb0c1f9bc88c5fa5ea86cdf994024b Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Fri, 28 Jul 2023 22:01:28 -0700 Subject: [PATCH 04/12] refactor(ModelManager): factor out get_model_path --- .../backend/model_management/model_manager.py | 38 ++++++++++++------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 51053f92cc..6c79b07959 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -451,37 +451,33 @@ class ModelManager(object): :param model_name: symbolic name of the model in models.yaml :param model_type: ModelType enum indicating the type of model to return :param base_model: BaseModelType enum indicating the base model used by this model - :param submode_typel: an ModelType enum indicating the portion of + :param submodel_type: an ModelType enum indicating the portion of the model to retrieve (e.g. ModelType.Vae) """ - model_class = self._get_implementation(base_model, model_type) model_key = self.create_key(model_name, base_model, model_type) if not self.model_exists(model_name, base_model, model_type, rescan=True): raise ModelNotFoundException(f"Model not found - {model_key}") model_config = self._get_model_config(base_model, model_name, model_type) - model_path = self.app_config.root_path / model_config.path + + model_path, is_submodel_override = self._get_model_path(model_config, submodel_type) + + if is_submodel_override: + model_type = submodel_type + submodel_type = None + + model_class = self._get_implementation(base_model, model_type) if not model_path.exists(): if model_class.save_to_config: self.models[model_key].error = ModelError.NotFound - raise Exception(f'Files for model "{model_key}" not found') + raise Exception(f'Files for model "{model_key}" not found at {model_path}') else: self.models.pop(model_key, None) raise ModelNotFoundException(f"Model not found - {model_key}") - # vae/movq override - # TODO: - if submodel_type is not None and hasattr(model_config, submodel_type): - override_path = getattr(model_config, submodel_type) - if override_path: - model_path = self.app_config.root_path / override_path - model_type = submodel_type - submodel_type = None - model_class = self._get_implementation(base_model, model_type) - # TODO: path # TODO: is it accurate to use path as id dst_convert_path = self._get_model_cache_path(model_path) @@ -518,6 +514,20 @@ class ModelManager(object): _cache=self.cache, ) + def _get_model_path(self, model_config: ModelConfigBase, submodel_type: SubModelType = None) -> (Path, bool): + model_path = model_config.path + is_submodel_override = False + + # Does the config explicitly override the submodel? + if submodel_type is not None and hasattr(model_config, submodel_type): + submodel_path = getattr(model_config, submodel_type) + if submodel_path is not None: + model_path = getattr(model_config, submodel_type) + is_submodel_override = True + + model_path = self.app_config.root_path / model_path + return model_path, is_submodel_override + def _get_model_config(self, base_model, model_name, model_type) -> ModelConfigBase: model_key = self.create_key(model_name, base_model, model_type) try: From 86b8b69e889cf0e972b75dc50ff119e86c315649 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Fri, 28 Jul 2023 22:30:25 -0700 Subject: [PATCH 05/12] internal(ModelManager): add instantiate method --- invokeai/backend/model_management/model_manager.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 6c79b07959..954b86283f 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -540,6 +540,15 @@ class ModelManager(object): model_class = MODEL_CLASSES[base_model][model_type] return model_class + def _instantiate(self, model_name: str, base_model: BaseModelType, model_type: ModelType, + submodel_type: SubModelType = None) -> ModelBase: + model_config = self._get_model_config(base_model, model_name, model_type) + model_path, is_submodel_override = self._get_model_path(model_config, submodel_type) + # FIXME: do non-overriden submodels get the right class? + constructor = self._get_implementation(base_model, model_type) + instance = constructor(model_path, base_model, model_type) + return instance + def model_info( self, model_name: str, From ccceb32a859c6a3fa46a4181e0c31555175de4b9 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Sat, 29 Jul 2023 11:50:04 -0700 Subject: [PATCH 06/12] lint: formatting --- invokeai/backend/model_management/model_manager.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 954b86283f..c8a1405428 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -228,19 +228,19 @@ the root is the InvokeAI ROOTDIR. """ from __future__ import annotations -import os import hashlib +import os import textwrap -import yaml +import types from dataclasses import dataclass from pathlib import Path -from typing import Optional, List, Tuple, Union, Dict, Set, Callable, types from shutil import rmtree, move +from typing import Optional, List, Tuple, Union, Dict, Set, Callable import torch +import yaml from omegaconf import OmegaConf from omegaconf.dictconfig import DictConfig - from pydantic import BaseModel, Field import invokeai.backend.util.logging as logger From ff1c40747e96e01429650c69e3756e2e58a714d8 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Sat, 29 Jul 2023 20:02:31 -0700 Subject: [PATCH 07/12] lint: formatting --- .../backend/model_management/model_manager.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index c8a1405428..8bade9cb04 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -258,7 +258,8 @@ from .models import ( ModelConfigBase, ModelNotFoundException, InvalidModelException, - DuplicateModelException, ModelBase, + DuplicateModelException, + ModelBase, ) # We are only starting to number the config file with release 3. @@ -381,13 +382,7 @@ class ModelManager(object): # causing otherwise unreferenced models to be removed from memory self._read_models() - def model_exists( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - rescan = False - ) -> bool: + def model_exists(self, model_name: str, base_model: BaseModelType, model_type: ModelType, rescan=False) -> bool: """ Given a model name, returns True if it is a valid identifier. @@ -540,8 +535,9 @@ class ModelManager(object): model_class = MODEL_CLASSES[base_model][model_type] return model_class - def _instantiate(self, model_name: str, base_model: BaseModelType, model_type: ModelType, - submodel_type: SubModelType = None) -> ModelBase: + def _instantiate( + self, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel_type: SubModelType = None + ) -> ModelBase: model_config = self._get_model_config(base_model, model_name, model_type) model_path, is_submodel_override = self._get_model_path(model_config, submodel_type) # FIXME: do non-overriden submodels get the right class? From adfd1e52f4e99956533f8dbfd6e9fffde4f3c521 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Sun, 30 Jul 2023 11:53:12 -0700 Subject: [PATCH 08/12] refactor(model_manager): avoid copy/paste logic --- invokeai/backend/model_management/model_manager.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 930ce119fd..ac71e2d2f8 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -382,10 +382,9 @@ class ModelManager(object): # causing otherwise unreferenced models to be removed from memory self._read_models() - def model_exists(self, model_name: str, base_model: BaseModelType, model_type: ModelType, rescan=False) -> bool: + def model_exists(self, model_name: str, base_model: BaseModelType, model_type: ModelType, *, rescan=False) -> bool: """ - Given a model name, returns True if it is a valid - identifier. + Given a model name, returns True if it is a valid identifier. """ model_key = self.create_key(model_name, base_model, model_type) exists = model_key in self.models @@ -393,7 +392,7 @@ class ModelManager(object): # if model not found try to find it (maybe file just pasted) if rescan and not exists: self.scan_models_directory(base_model=base_model, model_type=model_type) - exists = model_key in self.models + exists = self.model_exists(model_name, base_model, model_type, rescan=False) return exists From bacdf985f1738f52816f88f708c25899279e517f Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Mon, 31 Jul 2023 09:08:46 -0700 Subject: [PATCH 09/12] doc(model_manager): docstrings --- .../backend/model_management/model_manager.py | 24 ++++++++++++++++--- .../backend/model_management/models/vae.py | 16 +++++++------ 2 files changed, 30 insertions(+), 10 deletions(-) diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index ac71e2d2f8..139e65ac93 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -385,6 +385,11 @@ class ModelManager(object): def model_exists(self, model_name: str, base_model: BaseModelType, model_type: ModelType, *, rescan=False) -> bool: """ Given a model name, returns True if it is a valid identifier. + + :param model_name: symbolic name of the model in models.yaml + :param model_type: ModelType enum indicating the type of model to return + :param base_model: BaseModelType enum indicating the base model used by this model + :param rescan: if True, scan_models_directory """ model_key = self.create_key(model_name, base_model, model_type) exists = model_key in self.models @@ -470,7 +475,7 @@ class ModelManager(object): else: self.models.pop(model_key, None) - raise ModelNotFoundException(f"Model not found - {model_key}") + raise ModelNotFoundException(f'Files for model "{model_key}" not found at {model_path}') # TODO: path # TODO: is it accurate to use path as id @@ -508,7 +513,13 @@ class ModelManager(object): _cache=self.cache, ) - def _get_model_path(self, model_config: ModelConfigBase, submodel_type: SubModelType = None) -> (Path, bool): + def _get_model_path( + self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None + ) -> (Path, bool): + """Extract a model's filesystem path from its config. + + :return: The fully qualified Path of the module (or submodule). + """ model_path = model_config.path is_submodel_override = False @@ -523,6 +534,7 @@ class ModelManager(object): return model_path, is_submodel_override def _get_model_config(self, base_model, model_name, model_type) -> ModelConfigBase: + """Get a model's config object.""" model_key = self.create_key(model_name, base_model, model_type) try: model_config = self.models[model_key] @@ -531,12 +543,18 @@ class ModelManager(object): return model_config def _get_implementation(self, base_model: BaseModelType, model_type: ModelType) -> type[ModelBase]: + """Get the concrete implementation class for a specific model type.""" model_class = MODEL_CLASSES[base_model][model_type] return model_class def _instantiate( - self, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel_type: SubModelType = None + self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + submodel_type: Optional[SubModelType] = None, ) -> ModelBase: + """Make a new instance of this model, without loading it.""" model_config = self._get_model_config(base_model, model_name, model_type) model_path, is_submodel_override = self._get_model_path(model_config, submodel_type) # FIXME: do non-overriden submodels get the right class? diff --git a/invokeai/backend/model_management/models/vae.py b/invokeai/backend/model_management/models/vae.py index b15844bcf8..957a102ffb 100644 --- a/invokeai/backend/model_management/models/vae.py +++ b/invokeai/backend/model_management/models/vae.py @@ -1,9 +1,14 @@ import os -import torch -import safetensors from enum import Enum from pathlib import Path -from typing import Optional, Union, Literal +from typing import Optional + +import safetensors +import torch +from diffusers.utils import is_safetensors_available +from omegaconf import OmegaConf + +from invokeai.app.services.config import InvokeAIAppConfig from .base import ( ModelBase, ModelConfigBase, @@ -18,9 +23,6 @@ from .base import ( InvalidModelException, ModelNotFoundException, ) -from invokeai.app.services.config import InvokeAIAppConfig -from diffusers.utils import is_safetensors_available -from omegaconf import OmegaConf class VaeModelFormat(str, Enum): @@ -80,7 +82,7 @@ class VaeModel(ModelBase): @classmethod def detect_format(cls, path: str): if not os.path.exists(path): - raise ModelNotFoundException() + raise ModelNotFoundException(f"Does not exist as local file: {path}") if os.path.isdir(path): if os.path.exists(os.path.join(path, "config.json")): From 44bf308192629a6e91a26d38e6814a2a2f6b068e Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Sat, 5 Aug 2023 15:22:23 -0700 Subject: [PATCH 10/12] test(model_management): add a couple tests for _get_model_path --- .../backend/model_management/model_manager.py | 2 +- pyproject.toml | 2 +- tests/test_model_manager.py | 36 +++++++++++++++++++ .../configs/relative_sub.models.yaml | 15 ++++++++ .../sdxl/main/SDXL base 1_0/model_index.json | 0 .../sdxl/vae/sdxl-vae-fp16-fix/config.json | 0 6 files changed, 53 insertions(+), 2 deletions(-) create mode 100644 tests/test_model_manager.py create mode 100644 tests/test_model_manager/configs/relative_sub.models.yaml create mode 100644 tests/test_model_manager/models/sdxl/main/SDXL base 1_0/model_index.json create mode 100644 tests/test_model_manager/models/sdxl/vae/sdxl-vae-fp16-fix/config.json diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 3e8888be24..ebe7ffbbd0 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -533,7 +533,7 @@ class ModelManager(object): model_path = self.resolve_model_path(model_path) return model_path, is_submodel_override - def _get_model_config(self, base_model, model_name, model_type) -> ModelConfigBase: + def _get_model_config(self, base_model: BaseModelType, model_name: str, model_type: ModelType) -> ModelConfigBase: """Get a model's config object.""" model_key = self.create_key(model_name, base_model, model_type) try: diff --git a/pyproject.toml b/pyproject.toml index b3f12481a8..2ae297a6da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,7 +100,7 @@ dependencies = [ "dev" = [ "pudb", ] -"test" = ["pytest>6.0.0", "pytest-cov", "black"] +"test" = ["pytest>6.0.0", "pytest-cov", "pytest-datadir", "black"] "xformers" = [ "xformers~=0.0.19; sys_platform!='darwin'", "triton; sys_platform=='linux'", diff --git a/tests/test_model_manager.py b/tests/test_model_manager.py new file mode 100644 index 0000000000..af0394eac2 --- /dev/null +++ b/tests/test_model_manager.py @@ -0,0 +1,36 @@ +from pathlib import Path + +import pytest + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.backend import ModelManager, BaseModelType, ModelType, SubModelType + + +@pytest.fixture +def model_manager(datadir) -> ModelManager: + InvokeAIAppConfig.get_config(root=datadir) + return ModelManager(datadir / "configs" / "relative_sub.models.yaml") + + +def test_get_model_names(model_manager: ModelManager): + names = model_manager.model_names() + assert names[:2] == [ + ("SDXL base", BaseModelType.StableDiffusionXL, ModelType.Main), + ("SDXL with VAE", BaseModelType.StableDiffusionXL, ModelType.Main), + ] + + +def test_get_model_path_for_diffusers(model_manager: ModelManager, datadir: Path): + model_config = model_manager._get_model_config(BaseModelType.StableDiffusionXL, "SDXL base", ModelType.Main) + top_model_path, is_override = model_manager._get_model_path(model_config) + expected_model_path = datadir / "models" / "sdxl" / "main" / "SDXL base 1_0" + assert top_model_path == expected_model_path + assert not is_override + + +def test_get_model_path_for_overridden_vae(model_manager: ModelManager, datadir: Path): + model_config = model_manager._get_model_config(BaseModelType.StableDiffusionXL, "SDXL with VAE", ModelType.Main) + vae_model_path, is_override = model_manager._get_model_path(model_config, SubModelType.Vae) + expected_vae_path = datadir / "models" / "sdxl" / "vae" / "sdxl-vae-fp16-fix" + assert vae_model_path == expected_vae_path + assert is_override diff --git a/tests/test_model_manager/configs/relative_sub.models.yaml b/tests/test_model_manager/configs/relative_sub.models.yaml new file mode 100644 index 0000000000..757c50e3b5 --- /dev/null +++ b/tests/test_model_manager/configs/relative_sub.models.yaml @@ -0,0 +1,15 @@ +__metadata__: + version: 3.0.0 + +sdxl/main/SDXL base: + path: sdxl/main/SDXL base 1_0 + description: SDXL base v1.0 + variant: normal + format: diffusers + +sdxl/main/SDXL with VAE: + path: sdxl/main/SDXL base 1_0 + description: SDXL base v1.0 + vae: sdxl/vae/sdxl-vae-fp16-fix/ + variant: normal + format: diffusers diff --git a/tests/test_model_manager/models/sdxl/main/SDXL base 1_0/model_index.json b/tests/test_model_manager/models/sdxl/main/SDXL base 1_0/model_index.json new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_model_manager/models/sdxl/vae/sdxl-vae-fp16-fix/config.json b/tests/test_model_manager/models/sdxl/vae/sdxl-vae-fp16-fix/config.json new file mode 100644 index 0000000000..e69de29bb2 From 7f4c3870808541746e4f4bc51fdc601b8e19056f Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Sat, 5 Aug 2023 15:46:46 -0700 Subject: [PATCH 11/12] test(model_management): factor out name strings --- tests/test_model_manager.py | 14 ++++++++------ .../configs/relative_sub.models.yaml | 2 +- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/test_model_manager.py b/tests/test_model_manager.py index af0394eac2..4314bad595 100644 --- a/tests/test_model_manager.py +++ b/tests/test_model_manager.py @@ -5,6 +5,9 @@ import pytest from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend import ModelManager, BaseModelType, ModelType, SubModelType +BASIC_MODEL_NAME = ("SDXL base", BaseModelType.StableDiffusionXL, ModelType.Main) +VAE_OVERRIDE_MODEL_NAME = ("SDXL with VAE", BaseModelType.StableDiffusionXL, ModelType.Main) + @pytest.fixture def model_manager(datadir) -> ModelManager: @@ -14,14 +17,11 @@ def model_manager(datadir) -> ModelManager: def test_get_model_names(model_manager: ModelManager): names = model_manager.model_names() - assert names[:2] == [ - ("SDXL base", BaseModelType.StableDiffusionXL, ModelType.Main), - ("SDXL with VAE", BaseModelType.StableDiffusionXL, ModelType.Main), - ] + assert names[:2] == [BASIC_MODEL_NAME, VAE_OVERRIDE_MODEL_NAME] def test_get_model_path_for_diffusers(model_manager: ModelManager, datadir: Path): - model_config = model_manager._get_model_config(BaseModelType.StableDiffusionXL, "SDXL base", ModelType.Main) + model_config = model_manager._get_model_config(BASIC_MODEL_NAME[1], BASIC_MODEL_NAME[0], BASIC_MODEL_NAME[2]) top_model_path, is_override = model_manager._get_model_path(model_config) expected_model_path = datadir / "models" / "sdxl" / "main" / "SDXL base 1_0" assert top_model_path == expected_model_path @@ -29,7 +29,9 @@ def test_get_model_path_for_diffusers(model_manager: ModelManager, datadir: Path def test_get_model_path_for_overridden_vae(model_manager: ModelManager, datadir: Path): - model_config = model_manager._get_model_config(BaseModelType.StableDiffusionXL, "SDXL with VAE", ModelType.Main) + model_config = model_manager._get_model_config( + VAE_OVERRIDE_MODEL_NAME[1], VAE_OVERRIDE_MODEL_NAME[0], VAE_OVERRIDE_MODEL_NAME[2] + ) vae_model_path, is_override = model_manager._get_model_path(model_config, SubModelType.Vae) expected_vae_path = datadir / "models" / "sdxl" / "vae" / "sdxl-vae-fp16-fix" assert vae_model_path == expected_vae_path diff --git a/tests/test_model_manager/configs/relative_sub.models.yaml b/tests/test_model_manager/configs/relative_sub.models.yaml index 757c50e3b5..3ec7a3adff 100644 --- a/tests/test_model_manager/configs/relative_sub.models.yaml +++ b/tests/test_model_manager/configs/relative_sub.models.yaml @@ -9,7 +9,7 @@ sdxl/main/SDXL base: sdxl/main/SDXL with VAE: path: sdxl/main/SDXL base 1_0 - description: SDXL base v1.0 + description: SDXL with customized VAE vae: sdxl/vae/sdxl-vae-fp16-fix/ variant: normal format: diffusers From ae17d01e1d2b4f2ef85a5df6c6e4d7ce0f378ca9 Mon Sep 17 00:00:00 2001 From: Jonathan <34005131+JPPhoto@users.noreply.github.com> Date: Sun, 6 Aug 2023 18:23:51 -0500 Subject: [PATCH 12/12] Fix hue adjustment (#4182) * Fix hue adjustment Hue adjustment wasn't working correctly because color channels got swapped. This has now been fixed and we're using PIL rather than cv2 to do the RGBA->HSV->RGBA conversion. The range of hue adjustment is also the more typical 0..360 degrees. --- invokeai/app/invocations/image.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index 9f60cb620a..c0aa7ad5ef 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -661,27 +661,23 @@ class ImageHueAdjustmentInvocation(BaseInvocation): # Inputs image: ImageField = Field(default=None, description="The image to adjust") - hue: int = Field(default=0, description="The degrees by which to rotate the hue") + hue: int = Field(default=0, description="The degrees by which to rotate the hue, 0-360") # fmt: on def invoke(self, context: InvocationContext) -> ImageOutput: pil_image = context.services.images.get_pil_image(self.image.image_name) - # Convert PIL image to OpenCV format (numpy array), note color channel - # ordering is changed from RGB to BGR - image = numpy.array(pil_image.convert("RGB"))[:, :, ::-1] - # Convert image to HSV color space - hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) + hsv_image = numpy.array(pil_image.convert("HSV")) - # Adjust the hue - hsv_image[:, :, 0] = (hsv_image[:, :, 0] + self.hue) % 180 + # Convert hue from 0..360 to 0..256 + hue = int(256 * ((self.hue % 360) / 360)) - # Convert image back to BGR color space - image = cv2.cvtColor(hsv_image, cv2.COLOR_HSV2BGR) + # Increment each hue and wrap around at 255 + hsv_image[:, :, 0] = (hsv_image[:, :, 0] + hue) % 256 # Convert back to PIL format and to original color mode - pil_image = Image.fromarray(image[:, :, ::-1], "RGB").convert("RGBA") + pil_image = Image.fromarray(hsv_image, mode="HSV").convert("RGBA") image_dto = context.services.images.create( image=pil_image,