Merge remote-tracking branch 'origin/main' into refactor/model_manager_instantiate

# Conflicts:
#	invokeai/backend/model_management/model_manager.py
This commit is contained in:
Kevin Turner 2023-08-05 22:02:28 -07:00
commit 5bfd6cb66f
6 changed files with 185 additions and 32 deletions

View File

@ -55,7 +55,7 @@ logger = InvokeAILogger.getLogger()
class ApiDependencies: class ApiDependencies:
"""Contains and initializes all dependencies for the API""" """Contains and initializes all dependencies for the API"""
invoker: Optional[Invoker] = None invoker: Invoker
@staticmethod @staticmethod
def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger): def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger):
@ -68,8 +68,9 @@ class ApiDependencies:
output_folder = config.output_path output_folder = config.output_path
# TODO: build a file/path manager? # TODO: build a file/path manager?
db_location = config.db_path db_path = config.db_path
db_location.parent.mkdir(parents=True, exist_ok=True) db_path.parent.mkdir(parents=True, exist_ok=True)
db_location = str(db_path)
graph_execution_manager = SqliteItemStorage[GraphExecutionState]( graph_execution_manager = SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions" filename=db_location, table_name="graph_executions"

View File

@ -3,6 +3,7 @@
from typing import Literal, Optional from typing import Literal, Optional
import numpy import numpy
import cv2
from PIL import Image, ImageFilter, ImageOps, ImageChops from PIL import Image, ImageFilter, ImageOps, ImageChops
from pydantic import Field from pydantic import Field
from pathlib import Path from pathlib import Path
@ -650,3 +651,147 @@ class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )
class ImageHueAdjustmentInvocation(BaseInvocation):
"""Adjusts the Hue of an image."""
# fmt: off
type: Literal["img_hue_adjust"] = "img_hue_adjust"
# Inputs
image: ImageField = Field(default=None, description="The image to adjust")
hue: int = Field(default=0, description="The degrees by which to rotate the hue")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
pil_image = context.services.images.get_pil_image(self.image.image_name)
# Convert PIL image to OpenCV format (numpy array), note color channel
# ordering is changed from RGB to BGR
image = numpy.array(pil_image.convert("RGB"))[:, :, ::-1]
# Convert image to HSV color space
hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
# Adjust the hue
hsv_image[:, :, 0] = (hsv_image[:, :, 0] + self.hue) % 180
# Convert image back to BGR color space
image = cv2.cvtColor(hsv_image, cv2.COLOR_HSV2BGR)
# Convert back to PIL format and to original color mode
pil_image = Image.fromarray(image[:, :, ::-1], "RGB").convert("RGBA")
image_dto = context.services.images.create(
image=pil_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
is_intermediate=self.is_intermediate,
session_id=context.graph_execution_state_id,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
),
width=image_dto.width,
height=image_dto.height,
)
class ImageLuminosityAdjustmentInvocation(BaseInvocation):
"""Adjusts the Luminosity (Value) of an image."""
# fmt: off
type: Literal["img_luminosity_adjust"] = "img_luminosity_adjust"
# Inputs
image: ImageField = Field(default=None, description="The image to adjust")
luminosity: float = Field(default=1.0, ge=0, le=1, description="The factor by which to adjust the luminosity (value)")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
pil_image = context.services.images.get_pil_image(self.image.image_name)
# Convert PIL image to OpenCV format (numpy array), note color channel
# ordering is changed from RGB to BGR
image = numpy.array(pil_image.convert("RGB"))[:, :, ::-1]
# Convert image to HSV color space
hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
# Adjust the luminosity (value)
hsv_image[:, :, 2] = numpy.clip(hsv_image[:, :, 2] * self.luminosity, 0, 255)
# Convert image back to BGR color space
image = cv2.cvtColor(hsv_image, cv2.COLOR_HSV2BGR)
# Convert back to PIL format and to original color mode
pil_image = Image.fromarray(image[:, :, ::-1], "RGB").convert("RGBA")
image_dto = context.services.images.create(
image=pil_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
is_intermediate=self.is_intermediate,
session_id=context.graph_execution_state_id,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
),
width=image_dto.width,
height=image_dto.height,
)
class ImageSaturationAdjustmentInvocation(BaseInvocation):
"""Adjusts the Saturation of an image."""
# fmt: off
type: Literal["img_saturation_adjust"] = "img_saturation_adjust"
# Inputs
image: ImageField = Field(default=None, description="The image to adjust")
saturation: float = Field(default=1.0, ge=0, le=1, description="The factor by which to adjust the saturation")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
pil_image = context.services.images.get_pil_image(self.image.image_name)
# Convert PIL image to OpenCV format (numpy array), note color channel
# ordering is changed from RGB to BGR
image = numpy.array(pil_image.convert("RGB"))[:, :, ::-1]
# Convert image to HSV color space
hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
# Adjust the saturation
hsv_image[:, :, 1] = numpy.clip(hsv_image[:, :, 1] * self.saturation, 0, 255)
# Convert image back to BGR color space
image = cv2.cvtColor(hsv_image, cv2.COLOR_HSV2BGR)
# Convert back to PIL format and to original color mode
pil_image = Image.fromarray(image[:, :, ::-1], "RGB").convert("RGBA")
image_dto = context.services.images.create(
image=pil_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
is_intermediate=self.is_intermediate,
session_id=context.graph_execution_state_id,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
),
width=image_dto.width,
height=image_dto.height,
)

