Merge remote-tracking branch 'origin/main' into refactor/remove_unused_pipeline_methods

This commit is contained in:
Kevin Turner 2023-08-07 13:03:10 -07:00
commit 25c669b1d6
8 changed files with 158 additions and 58 deletions

View File

@ -661,27 +661,23 @@ class ImageHueAdjustmentInvocation(BaseInvocation):
# Inputs # Inputs
image: ImageField = Field(default=None, description="The image to adjust") image: ImageField = Field(default=None, description="The image to adjust")
hue: int = Field(default=0, description="The degrees by which to rotate the hue") hue: int = Field(default=0, description="The degrees by which to rotate the hue, 0-360")
# fmt: on # fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
pil_image = context.services.images.get_pil_image(self.image.image_name) pil_image = context.services.images.get_pil_image(self.image.image_name)
# Convert PIL image to OpenCV format (numpy array), note color channel
# ordering is changed from RGB to BGR
image = numpy.array(pil_image.convert("RGB"))[:, :, ::-1]
# Convert image to HSV color space # Convert image to HSV color space
hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) hsv_image = numpy.array(pil_image.convert("HSV"))
# Adjust the hue # Convert hue from 0..360 to 0..256
hsv_image[:, :, 0] = (hsv_image[:, :, 0] + self.hue) % 180 hue = int(256 * ((self.hue % 360) / 360))
# Convert image back to BGR color space # Increment each hue and wrap around at 255
image = cv2.cvtColor(hsv_image, cv2.COLOR_HSV2BGR) hsv_image[:, :, 0] = (hsv_image[:, :, 0] + hue) % 256
# Convert back to PIL format and to original color mode # Convert back to PIL format and to original color mode
pil_image = Image.fromarray(image[:, :, ::-1], "RGB").convert("RGBA") pil_image = Image.fromarray(hsv_image, mode="HSV").convert("RGBA")
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=pil_image, image=pil_image,

View File

