Merge branch 'main' of https://github.com/invoke-ai/InvokeAI into responsive-ui

This commit is contained in:
SammCheese 2023-04-22 08:27:00 +02:00
commit 3fb433cb91
No known key found for this signature in database
GPG Key ID: 28CFE2321A140BA1
329 changed files with 13081 additions and 1437 deletions

14
.github/CODEOWNERS vendored
View File

@ -1,16 +1,16 @@
# continuous integration
/.github/workflows/ @mauwii @lstein @blessedcoolant
/.github/workflows/ @lstein @blessedcoolant
# documentation
/docs/ @lstein @mauwii @tildebyte @blessedcoolant
/mkdocs.yml @lstein @mauwii @blessedcoolant
/docs/ @lstein @tildebyte @blessedcoolant
/mkdocs.yml @lstein @blessedcoolant
# nodes
/invokeai/app/ @Kyle0654 @blessedcoolant
# installation and configuration
/pyproject.toml @mauwii @lstein @blessedcoolant
/docker/ @mauwii @lstein @blessedcoolant
/pyproject.toml @lstein @blessedcoolant
/docker/ @lstein @blessedcoolant
/scripts/ @ebr @lstein
/installer/ @lstein @ebr
/invokeai/assets @lstein @ebr
@ -22,11 +22,11 @@
/invokeai/backend @blessedcoolant @psychedelicious @lstein
# generation, model management, postprocessing
/invokeai/backend @keturn @damian0815 @lstein @blessedcoolant @jpphoto
/invokeai/backend @damian0815 @lstein @blessedcoolant @jpphoto @gregghelt2
# front ends
/invokeai/frontend/CLI @lstein
/invokeai/frontend/install @lstein @ebr @mauwii
/invokeai/frontend/install @lstein @ebr
/invokeai/frontend/merge @lstein @blessedcoolant @hipsterusername
/invokeai/frontend/training @lstein @blessedcoolant @hipsterusername
/invokeai/frontend/web @psychedelicious @blessedcoolant

2
.gitignore vendored
View File

@ -9,6 +9,8 @@ models/ldm/stable-diffusion-v1/model.ckpt
configs/models.user.yaml
config/models.user.yml
invokeai.init
.version
.last_model
# ignore the Anaconda/Miniconda installer used while building Docker image
anaconda.sh

View File

@ -148,6 +148,11 @@ not supported.
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
```
_For non-GPU systems:_
```terminal
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/cpu
```
_For Macintoshes, either Intel or M1/M2:_
```sh

View File

@ -32,7 +32,7 @@ turned on and off on the command line using `--nsfw_checker` and
At installation time, InvokeAI will ask whether the checker should be
activated by default (neither argument given on the command line). The
response is stored in the InvokeAI initialization file (usually
`.invokeai` in your home directory). You can change the default at any
`invokeai.init` in your home directory). You can change the default at any
time by opening this file in a text editor and commenting or
uncommenting the line `--nsfw_checker`.

View File

@ -3,6 +3,8 @@
import os
from argparse import Namespace
from invokeai.app.services.metadata import PngMetadataService, MetadataServiceBase
from ..services.default_graphs import create_system_graphs
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
@ -60,7 +62,9 @@ class ApiDependencies:
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents'))
images = DiskImageStorage(f'{output_folder}/images')
metadata = PngMetadataService()
images = DiskImageStorage(f'{output_folder}/images', metadata_service=metadata)
# TODO: build a file/path manager?
db_location = os.path.join(output_folder, "invokeai.db")
@ -70,6 +74,7 @@ class ApiDependencies:
events=events,
latents=latents,
images=images,
metadata=metadata,
queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](
filename=db_location, table_name="graphs"

View File

@ -1,7 +1,19 @@
from typing import Optional
from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageType
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.services.metadata import InvokeAIMetadata
class ImageResponseMetadata(BaseModel):
"""An image's metadata. Used only in HTTP responses."""
created: int = Field(description="The creation timestamp of the image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
invokeai: Optional[InvokeAIMetadata] = Field(
description="The image's InvokeAI-specific metadata"
)
class ImageResponse(BaseModel):
@ -11,4 +23,12 @@ class ImageResponse(BaseModel):
image_name: str = Field(description="The name of the image")
image_url: str = Field(description="The url of the image")
thumbnail_url: str = Field(description="The url of the image's thumbnail")
metadata: ImageMetadata = Field(description="The image's metadata")
metadata: ImageResponseMetadata = Field(description="The image's metadata")
class ProgressImage(BaseModel):
"""The progress image sent intermittently during processing"""
width: int = Field(description="The effective width of the image in pixels")
height: int = Field(description="The effective height of the image in pixels")
dataURL: str = Field(description="The image data as a b64 data URL")

View File

@ -1,13 +1,17 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import io
from datetime import datetime, timezone
import json
import os
from typing import Any
import uuid
from fastapi import Path, Query, Request, UploadFile
from fastapi import HTTPException, Path, Query, Request, UploadFile
from fastapi.responses import FileResponse, Response
from fastapi.routing import APIRouter
from PIL import Image
from invokeai.app.api.models.images import ImageResponse
from invokeai.app.api.models.images import ImageResponse, ImageResponseMetadata
from invokeai.app.services.metadata import InvokeAIMetadata
from invokeai.app.services.item_storage import PaginatedResults
from ...services.image_storage import ImageType
@ -15,70 +19,110 @@ from ..dependencies import ApiDependencies
images_router = APIRouter(prefix="/v1/images", tags=["images"])
@images_router.get("/{image_type}/{image_name}", operation_id="get_image")
async def get_image(
image_type: ImageType = Path(description="The type of image to get"),
image_name: str = Path(description="The name of the image to get"),
):
) -> FileResponse | Response:
"""Gets a result"""
# TODO: This is not really secure at all. At least make sure only output results are served
filename = ApiDependencies.invoker.services.images.get_path(image_type, image_name)
return FileResponse(filename)
@images_router.get("/{image_type}/thumbnails/{image_name}", operation_id="get_thumbnail")
path = ApiDependencies.invoker.services.images.get_path(
image_type=image_type, image_name=image_name
)
if ApiDependencies.invoker.services.images.validate_path(path):
return FileResponse(path)
else:
raise HTTPException(status_code=404)
@images_router.get(
"/{image_type}/thumbnails/{image_name}", operation_id="get_thumbnail"
)
async def get_thumbnail(
image_type: ImageType = Path(description="The type of image to get"),
image_name: str = Path(description="The name of the image to get"),
):
) -> FileResponse | Response:
"""Gets a thumbnail"""
# TODO: This is not really secure at all. At least make sure only output results are served
filename = ApiDependencies.invoker.services.images.get_path(image_type, 'thumbnails/' + image_name)
return FileResponse(filename)
path = ApiDependencies.invoker.services.images.get_path(
image_type=image_type, image_name=image_name, is_thumbnail=True
)
if ApiDependencies.invoker.services.images.validate_path(path):
return FileResponse(path)
else:
raise HTTPException(status_code=404)
@images_router.post(
"/uploads/",
operation_id="upload_image",
responses={
201: {"description": "The image was uploaded successfully"},
404: {"description": "Session not found"},
201: {
"description": "The image was uploaded successfully",
"model": ImageResponse,
},
415: {"description": "Image upload failed"},
},
status_code=201,
)
async def upload_image(file: UploadFile, request: Request):
async def upload_image(
file: UploadFile, request: Request, response: Response
) -> ImageResponse:
if not file.content_type.startswith("image"):
return Response(status_code=415)
raise HTTPException(status_code=415, detail="Not an image")
contents = await file.read()
try:
im = Image.open(contents)
img = Image.open(io.BytesIO(contents))
except:
# Error opening the image
return Response(status_code=415)
raise HTTPException(status_code=415, detail="Failed to read image")
filename = f"{uuid.uuid4()}_{str(int(datetime.now(timezone.utc).timestamp()))}.png"
ApiDependencies.invoker.services.images.save(ImageType.UPLOAD, filename, im)
return Response(
status_code=201,
headers={
"Location": request.url_for(
"get_image", image_type=ImageType.UPLOAD.value, image_name=filename
)
},
(image_path, thumbnail_path, ctime) = ApiDependencies.invoker.services.images.save(
ImageType.UPLOAD, filename, img
)
invokeai_metadata = ApiDependencies.invoker.services.metadata.get_metadata(img)
res = ImageResponse(
image_type=ImageType.UPLOAD,
image_name=filename,
image_url=f"api/v1/images/{ImageType.UPLOAD.value}/{filename}",
thumbnail_url=f"api/v1/images/{ImageType.UPLOAD.value}/thumbnails/{os.path.splitext(filename)[0]}.webp",
metadata=ImageResponseMetadata(
created=ctime,
width=img.width,
height=img.height,
invokeai=invokeai_metadata,
),
)
response.status_code = 201
response.headers["Location"] = request.url_for(
"get_image", image_type=ImageType.UPLOAD.value, image_name=filename
)
return res
@images_router.get(
"/",
operation_id="list_images",
responses={200: {"model": PaginatedResults[ImageResponse]}},
)
async def list_images(
image_type: ImageType = Query(default=ImageType.RESULT, description="The type of images to get"),
image_type: ImageType = Query(
default=ImageType.RESULT, description="The type of images to get"
),
page: int = Query(default=0, description="The page of images to get"),
per_page: int = Query(default=10, description="The number of images per page"),
) -> PaginatedResults[ImageResponse]:
"""Gets a list of images"""
result = ApiDependencies.invoker.services.images.list(
image_type, page, per_page
)
result = ApiDependencies.invoker.services.images.list(image_type, page, per_page)
return result

View File

@ -13,6 +13,8 @@ from typing import (
from pydantic import BaseModel
from pydantic.fields import Field
from invokeai.app.services.metadata import PngMetadataService
from .services.default_graphs import create_system_graphs
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
@ -200,6 +202,8 @@ def invoke_cli():
events = EventServiceBase()
metadata = PngMetadataService()
output_folder = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../../../outputs")
)
@ -211,7 +215,8 @@ def invoke_cli():
model_manager=model_manager,
events=events,
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')),
images=DiskImageStorage(f'{output_folder}/images'),
images=DiskImageStorage(f'{output_folder}/images', metadata_service=metadata),
metadata=metadata,
queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](
filename=db_location, table_name="graphs"

View File

@ -95,7 +95,7 @@ class UIConfig(TypedDict, total=False):
],
]
tags: List[str]
title: str
class CustomisedSchemaExtra(TypedDict):
ui: UIConfig

View File

@ -1,16 +1,17 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal
from typing import Literal, Optional
import cv2 as cv
import numpy as np
import numpy.random
from PIL import Image, ImageOps
from pydantic import Field
from ..services.image_storage import ImageType
from .baseinvocation import BaseInvocation, InvocationContext, BaseInvocationOutput
from .image import ImageField, ImageOutput
from .baseinvocation import (
BaseInvocation,
InvocationConfig,
InvocationContext,
BaseInvocationOutput,
)
class IntCollectionOutput(BaseInvocationOutput):
@ -33,7 +34,9 @@ class RangeInvocation(BaseInvocation):
step: int = Field(default=1, description="The step of the range")
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
return IntCollectionOutput(collection=list(range(self.start, self.stop, self.step)))
return IntCollectionOutput(
collection=list(range(self.start, self.stop, self.step))
)
class RandomRangeInvocation(BaseInvocation):
@ -43,8 +46,19 @@ class RandomRangeInvocation(BaseInvocation):
# Inputs
low: int = Field(default=0, description="The inclusive low value")
high: int = Field(default=np.iinfo(np.int32).max, description="The exclusive high value")
high: int = Field(
default=np.iinfo(np.int32).max, description="The exclusive high value"
)
size: int = Field(default=1, description="The number of values to generate")
seed: Optional[int] = Field(
ge=0,
le=np.iinfo(np.int32).max,
description="The seed for the RNG",
default_factory=lambda: numpy.random.randint(0, np.iinfo(np.int32).max),
)
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
return IntCollectionOutput(collection=list(numpy.random.randint(self.low, self.high, size=self.size)))
rng = np.random.default_rng(self.seed)
return IntCollectionOutput(
collection=list(rng.integers(low=self.low, high=self.high, size=self.size))
)

View File

@ -9,7 +9,7 @@ from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageField, ImageType
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput
from .image import ImageOutput, build_image_output
class CvInvocationConfig(BaseModel):
@ -56,7 +56,14 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, image_inpainted)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, image_inpainted, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image_inpainted,
)

View File

@ -9,13 +9,12 @@ from torch import Tensor
from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageField, ImageType
from invokeai.app.invocations.util.get_model import choose_model
from invokeai.app.invocations.util.choose_model import choose_model
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput
from .image import ImageOutput, build_image_output
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
from ...backend.stable_diffusion import PipelineIntermediateState
from ..models.exceptions import CanceledException
from ..util.step_callback import diffusers_step_callback_adapter
from ..util.step_callback import stable_diffusion_step_callback
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
@ -58,28 +57,31 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
# TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress(
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
self,
context: InvocationContext,
source_node_id: str,
intermediate_state: PipelineIntermediateState,
) -> None:
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
raise CanceledException
step = intermediate_state.step
if intermediate_state.predicted_original is not None:
# Some schedulers report not only the noisy latents at the current timestep,
# but also their estimate so far of what the de-noised latents will be.
sample = intermediate_state.predicted_original
else:
sample = intermediate_state.latents
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
stable_diffusion_step_callback(
context=context,
intermediate_state=intermediate_state,
node=self.dict(),
source_node_id=source_node_id,
)
def invoke(self, context: InvocationContext) -> ImageOutput:
# Handle invalid model parameter
model = choose_model(context.services.model_manager, self.model)
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
outputs = Txt2Img(model).generate(
prompt=self.prompt,
step_callback=partial(self.dispatch_progress, context),
step_callback=partial(self.dispatch_progress, context, source_node_id),
**self.dict(
exclude={"prompt"}
), # Shorthand for passing all of the parameters above manually
@ -95,9 +97,18 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, generate_output.image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(
image_type, image_name, generate_output.image, metadata
)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=generate_output.image,
)
@ -117,20 +128,17 @@ class ImageToImageInvocation(TextToImageInvocation):
)
def dispatch_progress(
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
self,
context: InvocationContext,
source_node_id: str,
intermediate_state: PipelineIntermediateState,
) -> None:
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
raise CanceledException
step = intermediate_state.step
if intermediate_state.predicted_original is not None:
# Some schedulers report not only the noisy latents at the current timestep,
# but also their estimate so far of what the de-noised latents will be.
sample = intermediate_state.predicted_original
else:
sample = intermediate_state.latents
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
stable_diffusion_step_callback(
context=context,
intermediate_state=intermediate_state,
node=self.dict(),
source_node_id=source_node_id,
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = (
@ -145,15 +153,21 @@ class ImageToImageInvocation(TextToImageInvocation):
# Handle invalid model parameter
model = choose_model(context.services.model_manager, self.model)
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
outputs = Img2Img(model).generate(
prompt=self.prompt,
init_image=image,
init_mask=mask,
step_callback=partial(self.dispatch_progress, context),
**self.dict(
exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually
)
prompt=self.prompt,
init_image=image,
init_mask=mask,
step_callback=partial(self.dispatch_progress, context, source_node_id),
**self.dict(
exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually
)
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
# each time it is called. We only need the first one.
@ -168,11 +182,19 @@ class ImageToImageInvocation(TextToImageInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, result_image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, result_image, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=result_image,
)
class InpaintInvocation(ImageToImageInvocation):
"""Generates an image using inpaint."""
@ -188,20 +210,17 @@ class InpaintInvocation(ImageToImageInvocation):
)
def dispatch_progress(
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
self,
context: InvocationContext,
source_node_id: str,
intermediate_state: PipelineIntermediateState,
) -> None:
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
raise CanceledException
step = intermediate_state.step
if intermediate_state.predicted_original is not None:
# Some schedulers report not only the noisy latents at the current timestep,
# but also their estimate so far of what the de-noised latents will be.
sample = intermediate_state.predicted_original
else:
sample = intermediate_state.latents
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
stable_diffusion_step_callback(
context=context,
intermediate_state=intermediate_state,
node=self.dict(),
source_node_id=source_node_id,
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = (
@ -220,15 +239,21 @@ class InpaintInvocation(ImageToImageInvocation):
# Handle invalid model parameter
model = choose_model(context.services.model_manager, self.model)
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
outputs = Inpaint(model).generate(
prompt=self.prompt,
init_img=image,
init_mask=mask,
step_callback=partial(self.dispatch_progress, context),
**self.dict(
exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually
)
prompt=self.prompt,
init_img=image,
init_mask=mask,
step_callback=partial(self.dispatch_progress, context, source_node_id),
**self.dict(
exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually
)
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
# each time it is called. We only need the first one.
@ -243,7 +268,14 @@ class InpaintInvocation(ImageToImageInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, result_image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, result_image, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=result_image,
)

View File

@ -1,6 +1,5 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from datetime import datetime, timezone
from typing import Literal, Optional
import numpy
@ -8,8 +7,12 @@ from PIL import Image, ImageFilter, ImageOps
from pydantic import BaseModel, Field
from ..models.image import ImageField, ImageType
from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InvocationContext,
InvocationConfig,
)
class PILInvocationConfig(BaseModel):
@ -22,50 +25,73 @@ class PILInvocationConfig(BaseModel):
},
}
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
#fmt: off
# fmt: off
type: Literal["image"] = "image"
image: ImageField = Field(default=None, description="The output image")
#fmt: on
width: Optional[int] = Field(default=None, description="The width of the image in pixels")
height: Optional[int] = Field(default=None, description="The height of the image in pixels")
# fmt: on
class Config:
schema_extra = {
'required': [
'type',
'image',
]
"required": ["type", "image", "width", "height", "mode"]
}
def build_image_output(
image_type: ImageType, image_name: str, image: Image.Image
) -> ImageOutput:
"""Builds an ImageOutput and its ImageField"""
image_field = ImageField(
image_name=image_name,
image_type=image_type,
)
return ImageOutput(
image=image_field,
width=image.width,
height=image.height,
mode=image.mode,
)
class MaskOutput(BaseInvocationOutput):
"""Base class for invocations that output a mask"""
#fmt: off
# fmt: off
type: Literal["mask"] = "mask"
mask: ImageField = Field(default=None, description="The output mask")
#fmt: on
# fmt: on
class Config:
schema_extra = {
'required': [
'type',
'mask',
"required": [
"type",
"mask",
]
}
# TODO: this isn't really necessary anymore
class LoadImageInvocation(BaseInvocation):
"""Load an image from a filename and provide it as output."""
#fmt: off
"""Load an image and provide it as output."""
# fmt: off
type: Literal["load_image"] = "load_image"
# Inputs
image_type: ImageType = Field(description="The type of the image")
image_name: str = Field(description="The name of the image")
#fmt: on
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
return ImageOutput(
image=ImageField(image_type=self.image_type, image_name=self.image_name)
image = context.services.images.get(self.image_type, self.image_name)
return build_image_output(
image_type=self.image_type,
image_name=self.image_name,
image=image,
)
@ -86,16 +112,17 @@ class ShowImageInvocation(BaseInvocation):
# TODO: how to handle failure?
return ImageOutput(
image=ImageField(
image_type=self.image.image_type, image_name=self.image.image_name
)
return build_image_output(
image_type=self.image.image_type,
image_name=self.image.image_name,
image=image,
)
class CropImageInvocation(BaseInvocation, PILInvocationConfig):
"""Crops an image to a specified box. The box can be outside of the image."""
#fmt: off
# fmt: off
type: Literal["crop"] = "crop"
# Inputs
@ -104,7 +131,7 @@ class CropImageInvocation(BaseInvocation, PILInvocationConfig):
y: int = Field(default=0, description="The top y coordinate of the crop rectangle")
width: int = Field(default=512, gt=0, description="The width of the crop rectangle")
height: int = Field(default=512, gt=0, description="The height of the crop rectangle")
#fmt: on
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
@ -120,15 +147,23 @@ class CropImageInvocation(BaseInvocation, PILInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, image_crop)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, image_crop, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image_crop,
)
class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
"""Pastes an image into another image."""
#fmt: off
# fmt: off
type: Literal["paste"] = "paste"
# Inputs
@ -137,7 +172,7 @@ class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting")
x: int = Field(default=0, description="The left x coordinate at which to paste the image")
y: int = Field(default=0, description="The top y coordinate at which to paste the image")
#fmt: on
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
base_image = context.services.images.get(
@ -170,21 +205,29 @@ class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, new_image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, new_image, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=new_image,
)
class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
"""Extracts the alpha channel of an image as a mask."""
#fmt: off
# fmt: off
type: Literal["tomask"] = "tomask"
# Inputs
image: ImageField = Field(default=None, description="The image to create the mask from")
invert: bool = Field(default=False, description="Whether or not to invert the mask")
#fmt: on
# fmt: on
def invoke(self, context: InvocationContext) -> MaskOutput:
image = context.services.images.get(
@ -199,21 +242,26 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, image_mask)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, image_mask, metadata)
return MaskOutput(mask=ImageField(image_type=image_type, image_name=image_name))
class BlurInvocation(BaseInvocation, PILInvocationConfig):
"""Blurs an image"""
#fmt: off
# fmt: off
type: Literal["blur"] = "blur"
# Inputs
image: ImageField = Field(default=None, description="The image to blur")
radius: float = Field(default=8.0, ge=0, description="The blur radius")
blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur")
#fmt: on
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
@ -231,22 +279,28 @@ class BlurInvocation(BaseInvocation, PILInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, blur_image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, blur_image, metadata)
return build_image_output(
image_type=image_type, image_name=image_name, image=blur_image
)
class LerpInvocation(BaseInvocation, PILInvocationConfig):
"""Linear interpolation of all pixels of an image"""
#fmt: off
# fmt: off
type: Literal["lerp"] = "lerp"
# Inputs
image: ImageField = Field(default=None, description="The image to lerp")
min: int = Field(default=0, ge=0, le=255, description="The minimum output value")
max: int = Field(default=255, ge=0, le=255, description="The maximum output value")
#fmt: on
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
@ -262,22 +316,28 @@ class LerpInvocation(BaseInvocation, PILInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, lerp_image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, lerp_image, metadata)
return build_image_output(
image_type=image_type, image_name=image_name, image=lerp_image
)
class InverseLerpInvocation(BaseInvocation, PILInvocationConfig):
"""Inverse linear interpolation of all pixels of an image"""
#fmt: off
# fmt: off
type: Literal["ilerp"] = "ilerp"
# Inputs
image: ImageField = Field(default=None, description="The image to lerp")
min: int = Field(default=0, ge=0, le=255, description="The minimum input value")
max: int = Field(default=255, ge=0, le=255, description="The maximum input value")
#fmt: on
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
@ -298,7 +358,12 @@ class InverseLerpInvocation(BaseInvocation, PILInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, ilerp_image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, ilerp_image, metadata)
return build_image_output(
image_type=image_type, image_name=image_name, image=ilerp_image
)

View File

@ -5,9 +5,9 @@ from typing import Literal, Optional
from pydantic import BaseModel, Field
import torch
from invokeai.app.models.exceptions import CanceledException
from invokeai.app.invocations.util.get_model import choose_model
from invokeai.app.util.step_callback import diffusers_step_callback_adapter
from invokeai.app.invocations.util.choose_model import choose_model
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from ...backend.model_management.model_manager import ModelManager
from ...backend.util.devices import choose_torch_device, torch_dtype
@ -19,7 +19,7 @@ from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationCont
import numpy as np
from ..services.image_storage import ImageType
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput
from .image import ImageField, ImageOutput, build_image_output
from ...backend.stable_diffusion import PipelineIntermediateState
from diffusers.schedulers import SchedulerMixin as Scheduler
import diffusers
@ -31,6 +31,8 @@ class LatentsField(BaseModel):
latents_name: Optional[str] = Field(default=None, description="The name of the latents")
class Config:
schema_extra = {"required": ["latents_name"]}
class LatentsOutput(BaseInvocationOutput):
"""Base class for invocations that output latents"""
@ -170,21 +172,14 @@ class TextToLatentsInvocation(BaseInvocation):
# TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress(
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
self, context: InvocationContext, source_node_id: str, intermediate_state: PipelineIntermediateState
) -> None:
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
raise CanceledException
step = intermediate_state.step
if intermediate_state.predicted_original is not None:
# Some schedulers report not only the noisy latents at the current timestep,
# but also their estimate so far of what the de-noised latents will be.
sample = intermediate_state.predicted_original
else:
sample = intermediate_state.latents
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
stable_diffusion_step_callback(
context=context,
intermediate_state=intermediate_state,
node=self.dict(),
source_node_id=source_node_id,
)
def get_model(self, model_manager: ModelManager) -> StableDiffusionGeneratorPipeline:
model_info = choose_model(model_manager, self.model)
@ -231,8 +226,12 @@ class TextToLatentsInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name)
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, state)
self.dispatch_progress(context, source_node_id, state)
model = self.get_model(context.services.model_manager)
conditioning_data = self.get_conditioning_data(model)
@ -281,58 +280,12 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
noise = context.services.latents.get(self.noise.latents_name)
latent = context.services.latents.get(self.latents.latents_name)
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, state)
model = self.get_model(context.services.model_manager)
conditioning_data = self.get_conditioning_data(model)
# TODO: Verify the noise is the right size
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
latent, device=model.device, dtype=latent.dtype
)
timesteps, _ = model.get_img2img_timesteps(
self.steps,
self.strength,
device=model.device,
)
result_latents, result_attention_map_saver = model.latents_from_embeddings(
latents=initial_latents,
timesteps=timesteps,
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
callback=step_callback
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, result_latents)
return LatentsOutput(
latents=LatentsField(latents_name=name)
)
class LatentsToLatentsInvocation(TextToLatentsInvocation):
"""Generates latents using latents as base image."""
type: Literal["l2l"] = "l2l"
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
strength: float = Field(default=0.5, description="The strength of the latents to use")
def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name)
latent = context.services.latents.get(self.latents.latents_name)
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, state)
self.dispatch_progress(context, source_node_id, state)
model = self.get_model(context.services.model_manager)
conditioning_data = self.get_conditioning_data(model)
@ -405,7 +358,14 @@ class LatentsToImageInvocation(BaseInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, image, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image
)