View File

@ -3,9 +3,10 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from logging import Logger
from pathlib import Path from pathlib import Path
from pydantic import Field from pydantic import Field
from typing import Optional, Union, Callable, List, Tuple, TYPE_CHECKING from typing import Literal, Optional, Union, Callable, List, Tuple, TYPE_CHECKING
from types import ModuleType from types import ModuleType
from invokeai.backend.model_management import ( from invokeai.backend.model_management import (
@ -193,7 +194,7 @@ class ModelManagerServiceBase(ABC):
self, self,
model_name: str, model_name: str,
base_model: BaseModelType, base_model: BaseModelType,
model_type: Union[ModelType.Main, ModelType.Vae], model_type: Literal[ModelType.Main, ModelType.Vae],
) -> AddModelResult: ) -> AddModelResult:
""" """
Convert a checkpoint file into a diffusers folder, deleting the cached Convert a checkpoint file into a diffusers folder, deleting the cached
@ -292,7 +293,7 @@ class ModelManagerService(ModelManagerServiceBase):
def __init__( def __init__(
self, self,
config: InvokeAIAppConfig, config: InvokeAIAppConfig,
logger: ModuleType, logger: Logger,
): ):
""" """
Initialize with the path to the models.yaml config file. Initialize with the path to the models.yaml config file.
@ -396,7 +397,7 @@ class ModelManagerService(ModelManagerServiceBase):
model_type, model_type,
) )
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]:
""" """
Given a model name returns a dict-like (OmegaConf) object describing it. Given a model name returns a dict-like (OmegaConf) object describing it.
""" """
@ -416,7 +417,7 @@ class ModelManagerService(ModelManagerServiceBase):
""" """
return self.mgr.list_models(base_model, model_type) return self.mgr.list_models(base_model, model_type)
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]:
""" """
Return information about the model using the same format as list_models() Return information about the model using the same format as list_models()
""" """
@ -429,7 +430,7 @@ class ModelManagerService(ModelManagerServiceBase):
model_type: ModelType, model_type: ModelType,
model_attributes: dict, model_attributes: dict,
clobber: bool = False, clobber: bool = False,
) -> None: ) -> AddModelResult:
""" """
Update the named model with a dictionary of attributes. Will fail with an Update the named model with a dictionary of attributes. Will fail with an
assertion error if the name already exists. Pass clobber=True to overwrite. assertion error if the name already exists. Pass clobber=True to overwrite.
@ -478,7 +479,7 @@ class ModelManagerService(ModelManagerServiceBase):
self, self,
model_name: str, model_name: str,
base_model: BaseModelType, base_model: BaseModelType,
model_type: Union[ModelType.Main, ModelType.Vae], model_type: Literal[ModelType.Main, ModelType.Vae],
convert_dest_directory: Optional[Path] = Field( convert_dest_directory: Optional[Path] = Field(
default=None, description="Optional directory location for merged model" default=None, description="Optional directory location for merged model"
), ),
@ -573,9 +574,9 @@ class ModelManagerService(ModelManagerServiceBase):
default=None, description="Base model shared by all models to be merged" default=None, description="Base model shared by all models to be merged"
), ),
merged_model_name: str = Field(default=None, description="Name of destination model after merging"), merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
alpha: Optional[float] = 0.5, alpha: float = 0.5,
interp: Optional[MergeInterpolationMethod] = None, interp: Optional[MergeInterpolationMethod] = None,
force: Optional[bool] = False, force: bool = False,
merge_dest_directory: Optional[Path] = Field( merge_dest_directory: Optional[Path] = Field(
default=None, description="Optional directory location for merged model" default=None, description="Optional directory location for merged model"
), ),
@ -633,8 +634,8 @@ class ModelManagerService(ModelManagerServiceBase):
model_name: str, model_name: str,
base_model: BaseModelType, base_model: BaseModelType,
model_type: ModelType, model_type: ModelType,
new_name: str = None, new_name: Optional[str] = None,
new_base: BaseModelType = None, new_base: Optional[BaseModelType] = None,
): ):
""" """
Rename the indicated model. Can provide a new name and/or a new base. Rename the indicated model. Can provide a new name and/or a new base.

