mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge remote-tracking branch 'origin/main' into refactor/remove_unused_pipeline_methods
This commit is contained in:
commit
25c669b1d6
@ -661,27 +661,23 @@ class ImageHueAdjustmentInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: ImageField = Field(default=None, description="The image to adjust")
|
image: ImageField = Field(default=None, description="The image to adjust")
|
||||||
hue: int = Field(default=0, description="The degrees by which to rotate the hue")
|
hue: int = Field(default=0, description="The degrees by which to rotate the hue, 0-360")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
pil_image = context.services.images.get_pil_image(self.image.image_name)
|
pil_image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
|
|
||||||
# Convert PIL image to OpenCV format (numpy array), note color channel
|
|
||||||
# ordering is changed from RGB to BGR
|
|
||||||
image = numpy.array(pil_image.convert("RGB"))[:, :, ::-1]
|
|
||||||
|
|
||||||
# Convert image to HSV color space
|
# Convert image to HSV color space
|
||||||
hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
|
hsv_image = numpy.array(pil_image.convert("HSV"))
|
||||||
|
|
||||||
# Adjust the hue
|
# Convert hue from 0..360 to 0..256
|
||||||
hsv_image[:, :, 0] = (hsv_image[:, :, 0] + self.hue) % 180
|
hue = int(256 * ((self.hue % 360) / 360))
|
||||||
|
|
||||||
# Convert image back to BGR color space
|
# Increment each hue and wrap around at 255
|
||||||
image = cv2.cvtColor(hsv_image, cv2.COLOR_HSV2BGR)
|
hsv_image[:, :, 0] = (hsv_image[:, :, 0] + hue) % 256
|
||||||
|
|
||||||
# Convert back to PIL format and to original color mode
|
# Convert back to PIL format and to original color mode
|
||||||
pil_image = Image.fromarray(image[:, :, ::-1], "RGB").convert("RGBA")
|
pil_image = Image.fromarray(hsv_image, mode="HSV").convert("RGBA")
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=pil_image,
|
image=pil_image,
|
||||||
|
@ -228,19 +228,19 @@ the root is the InvokeAI ROOTDIR.
|
|||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import os
|
||||||
import textwrap
|
import textwrap
|
||||||
import yaml
|
import types
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Literal, Optional, List, Tuple, Union, Dict, Set, Callable, types
|
|
||||||
from shutil import rmtree, move
|
from shutil import rmtree, move
|
||||||
|
from typing import Optional, List, Literal, Tuple, Union, Dict, Set, Callable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import yaml
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from omegaconf.dictconfig import DictConfig
|
from omegaconf.dictconfig import DictConfig
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
@ -259,6 +259,7 @@ from .models import (
|
|||||||
ModelNotFoundException,
|
ModelNotFoundException,
|
||||||
InvalidModelException,
|
InvalidModelException,
|
||||||
DuplicateModelException,
|
DuplicateModelException,
|
||||||
|
ModelBase,
|
||||||
)
|
)
|
||||||
|
|
||||||
# We are only starting to number the config file with release 3.
|
# We are only starting to number the config file with release 3.
|
||||||
@ -361,7 +362,7 @@ class ModelManager(object):
|
|||||||
if model_key.startswith("_"):
|
if model_key.startswith("_"):
|
||||||
continue
|
continue
|
||||||
model_name, base_model, model_type = self.parse_key(model_key)
|
model_name, base_model, model_type = self.parse_key(model_key)
|
||||||
model_class = MODEL_CLASSES[base_model][model_type]
|
model_class = self._get_implementation(base_model, model_type)
|
||||||
# alias for config file
|
# alias for config file
|
||||||
model_config["model_format"] = model_config.pop("format")
|
model_config["model_format"] = model_config.pop("format")
|
||||||
self.models[model_key] = model_class.create_config(**model_config)
|
self.models[model_key] = model_class.create_config(**model_config)
|
||||||
@ -381,18 +382,24 @@ class ModelManager(object):
|
|||||||
# causing otherwise unreferenced models to be removed from memory
|
# causing otherwise unreferenced models to be removed from memory
|
||||||
self._read_models()
|
self._read_models()
|
||||||
|
|
||||||
def model_exists(
|
def model_exists(self, model_name: str, base_model: BaseModelType, model_type: ModelType, *, rescan=False) -> bool:
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_type: ModelType,
|
|
||||||
) -> bool:
|
|
||||||
"""
|
"""
|
||||||
Given a model name, returns True if it is a valid
|
Given a model name, returns True if it is a valid identifier.
|
||||||
identifier.
|
|
||||||
|
:param model_name: symbolic name of the model in models.yaml
|
||||||
|
:param model_type: ModelType enum indicating the type of model to return
|
||||||
|
:param base_model: BaseModelType enum indicating the base model used by this model
|
||||||
|
:param rescan: if True, scan_models_directory
|
||||||
"""
|
"""
|
||||||
model_key = self.create_key(model_name, base_model, model_type)
|
model_key = self.create_key(model_name, base_model, model_type)
|
||||||
return model_key in self.models
|
exists = model_key in self.models
|
||||||
|
|
||||||
|
# if model not found try to find it (maybe file just pasted)
|
||||||
|
if rescan and not exists:
|
||||||
|
self.scan_models_directory(base_model=base_model, model_type=model_type)
|
||||||
|
exists = self.model_exists(model_name, base_model, model_type, rescan=False)
|
||||||
|
|
||||||
|
return exists
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_key(
|
def create_key(
|
||||||
@ -443,39 +450,32 @@ class ModelManager(object):
|
|||||||
:param model_name: symbolic name of the model in models.yaml
|
:param model_name: symbolic name of the model in models.yaml
|
||||||
:param model_type: ModelType enum indicating the type of model to return
|
:param model_type: ModelType enum indicating the type of model to return
|
||||||
:param base_model: BaseModelType enum indicating the base model used by this model
|
:param base_model: BaseModelType enum indicating the base model used by this model
|
||||||
:param submode_typel: an ModelType enum indicating the portion of
|
:param submodel_type: an ModelType enum indicating the portion of
|
||||||
the model to retrieve (e.g. ModelType.Vae)
|
the model to retrieve (e.g. ModelType.Vae)
|
||||||
"""
|
"""
|
||||||
model_class = MODEL_CLASSES[base_model][model_type]
|
|
||||||
model_key = self.create_key(model_name, base_model, model_type)
|
model_key = self.create_key(model_name, base_model, model_type)
|
||||||
|
|
||||||
# if model not found try to find it (maybe file just pasted)
|
if not self.model_exists(model_name, base_model, model_type, rescan=True):
|
||||||
if model_key not in self.models:
|
|
||||||
self.scan_models_directory(base_model=base_model, model_type=model_type)
|
|
||||||
if model_key not in self.models:
|
|
||||||
raise ModelNotFoundException(f"Model not found - {model_key}")
|
raise ModelNotFoundException(f"Model not found - {model_key}")
|
||||||
|
|
||||||
model_config = self.models[model_key]
|
model_config = self._get_model_config(base_model, model_name, model_type)
|
||||||
model_path = self.resolve_model_path(model_config.path)
|
|
||||||
|
model_path, is_submodel_override = self._get_model_path(model_config, submodel_type)
|
||||||
|
|
||||||
|
if is_submodel_override:
|
||||||
|
model_type = submodel_type
|
||||||
|
submodel_type = None
|
||||||
|
|
||||||
|
model_class = self._get_implementation(base_model, model_type)
|
||||||
|
|
||||||
if not model_path.exists():
|
if not model_path.exists():
|
||||||
if model_class.save_to_config:
|
if model_class.save_to_config:
|
||||||
self.models[model_key].error = ModelError.NotFound
|
self.models[model_key].error = ModelError.NotFound
|
||||||
raise Exception(f'Files for model "{model_key}" not found')
|
raise Exception(f'Files for model "{model_key}" not found at {model_path}')
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.models.pop(model_key, None)
|
self.models.pop(model_key, None)
|
||||||
raise ModelNotFoundException(f"Model not found - {model_key}")
|
raise ModelNotFoundException(f'Files for model "{model_key}" not found at {model_path}')
|
||||||
|
|
||||||
# vae/movq override
|
|
||||||
# TODO:
|
|
||||||
if submodel_type is not None and hasattr(model_config, submodel_type):
|
|
||||||
override_path = getattr(model_config, submodel_type)
|
|
||||||
if override_path:
|
|
||||||
model_path = self.resolve_path(override_path)
|
|
||||||
model_type = submodel_type
|
|
||||||
submodel_type = None
|
|
||||||
model_class = MODEL_CLASSES[base_model][model_type]
|
|
||||||
|
|
||||||
# TODO: path
|
# TODO: path
|
||||||
# TODO: is it accurate to use path as id
|
# TODO: is it accurate to use path as id
|
||||||
@ -513,6 +513,55 @@ class ModelManager(object):
|
|||||||
_cache=self.cache,
|
_cache=self.cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _get_model_path(
|
||||||
|
self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None
|
||||||
|
) -> (Path, bool):
|
||||||
|
"""Extract a model's filesystem path from its config.
|
||||||
|
|
||||||
|
:return: The fully qualified Path of the module (or submodule).
|
||||||
|
"""
|
||||||
|
model_path = model_config.path
|
||||||
|
is_submodel_override = False
|
||||||
|
|
||||||
|
# Does the config explicitly override the submodel?
|
||||||
|
if submodel_type is not None and hasattr(model_config, submodel_type):
|
||||||
|
submodel_path = getattr(model_config, submodel_type)
|
||||||
|
if submodel_path is not None:
|
||||||
|
model_path = getattr(model_config, submodel_type)
|
||||||
|
is_submodel_override = True
|
||||||
|
|
||||||
|
model_path = self.resolve_model_path(model_path)
|
||||||
|
return model_path, is_submodel_override
|
||||||
|
|
||||||
|
def _get_model_config(self, base_model: BaseModelType, model_name: str, model_type: ModelType) -> ModelConfigBase:
|
||||||
|
"""Get a model's config object."""
|
||||||
|
model_key = self.create_key(model_name, base_model, model_type)
|
||||||
|
try:
|
||||||
|
model_config = self.models[model_key]
|
||||||
|
except KeyError:
|
||||||
|
raise ModelNotFoundException(f"Model not found - {model_key}")
|
||||||
|
return model_config
|
||||||
|
|
||||||
|
def _get_implementation(self, base_model: BaseModelType, model_type: ModelType) -> type[ModelBase]:
|
||||||
|
"""Get the concrete implementation class for a specific model type."""
|
||||||
|
model_class = MODEL_CLASSES[base_model][model_type]
|
||||||
|
return model_class
|
||||||
|
|
||||||
|
def _instantiate(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
|
submodel_type: Optional[SubModelType] = None,
|
||||||
|
) -> ModelBase:
|
||||||
|
"""Make a new instance of this model, without loading it."""
|
||||||
|
model_config = self._get_model_config(base_model, model_name, model_type)
|
||||||
|
model_path, is_submodel_override = self._get_model_path(model_config, submodel_type)
|
||||||
|
# FIXME: do non-overriden submodels get the right class?
|
||||||
|
constructor = self._get_implementation(base_model, model_type)
|
||||||
|
instance = constructor(model_path, base_model, model_type)
|
||||||
|
return instance
|
||||||
|
|
||||||
def model_info(
|
def model_info(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
@ -660,7 +709,7 @@ class ModelManager(object):
|
|||||||
if path := model_attributes.get("path"):
|
if path := model_attributes.get("path"):
|
||||||
model_attributes["path"] = str(self.relative_model_path(Path(path)))
|
model_attributes["path"] = str(self.relative_model_path(Path(path)))
|
||||||
|
|
||||||
model_class = MODEL_CLASSES[base_model][model_type]
|
model_class = self._get_implementation(base_model, model_type)
|
||||||
model_config = model_class.create_config(**model_attributes)
|
model_config = model_class.create_config(**model_attributes)
|
||||||
model_key = self.create_key(model_name, base_model, model_type)
|
model_key = self.create_key(model_name, base_model, model_type)
|
||||||
|
|
||||||
@ -851,7 +900,7 @@ class ModelManager(object):
|
|||||||
|
|
||||||
for model_key, model_config in self.models.items():
|
for model_key, model_config in self.models.items():
|
||||||
model_name, base_model, model_type = self.parse_key(model_key)
|
model_name, base_model, model_type = self.parse_key(model_key)
|
||||||
model_class = MODEL_CLASSES[base_model][model_type]
|
model_class = self._get_implementation(base_model, model_type)
|
||||||
if model_class.save_to_config:
|
if model_class.save_to_config:
|
||||||
# TODO: or exclude_unset better fits here?
|
# TODO: or exclude_unset better fits here?
|
||||||
data_to_save[model_key] = model_config.dict(exclude_defaults=True, exclude={"error"})
|
data_to_save[model_key] = model_config.dict(exclude_defaults=True, exclude={"error"})
|
||||||
@ -909,7 +958,7 @@ class ModelManager(object):
|
|||||||
|
|
||||||
model_path = self.resolve_model_path(model_config.path).absolute()
|
model_path = self.resolve_model_path(model_config.path).absolute()
|
||||||
if not model_path.exists():
|
if not model_path.exists():
|
||||||
model_class = MODEL_CLASSES[cur_base_model][cur_model_type]
|
model_class = self._get_implementation(cur_base_model, cur_model_type)
|
||||||
if model_class.save_to_config:
|
if model_class.save_to_config:
|
||||||
model_config.error = ModelError.NotFound
|
model_config.error = ModelError.NotFound
|
||||||
self.models.pop(model_key, None)
|
self.models.pop(model_key, None)
|
||||||
@ -925,7 +974,7 @@ class ModelManager(object):
|
|||||||
for cur_model_type in ModelType:
|
for cur_model_type in ModelType:
|
||||||
if model_type is not None and cur_model_type != model_type:
|
if model_type is not None and cur_model_type != model_type:
|
||||||
continue
|
continue
|
||||||
model_class = MODEL_CLASSES[cur_base_model][cur_model_type]
|
model_class = self._get_implementation(cur_base_model, cur_model_type)
|
||||||
models_dir = self.resolve_model_path(Path(cur_base_model.value, cur_model_type.value))
|
models_dir = self.resolve_model_path(Path(cur_base_model.value, cur_model_type.value))
|
||||||
|
|
||||||
if not models_dir.exists():
|
if not models_dir.exists():
|
||||||
|
@ -1,9 +1,14 @@
|
|||||||
import os
|
import os
|
||||||
import torch
|
|
||||||
import safetensors
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Union, Literal
|
from typing import Optional
|
||||||
|
|
||||||
|
import safetensors
|
||||||
|
import torch
|
||||||
|
from diffusers.utils import is_safetensors_available
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from .base import (
|
from .base import (
|
||||||
ModelBase,
|
ModelBase,
|
||||||
ModelConfigBase,
|
ModelConfigBase,
|
||||||
@ -18,9 +23,6 @@ from .base import (
|
|||||||
InvalidModelException,
|
InvalidModelException,
|
||||||
ModelNotFoundException,
|
ModelNotFoundException,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
|
||||||
from diffusers.utils import is_safetensors_available
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
|
|
||||||
|
|
||||||
class VaeModelFormat(str, Enum):
|
class VaeModelFormat(str, Enum):
|
||||||
@ -80,7 +82,7 @@ class VaeModel(ModelBase):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def detect_format(cls, path: str):
|
def detect_format(cls, path: str):
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
raise ModelNotFoundException()
|
raise ModelNotFoundException(f"Does not exist as local file: {path}")
|
||||||
|
|
||||||
if os.path.isdir(path):
|
if os.path.isdir(path):
|
||||||
if os.path.exists(os.path.join(path, "config.json")):
|
if os.path.exists(os.path.join(path, "config.json")):
|
||||||
|
@ -100,7 +100,7 @@ dependencies = [
|
|||||||
"dev" = [
|
"dev" = [
|
||||||
"pudb",
|
"pudb",
|
||||||
]
|
]
|
||||||
"test" = ["pytest>6.0.0", "pytest-cov", "black"]
|
"test" = ["pytest>6.0.0", "pytest-cov", "pytest-datadir", "black"]
|
||||||
"xformers" = [
|
"xformers" = [
|
||||||
"xformers~=0.0.19; sys_platform!='darwin'",
|
"xformers~=0.0.19; sys_platform!='darwin'",
|
||||||
"triton; sys_platform=='linux'",
|
"triton; sys_platform=='linux'",
|
||||||
|
38
tests/test_model_manager.py
Normal file
38
tests/test_model_manager.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
from invokeai.backend import ModelManager, BaseModelType, ModelType, SubModelType
|
||||||
|
|
||||||
|
BASIC_MODEL_NAME = ("SDXL base", BaseModelType.StableDiffusionXL, ModelType.Main)
|
||||||
|
VAE_OVERRIDE_MODEL_NAME = ("SDXL with VAE", BaseModelType.StableDiffusionXL, ModelType.Main)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def model_manager(datadir) -> ModelManager:
|
||||||
|
InvokeAIAppConfig.get_config(root=datadir)
|
||||||
|
return ModelManager(datadir / "configs" / "relative_sub.models.yaml")
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_model_names(model_manager: ModelManager):
|
||||||
|
names = model_manager.model_names()
|
||||||
|
assert names[:2] == [BASIC_MODEL_NAME, VAE_OVERRIDE_MODEL_NAME]
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_model_path_for_diffusers(model_manager: ModelManager, datadir: Path):
|
||||||
|
model_config = model_manager._get_model_config(BASIC_MODEL_NAME[1], BASIC_MODEL_NAME[0], BASIC_MODEL_NAME[2])
|
||||||
|
top_model_path, is_override = model_manager._get_model_path(model_config)
|
||||||
|
expected_model_path = datadir / "models" / "sdxl" / "main" / "SDXL base 1_0"
|
||||||
|
assert top_model_path == expected_model_path
|
||||||
|
assert not is_override
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_model_path_for_overridden_vae(model_manager: ModelManager, datadir: Path):
|
||||||
|
model_config = model_manager._get_model_config(
|
||||||
|
VAE_OVERRIDE_MODEL_NAME[1], VAE_OVERRIDE_MODEL_NAME[0], VAE_OVERRIDE_MODEL_NAME[2]
|
||||||
|
)
|
||||||
|
vae_model_path, is_override = model_manager._get_model_path(model_config, SubModelType.Vae)
|
||||||
|
expected_vae_path = datadir / "models" / "sdxl" / "vae" / "sdxl-vae-fp16-fix"
|
||||||
|
assert vae_model_path == expected_vae_path
|
||||||
|
assert is_override
|
15
tests/test_model_manager/configs/relative_sub.models.yaml
Normal file
15
tests/test_model_manager/configs/relative_sub.models.yaml
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
__metadata__:
|
||||||
|
version: 3.0.0
|
||||||
|
|
||||||
|
sdxl/main/SDXL base:
|
||||||
|
path: sdxl/main/SDXL base 1_0
|
||||||
|
description: SDXL base v1.0
|
||||||
|
variant: normal
|
||||||
|
format: diffusers
|
||||||
|
|
||||||
|
sdxl/main/SDXL with VAE:
|
||||||
|
path: sdxl/main/SDXL base 1_0
|
||||||
|
description: SDXL with customized VAE
|
||||||
|
vae: sdxl/vae/sdxl-vae-fp16-fix/
|
||||||
|
variant: normal
|
||||||
|
format: diffusers
|
Loading…
Reference in New Issue
Block a user