View File

@ -1,12 +1,11 @@
from datetime import datetime, timezone
from typing import Literal, Union
from pydantic import Field
from invokeai.app.models.image import ImageField, ImageType
from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput
from .image import ImageOutput, build_image_output
class RestoreFaceInvocation(BaseInvocation):
"""Restores faces in an image."""
@ -44,7 +43,14 @@ class RestoreFaceInvocation(BaseInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, results[0][0])
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, results[0][0], metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=results[0][0]
)

View File

@ -1,14 +1,12 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from datetime import datetime, timezone
from typing import Literal, Union
from pydantic import Field
from invokeai.app.models.image import ImageField, ImageType
from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput
from .image import ImageOutput, build_image_output
class UpscaleInvocation(BaseInvocation):
@ -49,7 +47,14 @@ class UpscaleInvocation(BaseInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, results[0][0])
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, results[0][0], metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=results[0][0]
)

View File

@ -1,11 +1,14 @@
from invokeai.app.invocations.baseinvocation import InvocationContext
from invokeai.backend.model_management.model_manager import ModelManager
def choose_model(model_manager: ModelManager, model_name: str):
"""Returns the default model if the `model_name` not a valid model, else returns the selected model."""
if model_manager.valid_model(model_name):
return model_manager.get_model(model_name)
model = model_manager.get_model(model_name)
else:
print(f"* Warning: '{model_name}' is not a valid model name. Using default model instead.")
return model_manager.get_model()
model = model_manager.get_model()
print(
f"* Warning: '{model_name}' is not a valid model name. Using default model \'{model['model_name']}\' instead."
)
return model

View File

@ -9,6 +9,14 @@ class ImageType(str, Enum):
UPLOAD = "uploads"
def is_image_type(obj):
try:
ImageType(obj)
except ValueError:
return False
return True
class ImageField(BaseModel):
"""An image field used for passing image objects between invocations"""
@ -18,9 +26,4 @@ class ImageField(BaseModel):
image_name: Optional[str] = Field(default=None, description="The name of the image")
class Config:
schema_extra = {
"required": [
"image_type",
"image_name",
]
}
schema_extra = {"required": ["image_type", "image_name"]}

View File

@ -1,11 +0,0 @@
from typing import Optional
from pydantic import BaseModel, Field
class ImageMetadata(BaseModel):
"""An image's metadata"""
timestamp: float = Field(description="The creation timestamp of the image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
# TODO: figure out metadata
sd_metadata: Optional[dict] = Field(default={}, description="The image's SD-specific metadata")

View File

@ -1,10 +1,9 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Any, Dict, TypedDict
from typing import Any
from invokeai.app.api.models.images import ProgressImage
from invokeai.app.util.misc import get_timestamp
ProgressImage = TypedDict(
"ProgressImage", {"dataURL": str, "width": int, "height": int}
)
class EventServiceBase:
session_event: str = "session_event"
@ -14,7 +13,8 @@ class EventServiceBase:
def dispatch(self, event_name: str, payload: Any) -> None:
pass
def __emit_session_event(self, event_name: str, payload: Dict) -> None:
def __emit_session_event(self, event_name: str, payload: dict) -> None:
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.session_event,
payload=dict(event=event_name, data=payload),
@ -25,7 +25,8 @@ class EventServiceBase:
def emit_generator_progress(
self,
graph_execution_state_id: str,
invocation_id: str,
node: dict,
source_node_id: str,
progress_image: ProgressImage | None,
step: int,
total_steps: int,
@ -35,48 +36,60 @@ class EventServiceBase:
event_name="generator_progress",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
invocation_id=invocation_id,
progress_image=progress_image,
node=node,
source_node_id=source_node_id,
progress_image=progress_image.dict() if progress_image is not None else None,
step=step,
total_steps=total_steps,
),
)
def emit_invocation_complete(
self, graph_execution_state_id: str, invocation_id: str, result: Dict
self,
graph_execution_state_id: str,
result: dict,
node: dict,
source_node_id: str,
) -> None:
"""Emitted when an invocation has completed"""
self.__emit_session_event(
event_name="invocation_complete",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
invocation_id=invocation_id,
node=node,
source_node_id=source_node_id,
result=result,
),
)
def emit_invocation_error(
self, graph_execution_state_id: str, invocation_id: str, error: str
self,
graph_execution_state_id: str,
node: dict,
source_node_id: str,
error: str,
) -> None:
"""Emitted when an invocation has completed"""
self.__emit_session_event(
event_name="invocation_error",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
invocation_id=invocation_id,
node=node,
source_node_id=source_node_id,
error=error,
),
)
def emit_invocation_started(
self, graph_execution_state_id: str, invocation_id: str
self, graph_execution_state_id: str, node: dict, source_node_id: str
) -> None:
"""Emitted when an invocation has started"""
self.__emit_session_event(
event_name="invocation_started",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
invocation_id=invocation_id,
node=node,
source_node_id=source_node_id,
),
)
@ -84,5 +97,7 @@ class EventServiceBase:
"""Emitted when a session has completed all invocations"""
self.__emit_session_event(
event_name="graph_execution_state_complete",
payload=dict(graph_execution_state_id=graph_execution_state_id),
payload=dict(
graph_execution_state_id=graph_execution_state_id,
),
)

View File

@ -2,7 +2,6 @@
import copy
import itertools
import traceback
import uuid
from types import NoneType
from typing import (
@ -26,7 +25,6 @@ from ..invocations.baseinvocation import (
BaseInvocationOutput,
InvocationContext,
)
from .invocation_services import InvocationServices
class EdgeConnection(BaseModel):
@ -215,7 +213,7 @@ InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()]
class Graph(BaseModel):
id: str = Field(description="The id of this graph", default_factory=uuid.uuid4)
id: str = Field(description="The id of this graph", default_factory=lambda: uuid.uuid4().__str__())
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(
description="The nodes in this graph", default_factory=dict
@ -750,9 +748,7 @@ class Graph(BaseModel):
class GraphExecutionState(BaseModel):
"""Tracks the state of a graph execution"""
id: str = Field(
description="The id of the execution state", default_factory=uuid.uuid4
)
id: str = Field(description="The id of the execution state", default_factory=lambda: uuid.uuid4().__str__())
# TODO: Store a reference to the graph instead of the actual graph?
graph: Graph = Field(description="The graph being executed")

View File

@ -1,24 +1,24 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import datetime
import os
from glob import glob
from abc import ABC, abstractmethod
from enum import Enum
from pathlib import Path
from queue import Queue
from typing import Callable, Dict, List
from typing import Dict, List, Tuple
from PIL.Image import Image
import PIL.Image as PILImage
from pydantic import BaseModel
from invokeai.app.api.models.images import ImageResponse
from invokeai.app.models.image import ImageField, ImageType
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.api.models.images import ImageResponse, ImageResponseMetadata
from invokeai.app.models.image import ImageType
from invokeai.app.services.metadata import (
InvokeAIMetadata,
MetadataServiceBase,
build_invokeai_metadata_pnginfo,
)
from invokeai.app.services.item_storage import PaginatedResults
from invokeai.app.util.save_thumbnail import save_thumbnail
from invokeai.backend.image_util import PngWriter
from invokeai.app.util.misc import get_timestamp
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
class ImageStorageBase(ABC):
@ -26,12 +26,14 @@ class ImageStorageBase(ABC):
@abstractmethod
def get(self, image_type: ImageType, image_name: str) -> Image:
"""Retrieves an image as PIL Image."""
pass
@abstractmethod
def list(
self, image_type: ImageType, page: int = 0, per_page: int = 10
) -> PaginatedResults[ImageResponse]:
"""Gets a paginated list of images."""
pass
# TODO: make this a bit more flexible for e.g. cloud storage
@ -39,35 +41,51 @@ class ImageStorageBase(ABC):
def get_path(
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
) -> str:
"""Gets the path to an image or its thumbnail."""
pass
# TODO: make this a bit more flexible for e.g. cloud storage
@abstractmethod
def validate_path(self, path: str) -> bool:
"""Validates an image path."""
pass
@abstractmethod
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
def save(
self,
image_type: ImageType,
image_name: str,
image: Image,
metadata: InvokeAIMetadata | None = None,
) -> Tuple[str, str, int]:
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image path, thumbnail path, and created timestamp."""
pass
@abstractmethod
def delete(self, image_type: ImageType, image_name: str) -> None:
"""Deletes an image and its thumbnail (if one exists)."""
pass
def create_name(self, context_id: str, node_id: str) -> str:
return f"{context_id}_{node_id}_{str(int(datetime.datetime.now(datetime.timezone.utc).timestamp()))}.png"
"""Creates a unique contextual image filename."""
return f"{context_id}_{node_id}_{str(get_timestamp())}.png"
class DiskImageStorage(ImageStorageBase):
"""Stores images on disk"""
__output_folder: str
__pngWriter: PngWriter
__cache_ids: Queue # TODO: this is an incredibly naive cache
__cache: Dict[str, Image]
__max_cache_size: int
__metadata_service: MetadataServiceBase
def __init__(self, output_folder: str):
def __init__(self, output_folder: str, metadata_service: MetadataServiceBase):
self.__output_folder = output_folder
self.__pngWriter = PngWriter(output_folder)
self.__cache = dict()
self.__cache_ids = Queue()
self.__max_cache_size = 10 # TODO: get this from config
self.__metadata_service = metadata_service
Path(output_folder).mkdir(parents=True, exist_ok=True)
@ -100,6 +118,9 @@ class DiskImageStorage(ImageStorageBase):
for path in page_of_image_paths:
filename = os.path.basename(path)
img = PILImage.open(path)
invokeai_metadata = self.__metadata_service.get_metadata(img)
page_of_images.append(
ImageResponse(
image_type=image_type.value,
@ -107,11 +128,12 @@ class DiskImageStorage(ImageStorageBase):
# TODO: DiskImageStorage should not be building URLs...?
image_url=f"api/v1/images/{image_type.value}/{filename}",
thumbnail_url=f"api/v1/images/{image_type.value}/thumbnails/{os.path.splitext(filename)[0]}.webp",
# TODO: Creation of this object should happen elsewhere, just making it fit here so it works
metadata=ImageMetadata(
timestamp=os.path.getctime(path),
# TODO: Creation of this object should happen elsewhere (?), just making it fit here so it works
metadata=ImageResponseMetadata(
created=int(os.path.getctime(path)),
width=img.width,
height=img.height,
invokeai=invokeai_metadata,
),
)
)
@ -142,26 +164,50 @@ class DiskImageStorage(ImageStorageBase):
def get_path(
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
) -> str:
# strip out any relative path shenanigans
basename = os.path.basename(image_name)
if is_thumbnail:
path = os.path.join(
self.__output_folder, image_type, "thumbnails", image_name
self.__output_folder, image_type, "thumbnails", basename
)
else:
path = os.path.join(self.__output_folder, image_type, image_name)
path = os.path.join(self.__output_folder, image_type, basename)
return path
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
image_subpath = os.path.join(image_type, image_name)
self.__pngWriter.save_image_and_prompt_to_png(
image, "", image_subpath, None
) # TODO: just pass full path to png writer
save_thumbnail(
image=image,
filename=image_name,
path=os.path.join(self.__output_folder, image_type, "thumbnails"),
)
def validate_path(self, path: str) -> bool:
try:
os.stat(path)
return True
except Exception:
return False
def save(
self,
image_type: ImageType,
image_name: str,
image: Image,
metadata: InvokeAIMetadata | None = None,
) -> Tuple[str, str, int]:
image_path = self.get_path(image_type, image_name)
# TODO: Reading the image and then saving it strips the metadata...
if metadata:
pnginfo = build_invokeai_metadata_pnginfo(metadata=metadata)
image.save(image_path, "PNG", pnginfo=pnginfo)
else:
image.save(image_path) # this saved image has an empty info
thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(image_type, thumbnail_name, is_thumbnail=True)
thumbnail_image = make_thumbnail(image)
thumbnail_image.save(thumbnail_path)
self.__set_cache(image_path, image)
self.__set_cache(thumbnail_path, thumbnail_image)
return (image_path, thumbnail_path, int(os.path.getctime(image_path)))
def delete(self, image_type: ImageType, image_name: str) -> None:
image_path = self.get_path(image_type, image_name)

View File

@ -1,30 +1,17 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import time
from abc import ABC, abstractmethod
from queue import Queue
import time
from pydantic import BaseModel, Field
# TODO: make this serializable
class InvocationQueueItem:
# session_id: str
graph_execution_state_id: str
invocation_id: str
invoke_all: bool
timestamp: float
def __init__(
self,
# session_id: str,
graph_execution_state_id: str,
invocation_id: str,
invoke_all: bool = False,
):
# self.session_id = session_id
self.graph_execution_state_id = graph_execution_state_id
self.invocation_id = invocation_id
self.invoke_all = invoke_all
self.timestamp = time.time()
class InvocationQueueItem(BaseModel):
graph_execution_state_id: str = Field(description="The ID of the graph execution state")
invocation_id: str = Field(description="The ID of the node being invoked")
invoke_all: bool = Field(default=False)
timestamp: float = Field(default_factory=time.time)
class InvocationQueueABC(ABC):

View File

@ -1,4 +1,5 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from invokeai.app.services.metadata import MetadataServiceBase
from invokeai.backend import ModelManager
from .events import EventServiceBase
@ -14,6 +15,7 @@ class InvocationServices:
events: EventServiceBase
latents: LatentsStorageBase
images: ImageStorageBase
metadata: MetadataServiceBase
queue: InvocationQueueABC
model_manager: ModelManager
restoration: RestorationServices
@ -29,6 +31,7 @@ class InvocationServices:
events: EventServiceBase,
latents: LatentsStorageBase,
images: ImageStorageBase,
metadata: MetadataServiceBase,
queue: InvocationQueueABC,
graph_library: ItemStorageABC["LibraryGraph"],
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
@ -39,6 +42,7 @@ class InvocationServices:
self.events = events
self.latents = latents
self.images = images
self.metadata = metadata
self.queue = queue
self.graph_library = graph_library
self.graph_execution_manager = graph_execution_manager

View File

@ -0,0 +1,96 @@
import json
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, TypedDict
from PIL import Image, PngImagePlugin
from pydantic import BaseModel
from invokeai.app.models.image import ImageType, is_image_type
class MetadataImageField(TypedDict):
"""Pydantic-less ImageField, used for metadata parsing."""
image_type: ImageType
image_name: str
class MetadataLatentsField(TypedDict):
"""Pydantic-less LatentsField, used for metadata parsing."""
latents_name: str
# TODO: This is a placeholder for `InvocationsUnion` pending resolution of circular imports
NodeMetadata = Dict[
str, str | int | float | bool | MetadataImageField | MetadataLatentsField
]
class InvokeAIMetadata(TypedDict, total=False):
"""InvokeAI-specific metadata format."""
session_id: Optional[str]
node: Optional[NodeMetadata]
def build_invokeai_metadata_pnginfo(
metadata: InvokeAIMetadata | None,
) -> PngImagePlugin.PngInfo:
"""Builds a PngInfo object with key `"invokeai"` and value `metadata`"""
pnginfo = PngImagePlugin.PngInfo()
if metadata is not None:
pnginfo.add_text("invokeai", json.dumps(metadata))
return pnginfo
class MetadataServiceBase(ABC):
@abstractmethod
def get_metadata(self, image: Image.Image) -> InvokeAIMetadata | None:
"""Gets the InvokeAI metadata from a PIL Image, skipping invalid values"""
pass
@abstractmethod
def build_metadata(
self, session_id: str, node: BaseModel
) -> InvokeAIMetadata | None:
"""Builds an InvokeAIMetadata object"""
pass
class PngMetadataService(MetadataServiceBase):
"""Handles loading and building metadata for images."""
# TODO: Use `InvocationsUnion` to **validate** metadata as representing a fully-functioning node
def _load_metadata(self, image: Image.Image) -> dict | None:
"""Loads a specific info entry from a PIL Image."""
try:
info = image.info.get("invokeai")
if type(info) is not str:
return None
loaded_metadata = json.loads(info)
if type(loaded_metadata) is not dict:
return None
if len(loaded_metadata.items()) == 0:
return None
return loaded_metadata
except:
return None
def get_metadata(self, image: Image.Image) -> dict | None:
"""Retrieves an image's metadata as a dict"""
loaded_metadata = self._load_metadata(image)
return loaded_metadata
def build_metadata(self, session_id: str, node: BaseModel) -> InvokeAIMetadata:
metadata = InvokeAIMetadata(session_id=session_id, node=node.dict())
return metadata

View File

@ -43,10 +43,14 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
queue_item.invocation_id
)
# get the source node id to provide to clients (the prepared node id is not as useful)
source_node_id = graph_execution_state.prepared_source_mapping[invocation.id]
# Send starting event
self.__invoker.services.events.emit_invocation_started(
graph_execution_state_id=graph_execution_state.id,
invocation_id=invocation.id,
node=invocation.dict(),
source_node_id=source_node_id
)
# Invoke
@ -75,7 +79,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# Send complete event
self.__invoker.services.events.emit_invocation_complete(
graph_execution_state_id=graph_execution_state.id,
invocation_id=invocation.id,
node=invocation.dict(),
source_node_id=source_node_id,
result=outputs.dict(),
)
@ -99,7 +104,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# Send error event
self.__invoker.services.events.emit_invocation_error(
graph_execution_state_id=graph_execution_state.id,
invocation_id=invocation.id,
node=invocation.dict(),
source_node_id=source_node_id,
error=error,
)

View File

