mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge remote-tracking branch 'origin/main' into feat/db/migrations
This commit is contained in:
commit
2cdda1fda2
33
Makefile
33
Makefile
@ -1,6 +1,20 @@
|
|||||||
# simple Makefile with scripts that are otherwise hard to remember
|
# simple Makefile with scripts that are otherwise hard to remember
|
||||||
# to use, run from the repo root `make <command>`
|
# to use, run from the repo root `make <command>`
|
||||||
|
|
||||||
|
default: help
|
||||||
|
|
||||||
|
help:
|
||||||
|
@echo Developer commands:
|
||||||
|
@echo
|
||||||
|
@echo "ruff Run ruff, fixing any safely-fixable errors and formatting"
|
||||||
|
@echo "ruff-unsafe Run ruff, fixing all fixable errors and formatting"
|
||||||
|
@echo "mypy Run mypy using the config in pyproject.toml to identify type mismatches and other coding errors"
|
||||||
|
@echo "mypy-all Run mypy ignoring the config in pyproject.tom but still ignoring missing imports"
|
||||||
|
@echo "frontend-build Build the frontend in order to run on localhost:9090"
|
||||||
|
@echo "frontend-dev Run the frontend in developer mode on localhost:5173"
|
||||||
|
@echo "installer-zip Build the installer .zip file for the current version"
|
||||||
|
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
|
||||||
|
|
||||||
# Runs ruff, fixing any safely-fixable errors and formatting
|
# Runs ruff, fixing any safely-fixable errors and formatting
|
||||||
ruff:
|
ruff:
|
||||||
ruff check . --fix
|
ruff check . --fix
|
||||||
@ -18,4 +32,21 @@ mypy:
|
|||||||
# Runs mypy, ignoring the config in pyproject.toml but still ignoring missing (untyped) imports
|
# Runs mypy, ignoring the config in pyproject.toml but still ignoring missing (untyped) imports
|
||||||
# (many files are ignored by the config, so this is useful for checking all files)
|
# (many files are ignored by the config, so this is useful for checking all files)
|
||||||
mypy-all:
|
mypy-all:
|
||||||
mypy scripts/invokeai-web.py --config-file= --ignore-missing-imports
|
mypy scripts/invokeai-web.py --config-file= --ignore-missing-imports
|
||||||
|
|
||||||
|
# Build the frontend
|
||||||
|
frontend-build:
|
||||||
|
cd invokeai/frontend/web && pnpm build
|
||||||
|
|
||||||
|
# Run the frontend in dev mode
|
||||||
|
frontend-dev:
|
||||||
|
cd invokeai/frontend/web && pnpm dev
|
||||||
|
|
||||||
|
# Installer zip file
|
||||||
|
installer-zip:
|
||||||
|
cd installer && ./create_installer.sh
|
||||||
|
|
||||||
|
# Tag the release
|
||||||
|
tag-release:
|
||||||
|
cd installer && ./tag_release.sh
|
||||||
|
|
||||||
|
@ -154,14 +154,16 @@ groups in `invokeia.yaml`:
|
|||||||
|
|
||||||
### Web Server
|
### Web Server
|
||||||
|
|
||||||
| Setting | Default Value | Description |
|
| Setting | Default Value | Description |
|
||||||
|----------|----------------|--------------|
|
|---------------------|---------------|----------------------------------------------------------------------------------------------------------------------------|
|
||||||
| `host` | `localhost` | Name or IP address of the network interface that the web server will listen on |
|
| `host` | `localhost` | Name or IP address of the network interface that the web server will listen on |
|
||||||
| `port` | `9090` | Network port number that the web server will listen on |
|
| `port` | `9090` | Network port number that the web server will listen on |
|
||||||
| `allow_origins` | `[]` | A list of host names or IP addresses that are allowed to connect to the InvokeAI API in the format `['host1','host2',...]` |
|
| `allow_origins` | `[]` | A list of host names or IP addresses that are allowed to connect to the InvokeAI API in the format `['host1','host2',...]` |
|
||||||
| `allow_credentials` | `true` | Require credentials for a foreign host to access the InvokeAI API (don't change this) |
|
| `allow_credentials` | `true` | Require credentials for a foreign host to access the InvokeAI API (don't change this) |
|
||||||
| `allow_methods` | `*` | List of HTTP methods ("GET", "POST") that the web server is allowed to use when accessing the API |
|
| `allow_methods` | `*` | List of HTTP methods ("GET", "POST") that the web server is allowed to use when accessing the API |
|
||||||
| `allow_headers` | `*` | List of HTTP headers that the web server will accept when accessing the API |
|
| `allow_headers` | `*` | List of HTTP headers that the web server will accept when accessing the API |
|
||||||
|
| `ssl_certfile` | null | Path to an SSL certificate file, used to enable HTTPS. |
|
||||||
|
| `ssl_keyfile` | null | Path to an SSL keyfile, if the key is not included in the certificate file. |
|
||||||
|
|
||||||
The documentation for InvokeAI's API can be accessed by browsing to the following URL: [http://localhost:9090/docs].
|
The documentation for InvokeAI's API can be accessed by browsing to the following URL: [http://localhost:9090/docs].
|
||||||
|
|
||||||
|
@ -13,14 +13,6 @@ function is_bin_in_path {
|
|||||||
builtin type -P "$1" &>/dev/null
|
builtin type -P "$1" &>/dev/null
|
||||||
}
|
}
|
||||||
|
|
||||||
function does_tag_exist {
|
|
||||||
git rev-parse --quiet --verify "refs/tags/$1" >/dev/null
|
|
||||||
}
|
|
||||||
|
|
||||||
function git_show_ref {
|
|
||||||
git show-ref --dereference $1 --abbrev 7
|
|
||||||
}
|
|
||||||
|
|
||||||
function git_show {
|
function git_show {
|
||||||
git show -s --format='%h %s' $1
|
git show -s --format='%h %s' $1
|
||||||
}
|
}
|
||||||
@ -53,50 +45,11 @@ VERSION=$(
|
|||||||
)
|
)
|
||||||
PATCH=""
|
PATCH=""
|
||||||
VERSION="v${VERSION}${PATCH}"
|
VERSION="v${VERSION}${PATCH}"
|
||||||
LATEST_TAG="v3-latest"
|
|
||||||
|
|
||||||
echo "Building installer for version $VERSION..."
|
|
||||||
echo
|
|
||||||
|
|
||||||
if does_tag_exist $VERSION; then
|
|
||||||
echo -e "${BCYAN}${VERSION}${RESET} already exists:"
|
|
||||||
git_show_ref tags/$VERSION
|
|
||||||
echo
|
|
||||||
fi
|
|
||||||
if does_tag_exist $LATEST_TAG; then
|
|
||||||
echo -e "${BCYAN}${LATEST_TAG}${RESET} already exists:"
|
|
||||||
git_show_ref tags/$LATEST_TAG
|
|
||||||
echo
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo -e "${BGREEN}HEAD${RESET}:"
|
echo -e "${BGREEN}HEAD${RESET}:"
|
||||||
git_show
|
git_show
|
||||||
echo
|
echo
|
||||||
|
|
||||||
echo -e -n "Create tags ${BCYAN}${VERSION}${RESET} and ${BCYAN}${LATEST_TAG}${RESET} @ ${BGREEN}HEAD${RESET}, ${RED}deleting existing tags on remote${RESET}? "
|
|
||||||
read -e -p 'y/n [n]: ' input
|
|
||||||
RESPONSE=${input:='n'}
|
|
||||||
if [ "$RESPONSE" == 'y' ]; then
|
|
||||||
echo
|
|
||||||
echo -e "Deleting ${BCYAN}${VERSION}${RESET} tag on remote..."
|
|
||||||
git push origin :refs/tags/$VERSION
|
|
||||||
|
|
||||||
echo -e "Tagging ${BGREEN}HEAD${RESET} with ${BCYAN}${VERSION}${RESET} locally..."
|
|
||||||
if ! git tag -fa $VERSION; then
|
|
||||||
echo "Existing/invalid tag"
|
|
||||||
exit -1
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo -e "Deleting ${BCYAN}${LATEST_TAG}${RESET} tag on remote..."
|
|
||||||
git push origin :refs/tags/$LATEST_TAG
|
|
||||||
|
|
||||||
echo -e "Tagging ${BGREEN}HEAD${RESET} with ${BCYAN}${LATEST_TAG}${RESET} locally..."
|
|
||||||
git tag -fa $LATEST_TAG
|
|
||||||
|
|
||||||
echo
|
|
||||||
echo -e "${BYELLOW}Remember to 'git push origin --tags'!${RESET}"
|
|
||||||
fi
|
|
||||||
|
|
||||||
# ---------------------- FRONTEND ----------------------
|
# ---------------------- FRONTEND ----------------------
|
||||||
|
|
||||||
pushd ../invokeai/frontend/web >/dev/null
|
pushd ../invokeai/frontend/web >/dev/null
|
||||||
|
71
installer/tag_release.sh
Executable file
71
installer/tag_release.sh
Executable file
@ -0,0 +1,71 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
BCYAN="\e[1;36m"
|
||||||
|
BYELLOW="\e[1;33m"
|
||||||
|
BGREEN="\e[1;32m"
|
||||||
|
BRED="\e[1;31m"
|
||||||
|
RED="\e[31m"
|
||||||
|
RESET="\e[0m"
|
||||||
|
|
||||||
|
function does_tag_exist {
|
||||||
|
git rev-parse --quiet --verify "refs/tags/$1" >/dev/null
|
||||||
|
}
|
||||||
|
|
||||||
|
function git_show_ref {
|
||||||
|
git show-ref --dereference $1 --abbrev 7
|
||||||
|
}
|
||||||
|
|
||||||
|
function git_show {
|
||||||
|
git show -s --format='%h %s' $1
|
||||||
|
}
|
||||||
|
|
||||||
|
VERSION=$(
|
||||||
|
cd ..
|
||||||
|
python -c "from invokeai.version import __version__ as version; print(version)"
|
||||||
|
)
|
||||||
|
PATCH=""
|
||||||
|
MAJOR_VERSION=$(echo $VERSION | sed 's/\..*$//')
|
||||||
|
VERSION="v${VERSION}${PATCH}"
|
||||||
|
LATEST_TAG="v${MAJOR_VERSION}-latest"
|
||||||
|
|
||||||
|
if does_tag_exist $VERSION; then
|
||||||
|
echo -e "${BCYAN}${VERSION}${RESET} already exists:"
|
||||||
|
git_show_ref tags/$VERSION
|
||||||
|
echo
|
||||||
|
fi
|
||||||
|
if does_tag_exist $LATEST_TAG; then
|
||||||
|
echo -e "${BCYAN}${LATEST_TAG}${RESET} already exists:"
|
||||||
|
git_show_ref tags/$LATEST_TAG
|
||||||
|
echo
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo -e "${BGREEN}HEAD${RESET}:"
|
||||||
|
git_show
|
||||||
|
echo
|
||||||
|
|
||||||
|
echo -e -n "Create tags ${BCYAN}${VERSION}${RESET} and ${BCYAN}${LATEST_TAG}${RESET} @ ${BGREEN}HEAD${RESET}, ${RED}deleting existing tags on remote${RESET}? "
|
||||||
|
read -e -p 'y/n [n]: ' input
|
||||||
|
RESPONSE=${input:='n'}
|
||||||
|
if [ "$RESPONSE" == 'y' ]; then
|
||||||
|
echo
|
||||||
|
echo -e "Deleting ${BCYAN}${VERSION}${RESET} tag on remote..."
|
||||||
|
git push --delete origin $VERSION
|
||||||
|
|
||||||
|
echo -e "Tagging ${BGREEN}HEAD${RESET} with ${BCYAN}${VERSION}${RESET} locally..."
|
||||||
|
if ! git tag -fa $VERSION; then
|
||||||
|
echo "Existing/invalid tag"
|
||||||
|
exit -1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo -e "Deleting ${BCYAN}${LATEST_TAG}${RESET} tag on remote..."
|
||||||
|
git push --delete origin $LATEST_TAG
|
||||||
|
|
||||||
|
echo -e "Tagging ${BGREEN}HEAD${RESET} with ${BCYAN}${LATEST_TAG}${RESET} locally..."
|
||||||
|
git tag -fa $LATEST_TAG
|
||||||
|
|
||||||
|
echo -e "Pushing updated tags to remote..."
|
||||||
|
git push origin --tags
|
||||||
|
fi
|
||||||
|
exit 0
|
@ -272,6 +272,8 @@ def invoke_api() -> None:
|
|||||||
port=port,
|
port=port,
|
||||||
loop="asyncio",
|
loop="asyncio",
|
||||||
log_level=app_config.log_level,
|
log_level=app_config.log_level,
|
||||||
|
ssl_certfile=app_config.ssl_certfile,
|
||||||
|
ssl_keyfile=app_config.ssl_keyfile,
|
||||||
)
|
)
|
||||||
server = uvicorn.Server(config)
|
server = uvicorn.Server(config)
|
||||||
|
|
||||||
|
@ -39,6 +39,19 @@ class InvalidFieldError(TypeError):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Classification(str, Enum, metaclass=MetaEnum):
|
||||||
|
"""
|
||||||
|
The classification of an Invocation.
|
||||||
|
- `Stable`: The invocation, including its inputs/outputs and internal logic, is stable. You may build workflows with it, having confidence that they will not break because of a change in this invocation.
|
||||||
|
- `Beta`: The invocation is not yet stable, but is planned to be stable in the future. Workflows built around this invocation may break, but we are committed to supporting this invocation long-term.
|
||||||
|
- `Prototype`: The invocation is not yet stable and may be removed from the application at any time. Workflows built around this invocation may break, and we are *not* committed to supporting this invocation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
Stable = "stable"
|
||||||
|
Beta = "beta"
|
||||||
|
Prototype = "prototype"
|
||||||
|
|
||||||
|
|
||||||
class Input(str, Enum, metaclass=MetaEnum):
|
class Input(str, Enum, metaclass=MetaEnum):
|
||||||
"""
|
"""
|
||||||
The type of input a field accepts.
|
The type of input a field accepts.
|
||||||
@ -439,6 +452,7 @@ class UIConfigBase(BaseModel):
|
|||||||
description='The node\'s version. Should be a valid semver string e.g. "1.0.0" or "3.8.13".',
|
description='The node\'s version. Should be a valid semver string e.g. "1.0.0" or "3.8.13".',
|
||||||
)
|
)
|
||||||
node_pack: Optional[str] = Field(default=None, description="Whether or not this is a custom node")
|
node_pack: Optional[str] = Field(default=None, description="Whether or not this is a custom node")
|
||||||
|
classification: Classification = Field(default=Classification.Stable, description="The node's classification")
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
validate_assignment=True,
|
validate_assignment=True,
|
||||||
@ -607,6 +621,7 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
schema["category"] = uiconfig.category
|
schema["category"] = uiconfig.category
|
||||||
if uiconfig.node_pack is not None:
|
if uiconfig.node_pack is not None:
|
||||||
schema["node_pack"] = uiconfig.node_pack
|
schema["node_pack"] = uiconfig.node_pack
|
||||||
|
schema["classification"] = uiconfig.classification
|
||||||
schema["version"] = uiconfig.version
|
schema["version"] = uiconfig.version
|
||||||
if "required" not in schema or not isinstance(schema["required"], list):
|
if "required" not in schema or not isinstance(schema["required"], list):
|
||||||
schema["required"] = []
|
schema["required"] = []
|
||||||
@ -782,6 +797,7 @@ def invocation(
|
|||||||
category: Optional[str] = None,
|
category: Optional[str] = None,
|
||||||
version: Optional[str] = None,
|
version: Optional[str] = None,
|
||||||
use_cache: Optional[bool] = True,
|
use_cache: Optional[bool] = True,
|
||||||
|
classification: Classification = Classification.Stable,
|
||||||
) -> Callable[[Type[TBaseInvocation]], Type[TBaseInvocation]]:
|
) -> Callable[[Type[TBaseInvocation]], Type[TBaseInvocation]]:
|
||||||
"""
|
"""
|
||||||
Registers an invocation.
|
Registers an invocation.
|
||||||
@ -792,6 +808,7 @@ def invocation(
|
|||||||
:param Optional[str] category: Adds a category to the invocation. Used to group the invocations in the UI. Defaults to None.
|
:param Optional[str] category: Adds a category to the invocation. Used to group the invocations in the UI. Defaults to None.
|
||||||
:param Optional[str] version: Adds a version to the invocation. Must be a valid semver string. Defaults to None.
|
:param Optional[str] version: Adds a version to the invocation. Must be a valid semver string. Defaults to None.
|
||||||
:param Optional[bool] use_cache: Whether or not to use the invocation cache. Defaults to True. The user may override this in the workflow editor.
|
:param Optional[bool] use_cache: Whether or not to use the invocation cache. Defaults to True. The user may override this in the workflow editor.
|
||||||
|
:param Classification classification: The classification of the invocation. Defaults to FeatureClassification.Stable. Use Beta or Prototype if the invocation is unstable.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def wrapper(cls: Type[TBaseInvocation]) -> Type[TBaseInvocation]:
|
def wrapper(cls: Type[TBaseInvocation]) -> Type[TBaseInvocation]:
|
||||||
@ -812,6 +829,7 @@ def invocation(
|
|||||||
cls.UIConfig.title = title
|
cls.UIConfig.title = title
|
||||||
cls.UIConfig.tags = tags
|
cls.UIConfig.tags = tags
|
||||||
cls.UIConfig.category = category
|
cls.UIConfig.category = category
|
||||||
|
cls.UIConfig.classification = classification
|
||||||
|
|
||||||
# Grab the node pack's name from the module name, if it's a custom node
|
# Grab the node pack's name from the module name, if it's a custom node
|
||||||
is_custom_node = cls.__module__.rsplit(".", 1)[0] == "invokeai.app.invocations"
|
is_custom_node = cls.__module__.rsplit(".", 1)[0] == "invokeai.app.invocations"
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from typing import Literal
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@ -5,6 +7,8 @@ from pydantic import BaseModel
|
|||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
|
Classification,
|
||||||
|
Input,
|
||||||
InputField,
|
InputField,
|
||||||
InvocationContext,
|
InvocationContext,
|
||||||
OutputField,
|
OutputField,
|
||||||
@ -14,7 +18,13 @@ from invokeai.app.invocations.baseinvocation import (
|
|||||||
)
|
)
|
||||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||||
from invokeai.backend.tiles.tiles import calc_tiles_with_overlap, merge_tiles_with_linear_blending
|
from invokeai.backend.tiles.tiles import (
|
||||||
|
calc_tiles_even_split,
|
||||||
|
calc_tiles_min_overlap,
|
||||||
|
calc_tiles_with_overlap,
|
||||||
|
merge_tiles_with_linear_blending,
|
||||||
|
merge_tiles_with_seam_blending,
|
||||||
|
)
|
||||||
from invokeai.backend.tiles.utils import Tile
|
from invokeai.backend.tiles.utils import Tile
|
||||||
|
|
||||||
|
|
||||||
@ -55,6 +65,79 @@ class CalculateImageTilesInvocation(BaseInvocation):
|
|||||||
return CalculateImageTilesOutput(tiles=tiles)
|
return CalculateImageTilesOutput(tiles=tiles)
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"calculate_image_tiles_even_split",
|
||||||
|
title="Calculate Image Tiles Even Split",
|
||||||
|
tags=["tiles"],
|
||||||
|
category="tiles",
|
||||||
|
version="1.0.0",
|
||||||
|
classification=Classification.Beta,
|
||||||
|
)
|
||||||
|
class CalculateImageTilesEvenSplitInvocation(BaseInvocation):
|
||||||
|
"""Calculate the coordinates and overlaps of tiles that cover a target image shape."""
|
||||||
|
|
||||||
|
image_width: int = InputField(ge=1, default=1024, description="The image width, in pixels, to calculate tiles for.")
|
||||||
|
image_height: int = InputField(
|
||||||
|
ge=1, default=1024, description="The image height, in pixels, to calculate tiles for."
|
||||||
|
)
|
||||||
|
num_tiles_x: int = InputField(
|
||||||
|
default=2,
|
||||||
|
ge=1,
|
||||||
|
description="Number of tiles to divide image into on the x axis",
|
||||||
|
)
|
||||||
|
num_tiles_y: int = InputField(
|
||||||
|
default=2,
|
||||||
|
ge=1,
|
||||||
|
description="Number of tiles to divide image into on the y axis",
|
||||||
|
)
|
||||||
|
overlap_fraction: float = InputField(
|
||||||
|
default=0.25,
|
||||||
|
ge=0,
|
||||||
|
lt=1,
|
||||||
|
description="Overlap between adjacent tiles as a fraction of the tile's dimensions (0-1)",
|
||||||
|
)
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput:
|
||||||
|
tiles = calc_tiles_even_split(
|
||||||
|
image_height=self.image_height,
|
||||||
|
image_width=self.image_width,
|
||||||
|
num_tiles_x=self.num_tiles_x,
|
||||||
|
num_tiles_y=self.num_tiles_y,
|
||||||
|
overlap_fraction=self.overlap_fraction,
|
||||||
|
)
|
||||||
|
return CalculateImageTilesOutput(tiles=tiles)
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"calculate_image_tiles_min_overlap",
|
||||||
|
title="Calculate Image Tiles Minimum Overlap",
|
||||||
|
tags=["tiles"],
|
||||||
|
category="tiles",
|
||||||
|
version="1.0.0",
|
||||||
|
classification=Classification.Beta,
|
||||||
|
)
|
||||||
|
class CalculateImageTilesMinimumOverlapInvocation(BaseInvocation):
|
||||||
|
"""Calculate the coordinates and overlaps of tiles that cover a target image shape."""
|
||||||
|
|
||||||
|
image_width: int = InputField(ge=1, default=1024, description="The image width, in pixels, to calculate tiles for.")
|
||||||
|
image_height: int = InputField(
|
||||||
|
ge=1, default=1024, description="The image height, in pixels, to calculate tiles for."
|
||||||
|
)
|
||||||
|
tile_width: int = InputField(ge=1, default=576, description="The tile width, in pixels.")
|
||||||
|
tile_height: int = InputField(ge=1, default=576, description="The tile height, in pixels.")
|
||||||
|
min_overlap: int = InputField(default=128, ge=0, description="Minimum overlap between adjacent tiles, in pixels.")
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput:
|
||||||
|
tiles = calc_tiles_min_overlap(
|
||||||
|
image_height=self.image_height,
|
||||||
|
image_width=self.image_width,
|
||||||
|
tile_height=self.tile_height,
|
||||||
|
tile_width=self.tile_width,
|
||||||
|
min_overlap=self.min_overlap,
|
||||||
|
)
|
||||||
|
return CalculateImageTilesOutput(tiles=tiles)
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("tile_to_properties_output")
|
@invocation_output("tile_to_properties_output")
|
||||||
class TileToPropertiesOutput(BaseInvocationOutput):
|
class TileToPropertiesOutput(BaseInvocationOutput):
|
||||||
coords_left: int = OutputField(description="Left coordinate of the tile relative to its parent image.")
|
coords_left: int = OutputField(description="Left coordinate of the tile relative to its parent image.")
|
||||||
@ -76,7 +159,14 @@ class TileToPropertiesOutput(BaseInvocationOutput):
|
|||||||
overlap_right: int = OutputField(description="Overlap between this tile and its right neighbor.")
|
overlap_right: int = OutputField(description="Overlap between this tile and its right neighbor.")
|
||||||
|
|
||||||
|
|
||||||
@invocation("tile_to_properties", title="Tile to Properties", tags=["tiles"], category="tiles", version="1.0.0")
|
@invocation(
|
||||||
|
"tile_to_properties",
|
||||||
|
title="Tile to Properties",
|
||||||
|
tags=["tiles"],
|
||||||
|
category="tiles",
|
||||||
|
version="1.0.0",
|
||||||
|
classification=Classification.Beta,
|
||||||
|
)
|
||||||
class TileToPropertiesInvocation(BaseInvocation):
|
class TileToPropertiesInvocation(BaseInvocation):
|
||||||
"""Split a Tile into its individual properties."""
|
"""Split a Tile into its individual properties."""
|
||||||
|
|
||||||
@ -102,7 +192,14 @@ class PairTileImageOutput(BaseInvocationOutput):
|
|||||||
tile_with_image: TileWithImage = OutputField(description="A tile description with its corresponding image.")
|
tile_with_image: TileWithImage = OutputField(description="A tile description with its corresponding image.")
|
||||||
|
|
||||||
|
|
||||||
@invocation("pair_tile_image", title="Pair Tile with Image", tags=["tiles"], category="tiles", version="1.0.0")
|
@invocation(
|
||||||
|
"pair_tile_image",
|
||||||
|
title="Pair Tile with Image",
|
||||||
|
tags=["tiles"],
|
||||||
|
category="tiles",
|
||||||
|
version="1.0.0",
|
||||||
|
classification=Classification.Beta,
|
||||||
|
)
|
||||||
class PairTileImageInvocation(BaseInvocation):
|
class PairTileImageInvocation(BaseInvocation):
|
||||||
"""Pair an image with its tile properties."""
|
"""Pair an image with its tile properties."""
|
||||||
|
|
||||||
@ -121,13 +218,29 @@ class PairTileImageInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("merge_tiles_to_image", title="Merge Tiles to Image", tags=["tiles"], category="tiles", version="1.1.0")
|
BLEND_MODES = Literal["Linear", "Seam"]
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"merge_tiles_to_image",
|
||||||
|
title="Merge Tiles to Image",
|
||||||
|
tags=["tiles"],
|
||||||
|
category="tiles",
|
||||||
|
version="1.1.0",
|
||||||
|
classification=Classification.Beta,
|
||||||
|
)
|
||||||
class MergeTilesToImageInvocation(BaseInvocation, WithMetadata):
|
class MergeTilesToImageInvocation(BaseInvocation, WithMetadata):
|
||||||
"""Merge multiple tile images into a single image."""
|
"""Merge multiple tile images into a single image."""
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
tiles_with_images: list[TileWithImage] = InputField(description="A list of tile images with tile properties.")
|
tiles_with_images: list[TileWithImage] = InputField(description="A list of tile images with tile properties.")
|
||||||
|
blend_mode: BLEND_MODES = InputField(
|
||||||
|
default="Seam",
|
||||||
|
description="blending type Linear or Seam",
|
||||||
|
input=Input.Direct,
|
||||||
|
)
|
||||||
blend_amount: int = InputField(
|
blend_amount: int = InputField(
|
||||||
|
default=32,
|
||||||
ge=0,
|
ge=0,
|
||||||
description="The amount to blend adjacent tiles in pixels. Must be <= the amount of overlap between adjacent tiles.",
|
description="The amount to blend adjacent tiles in pixels. Must be <= the amount of overlap between adjacent tiles.",
|
||||||
)
|
)
|
||||||
@ -157,10 +270,18 @@ class MergeTilesToImageInvocation(BaseInvocation, WithMetadata):
|
|||||||
channels = tile_np_images[0].shape[-1]
|
channels = tile_np_images[0].shape[-1]
|
||||||
dtype = tile_np_images[0].dtype
|
dtype = tile_np_images[0].dtype
|
||||||
np_image = np.zeros(shape=(height, width, channels), dtype=dtype)
|
np_image = np.zeros(shape=(height, width, channels), dtype=dtype)
|
||||||
|
if self.blend_mode == "Linear":
|
||||||
|
merge_tiles_with_linear_blending(
|
||||||
|
dst_image=np_image, tiles=tiles, tile_images=tile_np_images, blend_amount=self.blend_amount
|
||||||
|
)
|
||||||
|
elif self.blend_mode == "Seam":
|
||||||
|
merge_tiles_with_seam_blending(
|
||||||
|
dst_image=np_image, tiles=tiles, tile_images=tile_np_images, blend_amount=self.blend_amount
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported blend mode: '{self.blend_mode}'.")
|
||||||
|
|
||||||
merge_tiles_with_linear_blending(
|
# Convert into a PIL image and save
|
||||||
dst_image=np_image, tiles=tiles, tile_images=tile_np_images, blend_amount=self.blend_amount
|
|
||||||
)
|
|
||||||
pil_image = Image.fromarray(np_image)
|
pil_image = Image.fromarray(np_image)
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
|
@ -221,6 +221,9 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
allow_credentials : bool = Field(default=True, description="Allow CORS credentials", json_schema_extra=Categories.WebServer)
|
allow_credentials : bool = Field(default=True, description="Allow CORS credentials", json_schema_extra=Categories.WebServer)
|
||||||
allow_methods : List[str] = Field(default=["*"], description="Methods allowed for CORS", json_schema_extra=Categories.WebServer)
|
allow_methods : List[str] = Field(default=["*"], description="Methods allowed for CORS", json_schema_extra=Categories.WebServer)
|
||||||
allow_headers : List[str] = Field(default=["*"], description="Headers allowed for CORS", json_schema_extra=Categories.WebServer)
|
allow_headers : List[str] = Field(default=["*"], description="Headers allowed for CORS", json_schema_extra=Categories.WebServer)
|
||||||
|
# SSL options correspond to https://www.uvicorn.org/settings/#https
|
||||||
|
ssl_certfile : Optional[Path] = Field(default=None, description="SSL certificate file (for HTTPS)", json_schema_extra=Categories.WebServer)
|
||||||
|
ssl_keyfile : Optional[Path] = Field(default=None, description="SSL key file", json_schema_extra=Categories.WebServer)
|
||||||
|
|
||||||
# FEATURES
|
# FEATURES
|
||||||
esrgan : bool = Field(default=True, description="Enable/disable upscaling code", json_schema_extra=Categories.Features)
|
esrgan : bool = Field(default=True, description="Enable/disable upscaling code", json_schema_extra=Categories.Features)
|
||||||
|
@ -85,7 +85,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
def event_bus(self) -> Optional[EventServiceBase]: # noqa D102
|
def event_bus(self) -> Optional[EventServiceBase]: # noqa D102
|
||||||
return self._event_bus
|
return self._event_bus
|
||||||
|
|
||||||
def stop(self) -> None:
|
def stop(self, *args, **kwargs) -> None:
|
||||||
"""Stop the install thread; after this the object can be deleted and garbage collected."""
|
"""Stop the install thread; after this the object can be deleted and garbage collected."""
|
||||||
self._install_queue.put(STOP_JOB)
|
self._install_queue.put(STOP_JOB)
|
||||||
|
|
||||||
|
@ -95,21 +95,13 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
"""--sql
|
"""--sql
|
||||||
INSERT INTO model_config (
|
INSERT INTO model_config (
|
||||||
id,
|
id,
|
||||||
base,
|
|
||||||
type,
|
|
||||||
name,
|
|
||||||
path,
|
|
||||||
original_hash,
|
original_hash,
|
||||||
config
|
config
|
||||||
)
|
)
|
||||||
VALUES (?,?,?,?,?,?,?);
|
VALUES (?,?,?);
|
||||||
""",
|
""",
|
||||||
(
|
(
|
||||||
key,
|
key,
|
||||||
record.base,
|
|
||||||
record.type,
|
|
||||||
record.name,
|
|
||||||
record.path,
|
|
||||||
record.original_hash,
|
record.original_hash,
|
||||||
json_serialized,
|
json_serialized,
|
||||||
),
|
),
|
||||||
@ -173,14 +165,11 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
UPDATE model_config
|
UPDATE model_config
|
||||||
SET base=?,
|
SET
|
||||||
type=?,
|
|
||||||
name=?,
|
|
||||||
path=?,
|
|
||||||
config=?
|
config=?
|
||||||
WHERE id=?;
|
WHERE id=?;
|
||||||
""",
|
""",
|
||||||
(record.base, record.type, record.name, record.path, json_serialized, key),
|
(json_serialized, key),
|
||||||
)
|
)
|
||||||
if self._cursor.rowcount == 0:
|
if self._cursor.rowcount == 0:
|
||||||
raise UnknownModelException("model not found")
|
raise UnknownModelException("model not found")
|
||||||
@ -278,7 +267,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
SELECT config FROM model_config
|
SELECT config FROM model_config
|
||||||
WHERE model_path=?;
|
WHERE path=?;
|
||||||
""",
|
""",
|
||||||
(str(path),),
|
(str(path),),
|
||||||
)
|
)
|
||||||
|
@ -22,6 +22,7 @@ def migrate_callback(cursor: sqlite3.Cursor, **kwargs) -> None:
|
|||||||
_drop_old_workflow_tables(cursor)
|
_drop_old_workflow_tables(cursor)
|
||||||
_add_workflow_library(cursor)
|
_add_workflow_library(cursor)
|
||||||
_drop_model_manager_metadata(cursor)
|
_drop_model_manager_metadata(cursor)
|
||||||
|
_recreate_model_config(cursor)
|
||||||
_migrate_embedded_workflows(cursor, logger, image_files)
|
_migrate_embedded_workflows(cursor, logger, image_files)
|
||||||
|
|
||||||
|
|
||||||
@ -101,6 +102,41 @@ def _drop_model_manager_metadata(cursor: sqlite3.Cursor) -> None:
|
|||||||
cursor.execute("DROP TABLE IF EXISTS model_manager_metadata;")
|
cursor.execute("DROP TABLE IF EXISTS model_manager_metadata;")
|
||||||
|
|
||||||
|
|
||||||
|
def _recreate_model_config(cursor: sqlite3.Cursor) -> None:
|
||||||
|
"""
|
||||||
|
Drops the `model_config` table, recreating it.
|
||||||
|
|
||||||
|
In 3.4.0, this table used explicit columns but was changed to use json_extract 3.5.0.
|
||||||
|
|
||||||
|
Because this table is not used in production, we are able to simply drop it and recreate it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
cursor.execute("DROP TABLE IF EXISTS model_config;")
|
||||||
|
|
||||||
|
cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
CREATE TABLE IF NOT EXISTS model_config (
|
||||||
|
id TEXT NOT NULL PRIMARY KEY,
|
||||||
|
-- The next 3 fields are enums in python, unrestricted string here
|
||||||
|
base TEXT GENERATED ALWAYS as (json_extract(config, '$.base')) VIRTUAL NOT NULL,
|
||||||
|
type TEXT GENERATED ALWAYS as (json_extract(config, '$.type')) VIRTUAL NOT NULL,
|
||||||
|
name TEXT GENERATED ALWAYS as (json_extract(config, '$.name')) VIRTUAL NOT NULL,
|
||||||
|
path TEXT GENERATED ALWAYS as (json_extract(config, '$.path')) VIRTUAL NOT NULL,
|
||||||
|
format TEXT GENERATED ALWAYS as (json_extract(config, '$.format')) VIRTUAL NOT NULL,
|
||||||
|
original_hash TEXT, -- could be null
|
||||||
|
-- Serialized JSON representation of the whole config object,
|
||||||
|
-- which will contain additional fields from subclasses
|
||||||
|
config TEXT NOT NULL,
|
||||||
|
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
|
-- Updated via trigger
|
||||||
|
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
|
-- unique constraint on combo of name, base and type
|
||||||
|
UNIQUE(name, base, type)
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _migrate_embedded_workflows(
|
def _migrate_embedded_workflows(
|
||||||
cursor: sqlite3.Cursor,
|
cursor: sqlite3.Cursor,
|
||||||
logger: Logger,
|
logger: Logger,
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
"""Migrate from the InvokeAI v2 models.yaml format to the v3 sqlite format."""
|
"""Migrate from the InvokeAI v2 models.yaml format to the v3 sqlite format."""
|
||||||
|
|
||||||
from hashlib import sha1
|
from hashlib import sha1
|
||||||
|
from logging import Logger
|
||||||
|
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from pydantic import TypeAdapter
|
from pydantic import TypeAdapter
|
||||||
@ -10,6 +11,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
|
|||||||
from invokeai.app.services.model_records import (
|
from invokeai.app.services.model_records import (
|
||||||
DuplicateModelException,
|
DuplicateModelException,
|
||||||
ModelRecordServiceSQL,
|
ModelRecordServiceSQL,
|
||||||
|
UnknownModelException,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||||
from invokeai.backend.model_manager.config import (
|
from invokeai.backend.model_manager.config import (
|
||||||
@ -38,9 +40,9 @@ class MigrateModelYamlToDb:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
config: InvokeAIAppConfig
|
config: InvokeAIAppConfig
|
||||||
logger: InvokeAILogger
|
logger: Logger
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.config = InvokeAIAppConfig.get_config()
|
self.config = InvokeAIAppConfig.get_config()
|
||||||
self.config.parse_args()
|
self.config.parse_args()
|
||||||
self.logger = InvokeAILogger.get_logger()
|
self.logger = InvokeAILogger.get_logger()
|
||||||
@ -54,9 +56,11 @@ class MigrateModelYamlToDb:
|
|||||||
def get_yaml(self) -> DictConfig:
|
def get_yaml(self) -> DictConfig:
|
||||||
"""Fetch the models.yaml DictConfig for this installation."""
|
"""Fetch the models.yaml DictConfig for this installation."""
|
||||||
yaml_path = self.config.model_conf_path
|
yaml_path = self.config.model_conf_path
|
||||||
return OmegaConf.load(yaml_path)
|
omegaconf = OmegaConf.load(yaml_path)
|
||||||
|
assert isinstance(omegaconf, DictConfig)
|
||||||
|
return omegaconf
|
||||||
|
|
||||||
def migrate(self):
|
def migrate(self) -> None:
|
||||||
"""Do the migration from models.yaml to invokeai.db."""
|
"""Do the migration from models.yaml to invokeai.db."""
|
||||||
db = self.get_db()
|
db = self.get_db()
|
||||||
yaml = self.get_yaml()
|
yaml = self.get_yaml()
|
||||||
@ -70,6 +74,7 @@ class MigrateModelYamlToDb:
|
|||||||
|
|
||||||
base_type, model_type, model_name = str(model_key).split("/")
|
base_type, model_type, model_name = str(model_key).split("/")
|
||||||
hash = FastModelHash.hash(self.config.models_path / stanza.path)
|
hash = FastModelHash.hash(self.config.models_path / stanza.path)
|
||||||
|
assert isinstance(model_key, str)
|
||||||
new_key = sha1(model_key.encode("utf-8")).hexdigest()
|
new_key = sha1(model_key.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
stanza["base"] = BaseModelType(base_type)
|
stanza["base"] = BaseModelType(base_type)
|
||||||
@ -78,12 +83,20 @@ class MigrateModelYamlToDb:
|
|||||||
stanza["original_hash"] = hash
|
stanza["original_hash"] = hash
|
||||||
stanza["current_hash"] = hash
|
stanza["current_hash"] = hash
|
||||||
|
|
||||||
new_config = ModelsValidator.validate_python(stanza)
|
new_config: AnyModelConfig = ModelsValidator.validate_python(stanza) # type: ignore # see https://github.com/pydantic/pydantic/discussions/7094
|
||||||
self.logger.info(f"Adding model {model_name} with key {model_key}")
|
|
||||||
try:
|
try:
|
||||||
db.add_model(new_key, new_config)
|
if original_record := db.search_by_path(stanza.path):
|
||||||
|
key = original_record[0].key
|
||||||
|
self.logger.info(f"Updating model {model_name} with information from models.yaml using key {key}")
|
||||||
|
db.update_model(key, new_config)
|
||||||
|
else:
|
||||||
|
self.logger.info(f"Adding model {model_name} with key {model_key}")
|
||||||
|
db.add_model(new_key, new_config)
|
||||||
except DuplicateModelException:
|
except DuplicateModelException:
|
||||||
self.logger.warning(f"Model {model_name} is already in the database")
|
self.logger.warning(f"Model {model_name} is already in the database")
|
||||||
|
except UnknownModelException:
|
||||||
|
self.logger.warning(f"Model at {stanza.path} could not be found in database")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
@ -3,7 +3,42 @@ from typing import Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from invokeai.backend.tiles.utils import TBLR, Tile, paste
|
from invokeai.app.invocations.latent import LATENT_SCALE_FACTOR
|
||||||
|
from invokeai.backend.tiles.utils import TBLR, Tile, paste, seam_blend
|
||||||
|
|
||||||
|
|
||||||
|
def calc_overlap(tiles: list[Tile], num_tiles_x: int, num_tiles_y: int) -> list[Tile]:
|
||||||
|
"""Calculate and update the overlap of a list of tiles.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tiles (list[Tile]): The list of tiles describing the locations of the respective `tile_images`.
|
||||||
|
num_tiles_x: the number of tiles on the x axis.
|
||||||
|
num_tiles_y: the number of tiles on the y axis.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_tile_or_none(idx_y: int, idx_x: int) -> Union[Tile, None]:
|
||||||
|
if idx_y < 0 or idx_y > num_tiles_y or idx_x < 0 or idx_x > num_tiles_x:
|
||||||
|
return None
|
||||||
|
return tiles[idx_y * num_tiles_x + idx_x]
|
||||||
|
|
||||||
|
for tile_idx_y in range(num_tiles_y):
|
||||||
|
for tile_idx_x in range(num_tiles_x):
|
||||||
|
cur_tile = get_tile_or_none(tile_idx_y, tile_idx_x)
|
||||||
|
top_neighbor_tile = get_tile_or_none(tile_idx_y - 1, tile_idx_x)
|
||||||
|
left_neighbor_tile = get_tile_or_none(tile_idx_y, tile_idx_x - 1)
|
||||||
|
|
||||||
|
assert cur_tile is not None
|
||||||
|
|
||||||
|
# Update cur_tile top-overlap and corresponding top-neighbor bottom-overlap.
|
||||||
|
if top_neighbor_tile is not None:
|
||||||
|
cur_tile.overlap.top = max(0, top_neighbor_tile.coords.bottom - cur_tile.coords.top)
|
||||||
|
top_neighbor_tile.overlap.bottom = cur_tile.overlap.top
|
||||||
|
|
||||||
|
# Update cur_tile left-overlap and corresponding left-neighbor right-overlap.
|
||||||
|
if left_neighbor_tile is not None:
|
||||||
|
cur_tile.overlap.left = max(0, left_neighbor_tile.coords.right - cur_tile.coords.left)
|
||||||
|
left_neighbor_tile.overlap.right = cur_tile.overlap.left
|
||||||
|
return tiles
|
||||||
|
|
||||||
|
|
||||||
def calc_tiles_with_overlap(
|
def calc_tiles_with_overlap(
|
||||||
@ -63,31 +98,125 @@ def calc_tiles_with_overlap(
|
|||||||
|
|
||||||
tiles.append(tile)
|
tiles.append(tile)
|
||||||
|
|
||||||
def get_tile_or_none(idx_y: int, idx_x: int) -> Union[Tile, None]:
|
return calc_overlap(tiles, num_tiles_x, num_tiles_y)
|
||||||
if idx_y < 0 or idx_y > num_tiles_y or idx_x < 0 or idx_x > num_tiles_x:
|
|
||||||
return None
|
|
||||||
return tiles[idx_y * num_tiles_x + idx_x]
|
|
||||||
|
|
||||||
# Iterate over tiles again and calculate overlaps.
|
|
||||||
|
def calc_tiles_even_split(
|
||||||
|
image_height: int, image_width: int, num_tiles_x: int, num_tiles_y: int, overlap_fraction: float = 0
|
||||||
|
) -> list[Tile]:
|
||||||
|
"""Calculate the tile coordinates for a given image shape with the number of tiles requested.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_height (int): The image height in px.
|
||||||
|
image_width (int): The image width in px.
|
||||||
|
num_x_tiles (int): The number of tile to split the image into on the X-axis.
|
||||||
|
num_y_tiles (int): The number of tile to split the image into on the Y-axis.
|
||||||
|
overlap_fraction (float, optional): The target overlap as fraction of the tiles size. Defaults to 0.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[Tile]: A list of tiles that cover the image shape. Ordered from left-to-right, top-to-bottom.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Ensure tile size is divisible by 8
|
||||||
|
if image_width % LATENT_SCALE_FACTOR != 0 or image_height % LATENT_SCALE_FACTOR != 0:
|
||||||
|
raise ValueError(f"image size (({image_width}, {image_height})) must be divisible by {LATENT_SCALE_FACTOR}")
|
||||||
|
|
||||||
|
# Calculate the overlap size based on the percentage and adjust it to be divisible by 8 (rounding up)
|
||||||
|
overlap_x = LATENT_SCALE_FACTOR * math.ceil(
|
||||||
|
int((image_width / num_tiles_x) * overlap_fraction) / LATENT_SCALE_FACTOR
|
||||||
|
)
|
||||||
|
overlap_y = LATENT_SCALE_FACTOR * math.ceil(
|
||||||
|
int((image_height / num_tiles_y) * overlap_fraction) / LATENT_SCALE_FACTOR
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate the tile size based on the number of tiles and overlap, and ensure it's divisible by 8 (rounding down)
|
||||||
|
tile_size_x = LATENT_SCALE_FACTOR * math.floor(
|
||||||
|
((image_width + overlap_x * (num_tiles_x - 1)) // num_tiles_x) / LATENT_SCALE_FACTOR
|
||||||
|
)
|
||||||
|
tile_size_y = LATENT_SCALE_FACTOR * math.floor(
|
||||||
|
((image_height + overlap_y * (num_tiles_y - 1)) // num_tiles_y) / LATENT_SCALE_FACTOR
|
||||||
|
)
|
||||||
|
|
||||||
|
# tiles[y * num_tiles_x + x] is the tile for the y'th row, x'th column.
|
||||||
|
tiles: list[Tile] = []
|
||||||
|
|
||||||
|
# Calculate tile coordinates. (Ignore overlap values for now.)
|
||||||
for tile_idx_y in range(num_tiles_y):
|
for tile_idx_y in range(num_tiles_y):
|
||||||
|
# Calculate the top and bottom of the row
|
||||||
|
top = tile_idx_y * (tile_size_y - overlap_y)
|
||||||
|
bottom = min(top + tile_size_y, image_height)
|
||||||
|
# For the last row adjust bottom to be the height of the image
|
||||||
|
if tile_idx_y == num_tiles_y - 1:
|
||||||
|
bottom = image_height
|
||||||
|
|
||||||
for tile_idx_x in range(num_tiles_x):
|
for tile_idx_x in range(num_tiles_x):
|
||||||
cur_tile = get_tile_or_none(tile_idx_y, tile_idx_x)
|
# Calculate the left & right coordinate of each tile
|
||||||
top_neighbor_tile = get_tile_or_none(tile_idx_y - 1, tile_idx_x)
|
left = tile_idx_x * (tile_size_x - overlap_x)
|
||||||
left_neighbor_tile = get_tile_or_none(tile_idx_y, tile_idx_x - 1)
|
right = min(left + tile_size_x, image_width)
|
||||||
|
# For the last tile in the row adjust right to be the width of the image
|
||||||
|
if tile_idx_x == num_tiles_x - 1:
|
||||||
|
right = image_width
|
||||||
|
|
||||||
assert cur_tile is not None
|
tile = Tile(
|
||||||
|
coords=TBLR(top=top, bottom=bottom, left=left, right=right),
|
||||||
|
overlap=TBLR(top=0, bottom=0, left=0, right=0),
|
||||||
|
)
|
||||||
|
|
||||||
# Update cur_tile top-overlap and corresponding top-neighbor bottom-overlap.
|
tiles.append(tile)
|
||||||
if top_neighbor_tile is not None:
|
|
||||||
cur_tile.overlap.top = max(0, top_neighbor_tile.coords.bottom - cur_tile.coords.top)
|
|
||||||
top_neighbor_tile.overlap.bottom = cur_tile.overlap.top
|
|
||||||
|
|
||||||
# Update cur_tile left-overlap and corresponding left-neighbor right-overlap.
|
return calc_overlap(tiles, num_tiles_x, num_tiles_y)
|
||||||
if left_neighbor_tile is not None:
|
|
||||||
cur_tile.overlap.left = max(0, left_neighbor_tile.coords.right - cur_tile.coords.left)
|
|
||||||
left_neighbor_tile.overlap.right = cur_tile.overlap.left
|
|
||||||
|
|
||||||
return tiles
|
|
||||||
|
def calc_tiles_min_overlap(
|
||||||
|
image_height: int,
|
||||||
|
image_width: int,
|
||||||
|
tile_height: int,
|
||||||
|
tile_width: int,
|
||||||
|
min_overlap: int = 0,
|
||||||
|
) -> list[Tile]:
|
||||||
|
"""Calculate the tile coordinates for a given image shape under a simple tiling scheme with overlaps.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_height (int): The image height in px.
|
||||||
|
image_width (int): The image width in px.
|
||||||
|
tile_height (int): The tile height in px. All tiles will have this height.
|
||||||
|
tile_width (int): The tile width in px. All tiles will have this width.
|
||||||
|
min_overlap (int): The target minimum overlap between adjacent tiles. If the tiles do not evenly cover the image
|
||||||
|
shape, then the overlap will be spread between the tiles.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[Tile]: A list of tiles that cover the image shape. Ordered from left-to-right, top-to-bottom.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert min_overlap < tile_height
|
||||||
|
assert min_overlap < tile_width
|
||||||
|
|
||||||
|
# The If Else catches the case when the tile size is larger than the images size and just clips the number of tiles to 1
|
||||||
|
num_tiles_x = math.ceil((image_width - min_overlap) / (tile_width - min_overlap)) if tile_width < image_width else 1
|
||||||
|
num_tiles_y = (
|
||||||
|
math.ceil((image_height - min_overlap) / (tile_height - min_overlap)) if tile_height < image_height else 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# tiles[y * num_tiles_x + x] is the tile for the y'th row, x'th column.
|
||||||
|
tiles: list[Tile] = []
|
||||||
|
|
||||||
|
# Calculate tile coordinates. (Ignore overlap values for now.)
|
||||||
|
for tile_idx_y in range(num_tiles_y):
|
||||||
|
top = (tile_idx_y * (image_height - tile_height)) // (num_tiles_y - 1) if num_tiles_y > 1 else 0
|
||||||
|
bottom = top + tile_height
|
||||||
|
|
||||||
|
for tile_idx_x in range(num_tiles_x):
|
||||||
|
left = (tile_idx_x * (image_width - tile_width)) // (num_tiles_x - 1) if num_tiles_x > 1 else 0
|
||||||
|
right = left + tile_width
|
||||||
|
|
||||||
|
tile = Tile(
|
||||||
|
coords=TBLR(top=top, bottom=bottom, left=left, right=right),
|
||||||
|
overlap=TBLR(top=0, bottom=0, left=0, right=0),
|
||||||
|
)
|
||||||
|
|
||||||
|
tiles.append(tile)
|
||||||
|
|
||||||
|
return calc_overlap(tiles, num_tiles_x, num_tiles_y)
|
||||||
|
|
||||||
|
|
||||||
def merge_tiles_with_linear_blending(
|
def merge_tiles_with_linear_blending(
|
||||||
@ -199,3 +328,91 @@ def merge_tiles_with_linear_blending(
|
|||||||
),
|
),
|
||||||
mask=mask,
|
mask=mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def merge_tiles_with_seam_blending(
|
||||||
|
dst_image: np.ndarray, tiles: list[Tile], tile_images: list[np.ndarray], blend_amount: int
|
||||||
|
):
|
||||||
|
"""Merge a set of image tiles into `dst_image` with seam blending between the tiles.
|
||||||
|
|
||||||
|
We expect every tile edge to either:
|
||||||
|
1) have an overlap of 0, because it is aligned with the image edge, or
|
||||||
|
2) have an overlap >= blend_amount.
|
||||||
|
If neither of these conditions are satisfied, we raise an exception.
|
||||||
|
|
||||||
|
The seam blending is centered on a seam of least energy of the overlap between adjacent tiles.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dst_image (np.ndarray): The destination image. Shape: (H, W, C).
|
||||||
|
tiles (list[Tile]): The list of tiles describing the locations of the respective `tile_images`.
|
||||||
|
tile_images (list[np.ndarray]): The tile images to merge into `dst_image`.
|
||||||
|
blend_amount (int): The amount of blending (in px) between adjacent overlapping tiles.
|
||||||
|
"""
|
||||||
|
# Sort tiles and images first by left x coordinate, then by top y coordinate. During tile processing, we want to
|
||||||
|
# iterate over tiles left-to-right, top-to-bottom.
|
||||||
|
tiles_and_images = list(zip(tiles, tile_images, strict=True))
|
||||||
|
tiles_and_images = sorted(tiles_and_images, key=lambda x: x[0].coords.left)
|
||||||
|
tiles_and_images = sorted(tiles_and_images, key=lambda x: x[0].coords.top)
|
||||||
|
|
||||||
|
# Organize tiles into rows.
|
||||||
|
tile_and_image_rows: list[list[tuple[Tile, np.ndarray]]] = []
|
||||||
|
cur_tile_and_image_row: list[tuple[Tile, np.ndarray]] = []
|
||||||
|
first_tile_in_cur_row, _ = tiles_and_images[0]
|
||||||
|
for tile_and_image in tiles_and_images:
|
||||||
|
tile, _ = tile_and_image
|
||||||
|
if not (
|
||||||
|
tile.coords.top == first_tile_in_cur_row.coords.top
|
||||||
|
and tile.coords.bottom == first_tile_in_cur_row.coords.bottom
|
||||||
|
):
|
||||||
|
# Store the previous row, and start a new one.
|
||||||
|
tile_and_image_rows.append(cur_tile_and_image_row)
|
||||||
|
cur_tile_and_image_row = []
|
||||||
|
first_tile_in_cur_row, _ = tile_and_image
|
||||||
|
|
||||||
|
cur_tile_and_image_row.append(tile_and_image)
|
||||||
|
tile_and_image_rows.append(cur_tile_and_image_row)
|
||||||
|
|
||||||
|
for tile_and_image_row in tile_and_image_rows:
|
||||||
|
first_tile_in_row, _ = tile_and_image_row[0]
|
||||||
|
row_height = first_tile_in_row.coords.bottom - first_tile_in_row.coords.top
|
||||||
|
row_image = np.zeros((row_height, dst_image.shape[1], dst_image.shape[2]), dtype=dst_image.dtype)
|
||||||
|
|
||||||
|
# Blend the tiles in the row horizontally.
|
||||||
|
for tile, tile_image in tile_and_image_row:
|
||||||
|
# We expect the tiles to be ordered left-to-right.
|
||||||
|
# For each tile:
|
||||||
|
# - extract the overlap regions and pass to seam_blend()
|
||||||
|
# - apply blended region to the row_image
|
||||||
|
# - apply the un-blended region to the row_image
|
||||||
|
tile_height, tile_width, _ = tile_image.shape
|
||||||
|
overlap_size = tile.overlap.left
|
||||||
|
# Left blending:
|
||||||
|
if overlap_size > 0:
|
||||||
|
assert overlap_size >= blend_amount
|
||||||
|
|
||||||
|
overlap_coord_right = tile.coords.left + overlap_size
|
||||||
|
src_overlap = row_image[:, tile.coords.left : overlap_coord_right]
|
||||||
|
dst_overlap = tile_image[:, :overlap_size]
|
||||||
|
blended_overlap = seam_blend(src_overlap, dst_overlap, blend_amount, x_seam=False)
|
||||||
|
row_image[:, tile.coords.left : overlap_coord_right] = blended_overlap
|
||||||
|
row_image[:, overlap_coord_right : tile.coords.right] = tile_image[:, overlap_size:]
|
||||||
|
else:
|
||||||
|
# no overlap just paste the tile
|
||||||
|
row_image[:, tile.coords.left : tile.coords.right] = tile_image
|
||||||
|
|
||||||
|
# Blend the row into the dst_image
|
||||||
|
# We assume that the entire row has the same vertical overlaps as the first_tile_in_row.
|
||||||
|
# Rows are processed in the same way as tiles (extract overlap, blend, apply)
|
||||||
|
row_overlap_size = first_tile_in_row.overlap.top
|
||||||
|
if row_overlap_size > 0:
|
||||||
|
assert row_overlap_size >= blend_amount
|
||||||
|
|
||||||
|
overlap_coords_bottom = first_tile_in_row.coords.top + row_overlap_size
|
||||||
|
src_overlap = dst_image[first_tile_in_row.coords.top : overlap_coords_bottom, :]
|
||||||
|
dst_overlap = row_image[:row_overlap_size, :]
|
||||||
|
blended_overlap = seam_blend(src_overlap, dst_overlap, blend_amount, x_seam=True)
|
||||||
|
dst_image[first_tile_in_row.coords.top : overlap_coords_bottom, :] = blended_overlap
|
||||||
|
dst_image[overlap_coords_bottom : first_tile_in_row.coords.bottom, :] = row_image[row_overlap_size:, :]
|
||||||
|
else:
|
||||||
|
# no overlap just paste the row
|
||||||
|
dst_image[first_tile_in_row.coords.top : first_tile_in_row.coords.bottom, :] = row_image
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
|
import math
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@ -31,10 +33,10 @@ def paste(dst_image: np.ndarray, src_image: np.ndarray, box: TBLR, mask: Optiona
|
|||||||
"""Paste a source image into a destination image.
|
"""Paste a source image into a destination image.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dst_image (torch.Tensor): The destination image to paste into. Shape: (H, W, C).
|
dst_image (np.array): The destination image to paste into. Shape: (H, W, C).
|
||||||
src_image (torch.Tensor): The source image to paste. Shape: (H, W, C). H and W must be compatible with 'box'.
|
src_image (np.array): The source image to paste. Shape: (H, W, C). H and W must be compatible with 'box'.
|
||||||
box (TBLR): Box defining the region in the 'dst_image' where 'src_image' will be pasted.
|
box (TBLR): Box defining the region in the 'dst_image' where 'src_image' will be pasted.
|
||||||
mask (Optional[torch.Tensor]): A mask that defines the blending between 'src_image' and 'dst_image'.
|
mask (Optional[np.array]): A mask that defines the blending between 'src_image' and 'dst_image'.
|
||||||
Range: [0.0, 1.0], Shape: (H, W). The output is calculate per-pixel according to
|
Range: [0.0, 1.0], Shape: (H, W). The output is calculate per-pixel according to
|
||||||
`src * mask + dst * (1 - mask)`.
|
`src * mask + dst * (1 - mask)`.
|
||||||
"""
|
"""
|
||||||
@ -45,3 +47,106 @@ def paste(dst_image: np.ndarray, src_image: np.ndarray, box: TBLR, mask: Optiona
|
|||||||
mask = np.expand_dims(mask, -1)
|
mask = np.expand_dims(mask, -1)
|
||||||
dst_image_box = dst_image[box.top : box.bottom, box.left : box.right]
|
dst_image_box = dst_image[box.top : box.bottom, box.left : box.right]
|
||||||
dst_image[box.top : box.bottom, box.left : box.right] = src_image * mask + dst_image_box * (1.0 - mask)
|
dst_image[box.top : box.bottom, box.left : box.right] = src_image * mask + dst_image_box * (1.0 - mask)
|
||||||
|
|
||||||
|
|
||||||
|
def seam_blend(ia1: np.ndarray, ia2: np.ndarray, blend_amount: int, x_seam: bool) -> np.ndarray:
|
||||||
|
"""Blend two overlapping tile sections using a seams to find a path.
|
||||||
|
|
||||||
|
It is assumed that input images will be RGB np arrays and are the same size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ia1 (np.array): Image array 1 Shape: (H, W, C).
|
||||||
|
ia2 (np.array): Image array 2 Shape: (H, W, C).
|
||||||
|
x_seam (bool): If the images should be blended on the x axis or not.
|
||||||
|
blend_amount (int): The size of the blur to use on the seam. Half of this value will be used to avoid the edges of the image.
|
||||||
|
"""
|
||||||
|
assert ia1.shape == ia2.shape
|
||||||
|
assert ia2.size == ia2.size
|
||||||
|
|
||||||
|
def shift(arr, num, fill_value=255.0):
|
||||||
|
result = np.full_like(arr, fill_value)
|
||||||
|
if num > 0:
|
||||||
|
result[num:] = arr[:-num]
|
||||||
|
elif num < 0:
|
||||||
|
result[:num] = arr[-num:]
|
||||||
|
else:
|
||||||
|
result[:] = arr
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Assume RGB and convert to grey
|
||||||
|
# Could offer other options for the luminance conversion
|
||||||
|
# BT.709 [0.2126, 0.7152, 0.0722], BT.2020 [0.2627, 0.6780, 0.0593])
|
||||||
|
# it might not have a huge impact due to the blur that is applied over the seam
|
||||||
|
iag1 = np.dot(ia1, [0.2989, 0.5870, 0.1140]) # BT.601 perceived brightness
|
||||||
|
iag2 = np.dot(ia2, [0.2989, 0.5870, 0.1140])
|
||||||
|
|
||||||
|
# Calc Difference between the images
|
||||||
|
ia = iag2 - iag1
|
||||||
|
|
||||||
|
# If the seam is on the X-axis rotate the array so we can treat it like a vertical seam
|
||||||
|
if x_seam:
|
||||||
|
ia = np.rot90(ia, 1)
|
||||||
|
|
||||||
|
# Calc max and min X & Y limits
|
||||||
|
# gutter is used to avoid the blur hitting the edge of the image
|
||||||
|
gutter = math.ceil(blend_amount / 2) if blend_amount > 0 else 0
|
||||||
|
max_y, max_x = ia.shape
|
||||||
|
max_x -= gutter
|
||||||
|
min_x = gutter
|
||||||
|
|
||||||
|
# Calc the energy in the difference
|
||||||
|
# Could offer different energy calculations e.g. Sobel or Scharr
|
||||||
|
energy = np.abs(np.gradient(ia, axis=0)) + np.abs(np.gradient(ia, axis=1))
|
||||||
|
|
||||||
|
# Find the starting position of the seam
|
||||||
|
res = np.copy(energy)
|
||||||
|
for y in range(1, max_y):
|
||||||
|
row = res[y, :]
|
||||||
|
rowl = shift(row, -1)
|
||||||
|
rowr = shift(row, 1)
|
||||||
|
res[y, :] = res[y - 1, :] + np.min([row, rowl, rowr], axis=0)
|
||||||
|
|
||||||
|
# create an array max_y long
|
||||||
|
lowest_energy_line = np.empty([max_y], dtype="uint16")
|
||||||
|
lowest_energy_line[max_y - 1] = np.argmin(res[max_y - 1, min_x : max_x - 1])
|
||||||
|
|
||||||
|
# Calc the path of the seam
|
||||||
|
# could offer options for larger search than just 1 pixel by adjusting lpos and rpos
|
||||||
|
for ypos in range(max_y - 2, -1, -1):
|
||||||
|
lowest_pos = lowest_energy_line[ypos + 1]
|
||||||
|
lpos = lowest_pos - 1
|
||||||
|
rpos = lowest_pos + 1
|
||||||
|
lpos = np.clip(lpos, min_x, max_x - 1)
|
||||||
|
rpos = np.clip(rpos, min_x, max_x - 1)
|
||||||
|
lowest_energy_line[ypos] = np.argmin(energy[ypos, lpos : rpos + 1]) + lpos
|
||||||
|
|
||||||
|
# Draw the mask
|
||||||
|
mask = np.zeros_like(ia)
|
||||||
|
for ypos in range(0, max_y):
|
||||||
|
to_fill = lowest_energy_line[ypos]
|
||||||
|
mask[ypos, :to_fill] = 1
|
||||||
|
|
||||||
|
# If the seam is on the X-axis rotate the array back
|
||||||
|
if x_seam:
|
||||||
|
mask = np.rot90(mask, 3)
|
||||||
|
|
||||||
|
# blur the seam mask if required
|
||||||
|
if blend_amount > 0:
|
||||||
|
mask = cv2.blur(mask, (blend_amount, blend_amount))
|
||||||
|
|
||||||
|
# for visual debugging
|
||||||
|
# from PIL import Image
|
||||||
|
# m_image = Image.fromarray((mask * 255.0).astype("uint8"))
|
||||||
|
|
||||||
|
# copy ia2 over ia1 while applying the seam mask
|
||||||
|
mask = np.expand_dims(mask, -1)
|
||||||
|
blended_image = ia1 * mask + ia2 * (1.0 - mask)
|
||||||
|
|
||||||
|
# for visual debugging
|
||||||
|
# i1 = Image.fromarray(ia1.astype("uint8"))
|
||||||
|
# i2 = Image.fromarray(ia2.astype("uint8"))
|
||||||
|
# b_image = Image.fromarray(blended_image.astype("uint8"))
|
||||||
|
# print(f"{ia1.shape}, {ia2.shape}, {mask.shape}, {blended_image.shape}")
|
||||||
|
# print(f"{i1.size}, {i2.size}, {m_image.size}, {b_image.size}")
|
||||||
|
|
||||||
|
return blended_image
|
||||||
|
@ -1032,7 +1032,9 @@
|
|||||||
"workflowValidation": "Workflow Validation Error",
|
"workflowValidation": "Workflow Validation Error",
|
||||||
"workflowVersion": "Version",
|
"workflowVersion": "Version",
|
||||||
"zoomInNodes": "Zoom In",
|
"zoomInNodes": "Zoom In",
|
||||||
"zoomOutNodes": "Zoom Out"
|
"zoomOutNodes": "Zoom Out",
|
||||||
|
"betaDesc": "This invocation is in beta. Until it is stable, it may have breaking changes during app updates. We plan to support this invocation long-term.",
|
||||||
|
"prototypeDesc": "This invocation is a prototype. It may have breaking changes during app updates and may be removed at any time."
|
||||||
},
|
},
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"aspectRatio": "Aspect Ratio",
|
"aspectRatio": "Aspect Ratio",
|
||||||
|
@ -109,7 +109,8 @@
|
|||||||
"somethingWentWrong": "出了点问题",
|
"somethingWentWrong": "出了点问题",
|
||||||
"copyError": "$t(gallery.copy) 错误",
|
"copyError": "$t(gallery.copy) 错误",
|
||||||
"input": "输入",
|
"input": "输入",
|
||||||
"notInstalled": "非 $t(common.installed)"
|
"notInstalled": "非 $t(common.installed)",
|
||||||
|
"delete": "删除"
|
||||||
},
|
},
|
||||||
"gallery": {
|
"gallery": {
|
||||||
"generations": "生成的图像",
|
"generations": "生成的图像",
|
||||||
|
@ -3,6 +3,7 @@ import { useStore } from '@nanostores/react';
|
|||||||
import { $customStarUI } from 'app/store/nanostores/customStarUI';
|
import { $customStarUI } from 'app/store/nanostores/customStarUI';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAIDndImage from 'common/components/IAIDndImage';
|
import IAIDndImage from 'common/components/IAIDndImage';
|
||||||
|
import IAIDndImageIcon from 'common/components/IAIDndImageIcon';
|
||||||
import IAIFillSkeleton from 'common/components/IAIFillSkeleton';
|
import IAIFillSkeleton from 'common/components/IAIFillSkeleton';
|
||||||
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
|
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
|
||||||
import {
|
import {
|
||||||
@ -10,7 +11,9 @@ import {
|
|||||||
ImageDraggableData,
|
ImageDraggableData,
|
||||||
TypesafeDraggableData,
|
TypesafeDraggableData,
|
||||||
} from 'features/dnd/types';
|
} from 'features/dnd/types';
|
||||||
|
import { VirtuosoGalleryContext } from 'features/gallery/components/ImageGrid/types';
|
||||||
import { useMultiselect } from 'features/gallery/hooks/useMultiselect';
|
import { useMultiselect } from 'features/gallery/hooks/useMultiselect';
|
||||||
|
import { useScrollToVisible } from 'features/gallery/hooks/useScrollToVisible';
|
||||||
import { MouseEvent, memo, useCallback, useMemo, useState } from 'react';
|
import { MouseEvent, memo, useCallback, useMemo, useState } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { FaTrash } from 'react-icons/fa';
|
import { FaTrash } from 'react-icons/fa';
|
||||||
@ -20,15 +23,16 @@ import {
|
|||||||
useStarImagesMutation,
|
useStarImagesMutation,
|
||||||
useUnstarImagesMutation,
|
useUnstarImagesMutation,
|
||||||
} from 'services/api/endpoints/images';
|
} from 'services/api/endpoints/images';
|
||||||
import IAIDndImageIcon from 'common/components/IAIDndImageIcon';
|
|
||||||
|
|
||||||
interface HoverableImageProps {
|
interface HoverableImageProps {
|
||||||
imageName: string;
|
imageName: string;
|
||||||
|
index: number;
|
||||||
|
virtuosoContext: VirtuosoGalleryContext;
|
||||||
}
|
}
|
||||||
|
|
||||||
const GalleryImage = (props: HoverableImageProps) => {
|
const GalleryImage = (props: HoverableImageProps) => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { imageName } = props;
|
const { imageName, virtuosoContext } = props;
|
||||||
const { currentData: imageDTO } = useGetImageDTOQuery(imageName);
|
const { currentData: imageDTO } = useGetImageDTOQuery(imageName);
|
||||||
const shift = useAppSelector((state) => state.hotkeys.shift);
|
const shift = useAppSelector((state) => state.hotkeys.shift);
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
@ -38,6 +42,13 @@ const GalleryImage = (props: HoverableImageProps) => {
|
|||||||
|
|
||||||
const customStarUi = useStore($customStarUI);
|
const customStarUi = useStore($customStarUI);
|
||||||
|
|
||||||
|
const imageContainerRef = useScrollToVisible(
|
||||||
|
isSelected,
|
||||||
|
props.index,
|
||||||
|
selectionCount,
|
||||||
|
virtuosoContext
|
||||||
|
);
|
||||||
|
|
||||||
const handleDelete = useCallback(
|
const handleDelete = useCallback(
|
||||||
(e: MouseEvent<HTMLButtonElement>) => {
|
(e: MouseEvent<HTMLButtonElement>) => {
|
||||||
e.stopPropagation();
|
e.stopPropagation();
|
||||||
@ -122,6 +133,7 @@ const GalleryImage = (props: HoverableImageProps) => {
|
|||||||
data-testid={`image-${imageDTO.image_name}`}
|
data-testid={`image-${imageDTO.image_name}`}
|
||||||
>
|
>
|
||||||
<Flex
|
<Flex
|
||||||
|
ref={imageContainerRef}
|
||||||
userSelect="none"
|
userSelect="none"
|
||||||
sx={{
|
sx={{
|
||||||
position: 'relative',
|
position: 'relative',
|
||||||
|
@ -1,7 +1,10 @@
|
|||||||
import { Box, Flex } from '@chakra-ui/react';
|
import { Box, Flex } from '@chakra-ui/react';
|
||||||
|
import { EntityId } from '@reduxjs/toolkit';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAIButton from 'common/components/IAIButton';
|
import IAIButton from 'common/components/IAIButton';
|
||||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||||
|
import { VirtuosoGalleryContext } from 'features/gallery/components/ImageGrid/types';
|
||||||
|
import { $useNextPrevImageState } from 'features/gallery/hooks/useNextPrevImage';
|
||||||
import { selectListImagesBaseQueryArgs } from 'features/gallery/store/gallerySelectors';
|
import { selectListImagesBaseQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||||
import { IMAGE_LIMIT } from 'features/gallery/store/types';
|
import { IMAGE_LIMIT } from 'features/gallery/store/types';
|
||||||
import {
|
import {
|
||||||
@ -11,7 +14,12 @@ import {
|
|||||||
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
|
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { FaExclamationCircle, FaImage } from 'react-icons/fa';
|
import { FaExclamationCircle, FaImage } from 'react-icons/fa';
|
||||||
import { VirtuosoGrid } from 'react-virtuoso';
|
import {
|
||||||
|
ItemContent,
|
||||||
|
ListRange,
|
||||||
|
VirtuosoGrid,
|
||||||
|
VirtuosoGridHandle,
|
||||||
|
} from 'react-virtuoso';
|
||||||
import {
|
import {
|
||||||
useLazyListImagesQuery,
|
useLazyListImagesQuery,
|
||||||
useListImagesQuery,
|
useListImagesQuery,
|
||||||
@ -20,7 +28,6 @@ import { useBoardTotal } from 'services/api/hooks/useBoardTotal';
|
|||||||
import GalleryImage from './GalleryImage';
|
import GalleryImage from './GalleryImage';
|
||||||
import ImageGridItemContainer from './ImageGridItemContainer';
|
import ImageGridItemContainer from './ImageGridItemContainer';
|
||||||
import ImageGridListContainer from './ImageGridListContainer';
|
import ImageGridListContainer from './ImageGridListContainer';
|
||||||
import { EntityId } from '@reduxjs/toolkit';
|
|
||||||
|
|
||||||
const overlayScrollbarsConfig: UseOverlayScrollbarsParams = {
|
const overlayScrollbarsConfig: UseOverlayScrollbarsParams = {
|
||||||
defer: true,
|
defer: true,
|
||||||
@ -48,6 +55,10 @@ const GalleryImageGrid = () => {
|
|||||||
const { currentViewTotal } = useBoardTotal(selectedBoardId);
|
const { currentViewTotal } = useBoardTotal(selectedBoardId);
|
||||||
const queryArgs = useAppSelector(selectListImagesBaseQueryArgs);
|
const queryArgs = useAppSelector(selectListImagesBaseQueryArgs);
|
||||||
|
|
||||||
|
const virtuosoRangeRef = useRef<ListRange | null>(null);
|
||||||
|
|
||||||
|
const virtuosoRef = useRef<VirtuosoGridHandle>(null);
|
||||||
|
|
||||||
const { currentData, isFetching, isSuccess, isError } =
|
const { currentData, isFetching, isSuccess, isError } =
|
||||||
useListImagesQuery(queryArgs);
|
useListImagesQuery(queryArgs);
|
||||||
|
|
||||||
@ -72,12 +83,26 @@ const GalleryImageGrid = () => {
|
|||||||
});
|
});
|
||||||
}, [areMoreAvailable, listImages, queryArgs, currentData?.ids.length]);
|
}, [areMoreAvailable, listImages, queryArgs, currentData?.ids.length]);
|
||||||
|
|
||||||
const itemContentFunc = useCallback(
|
const virtuosoContext = useMemo<VirtuosoGalleryContext>(() => {
|
||||||
(index: number, imageName: EntityId) => (
|
return {
|
||||||
<GalleryImage key={imageName} imageName={imageName as string} />
|
virtuosoRef,
|
||||||
),
|
rootRef,
|
||||||
[]
|
virtuosoRangeRef,
|
||||||
);
|
};
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const itemContentFunc: ItemContent<EntityId, VirtuosoGalleryContext> =
|
||||||
|
useCallback(
|
||||||
|
(index, imageName, virtuosoContext) => (
|
||||||
|
<GalleryImage
|
||||||
|
key={imageName}
|
||||||
|
index={index}
|
||||||
|
imageName={imageName as string}
|
||||||
|
virtuosoContext={virtuosoContext}
|
||||||
|
/>
|
||||||
|
),
|
||||||
|
[]
|
||||||
|
);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
// Initialize the gallery's custom scrollbar
|
// Initialize the gallery's custom scrollbar
|
||||||
@ -93,6 +118,15 @@ const GalleryImageGrid = () => {
|
|||||||
return () => osInstance()?.destroy();
|
return () => osInstance()?.destroy();
|
||||||
}, [scroller, initialize, osInstance]);
|
}, [scroller, initialize, osInstance]);
|
||||||
|
|
||||||
|
const onRangeChanged = useCallback((range: ListRange) => {
|
||||||
|
virtuosoRangeRef.current = range;
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
$useNextPrevImageState.setKey('virtuosoRef', virtuosoRef);
|
||||||
|
$useNextPrevImageState.setKey('virtuosoRangeRef', virtuosoRangeRef);
|
||||||
|
}, []);
|
||||||
|
|
||||||
if (!currentData) {
|
if (!currentData) {
|
||||||
return (
|
return (
|
||||||
<Flex
|
<Flex
|
||||||
@ -140,6 +174,10 @@ const GalleryImageGrid = () => {
|
|||||||
}}
|
}}
|
||||||
scrollerRef={setScroller}
|
scrollerRef={setScroller}
|
||||||
itemContent={itemContentFunc}
|
itemContent={itemContentFunc}
|
||||||
|
ref={virtuosoRef}
|
||||||
|
rangeChanged={onRangeChanged}
|
||||||
|
context={virtuosoContext}
|
||||||
|
overscan={10}
|
||||||
/>
|
/>
|
||||||
</Box>
|
</Box>
|
||||||
<IAIButton
|
<IAIButton
|
||||||
|
@ -0,0 +1,8 @@
|
|||||||
|
import { RefObject } from 'react';
|
||||||
|
import { ListRange, VirtuosoGridHandle } from 'react-virtuoso';
|
||||||
|
|
||||||
|
export type VirtuosoGalleryContext = {
|
||||||
|
virtuosoRef: RefObject<VirtuosoGridHandle>;
|
||||||
|
rootRef: RefObject<HTMLDivElement>;
|
||||||
|
virtuosoRangeRef: RefObject<ListRange>;
|
||||||
|
};
|
@ -1,7 +1,7 @@
|
|||||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useGetImageWorkflowQuery } from 'services/api/endpoints/images';
|
import { useDebouncedImageWorkflow } from 'services/api/hooks/useDebouncedImageWorkflow';
|
||||||
import { ImageDTO } from 'services/api/types';
|
import { ImageDTO } from 'services/api/types';
|
||||||
import DataViewer from './DataViewer';
|
import DataViewer from './DataViewer';
|
||||||
|
|
||||||
@ -11,7 +11,7 @@ type Props = {
|
|||||||
|
|
||||||
const ImageMetadataWorkflowTabContent = ({ image }: Props) => {
|
const ImageMetadataWorkflowTabContent = ({ image }: Props) => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const { currentData: workflow } = useGetImageWorkflowQuery(image.image_name);
|
const { workflow } = useDebouncedImageWorkflow(image);
|
||||||
|
|
||||||
if (!workflow) {
|
if (!workflow) {
|
||||||
return <IAINoContentFallback label={t('nodes.noWorkflow')} />;
|
return <IAINoContentFallback label={t('nodes.noWorkflow')} />;
|
||||||
|
@ -4,8 +4,11 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|||||||
import { selectListImagesBaseQueryArgs } from 'features/gallery/store/gallerySelectors';
|
import { selectListImagesBaseQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||||
import { IMAGE_LIMIT } from 'features/gallery/store/types';
|
import { IMAGE_LIMIT } from 'features/gallery/store/types';
|
||||||
|
import { getScrollToIndexAlign } from 'features/gallery/util/getScrollToIndexAlign';
|
||||||
import { clamp } from 'lodash-es';
|
import { clamp } from 'lodash-es';
|
||||||
import { useCallback } from 'react';
|
import { map } from 'nanostores';
|
||||||
|
import { RefObject, useCallback } from 'react';
|
||||||
|
import { ListRange, VirtuosoGridHandle } from 'react-virtuoso';
|
||||||
import { boardsApi } from 'services/api/endpoints/boards';
|
import { boardsApi } from 'services/api/endpoints/boards';
|
||||||
import {
|
import {
|
||||||
imagesApi,
|
imagesApi,
|
||||||
@ -14,6 +17,16 @@ import {
|
|||||||
import { ListImagesArgs } from 'services/api/types';
|
import { ListImagesArgs } from 'services/api/types';
|
||||||
import { imagesAdapter } from 'services/api/util';
|
import { imagesAdapter } from 'services/api/util';
|
||||||
|
|
||||||
|
export type UseNextPrevImageState = {
|
||||||
|
virtuosoRef: RefObject<VirtuosoGridHandle> | undefined;
|
||||||
|
virtuosoRangeRef: RefObject<ListRange> | undefined;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const $useNextPrevImageState = map<UseNextPrevImageState>({
|
||||||
|
virtuosoRef: undefined,
|
||||||
|
virtuosoRangeRef: undefined,
|
||||||
|
});
|
||||||
|
|
||||||
export const nextPrevImageButtonsSelector = createMemoizedSelector(
|
export const nextPrevImageButtonsSelector = createMemoizedSelector(
|
||||||
[stateSelector, selectListImagesBaseQueryArgs],
|
[stateSelector, selectListImagesBaseQueryArgs],
|
||||||
(state, baseQueryArgs) => {
|
(state, baseQueryArgs) => {
|
||||||
@ -78,6 +91,8 @@ export const nextPrevImageButtonsSelector = createMemoizedSelector(
|
|||||||
isFetching: status === 'pending',
|
isFetching: status === 'pending',
|
||||||
nextImage,
|
nextImage,
|
||||||
prevImage,
|
prevImage,
|
||||||
|
nextImageIndex,
|
||||||
|
prevImageIndex,
|
||||||
queryArgs,
|
queryArgs,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@ -88,7 +103,9 @@ export const useNextPrevImage = () => {
|
|||||||
|
|
||||||
const {
|
const {
|
||||||
nextImage,
|
nextImage,
|
||||||
|
nextImageIndex,
|
||||||
prevImage,
|
prevImage,
|
||||||
|
prevImageIndex,
|
||||||
areMoreImagesAvailable,
|
areMoreImagesAvailable,
|
||||||
isFetching,
|
isFetching,
|
||||||
queryArgs,
|
queryArgs,
|
||||||
@ -98,11 +115,43 @@ export const useNextPrevImage = () => {
|
|||||||
|
|
||||||
const handlePrevImage = useCallback(() => {
|
const handlePrevImage = useCallback(() => {
|
||||||
prevImage && dispatch(imageSelected(prevImage));
|
prevImage && dispatch(imageSelected(prevImage));
|
||||||
}, [dispatch, prevImage]);
|
const range = $useNextPrevImageState.get().virtuosoRangeRef?.current;
|
||||||
|
const virtuoso = $useNextPrevImageState.get().virtuosoRef?.current;
|
||||||
|
if (!range || !virtuoso) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
prevImageIndex !== undefined &&
|
||||||
|
(prevImageIndex < range.startIndex || prevImageIndex > range.endIndex)
|
||||||
|
) {
|
||||||
|
virtuoso.scrollToIndex({
|
||||||
|
index: prevImageIndex,
|
||||||
|
behavior: 'smooth',
|
||||||
|
align: getScrollToIndexAlign(prevImageIndex, range),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}, [dispatch, prevImage, prevImageIndex]);
|
||||||
|
|
||||||
const handleNextImage = useCallback(() => {
|
const handleNextImage = useCallback(() => {
|
||||||
nextImage && dispatch(imageSelected(nextImage));
|
nextImage && dispatch(imageSelected(nextImage));
|
||||||
}, [dispatch, nextImage]);
|
const range = $useNextPrevImageState.get().virtuosoRangeRef?.current;
|
||||||
|
const virtuoso = $useNextPrevImageState.get().virtuosoRef?.current;
|
||||||
|
if (!range || !virtuoso) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
nextImageIndex !== undefined &&
|
||||||
|
(nextImageIndex < range.startIndex || nextImageIndex > range.endIndex)
|
||||||
|
) {
|
||||||
|
virtuoso.scrollToIndex({
|
||||||
|
index: nextImageIndex,
|
||||||
|
behavior: 'smooth',
|
||||||
|
align: getScrollToIndexAlign(nextImageIndex, range),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}, [dispatch, nextImage, nextImageIndex]);
|
||||||
|
|
||||||
const [listImages] = useLazyListImagesQuery();
|
const [listImages] = useLazyListImagesQuery();
|
||||||
|
|
||||||
|
@ -0,0 +1,46 @@
|
|||||||
|
import { VirtuosoGalleryContext } from 'features/gallery/components/ImageGrid/types';
|
||||||
|
import { getScrollToIndexAlign } from 'features/gallery/util/getScrollToIndexAlign';
|
||||||
|
import { useEffect, useRef } from 'react';
|
||||||
|
|
||||||
|
export const useScrollToVisible = (
|
||||||
|
isSelected: boolean,
|
||||||
|
index: number,
|
||||||
|
selectionCount: number,
|
||||||
|
virtuosoContext: VirtuosoGalleryContext
|
||||||
|
) => {
|
||||||
|
const imageContainerRef = useRef<HTMLDivElement>(null);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (
|
||||||
|
!isSelected ||
|
||||||
|
selectionCount !== 1 ||
|
||||||
|
!virtuosoContext.rootRef.current ||
|
||||||
|
!virtuosoContext.virtuosoRef.current ||
|
||||||
|
!virtuosoContext.virtuosoRangeRef.current ||
|
||||||
|
!imageContainerRef.current
|
||||||
|
) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const itemRect = imageContainerRef.current.getBoundingClientRect();
|
||||||
|
const rootRect = virtuosoContext.rootRef.current.getBoundingClientRect();
|
||||||
|
const itemIsVisible =
|
||||||
|
itemRect.top >= rootRect.top &&
|
||||||
|
itemRect.bottom <= rootRect.bottom &&
|
||||||
|
itemRect.left >= rootRect.left &&
|
||||||
|
itemRect.right <= rootRect.right;
|
||||||
|
|
||||||
|
if (!itemIsVisible) {
|
||||||
|
virtuosoContext.virtuosoRef.current.scrollToIndex({
|
||||||
|
index,
|
||||||
|
behavior: 'smooth',
|
||||||
|
align: getScrollToIndexAlign(
|
||||||
|
index,
|
||||||
|
virtuosoContext.virtuosoRangeRef.current
|
||||||
|
),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}, [isSelected, index, selectionCount, virtuosoContext]);
|
||||||
|
|
||||||
|
return imageContainerRef;
|
||||||
|
};
|
@ -0,0 +1,17 @@
|
|||||||
|
import { ListRange } from 'react-virtuoso';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets the alignment for react-virtuoso's scrollToIndex function.
|
||||||
|
* @param index The index of the item.
|
||||||
|
* @param range The range of items currently visible.
|
||||||
|
* @returns
|
||||||
|
*/
|
||||||
|
export const getScrollToIndexAlign = (
|
||||||
|
index: number,
|
||||||
|
range: ListRange
|
||||||
|
): 'start' | 'end' => {
|
||||||
|
if (index > (range.endIndex - range.startIndex) / 2 + range.startIndex) {
|
||||||
|
return 'end';
|
||||||
|
}
|
||||||
|
return 'start';
|
||||||
|
};
|
@ -0,0 +1,68 @@
|
|||||||
|
import { Icon, Tooltip } from '@chakra-ui/react';
|
||||||
|
import { memo } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { FaFlask } from 'react-icons/fa';
|
||||||
|
import { useNodeClassification } from 'features/nodes/hooks/useNodeClassification';
|
||||||
|
import { Classification } from 'features/nodes/types/common';
|
||||||
|
import { FaHammer } from 'react-icons/fa6';
|
||||||
|
|
||||||
|
interface Props {
|
||||||
|
nodeId: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
const InvocationNodeClassificationIcon = ({ nodeId }: Props) => {
|
||||||
|
const classification = useNodeClassification(nodeId);
|
||||||
|
|
||||||
|
if (!classification || classification === 'stable') {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Tooltip
|
||||||
|
label={<ClassificationTooltipContent classification={classification} />}
|
||||||
|
placement="top"
|
||||||
|
shouldWrapChildren
|
||||||
|
>
|
||||||
|
<Icon
|
||||||
|
as={getIcon(classification)}
|
||||||
|
sx={{
|
||||||
|
display: 'block',
|
||||||
|
boxSize: 4,
|
||||||
|
color: 'base.400',
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</Tooltip>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(InvocationNodeClassificationIcon);
|
||||||
|
|
||||||
|
const ClassificationTooltipContent = memo(
|
||||||
|
({ classification }: { classification: Classification }) => {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
if (classification === 'beta') {
|
||||||
|
return t('nodes.betaDesc');
|
||||||
|
}
|
||||||
|
|
||||||
|
if (classification === 'prototype') {
|
||||||
|
return t('nodes.prototypeDesc');
|
||||||
|
}
|
||||||
|
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
ClassificationTooltipContent.displayName = 'ClassificationTooltipContent';
|
||||||
|
|
||||||
|
const getIcon = (classification: Classification) => {
|
||||||
|
if (classification === 'beta') {
|
||||||
|
return FaHammer;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (classification === 'prototype') {
|
||||||
|
return FaFlask;
|
||||||
|
}
|
||||||
|
|
||||||
|
return undefined;
|
||||||
|
};
|
@ -5,6 +5,7 @@ import NodeTitle from 'features/nodes/components/flow/nodes/common/NodeTitle';
|
|||||||
import InvocationNodeCollapsedHandles from './InvocationNodeCollapsedHandles';
|
import InvocationNodeCollapsedHandles from './InvocationNodeCollapsedHandles';
|
||||||
import InvocationNodeInfoIcon from './InvocationNodeInfoIcon';
|
import InvocationNodeInfoIcon from './InvocationNodeInfoIcon';
|
||||||
import InvocationNodeStatusIndicator from './InvocationNodeStatusIndicator';
|
import InvocationNodeStatusIndicator from './InvocationNodeStatusIndicator';
|
||||||
|
import InvocationNodeClassificationIcon from 'features/nodes/components/flow/nodes/Invocation/InvocationNodeClassificationIcon';
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
nodeId: string;
|
nodeId: string;
|
||||||
@ -31,6 +32,7 @@ const InvocationNodeHeader = ({ nodeId, isOpen }: Props) => {
|
|||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<NodeCollapseButton nodeId={nodeId} isOpen={isOpen} />
|
<NodeCollapseButton nodeId={nodeId} isOpen={isOpen} />
|
||||||
|
<InvocationNodeClassificationIcon nodeId={nodeId} />
|
||||||
<NodeTitle nodeId={nodeId} />
|
<NodeTitle nodeId={nodeId} />
|
||||||
<Flex alignItems="center">
|
<Flex alignItems="center">
|
||||||
<InvocationNodeStatusIndicator nodeId={nodeId} />
|
<InvocationNodeStatusIndicator nodeId={nodeId} />
|
||||||
|
@ -0,0 +1,23 @@
|
|||||||
|
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||||
|
import { useMemo } from 'react';
|
||||||
|
|
||||||
|
export const useNodeClassification = (nodeId: string) => {
|
||||||
|
const selector = useMemo(
|
||||||
|
() =>
|
||||||
|
createMemoizedSelector(stateSelector, ({ nodes }) => {
|
||||||
|
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||||
|
if (!isInvocationNode(node)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? ''];
|
||||||
|
return nodeTemplate?.classification;
|
||||||
|
}),
|
||||||
|
[nodeId]
|
||||||
|
);
|
||||||
|
|
||||||
|
const title = useAppSelector(selector);
|
||||||
|
return title;
|
||||||
|
};
|
@ -19,6 +19,9 @@ export const zColorField = z.object({
|
|||||||
});
|
});
|
||||||
export type ColorField = z.infer<typeof zColorField>;
|
export type ColorField = z.infer<typeof zColorField>;
|
||||||
|
|
||||||
|
export const zClassification = z.enum(['stable', 'beta', 'prototype']);
|
||||||
|
export type Classification = z.infer<typeof zClassification>;
|
||||||
|
|
||||||
export const zSchedulerField = z.enum([
|
export const zSchedulerField = z.enum([
|
||||||
'euler',
|
'euler',
|
||||||
'deis',
|
'deis',
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import { Edge, Node } from 'reactflow';
|
import { Edge, Node } from 'reactflow';
|
||||||
import { z } from 'zod';
|
import { z } from 'zod';
|
||||||
import { zProgressImage } from './common';
|
import { zClassification, zProgressImage } from './common';
|
||||||
import {
|
import {
|
||||||
zFieldInputInstance,
|
zFieldInputInstance,
|
||||||
zFieldInputTemplate,
|
zFieldInputTemplate,
|
||||||
@ -21,6 +21,7 @@ export const zInvocationTemplate = z.object({
|
|||||||
version: zSemVer,
|
version: zSemVer,
|
||||||
useCache: z.boolean(),
|
useCache: z.boolean(),
|
||||||
nodePack: z.string().min(1).nullish(),
|
nodePack: z.string().min(1).nullish(),
|
||||||
|
classification: zClassification,
|
||||||
});
|
});
|
||||||
export type InvocationTemplate = z.infer<typeof zInvocationTemplate>;
|
export type InvocationTemplate = z.infer<typeof zInvocationTemplate>;
|
||||||
// #endregion
|
// #endregion
|
||||||
|
@ -83,6 +83,7 @@ export const parseSchema = (
|
|||||||
const description = schema.description ?? '';
|
const description = schema.description ?? '';
|
||||||
const version = schema.version;
|
const version = schema.version;
|
||||||
const nodePack = schema.node_pack;
|
const nodePack = schema.node_pack;
|
||||||
|
const classification = schema.classification;
|
||||||
|
|
||||||
const inputs = reduce(
|
const inputs = reduce(
|
||||||
schema.properties,
|
schema.properties,
|
||||||
@ -245,6 +246,7 @@ export const parseSchema = (
|
|||||||
outputs,
|
outputs,
|
||||||
useCache,
|
useCache,
|
||||||
nodePack,
|
nodePack,
|
||||||
|
classification,
|
||||||
};
|
};
|
||||||
|
|
||||||
Object.assign(invocationsAccumulator, { [type]: invocation });
|
Object.assign(invocationsAccumulator, { [type]: invocation });
|
||||||
|
@ -0,0 +1,22 @@
|
|||||||
|
import { skipToken } from '@reduxjs/toolkit/query';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { useGetImageWorkflowQuery } from 'services/api/endpoints/images';
|
||||||
|
import { ImageDTO } from 'services/api/types';
|
||||||
|
import { useDebounce } from 'use-debounce';
|
||||||
|
|
||||||
|
export const useDebouncedImageWorkflow = (imageDTO?: ImageDTO | null) => {
|
||||||
|
const workflowFetchDebounce = useAppSelector(
|
||||||
|
(state) => state.config.workflowFetchDebounce ?? 300
|
||||||
|
);
|
||||||
|
|
||||||
|
const [debouncedImageName] = useDebounce(
|
||||||
|
imageDTO?.has_workflow ? imageDTO.image_name : null,
|
||||||
|
workflowFetchDebounce
|
||||||
|
);
|
||||||
|
|
||||||
|
const { data: workflow, isLoading } = useGetImageWorkflowQuery(
|
||||||
|
debouncedImageName ?? skipToken
|
||||||
|
);
|
||||||
|
|
||||||
|
return { workflow, isLoading };
|
||||||
|
};
|
@ -1,17 +1,14 @@
|
|||||||
import { skipToken } from '@reduxjs/toolkit/query';
|
import { skipToken } from '@reduxjs/toolkit/query';
|
||||||
import { useDebounce } from 'use-debounce';
|
|
||||||
import { useGetImageMetadataQuery } from 'services/api/endpoints/images';
|
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { useGetImageMetadataQuery } from 'services/api/endpoints/images';
|
||||||
|
import { useDebounce } from 'use-debounce';
|
||||||
|
|
||||||
export const useDebouncedMetadata = (imageName?: string | null) => {
|
export const useDebouncedMetadata = (imageName?: string | null) => {
|
||||||
const metadataFetchDebounce = useAppSelector(
|
const metadataFetchDebounce = useAppSelector(
|
||||||
(state) => state.config.metadataFetchDebounce
|
(state) => state.config.metadataFetchDebounce ?? 300
|
||||||
);
|
);
|
||||||
|
|
||||||
const [debouncedImageName] = useDebounce(
|
const [debouncedImageName] = useDebounce(imageName, metadataFetchDebounce);
|
||||||
imageName,
|
|
||||||
metadataFetchDebounce ?? 0
|
|
||||||
);
|
|
||||||
|
|
||||||
const { data: metadata, isLoading } = useGetImageMetadataQuery(
|
const { data: metadata, isLoading } = useGetImageMetadataQuery(
|
||||||
debouncedImageName ?? skipToken
|
debouncedImageName ?? skipToken
|
||||||
|
569
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
569
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
File diff suppressed because one or more lines are too long
@ -1,7 +1,12 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from invokeai.backend.tiles.tiles import calc_tiles_with_overlap, merge_tiles_with_linear_blending
|
from invokeai.backend.tiles.tiles import (
|
||||||
|
calc_tiles_even_split,
|
||||||
|
calc_tiles_min_overlap,
|
||||||
|
calc_tiles_with_overlap,
|
||||||
|
merge_tiles_with_linear_blending,
|
||||||
|
)
|
||||||
from invokeai.backend.tiles.utils import TBLR, Tile
|
from invokeai.backend.tiles.utils import TBLR, Tile
|
||||||
|
|
||||||
####################################
|
####################################
|
||||||
@ -14,7 +19,10 @@ def test_calc_tiles_with_overlap_single_tile():
|
|||||||
tiles = calc_tiles_with_overlap(image_height=512, image_width=1024, tile_height=512, tile_width=1024, overlap=64)
|
tiles = calc_tiles_with_overlap(image_height=512, image_width=1024, tile_height=512, tile_width=1024, overlap=64)
|
||||||
|
|
||||||
expected_tiles = [
|
expected_tiles = [
|
||||||
Tile(coords=TBLR(top=0, bottom=512, left=0, right=1024), overlap=TBLR(top=0, bottom=0, left=0, right=0))
|
Tile(
|
||||||
|
coords=TBLR(top=0, bottom=512, left=0, right=1024),
|
||||||
|
overlap=TBLR(top=0, bottom=0, left=0, right=0),
|
||||||
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
assert tiles == expected_tiles
|
assert tiles == expected_tiles
|
||||||
@ -27,13 +35,31 @@ def test_calc_tiles_with_overlap_evenly_divisible():
|
|||||||
|
|
||||||
expected_tiles = [
|
expected_tiles = [
|
||||||
# Row 0
|
# Row 0
|
||||||
Tile(coords=TBLR(top=0, bottom=320, left=0, right=576), overlap=TBLR(top=0, bottom=64, left=0, right=64)),
|
Tile(
|
||||||
Tile(coords=TBLR(top=0, bottom=320, left=512, right=1088), overlap=TBLR(top=0, bottom=64, left=64, right=64)),
|
coords=TBLR(top=0, bottom=320, left=0, right=576),
|
||||||
Tile(coords=TBLR(top=0, bottom=320, left=1024, right=1600), overlap=TBLR(top=0, bottom=64, left=64, right=0)),
|
overlap=TBLR(top=0, bottom=64, left=0, right=64),
|
||||||
|
),
|
||||||
|
Tile(
|
||||||
|
coords=TBLR(top=0, bottom=320, left=512, right=1088),
|
||||||
|
overlap=TBLR(top=0, bottom=64, left=64, right=64),
|
||||||
|
),
|
||||||
|
Tile(
|
||||||
|
coords=TBLR(top=0, bottom=320, left=1024, right=1600),
|
||||||
|
overlap=TBLR(top=0, bottom=64, left=64, right=0),
|
||||||
|
),
|
||||||
# Row 1
|
# Row 1
|
||||||
Tile(coords=TBLR(top=256, bottom=576, left=0, right=576), overlap=TBLR(top=64, bottom=0, left=0, right=64)),
|
Tile(
|
||||||
Tile(coords=TBLR(top=256, bottom=576, left=512, right=1088), overlap=TBLR(top=64, bottom=0, left=64, right=64)),
|
coords=TBLR(top=256, bottom=576, left=0, right=576),
|
||||||
Tile(coords=TBLR(top=256, bottom=576, left=1024, right=1600), overlap=TBLR(top=64, bottom=0, left=64, right=0)),
|
overlap=TBLR(top=64, bottom=0, left=0, right=64),
|
||||||
|
),
|
||||||
|
Tile(
|
||||||
|
coords=TBLR(top=256, bottom=576, left=512, right=1088),
|
||||||
|
overlap=TBLR(top=64, bottom=0, left=64, right=64),
|
||||||
|
),
|
||||||
|
Tile(
|
||||||
|
coords=TBLR(top=256, bottom=576, left=1024, right=1600),
|
||||||
|
overlap=TBLR(top=64, bottom=0, left=64, right=0),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
assert tiles == expected_tiles
|
assert tiles == expected_tiles
|
||||||
@ -46,16 +72,30 @@ def test_calc_tiles_with_overlap_not_evenly_divisible():
|
|||||||
|
|
||||||
expected_tiles = [
|
expected_tiles = [
|
||||||
# Row 0
|
# Row 0
|
||||||
Tile(coords=TBLR(top=0, bottom=256, left=0, right=512), overlap=TBLR(top=0, bottom=112, left=0, right=64)),
|
|
||||||
Tile(coords=TBLR(top=0, bottom=256, left=448, right=960), overlap=TBLR(top=0, bottom=112, left=64, right=272)),
|
|
||||||
Tile(coords=TBLR(top=0, bottom=256, left=688, right=1200), overlap=TBLR(top=0, bottom=112, left=272, right=0)),
|
|
||||||
# Row 1
|
|
||||||
Tile(coords=TBLR(top=144, bottom=400, left=0, right=512), overlap=TBLR(top=112, bottom=0, left=0, right=64)),
|
|
||||||
Tile(
|
Tile(
|
||||||
coords=TBLR(top=144, bottom=400, left=448, right=960), overlap=TBLR(top=112, bottom=0, left=64, right=272)
|
coords=TBLR(top=0, bottom=256, left=0, right=512),
|
||||||
|
overlap=TBLR(top=0, bottom=112, left=0, right=64),
|
||||||
),
|
),
|
||||||
Tile(
|
Tile(
|
||||||
coords=TBLR(top=144, bottom=400, left=688, right=1200), overlap=TBLR(top=112, bottom=0, left=272, right=0)
|
coords=TBLR(top=0, bottom=256, left=448, right=960),
|
||||||
|
overlap=TBLR(top=0, bottom=112, left=64, right=272),
|
||||||
|
),
|
||||||
|
Tile(
|
||||||
|
coords=TBLR(top=0, bottom=256, left=688, right=1200),
|
||||||
|
overlap=TBLR(top=0, bottom=112, left=272, right=0),
|
||||||
|
),
|
||||||
|
# Row 1
|
||||||
|
Tile(
|
||||||
|
coords=TBLR(top=144, bottom=400, left=0, right=512),
|
||||||
|
overlap=TBLR(top=112, bottom=0, left=0, right=64),
|
||||||
|
),
|
||||||
|
Tile(
|
||||||
|
coords=TBLR(top=144, bottom=400, left=448, right=960),
|
||||||
|
overlap=TBLR(top=112, bottom=0, left=64, right=272),
|
||||||
|
),
|
||||||
|
Tile(
|
||||||
|
coords=TBLR(top=144, bottom=400, left=688, right=1200),
|
||||||
|
overlap=TBLR(top=112, bottom=0, left=272, right=0),
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -75,7 +115,12 @@ def test_calc_tiles_with_overlap_not_evenly_divisible():
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_calc_tiles_with_overlap_input_validation(
|
def test_calc_tiles_with_overlap_input_validation(
|
||||||
image_height: int, image_width: int, tile_height: int, tile_width: int, overlap: int, raises: bool
|
image_height: int,
|
||||||
|
image_width: int,
|
||||||
|
tile_height: int,
|
||||||
|
tile_width: int,
|
||||||
|
overlap: int,
|
||||||
|
raises: bool,
|
||||||
):
|
):
|
||||||
"""Test that calc_tiles_with_overlap() raises an exception if the inputs are invalid."""
|
"""Test that calc_tiles_with_overlap() raises an exception if the inputs are invalid."""
|
||||||
if raises:
|
if raises:
|
||||||
@ -85,6 +130,306 @@ def test_calc_tiles_with_overlap_input_validation(
|
|||||||
calc_tiles_with_overlap(image_height, image_width, tile_height, tile_width, overlap)
|
calc_tiles_with_overlap(image_height, image_width, tile_height, tile_width, overlap)
|
||||||
|
|
||||||
|
|
||||||
|
####################################
|
||||||
|
# Test calc_tiles_min_overlap(...)
|
||||||
|
####################################
|
||||||
|
|
||||||
|
|
||||||
|
def test_calc_tiles_min_overlap_single_tile():
|
||||||
|
"""Test calc_tiles_min_overlap() behavior when a single tile covers the image."""
|
||||||
|
tiles = calc_tiles_min_overlap(
|
||||||
|
image_height=512,
|
||||||
|
image_width=1024,
|
||||||
|
tile_height=512,
|
||||||
|
tile_width=1024,
|
||||||
|
min_overlap=64,
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_tiles = [
|
||||||
|
Tile(
|
||||||
|
coords=TBLR(top=0, bottom=512, left=0, right=1024),
|
||||||
|
overlap=TBLR(top=0, bottom=0, left=0, right=0),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
assert tiles == expected_tiles
|
||||||
|
|
||||||
|
|
||||||
|
def test_calc_tiles_min_overlap_evenly_divisible():
|
||||||
|
"""Test calc_tiles_min_overlap() behavior when the image is evenly covered by multiple tiles."""
|
||||||
|
# Parameters mimic roughly the same output as the original tile generations of the same test name
|
||||||
|
tiles = calc_tiles_min_overlap(
|
||||||
|
image_height=576,
|
||||||
|
image_width=1600,
|
||||||
|
tile_height=320,
|
||||||
|
tile_width=576,
|
||||||
|
min_overlap=64,
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_tiles = [
|
||||||
|
# Row 0
|
||||||
|
Tile(
|
||||||
|
coords=TBLR(top=0, bottom=320, left=0, right=576),
|
||||||
|
overlap=TBLR(top=0, bottom=64, left=0, right=64),
|
||||||
|
),
|
||||||
|
Tile(
|
||||||
|
coords=TBLR(top=0, bottom=320, left=512, right=1088),
|
||||||
|
overlap=TBLR(top=0, bottom=64, left=64, right=64),
|
||||||
|
),
|
||||||
|
Tile(
|
||||||
|
coords=TBLR(top=0, bottom=320, left=1024, right=1600),
|
||||||
|
overlap=TBLR(top=0, bottom=64, left=64, right=0),
|
||||||
|
),
|
||||||
|
# Row 1
|
||||||
|
Tile(
|
||||||
|
coords=TBLR(top=256, bottom=576, left=0, right=576),
|
||||||
|
overlap=TBLR(top=64, bottom=0, left=0, right=64),
|
||||||
|
),
|
||||||
|
Tile(
|
||||||
|
coords=TBLR(top=256, bottom=576, left=512, right=1088),
|
||||||
|
overlap=TBLR(top=64, bottom=0, left=64, right=64),
|
||||||
|
),
|
||||||
|
Tile(
|
||||||
|
coords=TBLR(top=256, bottom=576, left=1024, right=1600),
|
||||||
|
overlap=TBLR(top=64, bottom=0, left=64, right=0),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
assert tiles == expected_tiles
|
||||||
|
|
||||||
|
|
||||||
|
def test_calc_tiles_min_overlap_not_evenly_divisible():
|
||||||
|
"""Test calc_tiles_min_overlap() behavior when the image requires 'uneven' overlaps to achieve proper coverage."""
|
||||||
|
# Parameters mimic roughly the same output as the original tile generations of the same test name
|
||||||
|
tiles = calc_tiles_min_overlap(
|
||||||
|
image_height=400,
|
||||||
|
image_width=1200,
|
||||||
|
tile_height=256,
|
||||||
|
tile_width=512,
|
||||||
|
min_overlap=64,
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_tiles = [
|
||||||
|
# Row 0
|
||||||
|
Tile(
|
||||||
|
coords=TBLR(top=0, bottom=256, left=0, right=512),
|
||||||
|
overlap=TBLR(top=0, bottom=112, left=0, right=168),
|
||||||
|
),
|
||||||
|
Tile(
|
||||||
|
coords=TBLR(top=0, bottom=256, left=344, right=856),
|
||||||
|
overlap=TBLR(top=0, bottom=112, left=168, right=168),
|
||||||
|
),
|
||||||
|
Tile(
|
||||||
|
coords=TBLR(top=0, bottom=256, left=688, right=1200),
|
||||||
|
overlap=TBLR(top=0, bottom=112, left=168, right=0),
|
||||||
|
),
|
||||||
|
# Row 1
|
||||||
|
Tile(
|
||||||
|
coords=TBLR(top=144, bottom=400, left=0, right=512),
|
||||||
|
overlap=TBLR(top=112, bottom=0, left=0, right=168),
|
||||||
|
),
|
||||||
|
Tile(
|
||||||
|
coords=TBLR(top=144, bottom=400, left=344, right=856),
|
||||||
|
overlap=TBLR(top=112, bottom=0, left=168, right=168),
|
||||||
|
),
|
||||||
|
Tile(
|
||||||
|
coords=TBLR(top=144, bottom=400, left=688, right=1200),
|
||||||
|
overlap=TBLR(top=112, bottom=0, left=168, right=0),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
assert tiles == expected_tiles
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
[
|
||||||
|
"image_height",
|
||||||
|
"image_width",
|
||||||
|
"tile_height",
|
||||||
|
"tile_width",
|
||||||
|
"min_overlap",
|
||||||
|
"raises",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
(128, 128, 128, 128, 127, False), # OK
|
||||||
|
(128, 128, 128, 128, 0, False), # OK
|
||||||
|
(128, 128, 64, 64, 0, False), # OK
|
||||||
|
(128, 128, 129, 128, 0, False), # tile_height exceeds image_height defaults to 1 tile.
|
||||||
|
(128, 128, 128, 129, 0, False), # tile_width exceeds image_width defaults to 1 tile.
|
||||||
|
(128, 128, 64, 128, 64, True), # overlap equals tile_height.
|
||||||
|
(128, 128, 128, 64, 64, True), # overlap equals tile_width.
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_calc_tiles_min_overlap_input_validation(
|
||||||
|
image_height: int,
|
||||||
|
image_width: int,
|
||||||
|
tile_height: int,
|
||||||
|
tile_width: int,
|
||||||
|
min_overlap: int,
|
||||||
|
raises: bool,
|
||||||
|
):
|
||||||
|
"""Test that calc_tiles_min_overlap() raises an exception if the inputs are invalid."""
|
||||||
|
if raises:
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
calc_tiles_min_overlap(image_height, image_width, tile_height, tile_width, min_overlap)
|
||||||
|
else:
|
||||||
|
calc_tiles_min_overlap(image_height, image_width, tile_height, tile_width, min_overlap)
|
||||||
|
|
||||||
|
|
||||||
|
####################################
|
||||||
|
# Test calc_tiles_even_split(...)
|
||||||
|
####################################
|
||||||
|
|
||||||
|
|
||||||
|
def test_calc_tiles_even_split_single_tile():
|
||||||
|
"""Test calc_tiles_even_split() behavior when a single tile covers the image."""
|
||||||
|
tiles = calc_tiles_even_split(
|
||||||
|
image_height=512, image_width=1024, num_tiles_x=1, num_tiles_y=1, overlap_fraction=0.25
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_tiles = [
|
||||||
|
Tile(
|
||||||
|
coords=TBLR(top=0, bottom=512, left=0, right=1024),
|
||||||
|
overlap=TBLR(top=0, bottom=0, left=0, right=0),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
assert tiles == expected_tiles
|
||||||
|
|
||||||
|
|
||||||
|
def test_calc_tiles_even_split_evenly_divisible():
|
||||||
|
"""Test calc_tiles_even_split() behavior when the image is evenly covered by multiple tiles."""
|
||||||
|
# Parameters mimic roughly the same output as the original tile generations of the same test name
|
||||||
|
tiles = calc_tiles_even_split(
|
||||||
|
image_height=576, image_width=1600, num_tiles_x=3, num_tiles_y=2, overlap_fraction=0.25
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_tiles = [
|
||||||
|
# Row 0
|
||||||
|
Tile(
|
||||||
|
coords=TBLR(top=0, bottom=320, left=0, right=624),
|
||||||
|
overlap=TBLR(top=0, bottom=72, left=0, right=136),
|
||||||
|
),
|
||||||
|
Tile(
|
||||||
|
coords=TBLR(top=0, bottom=320, left=488, right=1112),
|
||||||
|
overlap=TBLR(top=0, bottom=72, left=136, right=136),
|
||||||
|
),
|
||||||
|
Tile(
|
||||||
|
coords=TBLR(top=0, bottom=320, left=976, right=1600),
|
||||||
|
overlap=TBLR(top=0, bottom=72, left=136, right=0),
|
||||||
|
),
|
||||||
|
# Row 1
|
||||||
|
Tile(
|
||||||
|
coords=TBLR(top=248, bottom=576, left=0, right=624),
|
||||||
|
overlap=TBLR(top=72, bottom=0, left=0, right=136),
|
||||||
|
),
|
||||||
|
Tile(
|
||||||
|
coords=TBLR(top=248, bottom=576, left=488, right=1112),
|
||||||
|
overlap=TBLR(top=72, bottom=0, left=136, right=136),
|
||||||
|
),
|
||||||
|
Tile(
|
||||||
|
coords=TBLR(top=248, bottom=576, left=976, right=1600),
|
||||||
|
overlap=TBLR(top=72, bottom=0, left=136, right=0),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
assert tiles == expected_tiles
|
||||||
|
|
||||||
|
|
||||||
|
def test_calc_tiles_even_split_not_evenly_divisible():
|
||||||
|
"""Test calc_tiles_even_split() behavior when the image requires 'uneven' overlaps to achieve proper coverage."""
|
||||||
|
# Parameters mimic roughly the same output as the original tile generations of the same test name
|
||||||
|
tiles = calc_tiles_even_split(
|
||||||
|
image_height=400, image_width=1200, num_tiles_x=3, num_tiles_y=2, overlap_fraction=0.25
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_tiles = [
|
||||||
|
# Row 0
|
||||||
|
Tile(
|
||||||
|
coords=TBLR(top=0, bottom=224, left=0, right=464),
|
||||||
|
overlap=TBLR(top=0, bottom=56, left=0, right=104),
|
||||||
|
),
|
||||||
|
Tile(
|
||||||
|
coords=TBLR(top=0, bottom=224, left=360, right=824),
|
||||||
|
overlap=TBLR(top=0, bottom=56, left=104, right=104),
|
||||||
|
),
|
||||||
|
Tile(
|
||||||
|
coords=TBLR(top=0, bottom=224, left=720, right=1200),
|
||||||
|
overlap=TBLR(top=0, bottom=56, left=104, right=0),
|
||||||
|
),
|
||||||
|
# Row 1
|
||||||
|
Tile(
|
||||||
|
coords=TBLR(top=168, bottom=400, left=0, right=464),
|
||||||
|
overlap=TBLR(top=56, bottom=0, left=0, right=104),
|
||||||
|
),
|
||||||
|
Tile(
|
||||||
|
coords=TBLR(top=168, bottom=400, left=360, right=824),
|
||||||
|
overlap=TBLR(top=56, bottom=0, left=104, right=104),
|
||||||
|
),
|
||||||
|
Tile(
|
||||||
|
coords=TBLR(top=168, bottom=400, left=720, right=1200),
|
||||||
|
overlap=TBLR(top=56, bottom=0, left=104, right=0),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
assert tiles == expected_tiles
|
||||||
|
|
||||||
|
|
||||||
|
def test_calc_tiles_even_split_difficult_size():
|
||||||
|
"""Test calc_tiles_even_split() behavior when the image is a difficult size to spilt evenly and keep div8."""
|
||||||
|
# Parameters are a difficult size for other tile gen routines to calculate
|
||||||
|
tiles = calc_tiles_even_split(
|
||||||
|
image_height=1000, image_width=1000, num_tiles_x=2, num_tiles_y=2, overlap_fraction=0.25
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_tiles = [
|
||||||
|
# Row 0
|
||||||
|
Tile(
|
||||||
|
coords=TBLR(top=0, bottom=560, left=0, right=560),
|
||||||
|
overlap=TBLR(top=0, bottom=128, left=0, right=128),
|
||||||
|
),
|
||||||
|
Tile(
|
||||||
|
coords=TBLR(top=0, bottom=560, left=432, right=1000),
|
||||||
|
overlap=TBLR(top=0, bottom=128, left=128, right=0),
|
||||||
|
),
|
||||||
|
# Row 1
|
||||||
|
Tile(
|
||||||
|
coords=TBLR(top=432, bottom=1000, left=0, right=560),
|
||||||
|
overlap=TBLR(top=128, bottom=0, left=0, right=128),
|
||||||
|
),
|
||||||
|
Tile(
|
||||||
|
coords=TBLR(top=432, bottom=1000, left=432, right=1000),
|
||||||
|
overlap=TBLR(top=128, bottom=0, left=128, right=0),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
assert tiles == expected_tiles
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
["image_height", "image_width", "num_tiles_x", "num_tiles_y", "overlap_fraction", "raises"],
|
||||||
|
[
|
||||||
|
(128, 128, 1, 1, 0.25, False), # OK
|
||||||
|
(128, 128, 1, 1, 0, False), # OK
|
||||||
|
(128, 128, 2, 1, 0, False), # OK
|
||||||
|
(127, 127, 1, 1, 0, True), # image size must be dividable by 8
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_calc_tiles_even_split_input_validation(
|
||||||
|
image_height: int,
|
||||||
|
image_width: int,
|
||||||
|
num_tiles_x: int,
|
||||||
|
num_tiles_y: int,
|
||||||
|
overlap_fraction: float,
|
||||||
|
raises: bool,
|
||||||
|
):
|
||||||
|
"""Test that calc_tiles_even_split() raises an exception if the inputs are invalid."""
|
||||||
|
if raises:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
calc_tiles_even_split(image_height, image_width, num_tiles_x, num_tiles_y, overlap_fraction)
|
||||||
|
else:
|
||||||
|
calc_tiles_even_split(image_height, image_width, num_tiles_x, num_tiles_y, overlap_fraction)
|
||||||
|
|
||||||
|
|
||||||
#############################################
|
#############################################
|
||||||
# Test merge_tiles_with_linear_blending(...)
|
# Test merge_tiles_with_linear_blending(...)
|
||||||
#############################################
|
#############################################
|
||||||
@ -95,8 +440,14 @@ def test_merge_tiles_with_linear_blending_horizontal(blend_amount: int):
|
|||||||
"""Test merge_tiles_with_linear_blending(...) behavior when merging horizontally."""
|
"""Test merge_tiles_with_linear_blending(...) behavior when merging horizontally."""
|
||||||
# Initialize 2 tiles side-by-side.
|
# Initialize 2 tiles side-by-side.
|
||||||
tiles = [
|
tiles = [
|
||||||
Tile(coords=TBLR(top=0, bottom=512, left=0, right=512), overlap=TBLR(top=0, bottom=0, left=0, right=64)),
|
Tile(
|
||||||
Tile(coords=TBLR(top=0, bottom=512, left=448, right=960), overlap=TBLR(top=0, bottom=0, left=64, right=0)),
|
coords=TBLR(top=0, bottom=512, left=0, right=512),
|
||||||
|
overlap=TBLR(top=0, bottom=0, left=0, right=64),
|
||||||
|
),
|
||||||
|
Tile(
|
||||||
|
coords=TBLR(top=0, bottom=512, left=448, right=960),
|
||||||
|
overlap=TBLR(top=0, bottom=0, left=64, right=0),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
dst_image = np.zeros((512, 960, 3), dtype=np.uint8)
|
dst_image = np.zeros((512, 960, 3), dtype=np.uint8)
|
||||||
@ -116,7 +467,10 @@ def test_merge_tiles_with_linear_blending_horizontal(blend_amount: int):
|
|||||||
expected_output[:, 480 + (blend_amount // 2) :, :] = 128
|
expected_output[:, 480 + (blend_amount // 2) :, :] = 128
|
||||||
|
|
||||||
merge_tiles_with_linear_blending(
|
merge_tiles_with_linear_blending(
|
||||||
dst_image=dst_image, tiles=tiles, tile_images=tile_images, blend_amount=blend_amount
|
dst_image=dst_image,
|
||||||
|
tiles=tiles,
|
||||||
|
tile_images=tile_images,
|
||||||
|
blend_amount=blend_amount,
|
||||||
)
|
)
|
||||||
|
|
||||||
np.testing.assert_array_equal(dst_image, expected_output, strict=True)
|
np.testing.assert_array_equal(dst_image, expected_output, strict=True)
|
||||||
@ -127,8 +481,14 @@ def test_merge_tiles_with_linear_blending_vertical(blend_amount: int):
|
|||||||
"""Test merge_tiles_with_linear_blending(...) behavior when merging vertically."""
|
"""Test merge_tiles_with_linear_blending(...) behavior when merging vertically."""
|
||||||
# Initialize 2 tiles stacked vertically.
|
# Initialize 2 tiles stacked vertically.
|
||||||
tiles = [
|
tiles = [
|
||||||
Tile(coords=TBLR(top=0, bottom=512, left=0, right=512), overlap=TBLR(top=0, bottom=64, left=0, right=0)),
|
Tile(
|
||||||
Tile(coords=TBLR(top=448, bottom=960, left=0, right=512), overlap=TBLR(top=64, bottom=0, left=0, right=0)),
|
coords=TBLR(top=0, bottom=512, left=0, right=512),
|
||||||
|
overlap=TBLR(top=0, bottom=64, left=0, right=0),
|
||||||
|
),
|
||||||
|
Tile(
|
||||||
|
coords=TBLR(top=448, bottom=960, left=0, right=512),
|
||||||
|
overlap=TBLR(top=64, bottom=0, left=0, right=0),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
dst_image = np.zeros((960, 512, 3), dtype=np.uint8)
|
dst_image = np.zeros((960, 512, 3), dtype=np.uint8)
|
||||||
@ -148,7 +508,10 @@ def test_merge_tiles_with_linear_blending_vertical(blend_amount: int):
|
|||||||
expected_output[480 + (blend_amount // 2) :, :, :] = 128
|
expected_output[480 + (blend_amount // 2) :, :, :] = 128
|
||||||
|
|
||||||
merge_tiles_with_linear_blending(
|
merge_tiles_with_linear_blending(
|
||||||
dst_image=dst_image, tiles=tiles, tile_images=tile_images, blend_amount=blend_amount
|
dst_image=dst_image,
|
||||||
|
tiles=tiles,
|
||||||
|
tile_images=tile_images,
|
||||||
|
blend_amount=blend_amount,
|
||||||
)
|
)
|
||||||
|
|
||||||
np.testing.assert_array_equal(dst_image, expected_output, strict=True)
|
np.testing.assert_array_equal(dst_image, expected_output, strict=True)
|
||||||
@ -160,8 +523,14 @@ def test_merge_tiles_with_linear_blending_blend_amount_exceeds_vertical_overlap(
|
|||||||
"""
|
"""
|
||||||
# Initialize 2 tiles stacked vertically.
|
# Initialize 2 tiles stacked vertically.
|
||||||
tiles = [
|
tiles = [
|
||||||
Tile(coords=TBLR(top=0, bottom=512, left=0, right=512), overlap=TBLR(top=0, bottom=64, left=0, right=0)),
|
Tile(
|
||||||
Tile(coords=TBLR(top=448, bottom=960, left=0, right=512), overlap=TBLR(top=64, bottom=0, left=0, right=0)),
|
coords=TBLR(top=0, bottom=512, left=0, right=512),
|
||||||
|
overlap=TBLR(top=0, bottom=64, left=0, right=0),
|
||||||
|
),
|
||||||
|
Tile(
|
||||||
|
coords=TBLR(top=448, bottom=960, left=0, right=512),
|
||||||
|
overlap=TBLR(top=64, bottom=0, left=0, right=0),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
dst_image = np.zeros((960, 512, 3), dtype=np.uint8)
|
dst_image = np.zeros((960, 512, 3), dtype=np.uint8)
|
||||||
@ -180,8 +549,14 @@ def test_merge_tiles_with_linear_blending_blend_amount_exceeds_horizontal_overla
|
|||||||
"""
|
"""
|
||||||
# Initialize 2 tiles side-by-side.
|
# Initialize 2 tiles side-by-side.
|
||||||
tiles = [
|
tiles = [
|
||||||
Tile(coords=TBLR(top=0, bottom=512, left=0, right=512), overlap=TBLR(top=0, bottom=0, left=0, right=64)),
|
Tile(
|
||||||
Tile(coords=TBLR(top=0, bottom=512, left=448, right=960), overlap=TBLR(top=0, bottom=0, left=64, right=0)),
|
coords=TBLR(top=0, bottom=512, left=0, right=512),
|
||||||
|
overlap=TBLR(top=0, bottom=0, left=0, right=64),
|
||||||
|
),
|
||||||
|
Tile(
|
||||||
|
coords=TBLR(top=0, bottom=512, left=448, right=960),
|
||||||
|
overlap=TBLR(top=0, bottom=0, left=64, right=0),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
dst_image = np.zeros((512, 960, 3), dtype=np.uint8)
|
dst_image = np.zeros((512, 960, 3), dtype=np.uint8)
|
||||||
@ -198,7 +573,12 @@ def test_merge_tiles_with_linear_blending_tiles_overflow_dst_image():
|
|||||||
"""Test that merge_tiles_with_linear_blending(...) raises an exception if any of the tiles overflows the
|
"""Test that merge_tiles_with_linear_blending(...) raises an exception if any of the tiles overflows the
|
||||||
dst_image.
|
dst_image.
|
||||||
"""
|
"""
|
||||||
tiles = [Tile(coords=TBLR(top=0, bottom=512, left=0, right=512), overlap=TBLR(top=0, bottom=0, left=0, right=0))]
|
tiles = [
|
||||||
|
Tile(
|
||||||
|
coords=TBLR(top=0, bottom=512, left=0, right=512),
|
||||||
|
overlap=TBLR(top=0, bottom=0, left=0, right=0),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
dst_image = np.zeros((256, 512, 3), dtype=np.uint8)
|
dst_image = np.zeros((256, 512, 3), dtype=np.uint8)
|
||||||
|
|
||||||
@ -213,7 +593,12 @@ def test_merge_tiles_with_linear_blending_mismatched_list_lengths():
|
|||||||
"""Test that merge_tiles_with_linear_blending(...) raises an exception if the lengths of 'tiles' and 'tile_images'
|
"""Test that merge_tiles_with_linear_blending(...) raises an exception if the lengths of 'tiles' and 'tile_images'
|
||||||
do not match.
|
do not match.
|
||||||
"""
|
"""
|
||||||
tiles = [Tile(coords=TBLR(top=0, bottom=512, left=0, right=512), overlap=TBLR(top=0, bottom=0, left=0, right=0))]
|
tiles = [
|
||||||
|
Tile(
|
||||||
|
coords=TBLR(top=0, bottom=512, left=0, right=512),
|
||||||
|
overlap=TBLR(top=0, bottom=0, left=0, right=0),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
dst_image = np.zeros((256, 512, 3), dtype=np.uint8)
|
dst_image = np.zeros((256, 512, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user