Merge remote-tracking branch 'origin/main' into feat/db/migrations

This commit is contained in:
psychedelicious 2023-12-12 17:22:52 +11:00
commit 2cdda1fda2
33 changed files with 1933 additions and 195 deletions

View File

@ -1,6 +1,20 @@
# simple Makefile with scripts that are otherwise hard to remember
# to use, run from the repo root `make <command>`
default: help
help:
@echo Developer commands:
@echo
@echo "ruff Run ruff, fixing any safely-fixable errors and formatting"
@echo "ruff-unsafe Run ruff, fixing all fixable errors and formatting"
@echo "mypy Run mypy using the config in pyproject.toml to identify type mismatches and other coding errors"
@echo "mypy-all Run mypy ignoring the config in pyproject.tom but still ignoring missing imports"
@echo "frontend-build Build the frontend in order to run on localhost:9090"
@echo "frontend-dev Run the frontend in developer mode on localhost:5173"
@echo "installer-zip Build the installer .zip file for the current version"
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
# Runs ruff, fixing any safely-fixable errors and formatting
ruff:
ruff check . --fix
@ -19,3 +33,20 @@ mypy:
# (many files are ignored by the config, so this is useful for checking all files)
mypy-all:
mypy scripts/invokeai-web.py --config-file= --ignore-missing-imports
# Build the frontend
frontend-build:
cd invokeai/frontend/web && pnpm build
# Run the frontend in dev mode
frontend-dev:
cd invokeai/frontend/web && pnpm dev
# Installer zip file
installer-zip:
cd installer && ./create_installer.sh
# Tag the release
tag-release:
cd installer && ./tag_release.sh

View File

@ -155,13 +155,15 @@ groups in `invokeia.yaml`:
### Web Server
| Setting | Default Value | Description |
|----------|----------------|--------------|
|---------------------|---------------|----------------------------------------------------------------------------------------------------------------------------|
| `host` | `localhost` | Name or IP address of the network interface that the web server will listen on |
| `port` | `9090` | Network port number that the web server will listen on |
| `allow_origins` | `[]` | A list of host names or IP addresses that are allowed to connect to the InvokeAI API in the format `['host1','host2',...]` |
| `allow_credentials` | `true` | Require credentials for a foreign host to access the InvokeAI API (don't change this) |
| `allow_methods` | `*` | List of HTTP methods ("GET", "POST") that the web server is allowed to use when accessing the API |
| `allow_headers` | `*` | List of HTTP headers that the web server will accept when accessing the API |
| `ssl_certfile` | null | Path to an SSL certificate file, used to enable HTTPS. |
| `ssl_keyfile` | null | Path to an SSL keyfile, if the key is not included in the certificate file. |
The documentation for InvokeAI's API can be accessed by browsing to the following URL: [http://localhost:9090/docs].

View File

@ -13,14 +13,6 @@ function is_bin_in_path {
builtin type -P "$1" &>/dev/null
}
function does_tag_exist {
git rev-parse --quiet --verify "refs/tags/$1" >/dev/null
}
function git_show_ref {
git show-ref --dereference $1 --abbrev 7
}
function git_show {
git show -s --format='%h %s' $1
}
@ -53,50 +45,11 @@ VERSION=$(
)
PATCH=""
VERSION="v${VERSION}${PATCH}"
LATEST_TAG="v3-latest"
echo "Building installer for version $VERSION..."
echo
if does_tag_exist $VERSION; then
echo -e "${BCYAN}${VERSION}${RESET} already exists:"
git_show_ref tags/$VERSION
echo
fi
if does_tag_exist $LATEST_TAG; then
echo -e "${BCYAN}${LATEST_TAG}${RESET} already exists:"
git_show_ref tags/$LATEST_TAG
echo
fi
echo -e "${BGREEN}HEAD${RESET}:"
git_show
echo
echo -e -n "Create tags ${BCYAN}${VERSION}${RESET} and ${BCYAN}${LATEST_TAG}${RESET} @ ${BGREEN}HEAD${RESET}, ${RED}deleting existing tags on remote${RESET}? "
read -e -p 'y/n [n]: ' input
RESPONSE=${input:='n'}
if [ "$RESPONSE" == 'y' ]; then
echo
echo -e "Deleting ${BCYAN}${VERSION}${RESET} tag on remote..."
git push origin :refs/tags/$VERSION
echo -e "Tagging ${BGREEN}HEAD${RESET} with ${BCYAN}${VERSION}${RESET} locally..."
if ! git tag -fa $VERSION; then
echo "Existing/invalid tag"
exit -1
fi
echo -e "Deleting ${BCYAN}${LATEST_TAG}${RESET} tag on remote..."
git push origin :refs/tags/$LATEST_TAG
echo -e "Tagging ${BGREEN}HEAD${RESET} with ${BCYAN}${LATEST_TAG}${RESET} locally..."
git tag -fa $LATEST_TAG
echo
echo -e "${BYELLOW}Remember to 'git push origin --tags'!${RESET}"
fi
# ---------------------- FRONTEND ----------------------
pushd ../invokeai/frontend/web >/dev/null

71
installer/tag_release.sh Executable file
View File

@ -0,0 +1,71 @@
#!/bin/bash
set -e
BCYAN="\e[1;36m"
BYELLOW="\e[1;33m"
BGREEN="\e[1;32m"
BRED="\e[1;31m"
RED="\e[31m"
RESET="\e[0m"
function does_tag_exist {
git rev-parse --quiet --verify "refs/tags/$1" >/dev/null
}
function git_show_ref {
git show-ref --dereference $1 --abbrev 7
}
function git_show {
git show -s --format='%h %s' $1
}
VERSION=$(
cd ..
python -c "from invokeai.version import __version__ as version; print(version)"
)
PATCH=""
MAJOR_VERSION=$(echo $VERSION | sed 's/\..*$//')
VERSION="v${VERSION}${PATCH}"
LATEST_TAG="v${MAJOR_VERSION}-latest"
if does_tag_exist $VERSION; then
echo -e "${BCYAN}${VERSION}${RESET} already exists:"
git_show_ref tags/$VERSION
echo
fi
if does_tag_exist $LATEST_TAG; then
echo -e "${BCYAN}${LATEST_TAG}${RESET} already exists:"
git_show_ref tags/$LATEST_TAG
echo
fi
echo -e "${BGREEN}HEAD${RESET}:"
git_show
echo
echo -e -n "Create tags ${BCYAN}${VERSION}${RESET} and ${BCYAN}${LATEST_TAG}${RESET} @ ${BGREEN}HEAD${RESET}, ${RED}deleting existing tags on remote${RESET}? "
read -e -p 'y/n [n]: ' input
RESPONSE=${input:='n'}
if [ "$RESPONSE" == 'y' ]; then
echo
echo -e "Deleting ${BCYAN}${VERSION}${RESET} tag on remote..."
git push --delete origin $VERSION
echo -e "Tagging ${BGREEN}HEAD${RESET} with ${BCYAN}${VERSION}${RESET} locally..."
if ! git tag -fa $VERSION; then
echo "Existing/invalid tag"
exit -1
fi
echo -e "Deleting ${BCYAN}${LATEST_TAG}${RESET} tag on remote..."
git push --delete origin $LATEST_TAG
echo -e "Tagging ${BGREEN}HEAD${RESET} with ${BCYAN}${LATEST_TAG}${RESET} locally..."
git tag -fa $LATEST_TAG
echo -e "Pushing updated tags to remote..."
git push origin --tags
fi
exit 0

View File

@ -272,6 +272,8 @@ def invoke_api() -> None:
port=port,
loop="asyncio",
log_level=app_config.log_level,
ssl_certfile=app_config.ssl_certfile,
ssl_keyfile=app_config.ssl_keyfile,
)
server = uvicorn.Server(config)

View File