@ -35,7 +35,8 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
self._create_table()
def _create_table(self):
with self._lock:
try:
self._lock.acquire()
self._cursor.execute(
f"""CREATE TABLE IF NOT EXISTS {self._table_name} (
item TEXT,
@ -44,27 +45,34 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
self._cursor.execute(
f"""CREATE UNIQUE INDEX IF NOT EXISTS {self._table_name}_id ON {self._table_name}(id);"""
)
self._conn.commit()
finally:
self._lock.release()
def _parse_item(self, item: str) -> T:
item_type = get_args(self.__orig_class__)[0]
return parse_raw_as(item_type, item)
def set(self, item: T):
with self._lock:
try:
self._lock.acquire()
self._cursor.execute(
f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""",
(item.json(),),
)
self._conn.commit()
finally:
self._lock.release()
self._on_changed(item)
def get(self, id: str) -> Union[T, None]:
with self._lock:
try:
self._lock.acquire()
self._cursor.execute(
f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),)
)
result = self._cursor.fetchone()
finally:
self._lock.release()
if not result:
return None
@ -72,15 +80,19 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
return self._parse_item(result[0])
def delete(self, id: str):
with self._lock:
try:
self._lock.acquire()
self._cursor.execute(
f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),)
)
self._conn.commit()
finally:
self._lock.release()
self._on_deleted(id)
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
with self._lock:
try:
self._lock.acquire()
self._cursor.execute(
f"""SELECT item FROM {self._table_name} LIMIT ? OFFSET ?;""",
(per_page, page * per_page),
@ -91,6 +103,8 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""")
count = self._cursor.fetchone()[0]
finally:
self._lock.release()
pageCount = int(count / per_page) + 1
@ -101,7 +115,8 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
def search(
self, query: str, page: int = 0, per_page: int = 10
) -> PaginatedResults[T]:
with self._lock:
try:
self._lock.acquire()
self._cursor.execute(
f"""SELECT item FROM {self._table_name} WHERE item LIKE ? LIMIT ? OFFSET ?;""",
(f"%{query}%", per_page, page * per_page),
@ -115,6 +130,8 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
(f"%{query}%",),
)
count = self._cursor.fetchone()[0]
finally:
self._lock.release()
pageCount = int(count / per_page) + 1

View File

@ -0,0 +1,5 @@
import datetime
def get_timestamp():
return int(datetime.datetime.now(datetime.timezone.utc).timestamp())

View File

@ -1,25 +0,0 @@
import os
from PIL import Image
def save_thumbnail(
image: Image.Image,
filename: str,
path: str,
size: int = 256,
) -> str:
"""
Saves a thumbnail of an image, returning its path.
"""
base_filename = os.path.splitext(filename)[0]
thumbnail_path = os.path.join(path, base_filename + ".webp")
if os.path.exists(thumbnail_path):
return thumbnail_path
image_copy = image.copy()
image_copy.thumbnail(size=(size, size))
image_copy.save(thumbnail_path, "WEBP")
return thumbnail_path

View File

@ -1,16 +1,41 @@
import torch
from invokeai.app.api.models.images import ProgressImage
from invokeai.app.models.exceptions import CanceledException
from ..invocations.baseinvocation import InvocationContext
from ...backend.util.util import image_to_dataURL
from ...backend.generator.base import Generator
from ...backend.stable_diffusion import PipelineIntermediateState
def fast_latents_step_callback(
sample: torch.Tensor,
step: int,
steps: int,
id: str,
def stable_diffusion_step_callback(
context: InvocationContext,
intermediate_state: PipelineIntermediateState,
node: dict,
source_node_id: str,
):
if context.services.queue.is_canceled(context.graph_execution_state_id):
raise CanceledException
# Some schedulers report not only the noisy latents at the current timestep,
# but also their estimate so far of what the de-noised latents will be. Use
# that estimate if it is available.
if intermediate_state.predicted_original is not None:
sample = intermediate_state.predicted_original
else:
sample = intermediate_state.latents
# TODO: This does not seem to be needed any more?
# # txt2img provides a Tensor in the step_callback
# # img2img provides a PipelineIntermediateState
# if isinstance(sample, PipelineIntermediateState):
# # this was an img2img
# print('img2img')
# latents = sample.latents
# step = sample.step
# else:
# print('txt2img')
# latents = sample
# step = intermediate_state.step
# TODO: only output a preview image when requested
image = Generator.sample_to_lowres_estimated_image(sample)
@ -21,23 +46,10 @@ def fast_latents_step_callback(
dataURL = image_to_dataURL(image, image_format="JPEG")
context.services.events.emit_generator_progress(
context.graph_execution_state_id,
id,
{"width": width, "height": height, "dataURL": dataURL},
step,
steps,
graph_execution_state_id=context.graph_execution_state_id,
node=node,
source_node_id=source_node_id,
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
step=intermediate_state.step,
total_steps=node["steps"],
)
def diffusers_step_callback_adapter(*cb_args, **kwargs):
"""
txt2img gives us a Tensor in the step_callbak, while img2img gives us a PipelineIntermediateState.
This adapter grabs the needed data and passes it along to the callback function.
"""
if isinstance(cb_args[0], PipelineIntermediateState):
progress_state: PipelineIntermediateState = cb_args[0]
return fast_latents_step_callback(
progress_state.latents, progress_state.step, **kwargs
)
else:
return fast_latents_step_callback(*cb_args, **kwargs)

View File

@ -0,0 +1,15 @@
import os
from PIL import Image
def get_thumbnail_name(image_name: str) -> str:
"""Formats given an image name, returns the appropriate thumbnail image name"""
thumbnail_name = os.path.splitext(image_name)[0] + ".webp"
return thumbnail_name
def make_thumbnail(image: Image.Image, size: int = 256) -> Image.Image:
"""Makes a thumbnail from a PIL Image"""
thumbnail = image.copy()
thumbnail.thumbnail(size=(size, size))
return thumbnail

View File

@ -57,7 +57,7 @@ class HuggingFaceConceptsLibrary(object):
self.concept_list.extend(list(local_concepts_to_add))
return self.concept_list
return self.concept_list
else:
elif Globals.internet_available is True:
try:
models = self.hf_api.list_models(
filter=ModelFilter(model_name="sd-concepts-library/")
@ -73,6 +73,8 @@ class HuggingFaceConceptsLibrary(object):
" ** You may load .bin and .pt file(s) manually using the --embedding_directory argument."
)
return self.concept_list
else:
return self.concept_list
def get_concept_model_path(self, concept_name: str) -> str:
"""

View File

@ -6,3 +6,5 @@ stats.html
index.html
.yarn/
*.scss
src/services/api/
src/services/fixtures/*

View File

@ -3,4 +3,8 @@ dist/
node_modules/
patches/
stats.html
index.html
.yarn/
*.scss
src/services/api/
src/services/fixtures/*

View File

@ -0,0 +1,87 @@
# Generated axios API client
- [Generated axios API client](#generated-axios-api-client)
- [Generation](#generation)
- [Generate the API client from the nodes web server](#generate-the-api-client-from-the-nodes-web-server)
- [Generate the API client from JSON](#generate-the-api-client-from-json)
- [Getting the JSON from the nodes web server](#getting-the-json-from-the-nodes-web-server)
- [Getting the JSON with a python script](#getting-the-json-with-a-python-script)
- [Generate the API client](#generate-the-api-client)
- [The generated client](#the-generated-client)
- [API client customisation](#api-client-customisation)
This API client is generated by an [openapi code generator](https://github.com/ferdikoomen/openapi-typescript-codegen).
All files in `invokeai/frontend/web/src/services/api/` are made by the generator.
## Generation
The axios client may be generated by from the OpenAPI schema from the nodes web server, or from JSON.
### Generate the API client from the nodes web server
We need to start the nodes web server, which serves the OpenAPI schema to the generator.
1. Start the nodes web server.
```bash
# from the repo root
python scripts/invoke-new.py --web
```
2. Generate the API client.
```bash
# from invokeai/frontend/web/
yarn api:web
```
### Generate the API client from JSON
The JSON can be acquired from the nodes web server, or with a python script.
#### Getting the JSON from the nodes web server
Start the nodes web server as described above, then download the file.
```bash
# from invokeai/frontend/web/
curl http://localhost:9090/openapi.json -o openapi.json
```
#### Getting the JSON with a python script
Run this python script from the repo root, so it can access the nodes server modules.
The script will output `openapi.json` in the repo root. Then we need to move it to `invokeai/frontend/web/`.
```bash
# from the repo root
python invokeai/app/util/generate_openapi_json.py
mv invokeai/app/util/openapi.json invokeai/frontend/web/services/fixtures/
```
#### Generate the API client
Now we can generate the API client from the JSON.
```bash
# from invokeai/frontend/web/
yarn api:file
```
## The generated client
The client will be written to `invokeai/frontend/web/services/api/`:
- `axios` client
- TS types
- An easily parseable schema, which we can use to generate UI
## API client customisation
The generator has a default `request.ts` file that implements a base `axios` client. The generated client uses this base client.
One shortcoming of this is base client is it does not provide response headers unless the response body is empty. To fix this, we provide our own lightly-patched `request.ts`.
To access the headers, call `getHeaders(response)` on any response from the generated api client. This function is exported from `invokeai/frontend/web/src/services/util/getHeaders.ts`.

View File

@ -0,0 +1,21 @@
# Events
Events via `socket.io`
## `actions.ts`
Redux actions for all socket events. Payloads all include a timestamp, and optionally some other data.
Any reducer (or middleware) can respond to the actions.
## `middleware.ts`
Redux middleware for events.
Handles dispatching the event actions. Only put logic here if it can't really go anywhere else.
For example, on connect we want to load images to the gallery if it's not populated. This requires dispatching a thunk, so we need to directly dispatch this in the middleware.
## `types.ts`
Hand-written types for the socket events. Cannot generate these from the server, but fortunately they are few and simple.

View File

@ -0,0 +1,17 @@
# Node Editor Design
WIP
nodes
everything in `src/features/nodes/`
have a look at `state.nodes.invocation`
- on socket connect, if no schema saved, fetch `localhost:9090/openapi.json`, save JSON to `state.nodes.schema`
- on fulfilled schema fetch, `parseSchema()` the schema. this outputs a `Record<string, Invocation>` which is saved to `state.nodes.invocations` - `Invocation` is like a template for the node
- when you add a node, the the `Invocation` template is passed to `InvocationComponent.tsx` to build the UI component for that node
- inputs/outputs have field types - and each field type gets an `FieldComponent` which includes a dispatcher to write state changes to redux `nodesSlice`
- `reactflow` sends changes to nodes/edges to redux
- to invoke, `buildNodesGraph()` state, then send this
- changed onClick Invoke button actions to build the schema, then when schema builds it dispatches the actual network request to create the session - see `session.ts`

View File

@ -0,0 +1,29 @@
# Package Scripts
WIP walkthrough of `package.json` scripts.
## `theme` & `theme:watch`
These run the Chakra CLI to generate types for the theme, or watch for code change and re-generate the types.
The CLI essentially monkeypatches Chakra's files in `node_modules`.
## `postinstall`
The `postinstall` script patches a few packages and runs the Chakra CLI to generate types for the theme.
### Patch `@chakra-ui/cli`
See: <https://github.com/chakra-ui/chakra-ui/issues/7394>
### Patch `redux-persist`
We want to persist the canvas state to `localStorage` but many canvas operations change data very quickly, so we need to debounce the writes to `localStorage`.
`redux-persist` is unfortunately unmaintained. The repo's current code is nonfunctional, but the last release's code depends on a package that was removed from `npm` for being malware, so we cannot just fork it.
So, we have to patch it directly. Perhaps a better way would be to write a debounced storage adapter, but I couldn't figure out how to do that.
### Patch `redux-deep-persist`
This package makes blacklisting and whitelisting persist configs very simple, but we have to patch it to match `redux-persist` for the types to work.

View File

@ -1,10 +1,16 @@
# InvokeAI Web UI
- [InvokeAI Web UI](#invokeai-web-ui)
- [Stack](#stack)
- [Contributing](#contributing)
- [Dev Environment](#dev-environment)
- [Production builds](#production-builds)
The UI is a fairly straightforward Typescript React app. The only really fancy stuff is the Unified Canvas.
Code in `invokeai/frontend/web/` if you want to have a look.
## Details
## Stack
State management is Redux via [Redux Toolkit](https://github.com/reduxjs/redux-toolkit). Communication with server is a mix of HTTP and [socket.io](https://github.com/socketio/socket.io-client) (with a custom redux middleware to help).
@ -32,7 +38,7 @@ Start everything in dev mode:
1. Start the dev server: `yarn dev`
2. Start the InvokeAI UI per usual: `invokeai --web`
3. Point your browser to the dev server address e.g. `http://localhost:5173/`
3. Point your browser to the dev server address e.g. <http://localhost:5173/>
### Production builds

View File

@ -1,6 +1,7 @@
import React, { PropsWithChildren } from 'react';
import { IAIPopoverProps } from '../web/src/common/components/IAIPopover';
import { IAIIconButtonProps } from '../web/src/common/components/IAIIconButton';
import { InvokeTabName } from 'features/ui/store/tabMap';
export {};
@ -64,9 +65,25 @@ declare module '@invoke-ai/invoke-ai-ui' {
declare class SettingsModal extends React.Component<SettingsModalProps> {
public constructor(props: SettingsModalProps);
}
declare class StatusIndicator extends React.Component<StatusIndicatorProps> {
public constructor(props: StatusIndicatorProps);
}
declare class ModelSelect extends React.Component<ModelSelectProps> {
public constructor(props: ModelSelectProps);
}
}
declare function Invoke(props: PropsWithChildren): JSX.Element;
interface InvokeProps extends PropsWithChildren {
apiUrl?: string;
disabledPanels?: string[];
disabledTabs?: InvokeTabName[];
token?: string;
shouldTransformUrls?: boolean;
}
declare function Invoke(props: InvokeProps): JSX.Element;
export {
ThemeChanger,
@ -74,5 +91,7 @@ export {
IAIPopover,
IAIIconButton,
SettingsModal,
StatusIndicator,
ModelSelect,
};
export = Invoke;

View File

@ -5,7 +5,10 @@
"scripts": {
"prepare": "cd ../../../ && husky install invokeai/frontend/web/.husky",
"dev": "concurrently \"vite dev\" \"yarn run theme:watch\"",
"dev:nodes": "concurrently \"vite dev --mode nodes\" \"yarn run theme:watch\"",
"build": "yarn run lint && vite build",
"api:web": "openapi -i http://localhost:9090/openapi.json -o src/services/api --client axios --useOptions --useUnionTypes --exportSchemas true --indent 2 --request src/services/fixtures/request.ts",
"api:file": "openapi -i src/services/fixtures/openapi.json -o src/services/api --client axios --useOptions --useUnionTypes --exportSchemas true --indent 2 --request src/services/fixtures/request.ts",
"preview": "vite preview",
"lint:madge": "madge --circular src/main.tsx",
"lint:eslint": "eslint --max-warnings=0 .",
@ -41,9 +44,11 @@
"@chakra-ui/react": "^2.5.1",
"@chakra-ui/styled-system": "^2.6.1",
"@chakra-ui/theme-tools": "^2.0.16",
"@dagrejs/graphlib": "^2.1.12",
"@emotion/react": "^11.10.6",
"@emotion/styled": "^11.10.6",
"@reduxjs/toolkit": "^1.9.2",
"@fontsource/inter": "^4.5.15",
"@reduxjs/toolkit": "^1.9.3",
"chakra-ui-contextmenu": "^1.0.5",
"dateformat": "^5.0.3",
"formik": "^2.2.9",
@ -67,15 +72,17 @@
"react-redux": "^8.0.5",
"react-transition-group": "^4.4.5",
"react-zoom-pan-pinch": "^2.6.1",
"reactflow": "^11.7.0",
"redux-deep-persist": "^1.0.7",
"redux-dynamic-middlewares": "^2.2.0",
"redux-persist": "^6.0.0",
"socket.io-client": "^4.6.0",
"use-image": "^1.1.0",
"uuid": "^9.0.0"
},
"devDependencies": {
"@fontsource/inter": "^4.5.15",
"@types/dateformat": "^5.0.0",
"@types/lodash": "^4.14.194",
"@types/react": "^18.0.28",
"@types/react-dom": "^18.0.11",
"@types/react-transition-group": "^4.4.5",
@ -83,6 +90,7 @@
"@typescript-eslint/eslint-plugin": "^5.52.0",
"@typescript-eslint/parser": "^5.52.0",
"@vitejs/plugin-react-swc": "^3.2.0",
"axios": "^1.3.4",
"babel-plugin-transform-imports": "^2.0.0",
"concurrently": "^7.6.0",
"eslint": "^8.34.0",
@ -90,13 +98,17 @@
"eslint-plugin-prettier": "^4.2.1",
"eslint-plugin-react": "^7.32.2",
"eslint-plugin-react-hooks": "^4.6.0",
"form-data": "^4.0.0",
"husky": "^8.0.3",
"lint-staged": "^13.1.2",
"madge": "^6.0.0",
"openapi-types": "^12.1.0",
"openapi-typescript-codegen": "^0.23.0",
"postinstall-postinstall": "^2.1.0",
"prettier": "^2.8.4",
"rollup-plugin-visualizer": "^5.9.0",
"terser": "^5.16.4",
"typescript": "4.9.5",
"vite": "^4.1.2",
"vite-plugin-eslint": "^1.8.1",
"vite-tsconfig-paths": "^4.0.5",

View File

@ -53,6 +53,7 @@
"txt2img": "Text To Image",
"img2img": "Image To Image",
"unifiedCanvas": "Unified Canvas",
"linear": "Linear",
"nodes": "Nodes",
"postprocessing": "Post Processing",
"nodesDesc": "A node based system for the generation of images is under development currently. Stay tuned for updates about this amazing feature.",
@ -525,6 +526,10 @@
"resetComplete": "Web UI has been reset. Refresh the page to reload."
},
"toast": {
"serverError": "Server Error",
"disconnected": "Disconnected from Server",
"connected": "Connected to Server",
"canceled": "Processing Canceled",
"tempFoldersEmptied": "Temp Folder Emptied",
"uploadFailed": "Upload failed",
"uploadFailedMultipleImagesDesc": "Multiple images pasted, may only upload one image at a time",

View File

@ -13,16 +13,42 @@ import { Box, Flex, Grid, Portal, useColorMode } from '@chakra-ui/react';
import { APP_HEIGHT, APP_WIDTH } from 'theme/util/constants';
import ImageGalleryPanel from 'features/gallery/components/ImageGalleryPanel';
import Lightbox from 'features/lightbox/components/Lightbox';
import { useAppSelector } from './storeHooks';
import { useAppDispatch, useAppSelector } from './storeHooks';
import { PropsWithChildren, useEffect } from 'react';
import { setDisabledPanels, setDisabledTabs } from 'features/ui/store/uiSlice';
import { InvokeTabName } from 'features/ui/store/tabMap';
import { shouldTransformUrlsChanged } from 'features/system/store/systemSlice';
keepGUIAlive();
const App = (props: PropsWithChildren) => {
interface Props extends PropsWithChildren {
options: {
disabledPanels: string[];
disabledTabs: InvokeTabName[];
shouldTransformUrls?: boolean;
};
}
const App = (props: Props) => {
useToastWatcher();
const currentTheme = useAppSelector((state) => state.ui.currentTheme);
const { setColorMode } = useColorMode();
const dispatch = useAppDispatch();
useEffect(() => {
dispatch(setDisabledPanels(props.options.disabledPanels));
}, [dispatch, props.options.disabledPanels]);
useEffect(() => {
dispatch(setDisabledTabs(props.options.disabledTabs));
}, [dispatch, props.options.disabledTabs]);
useEffect(() => {
dispatch(
shouldTransformUrlsChanged(Boolean(props.options.shouldTransformUrls))
);
}, [dispatch, props.options.shouldTransformUrls]);
useEffect(() => {
setColorMode(['light'].includes(currentTheme) ? 'light' : 'dark');

View File

@ -14,6 +14,8 @@
import { InvokeTabName } from 'features/ui/store/tabMap';
import { IRect } from 'konva/lib/types';
import { ImageMetadata, ImageType } from 'services/api';
import { AnyInvocation } from 'services/events/types';
/**
* TODO:
@ -113,7 +115,7 @@ export declare type Metadata = SystemGenerationMetadata & {
};
// An Image has a UUID, url, modified timestamp, width, height and maybe metadata
export declare type Image = {
export declare type _Image = {
uuid: string;
url: string;
thumbnail: string;
@ -124,11 +126,23 @@ export declare type Image = {
category: GalleryCategory;
isBase64?: boolean;
dreamPrompt?: 'string';
name?: string;
};
/**
* ResultImage
*/
export declare type Image = {
name: string;
type: ImageType;
url: string;
thumbnail: string;
metadata: ImageMetadata;
};
// GalleryImages is an array of Image.
export declare type GalleryImages = {
images: Array<Image>;
images: Array<_Image>;
};
/**
@ -275,7 +289,7 @@ export declare type SystemStatusResponse = SystemStatus;
export declare type SystemConfigResponse = SystemConfig;
export declare type ImageResultResponse = Omit<Image, 'uuid'> & {
export declare type ImageResultResponse = Omit<_Image, 'uuid'> & {
boundingBox?: IRect;
generationMode: InvokeTabName;
};
@ -296,7 +310,7 @@ export declare type ErrorResponse = {
};
export declare type GalleryImagesResponse = {
images: Array<Omit<Image, 'uuid'>>;
images: Array<Omit<_Image, 'uuid'>>;
areMoreImagesAvailable: boolean;
category: GalleryCategory;
};

View File

@ -20,6 +20,7 @@ export const readinessSelector = createSelector(
seedWeights,
initialImage,
seed,
isImageToImageEnabled,
} = generation;
const { isProcessing, isConnected } = system;
@ -33,7 +34,7 @@ export const readinessSelector = createSelector(
reasonsWhyNotReady.push('Missing prompt');
}
if (activeTabName === 'img2img' && !initialImage) {
if (isImageToImageEnabled && !initialImage) {
isReady = false;
reasonsWhyNotReady.push('No initial image selected');
}

View File

@ -13,9 +13,13 @@ import { InvokeTabName } from 'features/ui/store/tabMap';
export const generateImage = createAction<InvokeTabName>(
'socketio/generateImage'
);
export const runESRGAN = createAction<InvokeAI.Image>('socketio/runESRGAN');
export const runFacetool = createAction<InvokeAI.Image>('socketio/runFacetool');
export const deleteImage = createAction<InvokeAI.Image>('socketio/deleteImage');
export const runESRGAN = createAction<InvokeAI._Image>('socketio/runESRGAN');
export const runFacetool = createAction<InvokeAI._Image>(
'socketio/runFacetool'
);
export const deleteImage = createAction<InvokeAI._Image>(
'socketio/deleteImage'
);
export const requestImages = createAction<GalleryCategory>(
'socketio/requestImages'
);

View File

@ -91,7 +91,7 @@ const makeSocketIOEmitters = (
})
);
},
emitRunESRGAN: (imageToProcess: InvokeAI.Image) => {
emitRunESRGAN: (imageToProcess: InvokeAI._Image) => {
dispatch(setIsProcessing(true));
const {
@ -119,7 +119,7 @@ const makeSocketIOEmitters = (
})
);
},
emitRunFacetool: (imageToProcess: InvokeAI.Image) => {
emitRunFacetool: (imageToProcess: InvokeAI._Image) => {
dispatch(setIsProcessing(true));
const {
@ -150,7 +150,7 @@ const makeSocketIOEmitters = (
})
);
},
emitDeleteImage: (imageToDelete: InvokeAI.Image) => {
emitDeleteImage: (imageToDelete: InvokeAI._Image) => {
const { url, uuid, category, thumbnail } = imageToDelete;
dispatch(removeImage(imageToDelete));
socketio.emit('deleteImage', url, thumbnail, uuid, category);

View File

@ -34,8 +34,9 @@ import type { RootState } from 'app/store';
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
import {
clearInitialImage,
initialImageSelected,
setInfillMethod,
setInitialImage,
// setInitialImage,
setMaskPath,
} from 'features/parameters/store/generationSlice';
import { tabMap } from 'features/ui/store/tabMap';
@ -142,15 +143,17 @@ const makeSocketIOListeners = (
}
}
if (shouldLoopback) {
const activeTabName = tabMap[activeTab];
switch (activeTabName) {
case 'img2img': {
dispatch(setInitialImage(newImage));
break;
}
}
}
// TODO: fix
// if (shouldLoopback) {
// const activeTabName = tabMap[activeTab];
// switch (activeTabName) {
// case 'img2img': {
// dispatch(initialImageSelected(newImage.uuid));
// // dispatch(setInitialImage(newImage));
// break;
// }
// }
// }
dispatch(clearIntermediateImage());
@ -262,7 +265,7 @@ const makeSocketIOListeners = (
*/
// Generate a UUID for each image
const preparedImages = images.map((image): InvokeAI.Image => {
const preparedImages = images.map((image): InvokeAI._Image => {
return {
uuid: uuidv4(),
...image,
@ -334,7 +337,7 @@ const makeSocketIOListeners = (
if (
initialImage === url ||
(initialImage as InvokeAI.Image)?.url === url
(initialImage as InvokeAI._Image)?.url === url
) {
dispatch(clearInitialImage());
}

View File

@ -29,6 +29,8 @@ export const socketioMiddleware = () => {
path: `${window.location.pathname}socket.io`,
});
socketio.disconnect();
let areListenersSet = false;
const middleware: Middleware = (store) => (next) => (action) => {

View File

@ -2,18 +2,32 @@ import { combineReducers, configureStore } from '@reduxjs/toolkit';
import { persistReducer } from 'redux-persist';
import storage from 'redux-persist/lib/storage'; // defaults to localStorage for web
import dynamicMiddlewares from 'redux-dynamic-middlewares';
import { getPersistConfig } from 'redux-deep-persist';
import canvasReducer from 'features/canvas/store/canvasSlice';
import galleryReducer from 'features/gallery/store/gallerySlice';
import resultsReducer from 'features/gallery/store/resultsSlice';
import uploadsReducer from 'features/gallery/store/uploadsSlice';
import lightboxReducer from 'features/lightbox/store/lightboxSlice';
import generationReducer from 'features/parameters/store/generationSlice';
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
import systemReducer from 'features/system/store/systemSlice';
import uiReducer from 'features/ui/store/uiSlice';
import modelsReducer from 'features/system/store/modelSlice';
import nodesReducer from 'features/nodes/store/nodesSlice';
import { socketioMiddleware } from './socketio/middleware';
import { socketMiddleware } from 'services/events/middleware';
import { canvasBlacklist } from 'features/canvas/store/canvasPersistBlacklist';
import { galleryBlacklist } from 'features/gallery/store/galleryPersistBlacklist';
import { generationBlacklist } from 'features/parameters/store/generationPersistBlacklist';
import { lightboxBlacklist } from 'features/lightbox/store/lightboxPersistBlacklist';
import { modelsBlacklist } from 'features/system/store/modelsPersistBlacklist';
import { nodesBlacklist } from 'features/nodes/store/nodesPersistBlacklist';
import { postprocessingBlacklist } from 'features/parameters/store/postprocessingPersistBlacklist';
import { systemBlacklist } from 'features/system/store/systemPersistsBlacklist';
import { uiBlacklist } from 'features/ui/store/uiPersistBlacklist';
/**
* redux-persist provides an easy and reliable way to persist state across reloads.
@ -29,49 +43,18 @@ import { socketioMiddleware } from './socketio/middleware';
* The necesssary nested persistors with blacklists are configured below.
*/
const canvasBlacklist = [
'cursorPosition',
'isCanvasInitialized',
'doesCanvasNeedScaling',
].map((blacklistItem) => `canvas.${blacklistItem}`);
const systemBlacklist = [
'currentIteration',
'currentStatus',
'currentStep',
'isCancelable',
'isConnected',
'isESRGANAvailable',
'isGFPGANAvailable',
'isProcessing',
'socketId',
'totalIterations',
'totalSteps',
'openModel',
'cancelOptions.cancelAfter',
].map((blacklistItem) => `system.${blacklistItem}`);
const galleryBlacklist = [
'categories',
'currentCategory',
'currentImage',
'currentImageUuid',
'shouldAutoSwitchToNewImages',
'intermediateImage',
].map((blacklistItem) => `gallery.${blacklistItem}`);
const lightboxBlacklist = ['isLightboxOpen'].map(
(blacklistItem) => `lightbox.${blacklistItem}`
);
const rootReducer = combineReducers({
generation: generationReducer,
postprocessing: postprocessingReducer,
gallery: galleryReducer,
system: systemReducer,
canvas: canvasReducer,
ui: uiReducer,
gallery: galleryReducer,
generation: generationReducer,
lightbox: lightboxReducer,
models: modelsReducer,
nodes: nodesReducer,
postprocessing: postprocessingReducer,
results: resultsReducer,
system: systemReducer,
ui: uiReducer,
uploads: uploadsReducer,
});
const rootPersistConfig = getPersistConfig({
@ -80,23 +63,40 @@ const rootPersistConfig = getPersistConfig({
rootReducer,
blacklist: [
...canvasBlacklist,
...systemBlacklist,
...galleryBlacklist,
...generationBlacklist,
...lightboxBlacklist,
...modelsBlacklist,
...nodesBlacklist,
...postprocessingBlacklist,
// ...resultsBlacklist,
'results',
...systemBlacklist,
...uiBlacklist,
// ...uploadsBlacklist,
'uploads',
],
debounce: 300,
});
const persistedReducer = persistReducer(rootPersistConfig, rootReducer);
// Continue with store setup
// TODO: rip the old middleware out when nodes is complete
export function buildMiddleware() {
if (import.meta.env.MODE === 'nodes' || import.meta.env.MODE === 'package') {
return socketMiddleware();
} else {
return socketioMiddleware();
}
}
export const store = configureStore({
reducer: persistedReducer,
middleware: (getDefaultMiddleware) =>
getDefaultMiddleware({
immutableCheck: false,
serializableCheck: false,
}).concat(socketioMiddleware()),
}).concat(dynamicMiddlewares),
devTools: {
// Uncommenting these very rapidly called actions makes the redux dev tools output much more readable
actionsDenylist: [

View File

@ -0,0 +1,8 @@
import { createAsyncThunk } from '@reduxjs/toolkit';
import { AppDispatch, RootState } from './store';
// https://redux-toolkit.js.org/usage/usage-with-typescript#defining-a-pre-typed-createasyncthunk
export const createAppAsyncThunk = createAsyncThunk.withTypes<{
state: RootState;
dispatch: AppDispatch;
}>();

View File

@ -44,12 +44,10 @@ export type IAIFullSliderProps = {
inputReadOnly?: boolean;
withReset?: boolean;
handleReset?: () => void;
isResetDisabled?: boolean;
isSliderDisabled?: boolean;
isInputDisabled?: boolean;
tooltipSuffix?: string;
hideTooltip?: boolean;
isCompact?: boolean;
isDisabled?: boolean;
sliderFormControlProps?: FormControlProps;
sliderFormLabelProps?: FormLabelProps;
sliderMarkProps?: Omit<SliderMarkProps, 'value'>;
@ -80,10 +78,8 @@ const IAISlider = (props: IAIFullSliderProps) => {
withReset = false,
hideTooltip = false,
isCompact = false,
isDisabled = false,
handleReset,
isResetDisabled,
isSliderDisabled,
isInputDisabled,
sliderFormControlProps,
sliderFormLabelProps,
sliderMarkProps,
@ -149,6 +145,7 @@ const IAISlider = (props: IAIFullSliderProps) => {
}
: {}
}
isDisabled={isDisabled}
{...sliderFormControlProps}
>
<FormLabel {...sliderFormLabelProps} mb={-1}>
@ -166,15 +163,13 @@ const IAISlider = (props: IAIFullSliderProps) => {
onMouseEnter={() => setShowTooltip(true)}
onMouseLeave={() => setShowTooltip(false)}
focusThumbOnChange={false}
isDisabled={isSliderDisabled}
// width={width}
isDisabled={isDisabled}
{...rest}
>
{withSliderMarks && (
<>
<SliderMark
value={min}
// insetInlineStart={0}
sx={{
insetInlineStart: '0 !important',
insetInlineEnd: 'unset !important',
@ -185,7 +180,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
</SliderMark>
<SliderMark
value={max}
// insetInlineEnd={0}
sx={{
insetInlineStart: 'unset !important',
insetInlineEnd: '0 !important',
@ -221,7 +215,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
value={localInputValue}
onChange={handleInputChange}
onBlur={handleInputBlur}
isDisabled={isInputDisabled}
{...sliderNumberInputProps}
>
<NumberInputField
@ -246,8 +239,8 @@ const IAISlider = (props: IAIFullSliderProps) => {
aria-label={t('accessibility.reset')}
tooltip="Reset"
icon={<BiReset />}
isDisabled={isDisabled}
onClick={handleResetDisable}
isDisabled={isResetDisabled}
{...sliderIAIIconButtonProps}
/>
)}

View File

@ -0,0 +1,79 @@
import { Badge, Box, ButtonGroup, Flex } from '@chakra-ui/react';
import { RootState } from 'app/store';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { clearInitialImage } from 'features/parameters/store/generationSlice';
import { useCallback } from 'react';
import IAIIconButton from 'common/components/IAIIconButton';
import { FaUndo, FaUpload } from 'react-icons/fa';
import { useTranslation } from 'react-i18next';
import { Image } from 'app/invokeai';
type ImageToImageOverlayProps = {
setIsLoaded: (isLoaded: boolean) => void;
image: Image;
};
const ImageToImageOverlay = ({
setIsLoaded,
image,
}: ImageToImageOverlayProps) => {
const isImageToImageEnabled = useAppSelector(
(state: RootState) => state.generation.isImageToImageEnabled
);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleResetInitialImage = useCallback(() => {
dispatch(clearInitialImage());
setIsLoaded(false);
}, [dispatch, setIsLoaded]);
return (
<Box
sx={{
top: 0,
left: 0,
w: 'full',
h: 'full',
position: 'absolute',
}}
>
<ButtonGroup
sx={{
position: 'absolute',
top: 0,
right: 0,
p: 2,
}}
>
<IAIIconButton
size="sm"
isDisabled={!isImageToImageEnabled}
icon={<FaUndo />}
aria-label={t('accessibility.reset')}
onClick={handleResetInitialImage}
/>
<IAIIconButton
size="sm"
isDisabled={!isImageToImageEnabled}
icon={<FaUpload />}
aria-label={t('common.upload')}
/>
</ButtonGroup>
<Flex
sx={{
position: 'absolute',
bottom: 0,
left: 0,
p: 2,
alignItems: 'flex-start',
}}
>
<Badge variant="solid" colorScheme="base">
{image.metadata?.width} × {image.metadata?.height}
</Badge>
</Flex>
</Box>
);
};
export default ImageToImageOverlay;

View File

@ -2,7 +2,6 @@ import { Box, useToast } from '@chakra-ui/react';
import { ImageUploaderTriggerContext } from 'app/contexts/ImageUploaderTriggerContext';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import useImageUploader from 'common/hooks/useImageUploader';
import { uploadImage } from 'features/gallery/store/thunks/uploadImage';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { ResourceKey } from 'i18next';
import {
@ -15,6 +14,7 @@ import {
} from 'react';
import { FileRejection, useDropzone } from 'react-dropzone';
import { useTranslation } from 'react-i18next';
import { imageUploaded } from 'services/thunks/image';
import ImageUploadOverlay from './ImageUploadOverlay';
type ImageUploaderProps = {
@ -49,7 +49,7 @@ const ImageUploader = (props: ImageUploaderProps) => {
const fileAcceptedCallback = useCallback(
async (file: File) => {
dispatch(uploadImage({ imageFile: file }));
dispatch(imageUploaded({ formData: { file } }));
},
[dispatch]
);
@ -124,7 +124,7 @@ const ImageUploader = (props: ImageUploaderProps) => {
return;
}
dispatch(uploadImage({ imageFile: file }));
dispatch(imageUploaded({ formData: { file } }));
};
document.addEventListener('paste', pasteImageListener);
return () => {

View File

@ -0,0 +1,12 @@
import { Flex, Icon } from '@chakra-ui/react';
import { FaImage } from 'react-icons/fa';
const SelectImagePlaceholder = () => {
return (
<Flex sx={{ h: 36, alignItems: 'center', justifyContent: 'center' }}>
<Icon color="base.400" boxSize={32} as={FaImage}></Icon>
</Flex>
);
};
export default SelectImagePlaceholder;

View File

@ -1,27 +1,160 @@
import { Flex, Heading, Text, VStack } from '@chakra-ui/react';
import { useTranslation } from 'react-i18next';
import WorkInProgress from './WorkInProgress';
// import WorkInProgress from './WorkInProgress';
// import ReactFlow, {
// applyEdgeChanges,
// applyNodeChanges,
// Background,
// Controls,
// Edge,
// Handle,
// Node,
// NodeTypes,
// OnEdgesChange,
// OnNodesChange,
// Position,
// } from 'reactflow';
export default function NodesWIP() {
const { t } = useTranslation();
return (
<WorkInProgress>
<Flex
sx={{
flexDirection: 'column',
alignItems: 'center',
justifyContent: 'center',
w: '100%',
h: '100%',
gap: 4,
textAlign: 'center',
}}
>
<Heading>{t('common.nodes')}</Heading>
<VStack maxW="50rem" gap={4}>
<Text>{t('common.nodesDesc')}</Text>
</VStack>
</Flex>
</WorkInProgress>
);
}
// import 'reactflow/dist/style.css';
// import {
// Fragment,
// FunctionComponent,
// ReactNode,
// useCallback,
// useMemo,
// useState,
// } from 'react';
// import { OpenAPIV3 } from 'openapi-types';
// import { filter, map, reduce } from 'lodash';
// import {
// Box,
// Flex,
// FormControl,
// FormLabel,
// Input,
// Select,
// Switch,
// Text,
// NumberInput,
// NumberInputField,
// NumberInputStepper,
// NumberIncrementStepper,
// NumberDecrementStepper,
// Tooltip,
// chakra,
// Badge,
// Heading,
// VStack,
// HStack,
// Menu,
// MenuButton,
// MenuList,
// MenuItem,
// MenuItemOption,
// MenuGroup,
// MenuOptionGroup,
// MenuDivider,
// IconButton,
// } from '@chakra-ui/react';
// import { FaPlus } from 'react-icons/fa';
// import {
// FIELD_NAMES as FIELD_NAMES,
// FIELDS,
// INVOCATION_NAMES as INVOCATION_NAMES,
// INVOCATIONS,
// } from 'features/nodeEditor/constants';
// console.log('invocations', INVOCATIONS);
// const nodeTypes = reduce(
// INVOCATIONS,
// (acc, val, key) => {
// acc[key] = val.component;
// return acc;
// },
// {} as NodeTypes
// );
// console.log('nodeTypes', nodeTypes);
// // make initial nodes one of every node for now
// let n = 0;
// const initialNodes = map(INVOCATIONS, (i) => ({
// id: i.type,
// type: i.title,
// position: { x: (n += 20), y: (n += 20) },
// data: {},
// }));
// console.log('initialNodes', initialNodes);
// export default function NodesWIP() {
// const [nodes, setNodes] = useState<Node[]>([]);
// const [edges, setEdges] = useState<Edge[]>([]);
// const onNodesChange: OnNodesChange = useCallback(
// (changes) => setNodes((nds) => applyNodeChanges(changes, nds)),
// []
// );
// const onEdgesChange: OnEdgesChange = useCallback(
// (changes) => setEdges((eds: Edge[]) => applyEdgeChanges(changes, eds)),
// []
// );
// return (
// <Box
// sx={{
// position: 'relative',
// width: 'full',
// height: 'full',
// borderRadius: 'md',
// }}
// >
// <ReactFlow
// nodeTypes={nodeTypes}
// nodes={nodes}
// edges={edges}
// onNodesChange={onNodesChange}
// onEdgesChange={onEdgesChange}
// >
// <Background />
// <Controls />
// </ReactFlow>
// <HStack sx={{ position: 'absolute', top: 2, right: 2 }}>
// {FIELD_NAMES.map((field) => (
// <Badge
// key={field}
// colorScheme={FIELDS[field].color}
// sx={{ userSelect: 'none' }}
// >
// {field}
// </Badge>
// ))}
// </HStack>
// <Menu>
// <MenuButton
// as={IconButton}
// aria-label="Options"
// icon={<FaPlus />}
// sx={{ position: 'absolute', top: 2, left: 2 }}
// />
// <MenuList>
// {INVOCATION_NAMES.map((name) => {
// const invocation = INVOCATIONS[name];
// return (
// <Tooltip
// key={name}
// label={invocation.description}
// placement="end"
// hasArrow
// >
// <MenuItem>{invocation.title}</MenuItem>
// </Tooltip>
// );
// })}
// </MenuList>
// </Menu>
// </Box>
// );
// }
export default {};

View File

@ -14,6 +14,8 @@ const WorkInProgress = (props: WorkInProgressProps) => {
width: '100%',
height: '100%',
bg: 'base.850',
borderRadius: 'base',
position: 'relative',
}}
>
{children}

View File

@ -0,0 +1,119 @@
/**
* PARTIAL ZOD IMPLEMENTATION
*
* doesn't work well bc like most validators, zod is not built to skip invalid values.
* it mostly works but just seems clearer and simpler to manually parse for now.
*
* in the future it would be really nice if we could use zod for some things:
* - zodios (axios + zod): https://github.com/ecyrbe/zodios
* - openapi to zodios: https://github.com/astahmer/openapi-zod-client
*/
// import { z } from 'zod';
// const zMetadataStringField = z.string();
// export type MetadataStringField = z.infer<typeof zMetadataStringField>;
// const zMetadataIntegerField = z.number().int();
// export type MetadataIntegerField = z.infer<typeof zMetadataIntegerField>;
// const zMetadataFloatField = z.number();
// export type MetadataFloatField = z.infer<typeof zMetadataFloatField>;
// const zMetadataBooleanField = z.boolean();
// export type MetadataBooleanField = z.infer<typeof zMetadataBooleanField>;
// const zMetadataImageField = z.object({
// image_type: z.union([
// z.literal('results'),
// z.literal('uploads'),
// z.literal('intermediates'),
// ]),
// image_name: z.string().min(1),
// });
// export type MetadataImageField = z.infer<typeof zMetadataImageField>;
// const zMetadataLatentsField = z.object({
// latents_name: z.string().min(1),
// });
// export type MetadataLatentsField = z.infer<typeof zMetadataLatentsField>;
// /**
// * zod Schema for any node field. Use a `transform()` to manually parse, skipping invalid values.
// */
// const zAnyMetadataField = z.any().transform((val, ctx) => {
// // Grab the field name from the path
// const fieldName = String(ctx.path[ctx.path.length - 1]);
// // `id` and `type` must be strings if they exist
// if (['id', 'type'].includes(fieldName)) {
// const reservedStringPropertyResult = zMetadataStringField.safeParse(val);
// if (reservedStringPropertyResult.success) {
// return reservedStringPropertyResult.data;
// }
// return;
// }
// // Parse the rest of the fields, only returning the data if the parsing is successful
// const stringFieldResult = zMetadataStringField.safeParse(val);
// if (stringFieldResult.success) {
// return stringFieldResult.data;
// }
// const integerFieldResult = zMetadataIntegerField.safeParse(val);
// if (integerFieldResult.success) {
// return integerFieldResult.data;
// }
// const floatFieldResult = zMetadataFloatField.safeParse(val);
// if (floatFieldResult.success) {
// return floatFieldResult.data;
// }
// const booleanFieldResult = zMetadataBooleanField.safeParse(val);
// if (booleanFieldResult.success) {
// return booleanFieldResult.data;
// }
// const imageFieldResult = zMetadataImageField.safeParse(val);
// if (imageFieldResult.success) {
// return imageFieldResult.data;
// }
// const latentsFieldResult = zMetadataImageField.safeParse(val);
// if (latentsFieldResult.success) {
// return latentsFieldResult.data;
// }
// });
// /**
// * The node metadata schema.
// */
// const zNodeMetadata = z.object({
// session_id: z.string().min(1).optional(),
// node: z.record(z.string().min(1), zAnyMetadataField).optional(),
// });
// export type NodeMetadata = z.infer<typeof zNodeMetadata>;
// const zMetadata = z.object({
// invokeai: zNodeMetadata.optional(),
// 'sd-metadata': z.record(z.string().min(1), z.any()).optional(),
// });
// export type Metadata = z.infer<typeof zMetadata>;
// export const parseMetadata = (
// metadata: Record<string, any>
// ): Metadata | undefined => {
// const result = zMetadata.safeParse(metadata);
// if (!result.success) {
// console.log(result.error.issues);
// return;
// }
// return result.data;
// };
export default {};

View File

@ -0,0 +1,6 @@
import dateFormat from 'dateformat';
/**
* Get a `now` timestamp with 1s precision, formatted as ISO datetime.
*/
export const getTimestamp = () => dateFormat(new Date(), 'isoDateTime');

View File

@ -0,0 +1,28 @@
import { RootState } from 'app/store';
import { useAppSelector } from 'app/storeHooks';
import { OpenAPI } from 'services/api';
export const getUrlAlt = (url: string, shouldTransformUrls: boolean) => {
if (OpenAPI.BASE && shouldTransformUrls) {
return [OpenAPI.BASE, url].join('/');
}
return url;
};
export const useGetUrl = () => {
const shouldTransformUrls = useAppSelector(
(state: RootState) => state.system.shouldTransformUrls
);
return {
shouldTransformUrls,
getUrl: (url?: string) => {
if (OpenAPI.BASE && shouldTransformUrls) {
return [OpenAPI.BASE, url].join('/');
}
return url;
},
};
};

View File

@ -0,0 +1,169 @@
import { forEach, size } from 'lodash';
import { ImageField, LatentsField } from 'services/api';
const OBJECT_TYPESTRING = '[object Object]';
const STRING_TYPESTRING = '[object String]';
const NUMBER_TYPESTRING = '[object Number]';
const BOOLEAN_TYPESTRING = '[object Boolean]';
const ARRAY_TYPESTRING = '[object Array]';
const isObject = (obj: unknown): obj is Record<string | number, any> =>
Object.prototype.toString.call(obj) === OBJECT_TYPESTRING;
const isString = (obj: unknown): obj is string =>
Object.prototype.toString.call(obj) === STRING_TYPESTRING;
const isNumber = (obj: unknown): obj is number =>
Object.prototype.toString.call(obj) === NUMBER_TYPESTRING;
const isBoolean = (obj: unknown): obj is boolean =>
Object.prototype.toString.call(obj) === BOOLEAN_TYPESTRING;
const isArray = (obj: unknown): obj is Array<any> =>
Object.prototype.toString.call(obj) === ARRAY_TYPESTRING;
const parseImageField = (imageField: unknown): ImageField | undefined => {
// Must be an object
if (!isObject(imageField)) {
return;
}
// An ImageField must have both `image_name` and `image_type`
if (!('image_name' in imageField && 'image_type' in imageField)) {
return;
}
// An ImageField's `image_type` must be one of the allowed values
if (
!['results', 'uploads', 'intermediates'].includes(imageField.image_type)
) {
return;
}
// An ImageField's `image_name` must be a string
if (typeof imageField.image_name !== 'string') {
return;
}
// Build a valid ImageField
return {
image_type: imageField.image_type,
image_name: imageField.image_name,
};
};
const parseLatentsField = (latentsField: unknown): LatentsField | undefined => {
// Must be an object
if (!isObject(latentsField)) {
return;
}
// A LatentsField must have a `latents_name`
if (!('latents_name' in latentsField)) {
return;
}
// A LatentsField's `latents_name` must be a string
if (typeof latentsField.latents_name !== 'string') {
return;
}
// Build a valid LatentsField
return {
latents_name: latentsField.latents_name,
};
};
type NodeMetadata = {
[key: string]: string | number | boolean | ImageField | LatentsField;
};
type InvokeAIMetadata = {
session_id?: string;
node?: NodeMetadata;
};
export const parseNodeMetadata = (
nodeMetadata: Record<string | number, any>
): NodeMetadata | undefined => {
if (!isObject(nodeMetadata)) {
return;
}
const parsed: NodeMetadata = {};
forEach(nodeMetadata, (nodeItem, nodeKey) => {
// `id` and `type` must be strings if they are present
if (['id', 'type'].includes(nodeKey)) {
if (isString(nodeItem)) {
parsed[nodeKey] = nodeItem;
}
return;
}
// the only valid object types are ImageField and LatentsField
if (isObject(nodeItem)) {
if ('image_name' in nodeItem || 'image_type' in nodeItem) {
const imageField = parseImageField(nodeItem);
if (imageField) {
parsed[nodeKey] = imageField;
}
return;
}
if ('latents_name' in nodeItem) {
const latentsField = parseLatentsField(nodeItem);
if (latentsField) {
parsed[nodeKey] = latentsField;
}
return;
}
}
// otherwise we accept any string, number or boolean
if (isString(nodeItem) || isNumber(nodeItem) || isBoolean(nodeItem)) {
parsed[nodeKey] = nodeItem;
return;
}
});
if (size(parsed) === 0) {
return;
}
return parsed;
};
export const parseInvokeAIMetadata = (
metadata: Record<string | number, any> | undefined
): InvokeAIMetadata | undefined => {
if (metadata === undefined) {
return;
}
if (!isObject(metadata)) {
return;
}
const parsed: InvokeAIMetadata = {};
forEach(metadata, (item, key) => {
if (key === 'session_id' && isString(item)) {
parsed['session_id'] = item;
}
if (key === 'node' && isObject(item)) {
const nodeMetadata = parseNodeMetadata(item);
if (nodeMetadata) {
parsed['node'] = nodeMetadata;
}
}
});
if (size(parsed) === 0) {
return;
}
return parsed;
};

View File

@ -1,8 +1,10 @@
import React, { lazy, PropsWithChildren } from 'react';
import React, { lazy, PropsWithChildren, useEffect, useState } from 'react';
import { Provider } from 'react-redux';
import { PersistGate } from 'redux-persist/integration/react';
import { store } from './app/store';
import { buildMiddleware, store } from './app/store';
import { persistor } from './persistor';
import { OpenAPI } from 'services/api';
import { InvokeTabName } from 'features/ui/store/tabMap';
import '@fontsource/inter/100.css';
import '@fontsource/inter/200.css';
import '@fontsource/inter/300.css';
@ -17,18 +19,61 @@ import Loading from './Loading';
// Localization
import './i18n';
import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares';
const App = lazy(() => import('./app/App'));
const ThemeLocaleProvider = lazy(() => import('./app/ThemeLocaleProvider'));
export default function Component(props: PropsWithChildren) {
interface Props extends PropsWithChildren {
apiUrl?: string;
disabledPanels?: string[];
disabledTabs?: InvokeTabName[];
token?: string;
shouldTransformUrls?: boolean;
}
export default function Component({
apiUrl,
disabledPanels = [],
disabledTabs = [],
token,
children,
shouldTransformUrls,
}: Props) {
useEffect(() => {
// configure API client token
if (token) {
OpenAPI.TOKEN = token;
}
// configure API client base url
if (apiUrl) {
OpenAPI.BASE = apiUrl;
}
// reset dynamically added middlewares
resetMiddlewares();
// TODO: at this point, after resetting the middleware, we really ought to clean up the socket
// stuff by calling `dispatch(socketReset())`. but we cannot dispatch from here as we are
// outside the provider. it's not needed until there is the possibility that we will change
// the `apiUrl`/`token` dynamically.
// rebuild socket middleware with token and apiUrl
addMiddleware(buildMiddleware());
}, [apiUrl, token]);
return (
<React.StrictMode>
<Provider store={store}>
<PersistGate loading={<Loading />} persistor={persistor}>
<React.Suspense fallback={<Loading showText />}>
<ThemeLocaleProvider>
<App>{props.children}</App>
<App
options={{ disabledPanels, disabledTabs, shouldTransformUrls }}
>
{children}
</App>
</ThemeLocaleProvider>
</React.Suspense>
</PersistGate>

View File

@ -5,6 +5,8 @@ import ThemeChanger from './features/system/components/ThemeChanger';
import IAIPopover from './common/components/IAIPopover';
import IAIIconButton from './common/components/IAIIconButton';
import SettingsModal from './features/system/components/SettingsModal/SettingsModal';
import StatusIndicator from './features/system/components/StatusIndicator';
import ModelSelect from 'features/system/components/ModelSelect';
export default Component;
export {
@ -13,4 +15,6 @@ export {
IAIPopover,
IAIIconButton,
SettingsModal,
StatusIndicator,
ModelSelect,
};

View File

@ -1,6 +1,7 @@
import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store';
import { useAppSelector } from 'app/storeHooks';
import { useGetUrl } from 'common/util/getUrl';
import { GalleryState } from 'features/gallery/store/gallerySlice';
import { ImageConfig } from 'konva/lib/shapes/Image';
import { isEqual } from 'lodash';
@ -25,7 +26,7 @@ type Props = Omit<ImageConfig, 'image'>;
const IAICanvasIntermediateImage = (props: Props) => {
const { ...rest } = props;
const intermediateImage = useAppSelector(selector);
const { getUrl } = useGetUrl();
const [loadedImageElement, setLoadedImageElement] =
useState<HTMLImageElement | null>(null);
@ -36,8 +37,8 @@ const IAICanvasIntermediateImage = (props: Props) => {
tempImage.onload = () => {
setLoadedImageElement(tempImage);
};
tempImage.src = intermediateImage.url;
}, [intermediateImage]);
tempImage.src = getUrl(intermediateImage.url);
}, [intermediateImage, getUrl]);
if (!intermediateImage?.boundingBox) return null;

View File

@ -1,5 +1,6 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/storeHooks';
import { useGetUrl } from 'common/util/getUrl';
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import { rgbaColorToString } from 'features/canvas/util/colorToString';
import { isEqual } from 'lodash';
@ -32,6 +33,7 @@ const selector = createSelector(
const IAICanvasObjectRenderer = () => {
const { objects } = useAppSelector(selector);
const { getUrl } = useGetUrl();
if (!objects) return null;
@ -40,7 +42,12 @@ const IAICanvasObjectRenderer = () => {
{objects.map((obj, i) => {
if (isCanvasBaseImage(obj)) {
return (
<IAICanvasImage key={i} x={obj.x} y={obj.y} url={obj.image.url} />
<IAICanvasImage
key={i}
x={obj.x}
y={obj.y}
url={getUrl(obj.image.url)}
/>
);
} else if (isCanvasBaseLine(obj)) {
const line = (

View File

@ -1,5 +1,6 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/storeHooks';
import { useGetUrl } from 'common/util/getUrl';
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import { GroupConfig } from 'konva/lib/Group';
import { isEqual } from 'lodash';
@ -53,11 +54,16 @@ const IAICanvasStagingArea = (props: Props) => {
width,
height,
} = useAppSelector(selector);
const { getUrl } = useGetUrl();
return (
<Group {...rest}>
{shouldShowStagingImage && currentStagingAreaImage && (
<IAICanvasImage url={currentStagingAreaImage.image.url} x={x} y={y} />
<IAICanvasImage
url={getUrl(currentStagingAreaImage.image.url)}
x={x}
y={y}
/>
)}
{shouldShowStagingOutline && (
<Group>

View File

@ -0,0 +1,14 @@
import { CanvasState } from './canvasTypes';
/**
* Canvas slice persist blacklist
*/
const itemsToBlacklist: (keyof CanvasState)[] = [
'cursorPosition',
'isCanvasInitialized',
'doesCanvasNeedScaling',
];
export const canvasBlacklist = itemsToBlacklist.map(
(blacklistItem) => `canvas.${blacklistItem}`
);

View File

@ -156,7 +156,7 @@ export const canvasSlice = createSlice({
setCursorPosition: (state, action: PayloadAction<Vector2d | null>) => {
state.cursorPosition = action.payload;
},
setInitialCanvasImage: (state, action: PayloadAction<InvokeAI.Image>) => {
setInitialCanvasImage: (state, action: PayloadAction<InvokeAI._Image>) => {
const image = action.payload;
const { stageDimensions } = state;
@ -291,7 +291,7 @@ export const canvasSlice = createSlice({
state,
action: PayloadAction<{
boundingBox: IRect;
image: InvokeAI.Image;
image: InvokeAI._Image;
}>
) => {
const { boundingBox, image } = action.payload;

View File

@ -37,7 +37,7 @@ export type CanvasImage = {
y: number;
width: number;
height: number;
image: InvokeAI.Image;
image: InvokeAI._Image;
};
export type CanvasMaskLine = {
@ -125,7 +125,7 @@ export interface CanvasState {
cursorPosition: Vector2d | null;
doesCanvasNeedScaling: boolean;
futureLayerStates: CanvasLayerState[];
intermediateImage?: InvokeAI.Image;
intermediateImage?: InvokeAI._Image;
isCanvasInitialized: boolean;
isDrawing: boolean;
isMaskEnabled: boolean;

View File

@ -105,7 +105,7 @@ export const mergeAndUploadCanvas =
const { url, width, height } = image;
const newImage: InvokeAI.Image = {
const newImage: InvokeAI._Image = {
uuid: uuidv4(),
category: shouldSaveToGallery ? 'result' : 'user',
...image,

View File

@ -14,8 +14,9 @@ import { setIsLightboxOpen } from 'features/lightbox/store/lightboxSlice';
import FaceRestoreSettings from 'features/parameters/components/AdvancedParameters/FaceRestore/FaceRestoreSettings';
import UpscaleSettings from 'features/parameters/components/AdvancedParameters/Upscale/UpscaleSettings';
import {
initialImageSelected,
setAllParameters,
setInitialImage,
// setInitialImage,
setSeed,
} from 'features/parameters/store/generationSlice';
import { postprocessingSelector } from 'features/parameters/store/postprocessingSelectors';
@ -48,11 +49,15 @@ import {
FaShareAlt,
FaTrash,
} from 'react-icons/fa';
import { gallerySelector } from '../store/gallerySelectors';
import {
gallerySelector,
selectedImageSelector,
} from '../store/gallerySelectors';
import DeleteImageModal from './DeleteImageModal';
import { useCallback } from 'react';
import useSetBothPrompts from 'features/parameters/hooks/usePrompt';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import { useGetUrl } from 'common/util/getUrl';
const currentImageButtonsSelector = createSelector(
[
@ -62,6 +67,7 @@ const currentImageButtonsSelector = createSelector(
uiSelector,
lightboxSelector,
activeTabNameSelector,
selectedImageSelector,
],
(
system: SystemState,
@ -69,7 +75,8 @@ const currentImageButtonsSelector = createSelector(
postprocessing,
ui,
lightbox,
activeTabName
activeTabName,
selectedImage
) => {
const { isProcessing, isConnected, isGFPGANAvailable, isESRGANAvailable } =
system;
@ -95,6 +102,7 @@ const currentImageButtonsSelector = createSelector(
activeTabName,
isLightboxOpen,
shouldHidePreview,
selectedImage,
};
},
{
@ -121,27 +129,33 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
facetoolStrength,
shouldDisableToolbarButtons,
shouldShowImageDetails,
currentImage,
// currentImage,
isLightboxOpen,
activeTabName,
shouldHidePreview,
selectedImage,
} = useAppSelector(currentImageButtonsSelector);
const { getUrl, shouldTransformUrls } = useGetUrl();
const toast = useToast();
const { t } = useTranslation();
const setBothPrompts = useSetBothPrompts();
const handleClickUseAsInitialImage = () => {
if (!currentImage) return;
if (!selectedImage) return;
if (isLightboxOpen) dispatch(setIsLightboxOpen(false));
dispatch(setInitialImage(currentImage));
dispatch(setActiveTab('img2img'));
dispatch(initialImageSelected(selectedImage.name));
// dispatch(setInitialImage(currentImage));
// dispatch(setActiveTab('img2img'));
};
const handleCopyImage = async () => {
if (!currentImage) return;
if (!selectedImage) return;
const blob = await fetch(currentImage.url).then((res) => res.blob());
const blob = await fetch(getUrl(selectedImage.url)).then((res) =>
res.blob()
);
const data = [new ClipboardItem({ [blob.type]: blob })];
await navigator.clipboard.write(data);
@ -155,24 +169,26 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
};
const handleCopyImageLink = () => {
navigator.clipboard
.writeText(
currentImage ? window.location.toString() + currentImage.url : ''
)
.then(() => {
toast({
title: t('toast.imageLinkCopied'),
status: 'success',
duration: 2500,
isClosable: true,
});
const url = selectedImage
? shouldTransformUrls
? getUrl(selectedImage.url)
: window.location.toString() + selectedImage.url
: '';
navigator.clipboard.writeText(url).then(() => {
toast({
title: t('toast.imageLinkCopied'),
status: 'success',
duration: 2500,
isClosable: true,
});
});
};
useHotkeys(
'shift+i',
() => {
if (currentImage) {
if (selectedImage) {
handleClickUseAsInitialImage();
toast({
title: t('toast.sentToImageToImage'),
@ -190,7 +206,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
});
}
},
[currentImage]
[selectedImage]
);
const handlePreviewVisibility = () => {
@ -198,20 +214,23 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
};
const handleClickUseAllParameters = () => {
if (!currentImage) return;
currentImage.metadata && dispatch(setAllParameters(currentImage.metadata));
if (currentImage.metadata?.image.type === 'img2img') {
dispatch(setActiveTab('img2img'));
} else if (currentImage.metadata?.image.type === 'txt2img') {
dispatch(setActiveTab('txt2img'));
}
if (!selectedImage) return;
// selectedImage.metadata &&
// dispatch(setAllParameters(selectedImage.metadata));
// if (selectedImage.metadata?.image.type === 'img2img') {
// dispatch(setActiveTab('img2img'));
// } else if (selectedImage.metadata?.image.type === 'txt2img') {
// dispatch(setActiveTab('txt2img'));
// }
};
useHotkeys(
'a',
() => {
if (
['txt2img', 'img2img'].includes(currentImage?.metadata?.image?.type)
['txt2img', 'img2img'].includes(
selectedImage?.metadata?.sd_metadata?.type
)
) {
handleClickUseAllParameters();
toast({
@ -230,18 +249,18 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
});
}
},
[currentImage]
[selectedImage]
);
const handleClickUseSeed = () => {
currentImage?.metadata &&
dispatch(setSeed(currentImage.metadata.image.seed));
selectedImage?.metadata &&
dispatch(setSeed(selectedImage.metadata.sd_metadata.seed));
};
useHotkeys(
's',
() => {
if (currentImage?.metadata?.image?.seed) {
if (selectedImage?.metadata?.sd_metadata?.seed) {
handleClickUseSeed();
toast({
title: t('toast.seedSet'),
@ -259,19 +278,19 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
});
}
},
[currentImage]
[selectedImage]
);
const handleClickUsePrompt = useCallback(() => {
if (currentImage?.metadata?.image?.prompt) {
setBothPrompts(currentImage?.metadata?.image?.prompt);
if (selectedImage?.metadata?.sd_metadata?.prompt) {
setBothPrompts(selectedImage?.metadata?.sd_metadata?.prompt);
}
}, [currentImage?.metadata?.image?.prompt, setBothPrompts]);
}, [selectedImage?.metadata?.sd_metadata?.prompt, setBothPrompts]);
useHotkeys(
'p',
() => {
if (currentImage?.metadata?.image?.prompt) {
if (selectedImage?.metadata?.sd_metadata?.prompt) {
handleClickUsePrompt();
toast({
title: t('toast.promptSet'),
@ -289,11 +308,11 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
});
}
},
[currentImage]
[selectedImage]
);
const handleClickUpscale = () => {
currentImage && dispatch(runESRGAN(currentImage));
// selectedImage && dispatch(runESRGAN(selectedImage));
};
useHotkeys(
@ -317,7 +336,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
}
},
[
currentImage,
selectedImage,
isESRGANAvailable,
shouldDisableToolbarButtons,
isConnected,
@ -327,7 +346,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
);
const handleClickFixFaces = () => {
currentImage && dispatch(runFacetool(currentImage));
// selectedImage && dispatch(runFacetool(selectedImage));
};
useHotkeys(
@ -351,7 +370,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
}
},
[
currentImage,
selectedImage,
isGFPGANAvailable,
shouldDisableToolbarButtons,
isConnected,
@ -364,10 +383,10 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
dispatch(setShouldShowImageDetails(!shouldShowImageDetails));
const handleSendToCanvas = () => {
if (!currentImage) return;
if (!selectedImage) return;
if (isLightboxOpen) dispatch(setIsLightboxOpen(false));
dispatch(setInitialCanvasImage(currentImage));
// dispatch(setInitialCanvasImage(selectedImage));
dispatch(requestCanvasRescale());
if (activeTabName !== 'unifiedCanvas') {
@ -385,7 +404,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
useHotkeys(
'i',
() => {
if (currentImage) {
if (selectedImage) {
handleClickShowImageDetails();
} else {
toast({
@ -396,7 +415,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
});
}
},
[currentImage, shouldShowImageDetails]
[selectedImage, shouldShowImageDetails]
);
const handleLightBox = () => {
@ -458,7 +477,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
{t('parameters.copyImageToLink')}
</IAIButton>
<Link download={true} href={currentImage?.url}>
<Link download={true} href={getUrl(selectedImage!.url)}>
<IAIButton leftIcon={<FaDownload />} size="sm" w="100%">
{t('parameters.downloadImage')}
</IAIButton>
@ -502,7 +521,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
icon={<FaQuoteRight />}
tooltip={`${t('parameters.usePrompt')} (P)`}
aria-label={`${t('parameters.usePrompt')} (P)`}
isDisabled={!currentImage?.metadata?.image?.prompt}
isDisabled={!selectedImage?.metadata?.sd_metadata?.prompt}
onClick={handleClickUsePrompt}
/>
@ -510,7 +529,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
icon={<FaSeedling />}
tooltip={`${t('parameters.useSeed')} (S)`}
aria-label={`${t('parameters.useSeed')} (S)`}
isDisabled={!currentImage?.metadata?.image?.seed}
isDisabled={!selectedImage?.metadata?.sd_metadata?.seed}
onClick={handleClickUseSeed}
/>
@ -520,7 +539,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
aria-label={`${t('parameters.useAll')} (A)`}
isDisabled={
!['txt2img', 'img2img'].includes(
currentImage?.metadata?.image?.type
selectedImage?.metadata?.sd_metadata?.type
)
}
onClick={handleClickUseAllParameters}
@ -546,7 +565,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
<IAIButton
isDisabled={
!isGFPGANAvailable ||
!currentImage ||
!selectedImage ||
!(isConnected && !isProcessing) ||
!facetoolStrength
}
@ -575,7 +594,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
<IAIButton
isDisabled={
!isESRGANAvailable ||
!currentImage ||
!selectedImage ||
!(isConnected && !isProcessing) ||
!upscalingLevel
}
@ -597,15 +616,15 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
/>
</ButtonGroup>
<DeleteImageModal image={currentImage}>
{/* <DeleteImageModal image={selectedImage}>
<IAIIconButton
icon={<FaTrash />}
tooltip={`${t('parameters.deleteImage')} (Del)`}
aria-label={`${t('parameters.deleteImage')} (Del)`}
isDisabled={!currentImage || !isConnected || isProcessing}
isDisabled={!selectedImage || !isConnected || isProcessing}
colorScheme="error"
/>
</DeleteImageModal>
</DeleteImageModal> */}
</Flex>
);
};

View File

@ -4,17 +4,20 @@ import { useAppSelector } from 'app/storeHooks';
import { isEqual } from 'lodash';
import { MdPhoto } from 'react-icons/md';
import { gallerySelector } from '../store/gallerySelectors';
import {
gallerySelector,
selectedImageSelector,
} from '../store/gallerySelectors';
import CurrentImageButtons from './CurrentImageButtons';
import CurrentImagePreview from './CurrentImagePreview';
export const currentImageDisplaySelector = createSelector(
[gallerySelector],
(gallery) => {
[gallerySelector, selectedImageSelector],
(gallery, selectedImage) => {
const { currentImage, intermediateImage } = gallery;
return {
hasAnImageToDisplay: currentImage || intermediateImage,
hasAnImageToDisplay: selectedImage || intermediateImage,
};
},
{

View File

@ -1,28 +1,48 @@
import { Box, Flex, Image } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/storeHooks';
import { GalleryState } from 'features/gallery/store/gallerySlice';
import { useGetUrl } from 'common/util/getUrl';
import { systemSelector } from 'features/system/store/systemSelectors';
import { uiSelector } from 'features/ui/store/uiSelectors';
import { isEqual } from 'lodash';
import { ReactEventHandler } from 'react';
import { APP_METADATA_HEIGHT } from 'theme/util/constants';
import { gallerySelector } from '../store/gallerySelectors';
import { selectedImageSelector } from '../store/gallerySelectors';
import CurrentImageFallback from './CurrentImageFallback';
import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer';
import NextPrevImageButtons from './NextPrevImageButtons';
import CurrentImageHidden from './CurrentImageHidden';
export const imagesSelector = createSelector(
[gallerySelector, uiSelector],
(gallery: GalleryState, ui) => {
const { currentImage, intermediateImage } = gallery;
[uiSelector, selectedImageSelector, systemSelector],
(ui, selectedImage, system) => {
const { shouldShowImageDetails, shouldHidePreview } = ui;
const { progressImage } = system;
// TODO: Clean this up, this is really gross
const imageToDisplay = progressImage
? {
url: progressImage.dataURL,
width: progressImage.width,
height: progressImage.height,
isProgressImage: true,
image: progressImage,
}
: selectedImage
? {
url: selectedImage.url,
width: selectedImage.metadata.width,
height: selectedImage.metadata.height,
isProgressImage: false,
image: selectedImage,
}
: null;
return {
imageToDisplay: intermediateImage ? intermediateImage : currentImage,
isIntermediate: Boolean(intermediateImage),
shouldShowImageDetails,
shouldHidePreview,
imageToDisplay,
};
},
{
@ -33,12 +53,9 @@ export const imagesSelector = createSelector(
);
export default function CurrentImagePreview() {
const {
shouldShowImageDetails,
imageToDisplay,
isIntermediate,
shouldHidePreview,
} = useAppSelector(imagesSelector);
const { shouldShowImageDetails, imageToDisplay, shouldHidePreview } =
useAppSelector(imagesSelector);
const { getUrl } = useGetUrl();
return (
<Flex
@ -52,13 +69,19 @@ export default function CurrentImagePreview() {
>
{imageToDisplay && (
<Image
src={shouldHidePreview ? undefined : imageToDisplay.url}
src={
shouldHidePreview
? undefined
: imageToDisplay.isProgressImage
? imageToDisplay.url
: getUrl(imageToDisplay.url)
}
width={imageToDisplay.width}
height={imageToDisplay.height}
fallback={
shouldHidePreview ? (
<CurrentImageHidden />
) : !isIntermediate ? (
) : !imageToDisplay.isProgressImage ? (
<CurrentImageFallback />
) : undefined
}
@ -68,27 +91,31 @@ export default function CurrentImagePreview() {
maxHeight: '100%',
height: 'auto',
position: 'absolute',
imageRendering: isIntermediate ? 'pixelated' : 'initial',
imageRendering: imageToDisplay.isProgressImage
? 'pixelated'
: 'initial',
borderRadius: 'base',
}}
/>
)}
{!shouldShowImageDetails && <NextPrevImageButtons />}
{shouldShowImageDetails && imageToDisplay && (
<Box
sx={{
position: 'absolute',
top: '0',
width: '100%',
height: '100%',
borderRadius: 'base',
overflow: 'scroll',
maxHeight: APP_METADATA_HEIGHT,
}}
>
<ImageMetadataViewer image={imageToDisplay} />
</Box>
)}
{shouldShowImageDetails &&
imageToDisplay &&
'metadata' in imageToDisplay.image && (
<Box
sx={{
position: 'absolute',
top: '0',
width: '100%',
height: '100%',
borderRadius: 'base',
overflow: 'scroll',
maxHeight: APP_METADATA_HEIGHT,
}}
>
<ImageMetadataViewer image={imageToDisplay.image} />
</Box>
)}
</Flex>
);
}

View File

@ -52,7 +52,7 @@ interface DeleteImageModalProps {
/**
* The image to delete.
*/
image?: InvokeAI.Image;
image?: InvokeAI._Image;
}
/**

View File

@ -9,11 +9,14 @@ import {
useToast,
} from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { setCurrentImage } from 'features/gallery/store/gallerySlice';
import {
imageSelected,
setCurrentImage,
} from 'features/gallery/store/gallerySlice';
import {
initialImageSelected,
setAllImageToImageParameters,
setAllParameters,
setInitialImage,
setSeed,
} from 'features/parameters/store/generationSlice';
import { DragEvent, memo, useState } from 'react';
@ -31,6 +34,7 @@ import { useTranslation } from 'react-i18next';
import useSetBothPrompts from 'features/parameters/hooks/usePrompt';
import { setIsLightboxOpen } from 'features/lightbox/store/lightboxSlice';
import IAIIconButton from 'common/components/IAIIconButton';
import { useGetUrl } from 'common/util/getUrl';
interface HoverableImageProps {
image: InvokeAI.Image;
@ -40,7 +44,7 @@ interface HoverableImageProps {
const memoEqualityCheck = (
prev: HoverableImageProps,
next: HoverableImageProps
) => prev.image.uuid === next.image.uuid && prev.isSelected === next.isSelected;
) => prev.image.name === next.image.name && prev.isSelected === next.isSelected;
/**
* Gallery image component with delete/use all/use seed buttons on hover.
@ -55,7 +59,8 @@ const HoverableImage = memo((props: HoverableImageProps) => {
shouldUseSingleGalleryColumn,
} = useAppSelector(hoverableImageSelector);
const { image, isSelected } = props;
const { url, thumbnail, uuid, metadata } = image;
const { url, thumbnail, name, metadata } = image;
const { getUrl } = useGetUrl();
const [isHovered, setIsHovered] = useState<boolean>(false);
@ -69,10 +74,9 @@ const HoverableImage = memo((props: HoverableImageProps) => {
const handleMouseOut = () => setIsHovered(false);
const handleUsePrompt = () => {
if (image.metadata?.image?.prompt) {
setBothPrompts(image.metadata?.image?.prompt);
if (image.metadata?.sd_metadata?.prompt) {
setBothPrompts(image.metadata?.sd_metadata?.prompt);
}
toast({
title: t('toast.promptSet'),
status: 'success',
@ -82,7 +86,8 @@ const HoverableImage = memo((props: HoverableImageProps) => {
};
const handleUseSeed = () => {
image.metadata && dispatch(setSeed(image.metadata.image.seed));
image.metadata.sd_metadata &&
dispatch(setSeed(image.metadata.sd_metadata.image.seed));
toast({
title: t('toast.seedSet'),
status: 'success',
@ -92,20 +97,11 @@ const HoverableImage = memo((props: HoverableImageProps) => {
};
const handleSendToImageToImage = () => {
dispatch(setInitialImage(image));
if (activeTabName !== 'img2img') {
dispatch(setActiveTab('img2img'));
}
toast({
title: t('toast.sentToImageToImage'),
status: 'success',
duration: 2500,
isClosable: true,
});
dispatch(initialImageSelected(image.name));
};
const handleSendToCanvas = () => {
dispatch(setInitialCanvasImage(image));
// dispatch(setInitialCanvasImage(image));
dispatch(resizeAndScaleCanvas());
@ -122,7 +118,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
};
const handleUseAllParameters = () => {
metadata && dispatch(setAllParameters(metadata));
metadata.sd_metadata && dispatch(setAllParameters(metadata.sd_metadata));
toast({
title: t('toast.parametersSet'),
status: 'success',
@ -132,11 +128,13 @@ const HoverableImage = memo((props: HoverableImageProps) => {
};
const handleUseInitialImage = async () => {
if (metadata?.image?.init_image_path) {
const response = await fetch(metadata.image.init_image_path);
if (metadata.sd_metadata?.image?.init_image_path) {
const response = await fetch(
metadata.sd_metadata?.image?.init_image_path
);
if (response.ok) {
dispatch(setActiveTab('img2img'));
dispatch(setAllImageToImageParameters(metadata));
dispatch(setAllImageToImageParameters(metadata?.sd_metadata));
toast({
title: t('toast.initialImageSet'),
status: 'success',
@ -155,16 +153,20 @@ const HoverableImage = memo((props: HoverableImageProps) => {
});
};
const handleSelectImage = () => dispatch(setCurrentImage(image));
const handleSelectImage = () => {
dispatch(imageSelected(image.name));
};
const handleDragStart = (e: DragEvent<HTMLDivElement>) => {
e.dataTransfer.setData('invokeai/imageUuid', uuid);
console.log('drag started');
e.dataTransfer.setData('invokeai/imageName', image.name);
e.dataTransfer.setData('invokeai/imageType', image.type);
e.dataTransfer.effectAllowed = 'move';
};
const handleLightBox = () => {
dispatch(setCurrentImage(image));
dispatch(setIsLightboxOpen(true));
// dispatch(setCurrentImage(image));
// dispatch(setIsLightboxOpen(true));
};
return (
@ -177,28 +179,30 @@ const HoverableImage = memo((props: HoverableImageProps) => {
</MenuItem>
<MenuItem
onClickCapture={handleUsePrompt}
isDisabled={image?.metadata?.image?.prompt === undefined}
isDisabled={image?.metadata?.sd_metadata?.prompt === undefined}
>
{t('parameters.usePrompt')}
</MenuItem>
<MenuItem
onClickCapture={handleUseSeed}
isDisabled={image?.metadata?.image?.seed === undefined}
isDisabled={image?.metadata?.sd_metadata?.seed === undefined}
>
{t('parameters.useSeed')}
</MenuItem>
<MenuItem
onClickCapture={handleUseAllParameters}
isDisabled={
!['txt2img', 'img2img'].includes(image?.metadata?.image?.type)
!['txt2img', 'img2img'].includes(
image?.metadata?.sd_metadata?.type
)
}
>
{t('parameters.useAll')}
</MenuItem>
<MenuItem
onClickCapture={handleUseInitialImage}
isDisabled={image?.metadata?.image?.type !== 'img2img'}
isDisabled={image?.metadata?.sd_metadata?.type !== 'img2img'}
>
{t('parameters.useInitImg')}
</MenuItem>
@ -209,9 +213,9 @@ const HoverableImage = memo((props: HoverableImageProps) => {
{t('parameters.sendToUnifiedCanvas')}
</MenuItem>
<MenuItem data-warning>
<DeleteImageModal image={image}>
{/* <DeleteImageModal image={image}>
<p>{t('parameters.deleteImage')}</p>
</DeleteImageModal>
</DeleteImageModal> */}
</MenuItem>
</MenuList>
)}
@ -219,7 +223,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
{(ref) => (
<Box
position="relative"
key={uuid}
key={name}
onMouseOver={handleMouseOver}
onMouseOut={handleMouseOut}
userSelect="none"
@ -244,7 +248,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
shouldUseSingleGalleryColumn ? 'contain' : galleryImageObjectFit
}
rounded="md"
src={thumbnail || url}
src={getUrl(thumbnail || url)}
loading="lazy"
sx={{
position: 'absolute',
@ -290,7 +294,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
insetInlineEnd: 1,
}}
>
<DeleteImageModal image={image}>
{/* <DeleteImageModal image={image}>
<IAIIconButton
aria-label={t('parameters.deleteImage')}
icon={<FaTrashAlt />}
@ -298,7 +302,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
fontSize={14}
isDisabled={!mayDeleteImage}
/>
</DeleteImageModal>
</DeleteImageModal> */}
</Box>
)}
</Box>

View File

@ -1,4 +1,4 @@
import { ButtonGroup, Flex, Grid, Icon, Text } from '@chakra-ui/react';
import { ButtonGroup, Flex, Grid, Icon, Image, Text } from '@chakra-ui/react';
import { requestImages } from 'app/socketio/actions';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import IAIButton from 'common/components/IAIButton';
@ -25,9 +25,44 @@ import HoverableImage from './HoverableImage';
import Scrollable from 'features/ui/components/common/Scrollable';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import {
resultsAdapter,
selectResultsAll,
selectResultsTotal,
} from '../store/resultsSlice';
import {
receivedResultImagesPage,
receivedUploadImagesPage,
} from 'services/thunks/gallery';
import { selectUploadsAll, uploadsAdapter } from '../store/uploadsSlice';
import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store';
const GALLERY_SHOW_BUTTONS_MIN_WIDTH = 290;
const gallerySelector = createSelector(
[
(state: RootState) => state.uploads,
(state: RootState) => state.results,
(state: RootState) => state.gallery,
],
(uploads, results, gallery) => {
const { currentCategory } = gallery;
return currentCategory === 'result'
? {
images: resultsAdapter.getSelectors().selectAll(results),
isLoading: results.isLoading,
areMoreImagesAvailable: results.page < results.pages - 1,
}
: {
images: uploadsAdapter.getSelectors().selectAll(uploads),
isLoading: uploads.isLoading,
areMoreImagesAvailable: uploads.page < uploads.pages - 1,
};
}
);
const ImageGalleryContent = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
@ -35,7 +70,7 @@ const ImageGalleryContent = () => {
const [shouldShouldIconButtons, setShouldShouldIconButtons] = useState(true);
const {
images,
// images,
currentCategory,
currentImageUuid,
shouldPinGallery,
@ -43,12 +78,24 @@ const ImageGalleryContent = () => {
galleryGridTemplateColumns,
galleryImageObjectFit,
shouldAutoSwitchToNewImages,
areMoreImagesAvailable,
// areMoreImagesAvailable,
shouldUseSingleGalleryColumn,
} = useAppSelector(imageGallerySelector);
const { images, areMoreImagesAvailable, isLoading } =
useAppSelector(gallerySelector);
// const handleClickLoadMore = () => {
// dispatch(requestImages(currentCategory));
// };
const handleClickLoadMore = () => {
dispatch(requestImages(currentCategory));
if (currentCategory === 'result') {
dispatch(receivedResultImagesPage());
}
if (currentCategory === 'user') {
dispatch(receivedUploadImagesPage());
}
};
const handleChangeGalleryImageMinimumWidth = (v: number) => {
@ -203,11 +250,11 @@ const ImageGalleryContent = () => {
style={{ gridTemplateColumns: galleryGridTemplateColumns }}
>
{images.map((image) => {
const { uuid } = image;
const isSelected = currentImageUuid === uuid;
const { name } = image;
const isSelected = currentImageUuid === name;
return (
<HoverableImage
key={uuid}
key={name}
image={image}
isSelected={isSelected}
/>
@ -217,6 +264,7 @@ const ImageGalleryContent = () => {
<IAIButton
onClick={handleClickLoadMore}
isDisabled={!areMoreImagesAvailable}
isLoading={isLoading}
flexShrink={0}
>
{areMoreImagesAvailable

View File

@ -33,12 +33,13 @@ const GALLERY_TAB_WIDTHS: Record<
InvokeTabName,
{ galleryMinWidth: number; galleryMaxWidth: number }
> = {
txt2img: { galleryMinWidth: 200, galleryMaxWidth: 500 },
img2img: { galleryMinWidth: 200, galleryMaxWidth: 500 },
// txt2img: { galleryMinWidth: 200, galleryMaxWidth: 500 },
// img2img: { galleryMinWidth: 200, galleryMaxWidth: 500 },
linear: { galleryMinWidth: 200, galleryMaxWidth: 500 },
unifiedCanvas: { galleryMinWidth: 200, galleryMaxWidth: 200 },
nodes: { galleryMinWidth: 200, galleryMaxWidth: 500 },
postprocessing: { galleryMinWidth: 200, galleryMaxWidth: 500 },
training: { galleryMinWidth: 200, galleryMaxWidth: 500 },
// postprocessing: { galleryMinWidth: 200, galleryMaxWidth: 500 },
// training: { galleryMinWidth: 200, galleryMaxWidth: 500 },
};
const galleryPanelSelector = createSelector(

View File

@ -11,6 +11,7 @@ import {
} from '@chakra-ui/react';
import * as InvokeAI from 'app/invokeai';
import { useAppDispatch } from 'app/storeHooks';
import { useGetUrl } from 'common/util/getUrl';
import promptToString from 'common/util/promptToString';
import { seedWeightsToString } from 'common/util/seedWeightPairs';
import useSetBothPrompts from 'features/parameters/hooks/usePrompt';
@ -18,7 +19,7 @@ import {
setCfgScale,
setHeight,
setImg2imgStrength,
setInitialImage,
// setInitialImage,
setMaskPath,
setPerlin,
setSampler,
@ -120,7 +121,7 @@ type ImageMetadataViewerProps = {
const memoEqualityCheck = (
prev: ImageMetadataViewerProps,
next: ImageMetadataViewerProps
) => prev.image.uuid === next.image.uuid;
) => prev.image.name === next.image.name;
// TODO: Show more interesting information in this component.
@ -137,34 +138,13 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
dispatch(setShouldShowImageDetails(false));
});
const metadata = image?.metadata?.image || {};
const dreamPrompt = image?.dreamPrompt;
const {
cfg_scale,
fit,
height,
hires_fix,
init_image_path,
mask_image_path,
orig_path,
perlin,
postprocessing,
prompt,
sampler,
seamless,
seed,
steps,
strength,
threshold,
type,
variations,
width,
} = metadata;
const sessionId = image.metadata.invokeai?.session_id;
const node = image.metadata.invokeai?.node as Record<string, any>;
const { t } = useTranslation();
const { getUrl } = useGetUrl();
const metadataJSON = JSON.stringify(image.metadata, null, 2);
const metadataJSON = JSON.stringify(image, null, 2);
return (
<Flex
@ -183,262 +163,134 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
>
<Flex gap={2}>
<Text fontWeight="semibold">File:</Text>
<Link href={image.url} isExternal maxW="calc(100% - 3rem)">
<Link href={getUrl(image.url)} isExternal maxW="calc(100% - 3rem)">
{image.url.length > 64
? image.url.substring(0, 64).concat('...')
: image.url}
<ExternalLinkIcon mx="2px" />
</Link>
</Flex>
{Object.keys(metadata).length > 0 ? (
{node && Object.keys(node).length > 0 ? (
<>
{type && <MetadataItem label="Generation type" value={type} />}
{image.metadata?.model_weights && (
<MetadataItem label="Model" value={image.metadata.model_weights} />
{node.type && (
<MetadataItem label="Invocation type" value={node.type} />
)}
{['esrgan', 'gfpgan'].includes(type) && (
<MetadataItem label="Original image" value={orig_path} />
)}
{prompt && (
{node.model && <MetadataItem label="Model" value={node.model} />}
{node.prompt && (
<MetadataItem
label="Prompt"
labelPosition="top"
value={
typeof prompt === 'string' ? prompt : promptToString(prompt)
typeof node.prompt === 'string'
? node.prompt
: promptToString(node.prompt)
}
onClick={() => setBothPrompts(prompt)}
onClick={() => setBothPrompts(node.prompt)}
/>
)}
{seed !== undefined && (
{node.seed !== undefined && (
<MetadataItem
label="Seed"
value={seed}
onClick={() => dispatch(setSeed(seed))}
value={node.seed}
onClick={() => dispatch(setSeed(Number(node.seed)))}
/>
)}
{threshold !== undefined && (
{node.threshold !== undefined && (
<MetadataItem
label="Noise Threshold"
value={threshold}
onClick={() => dispatch(setThreshold(threshold))}
value={node.threshold}
onClick={() => dispatch(setThreshold(Number(node.threshold)))}
/>
)}
{perlin !== undefined && (
{node.perlin !== undefined && (
<MetadataItem
label="Perlin Noise"
value={perlin}
onClick={() => dispatch(setPerlin(perlin))}
value={node.perlin}
onClick={() => dispatch(setPerlin(Number(node.perlin)))}
/>
)}
{sampler && (
{node.scheduler && (
<MetadataItem
label="Sampler"
value={sampler}
onClick={() => dispatch(setSampler(sampler))}
value={node.scheduler}
onClick={() => dispatch(setSampler(node.scheduler))}
/>
)}
{steps && (
{node.steps && (
<MetadataItem
label="Steps"
value={steps}
onClick={() => dispatch(setSteps(steps))}
value={node.steps}
onClick={() => dispatch(setSteps(Number(node.steps)))}
/>
)}
{cfg_scale !== undefined && (
{node.cfg_scale !== undefined && (
<MetadataItem
label="CFG scale"
value={cfg_scale}
onClick={() => dispatch(setCfgScale(cfg_scale))}
value={node.cfg_scale}
onClick={() => dispatch(setCfgScale(Number(node.cfg_scale)))}
/>
)}
{variations && variations.length > 0 && (
{node.variations && node.variations.length > 0 && (
<MetadataItem
label="Seed-weight pairs"
value={seedWeightsToString(variations)}
value={seedWeightsToString(node.variations)}
onClick={() =>
dispatch(setSeedWeights(seedWeightsToString(variations)))
dispatch(setSeedWeights(seedWeightsToString(node.variations)))
}
/>
)}
{seamless && (
{node.seamless && (
<MetadataItem
label="Seamless"
value={seamless}
onClick={() => dispatch(setSeamless(seamless))}
value={node.seamless}
onClick={() => dispatch(setSeamless(node.seamless))}
/>
)}
{hires_fix && (
{node.hires_fix && (
<MetadataItem
label="High Resolution Optimization"
value={hires_fix}
onClick={() => dispatch(setHiresFix(hires_fix))}
value={node.hires_fix}
onClick={() => dispatch(setHiresFix(node.hires_fix))}
/>
)}
{width && (
{node.width && (
<MetadataItem
label="Width"
value={width}
onClick={() => dispatch(setWidth(width))}
value={node.width}
onClick={() => dispatch(setWidth(Number(node.width)))}
/>
)}
{height && (
{node.height && (
<MetadataItem
label="Height"
value={height}
onClick={() => dispatch(setHeight(height))}
value={node.height}
onClick={() => dispatch(setHeight(Number(node.height)))}
/>
)}
{init_image_path && (
{/* {init_image_path && (
<MetadataItem
label="Initial image"
value={init_image_path}
isLink
onClick={() => dispatch(setInitialImage(init_image_path))}
/>
)}
{mask_image_path && (
<MetadataItem
label="Mask image"
value={mask_image_path}
isLink
onClick={() => dispatch(setMaskPath(mask_image_path))}
/>
)}
{type === 'img2img' && strength && (
)} */}
{node.strength && (
<MetadataItem
label="Image to image strength"
value={strength}
onClick={() => dispatch(setImg2imgStrength(strength))}
value={node.strength}
onClick={() =>
dispatch(setImg2imgStrength(Number(node.strength)))
}
/>
)}
{fit && (
{node.fit && (
<MetadataItem
label="Image to image fit"
value={fit}
onClick={() => dispatch(setShouldFitToWidthHeight(fit))}
value={node.fit}
onClick={() => dispatch(setShouldFitToWidthHeight(node.fit))}
/>
)}
{postprocessing && postprocessing.length > 0 && (
<>
<Heading size="sm">Postprocessing</Heading>
{postprocessing.map(
(
postprocess: InvokeAI.PostProcessedImageMetadata,
i: number
) => {
if (postprocess.type === 'esrgan') {
const { scale, strength, denoise_str } = postprocess;
return (
<Flex key={i} pl={8} gap={1} direction="column">
<Text size="md">{`${i + 1}: Upscale (ESRGAN)`}</Text>
<MetadataItem
label="Scale"
value={scale}
onClick={() => dispatch(setUpscalingLevel(scale))}
/>
<MetadataItem
label="Strength"
value={strength}
onClick={() =>
dispatch(setUpscalingStrength(strength))
}
/>
{denoise_str !== undefined && (
<MetadataItem
label="Denoising strength"
value={denoise_str}
onClick={() =>
dispatch(setUpscalingDenoising(denoise_str))
}
/>
)}
</Flex>
);
} else if (postprocess.type === 'gfpgan') {
const { strength } = postprocess;
return (
<Flex key={i} pl={8} gap={1} direction="column">
<Text size="md">{`${
i + 1
}: Face restoration (GFPGAN)`}</Text>
<MetadataItem
label="Strength"
value={strength}
onClick={() => {
dispatch(setFacetoolStrength(strength));
dispatch(setFacetoolType('gfpgan'));
}}
/>
</Flex>
);
} else if (postprocess.type === 'codeformer') {
const { strength, fidelity } = postprocess;
return (
<Flex key={i} pl={8} gap={1} direction="column">
<Text size="md">{`${
i + 1
}: Face restoration (Codeformer)`}</Text>
<MetadataItem
label="Strength"
value={strength}
onClick={() => {
dispatch(setFacetoolStrength(strength));
dispatch(setFacetoolType('codeformer'));
}}
/>
{fidelity && (
<MetadataItem
label="Fidelity"
value={fidelity}
onClick={() => {
dispatch(setCodeformerFidelity(fidelity));
dispatch(setFacetoolType('codeformer'));
}}
/>
)}
</Flex>
);
}
}
)}
</>
)}
{dreamPrompt && (
<MetadataItem withCopy label="Dream Prompt" value={dreamPrompt} />
)}
<Flex gap={2} direction="column">
<Flex gap={2}>
<Tooltip label="Copy metadata JSON">
<IconButton
aria-label={t('accessibility.copyMetadataJson')}
icon={<FaCopy />}
size="xs"
variant="ghost"
fontSize={14}
onClick={() => navigator.clipboard.writeText(metadataJSON)}
/>
</Tooltip>
<Text fontWeight="semibold">Metadata JSON:</Text>
</Flex>
<Box
sx={{
mt: 0,
mr: 2,
mb: 4,
ml: 2,
padding: 4,
borderRadius: 'base',
overflowX: 'scroll',
wordBreak: 'break-all',
bg: 'whiteAlpha.500',
_dark: { bg: 'blackAlpha.500' },
}}
>
<pre>{metadataJSON}</pre>
</Box>
</Flex>
</>
) : (
<Center width="100%" pt={10}>
@ -447,6 +299,37 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
</Text>
</Center>
)}
<Flex gap={2} direction="column">
<Flex gap={2}>
<Tooltip label="Copy metadata JSON">
<IconButton
aria-label={t('accessibility.copyMetadataJson')}
icon={<FaCopy />}
size="xs"
variant="ghost"
fontSize={14}
onClick={() => navigator.clipboard.writeText(metadataJSON)}
/>
</Tooltip>
<Text fontWeight="semibold">Metadata JSON:</Text>
</Flex>
<Box
sx={{
mt: 0,
mr: 2,
mb: 4,
ml: 2,
padding: 4,
borderRadius: 'base',
overflowX: 'scroll',
wordBreak: 'break-all',
bg: 'whiteAlpha.500',
_dark: { bg: 'blackAlpha.500' },
}}
>
<pre>{metadataJSON}</pre>
</Box>
</Flex>
</Flex>
);
}, memoEqualityCheck);

View File

@ -0,0 +1,470 @@
import { ExternalLinkIcon } from '@chakra-ui/icons';
import {
Box,
Center,
Flex,
Heading,
IconButton,
Link,
Text,
Tooltip,
} from '@chakra-ui/react';
import * as InvokeAI from 'app/invokeai';
import { useAppDispatch } from 'app/storeHooks';
import { useGetUrl } from 'common/util/getUrl';
import promptToString from 'common/util/promptToString';
import { seedWeightsToString } from 'common/util/seedWeightPairs';
import useSetBothPrompts from 'features/parameters/hooks/usePrompt';
import {
setCfgScale,
setHeight,
setImg2imgStrength,
// setInitialImage,
setMaskPath,
setPerlin,
setSampler,
setSeamless,
setSeed,
setSeedWeights,
setShouldFitToWidthHeight,
setSteps,
setThreshold,
setWidth,
} from 'features/parameters/store/generationSlice';
import {
setCodeformerFidelity,
setFacetoolStrength,
setFacetoolType,
setHiresFix,
setUpscalingDenoising,
setUpscalingLevel,
setUpscalingStrength,
} from 'features/parameters/store/postprocessingSlice';
import { setShouldShowImageDetails } from 'features/ui/store/uiSlice';
import { memo } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import { FaCopy } from 'react-icons/fa';
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
import * as png from '@stevebel/png';
type MetadataItemProps = {
isLink?: boolean;
label: string;
onClick?: () => void;
value: number | string | boolean;
labelPosition?: string;
withCopy?: boolean;
};
/**
* Component to display an individual metadata item or parameter.
*/
const MetadataItem = ({
label,
value,
onClick,
isLink,
labelPosition,
withCopy = false,
}: MetadataItemProps) => {
const { t } = useTranslation();
return (
<Flex gap={2}>
{onClick && (
<Tooltip label={`Recall ${label}`}>
<IconButton
aria-label={t('accessibility.useThisParameter')}
icon={<IoArrowUndoCircleOutline />}
size="xs"
variant="ghost"
fontSize={20}
onClick={onClick}
/>
</Tooltip>
)}
{withCopy && (
<Tooltip label={`Copy ${label}`}>
<IconButton
aria-label={`Copy ${label}`}
icon={<FaCopy />}
size="xs"
variant="ghost"
fontSize={14}
onClick={() => navigator.clipboard.writeText(value.toString())}
/>
</Tooltip>
)}
<Flex direction={labelPosition ? 'column' : 'row'}>
<Text fontWeight="semibold" whiteSpace="pre-wrap" pr={2}>
{label}:
</Text>
{isLink ? (
<Link href={value.toString()} isExternal wordBreak="break-all">
{value.toString()} <ExternalLinkIcon mx="2px" />
</Link>
) : (
<Text overflowY="scroll" wordBreak="break-all">
{value.toString()}
</Text>
)}
</Flex>
</Flex>
);
};
type ImageMetadataViewerProps = {
image: InvokeAI.Image;
};
// TODO: I don't know if this is needed.
const memoEqualityCheck = (
prev: ImageMetadataViewerProps,
next: ImageMetadataViewerProps
) => prev.image.name === next.image.name;
// TODO: Show more interesting information in this component.
/**
* Image metadata viewer overlays currently selected image and provides
* access to any of its metadata for use in processing.
*/
const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
const dispatch = useAppDispatch();
const setBothPrompts = useSetBothPrompts();
useHotkeys('esc', () => {
dispatch(setShouldShowImageDetails(false));
});
const metadata = image?.metadata.sd_metadata || {};
const dreamPrompt = image?.metadata.sd_metadata?.dreamPrompt;
const {
cfg_scale,
fit,
height,
hires_fix,
init_image_path,
mask_image_path,
orig_path,
perlin,
postprocessing,
prompt,
sampler,
seamless,
seed,
steps,
strength,
threshold,
type,
variations,
width,
model_weights,
} = metadata;
const { t } = useTranslation();
const { getUrl } = useGetUrl();
const metadataJSON = JSON.stringify(image, null, 2);
// fetch(getUrl(image.url))
// .then((r) => r.arrayBuffer())
// .then((buffer) => {
// const { text } = png.decode(buffer);
// const metadata = text?.['sd-metadata']
// ? JSON.parse(text['sd-metadata'] ?? {})
// : {};
// console.log(metadata);
// });
return (
<Flex
sx={{
padding: 4,
gap: 1,
flexDirection: 'column',
width: 'full',
height: 'full',
backdropFilter: 'blur(20px)',
bg: 'whiteAlpha.600',
_dark: {
bg: 'blackAlpha.600',
},
}}
>
<Flex gap={2}>
<Text fontWeight="semibold">File:</Text>
<Link href={getUrl(image.url)} isExternal maxW="calc(100% - 3rem)">
{image.url.length > 64
? image.url.substring(0, 64).concat('...')
: image.url}
<ExternalLinkIcon mx="2px" />
</Link>
</Flex>
<Flex gap={2} direction="column">
<Flex gap={2}>
<Tooltip label="Copy metadata JSON">
<IconButton
aria-label={t('accessibility.copyMetadataJson')}
icon={<FaCopy />}
size="xs"
variant="ghost"
fontSize={14}
onClick={() => navigator.clipboard.writeText(metadataJSON)}
/>
</Tooltip>
<Text fontWeight="semibold">Metadata JSON:</Text>
</Flex>
<Box
sx={{
mt: 0,
mr: 2,
mb: 4,
ml: 2,
padding: 4,
borderRadius: 'base',
overflowX: 'scroll',
wordBreak: 'break-all',
bg: 'whiteAlpha.500',
_dark: { bg: 'blackAlpha.500' },
}}
>
<pre>{metadataJSON}</pre>
</Box>
</Flex>
{Object.keys(metadata).length > 0 ? (
<>
{type && <MetadataItem label="Generation type" value={type} />}
{model_weights && (
<MetadataItem label="Model" value={model_weights} />
)}
{['esrgan', 'gfpgan'].includes(type) && (
<MetadataItem label="Original image" value={orig_path} />
)}
{prompt && (
<MetadataItem
label="Prompt"
labelPosition="top"
value={
typeof prompt === 'string' ? prompt : promptToString(prompt)
}
onClick={() => setBothPrompts(prompt)}
/>
)}
{seed !== undefined && (
<MetadataItem
label="Seed"
value={seed}
onClick={() => dispatch(setSeed(seed))}
/>
)}
{threshold !== undefined && (
<MetadataItem
label="Noise Threshold"
value={threshold}
onClick={() => dispatch(setThreshold(threshold))}
/>
)}
{perlin !== undefined && (
<MetadataItem
label="Perlin Noise"
value={perlin}
onClick={() => dispatch(setPerlin(perlin))}
/>
)}
{sampler && (
<MetadataItem
label="Sampler"
value={sampler}
onClick={() => dispatch(setSampler(sampler))}
/>
)}
{steps && (
<MetadataItem
label="Steps"
value={steps}
onClick={() => dispatch(setSteps(steps))}
/>
)}
{cfg_scale !== undefined && (
<MetadataItem
label="CFG scale"
value={cfg_scale}
onClick={() => dispatch(setCfgScale(cfg_scale))}
/>
)}
{variations && variations.length > 0 && (
<MetadataItem
label="Seed-weight pairs"
value={seedWeightsToString(variations)}
onClick={() =>
dispatch(setSeedWeights(seedWeightsToString(variations)))
}
/>
)}
{seamless && (
<MetadataItem
label="Seamless"
value={seamless}
onClick={() => dispatch(setSeamless(seamless))}
/>
)}
{hires_fix && (
<MetadataItem
label="High Resolution Optimization"
value={hires_fix}
onClick={() => dispatch(setHiresFix(hires_fix))}
/>
)}
{width && (
<MetadataItem
label="Width"
value={width}
onClick={() => dispatch(setWidth(width))}
/>
)}
{height && (
<MetadataItem
label="Height"
value={height}
onClick={() => dispatch(setHeight(height))}
/>
)}
{/* {init_image_path && (
<MetadataItem
label="Initial image"
value={init_image_path}
isLink
onClick={() => dispatch(setInitialImage(init_image_path))}
/>
)} */}
{mask_image_path && (
<MetadataItem
label="Mask image"
value={mask_image_path}
isLink
onClick={() => dispatch(setMaskPath(mask_image_path))}
/>
)}
{type === 'img2img' && strength && (
<MetadataItem
label="Image to image strength"
value={strength}
onClick={() => dispatch(setImg2imgStrength(strength))}
/>
)}
{fit && (
<MetadataItem
label="Image to image fit"
value={fit}
onClick={() => dispatch(setShouldFitToWidthHeight(fit))}
/>
)}
{postprocessing && postprocessing.length > 0 && (
<>
<Heading size="sm">Postprocessing</Heading>
{postprocessing.map(
(
postprocess: InvokeAI.PostProcessedImageMetadata,
i: number
) => {
if (postprocess.type === 'esrgan') {
const { scale, strength, denoise_str } = postprocess;
return (
<Flex key={i} pl={8} gap={1} direction="column">
<Text size="md">{`${i + 1}: Upscale (ESRGAN)`}</Text>
<MetadataItem
label="Scale"
value={scale}
onClick={() => dispatch(setUpscalingLevel(scale))}
/>
<MetadataItem
label="Strength"
value={strength}
onClick={() =>
dispatch(setUpscalingStrength(strength))
}
/>
{denoise_str !== undefined && (
<MetadataItem
label="Denoising strength"
value={denoise_str}
onClick={() =>
dispatch(setUpscalingDenoising(denoise_str))
}
/>
)}
</Flex>
);
} else if (postprocess.type === 'gfpgan') {
const { strength } = postprocess;
return (
<Flex key={i} pl={8} gap={1} direction="column">
<Text size="md">{`${
i + 1
}: Face restoration (GFPGAN)`}</Text>
<MetadataItem
label="Strength"
value={strength}
onClick={() => {
dispatch(setFacetoolStrength(strength));
dispatch(setFacetoolType('gfpgan'));
}}
/>
</Flex>
);
} else if (postprocess.type === 'codeformer') {
const { strength, fidelity } = postprocess;
return (
<Flex key={i} pl={8} gap={1} direction="column">
<Text size="md">{`${
i + 1
}: Face restoration (Codeformer)`}</Text>
<MetadataItem
label="Strength"
value={strength}
onClick={() => {
dispatch(setFacetoolStrength(strength));
dispatch(setFacetoolType('codeformer'));
}}
/>
{fidelity && (
<MetadataItem
label="Fidelity"
value={fidelity}
onClick={() => {
dispatch(setCodeformerFidelity(fidelity));
dispatch(setFacetoolType('codeformer'));
}}
/>
)}
</Flex>
);
}
}
)}
</>
)}
{dreamPrompt && (
<MetadataItem withCopy label="Dream Prompt" value={dreamPrompt} />
)}
</>
) : (
<Center width="100%" pt={10}>
<Text fontSize="lg" fontWeight="semibold">
No metadata available
</Text>
</Center>
)}
</Flex>
);
}, memoEqualityCheck);
ImageMetadataViewer.displayName = 'ImageMetadataViewer';
export default ImageMetadataViewer;

View File

@ -0,0 +1,35 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/storeHooks';
import { ImageType } from 'services/api';
import { selectResultsEntities } from '../store/resultsSlice';
import { selectUploadsEntities } from '../store/uploadsSlice';
const useGetImageByNameSelector = createSelector(
[selectResultsEntities, selectUploadsEntities],
(allResults, allUploads) => {
return { allResults, allUploads };
}
);
const useGetImageByNameAndType = () => {
const { allResults, allUploads } = useAppSelector(useGetImageByNameSelector);
return (name: string, type: ImageType) => {
if (type === 'results') {
const resultImagesResult = allResults[name];
if (resultImagesResult) {
return resultImagesResult;
}
}
if (type === 'uploads') {
const userImagesResult = allUploads[name];
if (userImagesResult) {
return userImagesResult;
}
}
};
};
export default useGetImageByNameAndType;

View File

@ -0,0 +1,17 @@
import { GalleryState } from './gallerySlice';
/**
* Gallery slice persist blacklist
*/
const itemsToBlacklist: (keyof GalleryState)[] = [
'categories',
'currentCategory',
'currentImage',
'currentImageUuid',
'shouldAutoSwitchToNewImages',
'intermediateImage',
];
export const galleryBlacklist = itemsToBlacklist.map(
(blacklistItem) => `gallery.${blacklistItem}`
);

View File

@ -7,6 +7,16 @@ import {
uiSelector,
} from 'features/ui/store/uiSelectors';
import { isEqual } from 'lodash';
import {
selectResultsAll,
selectResultsById,
selectResultsEntities,
} from './resultsSlice';
import {
selectUploadsAll,
selectUploadsById,
selectUploadsEntities,
} from './uploadsSlice';
export const gallerySelector = (state: RootState) => state.gallery;
@ -75,3 +85,18 @@ export const hoverableImageSelector = createSelector(
},
}
);
export const selectedImageSelector = createSelector(
[gallerySelector, selectResultsEntities, selectUploadsEntities],
(gallery, allResults, allUploads) => {
const selectedImageName = gallery.selectedImageName;
if (selectedImageName in allResults) {
return allResults[selectedImageName];
}
if (selectedImageName in allUploads) {
return allUploads[selectedImageName];
}
}
);

View File

@ -1,14 +1,17 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import * as InvokeAI from 'app/invokeai';
import { invocationComplete } from 'services/events/actions';
import { InvokeTabName } from 'features/ui/store/tabMap';
import { IRect } from 'konva/lib/types';
import { clamp } from 'lodash';
import { isImageOutput } from 'services/types/guards';
import { imageUploaded } from 'services/thunks/image';
export type GalleryCategory = 'user' | 'result';
export type AddImagesPayload = {
images: Array<InvokeAI.Image>;
images: Array<InvokeAI._Image>;
areMoreImagesAvailable: boolean;
category: GalleryCategory;
};
@ -16,16 +19,33 @@ export type AddImagesPayload = {
type GalleryImageObjectFitType = 'contain' | 'cover';
export type Gallery = {
images: InvokeAI.Image[];
images: InvokeAI._Image[];
latest_mtime?: number;
earliest_mtime?: number;
areMoreImagesAvailable: boolean;
};
export interface GalleryState {
currentImage?: InvokeAI.Image;
/**
* The selected image's unique name
* Use `selectedImageSelector` to access the image
*/
selectedImageName: string;
/**
* The currently selected image
* @deprecated See `state.gallery.selectedImageName`
*/
currentImage?: InvokeAI._Image;
/**
* The currently selected image's uuid.
* @deprecated See `state.gallery.selectedImageName`, use `selectedImageSelector` to access the image
*/
currentImageUuid: string;
intermediateImage?: InvokeAI.Image & {
/**
* The current progress image
* @deprecated See `state.system.progressImage`
*/
intermediateImage?: InvokeAI._Image & {
boundingBox?: IRect;
generationMode?: InvokeTabName;
};
@ -42,6 +62,7 @@ export interface GalleryState {
}
const initialState: GalleryState = {
selectedImageName: '',
currentImageUuid: '',
galleryImageMinimumWidth: 64,
galleryImageObjectFit: 'cover',
@ -69,7 +90,10 @@ export const gallerySlice = createSlice({
name: 'gallery',
initialState,
reducers: {
setCurrentImage: (state, action: PayloadAction<InvokeAI.Image>) => {
imageSelected: (state, action: PayloadAction<string>) => {
state.selectedImageName = action.payload;
},
setCurrentImage: (state, action: PayloadAction<InvokeAI._Image>) => {
state.currentImage = action.payload;
state.currentImageUuid = action.payload.uuid;
},
@ -124,7 +148,7 @@ export const gallerySlice = createSlice({
addImage: (
state,
action: PayloadAction<{
image: InvokeAI.Image;
image: InvokeAI._Image;
category: GalleryCategory;
}>
) => {
@ -150,7 +174,10 @@ export const gallerySlice = createSlice({
setIntermediateImage: (
state,
action: PayloadAction<
InvokeAI.Image & { boundingBox?: IRect; generationMode?: InvokeTabName }
InvokeAI._Image & {
boundingBox?: IRect;
generationMode?: InvokeTabName;
}
>
) => {
state.intermediateImage = action.payload;
@ -252,9 +279,31 @@ export const gallerySlice = createSlice({
state.shouldUseSingleGalleryColumn = action.payload;
},
},
extraReducers(builder) {
/**
* Invocation Complete
*/
builder.addCase(invocationComplete, (state, action) => {
const { data } = action.payload;
if (isImageOutput(data.result)) {
state.selectedImageName = data.result.image.image_name;
state.intermediateImage = undefined;
}
});
/**
* Upload Image - FULFILLED
*/
builder.addCase(imageUploaded.fulfilled, (state, action) => {
const { location } = action.payload;
const imageName = location.split('/').pop() || '';
state.selectedImageName = imageName;
});
},
});
export const {
imageSelected,
addImage,
clearIntermediateImage,
removeImage,

View File

@ -0,0 +1,12 @@
import { ResultsState } from './resultsSlice';
/**
* Results slice persist blacklist
*
* Currently blacklisting results slice entirely, see persist config in store.ts
*/
const itemsToBlacklist: (keyof ResultsState)[] = [];
export const resultsBlacklist = itemsToBlacklist.map(
(blacklistItem) => `results.${blacklistItem}`
);

View File

@ -0,0 +1,139 @@
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
import { Image } from 'app/invokeai';
import { invocationComplete } from 'services/events/actions';
import { RootState } from 'app/store';
import {
receivedResultImagesPage,
IMAGES_PER_PAGE,
} from 'services/thunks/gallery';
import { isImageOutput } from 'services/types/guards';
import {
buildImageUrls,
extractTimestampFromImageName,
} from 'services/util/deserializeImageField';
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
// use `createEntityAdapter` to create a slice for results images
// https://redux-toolkit.js.org/api/createEntityAdapter#overview
// the "Entity" is InvokeAI.ResultImage, while the "entities" are instances of that type
export const resultsAdapter = createEntityAdapter<Image>({
// Provide a callback to get a stable, unique identifier for each entity. This defaults to
// `(item) => item.id`, but for our result images, the `name` is the unique identifier.
selectId: (image) => image.name,
// Order all images by their time (in descending order)
sortComparer: (a, b) => b.metadata.created - a.metadata.created,
});
// This type is intersected with the Entity type to create the shape of the state
type AdditionalResultsState = {
// these are a bit misleading; they refer to sessions, not results, but we don't have a route
// to list all images directly at this time...
page: number; // current page we are on
pages: number; // the total number of pages available
isLoading: boolean; // whether we are loading more images or not, mostly a placeholder
nextPage: number; // the next page to request
};
export const initialResultsState =
resultsAdapter.getInitialState<AdditionalResultsState>({
// provide the additional initial state
page: 0,
pages: 0,
isLoading: false,
nextPage: 0,
});
export type ResultsState = typeof initialResultsState;
const resultsSlice = createSlice({
name: 'results',
initialState: initialResultsState,
reducers: {
// the adapter provides some helper reducers; see the docs for all of them
// can use them as helper functions within a reducer, or use the function itself as a reducer
// here we just use the function itself as the reducer. we'll call this on `invocation_complete`
// to add a single result
resultAdded: resultsAdapter.upsertOne,
},
extraReducers: (builder) => {
// here we can respond to a fulfilled call of the `getNextResultsPage` thunk
// because we pass in the fulfilled thunk action creator, everything is typed
/**
* Received Result Images Page - PENDING
*/
builder.addCase(receivedResultImagesPage.pending, (state) => {
state.isLoading = true;
});
/**
* Received Result Images Page - FULFILLED
*/
builder.addCase(receivedResultImagesPage.fulfilled, (state, action) => {
const { items, page, pages } = action.payload;
const resultImages = items.map((image) =>
deserializeImageResponse(image)
);
// use the adapter reducer to append all the results to state
resultsAdapter.addMany(state, resultImages);
state.page = page;
state.pages = pages;
state.nextPage = items.length < IMAGES_PER_PAGE ? page : page + 1;
state.isLoading = false;
});
/**
* Invocation Complete
*/
builder.addCase(invocationComplete, (state, action) => {
const { data } = action.payload;
const { result, node, graph_execution_state_id } = data;
if (isImageOutput(result)) {
const name = result.image.image_name;
const type = result.image.image_type;
const { url, thumbnail } = buildImageUrls(type, name);
const timestamp = extractTimestampFromImageName(name);
const image: Image = {
name,
type,
url,
thumbnail,
metadata: {
created: timestamp,
width: result.width, // TODO: add tese dimensions
height: result.height,
invokeai: {
session_id: graph_execution_state_id,
...(node ? { node } : {}),
},
},
};
resultsAdapter.addOne(state, image);
}
});
},
});
// Create a set of memoized selectors based on the location of this entity state
// to be used as selectors in a `useAppSelector()` call
export const {
selectAll: selectResultsAll,
selectById: selectResultsById,
selectEntities: selectResultsEntities,
selectIds: selectResultsIds,
selectTotal: selectResultsTotal,
} = resultsAdapter.getSelectors<RootState>((state) => state.results);
export const { resultAdded } = resultsSlice.actions;
export default resultsSlice.reducer;

View File

@ -1,54 +0,0 @@
import { AnyAction, ThunkAction } from '@reduxjs/toolkit';
import * as InvokeAI from 'app/invokeai';
import { RootState } from 'app/store';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { setInitialImage } from 'features/parameters/store/generationSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { v4 as uuidv4 } from 'uuid';
import { addImage } from '../gallerySlice';
type UploadImageConfig = {
imageFile: File;
};
export const uploadImage =
(
config: UploadImageConfig
): ThunkAction<void, RootState, unknown, AnyAction> =>
async (dispatch, getState) => {
const { imageFile } = config;
const state = getState() as RootState;
const activeTabName = activeTabNameSelector(state);
const formData = new FormData();
formData.append('file', imageFile, imageFile.name);
formData.append(
'data',
JSON.stringify({
kind: 'init',
})
);
const response = await fetch(`${window.location.origin}/upload`, {
method: 'POST',
body: formData,
});
const image = (await response.json()) as InvokeAI.ImageUploadResponse;
const newImage: InvokeAI.Image = {
uuid: uuidv4(),
category: 'user',
...image,
};
dispatch(addImage({ image: newImage, category: 'user' }));
if (activeTabName === 'unifiedCanvas') {
dispatch(setInitialCanvasImage(newImage));
} else if (activeTabName === 'img2img') {
dispatch(setInitialImage(newImage));
}
};

View File

@ -0,0 +1,12 @@
import { UploadsState } from './uploadsSlice';
/**
* Uploads slice persist blacklist
*
* Currently blacklisting uploads slice entirely, see persist config in store.ts
*/
const itemsToBlacklist: (keyof UploadsState)[] = [];
export const uploadsBlacklist = itemsToBlacklist.map(
(blacklistItem) => `uploads.${blacklistItem}`
);

View File

@ -0,0 +1,87 @@
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
import { Image } from 'app/invokeai';
import { RootState } from 'app/store';
import {
receivedUploadImagesPage,
IMAGES_PER_PAGE,
} from 'services/thunks/gallery';
import { imageUploaded } from 'services/thunks/image';
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
export const uploadsAdapter = createEntityAdapter<Image>({
selectId: (image) => image.name,
sortComparer: (a, b) => b.metadata.created - a.metadata.created,
});
type AdditionalUploadsState = {
page: number;
pages: number;
isLoading: boolean;
nextPage: number;
};
const initialUploadsState =
uploadsAdapter.getInitialState<AdditionalUploadsState>({
page: 0,
pages: 0,
nextPage: 0,
isLoading: false,
});
export type UploadsState = typeof initialUploadsState;
const uploadsSlice = createSlice({
name: 'uploads',
initialState: initialUploadsState,
reducers: {
uploadAdded: uploadsAdapter.addOne,
},
extraReducers: (builder) => {
/**
* Received Upload Images Page - PENDING
*/
builder.addCase(receivedUploadImagesPage.pending, (state) => {
state.isLoading = true;
});
/**
* Received Upload Images Page - FULFILLED
*/
builder.addCase(receivedUploadImagesPage.fulfilled, (state, action) => {
const { items, page, pages } = action.payload;
const images = items.map((image) => deserializeImageResponse(image));
uploadsAdapter.addMany(state, images);
state.page = page;
state.pages = pages;
state.nextPage = items.length < IMAGES_PER_PAGE ? page : page + 1;
state.isLoading = false;
});
/**
* Upload Image - FULFILLED
*/
builder.addCase(imageUploaded.fulfilled, (state, action) => {
const { location, response } = action.payload;
const uploadedImage = deserializeImageResponse(response);
uploadsAdapter.addOne(state, uploadedImage);
});
},
});
export const {
selectAll: selectUploadsAll,
selectById: selectUploadsById,
selectEntities: selectUploadsEntities,
selectIds: selectUploadsIds,
selectTotal: selectUploadsTotal,
} = uploadsAdapter.getSelectors<RootState>((state) => state.uploads);
export const { uploadAdded } = uploadsSlice.actions;
export default uploadsSlice.reducer;

View File

@ -1,9 +1,10 @@
import * as React from 'react';
import { TransformComponent, useTransformContext } from 'react-zoom-pan-pinch';
import * as InvokeAI from 'app/invokeai';
import { useGetUrl } from 'common/util/getUrl';
type ReactPanZoomProps = {
image: InvokeAI.Image;
image: InvokeAI._Image;
styleClass?: string;
alt?: string;
ref?: React.Ref<HTMLImageElement>;
@ -22,6 +23,7 @@ export default function ReactPanZoomImage({
scaleY,
}: ReactPanZoomProps) {
const { centerView } = useTransformContext();
const { getUrl } = useGetUrl();
return (
<TransformComponent
@ -35,7 +37,7 @@ export default function ReactPanZoomImage({
transform: `rotate(${rotation}deg) scaleX(${scaleX}) scaleY(${scaleY})`,
width: '100%',
}}
src={image.url}
src={getUrl(image.url)}
alt={alt}
ref={ref}
className={styleClass ? styleClass : ''}

View File

@ -0,0 +1,10 @@
import { LightboxState } from './lightboxSlice';
/**
* Lightbox slice persist blacklist
*/
const itemsToBlacklist: (keyof LightboxState)[] = ['isLightboxOpen'];
export const lightboxBlacklist = itemsToBlacklist.map(
(blacklistItem) => `lightbox.${blacklistItem}`
);

View File

@ -0,0 +1,63 @@
import { v4 as uuidv4 } from 'uuid';
import 'reactflow/dist/style.css';
import { useCallback } from 'react';
import {
Tooltip,
Menu,
MenuButton,
MenuList,
MenuItem,
IconButton,
} from '@chakra-ui/react';
import { FaPlus } from 'react-icons/fa';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { nodeAdded } from '../store/nodesSlice';
import { cloneDeep, map } from 'lodash';
import { RootState } from 'app/store';
import { useBuildInvocation } from '../hooks/useBuildInvocation';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/hooks/useToastWatcher';
export const AddNodeMenu = () => {
const dispatch = useAppDispatch();
const invocationTemplates = useAppSelector(
(state: RootState) => state.nodes.invocationTemplates
);
const buildInvocation = useBuildInvocation();
const addNode = useCallback(
(nodeType: string) => {
const invocation = buildInvocation(nodeType);
if (!invocation) {
const toast = makeToast({
status: 'error',
title: `Unknown Invocation type ${nodeType}`,
});
dispatch(addToast(toast));
return;
}
dispatch(nodeAdded(invocation));
},
[dispatch, buildInvocation]
);
return (
<Menu>
<MenuButton as={IconButton} aria-label="Add Node" icon={<FaPlus />} />
<MenuList>
{map(invocationTemplates, ({ title, description, type }, key) => {
return (
<Tooltip key={key} label={description} placement="end" hasArrow>
<MenuItem onClick={() => addNode(type)}>{title}</MenuItem>
</Tooltip>
);
})}
</MenuList>
</Menu>
);
};

View File

@ -0,0 +1,69 @@
import { Tooltip } from '@chakra-ui/react';
import { CSSProperties, useMemo } from 'react';
import {
Handle,
Position,
Connection,
HandleType,
useReactFlow,
} from 'reactflow';
import { FIELDS, HANDLE_TOOLTIP_OPEN_DELAY } from '../types/constants';
// import { useConnectionEventStyles } from '../hooks/useConnectionEventStyles';
import { InputFieldTemplate, OutputFieldTemplate } from '../types/types';
const handleBaseStyles: CSSProperties = {
position: 'absolute',
width: '1rem',
height: '1rem',
borderWidth: 0,
};
const inputHandleStyles: CSSProperties = {
left: '-1.7rem',
};
const outputHandleStyles: CSSProperties = {
right: '-1.7rem',
};
const requiredConnectionStyles: CSSProperties = {
boxShadow: '0 0 0.5rem 0.5rem var(--invokeai-colors-error-400)',
};
type FieldHandleProps = {
nodeId: string;
field: InputFieldTemplate | OutputFieldTemplate;
isValidConnection: (connection: Connection) => boolean;
handleType: HandleType;
styles?: CSSProperties;
};
export const FieldHandle = (props: FieldHandleProps) => {
const { nodeId, field, isValidConnection, handleType, styles } = props;
const { name, title, type, description } = field;
return (
<Tooltip
key={name}
label={type}
placement={handleType === 'target' ? 'start' : 'end'}
hasArrow
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
>
<Handle
type={handleType}
id={name}
isValidConnection={isValidConnection}
position={handleType === 'target' ? Position.Left : Position.Right}
style={{
backgroundColor: FIELDS[type].colorCssVar,
...styles,
...handleBaseStyles,
...(handleType === 'target' ? inputHandleStyles : outputHandleStyles),
// ...(inputRequirement === 'always' ? requiredConnectionStyles : {}),
// ...connectionEventStyles,
}}
/>
</Tooltip>
);
};

View File

@ -0,0 +1,18 @@
import 'reactflow/dist/style.css';
import { Tooltip, Badge, HStack } from '@chakra-ui/react';
import { map } from 'lodash';
import { FIELDS } from '../types/constants';
export const FieldTypeLegend = () => {
return (
<HStack>
{map(FIELDS, ({ title, description, color }, key) => (
<Tooltip key={key} label={description}>
<Badge colorScheme={color} sx={{ userSelect: 'none' }}>
{title}
</Badge>
</Tooltip>
))}
</HStack>
);
};

View File

@ -0,0 +1,104 @@
import {
Background,
Controls,
MiniMap,
OnConnect,
OnEdgesChange,
OnNodesChange,
ReactFlow,
ConnectionLineType,
OnConnectStart,
OnConnectEnd,
Panel,
} from 'reactflow';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { RootState } from 'app/store';
import {
connectionEnded,
connectionMade,
connectionStarted,
edgesChanged,
nodesChanged,
} from '../store/nodesSlice';
import { useCallback } from 'react';
import { InvocationComponent } from './InvocationComponent';
import { AddNodeMenu } from './AddNodeMenu';
import { FieldTypeLegend } from './FieldTypeLegend';
import { Button } from '@chakra-ui/react';
import { nodesGraphBuilt } from 'services/thunks/session';
const nodeTypes = { invocation: InvocationComponent };
export const Flow = () => {
const dispatch = useAppDispatch();
const nodes = useAppSelector((state: RootState) => state.nodes.nodes);
const edges = useAppSelector((state: RootState) => state.nodes.edges);
const onNodesChange: OnNodesChange = useCallback(
(changes) => {
dispatch(nodesChanged(changes));
},
[dispatch]
);
const onEdgesChange: OnEdgesChange = useCallback(
(changes) => {
dispatch(edgesChanged(changes));
},
[dispatch]
);
const onConnectStart: OnConnectStart = useCallback(
(event, params) => {
dispatch(connectionStarted(params));
},
[dispatch]
);
const onConnect: OnConnect = useCallback(
(connection) => {
dispatch(connectionMade(connection));
},
[dispatch]
);
const onConnectEnd: OnConnectEnd = useCallback(
(event) => {
dispatch(connectionEnded());
},
[dispatch]
);
const handleInvoke = useCallback(() => {
dispatch(nodesGraphBuilt());
}, [dispatch]);
return (
<ReactFlow
nodeTypes={nodeTypes}
nodes={nodes}
edges={edges}
onNodesChange={onNodesChange}
onEdgesChange={onEdgesChange}
onConnectStart={onConnectStart}
onConnect={onConnect}
onConnectEnd={onConnectEnd}
defaultEdgeOptions={{
style: { strokeWidth: 2 },
}}
>
<Panel position="top-left">
<AddNodeMenu />
</Panel>
<Panel position="top-center">
<Button onClick={handleInvoke}>Will it blend?</Button>
</Panel>
<Panel position="top-right">
<FieldTypeLegend />
</Panel>
<Background />
<Controls />
<MiniMap nodeStrokeWidth={3} zoomable pannable />
</ReactFlow>
);
};

View File

@ -0,0 +1,107 @@
import { Box } from '@chakra-ui/react';
import { InputFieldTemplate, InputFieldValue } from '../types/types';
import { ArrayInputFieldComponent } from './fields/ArrayInputField.tsx';
import { BooleanInputFieldComponent } from './fields/BooleanInputFieldComponent';
import { EnumInputFieldComponent } from './fields/EnumInputFieldComponent';
import { ImageInputFieldComponent } from './fields/ImageInputFieldComponent';
import { LatentsInputFieldComponent } from './fields/LatentsInputFieldComponent';
import { ModelInputFieldComponent } from './fields/ModelInputFieldComponent';
import { NumberInputFieldComponent } from './fields/NumberInputFieldComponent';
import { StringInputFieldComponent } from './fields/StringInputFieldComponent';
type InputFieldComponentProps = {
nodeId: string;
field: InputFieldValue;
template: InputFieldTemplate;
};
// build an individual input element based on the schema
export const InputFieldComponent = (props: InputFieldComponentProps) => {
const { nodeId, field, template } = props;
const { type, value } = field;
if (type === 'string' && template.type === 'string') {
return (
<StringInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'boolean' && template.type === 'boolean') {
return (
<BooleanInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (
(type === 'integer' && template.type === 'integer') ||
(type === 'float' && template.type === 'float')
) {
return (
<NumberInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'enum' && template.type === 'enum') {
return (
<EnumInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'image' && template.type === 'image') {
return (
<ImageInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'latents' && template.type === 'latents') {
return (
<LatentsInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'model' && template.type === 'model') {
return (
<ModelInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'array' && template.type === 'array') {
return (
<ArrayInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
return <Box p={2}>Unknown field type: {type}</Box>;
};

View File

@ -0,0 +1,243 @@
import { NodeProps, useReactFlow } from 'reactflow';
import {
Box,
Flex,
FormControl,
FormLabel,
Heading,
HStack,
Tooltip,
Icon,
Code,
Text,
} from '@chakra-ui/react';
import { FaExclamationCircle, FaInfoCircle } from 'react-icons/fa';
import { InvocationValue } from '../types/types';
import { InputFieldComponent } from './InputFieldComponent';
import { FieldHandle } from './FieldHandle';
import { isEqual, map, size } from 'lodash';
import { memo, useMemo, useRef } from 'react';
import { useIsValidConnection } from '../hooks/useIsValidConnection';
import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store';
import { useAppSelector } from 'app/storeHooks';
import { useGetInvocationTemplate } from '../hooks/useInvocationTemplate';
const connectedInputFieldsSelector = createSelector(
[(state: RootState) => state.nodes.edges],
(edges) => {
// return edges.map((e) => e.targetHandle);
return edges;
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
export const InvocationComponent = memo((props: NodeProps<InvocationValue>) => {
const { id: nodeId, data, selected } = props;
const { type, inputs, outputs } = data;
const isValidConnection = useIsValidConnection();
const connectedInputs = useAppSelector(connectedInputFieldsSelector);
const getInvocationTemplate = useGetInvocationTemplate();
// TODO: determine if a field/handle is connected and disable the input if so
const template = useRef(getInvocationTemplate(type));
if (!template.current) {
return (
<Box
sx={{
padding: 4,
bg: 'base.800',
borderRadius: 'md',
boxShadow: 'dark-lg',
borderWidth: 2,
borderColor: selected ? 'base.400' : 'transparent',
}}
>
<Flex sx={{ alignItems: 'center', justifyContent: 'center' }}>
<Icon color="base.400" boxSize={32} as={FaExclamationCircle}></Icon>
</Flex>
</Box>
);
}
return (
<Box
sx={{
padding: 4,
bg: 'base.800',
borderRadius: 'md',
boxShadow: 'dark-lg',
borderWidth: 2,
borderColor: selected ? 'base.400' : 'transparent',
}}
>
<Flex flexDirection="column" gap={2}>
<>
<Code>{nodeId}</Code>
<HStack justifyContent="space-between">
<Heading size="sm" fontWeight={500} color="base.100">
{template.current.title}
</Heading>
<Tooltip
label={template.current.description}
placement="top"
hasArrow
shouldWrapChildren
>
<Icon color="base.300" as={FaInfoCircle} />
</Tooltip>
</HStack>
{map(inputs, (input, i) => {
const { id: fieldId } = input;
const inputTemplate = template.current?.inputs[input.name];
if (!inputTemplate) {
return (
<Box
key={fieldId}
position="relative"
p={2}
borderWidth={1}
borderRadius="md"
sx={{
borderColor: 'error.400',
}}
>
<FormControl isDisabled={true}>
<HStack justifyContent="space-between" alignItems="center">
<FormLabel>Unknown input: {input.name}</FormLabel>
</HStack>
</FormControl>
</Box>
);
}
const isConnected = Boolean(
connectedInputs.filter((connectedInput) => {
return (
connectedInput.target === nodeId &&
connectedInput.targetHandle === input.name
);
}).length
);
return (
<Box
key={fieldId}
position="relative"
p={2}
borderWidth={1}
borderRadius="md"
sx={{
borderColor:
!isConnected &&
['always', 'connectionOnly'].includes(
String(inputTemplate?.inputRequirement)
) &&
input.value === undefined
? 'warning.400'
: undefined,
}}
>
<FormControl isDisabled={isConnected}>
<HStack justifyContent="space-between" alignItems="center">
<FormLabel>{inputTemplate?.title}</FormLabel>
<Tooltip
label={inputTemplate?.description}
placement="top"
hasArrow
shouldWrapChildren
>
<Icon color="base.400" as={FaInfoCircle} />
</Tooltip>
</HStack>
<InputFieldComponent
nodeId={nodeId}
field={input}
template={inputTemplate}
/>
</FormControl>
{!['never', 'directOnly'].includes(
inputTemplate?.inputRequirement ?? ''
) && (
<FieldHandle
nodeId={nodeId}
field={inputTemplate}
isValidConnection={isValidConnection}
handleType="target"
/>
)}
</Box>
);
})}
{map(outputs).map((output, i) => {
const outputTemplate = template.current?.outputs[output.name];
const isConnected = Boolean(
connectedInputs.filter((connectedInput) => {
return (
connectedInput.source === nodeId &&
connectedInput.sourceHandle === output.name
);
}).length
);
if (!outputTemplate) {
return (
<Box
key={output.id}
position="relative"
p={2}
borderWidth={1}
borderRadius="md"
sx={{
borderColor: 'error.400',
}}
>
<FormControl isDisabled={true}>
<HStack justifyContent="space-between" alignItems="center">
<FormLabel>Unknown output: {output.name}</FormLabel>
</HStack>
</FormControl>
</Box>
);
}
return (
<Box
key={output.id}
position="relative"
p={2}
borderWidth={1}
borderRadius="md"
>
<FormControl isDisabled={isConnected}>
<FormLabel textAlign="end">
{outputTemplate?.title} Output
</FormLabel>
</FormControl>
<FieldHandle
key={output.id}
nodeId={nodeId}
field={outputTemplate}
isValidConnection={isValidConnection}
handleType="source"
/>
</Box>
);
})}
</>
</Flex>
<Flex></Flex>
</Box>
);
});
InvocationComponent.displayName = 'InvocationComponent';

View File

@ -0,0 +1,46 @@
import 'reactflow/dist/style.css';
import { Box } from '@chakra-ui/react';
import { ReactFlowProvider } from 'reactflow';
import { Flow } from './Flow';
import { useAppSelector } from 'app/storeHooks';
import { RootState } from 'app/store';
import { buildNodesGraph } from '../util/nodesGraphBuilder/buildNodesGraph';
const NodeEditor = () => {
const state = useAppSelector((state: RootState) => state);
const graph = buildNodesGraph(state);
return (
<Box
sx={{
position: 'relative',
width: 'full',
height: 'full',
borderRadius: 'md',
bg: 'base.850',
}}
>
<ReactFlowProvider>
<Flow />
</ReactFlowProvider>
<Box
as="pre"
fontFamily="monospace"
position="absolute"
top={2}
left={2}
width="full"
height="full"
userSelect="none"
pointerEvents="none"
opacity={0.7}
>
<Box w="50%">{JSON.stringify(graph, null, 2)}</Box>
</Box>
</Box>
);
};
export default NodeEditor;

View File

@ -0,0 +1,14 @@
import {
ArrayInputFieldTemplate,
ArrayInputFieldValue,
} from 'features/nodes/types';
import { FaImage, FaList } from 'react-icons/fa';
import { FieldComponentProps } from './types';
export const ArrayInputFieldComponent = (
props: FieldComponentProps<ArrayInputFieldValue, ArrayInputFieldTemplate>
) => {
const { nodeId, field } = props;
return <FaList />;
};

View File

@ -0,0 +1,31 @@
import { Switch } from '@chakra-ui/react';
import { useAppDispatch } from 'app/storeHooks';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import {
BooleanInputFieldTemplate,
BooleanInputFieldValue,
} from 'features/nodes/types';
import { ChangeEvent } from 'react';
import { FieldComponentProps } from './types';
export const BooleanInputFieldComponent = (
props: FieldComponentProps<BooleanInputFieldValue, BooleanInputFieldTemplate>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const handleValueChanged = (e: ChangeEvent<HTMLInputElement>) => {
dispatch(
fieldValueChanged({
nodeId,
fieldName: field.name,
value: e.target.checked,
})
);
};
return (
<Switch onChange={handleValueChanged} isChecked={field.value}></Switch>
);
};

View File

@ -0,0 +1,35 @@
import { Select } from '@chakra-ui/react';
import { useAppDispatch } from 'app/storeHooks';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import {
EnumInputFieldTemplate,
EnumInputFieldValue,
} from 'features/nodes/types';
import { ChangeEvent } from 'react';
import { FieldComponentProps } from './types';
export const EnumInputFieldComponent = (
props: FieldComponentProps<EnumInputFieldValue, EnumInputFieldTemplate>
) => {
const { nodeId, field, template } = props;
const dispatch = useAppDispatch();
const handleValueChanged = (e: ChangeEvent<HTMLSelectElement>) => {
dispatch(
fieldValueChanged({
nodeId,
fieldName: field.name,
value: e.target.value,
})
);
};
return (
<Select onChange={handleValueChanged} value={field.value}>
{template.options.map((option) => (
<option key={option}>{option}</option>
))}
</Select>
);
};

Some files were not shown because too many files have changed in this diff Show More