View File

@ -101,9 +101,9 @@ class ModelInstall(object):
def __init__( def __init__(
self, self,
config: InvokeAIAppConfig, config: InvokeAIAppConfig,
prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None, prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
model_manager: ModelManager = None, model_manager: Optional[ModelManager] = None,
access_token: str = None, access_token: Optional[str] = None,
): ):
self.config = config self.config = config
self.mgr = model_manager or ModelManager(config.model_conf_path) self.mgr = model_manager or ModelManager(config.model_conf_path)

View File

@ -235,7 +235,7 @@ import types
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from shutil import rmtree, move from shutil import rmtree, move
from typing import Optional, List, Tuple, Union, Dict, Set, Callable from typing import Optional, List, Literal, Tuple, Union, Dict, Set, Callable
import torch import torch
import yaml import yaml
@ -567,7 +567,7 @@ class ModelManager(object):
model_name: str, model_name: str,
base_model: BaseModelType, base_model: BaseModelType,
model_type: ModelType, model_type: ModelType,
) -> dict: ) -> Union[dict, None]:
""" """
Given a model name returns the OmegaConf (dict-like) object describing it. Given a model name returns the OmegaConf (dict-like) object describing it.
""" """
@ -589,13 +589,15 @@ class ModelManager(object):
model_name: str, model_name: str,
base_model: BaseModelType, base_model: BaseModelType,
model_type: ModelType, model_type: ModelType,
) -> dict: ) -> Union[dict, None]:
""" """
Returns a dict describing one installed model, using Returns a dict describing one installed model, using
the combined format of the list_models() method. the combined format of the list_models() method.
""" """
models = self.list_models(base_model, model_type, model_name) models = self.list_models(base_model, model_type, model_name)
return models[0] if models else None if len(models) > 1:
return models[0]
return None
def list_models( def list_models(
self, self,
@ -609,7 +611,7 @@ class ModelManager(object):
model_keys = ( model_keys = (
[self.create_key(model_name, base_model, model_type)] [self.create_key(model_name, base_model, model_type)]
if model_name if model_name and base_model and model_type
else sorted(self.models, key=str.casefold) else sorted(self.models, key=str.casefold)
) )
models = [] models = []
@ -645,7 +647,7 @@ class ModelManager(object):
Print a table of models and their descriptions. This needs to be redone Print a table of models and their descriptions. This needs to be redone
""" """
# TODO: redo # TODO: redo
for model_type, model_dict in self.list_models().items(): for model_dict in self.list_models():
for model_name, model_info in model_dict.items(): for model_name, model_info in model_dict.items():
line = f'{model_info["name"]:25s} {model_info["type"]:10s} {model_info["description"]}' line = f'{model_info["name"]:25s} {model_info["type"]:10s} {model_info["description"]}'
print(line) print(line)
@ -748,8 +750,8 @@ class ModelManager(object):
model_name: str, model_name: str,
base_model: BaseModelType, base_model: BaseModelType,
model_type: ModelType, model_type: ModelType,
new_name: str = None, new_name: Optional[str] = None,
new_base: BaseModelType = None, new_base: Optional[BaseModelType] = None,
): ):
""" """
Rename or rebase a model. Rename or rebase a model.
@ -802,7 +804,7 @@ class ModelManager(object):
self, self,
model_name: str, model_name: str,
base_model: BaseModelType, base_model: BaseModelType,
model_type: Union[ModelType.Main, ModelType.Vae], model_type: Literal[ModelType.Main, ModelType.Vae],
dest_directory: Optional[Path] = None, dest_directory: Optional[Path] = None,
) -> AddModelResult: ) -> AddModelResult:
""" """
@ -816,6 +818,10 @@ class ModelManager(object):
This will raise a ValueError unless the model is a checkpoint. This will raise a ValueError unless the model is a checkpoint.
""" """
info = self.model_info(model_name, base_model, model_type) info = self.model_info(model_name, base_model, model_type)
if info is None:
raise FileNotFoundError(f"model not found: {model_name}")
if info["model_format"] != "checkpoint": if info["model_format"] != "checkpoint":
raise ValueError(f"not a checkpoint format model: {model_name}") raise ValueError(f"not a checkpoint format model: {model_name}")
@ -885,7 +891,7 @@ class ModelManager(object):
return search_folder, found_models return search_folder, found_models
def commit(self, conf_file: Path = None) -> None: def commit(self, conf_file: Optional[Path] = None) -> None:
""" """
Write current configuration out to the indicated file. Write current configuration out to the indicated file.
""" """
@ -1032,7 +1038,7 @@ class ModelManager(object):
# LS: hacky # LS: hacky
# Patch in the SD VAE from core so that it is available for use by the UI # Patch in the SD VAE from core so that it is available for use by the UI
try: try:
self.heuristic_import({self.resolve_model_path("core/convert/sd-vae-ft-mse")}) self.heuristic_import({str(self.resolve_model_path("core/convert/sd-vae-ft-mse"))})
except: except:
pass pass
@ -1060,7 +1066,7 @@ class ModelManager(object):
def heuristic_import( def heuristic_import(
self, self,
items_to_import: Set[str], items_to_import: Set[str],
prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None, prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
) -> Dict[str, AddModelResult]: ) -> Dict[str, AddModelResult]:
"""Import a list of paths, repo_ids or URLs. Returns the set of """Import a list of paths, repo_ids or URLs. Returns the set of
successfully imported items. successfully imported items.