@ -39,6 +39,19 @@ class InvalidFieldError(TypeError):
pass
class Classification(str, Enum, metaclass=MetaEnum):
"""
The classification of an Invocation.
- `Stable`: The invocation, including its inputs/outputs and internal logic, is stable. You may build workflows with it, having confidence that they will not break because of a change in this invocation.
- `Beta`: The invocation is not yet stable, but is planned to be stable in the future. Workflows built around this invocation may break, but we are committed to supporting this invocation long-term.
- `Prototype`: The invocation is not yet stable and may be removed from the application at any time. Workflows built around this invocation may break, and we are *not* committed to supporting this invocation.
"""
Stable = "stable"
Beta = "beta"
Prototype = "prototype"
class Input(str, Enum, metaclass=MetaEnum):
"""
The type of input a field accepts.
@ -439,6 +452,7 @@ class UIConfigBase(BaseModel):
description='The node\'s version. Should be a valid semver string e.g. "1.0.0" or "3.8.13".',
)
node_pack: Optional[str] = Field(default=None, description="Whether or not this is a custom node")
classification: Classification = Field(default=Classification.Stable, description="The node's classification")
model_config = ConfigDict(
validate_assignment=True,
@ -607,6 +621,7 @@ class BaseInvocation(ABC, BaseModel):
schema["category"] = uiconfig.category
if uiconfig.node_pack is not None:
schema["node_pack"] = uiconfig.node_pack
schema["classification"] = uiconfig.classification
schema["version"] = uiconfig.version
if "required" not in schema or not isinstance(schema["required"], list):
schema["required"] = []
@ -782,6 +797,7 @@ def invocation(
category: Optional[str] = None,
version: Optional[str] = None,
use_cache: Optional[bool] = True,
classification: Classification = Classification.Stable,
) -> Callable[[Type[TBaseInvocation]], Type[TBaseInvocation]]:
"""
Registers an invocation.
@ -792,6 +808,7 @@ def invocation(
:param Optional[str] category: Adds a category to the invocation. Used to group the invocations in the UI. Defaults to None.
:param Optional[str] version: Adds a version to the invocation. Must be a valid semver string. Defaults to None.
:param Optional[bool] use_cache: Whether or not to use the invocation cache. Defaults to True. The user may override this in the workflow editor.
:param Classification classification: The classification of the invocation. Defaults to FeatureClassification.Stable. Use Beta or Prototype if the invocation is unstable.
"""
def wrapper(cls: Type[TBaseInvocation]) -> Type[TBaseInvocation]:
@ -812,6 +829,7 @@ def invocation(
cls.UIConfig.title = title
cls.UIConfig.tags = tags
cls.UIConfig.category = category
cls.UIConfig.classification = classification
# Grab the node pack's name from the module name, if it's a custom node
is_custom_node = cls.__module__.rsplit(".", 1)[0] == "invokeai.app.invocations"

View File

@ -1,3 +1,5 @@
from typing import Literal
import numpy as np
from PIL import Image
from pydantic import BaseModel
@ -5,6 +7,8 @@ from pydantic import BaseModel
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
Input,
InputField,
InvocationContext,
OutputField,
@ -14,7 +18,13 @@ from invokeai.app.invocations.baseinvocation import (
)
from invokeai.app.invocations.primitives import ImageField, ImageOutput
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
from invokeai.backend.tiles.tiles import calc_tiles_with_overlap, merge_tiles_with_linear_blending
from invokeai.backend.tiles.tiles import (
calc_tiles_even_split,
calc_tiles_min_overlap,
calc_tiles_with_overlap,
merge_tiles_with_linear_blending,
merge_tiles_with_seam_blending,
)
from invokeai.backend.tiles.utils import Tile
@ -55,6 +65,79 @@ class CalculateImageTilesInvocation(BaseInvocation):
return CalculateImageTilesOutput(tiles=tiles)
@invocation(
"calculate_image_tiles_even_split",
title="Calculate Image Tiles Even Split",
tags=["tiles"],
category="tiles",
version="1.0.0",
classification=Classification.Beta,
)
class CalculateImageTilesEvenSplitInvocation(BaseInvocation):
"""Calculate the coordinates and overlaps of tiles that cover a target image shape."""
image_width: int = InputField(ge=1, default=1024, description="The image width, in pixels, to calculate tiles for.")
image_height: int = InputField(
ge=1, default=1024, description="The image height, in pixels, to calculate tiles for."
)
num_tiles_x: int = InputField(
default=2,
ge=1,
description="Number of tiles to divide image into on the x axis",
)
num_tiles_y: int = InputField(
default=2,
ge=1,
description="Number of tiles to divide image into on the y axis",
)
overlap_fraction: float = InputField(
default=0.25,
ge=0,
lt=1,
description="Overlap between adjacent tiles as a fraction of the tile's dimensions (0-1)",
)
def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput:
tiles = calc_tiles_even_split(
image_height=self.image_height,
image_width=self.image_width,
num_tiles_x=self.num_tiles_x,
num_tiles_y=self.num_tiles_y,
overlap_fraction=self.overlap_fraction,
)
return CalculateImageTilesOutput(tiles=tiles)
@invocation(
"calculate_image_tiles_min_overlap",
title="Calculate Image Tiles Minimum Overlap",
tags=["tiles"],
category="tiles",
version="1.0.0",
classification=Classification.Beta,
)
class CalculateImageTilesMinimumOverlapInvocation(BaseInvocation):
"""Calculate the coordinates and overlaps of tiles that cover a target image shape."""
image_width: int = InputField(ge=1, default=1024, description="The image width, in pixels, to calculate tiles for.")
image_height: int = InputField(
ge=1, default=1024, description="The image height, in pixels, to calculate tiles for."
)
tile_width: int = InputField(ge=1, default=576, description="The tile width, in pixels.")
tile_height: int = InputField(ge=1, default=576, description="The tile height, in pixels.")
min_overlap: int = InputField(default=128, ge=0, description="Minimum overlap between adjacent tiles, in pixels.")
def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput:
tiles = calc_tiles_min_overlap(
image_height=self.image_height,
image_width=self.image_width,
tile_height=self.tile_height,
tile_width=self.tile_width,
min_overlap=self.min_overlap,
)
return CalculateImageTilesOutput(tiles=tiles)
@invocation_output("tile_to_properties_output")
class TileToPropertiesOutput(BaseInvocationOutput):
coords_left: int = OutputField(description="Left coordinate of the tile relative to its parent image.")
@ -76,7 +159,14 @@ class TileToPropertiesOutput(BaseInvocationOutput):
overlap_right: int = OutputField(description="Overlap between this tile and its right neighbor.")
@invocation("tile_to_properties", title="Tile to Properties", tags=["tiles"], category="tiles", version="1.0.0")
@invocation(
"tile_to_properties",
title="Tile to Properties",
tags=["tiles"],
category="tiles",
version="1.0.0",
classification=Classification.Beta,
)
class TileToPropertiesInvocation(BaseInvocation):
"""Split a Tile into its individual properties."""
@ -102,7 +192,14 @@ class PairTileImageOutput(BaseInvocationOutput):
tile_with_image: TileWithImage = OutputField(description="A tile description with its corresponding image.")
@invocation("pair_tile_image", title="Pair Tile with Image", tags=["tiles"], category="tiles", version="1.0.0")
@invocation(
"pair_tile_image",
title="Pair Tile with Image",
tags=["tiles"],
category="tiles",
version="1.0.0",
classification=Classification.Beta,
)
class PairTileImageInvocation(BaseInvocation):
"""Pair an image with its tile properties."""
@ -121,13 +218,29 @@ class PairTileImageInvocation(BaseInvocation):
)
@invocation("merge_tiles_to_image", title="Merge Tiles to Image", tags=["tiles"], category="tiles", version="1.1.0")
BLEND_MODES = Literal["Linear", "Seam"]
@invocation(
"merge_tiles_to_image",
title="Merge Tiles to Image",
tags=["tiles"],
category="tiles",
version="1.1.0",
classification=Classification.Beta,
)
class MergeTilesToImageInvocation(BaseInvocation, WithMetadata):
"""Merge multiple tile images into a single image."""
# Inputs
tiles_with_images: list[TileWithImage] = InputField(description="A list of tile images with tile properties.")
blend_mode: BLEND_MODES = InputField(
default="Seam",
description="blending type Linear or Seam",
input=Input.Direct,
)
blend_amount: int = InputField(
default=32,
ge=0,
description="The amount to blend adjacent tiles in pixels. Must be <= the amount of overlap between adjacent tiles.",
)
@ -157,10 +270,18 @@ class MergeTilesToImageInvocation(BaseInvocation, WithMetadata):
channels = tile_np_images[0].shape[-1]
dtype = tile_np_images[0].dtype
np_image = np.zeros(shape=(height, width, channels), dtype=dtype)
if self.blend_mode == "Linear":
merge_tiles_with_linear_blending(
dst_image=np_image, tiles=tiles, tile_images=tile_np_images, blend_amount=self.blend_amount
)
elif self.blend_mode == "Seam":
merge_tiles_with_seam_blending(
dst_image=np_image, tiles=tiles, tile_images=tile_np_images, blend_amount=self.blend_amount
)
else:
raise ValueError(f"Unsupported blend mode: '{self.blend_mode}'.")
# Convert into a PIL image and save
pil_image = Image.fromarray(np_image)
image_dto = context.services.images.create(

View File

@ -221,6 +221,9 @@ class InvokeAIAppConfig(InvokeAISettings):
allow_credentials : bool = Field(default=True, description="Allow CORS credentials", json_schema_extra=Categories.WebServer)
allow_methods : List[str] = Field(default=["*"], description="Methods allowed for CORS", json_schema_extra=Categories.WebServer)
allow_headers : List[str] = Field(default=["*"], description="Headers allowed for CORS", json_schema_extra=Categories.WebServer)
# SSL options correspond to https://www.uvicorn.org/settings/#https
ssl_certfile : Optional[Path] = Field(default=None, description="SSL certificate file (for HTTPS)", json_schema_extra=Categories.WebServer)
ssl_keyfile : Optional[Path] = Field(default=None, description="SSL key file", json_schema_extra=Categories.WebServer)
# FEATURES
esrgan : bool = Field(default=True, description="Enable/disable upscaling code", json_schema_extra=Categories.Features)

View File

@ -85,7 +85,7 @@ class ModelInstallService(ModelInstallServiceBase):
def event_bus(self) -> Optional[EventServiceBase]: # noqa D102
return self._event_bus
def stop(self) -> None:
def stop(self, *args, **kwargs) -> None:
"""Stop the install thread; after this the object can be deleted and garbage collected."""
self._install_queue.put(STOP_JOB)

View File

@ -95,21 +95,13 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
"""--sql
INSERT INTO model_config (
id,
base,
type,
name,
path,
original_hash,
config
)
VALUES (?,?,?,?,?,?,?);
VALUES (?,?,?);
""",
(
key,
record.base,
record.type,
record.name,
record.path,
record.original_hash,
json_serialized,
),
@ -173,14 +165,11 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
self._cursor.execute(
"""--sql
UPDATE model_config
SET base=?,
type=?,
name=?,
path=?,
SET
config=?
WHERE id=?;
""",
(record.base, record.type, record.name, record.path, json_serialized, key),
(json_serialized, key),
)
if self._cursor.rowcount == 0:
raise UnknownModelException("model not found")
@ -278,7 +267,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
self._cursor.execute(
"""--sql
SELECT config FROM model_config
WHERE model_path=?;
WHERE path=?;
""",
(str(path),),
)

View File

@ -22,6 +22,7 @@ def migrate_callback(cursor: sqlite3.Cursor, **kwargs) -> None:
_drop_old_workflow_tables(cursor)
_add_workflow_library(cursor)
_drop_model_manager_metadata(cursor)
_recreate_model_config(cursor)
_migrate_embedded_workflows(cursor, logger, image_files)
@ -101,6 +102,41 @@ def _drop_model_manager_metadata(cursor: sqlite3.Cursor) -> None:
cursor.execute("DROP TABLE IF EXISTS model_manager_metadata;")
def _recreate_model_config(cursor: sqlite3.Cursor) -> None:
"""
Drops the `model_config` table, recreating it.
In 3.4.0, this table used explicit columns but was changed to use json_extract 3.5.0.
Because this table is not used in production, we are able to simply drop it and recreate it.
"""
cursor.execute("DROP TABLE IF EXISTS model_config;")
cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS model_config (
id TEXT NOT NULL PRIMARY KEY,
-- The next 3 fields are enums in python, unrestricted string here
base TEXT GENERATED ALWAYS as (json_extract(config, '$.base')) VIRTUAL NOT NULL,
type TEXT GENERATED ALWAYS as (json_extract(config, '$.type')) VIRTUAL NOT NULL,
name TEXT GENERATED ALWAYS as (json_extract(config, '$.name')) VIRTUAL NOT NULL,
path TEXT GENERATED ALWAYS as (json_extract(config, '$.path')) VIRTUAL NOT NULL,
format TEXT GENERATED ALWAYS as (json_extract(config, '$.format')) VIRTUAL NOT NULL,
original_hash TEXT, -- could be null
-- Serialized JSON representation of the whole config object,
-- which will contain additional fields from subclasses
config TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- unique constraint on combo of name, base and type
UNIQUE(name, base, type)
);
"""
)
def _migrate_embedded_workflows(
cursor: sqlite3.Cursor,
logger: Logger,

View File

@ -2,6 +2,7 @@
"""Migrate from the InvokeAI v2 models.yaml format to the v3 sqlite format."""
from hashlib import sha1
from logging import Logger
from omegaconf import DictConfig, OmegaConf
from pydantic import TypeAdapter
@ -10,6 +11,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_records import (
DuplicateModelException,
ModelRecordServiceSQL,
UnknownModelException,
)
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.backend.model_manager.config import (
@ -38,9 +40,9 @@ class MigrateModelYamlToDb:
"""
config: InvokeAIAppConfig
logger: InvokeAILogger
logger: Logger
def __init__(self):
def __init__(self) -> None:
self.config = InvokeAIAppConfig.get_config()
self.config.parse_args()
self.logger = InvokeAILogger.get_logger()
@ -54,9 +56,11 @@ class MigrateModelYamlToDb:
def get_yaml(self) -> DictConfig:
"""Fetch the models.yaml DictConfig for this installation."""
yaml_path = self.config.model_conf_path
return OmegaConf.load(yaml_path)
omegaconf = OmegaConf.load(yaml_path)
assert isinstance(omegaconf, DictConfig)
return omegaconf
def migrate(self):
def migrate(self) -> None:
"""Do the migration from models.yaml to invokeai.db."""
db = self.get_db()
yaml = self.get_yaml()
@ -70,6 +74,7 @@ class MigrateModelYamlToDb:
base_type, model_type, model_name = str(model_key).split("/")
hash = FastModelHash.hash(self.config.models_path / stanza.path)
assert isinstance(model_key, str)
new_key = sha1(model_key.encode("utf-8")).hexdigest()
stanza["base"] = BaseModelType(base_type)
@ -78,12 +83,20 @@ class MigrateModelYamlToDb:
stanza["original_hash"] = hash
stanza["current_hash"] = hash
new_config = ModelsValidator.validate_python(stanza)
self.logger.info(f"Adding model {model_name} with key {model_key}")
new_config: AnyModelConfig = ModelsValidator.validate_python(stanza) # type: ignore # see https://github.com/pydantic/pydantic/discussions/7094
try:
if original_record := db.search_by_path(stanza.path):
key = original_record[0].key
self.logger.info(f"Updating model {model_name} with information from models.yaml using key {key}")
db.update_model(key, new_config)
else:
self.logger.info(f"Adding model {model_name} with key {model_key}")
db.add_model(new_key, new_config)
except DuplicateModelException:
self.logger.warning(f"Model {model_name} is already in the database")
except UnknownModelException:
self.logger.warning(f"Model at {stanza.path} could not be found in database")
def main():

View File

@ -3,7 +3,42 @@ from typing import Union
import numpy as np
from invokeai.backend.tiles.utils import TBLR, Tile, paste
from invokeai.app.invocations.latent import LATENT_SCALE_FACTOR
from invokeai.backend.tiles.utils import TBLR, Tile, paste, seam_blend
def calc_overlap(tiles: list[Tile], num_tiles_x: int, num_tiles_y: int) -> list[Tile]:
"""Calculate and update the overlap of a list of tiles.
Args:
tiles (list[Tile]): The list of tiles describing the locations of the respective `tile_images`.
num_tiles_x: the number of tiles on the x axis.
num_tiles_y: the number of tiles on the y axis.
"""
def get_tile_or_none(idx_y: int, idx_x: int) -> Union[Tile, None]:
if idx_y < 0 or idx_y > num_tiles_y or idx_x < 0 or idx_x > num_tiles_x:
return None
return tiles[idx_y * num_tiles_x + idx_x]
for tile_idx_y in range(num_tiles_y):
for tile_idx_x in range(num_tiles_x):
cur_tile = get_tile_or_none(tile_idx_y, tile_idx_x)
top_neighbor_tile = get_tile_or_none(tile_idx_y - 1, tile_idx_x)
left_neighbor_tile = get_tile_or_none(tile_idx_y, tile_idx_x - 1)
assert cur_tile is not None
# Update cur_tile top-overlap and corresponding top-neighbor bottom-overlap.
if top_neighbor_tile is not None:
cur_tile.overlap.top = max(0, top_neighbor_tile.coords.bottom - cur_tile.coords.top)
top_neighbor_tile.overlap.bottom = cur_tile.overlap.top
# Update cur_tile left-overlap and corresponding left-neighbor right-overlap.
if left_neighbor_tile is not None:
cur_tile.overlap.left = max(0, left_neighbor_tile.coords.right - cur_tile.coords.left)
left_neighbor_tile.overlap.right = cur_tile.overlap.left
return tiles
def calc_tiles_with_overlap(
@ -63,31 +98,125 @@ def calc_tiles_with_overlap(
tiles.append(tile)
def get_tile_or_none(idx_y: int, idx_x: int) -> Union[Tile, None]:
if idx_y < 0 or idx_y > num_tiles_y or idx_x < 0 or idx_x > num_tiles_x:
return None
return tiles[idx_y * num_tiles_x + idx_x]
return calc_overlap(tiles, num_tiles_x, num_tiles_y)
# Iterate over tiles again and calculate overlaps.
def calc_tiles_even_split(
image_height: int, image_width: int, num_tiles_x: int, num_tiles_y: int, overlap_fraction: float = 0
) -> list[Tile]:
"""Calculate the tile coordinates for a given image shape with the number of tiles requested.
Args:
image_height (int): The image height in px.
image_width (int): The image width in px.
num_x_tiles (int): The number of tile to split the image into on the X-axis.
num_y_tiles (int): The number of tile to split the image into on the Y-axis.
overlap_fraction (float, optional): The target overlap as fraction of the tiles size. Defaults to 0.
Returns:
list[Tile]: A list of tiles that cover the image shape. Ordered from left-to-right, top-to-bottom.
"""
# Ensure tile size is divisible by 8
if image_width % LATENT_SCALE_FACTOR != 0 or image_height % LATENT_SCALE_FACTOR != 0:
raise ValueError(f"image size (({image_width}, {image_height})) must be divisible by {LATENT_SCALE_FACTOR}")
# Calculate the overlap size based on the percentage and adjust it to be divisible by 8 (rounding up)
overlap_x = LATENT_SCALE_FACTOR * math.ceil(
int((image_width / num_tiles_x) * overlap_fraction) / LATENT_SCALE_FACTOR
)
overlap_y = LATENT_SCALE_FACTOR * math.ceil(
int((image_height / num_tiles_y) * overlap_fraction) / LATENT_SCALE_FACTOR
)
# Calculate the tile size based on the number of tiles and overlap, and ensure it's divisible by 8 (rounding down)
tile_size_x = LATENT_SCALE_FACTOR * math.floor(
((image_width + overlap_x * (num_tiles_x - 1)) // num_tiles_x) / LATENT_SCALE_FACTOR
)
tile_size_y = LATENT_SCALE_FACTOR * math.floor(
((image_height + overlap_y * (num_tiles_y - 1)) // num_tiles_y) / LATENT_SCALE_FACTOR
)
# tiles[y * num_tiles_x + x] is the tile for the y'th row, x'th column.
tiles: list[Tile] = []
# Calculate tile coordinates. (Ignore overlap values for now.)
for tile_idx_y in range(num_tiles_y):
# Calculate the top and bottom of the row
top = tile_idx_y * (tile_size_y - overlap_y)
bottom = min(top + tile_size_y, image_height)
# For the last row adjust bottom to be the height of the image
if tile_idx_y == num_tiles_y - 1:
bottom = image_height
for tile_idx_x in range(num_tiles_x):
cur_tile = get_tile_or_none(tile_idx_y, tile_idx_x)
top_neighbor_tile = get_tile_or_none(tile_idx_y - 1, tile_idx_x)
left_neighbor_tile = get_tile_or_none(tile_idx_y, tile_idx_x - 1)
# Calculate the left & right coordinate of each tile
left = tile_idx_x * (tile_size_x - overlap_x)
right = min(left + tile_size_x, image_width)
# For the last tile in the row adjust right to be the width of the image
if tile_idx_x == num_tiles_x - 1:
right = image_width
assert cur_tile is not None
tile = Tile(
coords=TBLR(top=top, bottom=bottom, left=left, right=right),
overlap=TBLR(top=0, bottom=0, left=0, right=0),
)
# Update cur_tile top-overlap and corresponding top-neighbor bottom-overlap.
if top_neighbor_tile is not None:
cur_tile.overlap.top = max(0, top_neighbor_tile.coords.bottom - cur_tile.coords.top)
top_neighbor_tile.overlap.bottom = cur_tile.overlap.top
tiles.append(tile)
# Update cur_tile left-overlap and corresponding left-neighbor right-overlap.
if left_neighbor_tile is not None:
cur_tile.overlap.left = max(0, left_neighbor_tile.coords.right - cur_tile.coords.left)
left_neighbor_tile.overlap.right = cur_tile.overlap.left
return calc_overlap(tiles, num_tiles_x, num_tiles_y)
return tiles
def calc_tiles_min_overlap(
image_height: int,
image_width: int,
tile_height: int,
tile_width: int,
min_overlap: int = 0,
) -> list[Tile]:
"""Calculate the tile coordinates for a given image shape under a simple tiling scheme with overlaps.
Args:
image_height (int): The image height in px.
image_width (int): The image width in px.
tile_height (int): The tile height in px. All tiles will have this height.
tile_width (int): The tile width in px. All tiles will have this width.
min_overlap (int): The target minimum overlap between adjacent tiles. If the tiles do not evenly cover the image
shape, then the overlap will be spread between the tiles.
Returns:
list[Tile]: A list of tiles that cover the image shape. Ordered from left-to-right, top-to-bottom.
"""
assert min_overlap < tile_height
assert min_overlap < tile_width
# The If Else catches the case when the tile size is larger than the images size and just clips the number of tiles to 1
num_tiles_x = math.ceil((image_width - min_overlap) / (tile_width - min_overlap)) if tile_width < image_width else 1
num_tiles_y = (
math.ceil((image_height - min_overlap) / (tile_height - min_overlap)) if tile_height < image_height else 1
)
# tiles[y * num_tiles_x + x] is the tile for the y'th row, x'th column.
tiles: list[Tile] = []
# Calculate tile coordinates. (Ignore overlap values for now.)
for tile_idx_y in range(num_tiles_y):
top = (tile_idx_y * (image_height - tile_height)) // (num_tiles_y - 1) if num_tiles_y > 1 else 0
bottom = top + tile_height
for tile_idx_x in range(num_tiles_x):
left = (tile_idx_x * (image_width - tile_width)) // (num_tiles_x - 1) if num_tiles_x > 1 else 0
right = left + tile_width
tile = Tile(
coords=TBLR(top=top, bottom=bottom, left=left, right=right),
overlap=TBLR(top=0, bottom=0, left=0, right=0),
)
tiles.append(tile)
return calc_overlap(tiles, num_tiles_x, num_tiles_y)
def merge_tiles_with_linear_blending(
@ -199,3 +328,91 @@ def merge_tiles_with_linear_blending(
),
mask=mask,
)
def merge_tiles_with_seam_blending(
dst_image: np.ndarray, tiles: list[Tile], tile_images: list[np.ndarray], blend_amount: int
):
"""Merge a set of image tiles into `dst_image` with seam blending between the tiles.
We expect every tile edge to either:
1) have an overlap of 0, because it is aligned with the image edge, or
2) have an overlap >= blend_amount.
If neither of these conditions are satisfied, we raise an exception.
The seam blending is centered on a seam of least energy of the overlap between adjacent tiles.
Args:
dst_image (np.ndarray): The destination image. Shape: (H, W, C).
tiles (list[Tile]): The list of tiles describing the locations of the respective `tile_images`.
tile_images (list[np.ndarray]): The tile images to merge into `dst_image`.
blend_amount (int): The amount of blending (in px) between adjacent overlapping tiles.
"""
# Sort tiles and images first by left x coordinate, then by top y coordinate. During tile processing, we want to
# iterate over tiles left-to-right, top-to-bottom.
tiles_and_images = list(zip(tiles, tile_images, strict=True))
tiles_and_images = sorted(tiles_and_images, key=lambda x: x[0].coords.left)
tiles_and_images = sorted(tiles_and_images, key=lambda x: x[0].coords.top)
# Organize tiles into rows.
tile_and_image_rows: list[list[tuple[Tile, np.ndarray]]] = []
cur_tile_and_image_row: list[tuple[Tile, np.ndarray]] = []
first_tile_in_cur_row, _ = tiles_and_images[0]
for tile_and_image in tiles_and_images:
tile, _ = tile_and_image
if not (
tile.coords.top == first_tile_in_cur_row.coords.top
and tile.coords.bottom == first_tile_in_cur_row.coords.bottom
):
# Store the previous row, and start a new one.
tile_and_image_rows.append(cur_tile_and_image_row)
cur_tile_and_image_row = []
first_tile_in_cur_row, _ = tile_and_image
cur_tile_and_image_row.append(tile_and_image)
tile_and_image_rows.append(cur_tile_and_image_row)
for tile_and_image_row in tile_and_image_rows:
first_tile_in_row, _ = tile_and_image_row[0]
row_height = first_tile_in_row.coords.bottom - first_tile_in_row.coords.top
row_image = np.zeros((row_height, dst_image.shape[1], dst_image.shape[2]), dtype=dst_image.dtype)
# Blend the tiles in the row horizontally.
for tile, tile_image in tile_and_image_row:
# We expect the tiles to be ordered left-to-right.
# For each tile:
# - extract the overlap regions and pass to seam_blend()
# - apply blended region to the row_image
# - apply the un-blended region to the row_image
tile_height, tile_width, _ = tile_image.shape
overlap_size = tile.overlap.left
# Left blending:
if overlap_size > 0:
assert overlap_size >= blend_amount
overlap_coord_right = tile.coords.left + overlap_size
src_overlap = row_image[:, tile.coords.left : overlap_coord_right]
dst_overlap = tile_image[:, :overlap_size]
blended_overlap = seam_blend(src_overlap, dst_overlap, blend_amount, x_seam=False)
row_image[:, tile.coords.left : overlap_coord_right] = blended_overlap
row_image[:, overlap_coord_right : tile.coords.right] = tile_image[:, overlap_size:]
else:
# no overlap just paste the tile
row_image[:, tile.coords.left : tile.coords.right] = tile_image
# Blend the row into the dst_image
# We assume that the entire row has the same vertical overlaps as the first_tile_in_row.
# Rows are processed in the same way as tiles (extract overlap, blend, apply)
row_overlap_size = first_tile_in_row.overlap.top
if row_overlap_size > 0:
assert row_overlap_size >= blend_amount
overlap_coords_bottom = first_tile_in_row.coords.top + row_overlap_size
src_overlap = dst_image[first_tile_in_row.coords.top : overlap_coords_bottom, :]
dst_overlap = row_image[:row_overlap_size, :]
blended_overlap = seam_blend(src_overlap, dst_overlap, blend_amount, x_seam=True)
dst_image[first_tile_in_row.coords.top : overlap_coords_bottom, :] = blended_overlap
dst_image[overlap_coords_bottom : first_tile_in_row.coords.bottom, :] = row_image[row_overlap_size:, :]
else:
# no overlap just paste the row
dst_image[first_tile_in_row.coords.top : first_tile_in_row.coords.bottom, :] = row_image

View File

@ -1,5 +1,7 @@
import math
from typing import Optional
import cv2
import numpy as np
from pydantic import BaseModel, Field
@ -31,10 +33,10 @@ def paste(dst_image: np.ndarray, src_image: np.ndarray, box: TBLR, mask: Optiona
"""Paste a source image into a destination image.
Args:
dst_image (torch.Tensor): The destination image to paste into. Shape: (H, W, C).
src_image (torch.Tensor): The source image to paste. Shape: (H, W, C). H and W must be compatible with 'box'.
dst_image (np.array): The destination image to paste into. Shape: (H, W, C).
src_image (np.array): The source image to paste. Shape: (H, W, C). H and W must be compatible with 'box'.
box (TBLR): Box defining the region in the 'dst_image' where 'src_image' will be pasted.
mask (Optional[torch.Tensor]): A mask that defines the blending between 'src_image' and 'dst_image'.
mask (Optional[np.array]): A mask that defines the blending between 'src_image' and 'dst_image'.
Range: [0.0, 1.0], Shape: (H, W). The output is calculate per-pixel according to
`src * mask + dst * (1 - mask)`.
"""
@ -45,3 +47,106 @@ def paste(dst_image: np.ndarray, src_image: np.ndarray, box: TBLR, mask: Optiona
mask = np.expand_dims(mask, -1)
dst_image_box = dst_image[box.top : box.bottom, box.left : box.right]
dst_image[box.top : box.bottom, box.left : box.right] = src_image * mask + dst_image_box * (1.0 - mask)
def seam_blend(ia1: np.ndarray, ia2: np.ndarray, blend_amount: int, x_seam: bool) -> np.ndarray:
"""Blend two overlapping tile sections using a seams to find a path.
It is assumed that input images will be RGB np arrays and are the same size.
Args:
ia1 (np.array): Image array 1 Shape: (H, W, C).
ia2 (np.array): Image array 2 Shape: (H, W, C).
x_seam (bool): If the images should be blended on the x axis or not.
blend_amount (int): The size of the blur to use on the seam. Half of this value will be used to avoid the edges of the image.
"""
assert ia1.shape == ia2.shape
assert ia2.size == ia2.size
def shift(arr, num, fill_value=255.0):
result = np.full_like(arr, fill_value)
if num > 0:
result[num:] = arr[:-num]
elif num < 0:
result[:num] = arr[-num:]
else:
result[:] = arr
return result
# Assume RGB and convert to grey
# Could offer other options for the luminance conversion
# BT.709 [0.2126, 0.7152, 0.0722], BT.2020 [0.2627, 0.6780, 0.0593])
# it might not have a huge impact due to the blur that is applied over the seam
iag1 = np.dot(ia1, [0.2989, 0.5870, 0.1140]) # BT.601 perceived brightness
iag2 = np.dot(ia2, [0.2989, 0.5870, 0.1140])
# Calc Difference between the images
ia = iag2 - iag1
# If the seam is on the X-axis rotate the array so we can treat it like a vertical seam
if x_seam:
ia = np.rot90(ia, 1)
# Calc max and min X & Y limits
# gutter is used to avoid the blur hitting the edge of the image
gutter = math.ceil(blend_amount / 2) if blend_amount > 0 else 0
max_y, max_x = ia.shape
max_x -= gutter
min_x = gutter
# Calc the energy in the difference
# Could offer different energy calculations e.g. Sobel or Scharr
energy = np.abs(np.gradient(ia, axis=0)) + np.abs(np.gradient(ia, axis=1))
# Find the starting position of the seam
res = np.copy(energy)
for y in range(1, max_y):
row = res[y, :]
rowl = shift(row, -1)
rowr = shift(row, 1)
res[y, :] = res[y - 1, :] + np.min([row, rowl, rowr], axis=0)
# create an array max_y long
lowest_energy_line = np.empty([max_y], dtype="uint16")
lowest_energy_line[max_y - 1] = np.argmin(res[max_y - 1, min_x : max_x - 1])
# Calc the path of the seam
# could offer options for larger search than just 1 pixel by adjusting lpos and rpos
for ypos in range(max_y - 2, -1, -1):
lowest_pos = lowest_energy_line[ypos + 1]
lpos = lowest_pos - 1
rpos = lowest_pos + 1
lpos = np.clip(lpos, min_x, max_x - 1)
rpos = np.clip(rpos, min_x, max_x - 1)
lowest_energy_line[ypos] = np.argmin(energy[ypos, lpos : rpos + 1]) + lpos
# Draw the mask
mask = np.zeros_like(ia)
for ypos in range(0, max_y):
to_fill = lowest_energy_line[ypos]
mask[ypos, :to_fill] = 1
# If the seam is on the X-axis rotate the array back
if x_seam:
mask = np.rot90(mask, 3)
# blur the seam mask if required
if blend_amount > 0:
mask = cv2.blur(mask, (blend_amount, blend_amount))
# for visual debugging
# from PIL import Image
# m_image = Image.fromarray((mask * 255.0).astype("uint8"))
# copy ia2 over ia1 while applying the seam mask
mask = np.expand_dims(mask, -1)
blended_image = ia1 * mask + ia2 * (1.0 - mask)
# for visual debugging
# i1 = Image.fromarray(ia1.astype("uint8"))
# i2 = Image.fromarray(ia2.astype("uint8"))
# b_image = Image.fromarray(blended_image.astype("uint8"))
# print(f"{ia1.shape}, {ia2.shape}, {mask.shape}, {blended_image.shape}")
# print(f"{i1.size}, {i2.size}, {m_image.size}, {b_image.size}")
return blended_image

View File

@ -1032,7 +1032,9 @@
"workflowValidation": "Workflow Validation Error",
"workflowVersion": "Version",
"zoomInNodes": "Zoom In",
"zoomOutNodes": "Zoom Out"
"zoomOutNodes": "Zoom Out",
"betaDesc": "This invocation is in beta. Until it is stable, it may have breaking changes during app updates. We plan to support this invocation long-term.",
"prototypeDesc": "This invocation is a prototype. It may have breaking changes during app updates and may be removed at any time."
},
"parameters": {
"aspectRatio": "Aspect Ratio",

View File

@ -109,7 +109,8 @@
"somethingWentWrong": "出了点问题",
"copyError": "$t(gallery.copy) 错误",
"input": "输入",
"notInstalled": "非 $t(common.installed)"
"notInstalled": "非 $t(common.installed)",
"delete": "删除"
},
"gallery": {
"generations": "生成的图像",

View File

@ -3,6 +3,7 @@ import { useStore } from '@nanostores/react';
import { $customStarUI } from 'app/store/nanostores/customStarUI';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIDndImage from 'common/components/IAIDndImage';
import IAIDndImageIcon from 'common/components/IAIDndImageIcon';
import IAIFillSkeleton from 'common/components/IAIFillSkeleton';
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
import {
@ -10,7 +11,9 @@ import {
ImageDraggableData,
TypesafeDraggableData,
} from 'features/dnd/types';
import { VirtuosoGalleryContext } from 'features/gallery/components/ImageGrid/types';
import { useMultiselect } from 'features/gallery/hooks/useMultiselect';
import { useScrollToVisible } from 'features/gallery/hooks/useScrollToVisible';
import { MouseEvent, memo, useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { FaTrash } from 'react-icons/fa';
@ -20,15 +23,16 @@ import {
useStarImagesMutation,
useUnstarImagesMutation,
} from 'services/api/endpoints/images';
import IAIDndImageIcon from 'common/components/IAIDndImageIcon';
interface HoverableImageProps {
imageName: string;
index: number;
virtuosoContext: VirtuosoGalleryContext;
}
const GalleryImage = (props: HoverableImageProps) => {
const dispatch = useAppDispatch();
const { imageName } = props;
const { imageName, virtuosoContext } = props;
const { currentData: imageDTO } = useGetImageDTOQuery(imageName);
const shift = useAppSelector((state) => state.hotkeys.shift);
const { t } = useTranslation();
@ -38,6 +42,13 @@ const GalleryImage = (props: HoverableImageProps) => {
const customStarUi = useStore($customStarUI);
const imageContainerRef = useScrollToVisible(
isSelected,
props.index,
selectionCount,
virtuosoContext
);
const handleDelete = useCallback(
(e: MouseEvent<HTMLButtonElement>) => {
e.stopPropagation();
@ -122,6 +133,7 @@ const GalleryImage = (props: HoverableImageProps) => {
data-testid={`image-${imageDTO.image_name}`}
>
<Flex
ref={imageContainerRef}
userSelect="none"
sx={{
position: 'relative',

View File

@ -1,7 +1,10 @@
import { Box, Flex } from '@chakra-ui/react';
import { EntityId } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { VirtuosoGalleryContext } from 'features/gallery/components/ImageGrid/types';
import { $useNextPrevImageState } from 'features/gallery/hooks/useNextPrevImage';
import { selectListImagesBaseQueryArgs } from 'features/gallery/store/gallerySelectors';
import { IMAGE_LIMIT } from 'features/gallery/store/types';
import {
@ -11,7 +14,12 @@ import {
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { FaExclamationCircle, FaImage } from 'react-icons/fa';
import { VirtuosoGrid } from 'react-virtuoso';
import {
ItemContent,
ListRange,
VirtuosoGrid,
VirtuosoGridHandle,
} from 'react-virtuoso';
import {
useLazyListImagesQuery,
useListImagesQuery,
@ -20,7 +28,6 @@ import { useBoardTotal } from 'services/api/hooks/useBoardTotal';
import GalleryImage from './GalleryImage';
import ImageGridItemContainer from './ImageGridItemContainer';
import ImageGridListContainer from './ImageGridListContainer';
import { EntityId } from '@reduxjs/toolkit';
const overlayScrollbarsConfig: UseOverlayScrollbarsParams = {
defer: true,
@ -48,6 +55,10 @@ const GalleryImageGrid = () => {
const { currentViewTotal } = useBoardTotal(selectedBoardId);
const queryArgs = useAppSelector(selectListImagesBaseQueryArgs);
const virtuosoRangeRef = useRef<ListRange | null>(null);
const virtuosoRef = useRef<VirtuosoGridHandle>(null);
const { currentData, isFetching, isSuccess, isError } =
useListImagesQuery(queryArgs);
@ -72,9 +83,23 @@ const GalleryImageGrid = () => {
});
}, [areMoreAvailable, listImages, queryArgs, currentData?.ids.length]);
const itemContentFunc = useCallback(
(index: number, imageName: EntityId) => (
<GalleryImage key={imageName} imageName={imageName as string} />
const virtuosoContext = useMemo<VirtuosoGalleryContext>(() => {
return {
virtuosoRef,
rootRef,
virtuosoRangeRef,
};
}, []);
const itemContentFunc: ItemContent<EntityId, VirtuosoGalleryContext> =
useCallback(
(index, imageName, virtuosoContext) => (
<GalleryImage
key={imageName}
index={index}
imageName={imageName as string}
virtuosoContext={virtuosoContext}
/>
),
[]
);
@ -93,6 +118,15 @@ const GalleryImageGrid = () => {
return () => osInstance()?.destroy();
}, [scroller, initialize, osInstance]);
const onRangeChanged = useCallback((range: ListRange) => {
virtuosoRangeRef.current = range;
}, []);
useEffect(() => {
$useNextPrevImageState.setKey('virtuosoRef', virtuosoRef);
$useNextPrevImageState.setKey('virtuosoRangeRef', virtuosoRangeRef);
}, []);
if (!currentData) {
return (
<Flex
@ -140,6 +174,10 @@ const GalleryImageGrid = () => {
}}
scrollerRef={setScroller}
itemContent={itemContentFunc}
ref={virtuosoRef}
rangeChanged={onRangeChanged}
context={virtuosoContext}
overscan={10}
/>
</Box>
<IAIButton

View File

@ -0,0 +1,8 @@
import { RefObject } from 'react';
import { ListRange, VirtuosoGridHandle } from 'react-virtuoso';
export type VirtuosoGalleryContext = {
virtuosoRef: RefObject<VirtuosoGridHandle>;
rootRef: RefObject<HTMLDivElement>;
virtuosoRangeRef: RefObject<ListRange>;
};

View File

@ -1,7 +1,7 @@
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetImageWorkflowQuery } from 'services/api/endpoints/images';
import { useDebouncedImageWorkflow } from 'services/api/hooks/useDebouncedImageWorkflow';
import { ImageDTO } from 'services/api/types';
import DataViewer from './DataViewer';
@ -11,7 +11,7 @@ type Props = {
const ImageMetadataWorkflowTabContent = ({ image }: Props) => {
const { t } = useTranslation();
const { currentData: workflow } = useGetImageWorkflowQuery(image.image_name);
const { workflow } = useDebouncedImageWorkflow(image);
if (!workflow) {
return <IAINoContentFallback label={t('nodes.noWorkflow')} />;

View File

@ -4,8 +4,11 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { selectListImagesBaseQueryArgs } from 'features/gallery/store/gallerySelectors';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { IMAGE_LIMIT } from 'features/gallery/store/types';
import { getScrollToIndexAlign } from 'features/gallery/util/getScrollToIndexAlign';
import { clamp } from 'lodash-es';
import { useCallback } from 'react';
import { map } from 'nanostores';
import { RefObject, useCallback } from 'react';
import { ListRange, VirtuosoGridHandle } from 'react-virtuoso';
import { boardsApi } from 'services/api/endpoints/boards';
import {
imagesApi,
@ -14,6 +17,16 @@ import {
import { ListImagesArgs } from 'services/api/types';
import { imagesAdapter } from 'services/api/util';
export type UseNextPrevImageState = {
virtuosoRef: RefObject<VirtuosoGridHandle> | undefined;
virtuosoRangeRef: RefObject<ListRange> | undefined;
};
export const $useNextPrevImageState = map<UseNextPrevImageState>({
virtuosoRef: undefined,
virtuosoRangeRef: undefined,
});
export const nextPrevImageButtonsSelector = createMemoizedSelector(
[stateSelector, selectListImagesBaseQueryArgs],
(state, baseQueryArgs) => {
@ -78,6 +91,8 @@ export const nextPrevImageButtonsSelector = createMemoizedSelector(
isFetching: status === 'pending',
nextImage,
prevImage,
nextImageIndex,
prevImageIndex,
queryArgs,
};
}
@ -88,7 +103,9 @@ export const useNextPrevImage = () => {
const {
nextImage,
nextImageIndex,
prevImage,
prevImageIndex,
areMoreImagesAvailable,
isFetching,
queryArgs,
@ -98,11 +115,43 @@ export const useNextPrevImage = () => {
const handlePrevImage = useCallback(() => {
prevImage && dispatch(imageSelected(prevImage));
}, [dispatch, prevImage]);
const range = $useNextPrevImageState.get().virtuosoRangeRef?.current;
const virtuoso = $useNextPrevImageState.get().virtuosoRef?.current;
if (!range || !virtuoso) {
return;
}
if (
prevImageIndex !== undefined &&
(prevImageIndex < range.startIndex || prevImageIndex > range.endIndex)
) {
virtuoso.scrollToIndex({
index: prevImageIndex,
behavior: 'smooth',
align: getScrollToIndexAlign(prevImageIndex, range),
});
}
}, [dispatch, prevImage, prevImageIndex]);
const handleNextImage = useCallback(() => {
nextImage && dispatch(imageSelected(nextImage));
}, [dispatch, nextImage]);
const range = $useNextPrevImageState.get().virtuosoRangeRef?.current;
const virtuoso = $useNextPrevImageState.get().virtuosoRef?.current;
if (!range || !virtuoso) {
return;
}
if (
nextImageIndex !== undefined &&
(nextImageIndex < range.startIndex || nextImageIndex > range.endIndex)
) {
virtuoso.scrollToIndex({
index: nextImageIndex,
behavior: 'smooth',
align: getScrollToIndexAlign(nextImageIndex, range),
});
}
}, [dispatch, nextImage, nextImageIndex]);
const [listImages] = useLazyListImagesQuery();

View File

@ -0,0 +1,46 @@
import { VirtuosoGalleryContext } from 'features/gallery/components/ImageGrid/types';
import { getScrollToIndexAlign } from 'features/gallery/util/getScrollToIndexAlign';
import { useEffect, useRef } from 'react';
export const useScrollToVisible = (
isSelected: boolean,
index: number,
selectionCount: number,
virtuosoContext: VirtuosoGalleryContext
) => {
const imageContainerRef = useRef<HTMLDivElement>(null);
useEffect(() => {
if (
!isSelected ||
selectionCount !== 1 ||
!virtuosoContext.rootRef.current ||
!virtuosoContext.virtuosoRef.current ||
!virtuosoContext.virtuosoRangeRef.current ||
!imageContainerRef.current
) {
return;
}
const itemRect = imageContainerRef.current.getBoundingClientRect();
const rootRect = virtuosoContext.rootRef.current.getBoundingClientRect();
const itemIsVisible =
itemRect.top >= rootRect.top &&
itemRect.bottom <= rootRect.bottom &&
itemRect.left >= rootRect.left &&
itemRect.right <= rootRect.right;
if (!itemIsVisible) {
virtuosoContext.virtuosoRef.current.scrollToIndex({
index,
behavior: 'smooth',
align: getScrollToIndexAlign(
index,
virtuosoContext.virtuosoRangeRef.current
),
});
}
}, [isSelected, index, selectionCount, virtuosoContext]);
return imageContainerRef;
};

View File

@ -0,0 +1,17 @@
import { ListRange } from 'react-virtuoso';
/**
* Gets the alignment for react-virtuoso's scrollToIndex function.
* @param index The index of the item.
* @param range The range of items currently visible.
* @returns
*/
export const getScrollToIndexAlign = (
index: number,
range: ListRange
): 'start' | 'end' => {
if (index > (range.endIndex - range.startIndex) / 2 + range.startIndex) {
return 'end';
}
return 'start';
};

View File

@ -0,0 +1,68 @@
import { Icon, Tooltip } from '@chakra-ui/react';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { FaFlask } from 'react-icons/fa';
import { useNodeClassification } from 'features/nodes/hooks/useNodeClassification';
import { Classification } from 'features/nodes/types/common';
import { FaHammer } from 'react-icons/fa6';
interface Props {
nodeId: string;
}
const InvocationNodeClassificationIcon = ({ nodeId }: Props) => {
const classification = useNodeClassification(nodeId);
if (!classification || classification === 'stable') {
return null;
}
return (
<Tooltip
label={<ClassificationTooltipContent classification={classification} />}
placement="top"
shouldWrapChildren
>
<Icon
as={getIcon(classification)}
sx={{
display: 'block',
boxSize: 4,
color: 'base.400',
}}
/>
</Tooltip>
);
};
export default memo(InvocationNodeClassificationIcon);
const ClassificationTooltipContent = memo(
({ classification }: { classification: Classification }) => {
const { t } = useTranslation();
if (classification === 'beta') {
return t('nodes.betaDesc');
}
if (classification === 'prototype') {
return t('nodes.prototypeDesc');
}
return null;
}
);
ClassificationTooltipContent.displayName = 'ClassificationTooltipContent';
const getIcon = (classification: Classification) => {
if (classification === 'beta') {
return FaHammer;
}
if (classification === 'prototype') {
return FaFlask;
}
return undefined;
};

View File

@ -5,6 +5,7 @@ import NodeTitle from 'features/nodes/components/flow/nodes/common/NodeTitle';
import InvocationNodeCollapsedHandles from './InvocationNodeCollapsedHandles';
import InvocationNodeInfoIcon from './InvocationNodeInfoIcon';
import InvocationNodeStatusIndicator from './InvocationNodeStatusIndicator';
import InvocationNodeClassificationIcon from 'features/nodes/components/flow/nodes/Invocation/InvocationNodeClassificationIcon';
type Props = {
nodeId: string;
@ -31,6 +32,7 @@ const InvocationNodeHeader = ({ nodeId, isOpen }: Props) => {
}}
>
<NodeCollapseButton nodeId={nodeId} isOpen={isOpen} />
<InvocationNodeClassificationIcon nodeId={nodeId} />
<NodeTitle nodeId={nodeId} />
<Flex alignItems="center">
<InvocationNodeStatusIndicator nodeId={nodeId} />

View File

@ -0,0 +1,23 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { useMemo } from 'react';
export const useNodeClassification = (nodeId: string) => {
const selector = useMemo(
() =>
createMemoizedSelector(stateSelector, ({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return false;
}
const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? ''];
return nodeTemplate?.classification;
}),
[nodeId]
);
const title = useAppSelector(selector);
return title;
};

View File

@ -19,6 +19,9 @@ export const zColorField = z.object({
});
export type ColorField = z.infer<typeof zColorField>;
export const zClassification = z.enum(['stable', 'beta', 'prototype']);
export type Classification = z.infer<typeof zClassification>;
export const zSchedulerField = z.enum([
'euler',
'deis',

View File

@ -1,6 +1,6 @@
import { Edge, Node } from 'reactflow';
import { z } from 'zod';
import { zProgressImage } from './common';
import { zClassification, zProgressImage } from './common';
import {
zFieldInputInstance,
zFieldInputTemplate,
@ -21,6 +21,7 @@ export const zInvocationTemplate = z.object({
version: zSemVer,
useCache: z.boolean(),
nodePack: z.string().min(1).nullish(),
classification: zClassification,
});
export type InvocationTemplate = z.infer<typeof zInvocationTemplate>;
// #endregion

View File

@ -83,6 +83,7 @@ export const parseSchema = (
const description = schema.description ?? '';
const version = schema.version;
const nodePack = schema.node_pack;
const classification = schema.classification;
const inputs = reduce(
schema.properties,
@ -245,6 +246,7 @@ export const parseSchema = (
outputs,
useCache,
nodePack,
classification,
};
Object.assign(invocationsAccumulator, { [type]: invocation });

View File

@ -0,0 +1,22 @@
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from 'app/store/storeHooks';
import { useGetImageWorkflowQuery } from 'services/api/endpoints/images';
import { ImageDTO } from 'services/api/types';
import { useDebounce } from 'use-debounce';
export const useDebouncedImageWorkflow = (imageDTO?: ImageDTO | null) => {
const workflowFetchDebounce = useAppSelector(
(state) => state.config.workflowFetchDebounce ?? 300
);
const [debouncedImageName] = useDebounce(
imageDTO?.has_workflow ? imageDTO.image_name : null,
workflowFetchDebounce
);
const { data: workflow, isLoading } = useGetImageWorkflowQuery(
debouncedImageName ?? skipToken
);
return { workflow, isLoading };
};

View File

@ -1,17 +1,14 @@
import { skipToken } from '@reduxjs/toolkit/query';
import { useDebounce } from 'use-debounce';
import { useGetImageMetadataQuery } from 'services/api/endpoints/images';
import { useAppSelector } from 'app/store/storeHooks';
import { useGetImageMetadataQuery } from 'services/api/endpoints/images';
import { useDebounce } from 'use-debounce';
export const useDebouncedMetadata = (imageName?: string | null) => {
const metadataFetchDebounce = useAppSelector(
(state) => state.config.metadataFetchDebounce
(state) => state.config.metadataFetchDebounce ?? 300
);
const [debouncedImageName] = useDebounce(
imageName,
metadataFetchDebounce ?? 0
);
const [debouncedImageName] = useDebounce(imageName, metadataFetchDebounce);
const { data: metadata, isLoading } = useGetImageMetadataQuery(
debouncedImageName ?? skipToken

File diff suppressed because one or more lines are too long

View File

@ -1,7 +1,12 @@
import numpy as np
import pytest
from invokeai.backend.tiles.tiles import calc_tiles_with_overlap, merge_tiles_with_linear_blending
from invokeai.backend.tiles.tiles import (
calc_tiles_even_split,
calc_tiles_min_overlap,
calc_tiles_with_overlap,
merge_tiles_with_linear_blending,
)
from invokeai.backend.tiles.utils import TBLR, Tile
####################################
@ -14,7 +19,10 @@ def test_calc_tiles_with_overlap_single_tile():
tiles = calc_tiles_with_overlap(image_height=512, image_width=1024, tile_height=512, tile_width=1024, overlap=64)
expected_tiles = [
Tile(coords=TBLR(top=0, bottom=512, left=0, right=1024), overlap=TBLR(top=0, bottom=0, left=0, right=0))
Tile(
coords=TBLR(top=0, bottom=512, left=0, right=1024),
overlap=TBLR(top=0, bottom=0, left=0, right=0),
)
]
assert tiles == expected_tiles
@ -27,13 +35,31 @@ def test_calc_tiles_with_overlap_evenly_divisible():
expected_tiles = [
# Row 0
Tile(coords=TBLR(top=0, bottom=320, left=0, right=576), overlap=TBLR(top=0, bottom=64, left=0, right=64)),
Tile(coords=TBLR(top=0, bottom=320, left=512, right=1088), overlap=TBLR(top=0, bottom=64, left=64, right=64)),
Tile(coords=TBLR(top=0, bottom=320, left=1024, right=1600), overlap=TBLR(top=0, bottom=64, left=64, right=0)),
Tile(
coords=TBLR(top=0, bottom=320, left=0, right=576),
overlap=TBLR(top=0, bottom=64, left=0, right=64),
),
Tile(
coords=TBLR(top=0, bottom=320, left=512, right=1088),
overlap=TBLR(top=0, bottom=64, left=64, right=64),
),
Tile(
coords=TBLR(top=0, bottom=320, left=1024, right=1600),
overlap=TBLR(top=0, bottom=64, left=64, right=0),
),
# Row 1
Tile(coords=TBLR(top=256, bottom=576, left=0, right=576), overlap=TBLR(top=64, bottom=0, left=0, right=64)),
Tile(coords=TBLR(top=256, bottom=576, left=512, right=1088), overlap=TBLR(top=64, bottom=0, left=64, right=64)),
Tile(coords=TBLR(top=256, bottom=576, left=1024, right=1600), overlap=TBLR(top=64, bottom=0, left=64, right=0)),
Tile(
coords=TBLR(top=256, bottom=576, left=0, right=576),
overlap=TBLR(top=64, bottom=0, left=0, right=64),
),
Tile(
coords=TBLR(top=256, bottom=576, left=512, right=1088),
overlap=TBLR(top=64, bottom=0, left=64, right=64),
),
Tile(
coords=TBLR(top=256, bottom=576, left=1024, right=1600),
overlap=TBLR(top=64, bottom=0, left=64, right=0),
),
]
assert tiles == expected_tiles
@ -46,16 +72,30 @@ def test_calc_tiles_with_overlap_not_evenly_divisible():
expected_tiles = [
# Row 0
Tile(coords=TBLR(top=0, bottom=256, left=0, right=512), overlap=TBLR(top=0, bottom=112, left=0, right=64)),
Tile(coords=TBLR(top=0, bottom=256, left=448, right=960), overlap=TBLR(top=0, bottom=112, left=64, right=272)),
Tile(coords=TBLR(top=0, bottom=256, left=688, right=1200), overlap=TBLR(top=0, bottom=112, left=272, right=0)),
# Row 1
Tile(coords=TBLR(top=144, bottom=400, left=0, right=512), overlap=TBLR(top=112, bottom=0, left=0, right=64)),
Tile(
coords=TBLR(top=144, bottom=400, left=448, right=960), overlap=TBLR(top=112, bottom=0, left=64, right=272)
coords=TBLR(top=0, bottom=256, left=0, right=512),
overlap=TBLR(top=0, bottom=112, left=0, right=64),
),
Tile(
coords=TBLR(top=144, bottom=400, left=688, right=1200), overlap=TBLR(top=112, bottom=0, left=272, right=0)
coords=TBLR(top=0, bottom=256, left=448, right=960),
overlap=TBLR(top=0, bottom=112, left=64, right=272),
),
Tile(
coords=TBLR(top=0, bottom=256, left=688, right=1200),
overlap=TBLR(top=0, bottom=112, left=272, right=0),
),
# Row 1
Tile(
coords=TBLR(top=144, bottom=400, left=0, right=512),
overlap=TBLR(top=112, bottom=0, left=0, right=64),
),
Tile(
coords=TBLR(top=144, bottom=400, left=448, right=960),
overlap=TBLR(top=112, bottom=0, left=64, right=272),
),
Tile(
coords=TBLR(top=144, bottom=400, left=688, right=1200),
overlap=TBLR(top=112, bottom=0, left=272, right=0),
),
]
@ -75,7 +115,12 @@ def test_calc_tiles_with_overlap_not_evenly_divisible():
],
)
def test_calc_tiles_with_overlap_input_validation(
image_height: int, image_width: int, tile_height: int, tile_width: int, overlap: int, raises: bool
image_height: int,
image_width: int,
tile_height: int,
tile_width: int,
overlap: int,
raises: bool,
):
"""Test that calc_tiles_with_overlap() raises an exception if the inputs are invalid."""
if raises:
@ -85,6 +130,306 @@ def test_calc_tiles_with_overlap_input_validation(
calc_tiles_with_overlap(image_height, image_width, tile_height, tile_width, overlap)
####################################
# Test calc_tiles_min_overlap(...)
####################################
def test_calc_tiles_min_overlap_single_tile():
"""Test calc_tiles_min_overlap() behavior when a single tile covers the image."""
tiles = calc_tiles_min_overlap(
image_height=512,
image_width=1024,
tile_height=512,
tile_width=1024,
min_overlap=64,
)
expected_tiles = [
Tile(
coords=TBLR(top=0, bottom=512, left=0, right=1024),
overlap=TBLR(top=0, bottom=0, left=0, right=0),
)
]
assert tiles == expected_tiles
def test_calc_tiles_min_overlap_evenly_divisible():
"""Test calc_tiles_min_overlap() behavior when the image is evenly covered by multiple tiles."""
# Parameters mimic roughly the same output as the original tile generations of the same test name
tiles = calc_tiles_min_overlap(
image_height=576,
image_width=1600,
tile_height=320,
tile_width=576,
min_overlap=64,
)
expected_tiles = [
# Row 0
Tile(
coords=TBLR(top=0, bottom=320, left=0, right=576),
overlap=TBLR(top=0, bottom=64, left=0, right=64),
),
Tile(
coords=TBLR(top=0, bottom=320, left=512, right=1088),
overlap=TBLR(top=0, bottom=64, left=64, right=64),
),
Tile(
coords=TBLR(top=0, bottom=320, left=1024, right=1600),
overlap=TBLR(top=0, bottom=64, left=64, right=0),
),
# Row 1
Tile(
coords=TBLR(top=256, bottom=576, left=0, right=576),
overlap=TBLR(top=64, bottom=0, left=0, right=64),
),
Tile(
coords=TBLR(top=256, bottom=576, left=512, right=1088),
overlap=TBLR(top=64, bottom=0, left=64, right=64),
),
Tile(
coords=TBLR(top=256, bottom=576, left=1024, right=1600),
overlap=TBLR(top=64, bottom=0, left=64, right=0),
),
]
assert tiles == expected_tiles
def test_calc_tiles_min_overlap_not_evenly_divisible():
"""Test calc_tiles_min_overlap() behavior when the image requires 'uneven' overlaps to achieve proper coverage."""
# Parameters mimic roughly the same output as the original tile generations of the same test name
tiles = calc_tiles_min_overlap(
image_height=400,
image_width=1200,
tile_height=256,
tile_width=512,
min_overlap=64,
)
expected_tiles = [
# Row 0
Tile(
coords=TBLR(top=0, bottom=256, left=0, right=512),
overlap=TBLR(top=0, bottom=112, left=0, right=168),
),
Tile(
coords=TBLR(top=0, bottom=256, left=344, right=856),
overlap=TBLR(top=0, bottom=112, left=168, right=168),
),
Tile(
coords=TBLR(top=0, bottom=256, left=688, right=1200),
overlap=TBLR(top=0, bottom=112, left=168, right=0),
),
# Row 1
Tile(
coords=TBLR(top=144, bottom=400, left=0, right=512),
overlap=TBLR(top=112, bottom=0, left=0, right=168),
),
Tile(
coords=TBLR(top=144, bottom=400, left=344, right=856),
overlap=TBLR(top=112, bottom=0, left=168, right=168),
),
Tile(
coords=TBLR(top=144, bottom=400, left=688, right=1200),
overlap=TBLR(top=112, bottom=0, left=168, right=0),
),
]
assert tiles == expected_tiles
@pytest.mark.parametrize(
[
"image_height",
"image_width",
"tile_height",
"tile_width",
"min_overlap",
"raises",
],
[
(128, 128, 128, 128, 127, False), # OK
(128, 128, 128, 128, 0, False), # OK
(128, 128, 64, 64, 0, False), # OK
(128, 128, 129, 128, 0, False), # tile_height exceeds image_height defaults to 1 tile.
(128, 128, 128, 129, 0, False), # tile_width exceeds image_width defaults to 1 tile.
(128, 128, 64, 128, 64, True), # overlap equals tile_height.
(128, 128, 128, 64, 64, True), # overlap equals tile_width.
],
)
def test_calc_tiles_min_overlap_input_validation(
image_height: int,
image_width: int,
tile_height: int,
tile_width: int,
min_overlap: int,
raises: bool,
):
"""Test that calc_tiles_min_overlap() raises an exception if the inputs are invalid."""
if raises:
with pytest.raises(AssertionError):
calc_tiles_min_overlap(image_height, image_width, tile_height, tile_width, min_overlap)
else:
calc_tiles_min_overlap(image_height, image_width, tile_height, tile_width, min_overlap)
####################################
# Test calc_tiles_even_split(...)
####################################
def test_calc_tiles_even_split_single_tile():
"""Test calc_tiles_even_split() behavior when a single tile covers the image."""
tiles = calc_tiles_even_split(
image_height=512, image_width=1024, num_tiles_x=1, num_tiles_y=1, overlap_fraction=0.25
)
expected_tiles = [
Tile(
coords=TBLR(top=0, bottom=512, left=0, right=1024),
overlap=TBLR(top=0, bottom=0, left=0, right=0),
)
]
assert tiles == expected_tiles
def test_calc_tiles_even_split_evenly_divisible():
"""Test calc_tiles_even_split() behavior when the image is evenly covered by multiple tiles."""
# Parameters mimic roughly the same output as the original tile generations of the same test name
tiles = calc_tiles_even_split(
image_height=576, image_width=1600, num_tiles_x=3, num_tiles_y=2, overlap_fraction=0.25
)
expected_tiles = [
# Row 0
Tile(
coords=TBLR(top=0, bottom=320, left=0, right=624),
overlap=TBLR(top=0, bottom=72, left=0, right=136),
),
Tile(
coords=TBLR(top=0, bottom=320, left=488, right=1112),
overlap=TBLR(top=0, bottom=72, left=136, right=136),
),
Tile(
coords=TBLR(top=0, bottom=320, left=976, right=1600),
overlap=TBLR(top=0, bottom=72, left=136, right=0),
),
# Row 1
Tile(
coords=TBLR(top=248, bottom=576, left=0, right=624),
overlap=TBLR(top=72, bottom=0, left=0, right=136),
),
Tile(
coords=TBLR(top=248, bottom=576, left=488, right=1112),
overlap=TBLR(top=72, bottom=0, left=136, right=136),
),
Tile(
coords=TBLR(top=248, bottom=576, left=976, right=1600),
overlap=TBLR(top=72, bottom=0, left=136, right=0),
),
]
assert tiles == expected_tiles
def test_calc_tiles_even_split_not_evenly_divisible():
"""Test calc_tiles_even_split() behavior when the image requires 'uneven' overlaps to achieve proper coverage."""
# Parameters mimic roughly the same output as the original tile generations of the same test name
tiles = calc_tiles_even_split(
image_height=400, image_width=1200, num_tiles_x=3, num_tiles_y=2, overlap_fraction=0.25
)
expected_tiles = [
# Row 0
Tile(
coords=TBLR(top=0, bottom=224, left=0, right=464),
overlap=TBLR(top=0, bottom=56, left=0, right=104),
),
Tile(
coords=TBLR(top=0, bottom=224, left=360, right=824),
overlap=TBLR(top=0, bottom=56, left=104, right=104),
),
Tile(
coords=TBLR(top=0, bottom=224, left=720, right=1200),
overlap=TBLR(top=0, bottom=56, left=104, right=0),
),
# Row 1
Tile(
coords=TBLR(top=168, bottom=400, left=0, right=464),
overlap=TBLR(top=56, bottom=0, left=0, right=104),
),
Tile(
coords=TBLR(top=168, bottom=400, left=360, right=824),
overlap=TBLR(top=56, bottom=0, left=104, right=104),
),
Tile(
coords=TBLR(top=168, bottom=400, left=720, right=1200),
overlap=TBLR(top=56, bottom=0, left=104, right=0),
),
]
assert tiles == expected_tiles
def test_calc_tiles_even_split_difficult_size():
"""Test calc_tiles_even_split() behavior when the image is a difficult size to spilt evenly and keep div8."""
# Parameters are a difficult size for other tile gen routines to calculate
tiles = calc_tiles_even_split(
image_height=1000, image_width=1000, num_tiles_x=2, num_tiles_y=2, overlap_fraction=0.25
)
expected_tiles = [
# Row 0
Tile(
coords=TBLR(top=0, bottom=560, left=0, right=560),
overlap=TBLR(top=0, bottom=128, left=0, right=128),
),
Tile(
coords=TBLR(top=0, bottom=560, left=432, right=1000),
overlap=TBLR(top=0, bottom=128, left=128, right=0),
),
# Row 1
Tile(
coords=TBLR(top=432, bottom=1000, left=0, right=560),
overlap=TBLR(top=128, bottom=0, left=0, right=128),
),
Tile(
coords=TBLR(top=432, bottom=1000, left=432, right=1000),
overlap=TBLR(top=128, bottom=0, left=128, right=0),
),
]
assert tiles == expected_tiles
@pytest.mark.parametrize(
["image_height", "image_width", "num_tiles_x", "num_tiles_y", "overlap_fraction", "raises"],
[
(128, 128, 1, 1, 0.25, False), # OK
(128, 128, 1, 1, 0, False), # OK
(128, 128, 2, 1, 0, False), # OK
(127, 127, 1, 1, 0, True), # image size must be dividable by 8
],
)
def test_calc_tiles_even_split_input_validation(
image_height: int,
image_width: int,
num_tiles_x: int,
num_tiles_y: int,
overlap_fraction: float,
raises: bool,
):
"""Test that calc_tiles_even_split() raises an exception if the inputs are invalid."""
if raises:
with pytest.raises(ValueError):
calc_tiles_even_split(image_height, image_width, num_tiles_x, num_tiles_y, overlap_fraction)
else:
calc_tiles_even_split(image_height, image_width, num_tiles_x, num_tiles_y, overlap_fraction)
#############################################
# Test merge_tiles_with_linear_blending(...)
#############################################
@ -95,8 +440,14 @@ def test_merge_tiles_with_linear_blending_horizontal(blend_amount: int):
"""Test merge_tiles_with_linear_blending(...) behavior when merging horizontally."""
# Initialize 2 tiles side-by-side.
tiles = [
Tile(coords=TBLR(top=0, bottom=512, left=0, right=512), overlap=TBLR(top=0, bottom=0, left=0, right=64)),
Tile(coords=TBLR(top=0, bottom=512, left=448, right=960), overlap=TBLR(top=0, bottom=0, left=64, right=0)),
Tile(
coords=TBLR(top=0, bottom=512, left=0, right=512),
overlap=TBLR(top=0, bottom=0, left=0, right=64),
),
Tile(
coords=TBLR(top=0, bottom=512, left=448, right=960),
overlap=TBLR(top=0, bottom=0, left=64, right=0),
),
]
dst_image = np.zeros((512, 960, 3), dtype=np.uint8)
@ -116,7 +467,10 @@ def test_merge_tiles_with_linear_blending_horizontal(blend_amount: int):
expected_output[:, 480 + (blend_amount // 2) :, :] = 128
merge_tiles_with_linear_blending(
dst_image=dst_image, tiles=tiles, tile_images=tile_images, blend_amount=blend_amount
dst_image=dst_image,
tiles=tiles,
tile_images=tile_images,
blend_amount=blend_amount,
)
np.testing.assert_array_equal(dst_image, expected_output, strict=True)
@ -127,8 +481,14 @@ def test_merge_tiles_with_linear_blending_vertical(blend_amount: int):
"""Test merge_tiles_with_linear_blending(...) behavior when merging vertically."""
# Initialize 2 tiles stacked vertically.
tiles = [
Tile(coords=TBLR(top=0, bottom=512, left=0, right=512), overlap=TBLR(top=0, bottom=64, left=0, right=0)),
Tile(coords=TBLR(top=448, bottom=960, left=0, right=512), overlap=TBLR(top=64, bottom=0, left=0, right=0)),
Tile(
coords=TBLR(top=0, bottom=512, left=0, right=512),
overlap=TBLR(top=0, bottom=64, left=0, right=0),
),
Tile(
coords=TBLR(top=448, bottom=960, left=0, right=512),
overlap=TBLR(top=64, bottom=0, left=0, right=0),
),
]
dst_image = np.zeros((960, 512, 3), dtype=np.uint8)
@ -148,7 +508,10 @@ def test_merge_tiles_with_linear_blending_vertical(blend_amount: int):
expected_output[480 + (blend_amount // 2) :, :, :] = 128
merge_tiles_with_linear_blending(
dst_image=dst_image, tiles=tiles, tile_images=tile_images, blend_amount=blend_amount
dst_image=dst_image,
tiles=tiles,
tile_images=tile_images,
blend_amount=blend_amount,
)
np.testing.assert_array_equal(dst_image, expected_output, strict=True)
@ -160,8 +523,14 @@ def test_merge_tiles_with_linear_blending_blend_amount_exceeds_vertical_overlap(
"""
# Initialize 2 tiles stacked vertically.
tiles = [
Tile(coords=TBLR(top=0, bottom=512, left=0, right=512), overlap=TBLR(top=0, bottom=64, left=0, right=0)),
Tile(coords=TBLR(top=448, bottom=960, left=0, right=512), overlap=TBLR(top=64, bottom=0, left=0, right=0)),
Tile(
coords=TBLR(top=0, bottom=512, left=0, right=512),
overlap=TBLR(top=0, bottom=64, left=0, right=0),
),
Tile(
coords=TBLR(top=448, bottom=960, left=0, right=512),
overlap=TBLR(top=64, bottom=0, left=0, right=0),
),
]
dst_image = np.zeros((960, 512, 3), dtype=np.uint8)
@ -180,8 +549,14 @@ def test_merge_tiles_with_linear_blending_blend_amount_exceeds_horizontal_overla
"""
# Initialize 2 tiles side-by-side.
tiles = [
Tile(coords=TBLR(top=0, bottom=512, left=0, right=512), overlap=TBLR(top=0, bottom=0, left=0, right=64)),
Tile(coords=TBLR(top=0, bottom=512, left=448, right=960), overlap=TBLR(top=0, bottom=0, left=64, right=0)),
Tile(
coords=TBLR(top=0, bottom=512, left=0, right=512),
overlap=TBLR(top=0, bottom=0, left=0, right=64),
),
Tile(
coords=TBLR(top=0, bottom=512, left=448, right=960),
overlap=TBLR(top=0, bottom=0, left=64, right=0),
),
]
dst_image = np.zeros((512, 960, 3), dtype=np.uint8)
@ -198,7 +573,12 @@ def test_merge_tiles_with_linear_blending_tiles_overflow_dst_image():
"""Test that merge_tiles_with_linear_blending(...) raises an exception if any of the tiles overflows the
dst_image.
"""
tiles = [Tile(coords=TBLR(top=0, bottom=512, left=0, right=512), overlap=TBLR(top=0, bottom=0, left=0, right=0))]
tiles = [
Tile(
coords=TBLR(top=0, bottom=512, left=0, right=512),
overlap=TBLR(top=0, bottom=0, left=0, right=0),
)
]
dst_image = np.zeros((256, 512, 3), dtype=np.uint8)
@ -213,7 +593,12 @@ def test_merge_tiles_with_linear_blending_mismatched_list_lengths():
"""Test that merge_tiles_with_linear_blending(...) raises an exception if the lengths of 'tiles' and 'tile_images'
do not match.
"""
tiles = [Tile(coords=TBLR(top=0, bottom=512, left=0, right=512), overlap=TBLR(top=0, bottom=0, left=0, right=0))]
tiles = [
Tile(
coords=TBLR(top=0, bottom=512, left=0, right=512),
overlap=TBLR(top=0, bottom=0, left=0, right=0),
)
]
dst_image = np.zeros((256, 512, 3), dtype=np.uint8)