Compare commits

...

93 Commits

Author SHA1 Message Date
5189fd70a4 feat(ui, nodes): metadata wip 2023-04-14 12:52:07 +10:00
c79e7d3bc1 fix(ui): fix sqlalchemy dynamic model instantiation 2023-04-13 09:27:46 +10:00
f0268235ed feat(ui): rewrite SqliteItemStore in sqlalchemy 2023-04-12 22:34:32 +10:00
a428c473ae feat(ui): handle already-connected fields 2023-04-12 12:32:58 +10:00
832e5f66e1 feat(ui): add url host transformation 2023-04-12 12:29:17 +10:00
43e820c98c add redux-dynamic-middlewares as a dependency 2023-04-12 12:29:17 +10:00
708b236769 chore(ui): rebuild api, update types 2023-04-12 12:29:17 +10:00
c182f64620 feat(ui): wip node editor 2023-04-12 12:29:17 +10:00
6e4c3a7127 docs(ui): update nodes doc 2023-04-12 12:29:17 +10:00
7b690a8127 feat(ui): validation connections w/ graphlib 2023-04-12 12:29:17 +10:00
fe7cf16547 feat(ui): wip model handling and graph topology validation 2023-04-12 12:29:17 +10:00
451fe7abcd feat(ui): it blends 2023-04-12 12:29:17 +10:00
ebc76a4785 feat(ui): increase edge width 2023-04-12 12:29:17 +10:00
2d1b818824 feat(ui): add connection validation styling 2023-04-12 12:29:17 +10:00
d93473eaae fix(ui): add basic node edges & connection validation 2023-04-12 12:29:17 +10:00
f6714d74be fix(ui): fix handle 2023-04-12 12:29:17 +10:00
d26a414560 feat(ui): hook up nodes to redux 2023-04-12 12:29:17 +10:00
5aec29b25f feat(ui): cleanup nodes ui stuff 2023-04-12 12:29:17 +10:00
abde52573e feat(ui): nodes before deleting stuff 2023-04-12 12:29:17 +10:00
49ea838a3c feat(ui): remove extraneous field types 2023-04-12 12:29:17 +10:00
9bd79c04a6 feat(ui): wip node editor 2023-04-12 12:29:17 +10:00
b55b2a8947 fix(ui): disable event subscription
it is not fully baked just yet
2023-04-12 12:29:17 +10:00
e3a8fceb5d feat(ui): first steps to node editor ui 2023-04-12 12:29:17 +10:00
1e09fdc8be feat(ui): "subscribe" to particular nodes
feels like a dirty hack but oh well it works
2023-04-12 12:29:17 +10:00
d0e9ec267c feat(ui): add hi-res functionality for txt2img generations 2023-04-12 12:29:17 +10:00
880e1743ac feat(ui): update ModelSelect for nodes API 2023-04-12 12:29:17 +10:00
f59d4a0015 feat(ui): generate iterations graph 2023-04-12 12:29:17 +10:00
152d4e76aa feat(ui): add exampleGraphs object w/ iterations example 2023-04-12 12:29:17 +10:00
b829af7410 fix(ui): fix middleware order for multi-node graphs 2023-04-12 12:29:16 +10:00
dce604b567 feat(ui): increase StatusIndicator font size 2023-04-12 12:29:16 +10:00
1ed4354753 feat(ui): improve InvocationCompleteEvent types 2023-04-12 12:29:16 +10:00
db8ba8b0bf chore(ui): regenerate api client 2023-04-12 12:29:16 +10:00
3cd2695676 fix(ui): fix img2img type 2023-04-12 12:29:16 +10:00
2787d32881 feat(ui): migrate cancelation
also updated action names to be event-like instead of declaration-like

sorry, i was scattered and this commit has a lot of unrelated stuff in it.
2023-04-12 12:29:16 +10:00
96768078fa feat(ui): prep for socket jwt 2023-04-12 12:29:16 +10:00
13c9639d7b feat(ui): dynamic middleware loading 2023-04-12 12:29:16 +10:00
f104f0a390 feat(ui) working on making socket URL dynamic 2023-04-12 12:29:16 +10:00
c49d2accb7 feat(ui): export StatusIndicator and ModelSelect for header use 2023-04-12 12:29:16 +10:00
749a0912c8 feat(ui): add optional token for auth 2023-04-12 12:29:16 +10:00
759e5613cd feat(ui): wip events, comments, and general refactoring 2023-04-12 12:29:16 +10:00
ac9b83722e lang(ui): add toast strings 2023-04-12 12:29:16 +10:00
439a35e064 docs(ui): organise and update docs 2023-04-12 12:29:16 +10:00
7286843698 feat(ui): add support to disableTabs 2023-04-12 12:29:16 +10:00
77ba1b77d7 disable panels when app mounts 2023-04-12 12:29:16 +10:00
e749e7e915 feat(ui): invert logic to be disabled 2023-04-12 12:29:16 +10:00
e486559d8f feat(ui): disable panels based on app props 2023-04-12 12:29:16 +10:00
2d8982c23d feat(ui): wip refactor socket events 2023-04-12 12:29:16 +10:00
02d510ba17 chore(ui): regenerate api 2023-04-12 12:29:16 +10:00
84d9ccb014 feat(ui): wip gallery migration 2023-04-12 12:29:16 +10:00
b9fc136f25 feat(ui): wip gallery migration 2023-04-12 12:29:16 +10:00
f6691dbf3b chore(ui): regenerate api 2023-04-12 12:29:16 +10:00
cb11717b9c feat(ui): patch api generation for headers access 2023-04-12 12:29:16 +10:00
35c950c50d fix(ui): restore removed type 2023-04-12 12:29:16 +10:00
afb0b564e9 feat(ui): POST upload working 2023-04-12 12:29:16 +10:00
657efadffa fix(ui): separate thunk for initial gallery load so it properly gets index 0 2023-04-12 12:29:16 +10:00
5b1ffc292f feat(ui): clean up & comment results slice 2023-04-12 12:29:16 +10:00
cad289dfe5 feat(ui): begin migrating gallery to nodes
Along the way, migrate to use RTK `createEntityAdapter` for gallery images, and separate `results` and `uploads` into separate slices. Much cleaner this way.
2023-04-12 12:29:16 +10:00
1df999d082 chore(ui): add typescript as dev dependency
I am having trouble with TS versions after vscode updated and now uses TS 5. `madge` has installed 3.9.10 and for whatever reason my vscode wants to use that. Manually specifying 4.9.5 and then setting vscode to use that as the workspace TS fixes the issue.
2023-04-12 12:29:16 +10:00
1372536728 chore(ui): regenerate api client 2023-04-12 12:29:16 +10:00
23a69ea7bf docs(ui): update readme 2023-04-12 12:29:16 +10:00
5cff28aaf3 chore(ui): bump redux-toolkit 2023-04-12 12:29:16 +10:00
21fba1aac6 feat(ui): load images on socket connect
Rudimentary
2023-04-12 12:29:16 +10:00
c992c2fe7d feat(ui): add type guards for outputs 2023-04-12 12:29:16 +10:00
3e76c1a3cd feat(ui): make thunk types more consistent 2023-04-12 12:29:16 +10:00
5eb077accc feat(ui): fix parameters panel border color
This commit should be elsewhere but I don't want to break my flow
2023-04-12 12:29:16 +10:00
007794f48b feat(ui): disable NodeAPITest
This was polluting the network/socket logs.
2023-04-12 12:29:16 +10:00
95a336c26a feat(ui): add rtk action type guard 2023-04-12 12:29:16 +10:00
6ca0798303 fix(ui): fix middleware types 2023-04-12 12:29:16 +10:00
bb9986bfd2 feat(ui): handle random seeds 2023-04-12 12:29:16 +10:00
11f34e0388 feat(ui): add nodes mode script 2023-04-12 12:29:16 +10:00
dea27f451a chore(ui): add support for package mode 2023-04-12 12:29:16 +10:00
be32f5639b feat(ui): get intermediate images working but types are stubbed out 2023-04-12 12:29:16 +10:00
6fd9840608 feat(ui): img2img implementation 2023-04-12 12:29:16 +10:00
158528cf12 feat(ui): write separate nodes socket layer, txt2img generating and rendering w single node 2023-04-12 12:29:16 +10:00
1401a26a41 feat(ui): start hooking up dynamic txt2img node generation, create middleware for session invocation 2023-04-12 12:29:16 +10:00
213a2dcdc8 add optional apiUrl prop 2023-04-12 12:29:16 +10:00
85019ab1b0 use reference to sampler_name 2023-04-12 12:29:16 +10:00
683f8b324e use reference to sampler_name 2023-04-12 12:29:16 +10:00
8a45efbaf3 start building out node translations from frontend state and add notes about missing features 2023-04-12 12:29:16 +10:00
14a1871087 feat(ui): wip nodes
- extract api client method arg types instead of manually declaring them
- update example to display images
- general tidy up
2023-04-12 12:29:16 +10:00
3e3ac329c8 feat(ui): add socketio types 2023-04-12 12:29:16 +10:00
1db0940c67 fix(ui): fix scrollbar styles typing and prop
just noticed the typo, and made the types stronger.
2023-04-12 12:29:16 +10:00
b7fa23be64 fix(ui): disable OG web server socket connection 2023-04-12 12:29:16 +10:00
9be2c02d5e chore(ui): regenerate api client 2023-04-12 12:29:16 +10:00
686f03d2cc feat(ui): nodes cancel 2023-04-12 12:29:16 +10:00
2b6ca72b36 feat(ui): more nodes api prototyping 2023-04-12 12:29:16 +10:00
bfc0c0b3f6 feat(ui): generate object args for api client 2023-04-12 12:29:16 +10:00
e3c3ddc45b feat(backend): fixes for nodes/generator 2023-04-12 12:29:16 +10:00
9436f8e81e chore(ui): update openapi.json 2023-04-12 12:29:16 +10:00
68d1c35b6f chore(ui): update .eslintignore, .prettierignore 2023-04-12 12:29:16 +10:00
46f54c81ed chore(ui): organize generated files 2023-04-12 12:29:16 +10:00
860d495732 fix(ui): update client & nodes test code w/ new Edge type 2023-04-12 12:29:16 +10:00
7c24706778 feat(ui): add axios client generator and simple example 2023-04-12 12:29:16 +10:00
247 changed files with 10312 additions and 687 deletions

View File

@ -1,6 +1,8 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import io
from datetime import datetime, timezone
import json
import os
import uuid
from fastapi import Path, Query, Request, UploadFile
@ -8,6 +10,7 @@ from fastapi.responses import FileResponse, Response
from fastapi.routing import APIRouter
from PIL import Image
from invokeai.app.api.models.images import ImageResponse
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.services.item_storage import PaginatedResults
from ...services.image_storage import ImageType
@ -40,32 +43,47 @@ async def get_thumbnail(
"/uploads/",
operation_id="upload_image",
responses={
201: {"description": "The image was uploaded successfully"},
201: {"description": "The image was uploaded successfully", "model": ImageResponse},
404: {"description": "Session not found"},
},
status_code=201
)
async def upload_image(file: UploadFile, request: Request):
async def upload_image(file: UploadFile, request: Request, response: Response) -> ImageResponse:
if not file.content_type.startswith("image"):
return Response(status_code=415)
contents = await file.read()
try:
im = Image.open(contents)
img = Image.open(io.BytesIO(contents))
except:
# Error opening the image
return Response(status_code=415)
filename = f"{uuid.uuid4()}_{str(int(datetime.now(timezone.utc).timestamp()))}.png"
ApiDependencies.invoker.services.images.save(ImageType.UPLOAD, filename, im)
image_path = ApiDependencies.invoker.services.images.save(ImageType.UPLOAD, filename, img)
invokeai_metadata = json.loads(img.info.get("invokeai", "{}"))
return Response(
status_code=201,
headers={
"Location": request.url_for(
res = ImageResponse(
image_type=ImageType.UPLOAD,
image_name=filename,
# TODO: DiskImageStorage should not be building URLs...?
image_url=f"api/v1/images/{ImageType.UPLOAD.value}/{filename}",
thumbnail_url=f"api/v1/images/{ImageType.UPLOAD.value}/thumbnails/{os.path.splitext(filename)[0]}.webp",
# TODO: Creation of this object should happen elsewhere, just making it fit here so it works
metadata=ImageMetadata(
created=int(os.path.getctime(image_path)),
width=img.width,
height=img.height,
invokeai=invokeai_metadata
),
)
response.status_code = 201
response.headers["Location"] = request.url_for(
"get_image", image_type=ImageType.UPLOAD.value, image_name=filename
)
},
)
return res
@images_router.get(
"/",

View File

@ -9,7 +9,7 @@ from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageField, ImageType
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput
from .image import ImageOutput, build_image_output
class CvInvocationConfig(BaseModel):
@ -56,7 +56,9 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, image_inpainted)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
context.services.images.save(image_type, image_name, image_inpainted, self.dict())
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image_inpainted,
)

View File