View File

@ -33,7 +33,7 @@ class ModelMerger(object):
self, self,
model_paths: List[Path], model_paths: List[Path],
alpha: float = 0.5, alpha: float = 0.5,
interp: MergeInterpolationMethod = None, interp: Optional[MergeInterpolationMethod] = None,
force: bool = False, force: bool = False,
**kwargs, **kwargs,
) -> DiffusionPipeline: ) -> DiffusionPipeline:
@ -73,7 +73,7 @@ class ModelMerger(object):
base_model: Union[BaseModelType, str], base_model: Union[BaseModelType, str],
merged_model_name: str, merged_model_name: str,
alpha: float = 0.5, alpha: float = 0.5,
interp: MergeInterpolationMethod = None, interp: Optional[MergeInterpolationMethod] = None,
force: bool = False, force: bool = False,
merge_dest_directory: Optional[Path] = None, merge_dest_directory: Optional[Path] = None,
**kwargs, **kwargs,
@ -122,7 +122,7 @@ class ModelMerger(object):
dump_path.mkdir(parents=True, exist_ok=True) dump_path.mkdir(parents=True, exist_ok=True)
dump_path = dump_path / merged_model_name dump_path = dump_path / merged_model_name
merged_pipe.save_pretrained(dump_path, safe_serialization=1) merged_pipe.save_pretrained(dump_path, safe_serialization=True)
attributes = dict( attributes = dict(
path=str(dump_path), path=str(dump_path),
description=f"Merge of models {', '.join(model_names)}", description=f"Merge of models {', '.join(model_names)}",