mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
5f498e10bd
* feat(ui): add axios client generator and simple example * fix(ui): update client & nodes test code w/ new Edge type * chore(ui): organize generated files * chore(ui): update .eslintignore, .prettierignore * chore(ui): update openapi.json * feat(backend): fixes for nodes/generator * feat(ui): generate object args for api client * feat(ui): more nodes api prototyping * feat(ui): nodes cancel * chore(ui): regenerate api client * fix(ui): disable OG web server socket connection * fix(ui): fix scrollbar styles typing and prop just noticed the typo, and made the types stronger. * feat(ui): add socketio types * feat(ui): wip nodes - extract api client method arg types instead of manually declaring them - update example to display images - general tidy up * start building out node translations from frontend state and add notes about missing features * use reference to sampler_name * use reference to sampler_name * add optional apiUrl prop * feat(ui): start hooking up dynamic txt2img node generation, create middleware for session invocation * feat(ui): write separate nodes socket layer, txt2img generating and rendering w single node * feat(ui): img2img implementation * feat(ui): get intermediate images working but types are stubbed out * chore(ui): add support for package mode * feat(ui): add nodes mode script * feat(ui): handle random seeds * fix(ui): fix middleware types * feat(ui): add rtk action type guard * feat(ui): disable NodeAPITest This was polluting the network/socket logs. * feat(ui): fix parameters panel border color This commit should be elsewhere but I don't want to break my flow * feat(ui): make thunk types more consistent * feat(ui): add type guards for outputs * feat(ui): load images on socket connect Rudimentary * chore(ui): bump redux-toolkit * docs(ui): update readme * chore(ui): regenerate api client * chore(ui): add typescript as dev dependency I am having trouble with TS versions after vscode updated and now uses TS 5. `madge` has installed 3.9.10 and for whatever reason my vscode wants to use that. Manually specifying 4.9.5 and then setting vscode to use that as the workspace TS fixes the issue. * feat(ui): begin migrating gallery to nodes Along the way, migrate to use RTK `createEntityAdapter` for gallery images, and separate `results` and `uploads` into separate slices. Much cleaner this way. * feat(ui): clean up & comment results slice * fix(ui): separate thunk for initial gallery load so it properly gets index 0 * feat(ui): POST upload working * fix(ui): restore removed type * feat(ui): patch api generation for headers access * chore(ui): regenerate api * feat(ui): wip gallery migration * feat(ui): wip gallery migration * chore(ui): regenerate api * feat(ui): wip refactor socket events * feat(ui): disable panels based on app props * feat(ui): invert logic to be disabled * disable panels when app mounts * feat(ui): add support to disableTabs * docs(ui): organise and update docs * lang(ui): add toast strings * feat(ui): wip events, comments, and general refactoring * feat(ui): add optional token for auth * feat(ui): export StatusIndicator and ModelSelect for header use * feat(ui) working on making socket URL dynamic * feat(ui): dynamic middleware loading * feat(ui): prep for socket jwt * feat(ui): migrate cancelation also updated action names to be event-like instead of declaration-like sorry, i was scattered and this commit has a lot of unrelated stuff in it. * fix(ui): fix img2img type * chore(ui): regenerate api client * feat(ui): improve InvocationCompleteEvent types * feat(ui): increase StatusIndicator font size * fix(ui): fix middleware order for multi-node graphs * feat(ui): add exampleGraphs object w/ iterations example * feat(ui): generate iterations graph * feat(ui): update ModelSelect for nodes API * feat(ui): add hi-res functionality for txt2img generations * feat(ui): "subscribe" to particular nodes feels like a dirty hack but oh well it works * feat(ui): first steps to node editor ui * fix(ui): disable event subscription it is not fully baked just yet * feat(ui): wip node editor * feat(ui): remove extraneous field types * feat(ui): nodes before deleting stuff * feat(ui): cleanup nodes ui stuff * feat(ui): hook up nodes to redux * fix(ui): fix handle * fix(ui): add basic node edges & connection validation * feat(ui): add connection validation styling * feat(ui): increase edge width * feat(ui): it blends * feat(ui): wip model handling and graph topology validation * feat(ui): validation connections w/ graphlib * docs(ui): update nodes doc * feat(ui): wip node editor * chore(ui): rebuild api, update types * add redux-dynamic-middlewares as a dependency * feat(ui): add url host transformation * feat(ui): handle already-connected fields * feat(ui): rewrite SqliteItemStore in sqlalchemy * fix(ui): fix sqlalchemy dynamic model instantiation * feat(ui, nodes): metadata wip * feat(ui, nodes): models * feat(ui, nodes): more metadata wip * feat(ui): wip range/iterate * fix(nodes): fix sqlite typing * feat(ui): export new type for invoke component * tests(nodes): fix test instantiation of ImageField * feat(nodes): fix LoadImageInvocation * feat(nodes): add `title` ui hint * feat(nodes): make ImageField attrs optional * feat(ui): wip nodes etc * feat(nodes): roll back sqlalchemy * fix(nodes): partially address feedback * fix(backend): roll back changes to pngwriter * feat(nodes): wip address metadata feedback * feat(nodes): add seeded rng to RandomRange * feat(nodes): address feedback * feat(nodes): move GET images error handling to DiskImageStorage * feat(nodes): move GET images error handling to DiskImageStorage * fix(nodes): fix image output schema customization * feat(ui): img2img/txt2img -> linear - remove txt2img and img2img tabs - add linear tab - add initial image selection to linear parameters accordion * feat(ui): tidy graph builders * feat(ui): tidy misc * feat(ui): improve invocation union types * feat(ui): wip metadata viewer recall * feat(ui): move fonts to normal deps * feat(nodes): fix broken upload * feat(nodes): add metadata module + tests, thumbnails - `MetadataModule` is stateless and needed in places where the `InvocationContext` is not available, so have not made it a `service` - Handles loading/parsing/building metadata, and creating png info objects - added tests for MetadataModule - Lifted thumbnail stuff to util * fix(nodes): revert change to RandomRangeInvocation * feat(nodes): address feedback - make metadata a service - rip out pydantic validation, implement metadata parsing as simple functions - update tests - address other minor feedback items * fix(nodes): fix other tests * fix(nodes): add metadata service to cli * fix(nodes): fix latents/image field parsing * feat(nodes): customise LatentsField schema * feat(nodes): move metadata parsing to frontend * fix(nodes): fix metadata test --------- Co-authored-by: maryhipp <maryhipp@gmail.com> Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
370 lines
12 KiB
Python
370 lines
12 KiB
Python
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
|
|
|
from typing import Literal, Optional
|
|
|
|
import numpy
|
|
from PIL import Image, ImageFilter, ImageOps
|
|
from pydantic import BaseModel, Field
|
|
|
|
from ..models.image import ImageField, ImageType
|
|
from .baseinvocation import (
|
|
BaseInvocation,
|
|
BaseInvocationOutput,
|
|
InvocationContext,
|
|
InvocationConfig,
|
|
)
|
|
|
|
|
|
class PILInvocationConfig(BaseModel):
|
|
"""Helper class to provide all PIL invocations with additional config"""
|
|
|
|
class Config(InvocationConfig):
|
|
schema_extra = {
|
|
"ui": {
|
|
"tags": ["PIL", "image"],
|
|
},
|
|
}
|
|
|
|
|
|
class ImageOutput(BaseInvocationOutput):
|
|
"""Base class for invocations that output an image"""
|
|
|
|
# fmt: off
|
|
type: Literal["image"] = "image"
|
|
image: ImageField = Field(default=None, description="The output image")
|
|
width: Optional[int] = Field(default=None, description="The width of the image in pixels")
|
|
height: Optional[int] = Field(default=None, description="The height of the image in pixels")
|
|
# fmt: on
|
|
|
|
class Config:
|
|
schema_extra = {
|
|
"required": ["type", "image", "width", "height", "mode"]
|
|
}
|
|
|
|
|
|
def build_image_output(
|
|
image_type: ImageType, image_name: str, image: Image.Image
|
|
) -> ImageOutput:
|
|
"""Builds an ImageOutput and its ImageField"""
|
|
image_field = ImageField(
|
|
image_name=image_name,
|
|
image_type=image_type,
|
|
)
|
|
return ImageOutput(
|
|
image=image_field,
|
|
width=image.width,
|
|
height=image.height,
|
|
mode=image.mode,
|
|
)
|
|
|
|
|
|
class MaskOutput(BaseInvocationOutput):
|
|
"""Base class for invocations that output a mask"""
|
|
|
|
# fmt: off
|
|
type: Literal["mask"] = "mask"
|
|
mask: ImageField = Field(default=None, description="The output mask")
|
|
# fmt: on
|
|
|
|
class Config:
|
|
schema_extra = {
|
|
"required": [
|
|
"type",
|
|
"mask",
|
|
]
|
|
}
|
|
|
|
|
|
class LoadImageInvocation(BaseInvocation):
|
|
"""Load an image and provide it as output."""
|
|
|
|
# fmt: off
|
|
type: Literal["load_image"] = "load_image"
|
|
|
|
# Inputs
|
|
image_type: ImageType = Field(description="The type of the image")
|
|
image_name: str = Field(description="The name of the image")
|
|
# fmt: on
|
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
image = context.services.images.get(self.image_type, self.image_name)
|
|
|
|
return build_image_output(
|
|
image_type=self.image_type,
|
|
image_name=self.image_name,
|
|
image=image,
|
|
)
|
|
|
|
|
|
class ShowImageInvocation(BaseInvocation):
|
|
"""Displays a provided image, and passes it forward in the pipeline."""
|
|
|
|
type: Literal["show_image"] = "show_image"
|
|
|
|
# Inputs
|
|
image: ImageField = Field(default=None, description="The image to show")
|
|
|
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
image = context.services.images.get(
|
|
self.image.image_type, self.image.image_name
|
|
)
|
|
if image:
|
|
image.show()
|
|
|
|
# TODO: how to handle failure?
|
|
|
|
return build_image_output(
|
|
image_type=self.image.image_type,
|
|
image_name=self.image.image_name,
|
|
image=image,
|
|
)
|
|
|
|
|
|
class CropImageInvocation(BaseInvocation, PILInvocationConfig):
|
|
"""Crops an image to a specified box. The box can be outside of the image."""
|
|
|
|
# fmt: off
|
|
type: Literal["crop"] = "crop"
|
|
|
|
# Inputs
|
|
image: ImageField = Field(default=None, description="The image to crop")
|
|
x: int = Field(default=0, description="The left x coordinate of the crop rectangle")
|
|
y: int = Field(default=0, description="The top y coordinate of the crop rectangle")
|
|
width: int = Field(default=512, gt=0, description="The width of the crop rectangle")
|
|
height: int = Field(default=512, gt=0, description="The height of the crop rectangle")
|
|
# fmt: on
|
|
|
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
image = context.services.images.get(
|
|
self.image.image_type, self.image.image_name
|
|
)
|
|
|
|
image_crop = Image.new(
|
|
mode="RGBA", size=(self.width, self.height), color=(0, 0, 0, 0)
|
|
)
|
|
image_crop.paste(image, (-self.x, -self.y))
|
|
|
|
image_type = ImageType.INTERMEDIATE
|
|
image_name = context.services.images.create_name(
|
|
context.graph_execution_state_id, self.id
|
|
)
|
|
|
|
metadata = context.services.metadata.build_metadata(
|
|
session_id=context.graph_execution_state_id, node=self
|
|
)
|
|
|
|
context.services.images.save(image_type, image_name, image_crop, metadata)
|
|
return build_image_output(
|
|
image_type=image_type,
|
|
image_name=image_name,
|
|
image=image_crop,
|
|
)
|
|
|
|
|
|
class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
|
|
"""Pastes an image into another image."""
|
|
|
|
# fmt: off
|
|
type: Literal["paste"] = "paste"
|
|
|
|
# Inputs
|
|
base_image: ImageField = Field(default=None, description="The base image")
|
|
image: ImageField = Field(default=None, description="The image to paste")
|
|
mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting")
|
|
x: int = Field(default=0, description="The left x coordinate at which to paste the image")
|
|
y: int = Field(default=0, description="The top y coordinate at which to paste the image")
|
|
# fmt: on
|
|
|
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
base_image = context.services.images.get(
|
|
self.base_image.image_type, self.base_image.image_name
|
|
)
|
|
image = context.services.images.get(
|
|
self.image.image_type, self.image.image_name
|
|
)
|
|
mask = (
|
|
None
|
|
if self.mask is None
|
|
else ImageOps.invert(
|
|
context.services.images.get(self.mask.image_type, self.mask.image_name)
|
|
)
|
|
)
|
|
# TODO: probably shouldn't invert mask here... should user be required to do it?
|
|
|
|
min_x = min(0, self.x)
|
|
min_y = min(0, self.y)
|
|
max_x = max(base_image.width, image.width + self.x)
|
|
max_y = max(base_image.height, image.height + self.y)
|
|
|
|
new_image = Image.new(
|
|
mode="RGBA", size=(max_x - min_x, max_y - min_y), color=(0, 0, 0, 0)
|
|
)
|
|
new_image.paste(base_image, (abs(min_x), abs(min_y)))
|
|
new_image.paste(image, (max(0, self.x), max(0, self.y)), mask=mask)
|
|
|
|
image_type = ImageType.RESULT
|
|
image_name = context.services.images.create_name(
|
|
context.graph_execution_state_id, self.id
|
|
)
|
|
|
|
metadata = context.services.metadata.build_metadata(
|
|
session_id=context.graph_execution_state_id, node=self
|
|
)
|
|
|
|
context.services.images.save(image_type, image_name, new_image, metadata)
|
|
return build_image_output(
|
|
image_type=image_type,
|
|
image_name=image_name,
|
|
image=new_image,
|
|
)
|
|
|
|
|
|
class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
|
"""Extracts the alpha channel of an image as a mask."""
|
|
|
|
# fmt: off
|
|
type: Literal["tomask"] = "tomask"
|
|
|
|
# Inputs
|
|
image: ImageField = Field(default=None, description="The image to create the mask from")
|
|
invert: bool = Field(default=False, description="Whether or not to invert the mask")
|
|
# fmt: on
|
|
|
|
def invoke(self, context: InvocationContext) -> MaskOutput:
|
|
image = context.services.images.get(
|
|
self.image.image_type, self.image.image_name
|
|
)
|
|
|
|
image_mask = image.split()[-1]
|
|
if self.invert:
|
|
image_mask = ImageOps.invert(image_mask)
|
|
|
|
image_type = ImageType.INTERMEDIATE
|
|
image_name = context.services.images.create_name(
|
|
context.graph_execution_state_id, self.id
|
|
)
|
|
|
|
metadata = context.services.metadata.build_metadata(
|
|
session_id=context.graph_execution_state_id, node=self
|
|
)
|
|
|
|
context.services.images.save(image_type, image_name, image_mask, metadata)
|
|
return MaskOutput(mask=ImageField(image_type=image_type, image_name=image_name))
|
|
|
|
|
|
class BlurInvocation(BaseInvocation, PILInvocationConfig):
|
|
"""Blurs an image"""
|
|
|
|
# fmt: off
|
|
type: Literal["blur"] = "blur"
|
|
|
|
# Inputs
|
|
image: ImageField = Field(default=None, description="The image to blur")
|
|
radius: float = Field(default=8.0, ge=0, description="The blur radius")
|
|
blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur")
|
|
# fmt: on
|
|
|
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
image = context.services.images.get(
|
|
self.image.image_type, self.image.image_name
|
|
)
|
|
|
|
blur = (
|
|
ImageFilter.GaussianBlur(self.radius)
|
|
if self.blur_type == "gaussian"
|
|
else ImageFilter.BoxBlur(self.radius)
|
|
)
|
|
blur_image = image.filter(blur)
|
|
|
|
image_type = ImageType.INTERMEDIATE
|
|
image_name = context.services.images.create_name(
|
|
context.graph_execution_state_id, self.id
|
|
)
|
|
|
|
metadata = context.services.metadata.build_metadata(
|
|
session_id=context.graph_execution_state_id, node=self
|
|
)
|
|
|
|
context.services.images.save(image_type, image_name, blur_image, metadata)
|
|
return build_image_output(
|
|
image_type=image_type, image_name=image_name, image=blur_image
|
|
)
|
|
|
|
|
|
class LerpInvocation(BaseInvocation, PILInvocationConfig):
|
|
"""Linear interpolation of all pixels of an image"""
|
|
|
|
# fmt: off
|
|
type: Literal["lerp"] = "lerp"
|
|
|
|
# Inputs
|
|
image: ImageField = Field(default=None, description="The image to lerp")
|
|
min: int = Field(default=0, ge=0, le=255, description="The minimum output value")
|
|
max: int = Field(default=255, ge=0, le=255, description="The maximum output value")
|
|
# fmt: on
|
|
|
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
image = context.services.images.get(
|
|
self.image.image_type, self.image.image_name
|
|
)
|
|
|
|
image_arr = numpy.asarray(image, dtype=numpy.float32) / 255
|
|
image_arr = image_arr * (self.max - self.min) + self.max
|
|
|
|
lerp_image = Image.fromarray(numpy.uint8(image_arr))
|
|
|
|
image_type = ImageType.INTERMEDIATE
|
|
image_name = context.services.images.create_name(
|
|
context.graph_execution_state_id, self.id
|
|
)
|
|
|
|
metadata = context.services.metadata.build_metadata(
|
|
session_id=context.graph_execution_state_id, node=self
|
|
)
|
|
|
|
context.services.images.save(image_type, image_name, lerp_image, metadata)
|
|
return build_image_output(
|
|
image_type=image_type, image_name=image_name, image=lerp_image
|
|
)
|
|
|
|
|
|
class InverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
|
"""Inverse linear interpolation of all pixels of an image"""
|
|
|
|
# fmt: off
|
|
type: Literal["ilerp"] = "ilerp"
|
|
|
|
# Inputs
|
|
image: ImageField = Field(default=None, description="The image to lerp")
|
|
min: int = Field(default=0, ge=0, le=255, description="The minimum input value")
|
|
max: int = Field(default=255, ge=0, le=255, description="The maximum input value")
|
|
# fmt: on
|
|
|
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
image = context.services.images.get(
|
|
self.image.image_type, self.image.image_name
|
|
)
|
|
|
|
image_arr = numpy.asarray(image, dtype=numpy.float32)
|
|
image_arr = (
|
|
numpy.minimum(
|
|
numpy.maximum(image_arr - self.min, 0) / float(self.max - self.min), 1
|
|
)
|
|
* 255
|
|
)
|
|
|
|
ilerp_image = Image.fromarray(numpy.uint8(image_arr))
|
|
|
|
image_type = ImageType.INTERMEDIATE
|
|
image_name = context.services.images.create_name(
|
|
context.graph_execution_state_id, self.id
|
|
)
|
|
|
|
metadata = context.services.metadata.build_metadata(
|
|
session_id=context.graph_execution_state_id, node=self
|
|
)
|
|
|
|
context.services.images.save(image_type, image_name, ilerp_image, metadata)
|
|
return build_image_output(
|
|
image_type=image_type, image_name=image_name, image=ilerp_image
|
|
)
|