@ -9,9 +9,9 @@ from torch import Tensor
from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageField, ImageType
from invokeai.app.invocations.util.get_model import choose_model
from invokeai.app.invocations.util.choose_model import choose_model
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput
from .image import ImageOutput, build_image_output
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
from ...backend.stable_diffusion import PipelineIntermediateState
from ..models.exceptions import CanceledException
@ -76,6 +76,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
def invoke(self, context: InvocationContext) -> ImageOutput:
# Handle invalid model parameter
model = choose_model(context.services.model_manager, self.model)
self.model_name = model["model_name"]
outputs = Txt2Img(model).generate(
prompt=self.prompt,
@ -95,9 +96,22 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, generate_output.image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_id = graph_execution_state.prepared_source_mapping[self.id]
invocation = graph_execution_state.execution_graph.get_node(self.id)
metadata = {
"session": context.graph_execution_state_id,
"source_id": source_id,
"invocation": invocation.dict()
}
context.services.images.save(image_type, image_name, generate_output.image, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=generate_output.image
)
@ -144,6 +158,7 @@ class ImageToImageInvocation(TextToImageInvocation):
# Handle invalid model parameter
model = choose_model(context.services.model_manager, self.model)
self.model = model["model_name"]
outputs = Img2Img(model).generate(
prompt=self.prompt,
@ -168,9 +183,11 @@ class ImageToImageInvocation(TextToImageInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, result_image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
context.services.images.save(image_type, image_name, result_image, self.dict())
return build_image_output(
image_type=image_type,
image_name=image_name,
image=result_image
)
class InpaintInvocation(ImageToImageInvocation):
@ -219,6 +236,7 @@ class InpaintInvocation(ImageToImageInvocation):
# Handle invalid model parameter
model = choose_model(context.services.model_manager, self.model)
self.model = model["model_name"]
outputs = Inpaint(model).generate(
prompt=self.prompt,
@ -243,7 +261,9 @@ class InpaintInvocation(ImageToImageInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, result_image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
context.services.images.save(image_type, image_name, result_image, self.dict())
return build_image_output(
image_type=image_type,
image_name=image_name,
image=result_image
)

View File

@ -9,7 +9,12 @@ from pydantic import BaseModel, Field
from ..models.image import ImageField, ImageType
from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InvocationContext,
InvocationConfig,
)
class PILInvocationConfig(BaseModel):
@ -22,51 +27,70 @@ class PILInvocationConfig(BaseModel):
},
}
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
#fmt: off
# fmt: off
type: Literal["image"] = "image"
image: ImageField = Field(default=None, description="The output image")
#fmt: on
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
# fmt: on
class Config:
schema_extra = {
'required': [
'type',
'image',
"required": [
"type",
"image",
"width",
"height",
]
}
def build_image_output(
image_type: ImageType, image_name: str, image: Image.Image
) -> ImageOutput:
image_field = ImageField(image_name=image_name, image_type=image_type)
return ImageOutput(image=image_field, width=image.width, height=image.height)
class MaskOutput(BaseInvocationOutput):
"""Base class for invocations that output a mask"""
#fmt: off
# fmt: off
type: Literal["mask"] = "mask"
mask: ImageField = Field(default=None, description="The output mask")
#fmt: on
# fmt: on
class Config:
schema_extra = {
'required': [
'type',
'mask',
"required": [
"type",
"mask",
]
}
# TODO: this isn't really necessary anymore
class LoadImageInvocation(BaseInvocation):
"""Load an image from a filename and provide it as output."""
#fmt: off
type: Literal["load_image"] = "load_image"
# Inputs
image_type: ImageType = Field(description="The type of the image")
image_name: str = Field(description="The name of the image")
#fmt: on
# # TODO: this isn't really necessary anymore
# class LoadImageInvocation(BaseInvocation):
# """Load an image from a filename and provide it as output."""
# #fmt: off
# type: Literal["load_image"] = "load_image"
def invoke(self, context: InvocationContext) -> ImageOutput:
return ImageOutput(
image=ImageField(image_type=self.image_type, image_name=self.image_name)
)
# # Inputs
# image_type: ImageType = Field(description="The type of the image")
# image_name: str = Field(description="The name of the image")
# #fmt: on
# def invoke(self, context: InvocationContext) -> ImageOutput:
# return ImageOutput(
# image_type=self.image_type,
# image_name=self.image_name,
# image=result_image
# )
class ShowImageInvocation(BaseInvocation):
@ -86,16 +110,17 @@ class ShowImageInvocation(BaseInvocation):
# TODO: how to handle failure?
return ImageOutput(
image=ImageField(
image_type=self.image.image_type, image_name=self.image.image_name
)
return build_image_output(
image_type=self.image.image_type,
image_name=self.image.image_name,
image=image,
)
class CropImageInvocation(BaseInvocation, PILInvocationConfig):
"""Crops an image to a specified box. The box can be outside of the image."""
#fmt: off
# fmt: off
type: Literal["crop"] = "crop"
# Inputs
@ -104,7 +129,7 @@ class CropImageInvocation(BaseInvocation, PILInvocationConfig):
y: int = Field(default=0, description="The top y coordinate of the crop rectangle")
width: int = Field(default=512, gt=0, description="The width of the crop rectangle")
height: int = Field(default=512, gt=0, description="The height of the crop rectangle")
#fmt: on
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
@ -120,15 +145,16 @@ class CropImageInvocation(BaseInvocation, PILInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, image_crop)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
context.services.images.save(image_type, image_name, image_crop, self.dict())
return build_image_output(
image_type=image_type, image_name=image_name, image=image_crop
)
class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
"""Pastes an image into another image."""
#fmt: off
# fmt: off
type: Literal["paste"] = "paste"
# Inputs
@ -137,7 +163,7 @@ class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting")
x: int = Field(default=0, description="The left x coordinate at which to paste the image")
y: int = Field(default=0, description="The top y coordinate at which to paste the image")
#fmt: on
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
base_image = context.services.images.get(
@ -170,21 +196,22 @@ class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, new_image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
context.services.images.save(image_type, image_name, new_image, self.dict())
return build_image_output(
image_type=image_type, image_name=image_name, image=new_image
)
class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
"""Extracts the alpha channel of an image as a mask."""
#fmt: off
# fmt: off
type: Literal["tomask"] = "tomask"
# Inputs
image: ImageField = Field(default=None, description="The image to create the mask from")
invert: bool = Field(default=False, description="Whether or not to invert the mask")
#fmt: on
# fmt: on
def invoke(self, context: InvocationContext) -> MaskOutput:
image = context.services.images.get(
@ -199,21 +226,21 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, image_mask)
context.services.images.save(image_type, image_name, image_mask, self.dict())
return MaskOutput(mask=ImageField(image_type=image_type, image_name=image_name))
class BlurInvocation(BaseInvocation, PILInvocationConfig):
"""Blurs an image"""
#fmt: off
# fmt: off
type: Literal["blur"] = "blur"
# Inputs
image: ImageField = Field(default=None, description="The image to blur")
radius: float = Field(default=8.0, ge=0, description="The blur radius")
blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur")
#fmt: on
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
@ -231,22 +258,23 @@ class BlurInvocation(BaseInvocation, PILInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, blur_image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
context.services.images.save(image_type, image_name, blur_image, self.dict())
return build_image_output(
image_type=image_type, image_name=image_name, image=blur_image
)
class LerpInvocation(BaseInvocation, PILInvocationConfig):
"""Linear interpolation of all pixels of an image"""
#fmt: off
# fmt: off
type: Literal["lerp"] = "lerp"
# Inputs
image: ImageField = Field(default=None, description="The image to lerp")
min: int = Field(default=0, ge=0, le=255, description="The minimum output value")
max: int = Field(default=255, ge=0, le=255, description="The maximum output value")
#fmt: on
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
@ -262,22 +290,23 @@ class LerpInvocation(BaseInvocation, PILInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, lerp_image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
context.services.images.save(image_type, image_name, lerp_image, self.dict())
return build_image_output(
image_type=image_type, image_name=image_name, image=lerp_image
)
class InverseLerpInvocation(BaseInvocation, PILInvocationConfig):
"""Inverse linear interpolation of all pixels of an image"""
#fmt: off
# fmt: off
type: Literal["ilerp"] = "ilerp"
# Inputs
image: ImageField = Field(default=None, description="The image to lerp")
min: int = Field(default=0, ge=0, le=255, description="The minimum input value")
max: int = Field(default=255, ge=0, le=255, description="The maximum input value")
#fmt: on
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
@ -298,7 +327,7 @@ class InverseLerpInvocation(BaseInvocation, PILInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, ilerp_image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
context.services.images.save(image_type, image_name, ilerp_image, self.dict())
return build_image_output(
image_type=image_type, image_name=image_name, image=ilerp_image
)

View File

@ -5,7 +5,7 @@ from pydantic import BaseModel, Field
import torch
from invokeai.app.models.exceptions import CanceledException
from invokeai.app.invocations.util.get_model import choose_model
from invokeai.app.invocations.util.choose_model import choose_model
from invokeai.app.util.step_callback import diffusers_step_callback_adapter
from ...backend.model_management.model_manager import ModelManager
@ -18,7 +18,7 @@ from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationCont
import numpy as np
from ..services.image_storage import ImageType
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput
from .image import ImageField, ImageOutput, build_image_output
from ...backend.stable_diffusion import PipelineIntermediateState
from diffusers.schedulers import SchedulerMixin as Scheduler
import diffusers
@ -350,7 +350,9 @@ class LatentsToImageInvocation(BaseInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
context.services.images.save(image_type, image_name, image, self.dict())
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image
)

View File

@ -6,7 +6,7 @@ from pydantic import Field
from invokeai.app.models.image import ImageField, ImageType
from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput
from .image import ImageOutput, build_image_output
class RestoreFaceInvocation(BaseInvocation):
"""Restores faces in an image."""
@ -44,7 +44,9 @@ class RestoreFaceInvocation(BaseInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, results[0][0])
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
context.services.images.save(image_type, image_name, results[0][0], self.dict())
return build_image_output(
image_type=image_type,
image_name=image_name,
image=results[0][0]
)

View File

@ -8,7 +8,7 @@ from pydantic import Field
from invokeai.app.models.image import ImageField, ImageType
from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput
from .image import ImageOutput, build_image_output
class UpscaleInvocation(BaseInvocation):
@ -49,7 +49,9 @@ class UpscaleInvocation(BaseInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, results[0][0])
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
context.services.images.save(image_type, image_name, results[0][0], self.dict())
return build_image_output(
image_type=image_type,
image_name=image_name,
image=results[0][0]
)

View File

@ -1,11 +1,14 @@
from invokeai.app.invocations.baseinvocation import InvocationContext
from invokeai.backend.model_management.model_manager import ModelManager
def choose_model(model_manager: ModelManager, model_name: str):
"""Returns the default model if the `model_name` not a valid model, else returns the selected model."""
if model_manager.valid_model(model_name):
return model_manager.get_model(model_name)
model = model_manager.get_model(model_name)
else:
print(f"* Warning: '{model_name}' is not a valid model name. Using default model instead.")
return model_manager.get_model()
model = model_manager.get_model()
print(
f"* Warning: '{model_name}' is not a valid model name. Using default model \'{model['model_name']}\' instead."
)
return model

View File

@ -1,11 +1,26 @@
from typing import Optional
from typing import Any, Optional, Dict
from pydantic import BaseModel, Field
class ImageMetadata(BaseModel):
"""An image's metadata"""
timestamp: float = Field(description="The creation timestamp of the image")
class InvokeAIMetadata(BaseModel):
"""An image's InvokeAI-specific metadata"""
session: Optional[str] = Field(description="The session that generated this image")
source_id: Optional[str] = Field(
description="The source id of the invocation that generated this image"
)
# TODO: figure out metadata
invocation: Optional[Dict[str, Any]] = Field(
default={}, description="The prepared invocation that generated this image"
)
class ImageMetadata(BaseModel):
"""An image's general metadata"""
created: int = Field(description="The creation timestamp of the image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
# TODO: figure out metadata
sd_metadata: Optional[dict] = Field(default={}, description="The image's SD-specific metadata")
invokeai: Optional[InvokeAIMetadata] = Field(
default={}, description="The image's InvokeAI-specific metadata"
)

View File

@ -25,7 +25,8 @@ class EventServiceBase:
def emit_generator_progress(
self,
graph_execution_state_id: str,
invocation_id: str,
invocation_dict: dict,
source_id: str,
progress_image: ProgressImage | None,
step: int,
total_steps: int,
@ -35,7 +36,8 @@ class EventServiceBase:
event_name="generator_progress",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
invocation_id=invocation_id,
invocation=invocation_dict,
source_id=source_id,
progress_image=progress_image,
step=step,
total_steps=total_steps,
@ -43,40 +45,43 @@ class EventServiceBase:
)
def emit_invocation_complete(
self, graph_execution_state_id: str, invocation_id: str, result: Dict
self, graph_execution_state_id: str, result: Dict, invocation_dict: Dict, source_id: str,
) -> None:
"""Emitted when an invocation has completed"""
self.__emit_session_event(
event_name="invocation_complete",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
invocation_id=invocation_id,
invocation=invocation_dict,
source_id=source_id,
result=result,
),
)
def emit_invocation_error(
self, graph_execution_state_id: str, invocation_id: str, error: str
self, graph_execution_state_id: str, invocation_dict: Dict, source_id: str, error: str
) -> None:
"""Emitted when an invocation has completed"""
self.__emit_session_event(
event_name="invocation_error",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
invocation_id=invocation_id,
invocation=invocation_dict,
source_id=source_id,
error=error,
),
)
def emit_invocation_started(
self, graph_execution_state_id: str, invocation_id: str
self, graph_execution_state_id: str, invocation_dict: Dict, source_id: str
) -> None:
"""Emitted when an invocation has started"""
self.__emit_session_event(
event_name="invocation_started",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
invocation_id=invocation_id,
invocation=invocation_dict,
source_id=source_id,
),
)

View File

@ -2,16 +2,17 @@
import datetime
import os
import json
from glob import glob
from abc import ABC, abstractmethod
from enum import Enum
from pathlib import Path
from queue import Queue
from typing import Callable, Dict, List
from typing import Any, Callable, Dict, List, Union
from PIL.Image import Image
import PIL.Image as PILImage
from pydantic import BaseModel
from pydantic import BaseModel, Json
from invokeai.app.api.models.images import ImageResponse
from invokeai.app.models.image import ImageField, ImageType
from invokeai.app.models.metadata import ImageMetadata
@ -42,7 +43,7 @@ class ImageStorageBase(ABC):
pass
@abstractmethod
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
def save(self, image_type: ImageType, image_name: str, image: Image, metadata: Dict[str, Any] | None = None) -> str:
pass
@abstractmethod
@ -100,6 +101,8 @@ class DiskImageStorage(ImageStorageBase):
for path in page_of_image_paths:
filename = os.path.basename(path)
img = PILImage.open(path)
invokeai_metadata = json.loads(img.info.get("invokeai", "{}"))
page_of_images.append(
ImageResponse(
image_type=image_type.value,
@ -109,9 +112,10 @@ class DiskImageStorage(ImageStorageBase):
thumbnail_url=f"api/v1/images/{image_type.value}/thumbnails/{os.path.splitext(filename)[0]}.webp",
# TODO: Creation of this object should happen elsewhere, just making it fit here so it works
metadata=ImageMetadata(
timestamp=os.path.getctime(path),
created=int(os.path.getctime(path)),
width=img.width,
height=img.height,
invokeai=invokeai_metadata
),
)
)
@ -150,10 +154,11 @@ class DiskImageStorage(ImageStorageBase):
path = os.path.join(self.__output_folder, image_type, image_name)
return path
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
def save(self, image_type: ImageType, image_name: str, image: Image, metadata: Dict[str, Any] | None = None) -> str:
print(metadata)
image_subpath = os.path.join(image_type, image_name)
self.__pngWriter.save_image_and_prompt_to_png(
image, "", image_subpath, None
image, "", image_subpath, metadata
) # TODO: just pass full path to png writer
save_thumbnail(
image=image,
@ -162,6 +167,7 @@ class DiskImageStorage(ImageStorageBase):
)
image_path = self.get_path(image_type, image_name)
self.__set_cache(image_path, image)
return image_path
def delete(self, image_type: ImageType, image_name: str) -> None:
image_path = self.get_path(image_type, image_name)

View File

@ -43,10 +43,14 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
queue_item.invocation_id
)
# get the source node to provide to cliepnts (the prepared node is not as useful)
source_id = graph_execution_state.prepared_source_mapping[invocation.id]
# Send starting event
self.__invoker.services.events.emit_invocation_started(
graph_execution_state_id=graph_execution_state.id,
invocation_id=invocation.id,
invocation_dict=invocation.dict(),
source_id=source_id
)
# Invoke
@ -75,7 +79,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# Send complete event
self.__invoker.services.events.emit_invocation_complete(
graph_execution_state_id=graph_execution_state.id,
invocation_id=invocation.id,
invocation_dict=invocation.dict(),
source_id=source_id,
result=outputs.dict(),
)
@ -99,7 +104,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# Send error event
self.__invoker.services.events.emit_invocation_error(
graph_execution_state_id=graph_execution_state.id,
invocation_id=invocation.id,
invocation_dict=invocation.dict(),
source_id=source_id,
error=error,
)

View File

@ -1,23 +1,25 @@
import sqlite3
from threading import Lock
from typing import Generic, TypeVar, Union, get_args
from pydantic import BaseModel, parse_raw_as
from .item_storage import ItemStorageABC, PaginatedResults
from sqlalchemy import create_engine, String, TEXT, Engine, select
from sqlalchemy.orm import DeclarativeBase, mapped_column, Session
T = TypeVar("T", bound=BaseModel)
sqlite_memory = ":memory:"
class Base(DeclarativeBase):
pass
class SqliteItemStorage(ItemStorageABC, Generic[T]):
_filename: str
_table_name: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_id_field: str
_lock: Lock
_engine: Engine
# _table: ??? # TODO: figure out how to type this
def __init__(self, filename: str, table_name: str, id_field: str = "id"):
super().__init__()
@ -25,86 +27,79 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
self._filename = filename
self._table_name = table_name
self._id_field = id_field # TODO: validate that T has this field
self._lock = Lock()
self._conn = sqlite3.connect(
self._filename, check_same_thread=False
) # TODO: figure out a better threading solution
self._cursor = self._conn.cursor()
self._engine = create_engine(f"sqlite+pysqlite:///{self._filename}")
self._create_table()
def _create_table(self):
try:
self._lock.acquire()
self._cursor.execute(
f"""CREATE TABLE IF NOT EXISTS {self._table_name} (
item TEXT,
id TEXT GENERATED ALWAYS AS (json_extract(item, '$.{self._id_field}')) VIRTUAL NOT NULL);"""
# dynamically create the ORM model class to avoid name collisions
# cannot access `self.__orig_class__` in `__init__` or `__new__` so
# format the table name into the class name
pascal_table_name = self._table_name.replace("_", " ").title()
pascal_table_name = pascal_table_name.replace(" ", "")
table_dict = dict(
__tablename__=self._table_name,
id=mapped_column(String, primary_key=True),
item=mapped_column(TEXT, nullable=False),
)
self._cursor.execute(
f"""CREATE UNIQUE INDEX IF NOT EXISTS {self._table_name}_id ON {self._table_name}(id);"""
)
finally:
self._lock.release()
self._table = type(pascal_table_name, (Base,), table_dict)
Base.metadata.create_all(self._engine)
def _parse_item(self, item: str) -> T:
item_type = get_args(self.__orig_class__)[0]
return parse_raw_as(item_type, item)
def set(self, item: T):
try:
self._lock.acquire()
self._cursor.execute(
f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""",
(item.json(),),
)
self._conn.commit()
finally:
self._lock.release()
session = Session(self._engine)
item_id = str(getattr(item, self._id_field))
new_item = self._table(id=item_id, item=item.json())
session.merge(new_item)
session.commit()
session.close()
self._on_changed(item)
def get(self, id: str) -> Union[T, None]:
try:
self._lock.acquire()
self._cursor.execute(
f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),)
)
result = self._cursor.fetchone()
finally:
self._lock.release()
session = Session(self._engine)
if not result:
item = session.get(self._table, str(id))
session.close()
if not item:
return None
return self._parse_item(result[0])
return self._parse_item(item.item)
def delete(self, id: str):
try:
self._lock.acquire()
self._cursor.execute(
f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),)
)
self._conn.commit()
finally:
self._lock.release()
session = Session(self._engine)
item = session.get(self._table, id)
session.delete(item)
session.commit()
session.close()
self._on_deleted(id)
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
try:
self._lock.acquire()
self._cursor.execute(
f"""SELECT item FROM {self._table_name} LIMIT ? OFFSET ?;""",
(per_page, page * per_page),
)
result = self._cursor.fetchall()
session = Session(self._engine)
stmt = select(self._table.item).limit(per_page).offset(page * per_page)
result = session.execute(stmt)
items = list(map(lambda r: self._parse_item(r[0]), result))
count = session.query(self._table.item).count()
self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""")
count = self._cursor.fetchone()[0]
finally:
self._lock.release()
session.commit()
session.close()
pageCount = int(count / per_page) + 1
@ -115,23 +110,19 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
def search(
self, query: str, page: int = 0, per_page: int = 10
) -> PaginatedResults[T]:
try:
self._lock.acquire()
self._cursor.execute(
f"""SELECT item FROM {self._table_name} WHERE item LIKE ? LIMIT ? OFFSET ?;""",
(f"%{query}%", per_page, page * per_page),
)
result = self._cursor.fetchall()
session = Session(self._engine)
items = list(map(lambda r: self._parse_item(r[0]), result))
self._cursor.execute(
f"""SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;""",
(f"%{query}%",),
stmt = (
session.query(self._table)
.where(self._table.item.like(f"%{query}%"))
.limit(per_page)
.offset(page * per_page)
)
count = self._cursor.fetchone()[0]
finally:
self._lock.release()
result = session.execute(stmt)
items = list(map(lambda r: self._parse_item(r[0].item), result))
count = session.query(self._table.item).count()
pageCount = int(count / per_page) + 1

View File

@ -1,3 +1,4 @@
from re import S
import torch
from ..invocations.baseinvocation import InvocationContext
from ...backend.util.util import image_to_dataURL
@ -20,12 +21,18 @@ def fast_latents_step_callback(
dataURL = image_to_dataURL(image, image_format="JPEG")
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_id = graph_execution_state.prepared_source_mapping[id]
invocation = graph_execution_state.execution_graph.get_node(id)
context.services.events.emit_generator_progress(
context.graph_execution_state_id,
id,
{"width": width, "height": height, "dataURL": dataURL},
step,
steps,
graph_execution_state_id=context.graph_execution_state_id,
invocation_dict=invocation.dict(),
source_id=source_id,
progress_image={"width": width, "height": height, "dataURL": dataURL},
step=step,
total_steps=steps,
)

View File

@ -41,7 +41,7 @@ class PngWriter:
info = PngImagePlugin.PngInfo()
info.add_text("Dream", dream_prompt)
if metadata:
info.add_text("sd-metadata", json.dumps(metadata))
info.add_text("invokeai", json.dumps(metadata))
image.save(path, "PNG", pnginfo=info, compress_level=compress_level)
return path

View File

@ -6,3 +6,5 @@ stats.html
index.html
.yarn/
*.scss
src/services/api/
src/services/fixtures/*

View File

@ -3,4 +3,8 @@ dist/
node_modules/
patches/
stats.html
index.html
.yarn/
*.scss
src/services/api/
src/services/fixtures/*

View File

@ -0,0 +1,87 @@
# Generated axios API client
- [Generated axios API client](#generated-axios-api-client)
- [Generation](#generation)
- [Generate the API client from the nodes web server](#generate-the-api-client-from-the-nodes-web-server)
- [Generate the API client from JSON](#generate-the-api-client-from-json)
- [Getting the JSON from the nodes web server](#getting-the-json-from-the-nodes-web-server)
- [Getting the JSON with a python script](#getting-the-json-with-a-python-script)
- [Generate the API client](#generate-the-api-client)
- [The generated client](#the-generated-client)
- [API client customisation](#api-client-customisation)
This API client is generated by an [openapi code generator](https://github.com/ferdikoomen/openapi-typescript-codegen).
All files in `invokeai/frontend/web/src/services/api/` are made by the generator.
## Generation
The axios client may be generated by from the OpenAPI schema from the nodes web server, or from JSON.
### Generate the API client from the nodes web server
We need to start the nodes web server, which serves the OpenAPI schema to the generator.
1. Start the nodes web server.
```bash
# from the repo root
python scripts/invoke-new.py --web
```
2. Generate the API client.
```bash
# from invokeai/frontend/web/
yarn api:web
```
### Generate the API client from JSON
The JSON can be acquired from the nodes web server, or with a python script.
#### Getting the JSON from the nodes web server
Start the nodes web server as described above, then download the file.
```bash
# from invokeai/frontend/web/
curl http://localhost:9090/openapi.json -o openapi.json
```
#### Getting the JSON with a python script
Run this python script from the repo root, so it can access the nodes server modules.
The script will output `openapi.json` in the repo root. Then we need to move it to `invokeai/frontend/web/`.
```bash
# from the repo root
python invokeai/app/util/generate_openapi_json.py
mv invokeai/app/util/openapi.json invokeai/frontend/web/services/fixtures/
```
#### Generate the API client
Now we can generate the API client from the JSON.
```bash
# from invokeai/frontend/web/
yarn api:file
```
## The generated client
The client will be written to `invokeai/frontend/web/services/api/`:
- `axios` client
- TS types
- An easily parseable schema, which we can use to generate UI
## API client customisation
The generator has a default `request.ts` file that implements a base `axios` client. The generated client uses this base client.
One shortcoming of this is base client is it does not provide response headers unless the response body is empty. To fix this, we provide our own lightly-patched `request.ts`.
To access the headers, call `getHeaders(response)` on any response from the generated api client. This function is exported from `invokeai/frontend/web/src/services/util/getHeaders.ts`.

View File

@ -0,0 +1,21 @@
# Events
Events via `socket.io`
## `actions.ts`
Redux actions for all socket events. Payloads all include a timestamp, and optionally some other data.
Any reducer (or middleware) can respond to the actions.
## `middleware.ts`
Redux middleware for events.
Handles dispatching the event actions. Only put logic here if it can't really go anywhere else.
For example, on connect we want to load images to the gallery if it's not populated. This requires dispatching a thunk, so we need to directly dispatch this in the middleware.
## `types.ts`
Hand-written types for the socket events. Cannot generate these from the server, but fortunately they are few and simple.

View File

@ -0,0 +1,17 @@
# Node Editor Design
WIP
nodes
everything in `src/features/nodes/`
have a look at `state.nodes.invocation`
- on socket connect, if no schema saved, fetch `localhost:9090/openapi.json`, save JSON to `state.nodes.schema`
- on fulfilled schema fetch, `parseSchema()` the schema. this outputs a `Record<string, Invocation>` which is saved to `state.nodes.invocations` - `Invocation` is like a template for the node
- when you add a node, the the `Invocation` template is passed to `InvocationComponent.tsx` to build the UI component for that node
- inputs/outputs have field types - and each field type gets an `FieldComponent` which includes a dispatcher to write state changes to redux `nodesSlice`
- `reactflow` sends changes to nodes/edges to redux
- to invoke, `buildNodesGraph()` state, then send this
- changed onClick Invoke button actions to build the schema, then when schema builds it dispatches the actual network request to create the session - see `session.ts`

View File

@ -0,0 +1,29 @@
# Package Scripts
WIP walkthrough of `package.json` scripts.
## `theme` & `theme:watch`
These run the Chakra CLI to generate types for the theme, or watch for code change and re-generate the types.
The CLI essentially monkeypatches Chakra's files in `node_modules`.
## `postinstall`
The `postinstall` script patches a few packages and runs the Chakra CLI to generate types for the theme.
### Patch `@chakra-ui/cli`
See: <https://github.com/chakra-ui/chakra-ui/issues/7394>
### Patch `redux-persist`
We want to persist the canvas state to `localStorage` but many canvas operations change data very quickly, so we need to debounce the writes to `localStorage`.
`redux-persist` is unfortunately unmaintained. The repo's current code is nonfunctional, but the last release's code depends on a package that was removed from `npm` for being malware, so we cannot just fork it.
So, we have to patch it directly. Perhaps a better way would be to write a debounced storage adapter, but I couldn't figure out how to do that.
### Patch `redux-deep-persist`
This package makes blacklisting and whitelisting persist configs very simple, but we have to patch it to match `redux-persist` for the types to work.

View File

@ -1,10 +1,16 @@
# InvokeAI Web UI
- [InvokeAI Web UI](#invokeai-web-ui)
- [Stack](#stack)
- [Contributing](#contributing)
- [Dev Environment](#dev-environment)
- [Production builds](#production-builds)
The UI is a fairly straightforward Typescript React app. The only really fancy stuff is the Unified Canvas.
Code in `invokeai/frontend/web/` if you want to have a look.
## Details
## Stack
State management is Redux via [Redux Toolkit](https://github.com/reduxjs/redux-toolkit). Communication with server is a mix of HTTP and [socket.io](https://github.com/socketio/socket.io-client) (with a custom redux middleware to help).
@ -32,7 +38,7 @@ Start everything in dev mode:
1. Start the dev server: `yarn dev`
2. Start the InvokeAI UI per usual: `invokeai --web`
3. Point your browser to the dev server address e.g. `http://localhost:5173/`
3. Point your browser to the dev server address e.g. <http://localhost:5173/>
### Production builds

View File

@ -1,6 +1,7 @@
import React, { PropsWithChildren } from 'react';
import { IAIPopoverProps } from '../web/src/common/components/IAIPopover';
import { IAIIconButtonProps } from '../web/src/common/components/IAIIconButton';
import { InvokeTabName } from 'features/ui/store/tabMap';
export {};
@ -64,9 +65,24 @@ declare module '@invoke-ai/invoke-ai-ui' {
declare class SettingsModal extends React.Component<SettingsModalProps> {
public constructor(props: SettingsModalProps);
}
declare class StatusIndicator extends React.Component<StatusIndicatorProps> {
public constructor(props: StatusIndicatorProps);
}
declare class ModelSelect extends React.Component<ModelSelectProps> {
public constructor(props: ModelSelectProps);
}
}
declare function Invoke(props: PropsWithChildren): JSX.Element;
interface InvokeProps extends PropsWithChildren {
apiUrl?: string;
disabledPanels?: string[];
disabledTabs?: InvokeTabName[];
token?: string;
}
declare function Invoke(props: InvokeProps): JSX.Element;
export {
ThemeChanger,
@ -74,5 +90,7 @@ export {
IAIPopover,
IAIIconButton,
SettingsModal,
StatusIndicator,
ModelSelect,
};
export = Invoke;

View File

@ -5,7 +5,10 @@
"scripts": {
"prepare": "cd ../../../ && husky install invokeai/frontend/web/.husky",
"dev": "concurrently \"vite dev\" \"yarn run theme:watch\"",
"dev:nodes": "concurrently \"vite dev --mode nodes\" \"yarn run theme:watch\"",
"build": "yarn run lint && vite build",
"api:web": "openapi -i http://localhost:9090/openapi.json -o src/services/api --client axios --useOptions --useUnionTypes --exportSchemas true --indent 2 --request src/services/fixtures/request.ts",
"api:file": "openapi -i src/services/fixtures/openapi.json -o src/services/api --client axios --useOptions --useUnionTypes --exportSchemas true --indent 2 --request src/services/fixtures/request.ts",
"preview": "vite preview",
"lint:madge": "madge --circular src/main.tsx",
"lint:eslint": "eslint --max-warnings=0 .",
@ -41,9 +44,10 @@
"@chakra-ui/react": "^2.5.1",
"@chakra-ui/styled-system": "^2.6.1",
"@chakra-ui/theme-tools": "^2.0.16",
"@dagrejs/graphlib": "^2.1.12",
"@emotion/react": "^11.10.6",
"@emotion/styled": "^11.10.6",
"@reduxjs/toolkit": "^1.9.2",
"@reduxjs/toolkit": "^1.9.3",
"chakra-ui-contextmenu": "^1.0.5",
"dateformat": "^5.0.3",
"formik": "^2.2.9",
@ -67,7 +71,9 @@
"react-redux": "^8.0.5",
"react-transition-group": "^4.4.5",
"react-zoom-pan-pinch": "^2.6.1",
"reactflow": "^11.7.0",
"redux-deep-persist": "^1.0.7",
"redux-dynamic-middlewares": "^2.2.0",
"redux-persist": "^6.0.0",
"socket.io-client": "^4.6.0",
"use-image": "^1.1.0",
@ -83,6 +89,7 @@
"@typescript-eslint/eslint-plugin": "^5.52.0",
"@typescript-eslint/parser": "^5.52.0",
"@vitejs/plugin-react-swc": "^3.2.0",
"axios": "^1.3.4",
"babel-plugin-transform-imports": "^2.0.0",
"concurrently": "^7.6.0",
"eslint": "^8.34.0",
@ -90,13 +97,17 @@
"eslint-plugin-prettier": "^4.2.1",
"eslint-plugin-react": "^7.32.2",
"eslint-plugin-react-hooks": "^4.6.0",
"form-data": "^4.0.0",
"husky": "^8.0.3",
"lint-staged": "^13.1.2",
"madge": "^6.0.0",
"openapi-types": "^12.1.0",
"openapi-typescript-codegen": "^0.23.0",
"postinstall-postinstall": "^2.1.0",
"prettier": "^2.8.4",
"rollup-plugin-visualizer": "^5.9.0",
"terser": "^5.16.4",
"typescript": "4.9.5",
"vite": "^4.1.2",
"vite-plugin-eslint": "^1.8.1",
"vite-tsconfig-paths": "^4.0.5",

View File

@ -522,6 +522,10 @@
"resetComplete": "Web UI has been reset. Refresh the page to reload."
},
"toast": {
"serverError": "Server Error",
"disconnected": "Disconnected from Server",
"connected": "Connected to Server",
"canceled": "Processing Canceled",
"tempFoldersEmptied": "Temp Folder Emptied",
"uploadFailed": "Upload failed",
"uploadFailedMultipleImagesDesc": "Multiple images pasted, may only upload one image at a time",

View File

@ -13,16 +13,42 @@ import { Box, Flex, Grid, Portal, useColorMode } from '@chakra-ui/react';
import { APP_HEIGHT, APP_WIDTH } from 'theme/util/constants';
import ImageGalleryPanel from 'features/gallery/components/ImageGalleryPanel';
import Lightbox from 'features/lightbox/components/Lightbox';
import { useAppSelector } from './storeHooks';
import { useAppDispatch, useAppSelector } from './storeHooks';
import { PropsWithChildren, useEffect } from 'react';
import { setDisabledPanels, setDisabledTabs } from 'features/ui/store/uiSlice';
import { InvokeTabName } from 'features/ui/store/tabMap';
import { shouldTransformUrlsChanged } from 'features/system/store/systemSlice';
keepGUIAlive();
const App = (props: PropsWithChildren) => {
interface Props extends PropsWithChildren {
options: {
disabledPanels: string[];
disabledTabs: InvokeTabName[];
shouldTransformUrls?: boolean;
};
}
const App = (props: Props) => {
useToastWatcher();
const currentTheme = useAppSelector((state) => state.ui.currentTheme);
const { setColorMode } = useColorMode();
const dispatch = useAppDispatch();
useEffect(() => {
dispatch(setDisabledPanels(props.options.disabledPanels));
}, [dispatch, props.options.disabledPanels]);
useEffect(() => {
dispatch(setDisabledTabs(props.options.disabledTabs));
}, [dispatch, props.options.disabledTabs]);
useEffect(() => {
dispatch(
shouldTransformUrlsChanged(Boolean(props.options.shouldTransformUrls))
);
}, [dispatch, props.options.shouldTransformUrls]);
useEffect(() => {
setColorMode(['light'].includes(currentTheme) ? 'light' : 'dark');

View File

@ -14,6 +14,8 @@
import { InvokeTabName } from 'features/ui/store/tabMap';
import { IRect } from 'konva/lib/types';
import { ImageMetadata, ImageType } from 'services/api';
import { AnyInvocation } from 'services/events/types';
/**
* TODO:
@ -113,7 +115,7 @@ export declare type Metadata = SystemGenerationMetadata & {
};
// An Image has a UUID, url, modified timestamp, width, height and maybe metadata
export declare type Image = {
export declare type _Image = {
uuid: string;
url: string;
thumbnail: string;
@ -124,11 +126,23 @@ export declare type Image = {
category: GalleryCategory;
isBase64?: boolean;
dreamPrompt?: 'string';
name?: string;
};
/**
* ResultImage
*/
export declare type Image = {
name: string;
type: ImageType;
url: string;
thumbnail: string;
metadata: ImageMetadata;
};
// GalleryImages is an array of Image.
export declare type GalleryImages = {
images: Array<Image>;
images: Array<_Image>;
};
/**
@ -275,7 +289,7 @@ export declare type SystemStatusResponse = SystemStatus;
export declare type SystemConfigResponse = SystemConfig;
export declare type ImageResultResponse = Omit<Image, 'uuid'> & {
export declare type ImageResultResponse = Omit<_Image, 'uuid'> & {
boundingBox?: IRect;
generationMode: InvokeTabName;
};
@ -296,7 +310,7 @@ export declare type ErrorResponse = {
};
export declare type GalleryImagesResponse = {
images: Array<Omit<Image, 'uuid'>>;
images: Array<Omit<_Image, 'uuid'>>;
areMoreImagesAvailable: boolean;
category: GalleryCategory;
};

View File

@ -13,9 +13,13 @@ import { InvokeTabName } from 'features/ui/store/tabMap';
export const generateImage = createAction<InvokeTabName>(
'socketio/generateImage'
);
export const runESRGAN = createAction<InvokeAI.Image>('socketio/runESRGAN');
export const runFacetool = createAction<InvokeAI.Image>('socketio/runFacetool');
export const deleteImage = createAction<InvokeAI.Image>('socketio/deleteImage');
export const runESRGAN = createAction<InvokeAI._Image>('socketio/runESRGAN');
export const runFacetool = createAction<InvokeAI._Image>(
'socketio/runFacetool'
);
export const deleteImage = createAction<InvokeAI._Image>(
'socketio/deleteImage'
);
export const requestImages = createAction<GalleryCategory>(
'socketio/requestImages'
);

View File

@ -91,7 +91,7 @@ const makeSocketIOEmitters = (
})
);
},
emitRunESRGAN: (imageToProcess: InvokeAI.Image) => {
emitRunESRGAN: (imageToProcess: InvokeAI._Image) => {
dispatch(setIsProcessing(true));
const {
@ -119,7 +119,7 @@ const makeSocketIOEmitters = (
})
);
},
emitRunFacetool: (imageToProcess: InvokeAI.Image) => {
emitRunFacetool: (imageToProcess: InvokeAI._Image) => {
dispatch(setIsProcessing(true));
const {
@ -150,7 +150,7 @@ const makeSocketIOEmitters = (
})
);
},
emitDeleteImage: (imageToDelete: InvokeAI.Image) => {
emitDeleteImage: (imageToDelete: InvokeAI._Image) => {
const { url, uuid, category, thumbnail } = imageToDelete;
dispatch(removeImage(imageToDelete));
socketio.emit('deleteImage', url, thumbnail, uuid, category);

View File

@ -34,8 +34,9 @@ import type { RootState } from 'app/store';
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
import {
clearInitialImage,
initialImageSelected,
setInfillMethod,
setInitialImage,
// setInitialImage,
setMaskPath,
} from 'features/parameters/store/generationSlice';
import { tabMap } from 'features/ui/store/tabMap';
@ -146,7 +147,8 @@ const makeSocketIOListeners = (
const activeTabName = tabMap[activeTab];
switch (activeTabName) {
case 'img2img': {
dispatch(setInitialImage(newImage));
dispatch(initialImageSelected(newImage.uuid));
// dispatch(setInitialImage(newImage));
break;
}
}
@ -262,7 +264,7 @@ const makeSocketIOListeners = (
*/
// Generate a UUID for each image
const preparedImages = images.map((image): InvokeAI.Image => {
const preparedImages = images.map((image): InvokeAI._Image => {
return {
uuid: uuidv4(),
...image,
@ -334,7 +336,7 @@ const makeSocketIOListeners = (
if (
initialImage === url ||
(initialImage as InvokeAI.Image)?.url === url
(initialImage as InvokeAI._Image)?.url === url
) {
dispatch(clearInitialImage());
}

View File

@ -29,6 +29,8 @@ export const socketioMiddleware = () => {
path: `${window.location.pathname}socket.io`,
});
socketio.disconnect();
let areListenersSet = false;
const middleware: Middleware = (store) => (next) => (action) => {

View File

@ -2,18 +2,35 @@ import { combineReducers, configureStore } from '@reduxjs/toolkit';
import { persistReducer } from 'redux-persist';
import storage from 'redux-persist/lib/storage'; // defaults to localStorage for web
import dynamicMiddlewares from 'redux-dynamic-middlewares';
import { getPersistConfig } from 'redux-deep-persist';
import canvasReducer from 'features/canvas/store/canvasSlice';
import galleryReducer from 'features/gallery/store/gallerySlice';
import lightboxReducer from 'features/lightbox/store/lightboxSlice';
import generationReducer from 'features/parameters/store/generationSlice';
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
import systemReducer from 'features/system/store/systemSlice';
import galleryReducer, {
GalleryState,
} from 'features/gallery/store/gallerySlice';
import resultsReducer, {
resultsAdapter,
ResultsState,
} from 'features/gallery/store/resultsSlice';
import uploadsReducer from 'features/gallery/store/uploadsSlice';
import lightboxReducer, {
LightboxState,
} from 'features/lightbox/store/lightboxSlice';
import generationReducer, {
GenerationState,
} from 'features/parameters/store/generationSlice';
import postprocessingReducer, {
PostprocessingState,
} from 'features/parameters/store/postprocessingSlice';
import systemReducer, { SystemState } from 'features/system/store/systemSlice';
import uiReducer from 'features/ui/store/uiSlice';
import modelsReducer from 'features/system/store/modelSlice';
import nodesReducer, { NodesState } from 'features/nodes/store/nodesSlice';
import { socketioMiddleware } from './socketio/middleware';
import { socketMiddleware } from 'services/events/middleware';
import { CanvasState } from 'features/canvas/store/canvasTypes';
/**
* redux-persist provides an easy and reliable way to persist state across reloads.
@ -29,13 +46,21 @@ import { socketioMiddleware } from './socketio/middleware';
* The necesssary nested persistors with blacklists are configured below.
*/
const canvasBlacklist = [
/**
* Canvas slice persist blacklist
*/
const canvasBlacklist: (keyof CanvasState)[] = [
'cursorPosition',
'isCanvasInitialized',
'doesCanvasNeedScaling',
].map((blacklistItem) => `canvas.${blacklistItem}`);
];
const systemBlacklist = [
canvasBlacklist.map((blacklistItem) => `canvas.${blacklistItem}`);
/**
* System slice persist blacklist
*/
const systemBlacklist: (keyof SystemState)[] = [
'currentIteration',
'currentStatus',
'currentStep',
@ -48,30 +73,101 @@ const systemBlacklist = [
'totalIterations',
'totalSteps',
'openModel',
'cancelOptions.cancelAfter',
].map((blacklistItem) => `system.${blacklistItem}`);
'isCancelScheduled',
'sessionId',
'progressImage',
];
const galleryBlacklist = [
systemBlacklist.map((blacklistItem) => `system.${blacklistItem}`);
/**
* Gallery slice persist blacklist
*/
const galleryBlacklist: (keyof GalleryState)[] = [
'categories',
'currentCategory',
'currentImage',
'currentImageUuid',
'shouldAutoSwitchToNewImages',
'intermediateImage',
].map((blacklistItem) => `gallery.${blacklistItem}`);
];
const lightboxBlacklist = ['isLightboxOpen'].map(
(blacklistItem) => `lightbox.${blacklistItem}`
galleryBlacklist.map((blacklistItem) => `gallery.${blacklistItem}`);
/**
* Lightbox slice persist blacklist
*/
const lightboxBlacklist: (keyof LightboxState)[] = ['isLightboxOpen'];
lightboxBlacklist.map((blacklistItem) => `lightbox.${blacklistItem}`);
/**
* Nodes slice persist blacklist
*/
const nodesBlacklist: (keyof NodesState)[] = ['schema', 'invocations'];
nodesBlacklist.map((blacklistItem) => `nodes.${blacklistItem}`);
/**
* Generation slice persist blacklist
*/
const generationBlacklist: (keyof GenerationState)[] = [];
generationBlacklist.map((blacklistItem) => `generation.${blacklistItem}`);
/**
* Postprocessing slice persist blacklist
*/
const postprocessingBlacklist: (keyof PostprocessingState)[] = [];
postprocessingBlacklist.map(
(blacklistItem) => `postprocessing.${blacklistItem}`
);
/**
* Results slice persist blacklist
*
* Currently blacklisting results slice entirely, see persist config below
*/
const resultsBlacklist: (keyof ResultsState)[] = [];
resultsBlacklist.map((blacklistItem) => `results.${blacklistItem}`);
/**
* Uploads slice persist blacklist
*
* Currently blacklisting uploads slice entirely, see persist config below
*/
const uploadsBlacklist: (keyof NodesState)[] = [];
uploadsBlacklist.map((blacklistItem) => `uploads.${blacklistItem}`);
/**
* Models slice persist blacklist
*/
const modelsBlacklist: (keyof NodesState)[] = [];
modelsBlacklist.map((blacklistItem) => `models.${blacklistItem}`);
/**
* UI slice persist blacklist
*/
const uiBlacklist: (keyof NodesState)[] = [];
uiBlacklist.map((blacklistItem) => `ui.${blacklistItem}`);
const rootReducer = combineReducers({
generation: generationReducer,
postprocessing: postprocessingReducer,
gallery: galleryReducer,
system: systemReducer,
canvas: canvasReducer,
ui: uiReducer,
gallery: galleryReducer,
generation: generationReducer,
lightbox: lightboxReducer,
models: modelsReducer,
nodes: nodesReducer,
postprocessing: postprocessingReducer,
results: resultsReducer,
system: systemReducer,
ui: uiReducer,
uploads: uploadsReducer,
});
const rootPersistConfig = getPersistConfig({
@ -80,23 +176,40 @@ const rootPersistConfig = getPersistConfig({
rootReducer,
blacklist: [
...canvasBlacklist,
...systemBlacklist,
...galleryBlacklist,
...generationBlacklist,
...lightboxBlacklist,
...modelsBlacklist,
...nodesBlacklist,
...postprocessingBlacklist,
// ...resultsBlacklist,
'results',
...systemBlacklist,
...uiBlacklist,
// ...uploadsBlacklist,
'uploads',
],
debounce: 300,
});
const persistedReducer = persistReducer(rootPersistConfig, rootReducer);
// Continue with store setup
// TODO: rip the old middleware out when nodes is complete
export function buildMiddleware() {
if (import.meta.env.MODE === 'nodes' || import.meta.env.MODE === 'package') {
return socketMiddleware();
} else {
return socketioMiddleware();
}
}
export const store = configureStore({
reducer: persistedReducer,
middleware: (getDefaultMiddleware) =>
getDefaultMiddleware({
immutableCheck: false,
serializableCheck: false,
}).concat(socketioMiddleware()),
}).concat(dynamicMiddlewares),
devTools: {
// Uncommenting these very rapidly called actions makes the redux dev tools output much more readable
actionsDenylist: [

View File

@ -0,0 +1,8 @@
import { createAsyncThunk } from '@reduxjs/toolkit';
import { AppDispatch, RootState } from './store';
// https://redux-toolkit.js.org/usage/usage-with-typescript#defining-a-pre-typed-createasyncthunk
export const createAppAsyncThunk = createAsyncThunk.withTypes<{
state: RootState;
dispatch: AppDispatch;
}>();

View File

@ -2,7 +2,6 @@ import { Box, useToast } from '@chakra-ui/react';
import { ImageUploaderTriggerContext } from 'app/contexts/ImageUploaderTriggerContext';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import useImageUploader from 'common/hooks/useImageUploader';
import { uploadImage } from 'features/gallery/store/thunks/uploadImage';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { ResourceKey } from 'i18next';
import {
@ -15,6 +14,7 @@ import {
} from 'react';
import { FileRejection, useDropzone } from 'react-dropzone';
import { useTranslation } from 'react-i18next';
import { imageUploaded } from 'services/thunks/image';
import ImageUploadOverlay from './ImageUploadOverlay';
type ImageUploaderProps = {
@ -49,7 +49,7 @@ const ImageUploader = (props: ImageUploaderProps) => {
const fileAcceptedCallback = useCallback(
async (file: File) => {
dispatch(uploadImage({ imageFile: file }));
dispatch(imageUploaded({ formData: { file } }));
},
[dispatch]
);
@ -124,7 +124,7 @@ const ImageUploader = (props: ImageUploaderProps) => {
return;
}
dispatch(uploadImage({ imageFile: file }));
dispatch(imageUploaded({ formData: { file } }));
};
document.addEventListener('paste', pasteImageListener);
return () => {

View File

@ -1,27 +1,160 @@
import { Flex, Heading, Text, VStack } from '@chakra-ui/react';
import { useTranslation } from 'react-i18next';
import WorkInProgress from './WorkInProgress';
// import WorkInProgress from './WorkInProgress';
// import ReactFlow, {
// applyEdgeChanges,
// applyNodeChanges,
// Background,
// Controls,
// Edge,
// Handle,
// Node,
// NodeTypes,
// OnEdgesChange,
// OnNodesChange,
// Position,
// } from 'reactflow';
export default function NodesWIP() {
const { t } = useTranslation();
return (
<WorkInProgress>
<Flex
sx={{
flexDirection: 'column',
alignItems: 'center',
justifyContent: 'center',
w: '100%',
h: '100%',
gap: 4,
textAlign: 'center',
}}
>
<Heading>{t('common.nodes')}</Heading>
<VStack maxW="50rem" gap={4}>
<Text>{t('common.nodesDesc')}</Text>
</VStack>
</Flex>
</WorkInProgress>
);
}
// import 'reactflow/dist/style.css';
// import {
// Fragment,
// FunctionComponent,
// ReactNode,
// useCallback,
// useMemo,
// useState,
// } from 'react';
// import { OpenAPIV3 } from 'openapi-types';
// import { filter, map, reduce } from 'lodash';
// import {
// Box,
// Flex,
// FormControl,
// FormLabel,
// Input,
// Select,
// Switch,
// Text,
// NumberInput,
// NumberInputField,
// NumberInputStepper,
// NumberIncrementStepper,
// NumberDecrementStepper,
// Tooltip,
// chakra,
// Badge,
// Heading,
// VStack,
// HStack,
// Menu,
// MenuButton,
// MenuList,
// MenuItem,
// MenuItemOption,
// MenuGroup,
// MenuOptionGroup,
// MenuDivider,
// IconButton,
// } from '@chakra-ui/react';
// import { FaPlus } from 'react-icons/fa';
// import {
// FIELD_NAMES as FIELD_NAMES,
// FIELDS,
// INVOCATION_NAMES as INVOCATION_NAMES,
// INVOCATIONS,
// } from 'features/nodeEditor/constants';
// console.log('invocations', INVOCATIONS);
// const nodeTypes = reduce(
// INVOCATIONS,
// (acc, val, key) => {
// acc[key] = val.component;
// return acc;
// },
// {} as NodeTypes
// );
// console.log('nodeTypes', nodeTypes);
// // make initial nodes one of every node for now
// let n = 0;
// const initialNodes = map(INVOCATIONS, (i) => ({
// id: i.type,
// type: i.title,
// position: { x: (n += 20), y: (n += 20) },
// data: {},
// }));
// console.log('initialNodes', initialNodes);
// export default function NodesWIP() {
// const [nodes, setNodes] = useState<Node[]>([]);
// const [edges, setEdges] = useState<Edge[]>([]);
// const onNodesChange: OnNodesChange = useCallback(
// (changes) => setNodes((nds) => applyNodeChanges(changes, nds)),
// []
// );
// const onEdgesChange: OnEdgesChange = useCallback(
// (changes) => setEdges((eds: Edge[]) => applyEdgeChanges(changes, eds)),
// []
// );
// return (
// <Box
// sx={{
// position: 'relative',
// width: 'full',
// height: 'full',
// borderRadius: 'md',
// }}
// >
// <ReactFlow
// nodeTypes={nodeTypes}
// nodes={nodes}
// edges={edges}
// onNodesChange={onNodesChange}
// onEdgesChange={onEdgesChange}
// >
// <Background />
// <Controls />
// </ReactFlow>
// <HStack sx={{ position: 'absolute', top: 2, right: 2 }}>
// {FIELD_NAMES.map((field) => (
// <Badge
// key={field}
// colorScheme={FIELDS[field].color}
// sx={{ userSelect: 'none' }}
// >
// {field}
// </Badge>
// ))}
// </HStack>
// <Menu>
// <MenuButton
// as={IconButton}
// aria-label="Options"
// icon={<FaPlus />}
// sx={{ position: 'absolute', top: 2, left: 2 }}
// />
// <MenuList>
// {INVOCATION_NAMES.map((name) => {
// const invocation = INVOCATIONS[name];
// return (
// <Tooltip
// key={name}
// label={invocation.description}
// placement="end"
// hasArrow
// >
// <MenuItem>{invocation.title}</MenuItem>
// </Tooltip>
// );
// })}
// </MenuList>
// </Menu>
// </Box>
// );
// }
export default {};

View File

@ -14,6 +14,8 @@ const WorkInProgress = (props: WorkInProgressProps) => {
width: '100%',
height: '100%',
bg: 'base.850',
borderRadius: 'base',
position: 'relative',
}}
>
{children}

View File

@ -0,0 +1,72 @@
import { RootState } from 'app/store';
import { InvokeTabName, tabMap } from 'features/ui/store/tabMap';
import { find } from 'lodash';
import {
Graph,
ImageToImageInvocation,
TextToImageInvocation,
} from 'services/api';
import { buildHiResNode, buildImg2ImgNode } from './nodes/image2Image';
import { buildIteration } from './nodes/iteration';
import { buildTxt2ImgNode } from './nodes/text2Image';
function mapTabToFunction(activeTabName: InvokeTabName) {
switch (activeTabName) {
case 'txt2img':
return buildTxt2ImgNode;
case 'img2img':
return buildImg2ImgNode;
default:
return buildTxt2ImgNode;
}
}
const buildBaseNode = (
state: RootState
): Record<string, TextToImageInvocation | ImageToImageInvocation> => {
const { activeTab } = state.ui;
const activeTabName = tabMap[activeTab];
return mapTabToFunction(activeTabName)(state);
};
type BuildGraphOutput = {
graph: Graph;
nodeIdsToSubscribe: string[];
};
export const buildGraph = (state: RootState): BuildGraphOutput => {
const { generation, postprocessing } = state;
const { iterations } = generation;
const { hiresFix, hiresStrength } = postprocessing;
const baseNode = buildBaseNode(state);
let graph: Graph = { nodes: baseNode };
const nodeIdsToSubscribe: string[] = [];
if (iterations > 1) {
graph = buildIteration({ graph, iterations });
}
if (hiresFix) {
const { node, edge } = buildHiResNode(
baseNode as Record<string, TextToImageInvocation>,
hiresStrength
);
graph = {
nodes: {
...graph.nodes,
...node,
},
edges: [...(graph.edges || []), edge],
};
nodeIdsToSubscribe.push(Object.keys(node)[0]);
}
console.log('buildGraph: ', graph);
return { graph, nodeIdsToSubscribe };
};

View File

@ -0,0 +1,6 @@
import dateFormat from 'dateformat';
/**
* Get a `now` timestamp with 1s precision, formatted as ISO datetime.
*/
export const getTimestamp = () => dateFormat(new Date(), 'isoDateTime');

View File

@ -0,0 +1,28 @@
import { RootState } from 'app/store';
import { useAppSelector } from 'app/storeHooks';
import { OpenAPI } from 'services/api';
export const getUrlAlt = (url: string, shouldTransformUrls: boolean) => {
if (OpenAPI.BASE && shouldTransformUrls) {
return [OpenAPI.BASE, url].join('/');
}
return url;
};
export const useGetUrl = () => {
const shouldTransformUrls = useAppSelector(
(state: RootState) => state.system.shouldTransformUrls
);
return {
shouldTransformUrls,
getUrl: (url: string) => {
if (OpenAPI.BASE && shouldTransformUrls) {
return [OpenAPI.BASE, url].join('/');
}
return url;
},
};
};

View File

@ -0,0 +1,98 @@
import { v4 as uuidv4 } from 'uuid';
import { RootState } from 'app/store';
import {
Edge,
ImageToImageInvocation,
TextToImageInvocation,
} from 'services/api';
import { _Image } from 'app/invokeai';
import { initialImageSelector } from 'features/parameters/store/generationSelectors';
export const buildImg2ImgNode = (
state: RootState
): Record<string, ImageToImageInvocation> => {
const nodeId = uuidv4();
const { generation, system, models } = state;
const { shouldDisplayInProgressType } = system;
const { currentModel: model } = models;
const {
prompt,
seed,
steps,
width,
height,
cfgScale,
sampler,
seamless,
img2imgStrength: strength,
shouldFitToWidthHeight: fit,
shouldRandomizeSeed,
} = generation;
const initialImage = initialImageSelector(state);
if (!initialImage) {
// TODO: handle this
throw 'no initial image';
}
return {
[nodeId]: {
id: nodeId,
type: 'img2img',
prompt,
seed: shouldRandomizeSeed ? -1 : seed,
steps,
width,
height,
cfg_scale: cfgScale,
scheduler: sampler as ImageToImageInvocation['scheduler'],
seamless,
model,
progress_images: shouldDisplayInProgressType === 'full-res',
image: {
image_name: initialImage.name,
image_type: initialImage.type,
},
strength,
fit,
},
};
};
type hiresReturnType = {
node: Record<string, ImageToImageInvocation>;
edge: Edge;
};
export const buildHiResNode = (
baseNode: Record<string, TextToImageInvocation>,
strength?: number
): hiresReturnType => {
const nodeId = uuidv4();
const baseNodeId = Object.keys(baseNode)[0];
const baseNodeValues = Object.values(baseNode)[0];
return {
node: {
[nodeId]: {
...baseNodeValues,
id: nodeId,
type: 'img2img',
strength,
},
},
edge: {
source: {
field: 'image',
node_id: baseNodeId,
},
destination: {
field: 'image',
node_id: nodeId,
},
},
};
};

View File

@ -0,0 +1,81 @@
import { v4 as uuidv4 } from 'uuid';
import {
Edge,
Graph,
ImageToImageInvocation,
IterateInvocation,
RangeInvocation,
TextToImageInvocation,
} from 'services/api';
import { buildImg2ImgNode } from './image2Image';
type BuildIteration = {
graph: Graph;
iterations: number;
};
const buildRangeNode = (
iterations: number
): Record<string, RangeInvocation> => {
const nodeId = uuidv4();
return {
[nodeId]: {
id: nodeId,
type: 'range',
start: 0,
stop: iterations,
step: 1,
},
};
};
const buildIterateNode = (): Record<string, IterateInvocation> => {
const nodeId = uuidv4();
return {
[nodeId]: {
id: nodeId,
type: 'iterate',
collection: [],
index: 0,
},
};
};
export const buildIteration = ({
graph,
iterations,
}: BuildIteration): Graph => {
const rangeNode = buildRangeNode(iterations);
const iterateNode = buildIterateNode();
const baseNode: Graph['nodes'] = graph.nodes;
const edges: Edge[] = [
{
source: {
field: 'collection',
node_id: Object.keys(rangeNode)[0],
},
destination: {
field: 'collection',
node_id: Object.keys(iterateNode)[0],
},
},
{
source: {
field: 'item',
node_id: Object.keys(iterateNode)[0],
},
destination: {
field: 'seed',
node_id: Object.keys(baseNode!)[0],
},
},
];
return {
nodes: {
...rangeNode,
...iterateNode,
...graph.nodes,
},
edges,
};
};

View File

@ -0,0 +1,43 @@
import { v4 as uuidv4 } from 'uuid';
import { RootState } from 'app/store';
import { TextToImageInvocation } from 'services/api';
export const buildTxt2ImgNode = (
state: RootState
): Record<string, TextToImageInvocation> => {
const nodeId = uuidv4();
const { generation, system, models } = state;
const { shouldDisplayInProgressType } = system;
const { currentModel: model } = models;
const {
prompt,
seed,
steps,
width,
height,
cfgScale: cfg_scale,
sampler,
seamless,
shouldRandomizeSeed,
} = generation;
// missing fields in TextToImageInvocation: strength, hires_fix
return {
[nodeId]: {
id: nodeId,
type: 'txt2img',
prompt,
seed: shouldRandomizeSeed ? -1 : seed,
steps,
width,
height,
cfg_scale,
scheduler: sampler as TextToImageInvocation['scheduler'],
seamless,
model,
progress_images: shouldDisplayInProgressType === 'full-res',
},
};
};

View File

@ -1,8 +1,10 @@
import React, { lazy, PropsWithChildren } from 'react';
import React, { lazy, PropsWithChildren, useEffect, useState } from 'react';
import { Provider } from 'react-redux';
import { PersistGate } from 'redux-persist/integration/react';
import { store } from './app/store';
import { buildMiddleware, store } from './app/store';
import { persistor } from './persistor';
import { OpenAPI } from 'services/api';
import { InvokeTabName } from 'features/ui/store/tabMap';
import '@fontsource/inter/100.css';
import '@fontsource/inter/200.css';
import '@fontsource/inter/300.css';
@ -17,18 +19,61 @@ import Loading from './Loading';
// Localization
import './i18n';
import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares';
const App = lazy(() => import('./app/App'));
const ThemeLocaleProvider = lazy(() => import('./app/ThemeLocaleProvider'));
export default function Component(props: PropsWithChildren) {
interface Props extends PropsWithChildren {
apiUrl?: string;
disabledPanels?: string[];
disabledTabs?: InvokeTabName[];
token?: string;
shouldTransformUrls?: boolean;
}
export default function Component({
apiUrl,
disabledPanels = [],
disabledTabs = [],
token,
children,
shouldTransformUrls,
}: Props) {
useEffect(() => {
// configure API client token
if (token) {
OpenAPI.TOKEN = token;
}
// configure API client base url
if (apiUrl) {
OpenAPI.BASE = apiUrl;
}
// reset dynamically added middlewares
resetMiddlewares();
// TODO: at this point, after resetting the middleware, we really ought to clean up the socket
// stuff by calling `dispatch(socketReset())`. but we cannot dispatch from here as we are
// outside the provider. it's not needed until there is the possibility that we will change
// the `apiUrl`/`token` dynamically.
// rebuild socket middleware with token and apiUrl
addMiddleware(buildMiddleware());
}, [apiUrl, token]);
return (
<React.StrictMode>
<Provider store={store}>
<PersistGate loading={<Loading />} persistor={persistor}>
<React.Suspense fallback={<Loading showText />}>
<ThemeLocaleProvider>
<App>{props.children}</App>
<App
options={{ disabledPanels, disabledTabs, shouldTransformUrls }}
>
{children}
</App>
</ThemeLocaleProvider>
</React.Suspense>
</PersistGate>

View File

@ -5,6 +5,8 @@ import ThemeChanger from './features/system/components/ThemeChanger';
import IAIPopover from './common/components/IAIPopover';
import IAIIconButton from './common/components/IAIIconButton';
import SettingsModal from './features/system/components/SettingsModal/SettingsModal';
import StatusIndicator from './features/system/components/StatusIndicator';
import ModelSelect from 'features/system/components/ModelSelect';
export default Component;
export {
@ -13,4 +15,6 @@ export {
IAIPopover,
IAIIconButton,
SettingsModal,
StatusIndicator,
ModelSelect,
};

View File

@ -1,6 +1,7 @@
import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store';
import { useAppSelector } from 'app/storeHooks';
import { useGetUrl } from 'common/util/getUrl';
import { GalleryState } from 'features/gallery/store/gallerySlice';
import { ImageConfig } from 'konva/lib/shapes/Image';
import { isEqual } from 'lodash';
@ -25,7 +26,7 @@ type Props = Omit<ImageConfig, 'image'>;
const IAICanvasIntermediateImage = (props: Props) => {
const { ...rest } = props;
const intermediateImage = useAppSelector(selector);
const { getUrl } = useGetUrl();
const [loadedImageElement, setLoadedImageElement] =
useState<HTMLImageElement | null>(null);
@ -36,8 +37,8 @@ const IAICanvasIntermediateImage = (props: Props) => {
tempImage.onload = () => {
setLoadedImageElement(tempImage);
};
tempImage.src = intermediateImage.url;
}, [intermediateImage]);
tempImage.src = getUrl(intermediateImage.url);
}, [intermediateImage, getUrl]);
if (!intermediateImage?.boundingBox) return null;

View File

@ -1,5 +1,6 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/storeHooks';
import { useGetUrl } from 'common/util/getUrl';
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import { rgbaColorToString } from 'features/canvas/util/colorToString';
import { isEqual } from 'lodash';
@ -32,6 +33,7 @@ const selector = createSelector(
const IAICanvasObjectRenderer = () => {
const { objects } = useAppSelector(selector);
const { getUrl } = useGetUrl();
if (!objects) return null;
@ -40,7 +42,12 @@ const IAICanvasObjectRenderer = () => {
{objects.map((obj, i) => {
if (isCanvasBaseImage(obj)) {
return (
<IAICanvasImage key={i} x={obj.x} y={obj.y} url={obj.image.url} />
<IAICanvasImage
key={i}
x={obj.x}
y={obj.y}
url={getUrl(obj.image.url)}
/>
);
} else if (isCanvasBaseLine(obj)) {
const line = (

View File

@ -1,5 +1,6 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/storeHooks';
import { useGetUrl } from 'common/util/getUrl';
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import { GroupConfig } from 'konva/lib/Group';
import { isEqual } from 'lodash';
@ -53,11 +54,16 @@ const IAICanvasStagingArea = (props: Props) => {
width,
height,
} = useAppSelector(selector);
const { getUrl } = useGetUrl();
return (
<Group {...rest}>
{shouldShowStagingImage && currentStagingAreaImage && (
<IAICanvasImage url={currentStagingAreaImage.image.url} x={x} y={y} />
<IAICanvasImage
url={getUrl(currentStagingAreaImage.image.url)}
x={x}
y={y}
/>
)}
{shouldShowStagingOutline && (
<Group>

View File

@ -156,7 +156,7 @@ export const canvasSlice = createSlice({
setCursorPosition: (state, action: PayloadAction<Vector2d | null>) => {
state.cursorPosition = action.payload;
},
setInitialCanvasImage: (state, action: PayloadAction<InvokeAI.Image>) => {
setInitialCanvasImage: (state, action: PayloadAction<InvokeAI._Image>) => {
const image = action.payload;
const { stageDimensions } = state;
@ -291,7 +291,7 @@ export const canvasSlice = createSlice({
state,
action: PayloadAction<{
boundingBox: IRect;
image: InvokeAI.Image;
image: InvokeAI._Image;
}>
) => {
const { boundingBox, image } = action.payload;

View File

@ -37,7 +37,7 @@ export type CanvasImage = {
y: number;
width: number;
height: number;
image: InvokeAI.Image;
image: InvokeAI._Image;
};
export type CanvasMaskLine = {
@ -125,7 +125,7 @@ export interface CanvasState {
cursorPosition: Vector2d | null;
doesCanvasNeedScaling: boolean;
futureLayerStates: CanvasLayerState[];
intermediateImage?: InvokeAI.Image;
intermediateImage?: InvokeAI._Image;
isCanvasInitialized: boolean;
isDrawing: boolean;
isMaskEnabled: boolean;

View File

@ -105,7 +105,7 @@ export const mergeAndUploadCanvas =
const { url, width, height } = image;
const newImage: InvokeAI.Image = {
const newImage: InvokeAI._Image = {
uuid: uuidv4(),
category: shouldSaveToGallery ? 'result' : 'user',
...image,

View File

@ -14,8 +14,9 @@ import { setIsLightboxOpen } from 'features/lightbox/store/lightboxSlice';
import FaceRestoreSettings from 'features/parameters/components/AdvancedParameters/FaceRestore/FaceRestoreSettings';
import UpscaleSettings from 'features/parameters/components/AdvancedParameters/Upscale/UpscaleSettings';
import {
initialImageSelected,
setAllParameters,
setInitialImage,
// setInitialImage,
setSeed,
} from 'features/parameters/store/generationSlice';
import { postprocessingSelector } from 'features/parameters/store/postprocessingSelectors';
@ -45,11 +46,15 @@ import {
FaShareAlt,
FaTrash,
} from 'react-icons/fa';
import { gallerySelector } from '../store/gallerySelectors';
import {
gallerySelector,
selectedImageSelector,
} from '../store/gallerySelectors';
import DeleteImageModal from './DeleteImageModal';
import { useCallback } from 'react';
import useSetBothPrompts from 'features/parameters/hooks/usePrompt';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import { useGetUrl } from 'common/util/getUrl';
const currentImageButtonsSelector = createSelector(
[
@ -59,6 +64,7 @@ const currentImageButtonsSelector = createSelector(
uiSelector,
lightboxSelector,
activeTabNameSelector,
selectedImageSelector,
],
(
system: SystemState,
@ -66,7 +72,8 @@ const currentImageButtonsSelector = createSelector(
postprocessing,
ui,
lightbox,
activeTabName
activeTabName,
selectedImage
) => {
const { isProcessing, isConnected, isGFPGANAvailable, isESRGANAvailable } =
system;
@ -91,6 +98,7 @@ const currentImageButtonsSelector = createSelector(
shouldShowImageDetails,
activeTabName,
isLightboxOpen,
selectedImage,
};
},
{
@ -117,26 +125,32 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
facetoolStrength,
shouldDisableToolbarButtons,
shouldShowImageDetails,
currentImage,
// currentImage,
isLightboxOpen,
activeTabName,
selectedImage,
} = useAppSelector(currentImageButtonsSelector);
const { getUrl, shouldTransformUrls } = useGetUrl();
const toast = useToast();
const { t } = useTranslation();
const setBothPrompts = useSetBothPrompts();
const handleClickUseAsInitialImage = () => {
if (!currentImage) return;
if (!selectedImage) return;
if (isLightboxOpen) dispatch(setIsLightboxOpen(false));
dispatch(setInitialImage(currentImage));
dispatch(setActiveTab('img2img'));
dispatch(initialImageSelected(selectedImage.name));
// dispatch(setInitialImage(currentImage));
// dispatch(setActiveTab('img2img'));
};
const handleCopyImage = async () => {
if (!currentImage) return;
if (!selectedImage) return;
const blob = await fetch(currentImage.url).then((res) => res.blob());
const blob = await fetch(getUrl(selectedImage.url)).then((res) =>
res.blob()
);
const data = [new ClipboardItem({ [blob.type]: blob })];
await navigator.clipboard.write(data);
@ -150,11 +164,13 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
};
const handleCopyImageLink = () => {
navigator.clipboard
.writeText(
currentImage ? window.location.toString() + currentImage.url : ''
)
.then(() => {
const url = selectedImage
? shouldTransformUrls
? getUrl(selectedImage.url)
: window.location.toString() + selectedImage.url
: '';
navigator.clipboard.writeText(url).then(() => {
toast({
title: t('toast.imageLinkCopied'),
status: 'success',
@ -167,7 +183,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
useHotkeys(
'shift+i',
() => {
if (currentImage) {
if (selectedImage) {
handleClickUseAsInitialImage();
toast({
title: t('toast.sentToImageToImage'),
@ -185,24 +201,27 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
});
}
},
[currentImage]
[selectedImage]
);
const handleClickUseAllParameters = () => {
if (!currentImage) return;
currentImage.metadata && dispatch(setAllParameters(currentImage.metadata));
if (currentImage.metadata?.image.type === 'img2img') {
dispatch(setActiveTab('img2img'));
} else if (currentImage.metadata?.image.type === 'txt2img') {
dispatch(setActiveTab('txt2img'));
}
if (!selectedImage) return;
// selectedImage.metadata &&
// dispatch(setAllParameters(selectedImage.metadata));
// if (selectedImage.metadata?.image.type === 'img2img') {
// dispatch(setActiveTab('img2img'));
// } else if (selectedImage.metadata?.image.type === 'txt2img') {
// dispatch(setActiveTab('txt2img'));
// }
};
useHotkeys(
'a',
() => {
if (
['txt2img', 'img2img'].includes(currentImage?.metadata?.image?.type)
['txt2img', 'img2img'].includes(
selectedImage?.metadata?.sd_metadata?.type
)
) {
handleClickUseAllParameters();
toast({
@ -221,18 +240,18 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
});
}
},
[currentImage]
[selectedImage]
);
const handleClickUseSeed = () => {
currentImage?.metadata &&
dispatch(setSeed(currentImage.metadata.image.seed));
selectedImage?.metadata &&
dispatch(setSeed(selectedImage.metadata.sd_metadata.seed));
};
useHotkeys(
's',
() => {
if (currentImage?.metadata?.image?.seed) {
if (selectedImage?.metadata?.sd_metadata?.seed) {
handleClickUseSeed();
toast({
title: t('toast.seedSet'),
@ -250,19 +269,19 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
});
}
},
[currentImage]
[selectedImage]
);
const handleClickUsePrompt = useCallback(() => {
if (currentImage?.metadata?.image?.prompt) {
setBothPrompts(currentImage?.metadata?.image?.prompt);
if (selectedImage?.metadata?.sd_metadata?.prompt) {
setBothPrompts(selectedImage?.metadata?.sd_metadata?.prompt);
}
}, [currentImage?.metadata?.image?.prompt, setBothPrompts]);
}, [selectedImage?.metadata?.sd_metadata?.prompt, setBothPrompts]);
useHotkeys(
'p',
() => {
if (currentImage?.metadata?.image?.prompt) {
if (selectedImage?.metadata?.sd_metadata?.prompt) {
handleClickUsePrompt();
toast({
title: t('toast.promptSet'),
@ -280,11 +299,11 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
});
}
},
[currentImage]
[selectedImage]
);
const handleClickUpscale = () => {
currentImage && dispatch(runESRGAN(currentImage));
// selectedImage && dispatch(runESRGAN(selectedImage));
};
useHotkeys(
@ -308,7 +327,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
}
},
[
currentImage,
selectedImage,
isESRGANAvailable,
shouldDisableToolbarButtons,
isConnected,
@ -318,7 +337,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
);
const handleClickFixFaces = () => {
currentImage && dispatch(runFacetool(currentImage));
// selectedImage && dispatch(runFacetool(selectedImage));
};
useHotkeys(
@ -342,7 +361,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
}
},
[
currentImage,
selectedImage,
isGFPGANAvailable,
shouldDisableToolbarButtons,
isConnected,
@ -355,10 +374,10 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
dispatch(setShouldShowImageDetails(!shouldShowImageDetails));
const handleSendToCanvas = () => {
if (!currentImage) return;
if (!selectedImage) return;
if (isLightboxOpen) dispatch(setIsLightboxOpen(false));
dispatch(setInitialCanvasImage(currentImage));
// dispatch(setInitialCanvasImage(selectedImage));
dispatch(requestCanvasRescale());
if (activeTabName !== 'unifiedCanvas') {
@ -376,7 +395,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
useHotkeys(
'i',
() => {
if (currentImage) {
if (selectedImage) {
handleClickShowImageDetails();
} else {
toast({
@ -387,7 +406,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
});
}
},
[currentImage, shouldShowImageDetails]
[selectedImage, shouldShowImageDetails]
);
const handleLightBox = () => {
@ -448,7 +467,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
{t('parameters.copyImageToLink')}
</IAIButton>
<Link download={true} href={currentImage?.url}>
<Link download={true} href={getUrl(selectedImage!.url)}>
<IAIButton leftIcon={<FaDownload />} size="sm" w="100%">
{t('parameters.downloadImage')}
</IAIButton>
@ -477,7 +496,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
icon={<FaQuoteRight />}
tooltip={`${t('parameters.usePrompt')} (P)`}
aria-label={`${t('parameters.usePrompt')} (P)`}
isDisabled={!currentImage?.metadata?.image?.prompt}
isDisabled={!selectedImage?.metadata?.sd_metadata?.prompt}
onClick={handleClickUsePrompt}
/>
@ -485,7 +504,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
icon={<FaSeedling />}
tooltip={`${t('parameters.useSeed')} (S)`}
aria-label={`${t('parameters.useSeed')} (S)`}
isDisabled={!currentImage?.metadata?.image?.seed}
isDisabled={!selectedImage?.metadata?.sd_metadata?.seed}
onClick={handleClickUseSeed}
/>
@ -495,7 +514,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
aria-label={`${t('parameters.useAll')} (A)`}
isDisabled={
!['txt2img', 'img2img'].includes(
currentImage?.metadata?.image?.type
selectedImage?.metadata?.sd_metadata?.type
)
}
onClick={handleClickUseAllParameters}
@ -521,7 +540,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
<IAIButton
isDisabled={
!isGFPGANAvailable ||
!currentImage ||
!selectedImage ||
!(isConnected && !isProcessing) ||
!facetoolStrength
}
@ -550,7 +569,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
<IAIButton
isDisabled={
!isESRGANAvailable ||
!currentImage ||
!selectedImage ||
!(isConnected && !isProcessing) ||
!upscalingLevel
}
@ -572,15 +591,15 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
/>
</ButtonGroup>
<DeleteImageModal image={currentImage}>
{/* <DeleteImageModal image={selectedImage}>
<IAIIconButton
icon={<FaTrash />}
tooltip={`${t('parameters.deleteImage')} (Del)`}
aria-label={`${t('parameters.deleteImage')} (Del)`}
isDisabled={!currentImage || !isConnected || isProcessing}
isDisabled={!selectedImage || !isConnected || isProcessing}
colorScheme="error"
/>
</DeleteImageModal>
</DeleteImageModal> */}
</Flex>
);
};

View File

@ -4,17 +4,20 @@ import { useAppSelector } from 'app/storeHooks';
import { isEqual } from 'lodash';
import { MdPhoto } from 'react-icons/md';
import { gallerySelector } from '../store/gallerySelectors';
import {
gallerySelector,
selectedImageSelector,
} from '../store/gallerySelectors';
import CurrentImageButtons from './CurrentImageButtons';
import CurrentImagePreview from './CurrentImagePreview';
export const currentImageDisplaySelector = createSelector(
[gallerySelector],
(gallery) => {
[gallerySelector, selectedImageSelector],
(gallery, selectedImage) => {
const { currentImage, intermediateImage } = gallery;
return {
hasAnImageToDisplay: currentImage || intermediateImage,
hasAnImageToDisplay: selectedImage || intermediateImage,
};
},
{

View File

@ -1,26 +1,46 @@
import { Box, Flex, Image } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/storeHooks';
import { GalleryState } from 'features/gallery/store/gallerySlice';
import { useGetUrl } from 'common/util/getUrl';
import { systemSelector } from 'features/system/store/systemSelectors';
import { uiSelector } from 'features/ui/store/uiSelectors';
import { isEqual } from 'lodash';
import { ReactEventHandler } from 'react';
import { APP_METADATA_HEIGHT } from 'theme/util/constants';
import { gallerySelector } from '../store/gallerySelectors';
import { selectedImageSelector } from '../store/gallerySelectors';
import CurrentImageFallback from './CurrentImageFallback';
import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer';
import NextPrevImageButtons from './NextPrevImageButtons';
export const imagesSelector = createSelector(
[gallerySelector, uiSelector],
(gallery: GalleryState, ui) => {
const { currentImage, intermediateImage } = gallery;
[uiSelector, selectedImageSelector, systemSelector],
(ui, selectedImage, system) => {
const { shouldShowImageDetails } = ui;
const { progressImage } = system;
// TODO: Clean this up, this is really gross
const imageToDisplay = progressImage
? {
url: progressImage.dataURL,
width: progressImage.width,
height: progressImage.height,
isProgressImage: true,
image: progressImage,
}
: selectedImage
? {
url: selectedImage.url,
width: selectedImage.metadata.width,
height: selectedImage.metadata.height,
isProgressImage: false,
image: selectedImage,
}
: null;
return {
imageToDisplay: intermediateImage ? intermediateImage : currentImage,
isIntermediate: Boolean(intermediateImage),
shouldShowImageDetails,
imageToDisplay,
};
},
{
@ -31,8 +51,9 @@ export const imagesSelector = createSelector(
);
export default function CurrentImagePreview() {
const { shouldShowImageDetails, imageToDisplay, isIntermediate } =
const { shouldShowImageDetails, imageToDisplay } =
useAppSelector(imagesSelector);
const { getUrl } = useGetUrl();
return (
<Flex
@ -46,23 +67,35 @@ export default function CurrentImagePreview() {
>
{imageToDisplay && (
<Image
src={imageToDisplay.url}
src={
imageToDisplay.isProgressImage
? imageToDisplay.url
: getUrl(imageToDisplay.url)
}
width={imageToDisplay.width}
height={imageToDisplay.height}
fallback={!isIntermediate ? <CurrentImageFallback /> : undefined}
fallback={
!imageToDisplay.isProgressImage ? (
<CurrentImageFallback />
) : undefined
}
sx={{
objectFit: 'contain',
maxWidth: '100%',
maxHeight: '100%',
height: 'auto',
position: 'absolute',
imageRendering: isIntermediate ? 'pixelated' : 'initial',
imageRendering: imageToDisplay.isProgressImage
? 'pixelated'
: 'initial',
borderRadius: 'base',
}}
/>
)}
{!shouldShowImageDetails && <NextPrevImageButtons />}
{shouldShowImageDetails && imageToDisplay && (
{shouldShowImageDetails &&
imageToDisplay &&
'metadata' in imageToDisplay.image && (
<Box
sx={{
position: 'absolute',
@ -74,7 +107,7 @@ export default function CurrentImagePreview() {
maxHeight: APP_METADATA_HEIGHT,
}}
>
<ImageMetadataViewer image={imageToDisplay} />
<ImageMetadataViewer image={imageToDisplay.image} />
</Box>
)}
</Flex>

View File

@ -52,7 +52,7 @@ interface DeleteImageModalProps {
/**
* The image to delete.
*/
image?: InvokeAI.Image;
image?: InvokeAI._Image;
}
/**

View File

@ -9,11 +9,14 @@ import {
useToast,
} from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { setCurrentImage } from 'features/gallery/store/gallerySlice';
import {
imageSelected,
setCurrentImage,
} from 'features/gallery/store/gallerySlice';
import {
initialImageSelected,
setAllImageToImageParameters,
setAllParameters,
setInitialImage,
setSeed,
} from 'features/parameters/store/generationSlice';
import { DragEvent, memo, useState } from 'react';
@ -31,6 +34,7 @@ import { useTranslation } from 'react-i18next';
import useSetBothPrompts from 'features/parameters/hooks/usePrompt';
import { setIsLightboxOpen } from 'features/lightbox/store/lightboxSlice';
import IAIIconButton from 'common/components/IAIIconButton';
import { useGetUrl } from 'common/util/getUrl';
interface HoverableImageProps {
image: InvokeAI.Image;
@ -40,7 +44,7 @@ interface HoverableImageProps {
const memoEqualityCheck = (
prev: HoverableImageProps,
next: HoverableImageProps
) => prev.image.uuid === next.image.uuid && prev.isSelected === next.isSelected;
) => prev.image.name === next.image.name && prev.isSelected === next.isSelected;
/**
* Gallery image component with delete/use all/use seed buttons on hover.
@ -55,7 +59,8 @@ const HoverableImage = memo((props: HoverableImageProps) => {
shouldUseSingleGalleryColumn,
} = useAppSelector(hoverableImageSelector);
const { image, isSelected } = props;
const { url, thumbnail, uuid, metadata } = image;
const { url, thumbnail, name, metadata } = image;
const { getUrl } = useGetUrl();
const [isHovered, setIsHovered] = useState<boolean>(false);
@ -69,10 +74,9 @@ const HoverableImage = memo((props: HoverableImageProps) => {
const handleMouseOut = () => setIsHovered(false);
const handleUsePrompt = () => {
if (image.metadata?.image?.prompt) {
setBothPrompts(image.metadata?.image?.prompt);
if (image.metadata?.sd_metadata?.prompt) {
setBothPrompts(image.metadata?.sd_metadata?.prompt);
}
toast({
title: t('toast.promptSet'),
status: 'success',
@ -82,7 +86,8 @@ const HoverableImage = memo((props: HoverableImageProps) => {
};
const handleUseSeed = () => {
image.metadata && dispatch(setSeed(image.metadata.image.seed));
image.metadata.sd_metadata &&
dispatch(setSeed(image.metadata.sd_metadata.image.seed));
toast({
title: t('toast.seedSet'),
status: 'success',
@ -92,20 +97,11 @@ const HoverableImage = memo((props: HoverableImageProps) => {
};
const handleSendToImageToImage = () => {
dispatch(setInitialImage(image));
if (activeTabName !== 'img2img') {
dispatch(setActiveTab('img2img'));
}
toast({
title: t('toast.sentToImageToImage'),
status: 'success',
duration: 2500,
isClosable: true,
});
dispatch(initialImageSelected(image.name));
};
const handleSendToCanvas = () => {
dispatch(setInitialCanvasImage(image));
// dispatch(setInitialCanvasImage(image));
dispatch(resizeAndScaleCanvas());
@ -122,7 +118,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
};
const handleUseAllParameters = () => {
metadata && dispatch(setAllParameters(metadata));
metadata.sd_metadata && dispatch(setAllParameters(metadata.sd_metadata));
toast({
title: t('toast.parametersSet'),
status: 'success',
@ -132,11 +128,13 @@ const HoverableImage = memo((props: HoverableImageProps) => {
};
const handleUseInitialImage = async () => {
if (metadata?.image?.init_image_path) {
const response = await fetch(metadata.image.init_image_path);
if (metadata.sd_metadata?.image?.init_image_path) {
const response = await fetch(
metadata.sd_metadata?.image?.init_image_path
);
if (response.ok) {
dispatch(setActiveTab('img2img'));
dispatch(setAllImageToImageParameters(metadata));
dispatch(setAllImageToImageParameters(metadata?.sd_metadata));
toast({
title: t('toast.initialImageSet'),
status: 'success',
@ -155,16 +153,18 @@ const HoverableImage = memo((props: HoverableImageProps) => {
});
};
const handleSelectImage = () => dispatch(setCurrentImage(image));
const handleSelectImage = () => {
dispatch(imageSelected(image.name));
};
const handleDragStart = (e: DragEvent<HTMLDivElement>) => {
e.dataTransfer.setData('invokeai/imageUuid', uuid);
e.dataTransfer.effectAllowed = 'move';
// e.dataTransfer.setData('invokeai/imageUuid', uuid);
// e.dataTransfer.effectAllowed = 'move';
};
const handleLightBox = () => {
dispatch(setCurrentImage(image));
dispatch(setIsLightboxOpen(true));
// dispatch(setCurrentImage(image));
// dispatch(setIsLightboxOpen(true));
};
return (
@ -177,28 +177,30 @@ const HoverableImage = memo((props: HoverableImageProps) => {
</MenuItem>
<MenuItem
onClickCapture={handleUsePrompt}
isDisabled={image?.metadata?.image?.prompt === undefined}
isDisabled={image?.metadata?.sd_metadata?.prompt === undefined}
>
{t('parameters.usePrompt')}
</MenuItem>
<MenuItem
onClickCapture={handleUseSeed}
isDisabled={image?.metadata?.image?.seed === undefined}
isDisabled={image?.metadata?.sd_metadata?.seed === undefined}
>
{t('parameters.useSeed')}
</MenuItem>
<MenuItem
onClickCapture={handleUseAllParameters}
isDisabled={
!['txt2img', 'img2img'].includes(image?.metadata?.image?.type)
!['txt2img', 'img2img'].includes(
image?.metadata?.sd_metadata?.type
)
}
>
{t('parameters.useAll')}
</MenuItem>
<MenuItem
onClickCapture={handleUseInitialImage}
isDisabled={image?.metadata?.image?.type !== 'img2img'}
isDisabled={image?.metadata?.sd_metadata?.type !== 'img2img'}
>
{t('parameters.useInitImg')}
</MenuItem>
@ -209,9 +211,9 @@ const HoverableImage = memo((props: HoverableImageProps) => {
{t('parameters.sendToUnifiedCanvas')}
</MenuItem>
<MenuItem data-warning>
<DeleteImageModal image={image}>
{/* <DeleteImageModal image={image}>
<p>{t('parameters.deleteImage')}</p>
</DeleteImageModal>
</DeleteImageModal> */}
</MenuItem>
</MenuList>
)}
@ -219,7 +221,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
{(ref) => (
<Box
position="relative"
key={uuid}
key={name}
onMouseOver={handleMouseOver}
onMouseOut={handleMouseOut}
userSelect="none"
@ -244,7 +246,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
shouldUseSingleGalleryColumn ? 'contain' : galleryImageObjectFit
}
rounded="md"
src={thumbnail || url}
src={getUrl(thumbnail || url)}
loading="lazy"
sx={{
position: 'absolute',
@ -290,7 +292,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
insetInlineEnd: 1,
}}
>
<DeleteImageModal image={image}>
{/* <DeleteImageModal image={image}>
<IAIIconButton
aria-label={t('parameters.deleteImage')}
icon={<FaTrashAlt />}
@ -298,7 +300,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
fontSize={14}
isDisabled={!mayDeleteImage}
/>
</DeleteImageModal>
</DeleteImageModal> */}
</Box>
)}
</Box>

View File

@ -1,4 +1,4 @@
import { ButtonGroup, Flex, Grid, Icon, Text } from '@chakra-ui/react';
import { ButtonGroup, Flex, Grid, Icon, Image, Text } from '@chakra-ui/react';
import { requestImages } from 'app/socketio/actions';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import IAIButton from 'common/components/IAIButton';
@ -25,9 +25,44 @@ import HoverableImage from './HoverableImage';
import Scrollable from 'features/ui/components/common/Scrollable';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import {
resultsAdapter,
selectResultsAll,
selectResultsTotal,
} from '../store/resultsSlice';
import {
receivedResultImagesPage,
receivedUploadImagesPage,
} from 'services/thunks/gallery';
import { selectUploadsAll, uploadsAdapter } from '../store/uploadsSlice';
import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store';
const GALLERY_SHOW_BUTTONS_MIN_WIDTH = 290;
const gallerySelector = createSelector(
[
(state: RootState) => state.uploads,
(state: RootState) => state.results,
(state: RootState) => state.gallery,
],
(uploads, results, gallery) => {
const { currentCategory } = gallery;
return currentCategory === 'result'
? {
images: resultsAdapter.getSelectors().selectAll(results),
isLoading: results.isLoading,
areMoreImagesAvailable: results.page < results.pages - 1,
}
: {
images: uploadsAdapter.getSelectors().selectAll(uploads),
isLoading: uploads.isLoading,
areMoreImagesAvailable: uploads.page < uploads.pages - 1,
};
}
);
const ImageGalleryContent = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
@ -35,7 +70,7 @@ const ImageGalleryContent = () => {
const [shouldShouldIconButtons, setShouldShouldIconButtons] = useState(true);
const {
images,
// images,
currentCategory,
currentImageUuid,
shouldPinGallery,
@ -43,12 +78,24 @@ const ImageGalleryContent = () => {
galleryGridTemplateColumns,
galleryImageObjectFit,
shouldAutoSwitchToNewImages,
areMoreImagesAvailable,
// areMoreImagesAvailable,
shouldUseSingleGalleryColumn,
} = useAppSelector(imageGallerySelector);
const { images, areMoreImagesAvailable, isLoading } =
useAppSelector(gallerySelector);
// const handleClickLoadMore = () => {
// dispatch(requestImages(currentCategory));
// };
const handleClickLoadMore = () => {
dispatch(requestImages(currentCategory));
if (currentCategory === 'result') {
dispatch(receivedResultImagesPage());
}
if (currentCategory === 'user') {
dispatch(receivedUploadImagesPage());
}
};
const handleChangeGalleryImageMinimumWidth = (v: number) => {
@ -203,11 +250,11 @@ const ImageGalleryContent = () => {
style={{ gridTemplateColumns: galleryGridTemplateColumns }}
>
{images.map((image) => {
const { uuid } = image;
const isSelected = currentImageUuid === uuid;
const { name } = image;
const isSelected = currentImageUuid === name;
return (
<HoverableImage
key={uuid}
key={name}
image={image}
isSelected={isSelected}
/>
@ -217,6 +264,7 @@ const ImageGalleryContent = () => {
<IAIButton
onClick={handleClickLoadMore}
isDisabled={!areMoreImagesAvailable}
isLoading={isLoading}
flexShrink={0}
>
{areMoreImagesAvailable

View File

@ -11,6 +11,7 @@ import {
} from '@chakra-ui/react';
import * as InvokeAI from 'app/invokeai';
import { useAppDispatch } from 'app/storeHooks';
import { useGetUrl } from 'common/util/getUrl';
import promptToString from 'common/util/promptToString';
import { seedWeightsToString } from 'common/util/seedWeightPairs';
import useSetBothPrompts from 'features/parameters/hooks/usePrompt';
@ -18,7 +19,7 @@ import {
setCfgScale,
setHeight,
setImg2imgStrength,
setInitialImage,
// setInitialImage,
setMaskPath,
setPerlin,
setSampler,
@ -45,6 +46,7 @@ import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import { FaCopy } from 'react-icons/fa';
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
import * as png from '@stevebel/png';
type MetadataItemProps = {
isLink?: boolean;
@ -120,7 +122,7 @@ type ImageMetadataViewerProps = {
const memoEqualityCheck = (
prev: ImageMetadataViewerProps,
next: ImageMetadataViewerProps
) => prev.image.uuid === next.image.uuid;
) => prev.image.name === next.image.name;
// TODO: Show more interesting information in this component.
@ -137,8 +139,8 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
dispatch(setShouldShowImageDetails(false));
});
const metadata = image?.metadata?.image || {};
const dreamPrompt = image?.dreamPrompt;
const metadata = image?.metadata.sd_metadata || {};
const dreamPrompt = image?.metadata.sd_metadata?.dreamPrompt;
const {
cfg_scale,
@ -160,11 +162,23 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
type,
variations,
width,
model_weights,
} = metadata;
const { t } = useTranslation();
const { getUrl } = useGetUrl();
const metadataJSON = JSON.stringify(image.metadata, null, 2);
const metadataJSON = JSON.stringify(image, null, 2);
// fetch(getUrl(image.url))
// .then((r) => r.arrayBuffer())
// .then((buffer) => {
// const { text } = png.decode(buffer);
// const metadata = text?.['sd-metadata']
// ? JSON.parse(text['sd-metadata'] ?? {})
// : {};
// console.log(metadata);
// });
return (
<Flex
@ -183,18 +197,49 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
>
<Flex gap={2}>
<Text fontWeight="semibold">File:</Text>
<Link href={image.url} isExternal maxW="calc(100% - 3rem)">
<Link href={getUrl(image.url)} isExternal maxW="calc(100% - 3rem)">
{image.url.length > 64
? image.url.substring(0, 64).concat('...')
: image.url}
<ExternalLinkIcon mx="2px" />
</Link>
</Flex>
<Flex gap={2} direction="column">
<Flex gap={2}>
<Tooltip label="Copy metadata JSON">
<IconButton
aria-label={t('accessibility.copyMetadataJson')}
icon={<FaCopy />}
size="xs"
variant="ghost"
fontSize={14}
onClick={() => navigator.clipboard.writeText(metadataJSON)}
/>
</Tooltip>
<Text fontWeight="semibold">Metadata JSON:</Text>
</Flex>
<Box
sx={{
mt: 0,
mr: 2,
mb: 4,
ml: 2,
padding: 4,
borderRadius: 'base',
overflowX: 'scroll',
wordBreak: 'break-all',
bg: 'whiteAlpha.500',
_dark: { bg: 'blackAlpha.500' },
}}
>
<pre>{metadataJSON}</pre>
</Box>
</Flex>
{Object.keys(metadata).length > 0 ? (
<>
{type && <MetadataItem label="Generation type" value={type} />}
{image.metadata?.model_weights && (
<MetadataItem label="Model" value={image.metadata.model_weights} />
{model_weights && (
<MetadataItem label="Model" value={model_weights} />
)}
{['esrgan', 'gfpgan'].includes(type) && (
<MetadataItem label="Original image" value={orig_path} />
@ -288,14 +333,14 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
onClick={() => dispatch(setHeight(height))}
/>
)}
{init_image_path && (
{/* {init_image_path && (
<MetadataItem
label="Initial image"
value={init_image_path}
isLink
onClick={() => dispatch(setInitialImage(init_image_path))}
/>
)}
)} */}
{mask_image_path && (
<MetadataItem
label="Mask image"
@ -408,37 +453,6 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
{dreamPrompt && (
<MetadataItem withCopy label="Dream Prompt" value={dreamPrompt} />
)}
<Flex gap={2} direction="column">
<Flex gap={2}>
<Tooltip label="Copy metadata JSON">
<IconButton
aria-label={t('accessibility.copyMetadataJson')}
icon={<FaCopy />}
size="xs"
variant="ghost"
fontSize={14}
onClick={() => navigator.clipboard.writeText(metadataJSON)}
/>
</Tooltip>
<Text fontWeight="semibold">Metadata JSON:</Text>
</Flex>
<Box
sx={{
mt: 0,
mr: 2,
mb: 4,
ml: 2,
padding: 4,
borderRadius: 'base',
overflowX: 'scroll',
wordBreak: 'break-all',
bg: 'whiteAlpha.500',
_dark: { bg: 'blackAlpha.500' },
}}
>
<pre>{metadataJSON}</pre>
</Box>
</Flex>
</>
) : (
<Center width="100%" pt={10}>

View File

@ -7,6 +7,16 @@ import {
uiSelector,
} from 'features/ui/store/uiSelectors';
import { isEqual } from 'lodash';
import {
selectResultsAll,
selectResultsById,
selectResultsEntities,
} from './resultsSlice';
import {
selectUploadsAll,
selectUploadsById,
selectUploadsEntities,
} from './uploadsSlice';
export const gallerySelector = (state: RootState) => state.gallery;
@ -75,3 +85,18 @@ export const hoverableImageSelector = createSelector(
},
}
);
export const selectedImageSelector = createSelector(
[gallerySelector, selectResultsEntities, selectUploadsEntities],
(gallery, allResults, allUploads) => {
const selectedImageName = gallery.selectedImageName;
if (selectedImageName in allResults) {
return allResults[selectedImageName];
}
if (selectedImageName in allUploads) {
return allUploads[selectedImageName];
}
}
);

View File

@ -1,14 +1,17 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import * as InvokeAI from 'app/invokeai';
import { invocationComplete } from 'services/events/actions';
import { InvokeTabName } from 'features/ui/store/tabMap';
import { IRect } from 'konva/lib/types';
import { clamp } from 'lodash';
import { isImageOutput } from 'services/types/guards';
import { imageUploaded } from 'services/thunks/image';
export type GalleryCategory = 'user' | 'result';
export type AddImagesPayload = {
images: Array<InvokeAI.Image>;
images: Array<InvokeAI._Image>;
areMoreImagesAvailable: boolean;
category: GalleryCategory;
};
@ -16,16 +19,33 @@ export type AddImagesPayload = {
type GalleryImageObjectFitType = 'contain' | 'cover';
export type Gallery = {
images: InvokeAI.Image[];
images: InvokeAI._Image[];
latest_mtime?: number;
earliest_mtime?: number;
areMoreImagesAvailable: boolean;
};
export interface GalleryState {
currentImage?: InvokeAI.Image;
/**
* The selected image's unique name
* Use `selectedImageSelector` to access the image
*/
selectedImageName: string;
/**
* The currently selected image
* @deprecated See `state.gallery.selectedImageName`
*/
currentImage?: InvokeAI._Image;
/**
* The currently selected image's uuid.
* @deprecated See `state.gallery.selectedImageName`, use `selectedImageSelector` to access the image
*/
currentImageUuid: string;
intermediateImage?: InvokeAI.Image & {
/**
* The current progress image
* @deprecated See `state.system.progressImage`
*/
intermediateImage?: InvokeAI._Image & {
boundingBox?: IRect;
generationMode?: InvokeTabName;
};
@ -42,6 +62,7 @@ export interface GalleryState {
}
const initialState: GalleryState = {
selectedImageName: '',
currentImageUuid: '',
galleryImageMinimumWidth: 64,
galleryImageObjectFit: 'cover',
@ -69,7 +90,10 @@ export const gallerySlice = createSlice({
name: 'gallery',
initialState,
reducers: {
setCurrentImage: (state, action: PayloadAction<InvokeAI.Image>) => {
imageSelected: (state, action: PayloadAction<string>) => {
state.selectedImageName = action.payload;
},
setCurrentImage: (state, action: PayloadAction<InvokeAI._Image>) => {
state.currentImage = action.payload;
state.currentImageUuid = action.payload.uuid;
},
@ -124,7 +148,7 @@ export const gallerySlice = createSlice({
addImage: (
state,
action: PayloadAction<{
image: InvokeAI.Image;
image: InvokeAI._Image;
category: GalleryCategory;
}>
) => {
@ -150,7 +174,10 @@ export const gallerySlice = createSlice({
setIntermediateImage: (
state,
action: PayloadAction<
InvokeAI.Image & { boundingBox?: IRect; generationMode?: InvokeTabName }
InvokeAI._Image & {
boundingBox?: IRect;
generationMode?: InvokeTabName;
}
>
) => {
state.intermediateImage = action.payload;
@ -252,9 +279,31 @@ export const gallerySlice = createSlice({
state.shouldUseSingleGalleryColumn = action.payload;
},
},
extraReducers(builder) {
/**
* Invocation Complete
*/
builder.addCase(invocationComplete, (state, action) => {
const { data } = action.payload;
if (isImageOutput(data.result)) {
state.selectedImageName = data.result.image.image_name;
state.intermediateImage = undefined;
}
});
/**
* Upload Image - FULFILLED
*/
builder.addCase(imageUploaded.fulfilled, (state, action) => {
const { location } = action.payload;
const imageName = location.split('/').pop() || '';
state.selectedImageName = imageName;
});
},
});
export const {
imageSelected,
addImage,
clearIntermediateImage,
removeImage,

View File

@ -0,0 +1,149 @@
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
import { Image } from 'app/invokeai';
import { invocationComplete } from 'services/events/actions';
import { RootState } from 'app/store';
import {
receivedResultImagesPage,
IMAGES_PER_PAGE,
} from 'services/thunks/gallery';
import { isImageOutput } from 'services/types/guards';
import {
buildImageUrls,
deserializeImageField,
extractTimestampFromImageName,
} from 'services/util/deserializeImageField';
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
import { getUrlAlt } from 'common/util/getUrl';
import { ImageMetadata } from 'services/api';
// import { deserializeImageField } from 'services/util/deserializeImageField';
// use `createEntityAdapter` to create a slice for results images
// https://redux-toolkit.js.org/api/createEntityAdapter#overview
// the "Entity" is InvokeAI.ResultImage, while the "entities" are instances of that type
export const resultsAdapter = createEntityAdapter<Image>({
// Provide a callback to get a stable, unique identifier for each entity. This defaults to
// `(item) => item.id`, but for our result images, the `name` is the unique identifier.
selectId: (image) => image.name,
// Order all images by their time (in descending order)
sortComparer: (a, b) => b.metadata.created - a.metadata.created,
});
// This type is intersected with the Entity type to create the shape of the state
type AdditionalResultsState = {
// these are a bit misleading; they refer to sessions, not results, but we don't have a route
// to list all images directly at this time...
page: number; // current page we are on
pages: number; // the total number of pages available
isLoading: boolean; // whether we are loading more images or not, mostly a placeholder
nextPage: number; // the next page to request
};
// export type ResultsState = ReturnType<
// typeof resultsAdapter.getInitialState<AdditionalResultsState>
// >;
export const initialResultsState =
resultsAdapter.getInitialState<AdditionalResultsState>({
// provide the additional initial state
page: 0,
pages: 0,
isLoading: false,
nextPage: 0,
});
export type ResultsState = typeof initialResultsState;
const resultsSlice = createSlice({
name: 'results',
initialState: initialResultsState,
reducers: {
// the adapter provides some helper reducers; see the docs for all of them
// can use them as helper functions within a reducer, or use the function itself as a reducer
// here we just use the function itself as the reducer. we'll call this on `invocation_complete`
// to add a single result
resultAdded: resultsAdapter.upsertOne,
},
extraReducers: (builder) => {
// here we can respond to a fulfilled call of the `getNextResultsPage` thunk
// because we pass in the fulfilled thunk action creator, everything is typed
/**
* Received Result Images Page - PENDING
*/
builder.addCase(receivedResultImagesPage.pending, (state) => {
state.isLoading = true;
});
/**
* Received Result Images Page - FULFILLED
*/
builder.addCase(receivedResultImagesPage.fulfilled, (state, action) => {
const { items, page, pages } = action.payload;
const resultImages = items.map((image) =>
deserializeImageResponse(image)
);
// use the adapter reducer to append all the results to state
resultsAdapter.addMany(state, resultImages);
state.page = page;
state.pages = pages;
state.nextPage = items.length < IMAGES_PER_PAGE ? page : page + 1;
state.isLoading = false;
});
/**
* Invocation Complete
*/
builder.addCase(invocationComplete, (state, action) => {
const { data } = action.payload;
const { result, invocation, graph_execution_state_id, source_id } = data;
if (isImageOutput(result)) {
const name = result.image.image_name;
const type = result.image.image_type;
const { url, thumbnail } = buildImageUrls(type, name);
const timestamp = extractTimestampFromImageName(name);
const image: Image = {
name,
type,
url,
thumbnail,
metadata: {
created: timestamp,
width: result.width, // TODO: add tese dimensions
height: result.height,
invokeai: {
session: graph_execution_state_id,
source_id,
invocation,
},
},
};
// const resultImage = deserializeImageField(result.image, invocation);
resultsAdapter.addOne(state, image);
}
});
},
});
// Create a set of memoized selectors based on the location of this entity state
// to be used as selectors in a `useAppSelector()` call
export const {
selectAll: selectResultsAll,
selectById: selectResultsById,
selectEntities: selectResultsEntities,
selectIds: selectResultsIds,
selectTotal: selectResultsTotal,
} = resultsAdapter.getSelectors<RootState>((state) => state.results);
export const { resultAdded } = resultsSlice.actions;
export default resultsSlice.reducer;

View File

@ -1,54 +0,0 @@
import { AnyAction, ThunkAction } from '@reduxjs/toolkit';
import * as InvokeAI from 'app/invokeai';
import { RootState } from 'app/store';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { setInitialImage } from 'features/parameters/store/generationSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { v4 as uuidv4 } from 'uuid';
import { addImage } from '../gallerySlice';
type UploadImageConfig = {
imageFile: File;
};
export const uploadImage =
(
config: UploadImageConfig
): ThunkAction<void, RootState, unknown, AnyAction> =>
async (dispatch, getState) => {
const { imageFile } = config;
const state = getState() as RootState;
const activeTabName = activeTabNameSelector(state);
const formData = new FormData();
formData.append('file', imageFile, imageFile.name);
formData.append(
'data',
JSON.stringify({
kind: 'init',
})
);
const response = await fetch(`${window.location.origin}/upload`, {
method: 'POST',
body: formData,
});
const image = (await response.json()) as InvokeAI.ImageUploadResponse;
const newImage: InvokeAI.Image = {
uuid: uuidv4(),
category: 'user',
...image,
};
dispatch(addImage({ image: newImage, category: 'user' }));
if (activeTabName === 'unifiedCanvas') {
dispatch(setInitialCanvasImage(newImage));
} else if (activeTabName === 'img2img') {
dispatch(setInitialImage(newImage));
}
};

View File

@ -0,0 +1,95 @@
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
import { Image } from 'app/invokeai';
import { RootState } from 'app/store';
import {
receivedUploadImagesPage,
IMAGES_PER_PAGE,
} from 'services/thunks/gallery';
import { imageUploaded } from 'services/thunks/image';
import { deserializeImageField } from 'services/util/deserializeImageField';
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
export const uploadsAdapter = createEntityAdapter<Image>({
selectId: (image) => image.name,
sortComparer: (a, b) => b.metadata.created - a.metadata.created,
});
type AdditionalUploadsState = {
page: number;
pages: number;
isLoading: boolean;
nextPage: number;
};
export type UploadssState = ReturnType<
typeof uploadsAdapter.getInitialState<AdditionalUploadsState>
>;
const uploadsSlice = createSlice({
name: 'uploads',
initialState: uploadsAdapter.getInitialState<AdditionalUploadsState>({
page: 0,
pages: 0,
nextPage: 0,
isLoading: false,
}),
reducers: {
uploadAdded: uploadsAdapter.addOne,
},
extraReducers: (builder) => {
/**
* Received Upload Images Page - PENDING
*/
builder.addCase(receivedUploadImagesPage.pending, (state) => {
state.isLoading = true;
});
/**
* Received Upload Images Page - FULFILLED
*/
builder.addCase(receivedUploadImagesPage.fulfilled, (state, action) => {
const { items, page, pages } = action.payload;
const images = items.map((image) => deserializeImageResponse(image));
uploadsAdapter.addMany(state, images);
state.page = page;
state.pages = pages;
state.nextPage = items.length < IMAGES_PER_PAGE ? page : page + 1;
state.isLoading = false;
});
/**
* Upload Image - FULFILLED
*/
builder.addCase(imageUploaded.fulfilled, (state, action) => {
const { location, response } = action.payload;
const { image_name, image_url, image_type, metadata, thumbnail_url } =
response;
const uploadedImage: Image = {
name: image_name,
url: image_url,
thumbnail: thumbnail_url,
type: 'uploads',
metadata,
};
uploadsAdapter.addOne(state, uploadedImage);
});
},
});
export const {
selectAll: selectUploadsAll,
selectById: selectUploadsById,
selectEntities: selectUploadsEntities,
selectIds: selectUploadsIds,
selectTotal: selectUploadsTotal,
} = uploadsAdapter.getSelectors<RootState>((state) => state.uploads);
export const { uploadAdded } = uploadsSlice.actions;
export default uploadsSlice.reducer;

View File

@ -1,9 +1,10 @@
import * as React from 'react';
import { TransformComponent, useTransformContext } from 'react-zoom-pan-pinch';
import * as InvokeAI from 'app/invokeai';
import { useGetUrl } from 'common/util/getUrl';
type ReactPanZoomProps = {
image: InvokeAI.Image;
image: InvokeAI._Image;
styleClass?: string;
alt?: string;
ref?: React.Ref<HTMLImageElement>;
@ -22,6 +23,7 @@ export default function ReactPanZoomImage({
scaleY,
}: ReactPanZoomProps) {
const { centerView } = useTransformContext();
const { getUrl } = useGetUrl();
return (
<TransformComponent
@ -35,7 +37,7 @@ export default function ReactPanZoomImage({
transform: `rotate(${rotation}deg) scaleX(${scaleX}) scaleY(${scaleY})`,
width: '100%',
}}
src={image.url}
src={getUrl(image.url)}
alt={alt}
ref={ref}
className={styleClass ? styleClass : ''}

View File

@ -0,0 +1,47 @@
import { v4 as uuidv4 } from 'uuid';
import 'reactflow/dist/style.css';
import { useCallback } from 'react';
import {
Tooltip,
Menu,
MenuButton,
MenuList,
MenuItem,
IconButton,
} from '@chakra-ui/react';
import { FaPlus } from 'react-icons/fa';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { nodeAdded } from '../store/nodesSlice';
import { map } from 'lodash';
import { RootState } from 'app/store';
export const AddNodeMenu = () => {
const dispatch = useAppDispatch();
const invocations = useAppSelector(
(state: RootState) => state.nodes.invocations
);
const addNode = useCallback(
(nodeType: string) => {
dispatch(nodeAdded({ id: uuidv4(), invocation: invocations[nodeType] }));
},
[dispatch, invocations]
);
return (
<Menu>
<MenuButton as={IconButton} aria-label="Add Node" icon={<FaPlus />} />
<MenuList>
{map(invocations, ({ title, description, type }, key) => {
return (
<Tooltip key={key} label={description} placement="end" hasArrow>
<MenuItem onClick={() => addNode(type)}>{title}</MenuItem>
</Tooltip>
);
})}
</MenuList>
</Menu>
);
};

View File

@ -0,0 +1,78 @@
import { Tooltip } from '@chakra-ui/react';
import { CSSProperties, useMemo } from 'react';
import {
Handle,
Position,
Connection,
HandleType,
useReactFlow,
} from 'reactflow';
import { FIELDS, HANDLE_TOOLTIP_OPEN_DELAY } from '../constants';
// import { useConnectionEventStyles } from '../hooks/useConnectionEventStyles';
import { InputField, OutputField } from '../types';
const handleBaseStyles: CSSProperties = {
position: 'absolute',
width: '1rem',
height: '1rem',
opacity: 0.5,
borderWidth: 0,
};
const inputHandleStyles: CSSProperties = {
left: '-1.7rem',
};
const outputHandleStyles: CSSProperties = {
right: '-1.7rem',
};
const requiredConnectionStyles: CSSProperties = {
opacity: 1,
};
type FieldHandleProps = {
nodeId: string;
field: InputField | OutputField;
isValidConnection: (connection: Connection) => boolean;
handleType: HandleType;
styles?: CSSProperties;
};
export const FieldHandle = (props: FieldHandleProps) => {
const { nodeId, field, isValidConnection, handleType, styles } = props;
const { name, title, type, description, connectionType } = field;
// this needs to iterate over every candicate target node, calculating graph cycles
// WIP
// const connectionEventStyles = useConnectionEventStyles(
// nodeId,
// type,
// handleType
// );
return (
<Tooltip
key={name}
label={`${title} (${type})`}
placement={handleType === 'target' ? 'start' : 'end'}
hasArrow
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
>
<Handle
type={handleType}
id={name}
isValidConnection={isValidConnection}
position={handleType === 'target' ? Position.Left : Position.Right}
style={{
backgroundColor: `var(--invokeai-colors-${FIELDS[type].color}-500)`,
...styles,
...handleBaseStyles,
...(handleType === 'target' ? inputHandleStyles : outputHandleStyles),
...(connectionType === 'always' ? requiredConnectionStyles : {}),
// ...connectionEventStyles,
}}
/>
</Tooltip>
);
};

View File

@ -0,0 +1,18 @@
import 'reactflow/dist/style.css';
import { Tooltip, Badge, HStack } from '@chakra-ui/react';
import { map } from 'lodash';
import { FIELDS } from '../constants';
export const FieldTypeLegend = () => {
return (
<HStack>
{map(FIELDS, ({ title, description, color }, key) => (
<Tooltip key={key} label={description}>
<Badge colorScheme={color} sx={{ userSelect: 'none' }}>
{title}
</Badge>
</Tooltip>
))}
</HStack>
);
};

View File

@ -0,0 +1,104 @@
import {
Background,
Controls,
MiniMap,
OnConnect,
OnEdgesChange,
OnNodesChange,
ReactFlow,
ConnectionLineType,
OnConnectStart,
OnConnectEnd,
Panel,
} from 'reactflow';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { RootState } from 'app/store';
import {
connectionEnded,
connectionMade,
connectionStarted,
edgesChanged,
nodesChanged,
} from '../store/nodesSlice';
import { useCallback } from 'react';
import { InvocationComponent } from './InvocationComponent';
import { AddNodeMenu } from './AddNodeMenu';
import { FieldTypeLegend } from './FieldTypeLegend';
import { Button } from '@chakra-ui/react';
import { nodesGraphBuilt } from 'services/thunks/session';
const nodeTypes = { invocation: InvocationComponent };
export const Flow = () => {
const dispatch = useAppDispatch();
const nodes = useAppSelector((state: RootState) => state.nodes.nodes);
const edges = useAppSelector((state: RootState) => state.nodes.edges);
const onNodesChange: OnNodesChange = useCallback(
(changes) => {
dispatch(nodesChanged(changes));
},
[dispatch]
);
const onEdgesChange: OnEdgesChange = useCallback(
(changes) => {
dispatch(edgesChanged(changes));
},
[dispatch]
);
const onConnectStart: OnConnectStart = useCallback(
(event, params) => {
dispatch(connectionStarted(params));
},
[dispatch]
);
const onConnect: OnConnect = useCallback(
(connection) => {
dispatch(connectionMade(connection));
},
[dispatch]
);
const onConnectEnd: OnConnectEnd = useCallback(
(event) => {
dispatch(connectionEnded());
},
[dispatch]
);
const handleInvoke = useCallback(() => {
dispatch(nodesGraphBuilt());
}, [dispatch]);
return (
<ReactFlow
nodeTypes={nodeTypes}
nodes={nodes}
edges={edges}
onNodesChange={onNodesChange}
onEdgesChange={onEdgesChange}
onConnectStart={onConnectStart}
onConnect={onConnect}
onConnectEnd={onConnectEnd}
defaultEdgeOptions={{
style: { strokeWidth: 2 },
}}
>
<Panel position="top-left">
<AddNodeMenu />
</Panel>
<Panel position="top-center">
<Button onClick={handleInvoke}>Will it blend?</Button>
</Panel>
<Panel position="top-right">
<FieldTypeLegend />
</Panel>
<Background />
<Controls />
<MiniMap nodeStrokeWidth={3} zoomable pannable />
</ReactFlow>
);
};

View File

@ -0,0 +1,50 @@
import { Box } from '@chakra-ui/react';
import { InputField } from '../types';
import { BooleanInputFieldComponent } from './fields/BooleanInputFieldComponent';
import { EnumInputFieldComponent } from './fields/EnumInputFieldComponent';
import { ImageInputFieldComponent } from './fields/ImageInputFieldComponent';
import { LatentsInputFieldComponent } from './fields/LatentsInputFieldComponent';
import { ModelInputFieldComponent } from './fields/ModelInputFieldComponent';
import { NumberInputFieldComponent } from './fields/NumberInputFieldComponent';
import { StringInputFieldComponent } from './fields/StringInputFieldComponent';
type InputFieldComponentProps = {
nodeId: string;
field: InputField;
};
// build an individual input element based on the schema
export const InputFieldComponent = (props: InputFieldComponentProps) => {
const { nodeId, field } = props;
const { type, value } = field;
if (type === 'string') {
return <StringInputFieldComponent nodeId={nodeId} field={field} />;
}
if (type === 'boolean') {
return <BooleanInputFieldComponent nodeId={nodeId} field={field} />;
}
if (type === 'integer' || type === 'float') {
return <NumberInputFieldComponent nodeId={nodeId} field={field} />;
}
if (type === 'enum') {
return <EnumInputFieldComponent nodeId={nodeId} field={field} />;
}
if (type === 'image') {
return <ImageInputFieldComponent nodeId={nodeId} field={field} />;
}
if (type === 'latents') {
return <LatentsInputFieldComponent nodeId={nodeId} field={field} />;
}
if (type === 'model') {
return <ModelInputFieldComponent nodeId={nodeId} field={field} />;
}
return <Box p={2}>Unknown field type: {type}</Box>;
};

View File

@ -0,0 +1,145 @@
import { NodeProps, useReactFlow } from 'reactflow';
import {
Box,
Flex,
FormControl,
FormLabel,
Heading,
HStack,
Tooltip,
Icon,
Code,
Text,
} from '@chakra-ui/react';
import { FaInfoCircle } from 'react-icons/fa';
import { Invocation } from '../types';
import { InputFieldComponent } from './InputFieldComponent';
import { FieldHandle } from './FieldHandle';
import { isEqual, map, size } from 'lodash';
import { memo, useMemo } from 'react';
import { useIsValidConnection } from '../hooks/useIsValidConnection';
import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store';
import { useAppSelector } from 'app/storeHooks';
const connectedInputFieldsSelector = createSelector(
(state: RootState) => state.nodes.edges,
(edges) => {
return edges.map((e) => e.targetHandle);
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
export const InvocationComponent = memo((props: NodeProps<Invocation>) => {
const { id, data, selected } = props;
const { type, title, description, inputs, outputs } = data;
const isValidConnection = useIsValidConnection();
const connectedInputs = useAppSelector(connectedInputFieldsSelector);
// TODO: determine if a field/handle is connected and disable the input if so
return (
<Box
sx={{
padding: 4,
bg: 'base.800',
borderRadius: 'md',
boxShadow: 'dark-lg',
borderWidth: 2,
borderColor: selected ? 'base.400' : 'transparent',
}}
>
<Flex flexDirection="column" gap={2}>
<>
<Code>{id}</Code>
<HStack justifyContent="space-between">
<Heading size="sm" fontWeight={500} color="base.100">
{title}
</Heading>
<Tooltip
label={description}
placement="top"
hasArrow
shouldWrapChildren
>
<Icon color="base.300" as={FaInfoCircle} />
</Tooltip>
</HStack>
{map(inputs, (input, i) => {
const isConnected = connectedInputs.includes(input.name);
return (
<Box
key={i}
position="relative"
p={2}
borderWidth={1}
borderRadius="md"
sx={{
borderColor:
!isConnected && input.connectionType === 'always'
? 'warning.400'
: undefined,
}}
>
<FormControl isDisabled={isConnected}>
<HStack justifyContent="space-between" alignItems="center">
<FormLabel>{input.title}</FormLabel>
<Tooltip
label={input.description}
placement="top"
hasArrow
shouldWrapChildren
>
<Icon color="base.400" as={FaInfoCircle} />
</Tooltip>
</HStack>
<InputFieldComponent nodeId={id} field={input} />
</FormControl>
{input.connectionType !== 'never' && (
<FieldHandle
nodeId={id}
field={input}
isValidConnection={isValidConnection}
handleType="target"
/>
)}
</Box>
);
})}
{map(outputs).map((output, i) => {
// const top = `${(100 / (size(outputs) + 1)) * (i + 1)}%`;
const { name, title } = output;
return (
<Box
key={name}
position="relative"
p={2}
borderWidth={1}
borderRadius="md"
>
<FormControl>
<FormLabel textAlign="end">{title} Output</FormLabel>
</FormControl>
<FieldHandle
key={name}
nodeId={id}
field={output}
isValidConnection={isValidConnection}
handleType="source"
/>
</Box>
);
})}
</>
</Flex>
<Flex></Flex>
</Box>
);
});
InvocationComponent.displayName = 'InvocationComponent';

View File

@ -0,0 +1,46 @@
import 'reactflow/dist/style.css';
import { Box } from '@chakra-ui/react';
import { ReactFlowProvider } from 'reactflow';
import { Flow } from './Flow';
import { useAppSelector } from 'app/storeHooks';
import { RootState } from 'app/store';
import { buildNodesGraph } from '../util/buildNodesGraph';
const NodeEditor = () => {
const state = useAppSelector((state: RootState) => state);
const graph = buildNodesGraph(state);
return (
<Box
sx={{
position: 'relative',
width: 'full',
height: 'full',
borderRadius: 'md',
bg: 'base.850',
}}
>
<ReactFlowProvider>
<Flow />
</ReactFlowProvider>
<Box
as="pre"
fontFamily="monospace"
position="absolute"
top={2}
left={2}
width="full"
height="full"
userSelect="none"
pointerEvents="none"
opacity={0.7}
>
<Box w="50%">{JSON.stringify(graph, null, 2)}</Box>
</Box>
</Box>
);
};
export default NodeEditor;

View File

@ -0,0 +1,28 @@
import { Switch } from '@chakra-ui/react';
import { useAppDispatch } from 'app/storeHooks';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import { BooleanInputField } from 'features/nodes/types';
import { ChangeEvent } from 'react';
import { FieldComponentProps } from './types';
export const BooleanInputFieldComponent = (
props: FieldComponentProps<BooleanInputField>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const handleValueChanged = (e: ChangeEvent<HTMLInputElement>) => {
dispatch(
fieldValueChanged({
nodeId,
fieldId: field.name,
value: e.target.checked,
})
);
};
return (
<Switch onChange={handleValueChanged} isChecked={field.value}></Switch>
);
};

View File

@ -0,0 +1,32 @@
import { Select } from '@chakra-ui/react';
import { useAppDispatch } from 'app/storeHooks';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import { EnumInputField } from 'features/nodes/types';
import { ChangeEvent } from 'react';
import { FieldComponentProps } from './types';
export const EnumInputFieldComponent = (
props: FieldComponentProps<EnumInputField>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const handleValueChanged = (e: ChangeEvent<HTMLSelectElement>) => {
dispatch(
fieldValueChanged({
nodeId,
fieldId: field.name,
value: e.target.value,
})
);
};
return (
<Select onChange={handleValueChanged} value={field.value}>
{field.options.map((option) => (
<option key={option}>{option}</option>
))}
</Select>
);
};

View File

@ -0,0 +1,11 @@
import { ImageInputField } from 'features/nodes/types';
import { FaImage } from 'react-icons/fa';
import { FieldComponentProps } from './types';
export const ImageInputFieldComponent = (
props: FieldComponentProps<ImageInputField>
) => {
const { nodeId, field } = props;
return <FaImage />;
};

View File

@ -0,0 +1,11 @@
import { LatentsInputField } from 'features/nodes/types';
import { TbBrandMatrix } from 'react-icons/tb';
import { FieldComponentProps } from './types';
export const LatentsInputFieldComponent = (
props: FieldComponentProps<LatentsInputField>
) => {
const { nodeId, field } = props;
return <TbBrandMatrix />;
};

View File

@ -0,0 +1,49 @@
import { Select } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import { ModelInputField } from 'features/nodes/types';
import { isEqual, map } from 'lodash';
import { ChangeEvent } from 'react';
import { FieldComponentProps } from './types';
const availableModelsSelector = createSelector(
(state: RootState) => state.models.modelList,
(modelList) => {
return map(modelList, (_, name) => name);
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
export const ModelInputFieldComponent = (
props: FieldComponentProps<ModelInputField>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const availableModels = useAppSelector(availableModelsSelector);
const handleValueChanged = (e: ChangeEvent<HTMLSelectElement>) => {
dispatch(
fieldValueChanged({
nodeId,
fieldId: field.name,
value: e.target.value,
})
);
};
return (
<Select onChange={handleValueChanged} value={field.value}>
{availableModels.map((option) => (
<option key={option}>{option}</option>
))}
</Select>
);
};

View File

@ -0,0 +1,33 @@
import {
NumberDecrementStepper,
NumberIncrementStepper,
NumberInput,
NumberInputField,
NumberInputStepper,
} from '@chakra-ui/react';
import { useAppDispatch } from 'app/storeHooks';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import { IntegerInputField, FloatInputField } from 'features/nodes/types';
import { FieldComponentProps } from './types';
export const NumberInputFieldComponent = (
props: FieldComponentProps<IntegerInputField | FloatInputField>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const handleValueChanged = (_: string, value: number) => {
dispatch(fieldValueChanged({ nodeId, fieldId: field.name, value }));
};
return (
<NumberInput onChange={handleValueChanged} value={field.value}>
<NumberInputField />
<NumberInputStepper>
<NumberIncrementStepper />
<NumberDecrementStepper />
</NumberInputStepper>
</NumberInput>
);
};

View File

@ -0,0 +1,22 @@
import { Input } from '@chakra-ui/react';
import { useAppDispatch } from 'app/storeHooks';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import { StringInputField } from 'features/nodes/types';
import { ChangeEvent } from 'react';
import { FieldComponentProps } from './types';
export const StringInputFieldComponent = (
props: FieldComponentProps<StringInputField>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const handleValueChanged = (e: ChangeEvent<HTMLInputElement>) => {
dispatch(
fieldValueChanged({ nodeId, fieldId: field.name, value: e.target.value })
);
};
return <Input onChange={handleValueChanged} value={field.value}></Input>;
};

View File

@ -0,0 +1,6 @@
import { InputField } from 'features/nodes/types';
export type FieldComponentProps<T extends InputField> = {
nodeId: string;
field: T;
};

View File

@ -0,0 +1,57 @@
import { FieldType, FieldUIConfig } from './types';
export const HANDLE_TOOLTIP_OPEN_DELAY = 500;
export const FIELD_TYPE_MAP: Record<string, FieldType> = {
integer: 'integer',
number: 'float',
string: 'string',
boolean: 'boolean',
enum: 'enum',
ImageField: 'image',
LatentsField: 'latents',
model: 'model',
};
export const FIELDS: Record<FieldType, FieldUIConfig> = {
integer: {
color: 'red',
title: 'Integer',
description: 'Integers are whole numbers, without a decimal point.',
},
float: {
color: 'orange',
title: 'Float',
description: 'Floats are numbers with a decimal point.',
},
string: {
color: 'yellow',
title: 'String',
description: 'Strings are text.',
},
boolean: {
color: 'green',
title: 'Boolean',
description: 'Booleans are true or false.',
},
enum: {
color: 'blue',
title: 'Enum',
description: 'Enums are values that may be one of a number of options.',
},
image: {
color: 'purple',
title: 'Image',
description: 'Images may be passed between nodes.',
},
latents: {
color: 'pink',
title: 'Latents',
description: 'Latents may be passed between nodes.',
},
model: {
color: 'teal',
title: 'Model',
description: 'Models are models.',
},
};

View File

@ -0,0 +1,67 @@
import { RootState } from 'app/store';
import { useAppSelector } from 'app/storeHooks';
import { CSSProperties, useMemo } from 'react';
import { HandleType, useReactFlow } from 'reactflow';
import { FieldType, Invocation } from '../types';
const invalidTargetStyles: CSSProperties = {
opacity: 0.3,
};
const validTargetStyles: CSSProperties = {};
export const useConnectionEventStyles = (
nodeId: string,
fieldType: FieldType,
handleType: HandleType
) => {
const flow = useReactFlow();
const pendingConnection = useAppSelector(
(state: RootState) => state.nodes.pendingConnection
);
return useMemo(() => {
if (!pendingConnection) {
return;
}
const {
handleId,
handleType: sourceHandleType,
nodeId: sourceNodeId,
} = pendingConnection;
// default to connectable if these are not present - unsure why they ever would not be present...
if (!handleId || !sourceNodeId || !handleType) {
return validTargetStyles;
}
if (
// cannot connect a node's input to its own output
nodeId === sourceNodeId
) {
return invalidTargetStyles;
}
if (
// cannot connect inputs to inputs or outputs to outputs
handleType === sourceHandleType
) {
return invalidTargetStyles;
}
const node = flow.getNode(sourceNodeId)?.data as Invocation;
// handle field types must be the same
if (
fieldType !==
(sourceHandleType === 'target'
? node.inputs[handleId].type
: node.outputs[handleId].type)
) {
return invalidTargetStyles;
}
return validTargetStyles;
}, [pendingConnection, nodeId, flow, fieldType, handleType]);
};

View File

@ -0,0 +1,67 @@
import { useCallback } from 'react';
import { Connection, useReactFlow } from 'reactflow';
import graphlib from '@dagrejs/graphlib';
export const useIsValidConnection = () => {
const flow = useReactFlow();
// Check if an in-progress connection is valid
const isValidConnection = useCallback(
({ source, sourceHandle, target, targetHandle }: Connection): boolean => {
const edges = flow.getEdges();
const nodes = flow.getNodes();
// Connection must have valid targets
if (!(source && sourceHandle && target && targetHandle)) {
return false;
}
// Connection is invalid if target already has a connection
if (
edges.find((edge) => {
return edge.target === target && edge.targetHandle === targetHandle;
})
) {
return false;
}
// Find the source and target nodes
const sourceNode = flow.getNode(source);
const targetNode = flow.getNode(target);
// Conditional guards against undefined nodes/handles
if (!(sourceNode && targetNode)) {
return false;
}
// Connection types must be the same for a connection
if (
sourceNode.data.outputs[sourceHandle].type !==
targetNode.data.inputs[targetHandle].type
) {
return false;
}
// Graphs much be acyclic (no loops!)
// build a graphlib graph
const g = new graphlib.Graph();
nodes.forEach((n) => {
g.setNode(n.id);
});
edges.forEach((e) => {
g.setEdge(e.source, e.target);
});
// Add the candidate edge to the graph
g.setEdge(source, target);
return graphlib.alg.isAcyclic(g);
},
[flow]
);
return isValidConnection;
};

View File

@ -0,0 +1,115 @@
import { createSlice, isAnyOf, PayloadAction } from '@reduxjs/toolkit';
import { OpenAPIV3 } from 'openapi-types';
import {
addEdge,
applyEdgeChanges,
applyNodeChanges,
Connection,
Edge,
EdgeChange,
Node,
NodeChange,
OnConnectStartParams,
} from 'reactflow';
import { Graph } from 'services/api';
import { receivedOpenAPISchema } from 'services/thunks/schema';
import {
isFulfilledAnyGraphBuilt,
linearGraphBuilt,
nodesGraphBuilt,
} from 'services/thunks/session';
import { Invocation } from '../types';
import { buildNodesGraph } from '../util/buildNodesGraph';
import { parseSchema } from '../util/parseSchema';
export type NodesState = {
nodes: Node<Invocation>[];
edges: Edge[];
schema: OpenAPIV3.Document | null;
invocations: Record<string, Invocation>;
pendingConnection: OnConnectStartParams | null;
lastGraph: Graph | null;
};
export const initialNodesState: NodesState = {
nodes: [],
edges: [],
schema: null,
invocations: {},
pendingConnection: null,
lastGraph: null,
};
const nodesSlice = createSlice({
name: 'results',
initialState: initialNodesState,
reducers: {
nodesChanged: (state, action: PayloadAction<NodeChange[]>) => {
state.nodes = applyNodeChanges(action.payload, state.nodes);
},
nodeAdded: (
state,
action: PayloadAction<{ id: string; invocation: Invocation }>
) => {
const { id, invocation } = action.payload;
const node: Node = {
id,
type: 'invocation',
position: { x: 0, y: 0 },
data: invocation,
};
state.nodes.push(node);
},
edgesChanged: (state, action: PayloadAction<EdgeChange[]>) => {
state.edges = applyEdgeChanges(action.payload, state.edges);
},
connectionStarted: (state, action: PayloadAction<OnConnectStartParams>) => {
state.pendingConnection = action.payload;
},
connectionMade: (state, action: PayloadAction<Connection>) => {
state.edges = addEdge(action.payload, state.edges);
},
connectionEnded: (state) => {
state.pendingConnection = null;
},
fieldValueChanged: (
state,
action: PayloadAction<{
nodeId: string;
fieldId: string;
value: string | number | boolean | undefined;
}>
) => {
const { nodeId, fieldId, value } = action.payload;
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
if (nodeIndex > -1) {
state.nodes[nodeIndex].data.inputs[fieldId].value = value;
}
},
},
extraReducers(builder) {
builder.addCase(receivedOpenAPISchema.fulfilled, (state, action) => {
state.schema = action.payload;
state.invocations = parseSchema(action.payload);
});
builder.addMatcher(isFulfilledAnyGraphBuilt, (state, action) => {
state.lastGraph = action.payload;
});
},
});
export const {
nodesChanged,
edgesChanged,
nodeAdded,
fieldValueChanged,
connectionMade,
connectionStarted,
connectionEnded,
} = nodesSlice.actions;
export default nodesSlice.reducer;

View File

@ -0,0 +1,187 @@
import { OpenAPIV3 } from 'openapi-types';
export const isReferenceObject = (
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject
): obj is OpenAPIV3.ReferenceObject => '$ref' in obj;
export const isSchemaObject = (
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject
): obj is OpenAPIV3.SchemaObject => !('$ref' in obj);
export type Invocation = {
/**
* Unique type of the invocation
*/
type: string;
/**
* Display name of the invocation
*/
title: string;
/**
* Description of the invocation
*/
description: string;
/**
* Invocation tags
*/
tags: string[];
/**
* Array of invocation inputs
*/
inputs: Record<string, InputField>;
// inputs: InputField[];
/**
* Array of the invocation outputs
*/
outputs: Record<string, OutputField>;
// outputs: OutputField[];
};
export type FieldUIConfig = {
color:
| 'red'
| 'orange'
| 'yellow'
| 'green'
| 'blue'
| 'purple'
| 'pink'
| 'teal';
title: string;
description: string;
};
export type FieldType =
| 'integer'
| 'float'
| 'string'
| 'boolean'
| 'enum'
| 'image'
| 'latents'
| 'model';
export type InputField =
| IntegerInputField
| FloatInputField
| StringInputField
| BooleanInputField
| ImageInputField
| LatentsInputField
| EnumInputField
| ModelInputField;
export type OutputField = FieldBase;
export type ConnectionType = 'never' | 'always';
export type FieldBase = {
name: string;
title: string;
description: string;
type: FieldType;
connectionType?: ConnectionType;
};
export type NumberInvocationField = {
value?: number;
multipleOf?: number;
maximum?: number;
exclusiveMaximum?: boolean;
minimum?: number;
exclusiveMinimum?: boolean;
};
export type IntegerInputField = FieldBase &
NumberInvocationField & {
type: 'integer';
};
export type FloatInputField = FieldBase &
NumberInvocationField & {
type: 'float';
};
export type StringInputField = FieldBase & {
type: 'string';
value?: string;
maxLength?: number;
minLength?: number;
pattern?: string;
};
export type BooleanInputField = FieldBase & {
type: 'boolean';
value?: boolean;
};
export type ImageInputField = FieldBase & {
type: 'image';
// TODO: use a better value
value?: string;
};
export type LatentsInputField = FieldBase & {
type: 'latents';
// TODO: use a better value
value?: string;
};
export type EnumInputField = FieldBase & {
type: 'enum';
value?: string | number;
enumType: 'string' | 'integer' | 'number';
options: Array<string | number>;
};
export type ModelInputField = FieldBase & {
type: 'model';
value?: string;
};
/**
* JANKY CUSTOMISATION OF OpenAPI SCHEMA TYPES
*/
export type TypeHints = {
[fieldName: string]: FieldType;
};
export type InvocationSchemaExtra = {
output: OpenAPIV3.ReferenceObject; // the output of the invocation
ui?: {
tags?: string[];
type_hints?: TypeHints;
};
title: string;
properties: Omit<
NonNullable<OpenAPIV3.SchemaObject['properties']>,
'type'
> & {
type: Omit<OpenAPIV3.SchemaObject, 'default'> & { default: string };
};
};
export type InvocationSchemaType = {
default: string; // the type of the invocation
};
export type InvocationBaseSchemaObject = Omit<
OpenAPIV3.BaseSchemaObject,
'title' | 'type' | 'properties'
> &
InvocationSchemaExtra;
interface ArraySchemaObject extends InvocationBaseSchemaObject {
type: OpenAPIV3.ArraySchemaObjectType;
items: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject;
}
interface NonArraySchemaObject extends InvocationBaseSchemaObject {
type?: OpenAPIV3.NonArraySchemaObjectType;
}
export type InvocationSchemaObject = ArraySchemaObject | NonArraySchemaObject;
export const isInvocationSchemaObject = (
obj: OpenAPIV3.ReferenceObject | InvocationSchemaObject
): obj is InvocationSchemaObject => !('$ref' in obj);

View File

@ -0,0 +1,64 @@
import { Graph } from 'services/api';
import { v4 as uuidv4 } from 'uuid';
import { reduce } from 'lodash';
import { RootState } from 'app/store';
export const buildNodesGraph = (state: RootState): Graph => {
const { nodes, edges } = state.nodes;
const parsedNodes = nodes.reduce<NonNullable<Graph['nodes']>>(
(nodesAccumulator, node, nodeIndex) => {
const { id, data } = node;
const { type, inputs } = data;
const transformedInputs = reduce(
inputs,
(inputsAccumulator, input, name) => {
inputsAccumulator[name] = input.value;
return inputsAccumulator;
},
{} as Record<string, any>
);
const graphNode = {
type,
id,
...transformedInputs,
};
nodesAccumulator[id] = graphNode;
return nodesAccumulator;
},
{}
);
const parsedEdges = edges.reduce<NonNullable<Graph['edges']>>(
(edgesAccumulator, edge, edgeIndex) => {
const { source, target, sourceHandle, targetHandle } = edge;
edgesAccumulator.push({
source: {
node_id: source,
field: sourceHandle as string,
},
destination: {
node_id: target,
field: targetHandle as string,
},
});
return edgesAccumulator;
},
[]
);
const graph = {
id: uuidv4(),
nodes: parsedNodes,
edges: parsedEdges,
};
return graph;
};

View File

@ -0,0 +1,313 @@
import { reduce } from 'lodash';
import { OpenAPIV3 } from 'openapi-types';
import { FIELD_TYPE_MAP } from '../constants';
import {
BooleanInputField,
EnumInputField,
FloatInputField,
ImageInputField,
IntegerInputField,
LatentsInputField,
OutputField,
StringInputField,
isSchemaObject,
ModelInputField,
TypeHints,
FieldType,
InputField,
} from '../types';
export type BaseFieldProperties = 'name' | 'title' | 'description';
export type BuildInputFieldArg = {
schemaObject: OpenAPIV3.SchemaObject;
baseField: Pick<InputField, BaseFieldProperties>;
};
/**
* Transforms an invocation output ref object to field type.
* @param ref The ref string to transform
* @returns The field type.
*
* @example
* refObjectToFieldType({ "$ref": "#/components/schemas/ImageField" }) --> 'ImageField'
*/
export const refObjectToFieldType = (
refObject: OpenAPIV3.ReferenceObject
): keyof typeof FIELD_TYPE_MAP => refObject.$ref.split('/').slice(-1)[0];
const buildIntegerInputField = ({
schemaObject,
baseField,
}: BuildInputFieldArg): IntegerInputField => {
const field: Omit<IntegerInputField, BaseFieldProperties> = {
type: 'integer',
value: schemaObject.default ?? 0,
};
if (schemaObject.multipleOf !== undefined) {
field.multipleOf = schemaObject.multipleOf;
}
if (schemaObject.maximum !== undefined) {
field.maximum = schemaObject.maximum;
}
if (schemaObject.exclusiveMaximum !== undefined) {
field.exclusiveMaximum = schemaObject.exclusiveMaximum;
}
if (schemaObject.minimum !== undefined) {
field.minimum = schemaObject.minimum;
}
if (schemaObject.exclusiveMinimum !== undefined) {
field.exclusiveMinimum = schemaObject.exclusiveMinimum;
}
return { ...baseField, ...field };
};
const buildFloatInputField = ({
schemaObject,
baseField,
}: BuildInputFieldArg): FloatInputField => {
const field: Omit<FloatInputField, BaseFieldProperties> = {
type: 'float',
value: schemaObject.default ?? 0,
};
if (schemaObject.multipleOf !== undefined) {
field.multipleOf = schemaObject.multipleOf;
}
if (schemaObject.maximum !== undefined) {
field.maximum = schemaObject.maximum;
}
if (schemaObject.exclusiveMaximum !== undefined) {
field.exclusiveMaximum = schemaObject.exclusiveMaximum;
}
if (schemaObject.minimum !== undefined) {
field.minimum = schemaObject.minimum;
}
if (schemaObject.exclusiveMinimum !== undefined) {
field.exclusiveMinimum = schemaObject.exclusiveMinimum;
}
return { ...baseField, ...field };
};
const buildStringInputField = ({
schemaObject,
baseField,
}: BuildInputFieldArg): StringInputField => {
const field: Omit<StringInputField, BaseFieldProperties> = {
type: 'string',
value: schemaObject.default ?? '',
};
if (schemaObject.minLength !== undefined) {
field.minLength = schemaObject.minLength;
}
if (schemaObject.maxLength !== undefined) {
field.maxLength = schemaObject.maxLength;
}
if (schemaObject.pattern !== undefined) {
field.pattern = schemaObject.pattern;
}
return { ...baseField, ...field };
};
const buildBooleanInputField = ({
schemaObject,
baseField,
}: BuildInputFieldArg): BooleanInputField => {
const field: Omit<BooleanInputField, BaseFieldProperties> = {
type: 'boolean',
value: schemaObject.default ?? false,
};
return { ...baseField, ...field };
};
const buildModelInputField = ({
schemaObject,
baseField,
}: BuildInputFieldArg): ModelInputField => {
const field: Omit<ModelInputField, BaseFieldProperties> = {
type: 'model',
value: schemaObject.default ?? '',
connectionType: 'never',
};
return { ...baseField, ...field };
};
const buildImageInputField = ({
schemaObject,
baseField,
}: BuildInputFieldArg): ImageInputField => {
const field: Omit<ImageInputField, BaseFieldProperties> = {
type: 'image',
value: schemaObject.default ?? '',
connectionType: 'always',
};
return { ...baseField, ...field };
};
const buildLatentsInputField = ({
schemaObject,
baseField,
}: BuildInputFieldArg): LatentsInputField => {
const field: Omit<LatentsInputField, BaseFieldProperties> = {
type: 'latents',
value: schemaObject.default ?? '',
connectionType: 'always',
};
return { ...baseField, ...field };
};
const buildEnumInputField = ({
schemaObject,
baseField,
}: BuildInputFieldArg): EnumInputField => {
const field: Omit<EnumInputField, BaseFieldProperties> = {
...baseField,
type: 'enum',
value: schemaObject.default,
enumType: (schemaObject.type as 'string' | 'number') ?? 'string', // TODO: dangerous?
options: schemaObject.enum ?? [],
};
return { ...baseField, ...field };
};
export const getFieldType = (
schemaObject: OpenAPIV3.SchemaObject,
name: string,
typeHints?: TypeHints
): FieldType | undefined => {
let rawFieldType = '';
if (typeHints && name in typeHints) {
rawFieldType = typeHints[name];
} else if (!schemaObject.type) {
rawFieldType = refObjectToFieldType(
schemaObject.allOf![0] as OpenAPIV3.ReferenceObject
);
} else if (schemaObject.enum) {
rawFieldType = 'enum';
} else if (schemaObject.type) {
rawFieldType = schemaObject.type;
}
return FIELD_TYPE_MAP[rawFieldType];
};
/**
* Builds an input field from an invocation schema property.
* @param schemaObject The schema object
* @returns An input field
*/
export const buildInputField = (
schemaObject: OpenAPIV3.SchemaObject,
name: string,
typeHints?: TypeHints
) => {
const fieldType = getFieldType(schemaObject, name, typeHints);
if (!fieldType) {
throw `Field type "${fieldType}" is unknown!`;
}
const baseField = {
name,
title: schemaObject.title ?? '',
description: schemaObject.description ?? '',
};
if (['image'].includes(fieldType)) {
return buildImageInputField({ schemaObject, baseField });
}
if (['latents'].includes(fieldType)) {
return buildLatentsInputField({ schemaObject, baseField });
}
if (['model'].includes(fieldType)) {
return buildModelInputField({ schemaObject, baseField });
}
if (['enum'].includes(fieldType)) {
return buildEnumInputField({ schemaObject, baseField });
}
if (['integer'].includes(fieldType)) {
return buildIntegerInputField({ schemaObject, baseField });
}
if (['number', 'float'].includes(fieldType)) {
return buildFloatInputField({ schemaObject, baseField });
}
if (['string'].includes(fieldType)) {
return buildStringInputField({ schemaObject, baseField });
}
if (['boolean'].includes(fieldType)) {
return buildBooleanInputField({ schemaObject, baseField });
}
return;
};
/**
* Builds invocation output fields from an invocation's output reference object.
* @param openAPI The OpenAPI schema
* @param refObject The output reference object
* @returns A record of outputs
*/
export const buildOutputFields = (
refObject: OpenAPIV3.ReferenceObject,
openAPI: OpenAPIV3.Document,
typeHints?: TypeHints
): Record<string, OutputField> => {
// extract output schema name from ref
const outputSchemaName = refObject.$ref.split('/').slice(-1)[0];
// get the output schema itself
const outputSchema = openAPI.components!.schemas![outputSchemaName];
if (isSchemaObject(outputSchema)) {
const outputFields = reduce(
outputSchema.properties as OpenAPIV3.SchemaObject,
(outputsAccumulator, property, propertyName) => {
if (
!['type', 'id'].includes(propertyName) &&
isSchemaObject(property)
) {
const fieldType = getFieldType(property, propertyName, typeHints);
if (!fieldType) {
throw `Field type "${fieldType}" is unknown!`;
}
outputsAccumulator[propertyName] = {
name: propertyName,
title: property.title ?? '',
description: property.description ?? '',
type: fieldType,
};
}
return outputsAccumulator;
},
{} as Record<string, OutputField>
);
return outputFields;
}
return {};
};

View File

@ -0,0 +1,82 @@
import { filter, reduce } from 'lodash';
import { OpenAPIV3 } from 'openapi-types';
import {
InputField,
Invocation,
InvocationSchemaObject,
isInvocationSchemaObject,
isSchemaObject,
} from '../types';
import { buildInputField, buildOutputFields } from './invocationFieldBuilders';
export const parseSchema = (openAPI: OpenAPIV3.Document) => {
// filter out non-invocation schemas, plus some tricky invocations for now
const filteredSchemas = filter(
openAPI.components!.schemas,
(schema, key) =>
key.includes('Invocation') &&
!key.includes('InvocationOutput') &&
!key.includes('Collect') &&
!key.includes('Range') &&
!key.includes('Iterate') &&
!key.includes('LoadImage') &&
!key.includes('Graph')
) as (OpenAPIV3.ReferenceObject | InvocationSchemaObject)[];
const invocations = filteredSchemas.reduce<Record<string, Invocation>>(
(acc, schema) => {
// only want SchemaObjects
if (isInvocationSchemaObject(schema)) {
const type = schema.properties.type.default;
const title = schema.title
.replace('Invocation', '')
.split(/(?=[A-Z])/) // split PascalCase into array
.join(' ');
const typeHints = schema.ui?.type_hints;
const inputs = reduce(
schema.properties,
(inputsAccumulator, property, propertyName) => {
if (
// `type` and `id` are not valid inputs/outputs
!['type', 'id'].includes(propertyName) &&
isSchemaObject(property)
) {
const field = buildInputField(property, propertyName, typeHints);
if (field) {
inputsAccumulator[propertyName] = field;
}
}
return inputsAccumulator;
},
{} as Record<string, InputField>
);
const rawOutput = (schema as InvocationSchemaObject).output;
const outputs = buildOutputFields(rawOutput, openAPI, typeHints);
const invocation: Invocation = {
title,
type,
tags: schema.ui?.tags ?? [],
description: schema.description ?? '',
inputs,
outputs,
};
acc[type] = invocation;
}
return acc;
},
{}
);
console.debug('Generated invocations: ', invocations);
return invocations;
};

View File

@ -21,9 +21,10 @@ type ParametersAccordionsType = {
const ParametersAccordion = (props: ParametersAccordionsType) => {
const { accordionInfo } = props;
const openAccordions = useAppSelector(
(state: RootState) => state.system.openAccordions
);
const { system, ui } = useAppSelector((state: RootState) => state);
const { openAccordions } = system;
const { disabledParameterPanels } = ui;
const dispatch = useAppDispatch();
@ -39,6 +40,9 @@ const ParametersAccordion = (props: ParametersAccordionsType) => {
Object.keys(accordionInfo).forEach((key) => {
const { header, feature, content, additionalHeaderComponents } =
accordionInfo[key];
// do not render if panel is disabled in global state
if (disabledParameterPanels.indexOf(key) === -1) {
accordionsToRender.push(
<InvokeAccordionItem
key={key}
@ -48,6 +52,7 @@ const ParametersAccordion = (props: ParametersAccordionsType) => {
additionalHeaderComponents={additionalHeaderComponents}
/>
);
}
});
}
return accordionsToRender;

View File

@ -1,5 +1,4 @@
import { createSelector } from '@reduxjs/toolkit';
import { cancelProcessing } from 'app/socketio/actions';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import IAIIconButton, {
IAIIconButtonProps,
@ -9,16 +8,36 @@ import {
SystemState,
setCancelAfter,
setCancelType,
cancelScheduled,
cancelTypeChanged,
CancelType,
} from 'features/system/store/systemSlice';
import { isEqual } from 'lodash';
import { useEffect, useCallback, memo } from 'react';
import { ButtonSpinner, ButtonGroup } from '@chakra-ui/react';
import {
ButtonSpinner,
ButtonGroup,
Menu,
MenuButton,
MenuList,
MenuOptionGroup,
MenuItemOption,
IconButton,
} from '@chakra-ui/react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import { MdCancel, MdCancelScheduleSend } from 'react-icons/md';
import {
MdArrowDropDown,
MdArrowDropUp,
MdCancel,
MdCancelScheduleSend,
} from 'react-icons/md';
import IAISimpleMenu from 'common/components/IAISimpleMenu';
import { sessionCanceled } from 'services/thunks/session';
import { FaChevronDown } from 'react-icons/fa';
import { BiChevronDown } from 'react-icons/bi';
const cancelButtonSelector = createSelector(
systemSelector,
@ -29,8 +48,11 @@ const cancelButtonSelector = createSelector(
isCancelable: system.isCancelable,
currentIteration: system.currentIteration,
totalIterations: system.totalIterations,
cancelType: system.cancelOptions.cancelType,
cancelAfter: system.cancelOptions.cancelAfter,
// cancelType: system.cancelOptions.cancelType,
// cancelAfter: system.cancelOptions.cancelAfter,
sessionId: system.sessionId,
cancelType: system.cancelType,
isCancelScheduled: system.isCancelScheduled,
};
},
{
@ -56,16 +78,34 @@ const CancelButton = (
currentIteration,
totalIterations,
cancelType,
cancelAfter,
isCancelScheduled,
// cancelAfter,
sessionId,
} = useAppSelector(cancelButtonSelector);
const handleClickCancel = useCallback(() => {
dispatch(cancelProcessing());
dispatch(setCancelAfter(null));
}, [dispatch]);
if (!sessionId) {
return;
}
if (cancelType === 'scheduled') {
dispatch(cancelScheduled());
return;
}
dispatch(sessionCanceled({ sessionId }));
}, [dispatch, sessionId, cancelType]);
const { t } = useTranslation();
const isCancelScheduled = cancelAfter === null ? false : true;
const handleCancelTypeChanged = useCallback(
(value: string | string[]) => {
const newCancelType = Array.isArray(value) ? value[0] : value;
dispatch(cancelTypeChanged(newCancelType as CancelType));
},
[dispatch]
);
// const isCancelScheduled = cancelAfter === null ? false : true;
useHotkeys(
'shift+x',
@ -77,22 +117,22 @@ const CancelButton = (
[isConnected, isProcessing, isCancelable]
);
useEffect(() => {
if (cancelAfter !== null && cancelAfter < currentIteration) {
handleClickCancel();
}
}, [cancelAfter, currentIteration, handleClickCancel]);
// useEffect(() => {
// if (cancelAfter !== null && cancelAfter < currentIteration) {
// handleClickCancel();
// }
// }, [cancelAfter, currentIteration, handleClickCancel]);
const cancelMenuItems = [
{
item: t('parameters.cancel.immediate'),
onClick: () => dispatch(setCancelType('immediate')),
},
{
item: t('parameters.cancel.schedule'),
onClick: () => dispatch(setCancelType('scheduled')),
},
];
// const cancelMenuItems = [
// {
// item: t('parameters.cancel.immediate'),
// onClick: () => dispatch(cancelTypeChanged('immediate')),
// },
// {
// item: t('parameters.cancel.schedule'),
// onClick: () => dispatch(cancelTypeChanged('scheduled')),
// },
// ];
return (
<ButtonGroup isAttached width={btnGroupWidth}>
@ -121,29 +161,40 @@ const CancelButton = (
? t('parameters.cancel.isScheduled')
: t('parameters.cancel.schedule')
}
isDisabled={
!isConnected ||
!isProcessing ||
!isCancelable ||
currentIteration === totalIterations
}
onClick={() => {
// If a cancel request has already been made, and the user clicks again before the next iteration has been processed, stop the request.
if (isCancelScheduled) dispatch(setCancelAfter(null));
else dispatch(setCancelAfter(currentIteration));
}}
isDisabled={!isConnected || !isProcessing || !isCancelable}
onClick={handleClickCancel}
colorScheme="error"
{...rest}
/>
)}
<IAISimpleMenu
menuItems={cancelMenuItems}
iconTooltip={t('parameters.cancel.setType')}
menuButtonProps={{
colorScheme: 'error',
minWidth: 5,
}}
<Menu closeOnSelect={false}>
<MenuButton
as={IconButton}
tooltip={t('parameters.cancel.setType')}
aria-label={t('parameters.cancel.setType')}
icon={<BiChevronDown />}
paddingX={0}
paddingY={0}
colorScheme="error"
minWidth={5}
/>
<MenuList minWidth="240px">
<MenuOptionGroup
value={cancelType}
title="Cancel Type"
type="radio"
onChange={handleCancelTypeChanged}
>
<MenuItemOption value="immediate">
{t('parameters.cancel.immediate')}
</MenuItemOption>
<MenuItemOption value="scheduled">
{t('parameters.cancel.schedule')}
</MenuItemOption>
</MenuOptionGroup>
</MenuList>
</Menu>
</ButtonGroup>
);
};

View File

@ -11,6 +11,7 @@ import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import { FaPlay } from 'react-icons/fa';
import { linearGraphBuilt, sessionCreated } from 'services/thunks/session';
interface InvokeButton
extends Omit<IAIButtonProps | IAIIconButtonProps, 'aria-label'> {
@ -24,7 +25,8 @@ export default function InvokeButton(props: InvokeButton) {
const activeTabName = useAppSelector(activeTabNameSelector);
const handleClickGenerate = () => {
dispatch(generateImage(activeTabName));
// dispatch(generateImage(activeTabName));
dispatch(linearGraphBuilt());
};
const { t } = useTranslation();

View File

@ -1,5 +1,11 @@
import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store';
import { gallerySelector } from 'features/gallery/store/gallerySelectors';
import {
selectResultsById,
selectResultsEntities,
} from 'features/gallery/store/resultsSlice';
import { selectUploadsById } from 'features/gallery/store/uploadsSlice';
import { isEqual } from 'lodash';
export const generationSelector = (state: RootState) => state.generation;
@ -15,3 +21,15 @@ export const mayGenerateMultipleImagesSelector = createSelector(
},
}
);
export const initialImageSelector = createSelector(
[(state: RootState) => state, generationSelector],
(state, generation) => {
const { initialImage: initialImageName } = generation;
return (
selectResultsById(state, initialImageName as string) ??
selectUploadsById(state, initialImageName as string)
);
}
);

View File

@ -11,7 +11,7 @@ export interface GenerationState {
height: number;
img2imgStrength: number;
infillMethod: string;
initialImage?: InvokeAI.Image | string; // can be an Image or url
initialImage?: InvokeAI._Image | string; // can be an Image or url
iterations: number;
maskPath: string;
perlin: number;
@ -317,12 +317,12 @@ export const generationSlice = createSlice({
setShouldRandomizeSeed: (state, action: PayloadAction<boolean>) => {
state.shouldRandomizeSeed = action.payload;
},
setInitialImage: (
state,
action: PayloadAction<InvokeAI.Image | string>
) => {
state.initialImage = action.payload;
},
// setInitialImage: (
// state,
// action: PayloadAction<InvokeAI._Image | string>
// ) => {
// state.initialImage = action.payload;
// },
clearInitialImage: (state) => {
state.initialImage = undefined;
},
@ -353,6 +353,9 @@ export const generationSlice = createSlice({
setVerticalSymmetrySteps: (state, action: PayloadAction<number>) => {
state.verticalSymmetrySteps = action.payload;
},
initialImageSelected: (state, action: PayloadAction<string>) => {
state.initialImage = action.payload;
},
},
});
@ -368,7 +371,7 @@ export const {
setHeight,
setImg2imgStrength,
setInfillMethod,
setInitialImage,
// setInitialImage,
setIterations,
setMaskPath,
setParameter,
@ -394,6 +397,7 @@ export const {
setShouldUseSymmetry,
setHorizontalSymmetrySteps,
setVerticalSymmetrySteps,
initialImageSelected,
} = generationSlice.actions;
export default generationSlice.reducer;

View File

@ -1,20 +1,20 @@
import { Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { requestModelChange } from 'app/socketio/actions';
import { ChangeEvent } from 'react';
import { isEqual, map } from 'lodash';
import { useTranslation } from 'react-i18next';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import IAISelect from 'common/components/IAISelect';
import { isEqual, map } from 'lodash';
import { ChangeEvent } from 'react';
import { useTranslation } from 'react-i18next';
import { activeModelSelector, systemSelector } from '../store/systemSelectors';
import { modelSelector } from '../store/modelSelectors';
import { setCurrentModel } from '../store/modelSlice';
const selector = createSelector(
[systemSelector],
(system) => {
const { isProcessing, model_list } = system;
const models = map(model_list, (model, key) => key);
return { models, isProcessing };
[modelSelector],
(model) => {
const { modelList, currentModel } = model;
const models = map(modelList, (model, key) => key);
return { models, currentModel, modelList };
},
{
memoizeOptions: {
@ -26,11 +26,12 @@ const selector = createSelector(
const ModelSelect = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { models, isProcessing } = useAppSelector(selector);
const activeModel = useAppSelector(activeModelSelector);
const { models, currentModel, modelList } = useAppSelector(selector);
const handleChangeModel = (e: ChangeEvent<HTMLSelectElement>) => {
dispatch(requestModelChange(e.target.value));
dispatch(setCurrentModel(e.target.value));
};
const currentModelDescription =
currentModel && modelList[currentModel].description;
return (
<Flex
@ -41,9 +42,8 @@ const ModelSelect = () => {
<IAISelect
style={{ fontSize: 'sm' }}
aria-label={t('accessibility.modelSelect')}
tooltip={activeModel.description}
isDisabled={isProcessing}
value={activeModel.name}
tooltip={currentModelDescription}
value={currentModel}
validValues={models}
onChange={handleChangeModel}
/>

View File

@ -80,7 +80,7 @@ const StatusIndicator = () => {
cursor={statusIndicatorCursor}
onClick={handleClickStatusIndicator}
sx={{
fontSize: 'xs',
fontSize: 'sm',
fontWeight: '600',
color: `${statusIdentifier}.400`,
}}

View File

@ -1,9 +1,24 @@
import { useToast } from '@chakra-ui/react';
import { useToast, UseToastOptions } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { toastQueueSelector } from 'features/system/store/systemSelectors';
import { clearToastQueue } from 'features/system/store/systemSlice';
import { useEffect } from 'react';
export type MakeToastArg = string | UseToastOptions;
export const makeToast = (arg: MakeToastArg): UseToastOptions => {
if (typeof arg === 'string') {
return {
title: arg,
status: 'info',
isClosable: true,
duration: 2500,
};
}
return { status: 'info', isClosable: true, duration: 2500, ...arg };
};
const useToastWatcher = () => {
const dispatch = useAppDispatch();
const toastQueue = useAppSelector(toastQueueSelector);

View File

@ -0,0 +1,5 @@
import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store';
import { reduce } from 'lodash';
export const modelSelector = (state: RootState) => state.models;

View File

@ -0,0 +1,40 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import { ModelsList } from 'services/api';
import { receivedModels } from 'services/thunks/model';
export interface ModelState {
modelList: ModelsList['models'];
currentModel?: string;
}
const initialModelState: ModelState = {
modelList: {},
currentModel: undefined,
};
export const modelSlice = createSlice({
name: 'model',
initialState: initialModelState,
reducers: {
setModelList: (state, action: PayloadAction<ModelsList['models']>) => {
state.modelList = action.payload;
},
setCurrentModel: (state, action: PayloadAction<string>) => {
state.currentModel = action.payload;
},
},
extraReducers(builder) {
/**
* Received Models - FULFILLED
*/
builder.addCase(receivedModels.fulfilled, (state, action) => {
const models = action.payload.models;
state.modelList = models;
});
},
});
export const { setModelList, setCurrentModel } = modelSlice.actions;
export default modelSlice.reducer;

View File

@ -2,7 +2,23 @@ import { ExpandedIndex, UseToastOptions } from '@chakra-ui/react';
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import * as InvokeAI from 'app/invokeai';
import {
generatorProgress,
invocationComplete,
invocationError,
invocationStarted,
socketConnected,
socketDisconnected,
socketSubscribed,
socketUnsubscribed,
} from 'services/events/actions';
import i18n from 'i18n';
import { isImageOutput } from 'services/types/guards';
import { ProgressImage } from 'services/events/types';
import { initialImageSelected } from 'features/parameters/store/generationSlice';
import { makeToast } from '../hooks/useToastWatcher';
import { sessionCanceled, sessionInvoked } from 'services/thunks/session';
export type LogLevel = 'info' | 'warning' | 'error';
@ -56,6 +72,30 @@ export interface SystemState
cancelType: CancelType;
cancelAfter: number | null;
};
/**
* The current progress image
*/
progressImage: ProgressImage | null;
/**
* The current socket session id
*/
sessionId: string | null;
/**
* Cancel strategy
*/
cancelType: CancelType;
/**
* Whether or not a scheduled cancelation is pending
*/
isCancelScheduled: boolean;
/**
* Array of node IDs that we want to handle when events received
*/
subscribedNodeIds: string[];
/**
* Whether or not URLs should be transformed to use a different host
*/
shouldTransformUrls: boolean;
}
const initialSystemState: SystemState = {
@ -98,6 +138,12 @@ const initialSystemState: SystemState = {
cancelType: 'immediate',
cancelAfter: null,
},
progressImage: null,
sessionId: null,
cancelType: 'immediate',
isCancelScheduled: false,
subscribedNodeIds: [],
shouldTransformUrls: false,
};
export const systemSlice = createSlice({
@ -271,6 +317,203 @@ export const systemSlice = createSlice({
setCancelAfter: (state, action: PayloadAction<number | null>) => {
state.cancelOptions.cancelAfter = action.payload;
},
/**
* A cancel was scheduled
*/
cancelScheduled: (state) => {
state.isCancelScheduled = true;
},
/**
* The scheduled cancel was aborted
*/
scheduledCancelAborted: (state) => {
state.isCancelScheduled = false;
},
/**
* The cancel type was changed
*/
cancelTypeChanged: (state, action: PayloadAction<CancelType>) => {
state.cancelType = action.payload;
},
/**
* The array of subscribed node ids was changed
*/
subscribedNodeIdsSet: (state, action: PayloadAction<string[]>) => {
state.subscribedNodeIds = action.payload;
},
/**
* `shouldTransformUrls` was changed
*/
shouldTransformUrlsChanged: (state, action: PayloadAction<boolean>) => {
state.shouldTransformUrls = action.payload;
},
},
extraReducers(builder) {
/**
* Socket Subscribed
*/
builder.addCase(socketSubscribed, (state, action) => {
state.sessionId = action.payload.sessionId;
});
/**
* Socket Unsubscribed
*/
builder.addCase(socketUnsubscribed, (state) => {
state.sessionId = null;
});
/**
* Socket Connected
*/
builder.addCase(socketConnected, (state, action) => {
const { timestamp } = action.payload;
state.isConnected = true;
state.currentStatus = i18n.t('common.statusConnected');
state.log.push({
timestamp,
message: `Connected to server`,
level: 'info',
});
state.toastQueue.push(
makeToast({ title: i18n.t('toast.connected'), status: 'success' })
);
});
/**
* Socket Disconnected
*/
builder.addCase(socketDisconnected, (state, action) => {
const { timestamp } = action.payload;
state.isConnected = false;
state.currentStatus = i18n.t('common.statusDisconnected');
state.log.push({
timestamp,
message: `Disconnected from server`,
level: 'error',
});
state.toastQueue.push(
makeToast({ title: i18n.t('toast.disconnected'), status: 'error' })
);
});
/**
* Invocation Started
*/
builder.addCase(invocationStarted, (state) => {
state.isProcessing = true;
state.isCancelable = true;
state.currentStatusHasSteps = false;
state.currentStatus = i18n.t('common.statusGenerating');
});
/**
* Generator Progress
*/
builder.addCase(generatorProgress, (state, action) => {
const {
step,
total_steps,
progress_image,
invocation,
graph_execution_state_id,
} = action.payload.data;
state.currentStatusHasSteps = true;
state.currentStep = step + 1; // TODO: step starts at -1, think this is a bug
state.totalSteps = total_steps;
state.progressImage = progress_image ?? null;
});
/**
* Invocation Complete
*/
builder.addCase(invocationComplete, (state, action) => {
const { data, timestamp } = action.payload;
state.isProcessing = false;
state.currentStep = 0;
state.totalSteps = 0;
state.progressImage = null;
state.currentStatus = i18n.t('common.statusProcessingComplete');
// TODO: handle logging for other invocation types
if (isImageOutput(data.result)) {
state.log.push({
timestamp,
message: `Generated: ${data.result.image.image_name}`,
level: 'info',
});
}
});
/**
* Invocation Error
*/
builder.addCase(invocationError, (state, action) => {
const { data, timestamp } = action.payload;
state.log.push({
timestamp,
message: `Server error: ${data.error}`,
level: 'error',
});
state.wasErrorSeen = true;
state.progressImage = null;
state.isProcessing = false;
state.toastQueue.push(
makeToast({ title: i18n.t('toast.serverError'), status: 'error' })
);
state.log.push({
timestamp,
message: `Server error: ${data.error}`,
level: 'error',
});
});
/**
* Session Invoked - PENDING
*/
builder.addCase(sessionInvoked.pending, (state) => {
state.currentStatus = i18n.t('common.statusPreparing');
});
/**
* Session Canceled
*/
builder.addCase(sessionCanceled.fulfilled, (state, action) => {
const { timestamp } = action.payload;
state.isProcessing = false;
state.isCancelable = false;
state.isCancelScheduled = false;
state.currentStep = 0;
state.totalSteps = 0;
state.progressImage = null;
state.toastQueue.push(
makeToast({ title: i18n.t('toast.canceled'), status: 'warning' })
);
state.log.push({
timestamp,
message: `Processing canceled`,
level: 'warning',
});
});
/**
* Initial Image Selected
*/
builder.addCase(initialImageSelected, (state) => {
state.toastQueue.push(makeToast(i18n.t('toast.sentToImageToImage')));
});
},
});
@ -306,6 +549,11 @@ export const {
setOpenModel,
setCancelType,
setCancelAfter,
cancelScheduled,
scheduledCancelAborted,
cancelTypeChanged,
subscribedNodeIdsSet,
shouldTransformUrlsChanged,
} = systemSlice.actions;
export default systemSlice.reducer;

View File

@ -34,6 +34,7 @@ import UnifiedCanvasWorkarea from 'features/ui/components/tabs/UnifiedCanvas/Uni
import { useTranslation } from 'react-i18next';
import { ResourceKey } from 'i18next';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import NodeEditor from 'features/nodes/components/NodeEditor';
export interface InvokeTabInfo {
id: InvokeTabName;
@ -45,7 +46,8 @@ const tabIconStyles: ChakraProps['sx'] = {
boxSize: 6,
};
const tabInfo: InvokeTabInfo[] = [
const buildTabs = (disabledTabs: InvokeTabName[]): InvokeTabInfo[] => {
const tabs: InvokeTabInfo[] = [
{
id: 'txt2img',
icon: <Icon as={MdTextFields} sx={tabIconStyles} />,
@ -64,7 +66,7 @@ const tabInfo: InvokeTabInfo[] = [
{
id: 'nodes',
icon: <Icon as={MdDeviceHub} sx={tabIconStyles} />,
workarea: <NodesWIP />,
workarea: <NodeEditor />,
},
{
id: 'postprocessing',
@ -76,7 +78,9 @@ const tabInfo: InvokeTabInfo[] = [
icon: <Icon as={MdFlashOn} sx={tabIconStyles} />,
workarea: <TrainingWIP />,
},
];
];
return tabs.filter((tab) => !disabledTabs.includes(tab.id));
};
export default function InvokeTabs() {
const activeTab = useAppSelector(activeTabIndexSelector);
@ -85,13 +89,10 @@ export default function InvokeTabs() {
(state: RootState) => state.lightbox.isLightboxOpen
);
const shouldPinGallery = useAppSelector(
(state: RootState) => state.ui.shouldPinGallery
);
const { shouldPinGallery, disabledTabs, shouldPinParametersPanel } =
useAppSelector((state: RootState) => state.ui);
const shouldPinParametersPanel = useAppSelector(
(state: RootState) => state.ui.shouldPinParametersPanel
);
const activeTabs = buildTabs(disabledTabs);
const { t } = useTranslation();
@ -142,7 +143,7 @@ export default function InvokeTabs() {
const tabs = useMemo(
() =>
tabInfo.map((tab) => (
activeTabs.map((tab) => (
<Tooltip
key={tab.id}
hasArrow
@ -157,13 +158,13 @@ export default function InvokeTabs() {
</Tab>
</Tooltip>
)),
[t]
[t, activeTabs]
);
const tabPanels = useMemo(
() =>
tabInfo.map((tab) => <TabPanel key={tab.id}>{tab.workarea}</TabPanel>),
[]
activeTabs.map((tab) => <TabPanel key={tab.id}>{tab.workarea}</TabPanel>),
[activeTabs]
);
return (
@ -174,6 +175,7 @@ export default function InvokeTabs() {
dispatch(setActiveTab(index));
}}
flexGrow={1}
isLazy
>
<TabList>{tabs}</TabList>
<TabPanels>{tabPanels}</TabPanels>

View File

@ -1,7 +1,7 @@
import { Box, BoxProps, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { setInitialImage } from 'features/parameters/store/generationSlice';
import { initialImageSelected } from 'features/parameters/store/generationSlice';
import {
activeTabNameSelector,
uiSelector,
@ -47,7 +47,7 @@ const InvokeWorkarea = (props: InvokeWorkareaProps) => {
const image = getImageByUuid(uuid);
if (!image) return;
if (activeTabName === 'img2img') {
dispatch(setInitialImage(image));
dispatch(initialImageSelected(image.uuid));
} else if (activeTabName === 'unifiedCanvas') {
dispatch(setInitialCanvasImage(image));
}

View File

@ -96,7 +96,6 @@ const ParametersPanel = ({ children }: ParametersPanelProps) => {
onClose={closeParametersPanel}
isPinned={shouldPinParametersPanel || isLightboxOpen}
sx={{
borderColor: 'base.700',
p: shouldPinParametersPanel ? 0 : 4,
bg: 'base.900',
}}

Some files were not shown because too many files have changed in this diff Show More