Merge remote-tracking branch 'origin/main' into feat/db/migrations

This commit is contained in:
psychedelicious
2023-12-12 17:22:52 +11:00
33 changed files with 1933 additions and 195 deletions

View File

@ -272,6 +272,8 @@ def invoke_api() -> None:
port=port,
loop="asyncio",
log_level=app_config.log_level,
ssl_certfile=app_config.ssl_certfile,
ssl_keyfile=app_config.ssl_keyfile,
)
server = uvicorn.Server(config)

View File

@ -39,6 +39,19 @@ class InvalidFieldError(TypeError):
pass
class Classification(str, Enum, metaclass=MetaEnum):
"""
The classification of an Invocation.
- `Stable`: The invocation, including its inputs/outputs and internal logic, is stable. You may build workflows with it, having confidence that they will not break because of a change in this invocation.
- `Beta`: The invocation is not yet stable, but is planned to be stable in the future. Workflows built around this invocation may break, but we are committed to supporting this invocation long-term.
- `Prototype`: The invocation is not yet stable and may be removed from the application at any time. Workflows built around this invocation may break, and we are *not* committed to supporting this invocation.
"""
Stable = "stable"
Beta = "beta"
Prototype = "prototype"
class Input(str, Enum, metaclass=MetaEnum):
"""
The type of input a field accepts.
@ -439,6 +452,7 @@ class UIConfigBase(BaseModel):
description='The node\'s version. Should be a valid semver string e.g. "1.0.0" or "3.8.13".',
)
node_pack: Optional[str] = Field(default=None, description="Whether or not this is a custom node")
classification: Classification = Field(default=Classification.Stable, description="The node's classification")
model_config = ConfigDict(
validate_assignment=True,
@ -607,6 +621,7 @@ class BaseInvocation(ABC, BaseModel):
schema["category"] = uiconfig.category
if uiconfig.node_pack is not None:
schema["node_pack"] = uiconfig.node_pack
schema["classification"] = uiconfig.classification
schema["version"] = uiconfig.version
if "required" not in schema or not isinstance(schema["required"], list):
schema["required"] = []
@ -782,6 +797,7 @@ def invocation(
category: Optional[str] = None,
version: Optional[str] = None,
use_cache: Optional[bool] = True,
classification: Classification = Classification.Stable,
) -> Callable[[Type[TBaseInvocation]], Type[TBaseInvocation]]:
"""
Registers an invocation.
@ -792,6 +808,7 @@ def invocation(
:param Optional[str] category: Adds a category to the invocation. Used to group the invocations in the UI. Defaults to None.
:param Optional[str] version: Adds a version to the invocation. Must be a valid semver string. Defaults to None.
:param Optional[bool] use_cache: Whether or not to use the invocation cache. Defaults to True. The user may override this in the workflow editor.
:param Classification classification: The classification of the invocation. Defaults to FeatureClassification.Stable. Use Beta or Prototype if the invocation is unstable.
"""
def wrapper(cls: Type[TBaseInvocation]) -> Type[TBaseInvocation]:
@ -812,6 +829,7 @@ def invocation(
cls.UIConfig.title = title
cls.UIConfig.tags = tags
cls.UIConfig.category = category
cls.UIConfig.classification = classification
# Grab the node pack's name from the module name, if it's a custom node
is_custom_node = cls.__module__.rsplit(".", 1)[0] == "invokeai.app.invocations"

View File

@ -1,3 +1,5 @@
from typing import Literal
import numpy as np
from PIL import Image
from pydantic import BaseModel
@ -5,6 +7,8 @@ from pydantic import BaseModel
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
Input,
InputField,
InvocationContext,
OutputField,
@ -14,7 +18,13 @@ from invokeai.app.invocations.baseinvocation import (
)
from invokeai.app.invocations.primitives import ImageField, ImageOutput
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
from invokeai.backend.tiles.tiles import calc_tiles_with_overlap, merge_tiles_with_linear_blending
from invokeai.backend.tiles.tiles import (
calc_tiles_even_split,
calc_tiles_min_overlap,
calc_tiles_with_overlap,
merge_tiles_with_linear_blending,
merge_tiles_with_seam_blending,
)
from invokeai.backend.tiles.utils import Tile
@ -55,6 +65,79 @@ class CalculateImageTilesInvocation(BaseInvocation):
return CalculateImageTilesOutput(tiles=tiles)
@invocation(
"calculate_image_tiles_even_split",
title="Calculate Image Tiles Even Split",
tags=["tiles"],
category="tiles",
version="1.0.0",
classification=Classification.Beta,
)
class CalculateImageTilesEvenSplitInvocation(BaseInvocation):
"""Calculate the coordinates and overlaps of tiles that cover a target image shape."""
image_width: int = InputField(ge=1, default=1024, description="The image width, in pixels, to calculate tiles for.")
image_height: int = InputField(
ge=1, default=1024, description="The image height, in pixels, to calculate tiles for."
)
num_tiles_x: int = InputField(
default=2,
ge=1,
description="Number of tiles to divide image into on the x axis",
)
num_tiles_y: int = InputField(
default=2,
ge=1,
description="Number of tiles to divide image into on the y axis",
)
overlap_fraction: float = InputField(
default=0.25,
ge=0,
lt=1,
description="Overlap between adjacent tiles as a fraction of the tile's dimensions (0-1)",
)
def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput:
tiles = calc_tiles_even_split(
image_height=self.image_height,
image_width=self.image_width,
num_tiles_x=self.num_tiles_x,
num_tiles_y=self.num_tiles_y,
overlap_fraction=self.overlap_fraction,
)
return CalculateImageTilesOutput(tiles=tiles)
@invocation(
"calculate_image_tiles_min_overlap",
title="Calculate Image Tiles Minimum Overlap",
tags=["tiles"],
category="tiles",
version="1.0.0",
classification=Classification.Beta,
)
class CalculateImageTilesMinimumOverlapInvocation(BaseInvocation):
"""Calculate the coordinates and overlaps of tiles that cover a target image shape."""
image_width: int = InputField(ge=1, default=1024, description="The image width, in pixels, to calculate tiles for.")
image_height: int = InputField(
ge=1, default=1024, description="The image height, in pixels, to calculate tiles for."
)
tile_width: int = InputField(ge=1, default=576, description="The tile width, in pixels.")
tile_height: int = InputField(ge=1, default=576, description="The tile height, in pixels.")
min_overlap: int = InputField(default=128, ge=0, description="Minimum overlap between adjacent tiles, in pixels.")
def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput:
tiles = calc_tiles_min_overlap(
image_height=self.image_height,
image_width=self.image_width,
tile_height=self.tile_height,
tile_width=self.tile_width,
min_overlap=self.min_overlap,
)
return CalculateImageTilesOutput(tiles=tiles)
@invocation_output("tile_to_properties_output")
class TileToPropertiesOutput(BaseInvocationOutput):
coords_left: int = OutputField(description="Left coordinate of the tile relative to its parent image.")
@ -76,7 +159,14 @@ class TileToPropertiesOutput(BaseInvocationOutput):
overlap_right: int = OutputField(description="Overlap between this tile and its right neighbor.")
@invocation("tile_to_properties", title="Tile to Properties", tags=["tiles"], category="tiles", version="1.0.0")
@invocation(
"tile_to_properties",
title="Tile to Properties",
tags=["tiles"],
category="tiles",
version="1.0.0",
classification=Classification.Beta,
)
class TileToPropertiesInvocation(BaseInvocation):
"""Split a Tile into its individual properties."""
@ -102,7 +192,14 @@ class PairTileImageOutput(BaseInvocationOutput):
tile_with_image: TileWithImage = OutputField(description="A tile description with its corresponding image.")
@invocation("pair_tile_image", title="Pair Tile with Image", tags=["tiles"], category="tiles", version="1.0.0")
@invocation(
"pair_tile_image",
title="Pair Tile with Image",
tags=["tiles"],
category="tiles",
version="1.0.0",
classification=Classification.Beta,
)
class PairTileImageInvocation(BaseInvocation):
"""Pair an image with its tile properties."""
@ -121,13 +218,29 @@ class PairTileImageInvocation(BaseInvocation):
)
@invocation("merge_tiles_to_image", title="Merge Tiles to Image", tags=["tiles"], category="tiles", version="1.1.0")
BLEND_MODES = Literal["Linear", "Seam"]
@invocation(
"merge_tiles_to_image",
title="Merge Tiles to Image",
tags=["tiles"],
category="tiles",
version="1.1.0",
classification=Classification.Beta,
)
class MergeTilesToImageInvocation(BaseInvocation, WithMetadata):
"""Merge multiple tile images into a single image."""
# Inputs
tiles_with_images: list[TileWithImage] = InputField(description="A list of tile images with tile properties.")
blend_mode: BLEND_MODES = InputField(
default="Seam",
description="blending type Linear or Seam",
input=Input.Direct,
)
blend_amount: int = InputField(
default=32,
ge=0,
description="The amount to blend adjacent tiles in pixels. Must be <= the amount of overlap between adjacent tiles.",
)
@ -157,10 +270,18 @@ class MergeTilesToImageInvocation(BaseInvocation, WithMetadata):
channels = tile_np_images[0].shape[-1]
dtype = tile_np_images[0].dtype
np_image = np.zeros(shape=(height, width, channels), dtype=dtype)
if self.blend_mode == "Linear":
merge_tiles_with_linear_blending(
dst_image=np_image, tiles=tiles, tile_images=tile_np_images, blend_amount=self.blend_amount
)
elif self.blend_mode == "Seam":
merge_tiles_with_seam_blending(
dst_image=np_image, tiles=tiles, tile_images=tile_np_images, blend_amount=self.blend_amount
)
else:
raise ValueError(f"Unsupported blend mode: '{self.blend_mode}'.")
merge_tiles_with_linear_blending(
dst_image=np_image, tiles=tiles, tile_images=tile_np_images, blend_amount=self.blend_amount
)
# Convert into a PIL image and save
pil_image = Image.fromarray(np_image)
image_dto = context.services.images.create(

View File

@ -221,6 +221,9 @@ class InvokeAIAppConfig(InvokeAISettings):
allow_credentials : bool = Field(default=True, description="Allow CORS credentials", json_schema_extra=Categories.WebServer)
allow_methods : List[str] = Field(default=["*"], description="Methods allowed for CORS", json_schema_extra=Categories.WebServer)
allow_headers : List[str] = Field(default=["*"], description="Headers allowed for CORS", json_schema_extra=Categories.WebServer)
# SSL options correspond to https://www.uvicorn.org/settings/#https
ssl_certfile : Optional[Path] = Field(default=None, description="SSL certificate file (for HTTPS)", json_schema_extra=Categories.WebServer)
ssl_keyfile : Optional[Path] = Field(default=None, description="SSL key file", json_schema_extra=Categories.WebServer)
# FEATURES
esrgan : bool = Field(default=True, description="Enable/disable upscaling code", json_schema_extra=Categories.Features)

View File

@ -85,7 +85,7 @@ class ModelInstallService(ModelInstallServiceBase):
def event_bus(self) -> Optional[EventServiceBase]: # noqa D102
return self._event_bus
def stop(self) -> None:
def stop(self, *args, **kwargs) -> None:
"""Stop the install thread; after this the object can be deleted and garbage collected."""
self._install_queue.put(STOP_JOB)

View File

@ -95,21 +95,13 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
"""--sql
INSERT INTO model_config (
id,
base,
type,
name,
path,
original_hash,
config
)
VALUES (?,?,?,?,?,?,?);
VALUES (?,?,?);
""",
(
key,
record.base,
record.type,
record.name,
record.path,
record.original_hash,
json_serialized,
),
@ -173,14 +165,11 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
self._cursor.execute(
"""--sql
UPDATE model_config
SET base=?,
type=?,
name=?,
path=?,
SET
config=?
WHERE id=?;
""",
(record.base, record.type, record.name, record.path, json_serialized, key),
(json_serialized, key),
)
if self._cursor.rowcount == 0:
raise UnknownModelException("model not found")
@ -278,7 +267,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
self._cursor.execute(
"""--sql
SELECT config FROM model_config
WHERE model_path=?;
WHERE path=?;
""",
(str(path),),
)

View File

@ -22,6 +22,7 @@ def migrate_callback(cursor: sqlite3.Cursor, **kwargs) -> None:
_drop_old_workflow_tables(cursor)
_add_workflow_library(cursor)
_drop_model_manager_metadata(cursor)
_recreate_model_config(cursor)
_migrate_embedded_workflows(cursor, logger, image_files)
@ -101,6 +102,41 @@ def _drop_model_manager_metadata(cursor: sqlite3.Cursor) -> None:
cursor.execute("DROP TABLE IF EXISTS model_manager_metadata;")
def _recreate_model_config(cursor: sqlite3.Cursor) -> None:
"""
Drops the `model_config` table, recreating it.
In 3.4.0, this table used explicit columns but was changed to use json_extract 3.5.0.
Because this table is not used in production, we are able to simply drop it and recreate it.
"""
cursor.execute("DROP TABLE IF EXISTS model_config;")
cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS model_config (
id TEXT NOT NULL PRIMARY KEY,
-- The next 3 fields are enums in python, unrestricted string here
base TEXT GENERATED ALWAYS as (json_extract(config, '$.base')) VIRTUAL NOT NULL,
type TEXT GENERATED ALWAYS as (json_extract(config, '$.type')) VIRTUAL NOT NULL,
name TEXT GENERATED ALWAYS as (json_extract(config, '$.name')) VIRTUAL NOT NULL,
path TEXT GENERATED ALWAYS as (json_extract(config, '$.path')) VIRTUAL NOT NULL,
format TEXT GENERATED ALWAYS as (json_extract(config, '$.format')) VIRTUAL NOT NULL,
original_hash TEXT, -- could be null
-- Serialized JSON representation of the whole config object,
-- which will contain additional fields from subclasses
config TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- unique constraint on combo of name, base and type
UNIQUE(name, base, type)
);
"""
)
def _migrate_embedded_workflows(
cursor: sqlite3.Cursor,
logger: Logger,