@ -228,19 +228,19 @@ the root is the InvokeAI ROOTDIR.
""" """
from __future__ import annotations from __future__ import annotations
import os
import hashlib import hashlib
import os
import textwrap import textwrap
import yaml import types
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Literal, Optional, List, Tuple, Union, Dict, Set, Callable, types
from shutil import rmtree, move from shutil import rmtree, move
from typing import Optional, List, Literal, Tuple, Union, Dict, Set, Callable
import torch import torch
import yaml
from omegaconf import OmegaConf from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig from omegaconf.dictconfig import DictConfig
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
@ -259,6 +259,7 @@ from .models import (
ModelNotFoundException, ModelNotFoundException,
InvalidModelException, InvalidModelException,
DuplicateModelException, DuplicateModelException,
ModelBase,
) )
# We are only starting to number the config file with release 3. # We are only starting to number the config file with release 3.
@ -361,7 +362,7 @@ class ModelManager(object):
if model_key.startswith("_"): if model_key.startswith("_"):
continue continue
model_name, base_model, model_type = self.parse_key(model_key) model_name, base_model, model_type = self.parse_key(model_key)
model_class = MODEL_CLASSES[base_model][model_type] model_class = self._get_implementation(base_model, model_type)
# alias for config file # alias for config file
model_config["model_format"] = model_config.pop("format") model_config["model_format"] = model_config.pop("format")
self.models[model_key] = model_class.create_config(**model_config) self.models[model_key] = model_class.create_config(**model_config)
@ -381,18 +382,24 @@ class ModelManager(object):
# causing otherwise unreferenced models to be removed from memory # causing otherwise unreferenced models to be removed from memory
self._read_models() self._read_models()
def model_exists( def model_exists(self, model_name: str, base_model: BaseModelType, model_type: ModelType, *, rescan=False) -> bool:
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
) -> bool:
""" """
Given a model name, returns True if it is a valid Given a model name, returns True if it is a valid identifier.
identifier.
:param model_name: symbolic name of the model in models.yaml
:param model_type: ModelType enum indicating the type of model to return
:param base_model: BaseModelType enum indicating the base model used by this model
:param rescan: if True, scan_models_directory
""" """
model_key = self.create_key(model_name, base_model, model_type) model_key = self.create_key(model_name, base_model, model_type)
return model_key in self.models exists = model_key in self.models
# if model not found try to find it (maybe file just pasted)
if rescan and not exists:
self.scan_models_directory(base_model=base_model, model_type=model_type)
exists = self.model_exists(model_name, base_model, model_type, rescan=False)
return exists
@classmethod @classmethod
def create_key( def create_key(
@ -443,39 +450,32 @@ class ModelManager(object):
:param model_name: symbolic name of the model in models.yaml :param model_name: symbolic name of the model in models.yaml
:param model_type: ModelType enum indicating the type of model to return :param model_type: ModelType enum indicating the type of model to return
:param base_model: BaseModelType enum indicating the base model used by this model :param base_model: BaseModelType enum indicating the base model used by this model
:param submode_typel: an ModelType enum indicating the portion of :param submodel_type: an ModelType enum indicating the portion of
the model to retrieve (e.g. ModelType.Vae) the model to retrieve (e.g. ModelType.Vae)
""" """
model_class = MODEL_CLASSES[base_model][model_type]
model_key = self.create_key(model_name, base_model, model_type) model_key = self.create_key(model_name, base_model, model_type)
# if model not found try to find it (maybe file just pasted) if not self.model_exists(model_name, base_model, model_type, rescan=True):
if model_key not in self.models:
self.scan_models_directory(base_model=base_model, model_type=model_type)
if model_key not in self.models:
raise ModelNotFoundException(f"Model not found - {model_key}") raise ModelNotFoundException(f"Model not found - {model_key}")
model_config = self.models[model_key] model_config = self._get_model_config(base_model, model_name, model_type)
model_path = self.resolve_model_path(model_config.path)
model_path, is_submodel_override = self._get_model_path(model_config, submodel_type)
if is_submodel_override:
model_type = submodel_type
submodel_type = None
model_class = self._get_implementation(base_model, model_type)
if not model_path.exists(): if not model_path.exists():
if model_class.save_to_config: if model_class.save_to_config:
self.models[model_key].error = ModelError.NotFound self.models[model_key].error = ModelError.NotFound
raise Exception(f'Files for model "{model_key}" not found') raise Exception(f'Files for model "{model_key}" not found at {model_path}')
else: else:
self.models.pop(model_key, None) self.models.pop(model_key, None)
raise ModelNotFoundException(f"Model not found - {model_key}") raise ModelNotFoundException(f'Files for model "{model_key}" not found at {model_path}')
# vae/movq override
# TODO:
if submodel_type is not None and hasattr(model_config, submodel_type):
override_path = getattr(model_config, submodel_type)
if override_path:
model_path = self.resolve_path(override_path)
model_type = submodel_type
submodel_type = None
model_class = MODEL_CLASSES[base_model][model_type]
# TODO: path # TODO: path
# TODO: is it accurate to use path as id # TODO: is it accurate to use path as id
@ -513,6 +513,55 @@ class ModelManager(object):
_cache=self.cache, _cache=self.cache,
) )
def _get_model_path(
self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None
) -> (Path, bool):
"""Extract a model's filesystem path from its config.
:return: The fully qualified Path of the module (or submodule).
"""
model_path = model_config.path
is_submodel_override = False
# Does the config explicitly override the submodel?
if submodel_type is not None and hasattr(model_config, submodel_type):
submodel_path = getattr(model_config, submodel_type)
if submodel_path is not None:
model_path = getattr(model_config, submodel_type)
is_submodel_override = True
model_path = self.resolve_model_path(model_path)
return model_path, is_submodel_override
def _get_model_config(self, base_model: BaseModelType, model_name: str, model_type: ModelType) -> ModelConfigBase:
"""Get a model's config object."""
model_key = self.create_key(model_name, base_model, model_type)
try:
model_config = self.models[model_key]
except KeyError:
raise ModelNotFoundException(f"Model not found - {model_key}")
return model_config
def _get_implementation(self, base_model: BaseModelType, model_type: ModelType) -> type[ModelBase]:
"""Get the concrete implementation class for a specific model type."""
model_class = MODEL_CLASSES[base_model][model_type]
return model_class
def _instantiate(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel_type: Optional[SubModelType] = None,
) -> ModelBase:
"""Make a new instance of this model, without loading it."""
model_config = self._get_model_config(base_model, model_name, model_type)
model_path, is_submodel_override = self._get_model_path(model_config, submodel_type)
# FIXME: do non-overriden submodels get the right class?
constructor = self._get_implementation(base_model, model_type)
instance = constructor(model_path, base_model, model_type)
return instance
def model_info( def model_info(
self, self,
model_name: str, model_name: str,
@ -660,7 +709,7 @@ class ModelManager(object):
if path := model_attributes.get("path"): if path := model_attributes.get("path"):
model_attributes["path"] = str(self.relative_model_path(Path(path))) model_attributes["path"] = str(self.relative_model_path(Path(path)))
model_class = MODEL_CLASSES[base_model][model_type] model_class = self._get_implementation(base_model, model_type)
model_config = model_class.create_config(**model_attributes) model_config = model_class.create_config(**model_attributes)
model_key = self.create_key(model_name, base_model, model_type) model_key = self.create_key(model_name, base_model, model_type)
@ -851,7 +900,7 @@ class ModelManager(object):
for model_key, model_config in self.models.items(): for model_key, model_config in self.models.items():
model_name, base_model, model_type = self.parse_key(model_key) model_name, base_model, model_type = self.parse_key(model_key)
model_class = MODEL_CLASSES[base_model][model_type] model_class = self._get_implementation(base_model, model_type)
if model_class.save_to_config: if model_class.save_to_config:
# TODO: or exclude_unset better fits here? # TODO: or exclude_unset better fits here?
data_to_save[model_key] = model_config.dict(exclude_defaults=True, exclude={"error"}) data_to_save[model_key] = model_config.dict(exclude_defaults=True, exclude={"error"})
@ -909,7 +958,7 @@ class ModelManager(object):
model_path = self.resolve_model_path(model_config.path).absolute() model_path = self.resolve_model_path(model_config.path).absolute()
if not model_path.exists(): if not model_path.exists():
model_class = MODEL_CLASSES[cur_base_model][cur_model_type] model_class = self._get_implementation(cur_base_model, cur_model_type)
if model_class.save_to_config: if model_class.save_to_config:
model_config.error = ModelError.NotFound model_config.error = ModelError.NotFound
self.models.pop(model_key, None) self.models.pop(model_key, None)
@ -925,7 +974,7 @@ class ModelManager(object):
for cur_model_type in ModelType: for cur_model_type in ModelType:
if model_type is not None and cur_model_type != model_type: if model_type is not None and cur_model_type != model_type:
continue continue
model_class = MODEL_CLASSES[cur_base_model][cur_model_type] model_class = self._get_implementation(cur_base_model, cur_model_type)
models_dir = self.resolve_model_path(Path(cur_base_model.value, cur_model_type.value)) models_dir = self.resolve_model_path(Path(cur_base_model.value, cur_model_type.value))
if not models_dir.exists(): if not models_dir.exists():

View File

@ -1,9 +1,14 @@
import os import os
import torch
import safetensors
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Optional, Union, Literal from typing import Optional
import safetensors
import torch
from diffusers.utils import is_safetensors_available
from omegaconf import OmegaConf
from invokeai.app.services.config import InvokeAIAppConfig
from .base import ( from .base import (
ModelBase, ModelBase,
ModelConfigBase, ModelConfigBase,
@ -18,9 +23,6 @@ from .base import (
InvalidModelException, InvalidModelException,
ModelNotFoundException, ModelNotFoundException,
) )
from invokeai.app.services.config import InvokeAIAppConfig
from diffusers.utils import is_safetensors_available
from omegaconf import OmegaConf
class VaeModelFormat(str, Enum): class VaeModelFormat(str, Enum):
@ -80,7 +82,7 @@ class VaeModel(ModelBase):
@classmethod @classmethod
def detect_format(cls, path: str): def detect_format(cls, path: str):
if not os.path.exists(path): if not os.path.exists(path):
raise ModelNotFoundException() raise ModelNotFoundException(f"Does not exist as local file: {path}")
if os.path.isdir(path): if os.path.isdir(path):
if os.path.exists(os.path.join(path, "config.json")): if os.path.exists(os.path.join(path, "config.json")):

View File

@ -100,7 +100,7 @@ dependencies = [
"dev" = [ "dev" = [
"pudb", "pudb",
] ]
"test" = ["pytest>6.0.0", "pytest-cov", "black"] "test" = ["pytest>6.0.0", "pytest-cov", "pytest-datadir", "black"]
"xformers" = [ "xformers" = [
"xformers~=0.0.19; sys_platform!='darwin'", "xformers~=0.0.19; sys_platform!='darwin'",
"triton; sys_platform=='linux'", "triton; sys_platform=='linux'",

View File

@ -0,0 +1,38 @@
from pathlib import Path
import pytest
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend import ModelManager, BaseModelType, ModelType, SubModelType
BASIC_MODEL_NAME = ("SDXL base", BaseModelType.StableDiffusionXL, ModelType.Main)
VAE_OVERRIDE_MODEL_NAME = ("SDXL with VAE", BaseModelType.StableDiffusionXL, ModelType.Main)
@pytest.fixture
def model_manager(datadir) -> ModelManager:
InvokeAIAppConfig.get_config(root=datadir)
return ModelManager(datadir / "configs" / "relative_sub.models.yaml")
def test_get_model_names(model_manager: ModelManager):
names = model_manager.model_names()
assert names[:2] == [BASIC_MODEL_NAME, VAE_OVERRIDE_MODEL_NAME]
def test_get_model_path_for_diffusers(model_manager: ModelManager, datadir: Path):
model_config = model_manager._get_model_config(BASIC_MODEL_NAME[1], BASIC_MODEL_NAME[0], BASIC_MODEL_NAME[2])
top_model_path, is_override = model_manager._get_model_path(model_config)
expected_model_path = datadir / "models" / "sdxl" / "main" / "SDXL base 1_0"
assert top_model_path == expected_model_path
assert not is_override
def test_get_model_path_for_overridden_vae(model_manager: ModelManager, datadir: Path):
model_config = model_manager._get_model_config(
VAE_OVERRIDE_MODEL_NAME[1], VAE_OVERRIDE_MODEL_NAME[0], VAE_OVERRIDE_MODEL_NAME[2]
)
vae_model_path, is_override = model_manager._get_model_path(model_config, SubModelType.Vae)
expected_vae_path = datadir / "models" / "sdxl" / "vae" / "sdxl-vae-fp16-fix"
assert vae_model_path == expected_vae_path
assert is_override

View File

@ -0,0 +1,15 @@
__metadata__:
version: 3.0.0
sdxl/main/SDXL base:
path: sdxl/main/SDXL base 1_0
description: SDXL base v1.0
variant: normal
format: diffusers
sdxl/main/SDXL with VAE:
path: sdxl/main/SDXL base 1_0
description: SDXL with customized VAE
vae: sdxl/vae/sdxl-vae-fp16-fix/
variant: normal
format: diffusers