InvokeAI/invokeai/app/invocations/tiles.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

163 lines
5.8 KiB
Python
Raw Normal View History

import numpy as np
from PIL import Image
from pydantic import BaseModel
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InputField,
InvocationContext,
OutputField,
WithMetadata,
WithWorkflow,
invocation,
invocation_output,
)
from invokeai.app.invocations.primitives import ImageField, ImageOutput
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
from invokeai.backend.tiles.tiles import calc_tiles, merge_tiles_with_linear_blending
from invokeai.backend.tiles.utils import Tile
# TODO(ryand): Is this important?
_DIMENSION_MULTIPLE_OF = 8
class TileWithImage(BaseModel):
tile: Tile
image: ImageField
@invocation_output("calc_tiles_output")
class CalcTilesOutput(BaseInvocationOutput):
# TODO(ryand): Add description from FieldDescriptions.
tiles: list[Tile] = OutputField(description="")
@invocation("calculate_tiles", title="Calculate Tiles", tags=["tiles"], category="tiles", version="1.0.0")
class CalcTiles(BaseInvocation):
"""TODO(ryand)"""
# Inputs
image_height: int = InputField(ge=1)
image_width: int = InputField(ge=1)
tile_height: int = InputField(ge=1, multiple_of=_DIMENSION_MULTIPLE_OF, default=576)
tile_width: int = InputField(ge=1, multiple_of=_DIMENSION_MULTIPLE_OF, default=576)
overlap: int = InputField(ge=0, multiple_of=_DIMENSION_MULTIPLE_OF, default=64)
def invoke(self, context: InvocationContext) -> CalcTilesOutput:
tiles = calc_tiles(
image_height=self.image_height,
image_width=self.image_width,
tile_height=self.tile_height,
tile_width=self.tile_width,
overlap=self.overlap,
)
return CalcTilesOutput(tiles=tiles)
@invocation_output("tile_to_properties_output")
class TileToPropertiesOutput(BaseInvocationOutput):
# TODO(ryand): Add descriptions.
coords_top: int = OutputField(description="")
coords_bottom: int = OutputField(description="")
coords_left: int = OutputField(description="")
coords_right: int = OutputField(description="")
overlap_top: int = OutputField(description="")
overlap_bottom: int = OutputField(description="")
overlap_left: int = OutputField(description="")
overlap_right: int = OutputField(description="")
@invocation("tile_to_properties")
class TileToProperties(BaseInvocation):
"""Split a Tile into its individual properties."""
tile: Tile = InputField()
def invoke(self, context: InvocationContext) -> TileToPropertiesOutput:
return TileToPropertiesOutput(
coords_top=self.tile.coords.top,
coords_bottom=self.tile.coords.bottom,
coords_left=self.tile.coords.left,
coords_right=self.tile.coords.right,
overlap_top=self.tile.overlap.top,
overlap_bottom=self.tile.overlap.bottom,
overlap_left=self.tile.overlap.left,
overlap_right=self.tile.overlap.right,
)
# HACK(ryand): The only reason that PairTileImage is needed is because the iterate/collect nodes don't preserve order.
# Can this be fixed?
@invocation_output("pair_tile_image_output")
class PairTileImageOutput(BaseInvocationOutput):
tile_with_image: TileWithImage = OutputField(description="")
@invocation("pair_tile_image", title="Pair Tile with Image", tags=["tiles"], category="tiles", version="1.0.0")
class PairTileImage(BaseInvocation):
image: ImageField = InputField()
tile: Tile = InputField()
def invoke(self, context: InvocationContext) -> PairTileImageOutput:
return PairTileImageOutput(
tile_with_image=TileWithImage(
tile=self.tile,
image=self.image,
)
)
@invocation("merge_tiles_to_image", title="Merge Tiles To Image", tags=["tiles"], category="tiles", version="1.0.0")
class MergeTilesToImage(BaseInvocation, WithMetadata, WithWorkflow):
"""TODO(ryand)"""
# Inputs
image_height: int = InputField(ge=1)
image_width: int = InputField(ge=1)
tiles_with_images: list[TileWithImage] = InputField()
blend_amount: int = InputField(ge=0)
def invoke(self, context: InvocationContext) -> ImageOutput:
images = [twi.image for twi in self.tiles_with_images]
tiles = [twi.tile for twi in self.tiles_with_images]
# Get all tile images for processing.
# TODO(ryand): It pains me that we spend time PNG decoding each tile from disk when they almost certainly
# existed in memory at an earlier point in the graph.
tile_np_images: list[np.ndarray] = []
for image in images:
pil_image = context.services.images.get_pil_image(image.image_name)
pil_image = pil_image.convert("RGB")
tile_np_images.append(np.array(pil_image))
# Prepare the output image buffer.
# Check the first tile to determine how many image channels are expected in the output.
channels = tile_np_images[0].shape[-1]
dtype = tile_np_images[0].dtype
np_image = np.zeros(shape=(self.image_height, self.image_width, channels), dtype=dtype)
merge_tiles_with_linear_blending(
dst_image=np_image, tiles=tiles, tile_images=tile_np_images, blend_amount=self.blend_amount
)
pil_image = Image.fromarray(np_image)
image_dto = context.services.images.create(
image=pil_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)