mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' of https://github.com/invoke-ai/InvokeAI into responsive-ui
This commit is contained in:
commit
3fb433cb91
14
.github/CODEOWNERS
vendored
14
.github/CODEOWNERS
vendored
@ -1,16 +1,16 @@
|
|||||||
# continuous integration
|
# continuous integration
|
||||||
/.github/workflows/ @mauwii @lstein @blessedcoolant
|
/.github/workflows/ @lstein @blessedcoolant
|
||||||
|
|
||||||
# documentation
|
# documentation
|
||||||
/docs/ @lstein @mauwii @tildebyte @blessedcoolant
|
/docs/ @lstein @tildebyte @blessedcoolant
|
||||||
/mkdocs.yml @lstein @mauwii @blessedcoolant
|
/mkdocs.yml @lstein @blessedcoolant
|
||||||
|
|
||||||
# nodes
|
# nodes
|
||||||
/invokeai/app/ @Kyle0654 @blessedcoolant
|
/invokeai/app/ @Kyle0654 @blessedcoolant
|
||||||
|
|
||||||
# installation and configuration
|
# installation and configuration
|
||||||
/pyproject.toml @mauwii @lstein @blessedcoolant
|
/pyproject.toml @lstein @blessedcoolant
|
||||||
/docker/ @mauwii @lstein @blessedcoolant
|
/docker/ @lstein @blessedcoolant
|
||||||
/scripts/ @ebr @lstein
|
/scripts/ @ebr @lstein
|
||||||
/installer/ @lstein @ebr
|
/installer/ @lstein @ebr
|
||||||
/invokeai/assets @lstein @ebr
|
/invokeai/assets @lstein @ebr
|
||||||
@ -22,11 +22,11 @@
|
|||||||
/invokeai/backend @blessedcoolant @psychedelicious @lstein
|
/invokeai/backend @blessedcoolant @psychedelicious @lstein
|
||||||
|
|
||||||
# generation, model management, postprocessing
|
# generation, model management, postprocessing
|
||||||
/invokeai/backend @keturn @damian0815 @lstein @blessedcoolant @jpphoto
|
/invokeai/backend @damian0815 @lstein @blessedcoolant @jpphoto @gregghelt2
|
||||||
|
|
||||||
# front ends
|
# front ends
|
||||||
/invokeai/frontend/CLI @lstein
|
/invokeai/frontend/CLI @lstein
|
||||||
/invokeai/frontend/install @lstein @ebr @mauwii
|
/invokeai/frontend/install @lstein @ebr
|
||||||
/invokeai/frontend/merge @lstein @blessedcoolant @hipsterusername
|
/invokeai/frontend/merge @lstein @blessedcoolant @hipsterusername
|
||||||
/invokeai/frontend/training @lstein @blessedcoolant @hipsterusername
|
/invokeai/frontend/training @lstein @blessedcoolant @hipsterusername
|
||||||
/invokeai/frontend/web @psychedelicious @blessedcoolant
|
/invokeai/frontend/web @psychedelicious @blessedcoolant
|
||||||
|
2
.gitignore
vendored
2
.gitignore
vendored
@ -9,6 +9,8 @@ models/ldm/stable-diffusion-v1/model.ckpt
|
|||||||
configs/models.user.yaml
|
configs/models.user.yaml
|
||||||
config/models.user.yml
|
config/models.user.yml
|
||||||
invokeai.init
|
invokeai.init
|
||||||
|
.version
|
||||||
|
.last_model
|
||||||
|
|
||||||
# ignore the Anaconda/Miniconda installer used while building Docker image
|
# ignore the Anaconda/Miniconda installer used while building Docker image
|
||||||
anaconda.sh
|
anaconda.sh
|
||||||
|
@ -148,6 +148,11 @@ not supported.
|
|||||||
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
|
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
|
||||||
```
|
```
|
||||||
|
|
||||||
|
_For non-GPU systems:_
|
||||||
|
```terminal
|
||||||
|
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/cpu
|
||||||
|
```
|
||||||
|
|
||||||
_For Macintoshes, either Intel or M1/M2:_
|
_For Macintoshes, either Intel or M1/M2:_
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
|
@ -32,7 +32,7 @@ turned on and off on the command line using `--nsfw_checker` and
|
|||||||
At installation time, InvokeAI will ask whether the checker should be
|
At installation time, InvokeAI will ask whether the checker should be
|
||||||
activated by default (neither argument given on the command line). The
|
activated by default (neither argument given on the command line). The
|
||||||
response is stored in the InvokeAI initialization file (usually
|
response is stored in the InvokeAI initialization file (usually
|
||||||
`.invokeai` in your home directory). You can change the default at any
|
`invokeai.init` in your home directory). You can change the default at any
|
||||||
time by opening this file in a text editor and commenting or
|
time by opening this file in a text editor and commenting or
|
||||||
uncommenting the line `--nsfw_checker`.
|
uncommenting the line `--nsfw_checker`.
|
||||||
|
|
||||||
|
@ -3,6 +3,8 @@
|
|||||||
import os
|
import os
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
|
||||||
|
from invokeai.app.services.metadata import PngMetadataService, MetadataServiceBase
|
||||||
|
|
||||||
from ..services.default_graphs import create_system_graphs
|
from ..services.default_graphs import create_system_graphs
|
||||||
|
|
||||||
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||||
@ -60,7 +62,9 @@ class ApiDependencies:
|
|||||||
|
|
||||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents'))
|
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents'))
|
||||||
|
|
||||||
images = DiskImageStorage(f'{output_folder}/images')
|
metadata = PngMetadataService()
|
||||||
|
|
||||||
|
images = DiskImageStorage(f'{output_folder}/images', metadata_service=metadata)
|
||||||
|
|
||||||
# TODO: build a file/path manager?
|
# TODO: build a file/path manager?
|
||||||
db_location = os.path.join(output_folder, "invokeai.db")
|
db_location = os.path.join(output_folder, "invokeai.db")
|
||||||
@ -70,6 +74,7 @@ class ApiDependencies:
|
|||||||
events=events,
|
events=events,
|
||||||
latents=latents,
|
latents=latents,
|
||||||
images=images,
|
images=images,
|
||||||
|
metadata=metadata,
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
graph_library=SqliteItemStorage[LibraryGraph](
|
graph_library=SqliteItemStorage[LibraryGraph](
|
||||||
filename=db_location, table_name="graphs"
|
filename=db_location, table_name="graphs"
|
||||||
|
@ -1,7 +1,19 @@
|
|||||||
|
from typing import Optional
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageType
|
from invokeai.app.models.image import ImageType
|
||||||
from invokeai.app.models.metadata import ImageMetadata
|
from invokeai.app.services.metadata import InvokeAIMetadata
|
||||||
|
|
||||||
|
|
||||||
|
class ImageResponseMetadata(BaseModel):
|
||||||
|
"""An image's metadata. Used only in HTTP responses."""
|
||||||
|
|
||||||
|
created: int = Field(description="The creation timestamp of the image")
|
||||||
|
width: int = Field(description="The width of the image in pixels")
|
||||||
|
height: int = Field(description="The height of the image in pixels")
|
||||||
|
invokeai: Optional[InvokeAIMetadata] = Field(
|
||||||
|
description="The image's InvokeAI-specific metadata"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ImageResponse(BaseModel):
|
class ImageResponse(BaseModel):
|
||||||
@ -11,4 +23,12 @@ class ImageResponse(BaseModel):
|
|||||||
image_name: str = Field(description="The name of the image")
|
image_name: str = Field(description="The name of the image")
|
||||||
image_url: str = Field(description="The url of the image")
|
image_url: str = Field(description="The url of the image")
|
||||||
thumbnail_url: str = Field(description="The url of the image's thumbnail")
|
thumbnail_url: str = Field(description="The url of the image's thumbnail")
|
||||||
metadata: ImageMetadata = Field(description="The image's metadata")
|
metadata: ImageResponseMetadata = Field(description="The image's metadata")
|
||||||
|
|
||||||
|
|
||||||
|
class ProgressImage(BaseModel):
|
||||||
|
"""The progress image sent intermittently during processing"""
|
||||||
|
|
||||||
|
width: int = Field(description="The effective width of the image in pixels")
|
||||||
|
height: int = Field(description="The effective height of the image in pixels")
|
||||||
|
dataURL: str = Field(description="The image data as a b64 data URL")
|
||||||
|
@ -1,13 +1,17 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
import io
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from typing import Any
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from fastapi import Path, Query, Request, UploadFile
|
from fastapi import HTTPException, Path, Query, Request, UploadFile
|
||||||
from fastapi.responses import FileResponse, Response
|
from fastapi.responses import FileResponse, Response
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from invokeai.app.api.models.images import ImageResponse
|
from invokeai.app.api.models.images import ImageResponse, ImageResponseMetadata
|
||||||
|
from invokeai.app.services.metadata import InvokeAIMetadata
|
||||||
from invokeai.app.services.item_storage import PaginatedResults
|
from invokeai.app.services.item_storage import PaginatedResults
|
||||||
|
|
||||||
from ...services.image_storage import ImageType
|
from ...services.image_storage import ImageType
|
||||||
@ -15,70 +19,110 @@ from ..dependencies import ApiDependencies
|
|||||||
|
|
||||||
images_router = APIRouter(prefix="/v1/images", tags=["images"])
|
images_router = APIRouter(prefix="/v1/images", tags=["images"])
|
||||||
|
|
||||||
|
|
||||||
@images_router.get("/{image_type}/{image_name}", operation_id="get_image")
|
@images_router.get("/{image_type}/{image_name}", operation_id="get_image")
|
||||||
async def get_image(
|
async def get_image(
|
||||||
image_type: ImageType = Path(description="The type of image to get"),
|
image_type: ImageType = Path(description="The type of image to get"),
|
||||||
image_name: str = Path(description="The name of the image to get"),
|
image_name: str = Path(description="The name of the image to get"),
|
||||||
):
|
) -> FileResponse | Response:
|
||||||
"""Gets a result"""
|
"""Gets a result"""
|
||||||
# TODO: This is not really secure at all. At least make sure only output results are served
|
|
||||||
filename = ApiDependencies.invoker.services.images.get_path(image_type, image_name)
|
|
||||||
return FileResponse(filename)
|
|
||||||
|
|
||||||
@images_router.get("/{image_type}/thumbnails/{image_name}", operation_id="get_thumbnail")
|
path = ApiDependencies.invoker.services.images.get_path(
|
||||||
|
image_type=image_type, image_name=image_name
|
||||||
|
)
|
||||||
|
|
||||||
|
if ApiDependencies.invoker.services.images.validate_path(path):
|
||||||
|
return FileResponse(path)
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=404)
|
||||||
|
|
||||||
|
|
||||||
|
@images_router.get(
|
||||||
|
"/{image_type}/thumbnails/{image_name}", operation_id="get_thumbnail"
|
||||||
|
)
|
||||||
async def get_thumbnail(
|
async def get_thumbnail(
|
||||||
image_type: ImageType = Path(description="The type of image to get"),
|
image_type: ImageType = Path(description="The type of image to get"),
|
||||||
image_name: str = Path(description="The name of the image to get"),
|
image_name: str = Path(description="The name of the image to get"),
|
||||||
):
|
) -> FileResponse | Response:
|
||||||
"""Gets a thumbnail"""
|
"""Gets a thumbnail"""
|
||||||
# TODO: This is not really secure at all. At least make sure only output results are served
|
|
||||||
filename = ApiDependencies.invoker.services.images.get_path(image_type, 'thumbnails/' + image_name)
|
path = ApiDependencies.invoker.services.images.get_path(
|
||||||
return FileResponse(filename)
|
image_type=image_type, image_name=image_name, is_thumbnail=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if ApiDependencies.invoker.services.images.validate_path(path):
|
||||||
|
return FileResponse(path)
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=404)
|
||||||
|
|
||||||
|
|
||||||
@images_router.post(
|
@images_router.post(
|
||||||
"/uploads/",
|
"/uploads/",
|
||||||
operation_id="upload_image",
|
operation_id="upload_image",
|
||||||
responses={
|
responses={
|
||||||
201: {"description": "The image was uploaded successfully"},
|
201: {
|
||||||
404: {"description": "Session not found"},
|
"description": "The image was uploaded successfully",
|
||||||
|
"model": ImageResponse,
|
||||||
|
},
|
||||||
|
415: {"description": "Image upload failed"},
|
||||||
},
|
},
|
||||||
|
status_code=201,
|
||||||
)
|
)
|
||||||
async def upload_image(file: UploadFile, request: Request):
|
async def upload_image(
|
||||||
|
file: UploadFile, request: Request, response: Response
|
||||||
|
) -> ImageResponse:
|
||||||
if not file.content_type.startswith("image"):
|
if not file.content_type.startswith("image"):
|
||||||
return Response(status_code=415)
|
raise HTTPException(status_code=415, detail="Not an image")
|
||||||
|
|
||||||
contents = await file.read()
|
contents = await file.read()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
im = Image.open(contents)
|
img = Image.open(io.BytesIO(contents))
|
||||||
except:
|
except:
|
||||||
# Error opening the image
|
# Error opening the image
|
||||||
return Response(status_code=415)
|
raise HTTPException(status_code=415, detail="Failed to read image")
|
||||||
|
|
||||||
filename = f"{uuid.uuid4()}_{str(int(datetime.now(timezone.utc).timestamp()))}.png"
|
filename = f"{uuid.uuid4()}_{str(int(datetime.now(timezone.utc).timestamp()))}.png"
|
||||||
ApiDependencies.invoker.services.images.save(ImageType.UPLOAD, filename, im)
|
|
||||||
|
|
||||||
return Response(
|
(image_path, thumbnail_path, ctime) = ApiDependencies.invoker.services.images.save(
|
||||||
status_code=201,
|
ImageType.UPLOAD, filename, img
|
||||||
headers={
|
|
||||||
"Location": request.url_for(
|
|
||||||
"get_image", image_type=ImageType.UPLOAD.value, image_name=filename
|
|
||||||
)
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
invokeai_metadata = ApiDependencies.invoker.services.metadata.get_metadata(img)
|
||||||
|
|
||||||
|
res = ImageResponse(
|
||||||
|
image_type=ImageType.UPLOAD,
|
||||||
|
image_name=filename,
|
||||||
|
image_url=f"api/v1/images/{ImageType.UPLOAD.value}/{filename}",
|
||||||
|
thumbnail_url=f"api/v1/images/{ImageType.UPLOAD.value}/thumbnails/{os.path.splitext(filename)[0]}.webp",
|
||||||
|
metadata=ImageResponseMetadata(
|
||||||
|
created=ctime,
|
||||||
|
width=img.width,
|
||||||
|
height=img.height,
|
||||||
|
invokeai=invokeai_metadata,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
response.status_code = 201
|
||||||
|
response.headers["Location"] = request.url_for(
|
||||||
|
"get_image", image_type=ImageType.UPLOAD.value, image_name=filename
|
||||||
|
)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
@images_router.get(
|
@images_router.get(
|
||||||
"/",
|
"/",
|
||||||
operation_id="list_images",
|
operation_id="list_images",
|
||||||
responses={200: {"model": PaginatedResults[ImageResponse]}},
|
responses={200: {"model": PaginatedResults[ImageResponse]}},
|
||||||
)
|
)
|
||||||
async def list_images(
|
async def list_images(
|
||||||
image_type: ImageType = Query(default=ImageType.RESULT, description="The type of images to get"),
|
image_type: ImageType = Query(
|
||||||
|
default=ImageType.RESULT, description="The type of images to get"
|
||||||
|
),
|
||||||
page: int = Query(default=0, description="The page of images to get"),
|
page: int = Query(default=0, description="The page of images to get"),
|
||||||
per_page: int = Query(default=10, description="The number of images per page"),
|
per_page: int = Query(default=10, description="The number of images per page"),
|
||||||
) -> PaginatedResults[ImageResponse]:
|
) -> PaginatedResults[ImageResponse]:
|
||||||
"""Gets a list of images"""
|
"""Gets a list of images"""
|
||||||
result = ApiDependencies.invoker.services.images.list(
|
result = ApiDependencies.invoker.services.images.list(image_type, page, per_page)
|
||||||
image_type, page, per_page
|
|
||||||
)
|
|
||||||
return result
|
return result
|
||||||
|
@ -13,6 +13,8 @@ from typing import (
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from pydantic.fields import Field
|
from pydantic.fields import Field
|
||||||
|
|
||||||
|
from invokeai.app.services.metadata import PngMetadataService
|
||||||
|
|
||||||
from .services.default_graphs import create_system_graphs
|
from .services.default_graphs import create_system_graphs
|
||||||
|
|
||||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||||
@ -200,6 +202,8 @@ def invoke_cli():
|
|||||||
|
|
||||||
events = EventServiceBase()
|
events = EventServiceBase()
|
||||||
|
|
||||||
|
metadata = PngMetadataService()
|
||||||
|
|
||||||
output_folder = os.path.abspath(
|
output_folder = os.path.abspath(
|
||||||
os.path.join(os.path.dirname(__file__), "../../../outputs")
|
os.path.join(os.path.dirname(__file__), "../../../outputs")
|
||||||
)
|
)
|
||||||
@ -211,7 +215,8 @@ def invoke_cli():
|
|||||||
model_manager=model_manager,
|
model_manager=model_manager,
|
||||||
events=events,
|
events=events,
|
||||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')),
|
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')),
|
||||||
images=DiskImageStorage(f'{output_folder}/images'),
|
images=DiskImageStorage(f'{output_folder}/images', metadata_service=metadata),
|
||||||
|
metadata=metadata,
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
graph_library=SqliteItemStorage[LibraryGraph](
|
graph_library=SqliteItemStorage[LibraryGraph](
|
||||||
filename=db_location, table_name="graphs"
|
filename=db_location, table_name="graphs"
|
||||||
|
@ -95,7 +95,7 @@ class UIConfig(TypedDict, total=False):
|
|||||||
],
|
],
|
||||||
]
|
]
|
||||||
tags: List[str]
|
tags: List[str]
|
||||||
|
title: str
|
||||||
|
|
||||||
class CustomisedSchemaExtra(TypedDict):
|
class CustomisedSchemaExtra(TypedDict):
|
||||||
ui: UIConfig
|
ui: UIConfig
|
||||||
|
@ -1,16 +1,17 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
from typing import Literal
|
from typing import Literal, Optional
|
||||||
|
|
||||||
import cv2 as cv
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.random
|
import numpy.random
|
||||||
from PIL import Image, ImageOps
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from ..services.image_storage import ImageType
|
from .baseinvocation import (
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext, BaseInvocationOutput
|
BaseInvocation,
|
||||||
from .image import ImageField, ImageOutput
|
InvocationConfig,
|
||||||
|
InvocationContext,
|
||||||
|
BaseInvocationOutput,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class IntCollectionOutput(BaseInvocationOutput):
|
class IntCollectionOutput(BaseInvocationOutput):
|
||||||
@ -33,7 +34,9 @@ class RangeInvocation(BaseInvocation):
|
|||||||
step: int = Field(default=1, description="The step of the range")
|
step: int = Field(default=1, description="The step of the range")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||||
return IntCollectionOutput(collection=list(range(self.start, self.stop, self.step)))
|
return IntCollectionOutput(
|
||||||
|
collection=list(range(self.start, self.stop, self.step))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RandomRangeInvocation(BaseInvocation):
|
class RandomRangeInvocation(BaseInvocation):
|
||||||
@ -43,8 +46,19 @@ class RandomRangeInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
low: int = Field(default=0, description="The inclusive low value")
|
low: int = Field(default=0, description="The inclusive low value")
|
||||||
high: int = Field(default=np.iinfo(np.int32).max, description="The exclusive high value")
|
high: int = Field(
|
||||||
|
default=np.iinfo(np.int32).max, description="The exclusive high value"
|
||||||
|
)
|
||||||
size: int = Field(default=1, description="The number of values to generate")
|
size: int = Field(default=1, description="The number of values to generate")
|
||||||
|
seed: Optional[int] = Field(
|
||||||
|
ge=0,
|
||||||
|
le=np.iinfo(np.int32).max,
|
||||||
|
description="The seed for the RNG",
|
||||||
|
default_factory=lambda: numpy.random.randint(0, np.iinfo(np.int32).max),
|
||||||
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||||
return IntCollectionOutput(collection=list(numpy.random.randint(self.low, self.high, size=self.size)))
|
rng = np.random.default_rng(self.seed)
|
||||||
|
return IntCollectionOutput(
|
||||||
|
collection=list(rng.integers(low=self.low, high=self.high, size=self.size))
|
||||||
|
)
|
||||||
|
@ -9,7 +9,7 @@ from pydantic import BaseModel, Field
|
|||||||
|
|
||||||
from invokeai.app.models.image import ImageField, ImageType
|
from invokeai.app.models.image import ImageField, ImageType
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||||
from .image import ImageOutput
|
from .image import ImageOutput, build_image_output
|
||||||
|
|
||||||
|
|
||||||
class CvInvocationConfig(BaseModel):
|
class CvInvocationConfig(BaseModel):
|
||||||
@ -56,7 +56,14 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
|
|||||||
image_name = context.services.images.create_name(
|
image_name = context.services.images.create_name(
|
||||||
context.graph_execution_state_id, self.id
|
context.graph_execution_state_id, self.id
|
||||||
)
|
)
|
||||||
context.services.images.save(image_type, image_name, image_inpainted)
|
|
||||||
return ImageOutput(
|
metadata = context.services.metadata.build_metadata(
|
||||||
image=ImageField(image_type=image_type, image_name=image_name)
|
session_id=context.graph_execution_state_id, node=self
|
||||||
)
|
)
|
||||||
|
|
||||||
|
context.services.images.save(image_type, image_name, image_inpainted, metadata)
|
||||||
|
return build_image_output(
|
||||||
|
image_type=image_type,
|
||||||
|
image_name=image_name,
|
||||||
|
image=image_inpainted,
|
||||||
|
)
|
@ -9,13 +9,12 @@ from torch import Tensor
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageField, ImageType
|
from invokeai.app.models.image import ImageField, ImageType
|
||||||
from invokeai.app.invocations.util.get_model import choose_model
|
from invokeai.app.invocations.util.choose_model import choose_model
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||||
from .image import ImageOutput
|
from .image import ImageOutput, build_image_output
|
||||||
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
|
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
from ..models.exceptions import CanceledException
|
from ..util.step_callback import stable_diffusion_step_callback
|
||||||
from ..util.step_callback import diffusers_step_callback_adapter
|
|
||||||
|
|
||||||
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
|
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
|
||||||
|
|
||||||
@ -58,28 +57,31 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
|||||||
|
|
||||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||||
def dispatch_progress(
|
def dispatch_progress(
|
||||||
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
|
self,
|
||||||
|
context: InvocationContext,
|
||||||
|
source_node_id: str,
|
||||||
|
intermediate_state: PipelineIntermediateState,
|
||||||
) -> None:
|
) -> None:
|
||||||
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
|
stable_diffusion_step_callback(
|
||||||
raise CanceledException
|
context=context,
|
||||||
|
intermediate_state=intermediate_state,
|
||||||
step = intermediate_state.step
|
node=self.dict(),
|
||||||
if intermediate_state.predicted_original is not None:
|
source_node_id=source_node_id,
|
||||||
# Some schedulers report not only the noisy latents at the current timestep,
|
)
|
||||||
# but also their estimate so far of what the de-noised latents will be.
|
|
||||||
sample = intermediate_state.predicted_original
|
|
||||||
else:
|
|
||||||
sample = intermediate_state.latents
|
|
||||||
|
|
||||||
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
# Handle invalid model parameter
|
# Handle invalid model parameter
|
||||||
model = choose_model(context.services.model_manager, self.model)
|
model = choose_model(context.services.model_manager, self.model)
|
||||||
|
|
||||||
|
# Get the source node id (we are invoking the prepared node)
|
||||||
|
graph_execution_state = context.services.graph_execution_manager.get(
|
||||||
|
context.graph_execution_state_id
|
||||||
|
)
|
||||||
|
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||||
|
|
||||||
outputs = Txt2Img(model).generate(
|
outputs = Txt2Img(model).generate(
|
||||||
prompt=self.prompt,
|
prompt=self.prompt,
|
||||||
step_callback=partial(self.dispatch_progress, context),
|
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||||
**self.dict(
|
**self.dict(
|
||||||
exclude={"prompt"}
|
exclude={"prompt"}
|
||||||
), # Shorthand for passing all of the parameters above manually
|
), # Shorthand for passing all of the parameters above manually
|
||||||
@ -95,9 +97,18 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
|||||||
image_name = context.services.images.create_name(
|
image_name = context.services.images.create_name(
|
||||||
context.graph_execution_state_id, self.id
|
context.graph_execution_state_id, self.id
|
||||||
)
|
)
|
||||||
context.services.images.save(image_type, image_name, generate_output.image)
|
|
||||||
return ImageOutput(
|
metadata = context.services.metadata.build_metadata(
|
||||||
image=ImageField(image_type=image_type, image_name=image_name)
|
session_id=context.graph_execution_state_id, node=self
|
||||||
|
)
|
||||||
|
|
||||||
|
context.services.images.save(
|
||||||
|
image_type, image_name, generate_output.image, metadata
|
||||||
|
)
|
||||||
|
return build_image_output(
|
||||||
|
image_type=image_type,
|
||||||
|
image_name=image_name,
|
||||||
|
image=generate_output.image,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -117,20 +128,17 @@ class ImageToImageInvocation(TextToImageInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def dispatch_progress(
|
def dispatch_progress(
|
||||||
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
|
self,
|
||||||
) -> None:
|
context: InvocationContext,
|
||||||
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
|
source_node_id: str,
|
||||||
raise CanceledException
|
intermediate_state: PipelineIntermediateState,
|
||||||
|
) -> None:
|
||||||
step = intermediate_state.step
|
stable_diffusion_step_callback(
|
||||||
if intermediate_state.predicted_original is not None:
|
context=context,
|
||||||
# Some schedulers report not only the noisy latents at the current timestep,
|
intermediate_state=intermediate_state,
|
||||||
# but also their estimate so far of what the de-noised latents will be.
|
node=self.dict(),
|
||||||
sample = intermediate_state.predicted_original
|
source_node_id=source_node_id,
|
||||||
else:
|
)
|
||||||
sample = intermediate_state.latents
|
|
||||||
|
|
||||||
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = (
|
image = (
|
||||||
@ -145,15 +153,21 @@ class ImageToImageInvocation(TextToImageInvocation):
|
|||||||
# Handle invalid model parameter
|
# Handle invalid model parameter
|
||||||
model = choose_model(context.services.model_manager, self.model)
|
model = choose_model(context.services.model_manager, self.model)
|
||||||
|
|
||||||
|
# Get the source node id (we are invoking the prepared node)
|
||||||
|
graph_execution_state = context.services.graph_execution_manager.get(
|
||||||
|
context.graph_execution_state_id
|
||||||
|
)
|
||||||
|
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||||
|
|
||||||
outputs = Img2Img(model).generate(
|
outputs = Img2Img(model).generate(
|
||||||
prompt=self.prompt,
|
prompt=self.prompt,
|
||||||
init_image=image,
|
init_image=image,
|
||||||
init_mask=mask,
|
init_mask=mask,
|
||||||
step_callback=partial(self.dispatch_progress, context),
|
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||||
**self.dict(
|
**self.dict(
|
||||||
exclude={"prompt", "image", "mask"}
|
exclude={"prompt", "image", "mask"}
|
||||||
), # Shorthand for passing all of the parameters above manually
|
), # Shorthand for passing all of the parameters above manually
|
||||||
)
|
)
|
||||||
|
|
||||||
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
||||||
# each time it is called. We only need the first one.
|
# each time it is called. We only need the first one.
|
||||||
@ -168,11 +182,19 @@ class ImageToImageInvocation(TextToImageInvocation):
|
|||||||
image_name = context.services.images.create_name(
|
image_name = context.services.images.create_name(
|
||||||
context.graph_execution_state_id, self.id
|
context.graph_execution_state_id, self.id
|
||||||
)
|
)
|
||||||
context.services.images.save(image_type, image_name, result_image)
|
|
||||||
return ImageOutput(
|
metadata = context.services.metadata.build_metadata(
|
||||||
image=ImageField(image_type=image_type, image_name=image_name)
|
session_id=context.graph_execution_state_id, node=self
|
||||||
)
|
)
|
||||||
|
|
||||||
|
context.services.images.save(image_type, image_name, result_image, metadata)
|
||||||
|
return build_image_output(
|
||||||
|
image_type=image_type,
|
||||||
|
image_name=image_name,
|
||||||
|
image=result_image,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class InpaintInvocation(ImageToImageInvocation):
|
class InpaintInvocation(ImageToImageInvocation):
|
||||||
"""Generates an image using inpaint."""
|
"""Generates an image using inpaint."""
|
||||||
|
|
||||||
@ -188,20 +210,17 @@ class InpaintInvocation(ImageToImageInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def dispatch_progress(
|
def dispatch_progress(
|
||||||
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
|
self,
|
||||||
) -> None:
|
context: InvocationContext,
|
||||||
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
|
source_node_id: str,
|
||||||
raise CanceledException
|
intermediate_state: PipelineIntermediateState,
|
||||||
|
) -> None:
|
||||||
step = intermediate_state.step
|
stable_diffusion_step_callback(
|
||||||
if intermediate_state.predicted_original is not None:
|
context=context,
|
||||||
# Some schedulers report not only the noisy latents at the current timestep,
|
intermediate_state=intermediate_state,
|
||||||
# but also their estimate so far of what the de-noised latents will be.
|
node=self.dict(),
|
||||||
sample = intermediate_state.predicted_original
|
source_node_id=source_node_id,
|
||||||
else:
|
)
|
||||||
sample = intermediate_state.latents
|
|
||||||
|
|
||||||
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = (
|
image = (
|
||||||
@ -218,17 +237,23 @@ class InpaintInvocation(ImageToImageInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Handle invalid model parameter
|
# Handle invalid model parameter
|
||||||
model = choose_model(context.services.model_manager, self.model)
|
model = choose_model(context.services.model_manager, self.model)
|
||||||
|
|
||||||
|
# Get the source node id (we are invoking the prepared node)
|
||||||
|
graph_execution_state = context.services.graph_execution_manager.get(
|
||||||
|
context.graph_execution_state_id
|
||||||
|
)
|
||||||
|
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||||
|
|
||||||
outputs = Inpaint(model).generate(
|
outputs = Inpaint(model).generate(
|
||||||
prompt=self.prompt,
|
prompt=self.prompt,
|
||||||
init_img=image,
|
init_img=image,
|
||||||
init_mask=mask,
|
init_mask=mask,
|
||||||
step_callback=partial(self.dispatch_progress, context),
|
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||||
**self.dict(
|
**self.dict(
|
||||||
exclude={"prompt", "image", "mask"}
|
exclude={"prompt", "image", "mask"}
|
||||||
), # Shorthand for passing all of the parameters above manually
|
), # Shorthand for passing all of the parameters above manually
|
||||||
)
|
)
|
||||||
|
|
||||||
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
||||||
# each time it is called. We only need the first one.
|
# each time it is called. We only need the first one.
|
||||||
@ -243,7 +268,14 @@ class InpaintInvocation(ImageToImageInvocation):
|
|||||||
image_name = context.services.images.create_name(
|
image_name = context.services.images.create_name(
|
||||||
context.graph_execution_state_id, self.id
|
context.graph_execution_state_id, self.id
|
||||||
)
|
)
|
||||||
context.services.images.save(image_type, image_name, result_image)
|
|
||||||
return ImageOutput(
|
metadata = context.services.metadata.build_metadata(
|
||||||
image=ImageField(image_type=image_type, image_name=image_name)
|
session_id=context.graph_execution_state_id, node=self
|
||||||
|
)
|
||||||
|
|
||||||
|
context.services.images.save(image_type, image_name, result_image, metadata)
|
||||||
|
return build_image_output(
|
||||||
|
image_type=image_type,
|
||||||
|
image_name=image_name,
|
||||||
|
image=result_image,
|
||||||
)
|
)
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from typing import Literal, Optional
|
from typing import Literal, Optional
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
@ -8,8 +7,12 @@ from PIL import Image, ImageFilter, ImageOps
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from ..models.image import ImageField, ImageType
|
from ..models.image import ImageField, ImageType
|
||||||
from ..services.invocation_services import InvocationServices
|
from .baseinvocation import (
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
BaseInvocation,
|
||||||
|
BaseInvocationOutput,
|
||||||
|
InvocationContext,
|
||||||
|
InvocationConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PILInvocationConfig(BaseModel):
|
class PILInvocationConfig(BaseModel):
|
||||||
@ -22,50 +25,73 @@ class PILInvocationConfig(BaseModel):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class ImageOutput(BaseInvocationOutput):
|
class ImageOutput(BaseInvocationOutput):
|
||||||
"""Base class for invocations that output an image"""
|
"""Base class for invocations that output an image"""
|
||||||
#fmt: off
|
|
||||||
|
# fmt: off
|
||||||
type: Literal["image"] = "image"
|
type: Literal["image"] = "image"
|
||||||
image: ImageField = Field(default=None, description="The output image")
|
image: ImageField = Field(default=None, description="The output image")
|
||||||
#fmt: on
|
width: Optional[int] = Field(default=None, description="The width of the image in pixels")
|
||||||
|
height: Optional[int] = Field(default=None, description="The height of the image in pixels")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
'required': [
|
"required": ["type", "image", "width", "height", "mode"]
|
||||||
'type',
|
|
||||||
'image',
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def build_image_output(
|
||||||
|
image_type: ImageType, image_name: str, image: Image.Image
|
||||||
|
) -> ImageOutput:
|
||||||
|
"""Builds an ImageOutput and its ImageField"""
|
||||||
|
image_field = ImageField(
|
||||||
|
image_name=image_name,
|
||||||
|
image_type=image_type,
|
||||||
|
)
|
||||||
|
return ImageOutput(
|
||||||
|
image=image_field,
|
||||||
|
width=image.width,
|
||||||
|
height=image.height,
|
||||||
|
mode=image.mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MaskOutput(BaseInvocationOutput):
|
class MaskOutput(BaseInvocationOutput):
|
||||||
"""Base class for invocations that output a mask"""
|
"""Base class for invocations that output a mask"""
|
||||||
#fmt: off
|
|
||||||
|
# fmt: off
|
||||||
type: Literal["mask"] = "mask"
|
type: Literal["mask"] = "mask"
|
||||||
mask: ImageField = Field(default=None, description="The output mask")
|
mask: ImageField = Field(default=None, description="The output mask")
|
||||||
#fmt: on
|
# fmt: on
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
'required': [
|
"required": [
|
||||||
'type',
|
"type",
|
||||||
'mask',
|
"mask",
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
# TODO: this isn't really necessary anymore
|
|
||||||
class LoadImageInvocation(BaseInvocation):
|
class LoadImageInvocation(BaseInvocation):
|
||||||
"""Load an image from a filename and provide it as output."""
|
"""Load an image and provide it as output."""
|
||||||
#fmt: off
|
|
||||||
|
# fmt: off
|
||||||
type: Literal["load_image"] = "load_image"
|
type: Literal["load_image"] = "load_image"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image_type: ImageType = Field(description="The type of the image")
|
image_type: ImageType = Field(description="The type of the image")
|
||||||
image_name: str = Field(description="The name of the image")
|
image_name: str = Field(description="The name of the image")
|
||||||
#fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
return ImageOutput(
|
image = context.services.images.get(self.image_type, self.image_name)
|
||||||
image=ImageField(image_type=self.image_type, image_name=self.image_name)
|
|
||||||
|
return build_image_output(
|
||||||
|
image_type=self.image_type,
|
||||||
|
image_name=self.image_name,
|
||||||
|
image=image,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -86,16 +112,17 @@ class ShowImageInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# TODO: how to handle failure?
|
# TODO: how to handle failure?
|
||||||
|
|
||||||
return ImageOutput(
|
return build_image_output(
|
||||||
image=ImageField(
|
image_type=self.image.image_type,
|
||||||
image_type=self.image.image_type, image_name=self.image.image_name
|
image_name=self.image.image_name,
|
||||||
)
|
image=image,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class CropImageInvocation(BaseInvocation, PILInvocationConfig):
|
class CropImageInvocation(BaseInvocation, PILInvocationConfig):
|
||||||
"""Crops an image to a specified box. The box can be outside of the image."""
|
"""Crops an image to a specified box. The box can be outside of the image."""
|
||||||
#fmt: off
|
|
||||||
|
# fmt: off
|
||||||
type: Literal["crop"] = "crop"
|
type: Literal["crop"] = "crop"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
@ -104,7 +131,7 @@ class CropImageInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
y: int = Field(default=0, description="The top y coordinate of the crop rectangle")
|
y: int = Field(default=0, description="The top y coordinate of the crop rectangle")
|
||||||
width: int = Field(default=512, gt=0, description="The width of the crop rectangle")
|
width: int = Field(default=512, gt=0, description="The width of the crop rectangle")
|
||||||
height: int = Field(default=512, gt=0, description="The height of the crop rectangle")
|
height: int = Field(default=512, gt=0, description="The height of the crop rectangle")
|
||||||
#fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get(
|
image = context.services.images.get(
|
||||||
@ -120,15 +147,23 @@ class CropImageInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
image_name = context.services.images.create_name(
|
image_name = context.services.images.create_name(
|
||||||
context.graph_execution_state_id, self.id
|
context.graph_execution_state_id, self.id
|
||||||
)
|
)
|
||||||
context.services.images.save(image_type, image_name, image_crop)
|
|
||||||
return ImageOutput(
|
metadata = context.services.metadata.build_metadata(
|
||||||
image=ImageField(image_type=image_type, image_name=image_name)
|
session_id=context.graph_execution_state_id, node=self
|
||||||
|
)
|
||||||
|
|
||||||
|
context.services.images.save(image_type, image_name, image_crop, metadata)
|
||||||
|
return build_image_output(
|
||||||
|
image_type=image_type,
|
||||||
|
image_name=image_name,
|
||||||
|
image=image_crop,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
|
class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
|
||||||
"""Pastes an image into another image."""
|
"""Pastes an image into another image."""
|
||||||
#fmt: off
|
|
||||||
|
# fmt: off
|
||||||
type: Literal["paste"] = "paste"
|
type: Literal["paste"] = "paste"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
@ -137,7 +172,7 @@ class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting")
|
mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting")
|
||||||
x: int = Field(default=0, description="The left x coordinate at which to paste the image")
|
x: int = Field(default=0, description="The left x coordinate at which to paste the image")
|
||||||
y: int = Field(default=0, description="The top y coordinate at which to paste the image")
|
y: int = Field(default=0, description="The top y coordinate at which to paste the image")
|
||||||
#fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
base_image = context.services.images.get(
|
base_image = context.services.images.get(
|
||||||
@ -170,21 +205,29 @@ class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
image_name = context.services.images.create_name(
|
image_name = context.services.images.create_name(
|
||||||
context.graph_execution_state_id, self.id
|
context.graph_execution_state_id, self.id
|
||||||
)
|
)
|
||||||
context.services.images.save(image_type, image_name, new_image)
|
|
||||||
return ImageOutput(
|
metadata = context.services.metadata.build_metadata(
|
||||||
image=ImageField(image_type=image_type, image_name=image_name)
|
session_id=context.graph_execution_state_id, node=self
|
||||||
|
)
|
||||||
|
|
||||||
|
context.services.images.save(image_type, image_name, new_image, metadata)
|
||||||
|
return build_image_output(
|
||||||
|
image_type=image_type,
|
||||||
|
image_name=image_name,
|
||||||
|
image=new_image,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
||||||
"""Extracts the alpha channel of an image as a mask."""
|
"""Extracts the alpha channel of an image as a mask."""
|
||||||
#fmt: off
|
|
||||||
|
# fmt: off
|
||||||
type: Literal["tomask"] = "tomask"
|
type: Literal["tomask"] = "tomask"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: ImageField = Field(default=None, description="The image to create the mask from")
|
image: ImageField = Field(default=None, description="The image to create the mask from")
|
||||||
invert: bool = Field(default=False, description="Whether or not to invert the mask")
|
invert: bool = Field(default=False, description="Whether or not to invert the mask")
|
||||||
#fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||||
image = context.services.images.get(
|
image = context.services.images.get(
|
||||||
@ -199,22 +242,27 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
image_name = context.services.images.create_name(
|
image_name = context.services.images.create_name(
|
||||||
context.graph_execution_state_id, self.id
|
context.graph_execution_state_id, self.id
|
||||||
)
|
)
|
||||||
context.services.images.save(image_type, image_name, image_mask)
|
|
||||||
|
metadata = context.services.metadata.build_metadata(
|
||||||
|
session_id=context.graph_execution_state_id, node=self
|
||||||
|
)
|
||||||
|
|
||||||
|
context.services.images.save(image_type, image_name, image_mask, metadata)
|
||||||
return MaskOutput(mask=ImageField(image_type=image_type, image_name=image_name))
|
return MaskOutput(mask=ImageField(image_type=image_type, image_name=image_name))
|
||||||
|
|
||||||
|
|
||||||
class BlurInvocation(BaseInvocation, PILInvocationConfig):
|
class BlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||||
"""Blurs an image"""
|
"""Blurs an image"""
|
||||||
|
|
||||||
#fmt: off
|
# fmt: off
|
||||||
type: Literal["blur"] = "blur"
|
type: Literal["blur"] = "blur"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: ImageField = Field(default=None, description="The image to blur")
|
image: ImageField = Field(default=None, description="The image to blur")
|
||||||
radius: float = Field(default=8.0, ge=0, description="The blur radius")
|
radius: float = Field(default=8.0, ge=0, description="The blur radius")
|
||||||
blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur")
|
blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur")
|
||||||
#fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get(
|
image = context.services.images.get(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_type, self.image.image_name
|
||||||
@ -231,22 +279,28 @@ class BlurInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
image_name = context.services.images.create_name(
|
image_name = context.services.images.create_name(
|
||||||
context.graph_execution_state_id, self.id
|
context.graph_execution_state_id, self.id
|
||||||
)
|
)
|
||||||
context.services.images.save(image_type, image_name, blur_image)
|
|
||||||
return ImageOutput(
|
metadata = context.services.metadata.build_metadata(
|
||||||
image=ImageField(image_type=image_type, image_name=image_name)
|
session_id=context.graph_execution_state_id, node=self
|
||||||
|
)
|
||||||
|
|
||||||
|
context.services.images.save(image_type, image_name, blur_image, metadata)
|
||||||
|
return build_image_output(
|
||||||
|
image_type=image_type, image_name=image_name, image=blur_image
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class LerpInvocation(BaseInvocation, PILInvocationConfig):
|
class LerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||||
"""Linear interpolation of all pixels of an image"""
|
"""Linear interpolation of all pixels of an image"""
|
||||||
#fmt: off
|
|
||||||
|
# fmt: off
|
||||||
type: Literal["lerp"] = "lerp"
|
type: Literal["lerp"] = "lerp"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: ImageField = Field(default=None, description="The image to lerp")
|
image: ImageField = Field(default=None, description="The image to lerp")
|
||||||
min: int = Field(default=0, ge=0, le=255, description="The minimum output value")
|
min: int = Field(default=0, ge=0, le=255, description="The minimum output value")
|
||||||
max: int = Field(default=255, ge=0, le=255, description="The maximum output value")
|
max: int = Field(default=255, ge=0, le=255, description="The maximum output value")
|
||||||
#fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get(
|
image = context.services.images.get(
|
||||||
@ -262,23 +316,29 @@ class LerpInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
image_name = context.services.images.create_name(
|
image_name = context.services.images.create_name(
|
||||||
context.graph_execution_state_id, self.id
|
context.graph_execution_state_id, self.id
|
||||||
)
|
)
|
||||||
context.services.images.save(image_type, image_name, lerp_image)
|
|
||||||
return ImageOutput(
|
metadata = context.services.metadata.build_metadata(
|
||||||
image=ImageField(image_type=image_type, image_name=image_name)
|
session_id=context.graph_execution_state_id, node=self
|
||||||
|
)
|
||||||
|
|
||||||
|
context.services.images.save(image_type, image_name, lerp_image, metadata)
|
||||||
|
return build_image_output(
|
||||||
|
image_type=image_type, image_name=image_name, image=lerp_image
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class InverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
class InverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||||
"""Inverse linear interpolation of all pixels of an image"""
|
"""Inverse linear interpolation of all pixels of an image"""
|
||||||
#fmt: off
|
|
||||||
|
# fmt: off
|
||||||
type: Literal["ilerp"] = "ilerp"
|
type: Literal["ilerp"] = "ilerp"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: ImageField = Field(default=None, description="The image to lerp")
|
image: ImageField = Field(default=None, description="The image to lerp")
|
||||||
min: int = Field(default=0, ge=0, le=255, description="The minimum input value")
|
min: int = Field(default=0, ge=0, le=255, description="The minimum input value")
|
||||||
max: int = Field(default=255, ge=0, le=255, description="The maximum input value")
|
max: int = Field(default=255, ge=0, le=255, description="The maximum input value")
|
||||||
#fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get(
|
image = context.services.images.get(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_type, self.image.image_name
|
||||||
@ -298,7 +358,12 @@ class InverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
image_name = context.services.images.create_name(
|
image_name = context.services.images.create_name(
|
||||||
context.graph_execution_state_id, self.id
|
context.graph_execution_state_id, self.id
|
||||||
)
|
)
|
||||||
context.services.images.save(image_type, image_name, ilerp_image)
|
|
||||||
return ImageOutput(
|
metadata = context.services.metadata.build_metadata(
|
||||||
image=ImageField(image_type=image_type, image_name=image_name)
|
session_id=context.graph_execution_state_id, node=self
|
||||||
|
)
|
||||||
|
|
||||||
|
context.services.images.save(image_type, image_name, ilerp_image, metadata)
|
||||||
|
return build_image_output(
|
||||||
|
image_type=image_type, image_name=image_name, image=ilerp_image
|
||||||
)
|
)
|
||||||
|
@ -5,9 +5,9 @@ from typing import Literal, Optional
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from invokeai.app.models.exceptions import CanceledException
|
from invokeai.app.invocations.util.choose_model import choose_model
|
||||||
from invokeai.app.invocations.util.get_model import choose_model
|
|
||||||
from invokeai.app.util.step_callback import diffusers_step_callback_adapter
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
|
|
||||||
from ...backend.model_management.model_manager import ModelManager
|
from ...backend.model_management.model_manager import ModelManager
|
||||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||||
@ -19,7 +19,7 @@ from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationCont
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from ..services.image_storage import ImageType
|
from ..services.image_storage import ImageType
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext
|
from .baseinvocation import BaseInvocation, InvocationContext
|
||||||
from .image import ImageField, ImageOutput
|
from .image import ImageField, ImageOutput, build_image_output
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||||
import diffusers
|
import diffusers
|
||||||
@ -31,6 +31,8 @@ class LatentsField(BaseModel):
|
|||||||
|
|
||||||
latents_name: Optional[str] = Field(default=None, description="The name of the latents")
|
latents_name: Optional[str] = Field(default=None, description="The name of the latents")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {"required": ["latents_name"]}
|
||||||
|
|
||||||
class LatentsOutput(BaseInvocationOutput):
|
class LatentsOutput(BaseInvocationOutput):
|
||||||
"""Base class for invocations that output latents"""
|
"""Base class for invocations that output latents"""
|
||||||
@ -170,22 +172,15 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||||
def dispatch_progress(
|
def dispatch_progress(
|
||||||
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
|
self, context: InvocationContext, source_node_id: str, intermediate_state: PipelineIntermediateState
|
||||||
) -> None:
|
) -> None:
|
||||||
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
|
stable_diffusion_step_callback(
|
||||||
raise CanceledException
|
context=context,
|
||||||
|
intermediate_state=intermediate_state,
|
||||||
|
node=self.dict(),
|
||||||
|
source_node_id=source_node_id,
|
||||||
|
)
|
||||||
|
|
||||||
step = intermediate_state.step
|
|
||||||
if intermediate_state.predicted_original is not None:
|
|
||||||
# Some schedulers report not only the noisy latents at the current timestep,
|
|
||||||
# but also their estimate so far of what the de-noised latents will be.
|
|
||||||
sample = intermediate_state.predicted_original
|
|
||||||
else:
|
|
||||||
sample = intermediate_state.latents
|
|
||||||
|
|
||||||
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
|
|
||||||
|
|
||||||
|
|
||||||
def get_model(self, model_manager: ModelManager) -> StableDiffusionGeneratorPipeline:
|
def get_model(self, model_manager: ModelManager) -> StableDiffusionGeneratorPipeline:
|
||||||
model_info = choose_model(model_manager, self.model)
|
model_info = choose_model(model_manager, self.model)
|
||||||
model_name = model_info['model_name']
|
model_name = model_info['model_name']
|
||||||
@ -195,7 +190,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
model=model,
|
model=model,
|
||||||
scheduler_name=self.scheduler
|
scheduler_name=self.scheduler
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(model, DiffusionPipeline):
|
if isinstance(model, DiffusionPipeline):
|
||||||
for component in [model.unet, model.vae]:
|
for component in [model.unet, model.vae]:
|
||||||
configure_model_padding(component,
|
configure_model_padding(component,
|
||||||
@ -231,8 +226,12 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
noise = context.services.latents.get(self.noise.latents_name)
|
noise = context.services.latents.get(self.noise.latents_name)
|
||||||
|
|
||||||
|
# Get the source node id (we are invoking the prepared node)
|
||||||
|
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||||
|
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||||
|
|
||||||
def step_callback(state: PipelineIntermediateState):
|
def step_callback(state: PipelineIntermediateState):
|
||||||
self.dispatch_progress(context, state)
|
self.dispatch_progress(context, source_node_id, state)
|
||||||
|
|
||||||
model = self.get_model(context.services.model_manager)
|
model = self.get_model(context.services.model_manager)
|
||||||
conditioning_data = self.get_conditioning_data(model)
|
conditioning_data = self.get_conditioning_data(model)
|
||||||
@ -281,8 +280,12 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
noise = context.services.latents.get(self.noise.latents_name)
|
noise = context.services.latents.get(self.noise.latents_name)
|
||||||
latent = context.services.latents.get(self.latents.latents_name)
|
latent = context.services.latents.get(self.latents.latents_name)
|
||||||
|
|
||||||
|
# Get the source node id (we are invoking the prepared node)
|
||||||
|
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||||
|
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||||
|
|
||||||
def step_callback(state: PipelineIntermediateState):
|
def step_callback(state: PipelineIntermediateState):
|
||||||
self.dispatch_progress(context, state)
|
self.dispatch_progress(context, source_node_id, state)
|
||||||
|
|
||||||
model = self.get_model(context.services.model_manager)
|
model = self.get_model(context.services.model_manager)
|
||||||
conditioning_data = self.get_conditioning_data(model)
|
conditioning_data = self.get_conditioning_data(model)
|
||||||
@ -292,57 +295,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
|
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
|
||||||
latent, device=model.device, dtype=latent.dtype
|
latent, device=model.device, dtype=latent.dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
timesteps, _ = model.get_img2img_timesteps(
|
|
||||||
self.steps,
|
|
||||||
self.strength,
|
|
||||||
device=model.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
result_latents, result_attention_map_saver = model.latents_from_embeddings(
|
|
||||||
latents=initial_latents,
|
|
||||||
timesteps=timesteps,
|
|
||||||
noise=noise,
|
|
||||||
num_inference_steps=self.steps,
|
|
||||||
conditioning_data=conditioning_data,
|
|
||||||
callback=step_callback
|
|
||||||
)
|
|
||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
|
||||||
context.services.latents.set(name, result_latents)
|
|
||||||
return LatentsOutput(
|
|
||||||
latents=LatentsField(latents_name=name)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|
||||||
"""Generates latents using latents as base image."""
|
|
||||||
|
|
||||||
type: Literal["l2l"] = "l2l"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
|
|
||||||
strength: float = Field(default=0.5, description="The strength of the latents to use")
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
|
||||||
noise = context.services.latents.get(self.noise.latents_name)
|
|
||||||
latent = context.services.latents.get(self.latents.latents_name)
|
|
||||||
|
|
||||||
def step_callback(state: PipelineIntermediateState):
|
|
||||||
self.dispatch_progress(context, state)
|
|
||||||
|
|
||||||
model = self.get_model(context.services.model_manager)
|
|
||||||
conditioning_data = self.get_conditioning_data(model)
|
|
||||||
|
|
||||||
# TODO: Verify the noise is the right size
|
|
||||||
|
|
||||||
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
|
|
||||||
latent, device=model.device, dtype=latent.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
timesteps, _ = model.get_img2img_timesteps(
|
timesteps, _ = model.get_img2img_timesteps(
|
||||||
self.steps,
|
self.steps,
|
||||||
self.strength,
|
self.strength,
|
||||||
@ -405,7 +358,14 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
image_name = context.services.images.create_name(
|
image_name = context.services.images.create_name(
|
||||||
context.graph_execution_state_id, self.id
|
context.graph_execution_state_id, self.id
|
||||||
)
|
)
|
||||||
context.services.images.save(image_type, image_name, image)
|
|
||||||
return ImageOutput(
|
metadata = context.services.metadata.build_metadata(
|
||||||
image=ImageField(image_type=image_type, image_name=image_name)
|
session_id=context.graph_execution_state_id, node=self
|
||||||
|
)
|
||||||
|
|
||||||
|
context.services.images.save(image_type, image_name, image, metadata)
|
||||||
|
return build_image_output(
|
||||||
|
image_type=image_type,
|
||||||
|
image_name=image_name,
|
||||||
|
image=image
|
||||||
)
|
)
|
||||||
|
@ -1,12 +1,11 @@
|
|||||||
from datetime import datetime, timezone
|
|
||||||
from typing import Literal, Union
|
from typing import Literal, Union
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageField, ImageType
|
from invokeai.app.models.image import ImageField, ImageType
|
||||||
from ..services.invocation_services import InvocationServices
|
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||||
from .image import ImageOutput
|
from .image import ImageOutput, build_image_output
|
||||||
|
|
||||||
class RestoreFaceInvocation(BaseInvocation):
|
class RestoreFaceInvocation(BaseInvocation):
|
||||||
"""Restores faces in an image."""
|
"""Restores faces in an image."""
|
||||||
@ -44,7 +43,14 @@ class RestoreFaceInvocation(BaseInvocation):
|
|||||||
image_name = context.services.images.create_name(
|
image_name = context.services.images.create_name(
|
||||||
context.graph_execution_state_id, self.id
|
context.graph_execution_state_id, self.id
|
||||||
)
|
)
|
||||||
context.services.images.save(image_type, image_name, results[0][0])
|
|
||||||
return ImageOutput(
|
metadata = context.services.metadata.build_metadata(
|
||||||
image=ImageField(image_type=image_type, image_name=image_name)
|
session_id=context.graph_execution_state_id, node=self
|
||||||
)
|
)
|
||||||
|
|
||||||
|
context.services.images.save(image_type, image_name, results[0][0], metadata)
|
||||||
|
return build_image_output(
|
||||||
|
image_type=image_type,
|
||||||
|
image_name=image_name,
|
||||||
|
image=results[0][0]
|
||||||
|
)
|
@ -1,14 +1,12 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from typing import Literal, Union
|
from typing import Literal, Union
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageField, ImageType
|
from invokeai.app.models.image import ImageField, ImageType
|
||||||
from ..services.invocation_services import InvocationServices
|
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||||
from .image import ImageOutput
|
from .image import ImageOutput, build_image_output
|
||||||
|
|
||||||
|
|
||||||
class UpscaleInvocation(BaseInvocation):
|
class UpscaleInvocation(BaseInvocation):
|
||||||
@ -49,7 +47,14 @@ class UpscaleInvocation(BaseInvocation):
|
|||||||
image_name = context.services.images.create_name(
|
image_name = context.services.images.create_name(
|
||||||
context.graph_execution_state_id, self.id
|
context.graph_execution_state_id, self.id
|
||||||
)
|
)
|
||||||
context.services.images.save(image_type, image_name, results[0][0])
|
|
||||||
return ImageOutput(
|
metadata = context.services.metadata.build_metadata(
|
||||||
image=ImageField(image_type=image_type, image_name=image_name)
|
session_id=context.graph_execution_state_id, node=self
|
||||||
)
|
)
|
||||||
|
|
||||||
|
context.services.images.save(image_type, image_name, results[0][0], metadata)
|
||||||
|
return build_image_output(
|
||||||
|
image_type=image_type,
|
||||||
|
image_name=image_name,
|
||||||
|
image=results[0][0]
|
||||||
|
)
|
@ -1,11 +1,14 @@
|
|||||||
from invokeai.app.invocations.baseinvocation import InvocationContext
|
|
||||||
from invokeai.backend.model_management.model_manager import ModelManager
|
from invokeai.backend.model_management.model_manager import ModelManager
|
||||||
|
|
||||||
|
|
||||||
def choose_model(model_manager: ModelManager, model_name: str):
|
def choose_model(model_manager: ModelManager, model_name: str):
|
||||||
"""Returns the default model if the `model_name` not a valid model, else returns the selected model."""
|
"""Returns the default model if the `model_name` not a valid model, else returns the selected model."""
|
||||||
if model_manager.valid_model(model_name):
|
if model_manager.valid_model(model_name):
|
||||||
return model_manager.get_model(model_name)
|
model = model_manager.get_model(model_name)
|
||||||
else:
|
else:
|
||||||
print(f"* Warning: '{model_name}' is not a valid model name. Using default model instead.")
|
model = model_manager.get_model()
|
||||||
return model_manager.get_model()
|
print(
|
||||||
|
f"* Warning: '{model_name}' is not a valid model name. Using default model \'{model['model_name']}\' instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
return model
|
@ -9,6 +9,14 @@ class ImageType(str, Enum):
|
|||||||
UPLOAD = "uploads"
|
UPLOAD = "uploads"
|
||||||
|
|
||||||
|
|
||||||
|
def is_image_type(obj):
|
||||||
|
try:
|
||||||
|
ImageType(obj)
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
class ImageField(BaseModel):
|
class ImageField(BaseModel):
|
||||||
"""An image field used for passing image objects between invocations"""
|
"""An image field used for passing image objects between invocations"""
|
||||||
|
|
||||||
@ -18,9 +26,4 @@ class ImageField(BaseModel):
|
|||||||
image_name: Optional[str] = Field(default=None, description="The name of the image")
|
image_name: Optional[str] = Field(default=None, description="The name of the image")
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
schema_extra = {
|
schema_extra = {"required": ["image_type", "image_name"]}
|
||||||
"required": [
|
|
||||||
"image_type",
|
|
||||||
"image_name",
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
@ -1,11 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
class ImageMetadata(BaseModel):
|
|
||||||
"""An image's metadata"""
|
|
||||||
|
|
||||||
timestamp: float = Field(description="The creation timestamp of the image")
|
|
||||||
width: int = Field(description="The width of the image in pixels")
|
|
||||||
height: int = Field(description="The height of the image in pixels")
|
|
||||||
# TODO: figure out metadata
|
|
||||||
sd_metadata: Optional[dict] = Field(default={}, description="The image's SD-specific metadata")
|
|
@ -1,10 +1,9 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
from typing import Any, Dict, TypedDict
|
from typing import Any
|
||||||
|
from invokeai.app.api.models.images import ProgressImage
|
||||||
|
from invokeai.app.util.misc import get_timestamp
|
||||||
|
|
||||||
ProgressImage = TypedDict(
|
|
||||||
"ProgressImage", {"dataURL": str, "width": int, "height": int}
|
|
||||||
)
|
|
||||||
|
|
||||||
class EventServiceBase:
|
class EventServiceBase:
|
||||||
session_event: str = "session_event"
|
session_event: str = "session_event"
|
||||||
@ -14,7 +13,8 @@ class EventServiceBase:
|
|||||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __emit_session_event(self, event_name: str, payload: Dict) -> None:
|
def __emit_session_event(self, event_name: str, payload: dict) -> None:
|
||||||
|
payload["timestamp"] = get_timestamp()
|
||||||
self.dispatch(
|
self.dispatch(
|
||||||
event_name=EventServiceBase.session_event,
|
event_name=EventServiceBase.session_event,
|
||||||
payload=dict(event=event_name, data=payload),
|
payload=dict(event=event_name, data=payload),
|
||||||
@ -25,7 +25,8 @@ class EventServiceBase:
|
|||||||
def emit_generator_progress(
|
def emit_generator_progress(
|
||||||
self,
|
self,
|
||||||
graph_execution_state_id: str,
|
graph_execution_state_id: str,
|
||||||
invocation_id: str,
|
node: dict,
|
||||||
|
source_node_id: str,
|
||||||
progress_image: ProgressImage | None,
|
progress_image: ProgressImage | None,
|
||||||
step: int,
|
step: int,
|
||||||
total_steps: int,
|
total_steps: int,
|
||||||
@ -35,48 +36,60 @@ class EventServiceBase:
|
|||||||
event_name="generator_progress",
|
event_name="generator_progress",
|
||||||
payload=dict(
|
payload=dict(
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
invocation_id=invocation_id,
|
node=node,
|
||||||
progress_image=progress_image,
|
source_node_id=source_node_id,
|
||||||
|
progress_image=progress_image.dict() if progress_image is not None else None,
|
||||||
step=step,
|
step=step,
|
||||||
total_steps=total_steps,
|
total_steps=total_steps,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_invocation_complete(
|
def emit_invocation_complete(
|
||||||
self, graph_execution_state_id: str, invocation_id: str, result: Dict
|
self,
|
||||||
|
graph_execution_state_id: str,
|
||||||
|
result: dict,
|
||||||
|
node: dict,
|
||||||
|
source_node_id: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Emitted when an invocation has completed"""
|
"""Emitted when an invocation has completed"""
|
||||||
self.__emit_session_event(
|
self.__emit_session_event(
|
||||||
event_name="invocation_complete",
|
event_name="invocation_complete",
|
||||||
payload=dict(
|
payload=dict(
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
invocation_id=invocation_id,
|
node=node,
|
||||||
|
source_node_id=source_node_id,
|
||||||
result=result,
|
result=result,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_invocation_error(
|
def emit_invocation_error(
|
||||||
self, graph_execution_state_id: str, invocation_id: str, error: str
|
self,
|
||||||
|
graph_execution_state_id: str,
|
||||||
|
node: dict,
|
||||||
|
source_node_id: str,
|
||||||
|
error: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Emitted when an invocation has completed"""
|
"""Emitted when an invocation has completed"""
|
||||||
self.__emit_session_event(
|
self.__emit_session_event(
|
||||||
event_name="invocation_error",
|
event_name="invocation_error",
|
||||||
payload=dict(
|
payload=dict(
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
invocation_id=invocation_id,
|
node=node,
|
||||||
|
source_node_id=source_node_id,
|
||||||
error=error,
|
error=error,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_invocation_started(
|
def emit_invocation_started(
|
||||||
self, graph_execution_state_id: str, invocation_id: str
|
self, graph_execution_state_id: str, node: dict, source_node_id: str
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Emitted when an invocation has started"""
|
"""Emitted when an invocation has started"""
|
||||||
self.__emit_session_event(
|
self.__emit_session_event(
|
||||||
event_name="invocation_started",
|
event_name="invocation_started",
|
||||||
payload=dict(
|
payload=dict(
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
invocation_id=invocation_id,
|
node=node,
|
||||||
|
source_node_id=source_node_id,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -84,5 +97,7 @@ class EventServiceBase:
|
|||||||
"""Emitted when a session has completed all invocations"""
|
"""Emitted when a session has completed all invocations"""
|
||||||
self.__emit_session_event(
|
self.__emit_session_event(
|
||||||
event_name="graph_execution_state_complete",
|
event_name="graph_execution_state_complete",
|
||||||
payload=dict(graph_execution_state_id=graph_execution_state_id),
|
payload=dict(
|
||||||
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
import itertools
|
import itertools
|
||||||
import traceback
|
|
||||||
import uuid
|
import uuid
|
||||||
from types import NoneType
|
from types import NoneType
|
||||||
from typing import (
|
from typing import (
|
||||||
@ -26,7 +25,6 @@ from ..invocations.baseinvocation import (
|
|||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
InvocationContext,
|
InvocationContext,
|
||||||
)
|
)
|
||||||
from .invocation_services import InvocationServices
|
|
||||||
|
|
||||||
|
|
||||||
class EdgeConnection(BaseModel):
|
class EdgeConnection(BaseModel):
|
||||||
@ -215,7 +213,7 @@ InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()]
|
|||||||
|
|
||||||
|
|
||||||
class Graph(BaseModel):
|
class Graph(BaseModel):
|
||||||
id: str = Field(description="The id of this graph", default_factory=uuid.uuid4)
|
id: str = Field(description="The id of this graph", default_factory=lambda: uuid.uuid4().__str__())
|
||||||
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
|
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
|
||||||
nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(
|
nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(
|
||||||
description="The nodes in this graph", default_factory=dict
|
description="The nodes in this graph", default_factory=dict
|
||||||
@ -750,9 +748,7 @@ class Graph(BaseModel):
|
|||||||
class GraphExecutionState(BaseModel):
|
class GraphExecutionState(BaseModel):
|
||||||
"""Tracks the state of a graph execution"""
|
"""Tracks the state of a graph execution"""
|
||||||
|
|
||||||
id: str = Field(
|
id: str = Field(description="The id of the execution state", default_factory=lambda: uuid.uuid4().__str__())
|
||||||
description="The id of the execution state", default_factory=uuid.uuid4
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: Store a reference to the graph instead of the actual graph?
|
# TODO: Store a reference to the graph instead of the actual graph?
|
||||||
graph: Graph = Field(description="The graph being executed")
|
graph: Graph = Field(description="The graph being executed")
|
||||||
@ -1171,7 +1167,7 @@ class LibraryGraph(BaseModel):
|
|||||||
if len(v) != len(set(i.alias for i in v)):
|
if len(v) != len(set(i.alias for i in v)):
|
||||||
raise ValueError("Duplicate exposed alias")
|
raise ValueError("Duplicate exposed alias")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@root_validator
|
@root_validator
|
||||||
def validate_exposed_nodes(cls, values):
|
def validate_exposed_nodes(cls, values):
|
||||||
graph = values['graph']
|
graph = values['graph']
|
||||||
|
@ -1,24 +1,24 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
import datetime
|
|
||||||
import os
|
import os
|
||||||
from glob import glob
|
from glob import glob
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from typing import Callable, Dict, List
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
from PIL.Image import Image
|
from PIL.Image import Image
|
||||||
import PIL.Image as PILImage
|
import PIL.Image as PILImage
|
||||||
from pydantic import BaseModel
|
from invokeai.app.api.models.images import ImageResponse, ImageResponseMetadata
|
||||||
from invokeai.app.api.models.images import ImageResponse
|
from invokeai.app.models.image import ImageType
|
||||||
from invokeai.app.models.image import ImageField, ImageType
|
from invokeai.app.services.metadata import (
|
||||||
from invokeai.app.models.metadata import ImageMetadata
|
InvokeAIMetadata,
|
||||||
|
MetadataServiceBase,
|
||||||
|
build_invokeai_metadata_pnginfo,
|
||||||
|
)
|
||||||
from invokeai.app.services.item_storage import PaginatedResults
|
from invokeai.app.services.item_storage import PaginatedResults
|
||||||
from invokeai.app.util.save_thumbnail import save_thumbnail
|
from invokeai.app.util.misc import get_timestamp
|
||||||
|
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
||||||
from invokeai.backend.image_util import PngWriter
|
|
||||||
|
|
||||||
|
|
||||||
class ImageStorageBase(ABC):
|
class ImageStorageBase(ABC):
|
||||||
@ -26,12 +26,14 @@ class ImageStorageBase(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get(self, image_type: ImageType, image_name: str) -> Image:
|
def get(self, image_type: ImageType, image_name: str) -> Image:
|
||||||
|
"""Retrieves an image as PIL Image."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def list(
|
def list(
|
||||||
self, image_type: ImageType, page: int = 0, per_page: int = 10
|
self, image_type: ImageType, page: int = 0, per_page: int = 10
|
||||||
) -> PaginatedResults[ImageResponse]:
|
) -> PaginatedResults[ImageResponse]:
|
||||||
|
"""Gets a paginated list of images."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||||
@ -39,35 +41,51 @@ class ImageStorageBase(ABC):
|
|||||||
def get_path(
|
def get_path(
|
||||||
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
|
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
|
||||||
) -> str:
|
) -> str:
|
||||||
|
"""Gets the path to an image or its thumbnail."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||||
|
@abstractmethod
|
||||||
|
def validate_path(self, path: str) -> bool:
|
||||||
|
"""Validates an image path."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
|
def save(
|
||||||
|
self,
|
||||||
|
image_type: ImageType,
|
||||||
|
image_name: str,
|
||||||
|
image: Image,
|
||||||
|
metadata: InvokeAIMetadata | None = None,
|
||||||
|
) -> Tuple[str, str, int]:
|
||||||
|
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image path, thumbnail path, and created timestamp."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||||
|
"""Deletes an image and its thumbnail (if one exists)."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def create_name(self, context_id: str, node_id: str) -> str:
|
def create_name(self, context_id: str, node_id: str) -> str:
|
||||||
return f"{context_id}_{node_id}_{str(int(datetime.datetime.now(datetime.timezone.utc).timestamp()))}.png"
|
"""Creates a unique contextual image filename."""
|
||||||
|
return f"{context_id}_{node_id}_{str(get_timestamp())}.png"
|
||||||
|
|
||||||
|
|
||||||
class DiskImageStorage(ImageStorageBase):
|
class DiskImageStorage(ImageStorageBase):
|
||||||
"""Stores images on disk"""
|
"""Stores images on disk"""
|
||||||
|
|
||||||
__output_folder: str
|
__output_folder: str
|
||||||
__pngWriter: PngWriter
|
|
||||||
__cache_ids: Queue # TODO: this is an incredibly naive cache
|
__cache_ids: Queue # TODO: this is an incredibly naive cache
|
||||||
__cache: Dict[str, Image]
|
__cache: Dict[str, Image]
|
||||||
__max_cache_size: int
|
__max_cache_size: int
|
||||||
|
__metadata_service: MetadataServiceBase
|
||||||
|
|
||||||
def __init__(self, output_folder: str):
|
def __init__(self, output_folder: str, metadata_service: MetadataServiceBase):
|
||||||
self.__output_folder = output_folder
|
self.__output_folder = output_folder
|
||||||
self.__pngWriter = PngWriter(output_folder)
|
|
||||||
self.__cache = dict()
|
self.__cache = dict()
|
||||||
self.__cache_ids = Queue()
|
self.__cache_ids = Queue()
|
||||||
self.__max_cache_size = 10 # TODO: get this from config
|
self.__max_cache_size = 10 # TODO: get this from config
|
||||||
|
self.__metadata_service = metadata_service
|
||||||
|
|
||||||
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
@ -100,6 +118,9 @@ class DiskImageStorage(ImageStorageBase):
|
|||||||
for path in page_of_image_paths:
|
for path in page_of_image_paths:
|
||||||
filename = os.path.basename(path)
|
filename = os.path.basename(path)
|
||||||
img = PILImage.open(path)
|
img = PILImage.open(path)
|
||||||
|
|
||||||
|
invokeai_metadata = self.__metadata_service.get_metadata(img)
|
||||||
|
|
||||||
page_of_images.append(
|
page_of_images.append(
|
||||||
ImageResponse(
|
ImageResponse(
|
||||||
image_type=image_type.value,
|
image_type=image_type.value,
|
||||||
@ -107,11 +128,12 @@ class DiskImageStorage(ImageStorageBase):
|
|||||||
# TODO: DiskImageStorage should not be building URLs...?
|
# TODO: DiskImageStorage should not be building URLs...?
|
||||||
image_url=f"api/v1/images/{image_type.value}/{filename}",
|
image_url=f"api/v1/images/{image_type.value}/{filename}",
|
||||||
thumbnail_url=f"api/v1/images/{image_type.value}/thumbnails/{os.path.splitext(filename)[0]}.webp",
|
thumbnail_url=f"api/v1/images/{image_type.value}/thumbnails/{os.path.splitext(filename)[0]}.webp",
|
||||||
# TODO: Creation of this object should happen elsewhere, just making it fit here so it works
|
# TODO: Creation of this object should happen elsewhere (?), just making it fit here so it works
|
||||||
metadata=ImageMetadata(
|
metadata=ImageResponseMetadata(
|
||||||
timestamp=os.path.getctime(path),
|
created=int(os.path.getctime(path)),
|
||||||
width=img.width,
|
width=img.width,
|
||||||
height=img.height,
|
height=img.height,
|
||||||
|
invokeai=invokeai_metadata,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -142,26 +164,50 @@ class DiskImageStorage(ImageStorageBase):
|
|||||||
def get_path(
|
def get_path(
|
||||||
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
|
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
|
||||||
) -> str:
|
) -> str:
|
||||||
|
# strip out any relative path shenanigans
|
||||||
|
basename = os.path.basename(image_name)
|
||||||
|
|
||||||
if is_thumbnail:
|
if is_thumbnail:
|
||||||
path = os.path.join(
|
path = os.path.join(
|
||||||
self.__output_folder, image_type, "thumbnails", image_name
|
self.__output_folder, image_type, "thumbnails", basename
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
path = os.path.join(self.__output_folder, image_type, image_name)
|
path = os.path.join(self.__output_folder, image_type, basename)
|
||||||
|
|
||||||
return path
|
return path
|
||||||
|
|
||||||
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
|
def validate_path(self, path: str) -> bool:
|
||||||
image_subpath = os.path.join(image_type, image_name)
|
try:
|
||||||
self.__pngWriter.save_image_and_prompt_to_png(
|
os.stat(path)
|
||||||
image, "", image_subpath, None
|
return True
|
||||||
) # TODO: just pass full path to png writer
|
except Exception:
|
||||||
save_thumbnail(
|
return False
|
||||||
image=image,
|
|
||||||
filename=image_name,
|
def save(
|
||||||
path=os.path.join(self.__output_folder, image_type, "thumbnails"),
|
self,
|
||||||
)
|
image_type: ImageType,
|
||||||
|
image_name: str,
|
||||||
|
image: Image,
|
||||||
|
metadata: InvokeAIMetadata | None = None,
|
||||||
|
) -> Tuple[str, str, int]:
|
||||||
image_path = self.get_path(image_type, image_name)
|
image_path = self.get_path(image_type, image_name)
|
||||||
|
|
||||||
|
# TODO: Reading the image and then saving it strips the metadata...
|
||||||
|
if metadata:
|
||||||
|
pnginfo = build_invokeai_metadata_pnginfo(metadata=metadata)
|
||||||
|
image.save(image_path, "PNG", pnginfo=pnginfo)
|
||||||
|
else:
|
||||||
|
image.save(image_path) # this saved image has an empty info
|
||||||
|
|
||||||
|
thumbnail_name = get_thumbnail_name(image_name)
|
||||||
|
thumbnail_path = self.get_path(image_type, thumbnail_name, is_thumbnail=True)
|
||||||
|
thumbnail_image = make_thumbnail(image)
|
||||||
|
thumbnail_image.save(thumbnail_path)
|
||||||
|
|
||||||
self.__set_cache(image_path, image)
|
self.__set_cache(image_path, image)
|
||||||
|
self.__set_cache(thumbnail_path, thumbnail_image)
|
||||||
|
|
||||||
|
return (image_path, thumbnail_path, int(os.path.getctime(image_path)))
|
||||||
|
|
||||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||||
image_path = self.get_path(image_type, image_name)
|
image_path = self.get_path(image_type, image_name)
|
||||||
|
@ -1,30 +1,17 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
import time
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
# TODO: make this serializable
|
class InvocationQueueItem(BaseModel):
|
||||||
class InvocationQueueItem:
|
graph_execution_state_id: str = Field(description="The ID of the graph execution state")
|
||||||
# session_id: str
|
invocation_id: str = Field(description="The ID of the node being invoked")
|
||||||
graph_execution_state_id: str
|
invoke_all: bool = Field(default=False)
|
||||||
invocation_id: str
|
timestamp: float = Field(default_factory=time.time)
|
||||||
invoke_all: bool
|
|
||||||
timestamp: float
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
# session_id: str,
|
|
||||||
graph_execution_state_id: str,
|
|
||||||
invocation_id: str,
|
|
||||||
invoke_all: bool = False,
|
|
||||||
):
|
|
||||||
# self.session_id = session_id
|
|
||||||
self.graph_execution_state_id = graph_execution_state_id
|
|
||||||
self.invocation_id = invocation_id
|
|
||||||
self.invoke_all = invoke_all
|
|
||||||
self.timestamp = time.time()
|
|
||||||
|
|
||||||
|
|
||||||
class InvocationQueueABC(ABC):
|
class InvocationQueueABC(ABC):
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
from invokeai.app.services.metadata import MetadataServiceBase
|
||||||
from invokeai.backend import ModelManager
|
from invokeai.backend import ModelManager
|
||||||
|
|
||||||
from .events import EventServiceBase
|
from .events import EventServiceBase
|
||||||
@ -14,6 +15,7 @@ class InvocationServices:
|
|||||||
events: EventServiceBase
|
events: EventServiceBase
|
||||||
latents: LatentsStorageBase
|
latents: LatentsStorageBase
|
||||||
images: ImageStorageBase
|
images: ImageStorageBase
|
||||||
|
metadata: MetadataServiceBase
|
||||||
queue: InvocationQueueABC
|
queue: InvocationQueueABC
|
||||||
model_manager: ModelManager
|
model_manager: ModelManager
|
||||||
restoration: RestorationServices
|
restoration: RestorationServices
|
||||||
@ -29,6 +31,7 @@ class InvocationServices:
|
|||||||
events: EventServiceBase,
|
events: EventServiceBase,
|
||||||
latents: LatentsStorageBase,
|
latents: LatentsStorageBase,
|
||||||
images: ImageStorageBase,
|
images: ImageStorageBase,
|
||||||
|
metadata: MetadataServiceBase,
|
||||||
queue: InvocationQueueABC,
|
queue: InvocationQueueABC,
|
||||||
graph_library: ItemStorageABC["LibraryGraph"],
|
graph_library: ItemStorageABC["LibraryGraph"],
|
||||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||||
@ -39,6 +42,7 @@ class InvocationServices:
|
|||||||
self.events = events
|
self.events = events
|
||||||
self.latents = latents
|
self.latents = latents
|
||||||
self.images = images
|
self.images = images
|
||||||
|
self.metadata = metadata
|
||||||
self.queue = queue
|
self.queue = queue
|
||||||
self.graph_library = graph_library
|
self.graph_library = graph_library
|
||||||
self.graph_execution_manager = graph_execution_manager
|
self.graph_execution_manager = graph_execution_manager
|
||||||
|
96
invokeai/app/services/metadata.py
Normal file
96
invokeai/app/services/metadata.py
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
import json
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Dict, Optional, TypedDict
|
||||||
|
from PIL import Image, PngImagePlugin
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from invokeai.app.models.image import ImageType, is_image_type
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataImageField(TypedDict):
|
||||||
|
"""Pydantic-less ImageField, used for metadata parsing."""
|
||||||
|
|
||||||
|
image_type: ImageType
|
||||||
|
image_name: str
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataLatentsField(TypedDict):
|
||||||
|
"""Pydantic-less LatentsField, used for metadata parsing."""
|
||||||
|
|
||||||
|
latents_name: str
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: This is a placeholder for `InvocationsUnion` pending resolution of circular imports
|
||||||
|
NodeMetadata = Dict[
|
||||||
|
str, str | int | float | bool | MetadataImageField | MetadataLatentsField
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class InvokeAIMetadata(TypedDict, total=False):
|
||||||
|
"""InvokeAI-specific metadata format."""
|
||||||
|
|
||||||
|
session_id: Optional[str]
|
||||||
|
node: Optional[NodeMetadata]
|
||||||
|
|
||||||
|
|
||||||
|
def build_invokeai_metadata_pnginfo(
|
||||||
|
metadata: InvokeAIMetadata | None,
|
||||||
|
) -> PngImagePlugin.PngInfo:
|
||||||
|
"""Builds a PngInfo object with key `"invokeai"` and value `metadata`"""
|
||||||
|
pnginfo = PngImagePlugin.PngInfo()
|
||||||
|
|
||||||
|
if metadata is not None:
|
||||||
|
pnginfo.add_text("invokeai", json.dumps(metadata))
|
||||||
|
|
||||||
|
return pnginfo
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataServiceBase(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def get_metadata(self, image: Image.Image) -> InvokeAIMetadata | None:
|
||||||
|
"""Gets the InvokeAI metadata from a PIL Image, skipping invalid values"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def build_metadata(
|
||||||
|
self, session_id: str, node: BaseModel
|
||||||
|
) -> InvokeAIMetadata | None:
|
||||||
|
"""Builds an InvokeAIMetadata object"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PngMetadataService(MetadataServiceBase):
|
||||||
|
"""Handles loading and building metadata for images."""
|
||||||
|
|
||||||
|
# TODO: Use `InvocationsUnion` to **validate** metadata as representing a fully-functioning node
|
||||||
|
def _load_metadata(self, image: Image.Image) -> dict | None:
|
||||||
|
"""Loads a specific info entry from a PIL Image."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
info = image.info.get("invokeai")
|
||||||
|
|
||||||
|
if type(info) is not str:
|
||||||
|
return None
|
||||||
|
|
||||||
|
loaded_metadata = json.loads(info)
|
||||||
|
|
||||||
|
if type(loaded_metadata) is not dict:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if len(loaded_metadata.items()) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return loaded_metadata
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_metadata(self, image: Image.Image) -> dict | None:
|
||||||
|
"""Retrieves an image's metadata as a dict"""
|
||||||
|
loaded_metadata = self._load_metadata(image)
|
||||||
|
|
||||||
|
return loaded_metadata
|
||||||
|
|
||||||
|
def build_metadata(self, session_id: str, node: BaseModel) -> InvokeAIMetadata:
|
||||||
|
metadata = InvokeAIMetadata(session_id=session_id, node=node.dict())
|
||||||
|
|
||||||
|
return metadata
|
@ -43,10 +43,14 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
queue_item.invocation_id
|
queue_item.invocation_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# get the source node id to provide to clients (the prepared node id is not as useful)
|
||||||
|
source_node_id = graph_execution_state.prepared_source_mapping[invocation.id]
|
||||||
|
|
||||||
# Send starting event
|
# Send starting event
|
||||||
self.__invoker.services.events.emit_invocation_started(
|
self.__invoker.services.events.emit_invocation_started(
|
||||||
graph_execution_state_id=graph_execution_state.id,
|
graph_execution_state_id=graph_execution_state.id,
|
||||||
invocation_id=invocation.id,
|
node=invocation.dict(),
|
||||||
|
source_node_id=source_node_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Invoke
|
# Invoke
|
||||||
@ -75,7 +79,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
# Send complete event
|
# Send complete event
|
||||||
self.__invoker.services.events.emit_invocation_complete(
|
self.__invoker.services.events.emit_invocation_complete(
|
||||||
graph_execution_state_id=graph_execution_state.id,
|
graph_execution_state_id=graph_execution_state.id,
|
||||||
invocation_id=invocation.id,
|
node=invocation.dict(),
|
||||||
|
source_node_id=source_node_id,
|
||||||
result=outputs.dict(),
|
result=outputs.dict(),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -99,7 +104,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
# Send error event
|
# Send error event
|
||||||
self.__invoker.services.events.emit_invocation_error(
|
self.__invoker.services.events.emit_invocation_error(
|
||||||
graph_execution_state_id=graph_execution_state.id,
|
graph_execution_state_id=graph_execution_state.id,
|
||||||
invocation_id=invocation.id,
|
node=invocation.dict(),
|
||||||
|
source_node_id=source_node_id,
|
||||||
error=error,
|
error=error,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -35,7 +35,8 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
self._create_table()
|
self._create_table()
|
||||||
|
|
||||||
def _create_table(self):
|
def _create_table(self):
|
||||||
with self._lock:
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
f"""CREATE TABLE IF NOT EXISTS {self._table_name} (
|
f"""CREATE TABLE IF NOT EXISTS {self._table_name} (
|
||||||
item TEXT,
|
item TEXT,
|
||||||
@ -44,27 +45,34 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
f"""CREATE UNIQUE INDEX IF NOT EXISTS {self._table_name}_id ON {self._table_name}(id);"""
|
f"""CREATE UNIQUE INDEX IF NOT EXISTS {self._table_name}_id ON {self._table_name}(id);"""
|
||||||
)
|
)
|
||||||
self._conn.commit()
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
def _parse_item(self, item: str) -> T:
|
def _parse_item(self, item: str) -> T:
|
||||||
item_type = get_args(self.__orig_class__)[0]
|
item_type = get_args(self.__orig_class__)[0]
|
||||||
return parse_raw_as(item_type, item)
|
return parse_raw_as(item_type, item)
|
||||||
|
|
||||||
def set(self, item: T):
|
def set(self, item: T):
|
||||||
with self._lock:
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""",
|
f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""",
|
||||||
(item.json(),),
|
(item.json(),),
|
||||||
)
|
)
|
||||||
self._conn.commit()
|
self._conn.commit()
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
self._on_changed(item)
|
self._on_changed(item)
|
||||||
|
|
||||||
def get(self, id: str) -> Union[T, None]:
|
def get(self, id: str) -> Union[T, None]:
|
||||||
with self._lock:
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),)
|
f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),)
|
||||||
)
|
)
|
||||||
result = self._cursor.fetchone()
|
result = self._cursor.fetchone()
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
if not result:
|
if not result:
|
||||||
return None
|
return None
|
||||||
@ -72,15 +80,19 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
return self._parse_item(result[0])
|
return self._parse_item(result[0])
|
||||||
|
|
||||||
def delete(self, id: str):
|
def delete(self, id: str):
|
||||||
with self._lock:
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),)
|
f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),)
|
||||||
)
|
)
|
||||||
self._conn.commit()
|
self._conn.commit()
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
self._on_deleted(id)
|
self._on_deleted(id)
|
||||||
|
|
||||||
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
||||||
with self._lock:
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
f"""SELECT item FROM {self._table_name} LIMIT ? OFFSET ?;""",
|
f"""SELECT item FROM {self._table_name} LIMIT ? OFFSET ?;""",
|
||||||
(per_page, page * per_page),
|
(per_page, page * per_page),
|
||||||
@ -91,6 +103,8 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
|
|
||||||
self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""")
|
self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""")
|
||||||
count = self._cursor.fetchone()[0]
|
count = self._cursor.fetchone()[0]
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
pageCount = int(count / per_page) + 1
|
pageCount = int(count / per_page) + 1
|
||||||
|
|
||||||
@ -101,7 +115,8 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
def search(
|
def search(
|
||||||
self, query: str, page: int = 0, per_page: int = 10
|
self, query: str, page: int = 0, per_page: int = 10
|
||||||
) -> PaginatedResults[T]:
|
) -> PaginatedResults[T]:
|
||||||
with self._lock:
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
f"""SELECT item FROM {self._table_name} WHERE item LIKE ? LIMIT ? OFFSET ?;""",
|
f"""SELECT item FROM {self._table_name} WHERE item LIKE ? LIMIT ? OFFSET ?;""",
|
||||||
(f"%{query}%", per_page, page * per_page),
|
(f"%{query}%", per_page, page * per_page),
|
||||||
@ -115,6 +130,8 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
(f"%{query}%",),
|
(f"%{query}%",),
|
||||||
)
|
)
|
||||||
count = self._cursor.fetchone()[0]
|
count = self._cursor.fetchone()[0]
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
pageCount = int(count / per_page) + 1
|
pageCount = int(count / per_page) + 1
|
||||||
|
|
||||||
|
5
invokeai/app/util/misc.py
Normal file
5
invokeai/app/util/misc.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
import datetime
|
||||||
|
|
||||||
|
|
||||||
|
def get_timestamp():
|
||||||
|
return int(datetime.datetime.now(datetime.timezone.utc).timestamp())
|
@ -1,25 +0,0 @@
|
|||||||
import os
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
|
|
||||||
def save_thumbnail(
|
|
||||||
image: Image.Image,
|
|
||||||
filename: str,
|
|
||||||
path: str,
|
|
||||||
size: int = 256,
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Saves a thumbnail of an image, returning its path.
|
|
||||||
"""
|
|
||||||
base_filename = os.path.splitext(filename)[0]
|
|
||||||
thumbnail_path = os.path.join(path, base_filename + ".webp")
|
|
||||||
|
|
||||||
if os.path.exists(thumbnail_path):
|
|
||||||
return thumbnail_path
|
|
||||||
|
|
||||||
image_copy = image.copy()
|
|
||||||
image_copy.thumbnail(size=(size, size))
|
|
||||||
|
|
||||||
image_copy.save(thumbnail_path, "WEBP")
|
|
||||||
|
|
||||||
return thumbnail_path
|
|
@ -1,16 +1,41 @@
|
|||||||
import torch
|
from invokeai.app.api.models.images import ProgressImage
|
||||||
|
from invokeai.app.models.exceptions import CanceledException
|
||||||
from ..invocations.baseinvocation import InvocationContext
|
from ..invocations.baseinvocation import InvocationContext
|
||||||
from ...backend.util.util import image_to_dataURL
|
from ...backend.util.util import image_to_dataURL
|
||||||
from ...backend.generator.base import Generator
|
from ...backend.generator.base import Generator
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
|
|
||||||
def fast_latents_step_callback(
|
|
||||||
sample: torch.Tensor,
|
def stable_diffusion_step_callback(
|
||||||
step: int,
|
|
||||||
steps: int,
|
|
||||||
id: str,
|
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
|
intermediate_state: PipelineIntermediateState,
|
||||||
|
node: dict,
|
||||||
|
source_node_id: str,
|
||||||
):
|
):
|
||||||
|
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
||||||
|
raise CanceledException
|
||||||
|
|
||||||
|
# Some schedulers report not only the noisy latents at the current timestep,
|
||||||
|
# but also their estimate so far of what the de-noised latents will be. Use
|
||||||
|
# that estimate if it is available.
|
||||||
|
if intermediate_state.predicted_original is not None:
|
||||||
|
sample = intermediate_state.predicted_original
|
||||||
|
else:
|
||||||
|
sample = intermediate_state.latents
|
||||||
|
|
||||||
|
# TODO: This does not seem to be needed any more?
|
||||||
|
# # txt2img provides a Tensor in the step_callback
|
||||||
|
# # img2img provides a PipelineIntermediateState
|
||||||
|
# if isinstance(sample, PipelineIntermediateState):
|
||||||
|
# # this was an img2img
|
||||||
|
# print('img2img')
|
||||||
|
# latents = sample.latents
|
||||||
|
# step = sample.step
|
||||||
|
# else:
|
||||||
|
# print('txt2img')
|
||||||
|
# latents = sample
|
||||||
|
# step = intermediate_state.step
|
||||||
|
|
||||||
# TODO: only output a preview image when requested
|
# TODO: only output a preview image when requested
|
||||||
image = Generator.sample_to_lowres_estimated_image(sample)
|
image = Generator.sample_to_lowres_estimated_image(sample)
|
||||||
|
|
||||||
@ -21,23 +46,10 @@ def fast_latents_step_callback(
|
|||||||
dataURL = image_to_dataURL(image, image_format="JPEG")
|
dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||||
|
|
||||||
context.services.events.emit_generator_progress(
|
context.services.events.emit_generator_progress(
|
||||||
context.graph_execution_state_id,
|
graph_execution_state_id=context.graph_execution_state_id,
|
||||||
id,
|
node=node,
|
||||||
{"width": width, "height": height, "dataURL": dataURL},
|
source_node_id=source_node_id,
|
||||||
step,
|
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
|
||||||
steps,
|
step=intermediate_state.step,
|
||||||
|
total_steps=node["steps"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def diffusers_step_callback_adapter(*cb_args, **kwargs):
|
|
||||||
"""
|
|
||||||
txt2img gives us a Tensor in the step_callbak, while img2img gives us a PipelineIntermediateState.
|
|
||||||
This adapter grabs the needed data and passes it along to the callback function.
|
|
||||||
"""
|
|
||||||
if isinstance(cb_args[0], PipelineIntermediateState):
|
|
||||||
progress_state: PipelineIntermediateState = cb_args[0]
|
|
||||||
return fast_latents_step_callback(
|
|
||||||
progress_state.latents, progress_state.step, **kwargs
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return fast_latents_step_callback(*cb_args, **kwargs)
|
|
||||||
|
15
invokeai/app/util/thumbnails.py
Normal file
15
invokeai/app/util/thumbnails.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
import os
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
def get_thumbnail_name(image_name: str) -> str:
|
||||||
|
"""Formats given an image name, returns the appropriate thumbnail image name"""
|
||||||
|
thumbnail_name = os.path.splitext(image_name)[0] + ".webp"
|
||||||
|
return thumbnail_name
|
||||||
|
|
||||||
|
|
||||||
|
def make_thumbnail(image: Image.Image, size: int = 256) -> Image.Image:
|
||||||
|
"""Makes a thumbnail from a PIL Image"""
|
||||||
|
thumbnail = image.copy()
|
||||||
|
thumbnail.thumbnail(size=(size, size))
|
||||||
|
return thumbnail
|
@ -57,7 +57,7 @@ class HuggingFaceConceptsLibrary(object):
|
|||||||
self.concept_list.extend(list(local_concepts_to_add))
|
self.concept_list.extend(list(local_concepts_to_add))
|
||||||
return self.concept_list
|
return self.concept_list
|
||||||
return self.concept_list
|
return self.concept_list
|
||||||
else:
|
elif Globals.internet_available is True:
|
||||||
try:
|
try:
|
||||||
models = self.hf_api.list_models(
|
models = self.hf_api.list_models(
|
||||||
filter=ModelFilter(model_name="sd-concepts-library/")
|
filter=ModelFilter(model_name="sd-concepts-library/")
|
||||||
@ -73,6 +73,8 @@ class HuggingFaceConceptsLibrary(object):
|
|||||||
" ** You may load .bin and .pt file(s) manually using the --embedding_directory argument."
|
" ** You may load .bin and .pt file(s) manually using the --embedding_directory argument."
|
||||||
)
|
)
|
||||||
return self.concept_list
|
return self.concept_list
|
||||||
|
else:
|
||||||
|
return self.concept_list
|
||||||
|
|
||||||
def get_concept_model_path(self, concept_name: str) -> str:
|
def get_concept_model_path(self, concept_name: str) -> str:
|
||||||
"""
|
"""
|
||||||
|
@ -6,3 +6,5 @@ stats.html
|
|||||||
index.html
|
index.html
|
||||||
.yarn/
|
.yarn/
|
||||||
*.scss
|
*.scss
|
||||||
|
src/services/api/
|
||||||
|
src/services/fixtures/*
|
||||||
|
@ -3,4 +3,8 @@ dist/
|
|||||||
node_modules/
|
node_modules/
|
||||||
patches/
|
patches/
|
||||||
stats.html
|
stats.html
|
||||||
|
index.html
|
||||||
.yarn/
|
.yarn/
|
||||||
|
*.scss
|
||||||
|
src/services/api/
|
||||||
|
src/services/fixtures/*
|
||||||
|
87
invokeai/frontend/web/docs/API_CLIENT.md
Normal file
87
invokeai/frontend/web/docs/API_CLIENT.md
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
# Generated axios API client
|
||||||
|
|
||||||
|
- [Generated axios API client](#generated-axios-api-client)
|
||||||
|
- [Generation](#generation)
|
||||||
|
- [Generate the API client from the nodes web server](#generate-the-api-client-from-the-nodes-web-server)
|
||||||
|
- [Generate the API client from JSON](#generate-the-api-client-from-json)
|
||||||
|
- [Getting the JSON from the nodes web server](#getting-the-json-from-the-nodes-web-server)
|
||||||
|
- [Getting the JSON with a python script](#getting-the-json-with-a-python-script)
|
||||||
|
- [Generate the API client](#generate-the-api-client)
|
||||||
|
- [The generated client](#the-generated-client)
|
||||||
|
- [API client customisation](#api-client-customisation)
|
||||||
|
|
||||||
|
This API client is generated by an [openapi code generator](https://github.com/ferdikoomen/openapi-typescript-codegen).
|
||||||
|
|
||||||
|
All files in `invokeai/frontend/web/src/services/api/` are made by the generator.
|
||||||
|
|
||||||
|
## Generation
|
||||||
|
|
||||||
|
The axios client may be generated by from the OpenAPI schema from the nodes web server, or from JSON.
|
||||||
|
|
||||||
|
### Generate the API client from the nodes web server
|
||||||
|
|
||||||
|
We need to start the nodes web server, which serves the OpenAPI schema to the generator.
|
||||||
|
|
||||||
|
1. Start the nodes web server.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# from the repo root
|
||||||
|
python scripts/invoke-new.py --web
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Generate the API client.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# from invokeai/frontend/web/
|
||||||
|
yarn api:web
|
||||||
|
```
|
||||||
|
|
||||||
|
### Generate the API client from JSON
|
||||||
|
|
||||||
|
The JSON can be acquired from the nodes web server, or with a python script.
|
||||||
|
|
||||||
|
#### Getting the JSON from the nodes web server
|
||||||
|
|
||||||
|
Start the nodes web server as described above, then download the file.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# from invokeai/frontend/web/
|
||||||
|
curl http://localhost:9090/openapi.json -o openapi.json
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Getting the JSON with a python script
|
||||||
|
|
||||||
|
Run this python script from the repo root, so it can access the nodes server modules.
|
||||||
|
|
||||||
|
The script will output `openapi.json` in the repo root. Then we need to move it to `invokeai/frontend/web/`.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# from the repo root
|
||||||
|
python invokeai/app/util/generate_openapi_json.py
|
||||||
|
mv invokeai/app/util/openapi.json invokeai/frontend/web/services/fixtures/
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Generate the API client
|
||||||
|
|
||||||
|
Now we can generate the API client from the JSON.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# from invokeai/frontend/web/
|
||||||
|
yarn api:file
|
||||||
|
```
|
||||||
|
|
||||||
|
## The generated client
|
||||||
|
|
||||||
|
The client will be written to `invokeai/frontend/web/services/api/`:
|
||||||
|
|
||||||
|
- `axios` client
|
||||||
|
- TS types
|
||||||
|
- An easily parseable schema, which we can use to generate UI
|
||||||
|
|
||||||
|
## API client customisation
|
||||||
|
|
||||||
|
The generator has a default `request.ts` file that implements a base `axios` client. The generated client uses this base client.
|
||||||
|
|
||||||
|
One shortcoming of this is base client is it does not provide response headers unless the response body is empty. To fix this, we provide our own lightly-patched `request.ts`.
|
||||||
|
|
||||||
|
To access the headers, call `getHeaders(response)` on any response from the generated api client. This function is exported from `invokeai/frontend/web/src/services/util/getHeaders.ts`.
|
21
invokeai/frontend/web/docs/EVENTS.md
Normal file
21
invokeai/frontend/web/docs/EVENTS.md
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
# Events
|
||||||
|
|
||||||
|
Events via `socket.io`
|
||||||
|
|
||||||
|
## `actions.ts`
|
||||||
|
|
||||||
|
Redux actions for all socket events. Payloads all include a timestamp, and optionally some other data.
|
||||||
|
|
||||||
|
Any reducer (or middleware) can respond to the actions.
|
||||||
|
|
||||||
|
## `middleware.ts`
|
||||||
|
|
||||||
|
Redux middleware for events.
|
||||||
|
|
||||||
|
Handles dispatching the event actions. Only put logic here if it can't really go anywhere else.
|
||||||
|
|
||||||
|
For example, on connect we want to load images to the gallery if it's not populated. This requires dispatching a thunk, so we need to directly dispatch this in the middleware.
|
||||||
|
|
||||||
|
## `types.ts`
|
||||||
|
|
||||||
|
Hand-written types for the socket events. Cannot generate these from the server, but fortunately they are few and simple.
|
17
invokeai/frontend/web/docs/NODE_EDITOR.md
Normal file
17
invokeai/frontend/web/docs/NODE_EDITOR.md
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
# Node Editor Design
|
||||||
|
|
||||||
|
WIP
|
||||||
|
|
||||||
|
nodes
|
||||||
|
|
||||||
|
everything in `src/features/nodes/`
|
||||||
|
|
||||||
|
have a look at `state.nodes.invocation`
|
||||||
|
|
||||||
|
- on socket connect, if no schema saved, fetch `localhost:9090/openapi.json`, save JSON to `state.nodes.schema`
|
||||||
|
- on fulfilled schema fetch, `parseSchema()` the schema. this outputs a `Record<string, Invocation>` which is saved to `state.nodes.invocations` - `Invocation` is like a template for the node
|
||||||
|
- when you add a node, the the `Invocation` template is passed to `InvocationComponent.tsx` to build the UI component for that node
|
||||||
|
- inputs/outputs have field types - and each field type gets an `FieldComponent` which includes a dispatcher to write state changes to redux `nodesSlice`
|
||||||
|
- `reactflow` sends changes to nodes/edges to redux
|
||||||
|
- to invoke, `buildNodesGraph()` state, then send this
|
||||||
|
- changed onClick Invoke button actions to build the schema, then when schema builds it dispatches the actual network request to create the session - see `session.ts`
|
29
invokeai/frontend/web/docs/PACKAGE_SCRIPTS.md
Normal file
29
invokeai/frontend/web/docs/PACKAGE_SCRIPTS.md
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
# Package Scripts
|
||||||
|
|
||||||
|
WIP walkthrough of `package.json` scripts.
|
||||||
|
|
||||||
|
## `theme` & `theme:watch`
|
||||||
|
|
||||||
|
These run the Chakra CLI to generate types for the theme, or watch for code change and re-generate the types.
|
||||||
|
|
||||||
|
The CLI essentially monkeypatches Chakra's files in `node_modules`.
|
||||||
|
|
||||||
|
## `postinstall`
|
||||||
|
|
||||||
|
The `postinstall` script patches a few packages and runs the Chakra CLI to generate types for the theme.
|
||||||
|
|
||||||
|
### Patch `@chakra-ui/cli`
|
||||||
|
|
||||||
|
See: <https://github.com/chakra-ui/chakra-ui/issues/7394>
|
||||||
|
|
||||||
|
### Patch `redux-persist`
|
||||||
|
|
||||||
|
We want to persist the canvas state to `localStorage` but many canvas operations change data very quickly, so we need to debounce the writes to `localStorage`.
|
||||||
|
|
||||||
|
`redux-persist` is unfortunately unmaintained. The repo's current code is nonfunctional, but the last release's code depends on a package that was removed from `npm` for being malware, so we cannot just fork it.
|
||||||
|
|
||||||
|
So, we have to patch it directly. Perhaps a better way would be to write a debounced storage adapter, but I couldn't figure out how to do that.
|
||||||
|
|
||||||
|
### Patch `redux-deep-persist`
|
||||||
|
|
||||||
|
This package makes blacklisting and whitelisting persist configs very simple, but we have to patch it to match `redux-persist` for the types to work.
|
@ -1,10 +1,16 @@
|
|||||||
# InvokeAI Web UI
|
# InvokeAI Web UI
|
||||||
|
|
||||||
|
- [InvokeAI Web UI](#invokeai-web-ui)
|
||||||
|
- [Stack](#stack)
|
||||||
|
- [Contributing](#contributing)
|
||||||
|
- [Dev Environment](#dev-environment)
|
||||||
|
- [Production builds](#production-builds)
|
||||||
|
|
||||||
The UI is a fairly straightforward Typescript React app. The only really fancy stuff is the Unified Canvas.
|
The UI is a fairly straightforward Typescript React app. The only really fancy stuff is the Unified Canvas.
|
||||||
|
|
||||||
Code in `invokeai/frontend/web/` if you want to have a look.
|
Code in `invokeai/frontend/web/` if you want to have a look.
|
||||||
|
|
||||||
## Details
|
## Stack
|
||||||
|
|
||||||
State management is Redux via [Redux Toolkit](https://github.com/reduxjs/redux-toolkit). Communication with server is a mix of HTTP and [socket.io](https://github.com/socketio/socket.io-client) (with a custom redux middleware to help).
|
State management is Redux via [Redux Toolkit](https://github.com/reduxjs/redux-toolkit). Communication with server is a mix of HTTP and [socket.io](https://github.com/socketio/socket.io-client) (with a custom redux middleware to help).
|
||||||
|
|
||||||
@ -32,7 +38,7 @@ Start everything in dev mode:
|
|||||||
|
|
||||||
1. Start the dev server: `yarn dev`
|
1. Start the dev server: `yarn dev`
|
||||||
2. Start the InvokeAI UI per usual: `invokeai --web`
|
2. Start the InvokeAI UI per usual: `invokeai --web`
|
||||||
3. Point your browser to the dev server address e.g. `http://localhost:5173/`
|
3. Point your browser to the dev server address e.g. <http://localhost:5173/>
|
||||||
|
|
||||||
### Production builds
|
### Production builds
|
||||||
|
|
21
invokeai/frontend/web/index.d.ts
vendored
21
invokeai/frontend/web/index.d.ts
vendored
@ -1,6 +1,7 @@
|
|||||||
import React, { PropsWithChildren } from 'react';
|
import React, { PropsWithChildren } from 'react';
|
||||||
import { IAIPopoverProps } from '../web/src/common/components/IAIPopover';
|
import { IAIPopoverProps } from '../web/src/common/components/IAIPopover';
|
||||||
import { IAIIconButtonProps } from '../web/src/common/components/IAIIconButton';
|
import { IAIIconButtonProps } from '../web/src/common/components/IAIIconButton';
|
||||||
|
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||||
|
|
||||||
export {};
|
export {};
|
||||||
|
|
||||||
@ -64,9 +65,25 @@ declare module '@invoke-ai/invoke-ai-ui' {
|
|||||||
declare class SettingsModal extends React.Component<SettingsModalProps> {
|
declare class SettingsModal extends React.Component<SettingsModalProps> {
|
||||||
public constructor(props: SettingsModalProps);
|
public constructor(props: SettingsModalProps);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
declare class StatusIndicator extends React.Component<StatusIndicatorProps> {
|
||||||
|
public constructor(props: StatusIndicatorProps);
|
||||||
|
}
|
||||||
|
|
||||||
|
declare class ModelSelect extends React.Component<ModelSelectProps> {
|
||||||
|
public constructor(props: ModelSelectProps);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
declare function Invoke(props: PropsWithChildren): JSX.Element;
|
interface InvokeProps extends PropsWithChildren {
|
||||||
|
apiUrl?: string;
|
||||||
|
disabledPanels?: string[];
|
||||||
|
disabledTabs?: InvokeTabName[];
|
||||||
|
token?: string;
|
||||||
|
shouldTransformUrls?: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
declare function Invoke(props: InvokeProps): JSX.Element;
|
||||||
|
|
||||||
export {
|
export {
|
||||||
ThemeChanger,
|
ThemeChanger,
|
||||||
@ -74,5 +91,7 @@ export {
|
|||||||
IAIPopover,
|
IAIPopover,
|
||||||
IAIIconButton,
|
IAIIconButton,
|
||||||
SettingsModal,
|
SettingsModal,
|
||||||
|
StatusIndicator,
|
||||||
|
ModelSelect,
|
||||||
};
|
};
|
||||||
export = Invoke;
|
export = Invoke;
|
||||||
|
@ -5,7 +5,10 @@
|
|||||||
"scripts": {
|
"scripts": {
|
||||||
"prepare": "cd ../../../ && husky install invokeai/frontend/web/.husky",
|
"prepare": "cd ../../../ && husky install invokeai/frontend/web/.husky",
|
||||||
"dev": "concurrently \"vite dev\" \"yarn run theme:watch\"",
|
"dev": "concurrently \"vite dev\" \"yarn run theme:watch\"",
|
||||||
|
"dev:nodes": "concurrently \"vite dev --mode nodes\" \"yarn run theme:watch\"",
|
||||||
"build": "yarn run lint && vite build",
|
"build": "yarn run lint && vite build",
|
||||||
|
"api:web": "openapi -i http://localhost:9090/openapi.json -o src/services/api --client axios --useOptions --useUnionTypes --exportSchemas true --indent 2 --request src/services/fixtures/request.ts",
|
||||||
|
"api:file": "openapi -i src/services/fixtures/openapi.json -o src/services/api --client axios --useOptions --useUnionTypes --exportSchemas true --indent 2 --request src/services/fixtures/request.ts",
|
||||||
"preview": "vite preview",
|
"preview": "vite preview",
|
||||||
"lint:madge": "madge --circular src/main.tsx",
|
"lint:madge": "madge --circular src/main.tsx",
|
||||||
"lint:eslint": "eslint --max-warnings=0 .",
|
"lint:eslint": "eslint --max-warnings=0 .",
|
||||||
@ -41,9 +44,11 @@
|
|||||||
"@chakra-ui/react": "^2.5.1",
|
"@chakra-ui/react": "^2.5.1",
|
||||||
"@chakra-ui/styled-system": "^2.6.1",
|
"@chakra-ui/styled-system": "^2.6.1",
|
||||||
"@chakra-ui/theme-tools": "^2.0.16",
|
"@chakra-ui/theme-tools": "^2.0.16",
|
||||||
|
"@dagrejs/graphlib": "^2.1.12",
|
||||||
"@emotion/react": "^11.10.6",
|
"@emotion/react": "^11.10.6",
|
||||||
"@emotion/styled": "^11.10.6",
|
"@emotion/styled": "^11.10.6",
|
||||||
"@reduxjs/toolkit": "^1.9.2",
|
"@fontsource/inter": "^4.5.15",
|
||||||
|
"@reduxjs/toolkit": "^1.9.3",
|
||||||
"chakra-ui-contextmenu": "^1.0.5",
|
"chakra-ui-contextmenu": "^1.0.5",
|
||||||
"dateformat": "^5.0.3",
|
"dateformat": "^5.0.3",
|
||||||
"formik": "^2.2.9",
|
"formik": "^2.2.9",
|
||||||
@ -67,15 +72,17 @@
|
|||||||
"react-redux": "^8.0.5",
|
"react-redux": "^8.0.5",
|
||||||
"react-transition-group": "^4.4.5",
|
"react-transition-group": "^4.4.5",
|
||||||
"react-zoom-pan-pinch": "^2.6.1",
|
"react-zoom-pan-pinch": "^2.6.1",
|
||||||
|
"reactflow": "^11.7.0",
|
||||||
"redux-deep-persist": "^1.0.7",
|
"redux-deep-persist": "^1.0.7",
|
||||||
|
"redux-dynamic-middlewares": "^2.2.0",
|
||||||
"redux-persist": "^6.0.0",
|
"redux-persist": "^6.0.0",
|
||||||
"socket.io-client": "^4.6.0",
|
"socket.io-client": "^4.6.0",
|
||||||
"use-image": "^1.1.0",
|
"use-image": "^1.1.0",
|
||||||
"uuid": "^9.0.0"
|
"uuid": "^9.0.0"
|
||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
"@fontsource/inter": "^4.5.15",
|
|
||||||
"@types/dateformat": "^5.0.0",
|
"@types/dateformat": "^5.0.0",
|
||||||
|
"@types/lodash": "^4.14.194",
|
||||||
"@types/react": "^18.0.28",
|
"@types/react": "^18.0.28",
|
||||||
"@types/react-dom": "^18.0.11",
|
"@types/react-dom": "^18.0.11",
|
||||||
"@types/react-transition-group": "^4.4.5",
|
"@types/react-transition-group": "^4.4.5",
|
||||||
@ -83,6 +90,7 @@
|
|||||||
"@typescript-eslint/eslint-plugin": "^5.52.0",
|
"@typescript-eslint/eslint-plugin": "^5.52.0",
|
||||||
"@typescript-eslint/parser": "^5.52.0",
|
"@typescript-eslint/parser": "^5.52.0",
|
||||||
"@vitejs/plugin-react-swc": "^3.2.0",
|
"@vitejs/plugin-react-swc": "^3.2.0",
|
||||||
|
"axios": "^1.3.4",
|
||||||
"babel-plugin-transform-imports": "^2.0.0",
|
"babel-plugin-transform-imports": "^2.0.0",
|
||||||
"concurrently": "^7.6.0",
|
"concurrently": "^7.6.0",
|
||||||
"eslint": "^8.34.0",
|
"eslint": "^8.34.0",
|
||||||
@ -90,13 +98,17 @@
|
|||||||
"eslint-plugin-prettier": "^4.2.1",
|
"eslint-plugin-prettier": "^4.2.1",
|
||||||
"eslint-plugin-react": "^7.32.2",
|
"eslint-plugin-react": "^7.32.2",
|
||||||
"eslint-plugin-react-hooks": "^4.6.0",
|
"eslint-plugin-react-hooks": "^4.6.0",
|
||||||
|
"form-data": "^4.0.0",
|
||||||
"husky": "^8.0.3",
|
"husky": "^8.0.3",
|
||||||
"lint-staged": "^13.1.2",
|
"lint-staged": "^13.1.2",
|
||||||
"madge": "^6.0.0",
|
"madge": "^6.0.0",
|
||||||
|
"openapi-types": "^12.1.0",
|
||||||
|
"openapi-typescript-codegen": "^0.23.0",
|
||||||
"postinstall-postinstall": "^2.1.0",
|
"postinstall-postinstall": "^2.1.0",
|
||||||
"prettier": "^2.8.4",
|
"prettier": "^2.8.4",
|
||||||
"rollup-plugin-visualizer": "^5.9.0",
|
"rollup-plugin-visualizer": "^5.9.0",
|
||||||
"terser": "^5.16.4",
|
"terser": "^5.16.4",
|
||||||
|
"typescript": "4.9.5",
|
||||||
"vite": "^4.1.2",
|
"vite": "^4.1.2",
|
||||||
"vite-plugin-eslint": "^1.8.1",
|
"vite-plugin-eslint": "^1.8.1",
|
||||||
"vite-tsconfig-paths": "^4.0.5",
|
"vite-tsconfig-paths": "^4.0.5",
|
||||||
|
@ -53,6 +53,7 @@
|
|||||||
"txt2img": "Text To Image",
|
"txt2img": "Text To Image",
|
||||||
"img2img": "Image To Image",
|
"img2img": "Image To Image",
|
||||||
"unifiedCanvas": "Unified Canvas",
|
"unifiedCanvas": "Unified Canvas",
|
||||||
|
"linear": "Linear",
|
||||||
"nodes": "Nodes",
|
"nodes": "Nodes",
|
||||||
"postprocessing": "Post Processing",
|
"postprocessing": "Post Processing",
|
||||||
"nodesDesc": "A node based system for the generation of images is under development currently. Stay tuned for updates about this amazing feature.",
|
"nodesDesc": "A node based system for the generation of images is under development currently. Stay tuned for updates about this amazing feature.",
|
||||||
@ -525,6 +526,10 @@
|
|||||||
"resetComplete": "Web UI has been reset. Refresh the page to reload."
|
"resetComplete": "Web UI has been reset. Refresh the page to reload."
|
||||||
},
|
},
|
||||||
"toast": {
|
"toast": {
|
||||||
|
"serverError": "Server Error",
|
||||||
|
"disconnected": "Disconnected from Server",
|
||||||
|
"connected": "Connected to Server",
|
||||||
|
"canceled": "Processing Canceled",
|
||||||
"tempFoldersEmptied": "Temp Folder Emptied",
|
"tempFoldersEmptied": "Temp Folder Emptied",
|
||||||
"uploadFailed": "Upload failed",
|
"uploadFailed": "Upload failed",
|
||||||
"uploadFailedMultipleImagesDesc": "Multiple images pasted, may only upload one image at a time",
|
"uploadFailedMultipleImagesDesc": "Multiple images pasted, may only upload one image at a time",
|
||||||
|
@ -13,16 +13,42 @@ import { Box, Flex, Grid, Portal, useColorMode } from '@chakra-ui/react';
|
|||||||
import { APP_HEIGHT, APP_WIDTH } from 'theme/util/constants';
|
import { APP_HEIGHT, APP_WIDTH } from 'theme/util/constants';
|
||||||
import ImageGalleryPanel from 'features/gallery/components/ImageGalleryPanel';
|
import ImageGalleryPanel from 'features/gallery/components/ImageGalleryPanel';
|
||||||
import Lightbox from 'features/lightbox/components/Lightbox';
|
import Lightbox from 'features/lightbox/components/Lightbox';
|
||||||
import { useAppSelector } from './storeHooks';
|
import { useAppDispatch, useAppSelector } from './storeHooks';
|
||||||
import { PropsWithChildren, useEffect } from 'react';
|
import { PropsWithChildren, useEffect } from 'react';
|
||||||
|
import { setDisabledPanels, setDisabledTabs } from 'features/ui/store/uiSlice';
|
||||||
|
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||||
|
import { shouldTransformUrlsChanged } from 'features/system/store/systemSlice';
|
||||||
|
|
||||||
keepGUIAlive();
|
keepGUIAlive();
|
||||||
|
|
||||||
const App = (props: PropsWithChildren) => {
|
interface Props extends PropsWithChildren {
|
||||||
|
options: {
|
||||||
|
disabledPanels: string[];
|
||||||
|
disabledTabs: InvokeTabName[];
|
||||||
|
shouldTransformUrls?: boolean;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
const App = (props: Props) => {
|
||||||
useToastWatcher();
|
useToastWatcher();
|
||||||
|
|
||||||
const currentTheme = useAppSelector((state) => state.ui.currentTheme);
|
const currentTheme = useAppSelector((state) => state.ui.currentTheme);
|
||||||
const { setColorMode } = useColorMode();
|
const { setColorMode } = useColorMode();
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
dispatch(setDisabledPanels(props.options.disabledPanels));
|
||||||
|
}, [dispatch, props.options.disabledPanels]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
dispatch(setDisabledTabs(props.options.disabledTabs));
|
||||||
|
}, [dispatch, props.options.disabledTabs]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
dispatch(
|
||||||
|
shouldTransformUrlsChanged(Boolean(props.options.shouldTransformUrls))
|
||||||
|
);
|
||||||
|
}, [dispatch, props.options.shouldTransformUrls]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
setColorMode(['light'].includes(currentTheme) ? 'light' : 'dark');
|
setColorMode(['light'].includes(currentTheme) ? 'light' : 'dark');
|
||||||
|
22
invokeai/frontend/web/src/app/invokeai.d.ts
vendored
22
invokeai/frontend/web/src/app/invokeai.d.ts
vendored
@ -14,6 +14,8 @@
|
|||||||
|
|
||||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||||
import { IRect } from 'konva/lib/types';
|
import { IRect } from 'konva/lib/types';
|
||||||
|
import { ImageMetadata, ImageType } from 'services/api';
|
||||||
|
import { AnyInvocation } from 'services/events/types';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* TODO:
|
* TODO:
|
||||||
@ -113,7 +115,7 @@ export declare type Metadata = SystemGenerationMetadata & {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// An Image has a UUID, url, modified timestamp, width, height and maybe metadata
|
// An Image has a UUID, url, modified timestamp, width, height and maybe metadata
|
||||||
export declare type Image = {
|
export declare type _Image = {
|
||||||
uuid: string;
|
uuid: string;
|
||||||
url: string;
|
url: string;
|
||||||
thumbnail: string;
|
thumbnail: string;
|
||||||
@ -124,11 +126,23 @@ export declare type Image = {
|
|||||||
category: GalleryCategory;
|
category: GalleryCategory;
|
||||||
isBase64?: boolean;
|
isBase64?: boolean;
|
||||||
dreamPrompt?: 'string';
|
dreamPrompt?: 'string';
|
||||||
|
name?: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ResultImage
|
||||||
|
*/
|
||||||
|
export declare type Image = {
|
||||||
|
name: string;
|
||||||
|
type: ImageType;
|
||||||
|
url: string;
|
||||||
|
thumbnail: string;
|
||||||
|
metadata: ImageMetadata;
|
||||||
};
|
};
|
||||||
|
|
||||||
// GalleryImages is an array of Image.
|
// GalleryImages is an array of Image.
|
||||||
export declare type GalleryImages = {
|
export declare type GalleryImages = {
|
||||||
images: Array<Image>;
|
images: Array<_Image>;
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -275,7 +289,7 @@ export declare type SystemStatusResponse = SystemStatus;
|
|||||||
|
|
||||||
export declare type SystemConfigResponse = SystemConfig;
|
export declare type SystemConfigResponse = SystemConfig;
|
||||||
|
|
||||||
export declare type ImageResultResponse = Omit<Image, 'uuid'> & {
|
export declare type ImageResultResponse = Omit<_Image, 'uuid'> & {
|
||||||
boundingBox?: IRect;
|
boundingBox?: IRect;
|
||||||
generationMode: InvokeTabName;
|
generationMode: InvokeTabName;
|
||||||
};
|
};
|
||||||
@ -296,7 +310,7 @@ export declare type ErrorResponse = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
export declare type GalleryImagesResponse = {
|
export declare type GalleryImagesResponse = {
|
||||||
images: Array<Omit<Image, 'uuid'>>;
|
images: Array<Omit<_Image, 'uuid'>>;
|
||||||
areMoreImagesAvailable: boolean;
|
areMoreImagesAvailable: boolean;
|
||||||
category: GalleryCategory;
|
category: GalleryCategory;
|
||||||
};
|
};
|
||||||
|
@ -20,6 +20,7 @@ export const readinessSelector = createSelector(
|
|||||||
seedWeights,
|
seedWeights,
|
||||||
initialImage,
|
initialImage,
|
||||||
seed,
|
seed,
|
||||||
|
isImageToImageEnabled,
|
||||||
} = generation;
|
} = generation;
|
||||||
|
|
||||||
const { isProcessing, isConnected } = system;
|
const { isProcessing, isConnected } = system;
|
||||||
@ -33,7 +34,7 @@ export const readinessSelector = createSelector(
|
|||||||
reasonsWhyNotReady.push('Missing prompt');
|
reasonsWhyNotReady.push('Missing prompt');
|
||||||
}
|
}
|
||||||
|
|
||||||
if (activeTabName === 'img2img' && !initialImage) {
|
if (isImageToImageEnabled && !initialImage) {
|
||||||
isReady = false;
|
isReady = false;
|
||||||
reasonsWhyNotReady.push('No initial image selected');
|
reasonsWhyNotReady.push('No initial image selected');
|
||||||
}
|
}
|
||||||
|
@ -13,9 +13,13 @@ import { InvokeTabName } from 'features/ui/store/tabMap';
|
|||||||
export const generateImage = createAction<InvokeTabName>(
|
export const generateImage = createAction<InvokeTabName>(
|
||||||
'socketio/generateImage'
|
'socketio/generateImage'
|
||||||
);
|
);
|
||||||
export const runESRGAN = createAction<InvokeAI.Image>('socketio/runESRGAN');
|
export const runESRGAN = createAction<InvokeAI._Image>('socketio/runESRGAN');
|
||||||
export const runFacetool = createAction<InvokeAI.Image>('socketio/runFacetool');
|
export const runFacetool = createAction<InvokeAI._Image>(
|
||||||
export const deleteImage = createAction<InvokeAI.Image>('socketio/deleteImage');
|
'socketio/runFacetool'
|
||||||
|
);
|
||||||
|
export const deleteImage = createAction<InvokeAI._Image>(
|
||||||
|
'socketio/deleteImage'
|
||||||
|
);
|
||||||
export const requestImages = createAction<GalleryCategory>(
|
export const requestImages = createAction<GalleryCategory>(
|
||||||
'socketio/requestImages'
|
'socketio/requestImages'
|
||||||
);
|
);
|
||||||
|
@ -91,7 +91,7 @@ const makeSocketIOEmitters = (
|
|||||||
})
|
})
|
||||||
);
|
);
|
||||||
},
|
},
|
||||||
emitRunESRGAN: (imageToProcess: InvokeAI.Image) => {
|
emitRunESRGAN: (imageToProcess: InvokeAI._Image) => {
|
||||||
dispatch(setIsProcessing(true));
|
dispatch(setIsProcessing(true));
|
||||||
|
|
||||||
const {
|
const {
|
||||||
@ -119,7 +119,7 @@ const makeSocketIOEmitters = (
|
|||||||
})
|
})
|
||||||
);
|
);
|
||||||
},
|
},
|
||||||
emitRunFacetool: (imageToProcess: InvokeAI.Image) => {
|
emitRunFacetool: (imageToProcess: InvokeAI._Image) => {
|
||||||
dispatch(setIsProcessing(true));
|
dispatch(setIsProcessing(true));
|
||||||
|
|
||||||
const {
|
const {
|
||||||
@ -150,7 +150,7 @@ const makeSocketIOEmitters = (
|
|||||||
})
|
})
|
||||||
);
|
);
|
||||||
},
|
},
|
||||||
emitDeleteImage: (imageToDelete: InvokeAI.Image) => {
|
emitDeleteImage: (imageToDelete: InvokeAI._Image) => {
|
||||||
const { url, uuid, category, thumbnail } = imageToDelete;
|
const { url, uuid, category, thumbnail } = imageToDelete;
|
||||||
dispatch(removeImage(imageToDelete));
|
dispatch(removeImage(imageToDelete));
|
||||||
socketio.emit('deleteImage', url, thumbnail, uuid, category);
|
socketio.emit('deleteImage', url, thumbnail, uuid, category);
|
||||||
|
@ -34,8 +34,9 @@ import type { RootState } from 'app/store';
|
|||||||
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
|
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
|
||||||
import {
|
import {
|
||||||
clearInitialImage,
|
clearInitialImage,
|
||||||
|
initialImageSelected,
|
||||||
setInfillMethod,
|
setInfillMethod,
|
||||||
setInitialImage,
|
// setInitialImage,
|
||||||
setMaskPath,
|
setMaskPath,
|
||||||
} from 'features/parameters/store/generationSlice';
|
} from 'features/parameters/store/generationSlice';
|
||||||
import { tabMap } from 'features/ui/store/tabMap';
|
import { tabMap } from 'features/ui/store/tabMap';
|
||||||
@ -142,15 +143,17 @@ const makeSocketIOListeners = (
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (shouldLoopback) {
|
// TODO: fix
|
||||||
const activeTabName = tabMap[activeTab];
|
// if (shouldLoopback) {
|
||||||
switch (activeTabName) {
|
// const activeTabName = tabMap[activeTab];
|
||||||
case 'img2img': {
|
// switch (activeTabName) {
|
||||||
dispatch(setInitialImage(newImage));
|
// case 'img2img': {
|
||||||
break;
|
// dispatch(initialImageSelected(newImage.uuid));
|
||||||
}
|
// // dispatch(setInitialImage(newImage));
|
||||||
}
|
// break;
|
||||||
}
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
dispatch(clearIntermediateImage());
|
dispatch(clearIntermediateImage());
|
||||||
|
|
||||||
@ -262,7 +265,7 @@ const makeSocketIOListeners = (
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
// Generate a UUID for each image
|
// Generate a UUID for each image
|
||||||
const preparedImages = images.map((image): InvokeAI.Image => {
|
const preparedImages = images.map((image): InvokeAI._Image => {
|
||||||
return {
|
return {
|
||||||
uuid: uuidv4(),
|
uuid: uuidv4(),
|
||||||
...image,
|
...image,
|
||||||
@ -334,7 +337,7 @@ const makeSocketIOListeners = (
|
|||||||
|
|
||||||
if (
|
if (
|
||||||
initialImage === url ||
|
initialImage === url ||
|
||||||
(initialImage as InvokeAI.Image)?.url === url
|
(initialImage as InvokeAI._Image)?.url === url
|
||||||
) {
|
) {
|
||||||
dispatch(clearInitialImage());
|
dispatch(clearInitialImage());
|
||||||
}
|
}
|
||||||
|
@ -29,6 +29,8 @@ export const socketioMiddleware = () => {
|
|||||||
path: `${window.location.pathname}socket.io`,
|
path: `${window.location.pathname}socket.io`,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
socketio.disconnect();
|
||||||
|
|
||||||
let areListenersSet = false;
|
let areListenersSet = false;
|
||||||
|
|
||||||
const middleware: Middleware = (store) => (next) => (action) => {
|
const middleware: Middleware = (store) => (next) => (action) => {
|
||||||
|
@ -2,18 +2,32 @@ import { combineReducers, configureStore } from '@reduxjs/toolkit';
|
|||||||
|
|
||||||
import { persistReducer } from 'redux-persist';
|
import { persistReducer } from 'redux-persist';
|
||||||
import storage from 'redux-persist/lib/storage'; // defaults to localStorage for web
|
import storage from 'redux-persist/lib/storage'; // defaults to localStorage for web
|
||||||
|
import dynamicMiddlewares from 'redux-dynamic-middlewares';
|
||||||
import { getPersistConfig } from 'redux-deep-persist';
|
import { getPersistConfig } from 'redux-deep-persist';
|
||||||
|
|
||||||
import canvasReducer from 'features/canvas/store/canvasSlice';
|
import canvasReducer from 'features/canvas/store/canvasSlice';
|
||||||
import galleryReducer from 'features/gallery/store/gallerySlice';
|
import galleryReducer from 'features/gallery/store/gallerySlice';
|
||||||
|
import resultsReducer from 'features/gallery/store/resultsSlice';
|
||||||
|
import uploadsReducer from 'features/gallery/store/uploadsSlice';
|
||||||
import lightboxReducer from 'features/lightbox/store/lightboxSlice';
|
import lightboxReducer from 'features/lightbox/store/lightboxSlice';
|
||||||
import generationReducer from 'features/parameters/store/generationSlice';
|
import generationReducer from 'features/parameters/store/generationSlice';
|
||||||
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
|
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
|
||||||
import systemReducer from 'features/system/store/systemSlice';
|
import systemReducer from 'features/system/store/systemSlice';
|
||||||
import uiReducer from 'features/ui/store/uiSlice';
|
import uiReducer from 'features/ui/store/uiSlice';
|
||||||
|
import modelsReducer from 'features/system/store/modelSlice';
|
||||||
|
import nodesReducer from 'features/nodes/store/nodesSlice';
|
||||||
|
|
||||||
import { socketioMiddleware } from './socketio/middleware';
|
import { socketioMiddleware } from './socketio/middleware';
|
||||||
|
import { socketMiddleware } from 'services/events/middleware';
|
||||||
|
import { canvasBlacklist } from 'features/canvas/store/canvasPersistBlacklist';
|
||||||
|
import { galleryBlacklist } from 'features/gallery/store/galleryPersistBlacklist';
|
||||||
|
import { generationBlacklist } from 'features/parameters/store/generationPersistBlacklist';
|
||||||
|
import { lightboxBlacklist } from 'features/lightbox/store/lightboxPersistBlacklist';
|
||||||
|
import { modelsBlacklist } from 'features/system/store/modelsPersistBlacklist';
|
||||||
|
import { nodesBlacklist } from 'features/nodes/store/nodesPersistBlacklist';
|
||||||
|
import { postprocessingBlacklist } from 'features/parameters/store/postprocessingPersistBlacklist';
|
||||||
|
import { systemBlacklist } from 'features/system/store/systemPersistsBlacklist';
|
||||||
|
import { uiBlacklist } from 'features/ui/store/uiPersistBlacklist';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* redux-persist provides an easy and reliable way to persist state across reloads.
|
* redux-persist provides an easy and reliable way to persist state across reloads.
|
||||||
@ -29,49 +43,18 @@ import { socketioMiddleware } from './socketio/middleware';
|
|||||||
* The necesssary nested persistors with blacklists are configured below.
|
* The necesssary nested persistors with blacklists are configured below.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
const canvasBlacklist = [
|
|
||||||
'cursorPosition',
|
|
||||||
'isCanvasInitialized',
|
|
||||||
'doesCanvasNeedScaling',
|
|
||||||
].map((blacklistItem) => `canvas.${blacklistItem}`);
|
|
||||||
|
|
||||||
const systemBlacklist = [
|
|
||||||
'currentIteration',
|
|
||||||
'currentStatus',
|
|
||||||
'currentStep',
|
|
||||||
'isCancelable',
|
|
||||||
'isConnected',
|
|
||||||
'isESRGANAvailable',
|
|
||||||
'isGFPGANAvailable',
|
|
||||||
'isProcessing',
|
|
||||||
'socketId',
|
|
||||||
'totalIterations',
|
|
||||||
'totalSteps',
|
|
||||||
'openModel',
|
|
||||||
'cancelOptions.cancelAfter',
|
|
||||||
].map((blacklistItem) => `system.${blacklistItem}`);
|
|
||||||
|
|
||||||
const galleryBlacklist = [
|
|
||||||
'categories',
|
|
||||||
'currentCategory',
|
|
||||||
'currentImage',
|
|
||||||
'currentImageUuid',
|
|
||||||
'shouldAutoSwitchToNewImages',
|
|
||||||
'intermediateImage',
|
|
||||||
].map((blacklistItem) => `gallery.${blacklistItem}`);
|
|
||||||
|
|
||||||
const lightboxBlacklist = ['isLightboxOpen'].map(
|
|
||||||
(blacklistItem) => `lightbox.${blacklistItem}`
|
|
||||||
);
|
|
||||||
|
|
||||||
const rootReducer = combineReducers({
|
const rootReducer = combineReducers({
|
||||||
generation: generationReducer,
|
|
||||||
postprocessing: postprocessingReducer,
|
|
||||||
gallery: galleryReducer,
|
|
||||||
system: systemReducer,
|
|
||||||
canvas: canvasReducer,
|
canvas: canvasReducer,
|
||||||
ui: uiReducer,
|
gallery: galleryReducer,
|
||||||
|
generation: generationReducer,
|
||||||
lightbox: lightboxReducer,
|
lightbox: lightboxReducer,
|
||||||
|
models: modelsReducer,
|
||||||
|
nodes: nodesReducer,
|
||||||
|
postprocessing: postprocessingReducer,
|
||||||
|
results: resultsReducer,
|
||||||
|
system: systemReducer,
|
||||||
|
ui: uiReducer,
|
||||||
|
uploads: uploadsReducer,
|
||||||
});
|
});
|
||||||
|
|
||||||
const rootPersistConfig = getPersistConfig({
|
const rootPersistConfig = getPersistConfig({
|
||||||
@ -80,23 +63,40 @@ const rootPersistConfig = getPersistConfig({
|
|||||||
rootReducer,
|
rootReducer,
|
||||||
blacklist: [
|
blacklist: [
|
||||||
...canvasBlacklist,
|
...canvasBlacklist,
|
||||||
...systemBlacklist,
|
|
||||||
...galleryBlacklist,
|
...galleryBlacklist,
|
||||||
|
...generationBlacklist,
|
||||||
...lightboxBlacklist,
|
...lightboxBlacklist,
|
||||||
|
...modelsBlacklist,
|
||||||
|
...nodesBlacklist,
|
||||||
|
...postprocessingBlacklist,
|
||||||
|
// ...resultsBlacklist,
|
||||||
|
'results',
|
||||||
|
...systemBlacklist,
|
||||||
|
...uiBlacklist,
|
||||||
|
// ...uploadsBlacklist,
|
||||||
|
'uploads',
|
||||||
],
|
],
|
||||||
debounce: 300,
|
debounce: 300,
|
||||||
});
|
});
|
||||||
|
|
||||||
const persistedReducer = persistReducer(rootPersistConfig, rootReducer);
|
const persistedReducer = persistReducer(rootPersistConfig, rootReducer);
|
||||||
|
|
||||||
// Continue with store setup
|
// TODO: rip the old middleware out when nodes is complete
|
||||||
|
export function buildMiddleware() {
|
||||||
|
if (import.meta.env.MODE === 'nodes' || import.meta.env.MODE === 'package') {
|
||||||
|
return socketMiddleware();
|
||||||
|
} else {
|
||||||
|
return socketioMiddleware();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
export const store = configureStore({
|
export const store = configureStore({
|
||||||
reducer: persistedReducer,
|
reducer: persistedReducer,
|
||||||
middleware: (getDefaultMiddleware) =>
|
middleware: (getDefaultMiddleware) =>
|
||||||
getDefaultMiddleware({
|
getDefaultMiddleware({
|
||||||
immutableCheck: false,
|
immutableCheck: false,
|
||||||
serializableCheck: false,
|
serializableCheck: false,
|
||||||
}).concat(socketioMiddleware()),
|
}).concat(dynamicMiddlewares),
|
||||||
devTools: {
|
devTools: {
|
||||||
// Uncommenting these very rapidly called actions makes the redux dev tools output much more readable
|
// Uncommenting these very rapidly called actions makes the redux dev tools output much more readable
|
||||||
actionsDenylist: [
|
actionsDenylist: [
|
||||||
|
8
invokeai/frontend/web/src/app/storeUtils.ts
Normal file
8
invokeai/frontend/web/src/app/storeUtils.ts
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
import { createAsyncThunk } from '@reduxjs/toolkit';
|
||||||
|
import { AppDispatch, RootState } from './store';
|
||||||
|
|
||||||
|
// https://redux-toolkit.js.org/usage/usage-with-typescript#defining-a-pre-typed-createasyncthunk
|
||||||
|
export const createAppAsyncThunk = createAsyncThunk.withTypes<{
|
||||||
|
state: RootState;
|
||||||
|
dispatch: AppDispatch;
|
||||||
|
}>();
|
@ -44,12 +44,10 @@ export type IAIFullSliderProps = {
|
|||||||
inputReadOnly?: boolean;
|
inputReadOnly?: boolean;
|
||||||
withReset?: boolean;
|
withReset?: boolean;
|
||||||
handleReset?: () => void;
|
handleReset?: () => void;
|
||||||
isResetDisabled?: boolean;
|
|
||||||
isSliderDisabled?: boolean;
|
|
||||||
isInputDisabled?: boolean;
|
|
||||||
tooltipSuffix?: string;
|
tooltipSuffix?: string;
|
||||||
hideTooltip?: boolean;
|
hideTooltip?: boolean;
|
||||||
isCompact?: boolean;
|
isCompact?: boolean;
|
||||||
|
isDisabled?: boolean;
|
||||||
sliderFormControlProps?: FormControlProps;
|
sliderFormControlProps?: FormControlProps;
|
||||||
sliderFormLabelProps?: FormLabelProps;
|
sliderFormLabelProps?: FormLabelProps;
|
||||||
sliderMarkProps?: Omit<SliderMarkProps, 'value'>;
|
sliderMarkProps?: Omit<SliderMarkProps, 'value'>;
|
||||||
@ -80,10 +78,8 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
|||||||
withReset = false,
|
withReset = false,
|
||||||
hideTooltip = false,
|
hideTooltip = false,
|
||||||
isCompact = false,
|
isCompact = false,
|
||||||
|
isDisabled = false,
|
||||||
handleReset,
|
handleReset,
|
||||||
isResetDisabled,
|
|
||||||
isSliderDisabled,
|
|
||||||
isInputDisabled,
|
|
||||||
sliderFormControlProps,
|
sliderFormControlProps,
|
||||||
sliderFormLabelProps,
|
sliderFormLabelProps,
|
||||||
sliderMarkProps,
|
sliderMarkProps,
|
||||||
@ -149,6 +145,7 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
|||||||
}
|
}
|
||||||
: {}
|
: {}
|
||||||
}
|
}
|
||||||
|
isDisabled={isDisabled}
|
||||||
{...sliderFormControlProps}
|
{...sliderFormControlProps}
|
||||||
>
|
>
|
||||||
<FormLabel {...sliderFormLabelProps} mb={-1}>
|
<FormLabel {...sliderFormLabelProps} mb={-1}>
|
||||||
@ -166,15 +163,13 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
|||||||
onMouseEnter={() => setShowTooltip(true)}
|
onMouseEnter={() => setShowTooltip(true)}
|
||||||
onMouseLeave={() => setShowTooltip(false)}
|
onMouseLeave={() => setShowTooltip(false)}
|
||||||
focusThumbOnChange={false}
|
focusThumbOnChange={false}
|
||||||
isDisabled={isSliderDisabled}
|
isDisabled={isDisabled}
|
||||||
// width={width}
|
|
||||||
{...rest}
|
{...rest}
|
||||||
>
|
>
|
||||||
{withSliderMarks && (
|
{withSliderMarks && (
|
||||||
<>
|
<>
|
||||||
<SliderMark
|
<SliderMark
|
||||||
value={min}
|
value={min}
|
||||||
// insetInlineStart={0}
|
|
||||||
sx={{
|
sx={{
|
||||||
insetInlineStart: '0 !important',
|
insetInlineStart: '0 !important',
|
||||||
insetInlineEnd: 'unset !important',
|
insetInlineEnd: 'unset !important',
|
||||||
@ -185,7 +180,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
|||||||
</SliderMark>
|
</SliderMark>
|
||||||
<SliderMark
|
<SliderMark
|
||||||
value={max}
|
value={max}
|
||||||
// insetInlineEnd={0}
|
|
||||||
sx={{
|
sx={{
|
||||||
insetInlineStart: 'unset !important',
|
insetInlineStart: 'unset !important',
|
||||||
insetInlineEnd: '0 !important',
|
insetInlineEnd: '0 !important',
|
||||||
@ -221,7 +215,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
|||||||
value={localInputValue}
|
value={localInputValue}
|
||||||
onChange={handleInputChange}
|
onChange={handleInputChange}
|
||||||
onBlur={handleInputBlur}
|
onBlur={handleInputBlur}
|
||||||
isDisabled={isInputDisabled}
|
|
||||||
{...sliderNumberInputProps}
|
{...sliderNumberInputProps}
|
||||||
>
|
>
|
||||||
<NumberInputField
|
<NumberInputField
|
||||||
@ -246,8 +239,8 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
|||||||
aria-label={t('accessibility.reset')}
|
aria-label={t('accessibility.reset')}
|
||||||
tooltip="Reset"
|
tooltip="Reset"
|
||||||
icon={<BiReset />}
|
icon={<BiReset />}
|
||||||
|
isDisabled={isDisabled}
|
||||||
onClick={handleResetDisable}
|
onClick={handleResetDisable}
|
||||||
isDisabled={isResetDisabled}
|
|
||||||
{...sliderIAIIconButtonProps}
|
{...sliderIAIIconButtonProps}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
|
@ -0,0 +1,79 @@
|
|||||||
|
import { Badge, Box, ButtonGroup, Flex } from '@chakra-ui/react';
|
||||||
|
import { RootState } from 'app/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||||
|
import { clearInitialImage } from 'features/parameters/store/generationSlice';
|
||||||
|
import { useCallback } from 'react';
|
||||||
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
|
import { FaUndo, FaUpload } from 'react-icons/fa';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { Image } from 'app/invokeai';
|
||||||
|
|
||||||
|
type ImageToImageOverlayProps = {
|
||||||
|
setIsLoaded: (isLoaded: boolean) => void;
|
||||||
|
image: Image;
|
||||||
|
};
|
||||||
|
|
||||||
|
const ImageToImageOverlay = ({
|
||||||
|
setIsLoaded,
|
||||||
|
image,
|
||||||
|
}: ImageToImageOverlayProps) => {
|
||||||
|
const isImageToImageEnabled = useAppSelector(
|
||||||
|
(state: RootState) => state.generation.isImageToImageEnabled
|
||||||
|
);
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const handleResetInitialImage = useCallback(() => {
|
||||||
|
dispatch(clearInitialImage());
|
||||||
|
setIsLoaded(false);
|
||||||
|
}, [dispatch, setIsLoaded]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Box
|
||||||
|
sx={{
|
||||||
|
top: 0,
|
||||||
|
left: 0,
|
||||||
|
w: 'full',
|
||||||
|
h: 'full',
|
||||||
|
position: 'absolute',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<ButtonGroup
|
||||||
|
sx={{
|
||||||
|
position: 'absolute',
|
||||||
|
top: 0,
|
||||||
|
right: 0,
|
||||||
|
p: 2,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<IAIIconButton
|
||||||
|
size="sm"
|
||||||
|
isDisabled={!isImageToImageEnabled}
|
||||||
|
icon={<FaUndo />}
|
||||||
|
aria-label={t('accessibility.reset')}
|
||||||
|
onClick={handleResetInitialImage}
|
||||||
|
/>
|
||||||
|
<IAIIconButton
|
||||||
|
size="sm"
|
||||||
|
isDisabled={!isImageToImageEnabled}
|
||||||
|
icon={<FaUpload />}
|
||||||
|
aria-label={t('common.upload')}
|
||||||
|
/>
|
||||||
|
</ButtonGroup>
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
position: 'absolute',
|
||||||
|
bottom: 0,
|
||||||
|
left: 0,
|
||||||
|
p: 2,
|
||||||
|
alignItems: 'flex-start',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Badge variant="solid" colorScheme="base">
|
||||||
|
{image.metadata?.width} × {image.metadata?.height}
|
||||||
|
</Badge>
|
||||||
|
</Flex>
|
||||||
|
</Box>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default ImageToImageOverlay;
|
@ -2,7 +2,6 @@ import { Box, useToast } from '@chakra-ui/react';
|
|||||||
import { ImageUploaderTriggerContext } from 'app/contexts/ImageUploaderTriggerContext';
|
import { ImageUploaderTriggerContext } from 'app/contexts/ImageUploaderTriggerContext';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||||
import useImageUploader from 'common/hooks/useImageUploader';
|
import useImageUploader from 'common/hooks/useImageUploader';
|
||||||
import { uploadImage } from 'features/gallery/store/thunks/uploadImage';
|
|
||||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||||
import { ResourceKey } from 'i18next';
|
import { ResourceKey } from 'i18next';
|
||||||
import {
|
import {
|
||||||
@ -15,6 +14,7 @@ import {
|
|||||||
} from 'react';
|
} from 'react';
|
||||||
import { FileRejection, useDropzone } from 'react-dropzone';
|
import { FileRejection, useDropzone } from 'react-dropzone';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { imageUploaded } from 'services/thunks/image';
|
||||||
import ImageUploadOverlay from './ImageUploadOverlay';
|
import ImageUploadOverlay from './ImageUploadOverlay';
|
||||||
|
|
||||||
type ImageUploaderProps = {
|
type ImageUploaderProps = {
|
||||||
@ -49,7 +49,7 @@ const ImageUploader = (props: ImageUploaderProps) => {
|
|||||||
|
|
||||||
const fileAcceptedCallback = useCallback(
|
const fileAcceptedCallback = useCallback(
|
||||||
async (file: File) => {
|
async (file: File) => {
|
||||||
dispatch(uploadImage({ imageFile: file }));
|
dispatch(imageUploaded({ formData: { file } }));
|
||||||
},
|
},
|
||||||
[dispatch]
|
[dispatch]
|
||||||
);
|
);
|
||||||
@ -124,7 +124,7 @@ const ImageUploader = (props: ImageUploaderProps) => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
dispatch(uploadImage({ imageFile: file }));
|
dispatch(imageUploaded({ formData: { file } }));
|
||||||
};
|
};
|
||||||
document.addEventListener('paste', pasteImageListener);
|
document.addEventListener('paste', pasteImageListener);
|
||||||
return () => {
|
return () => {
|
||||||
|
@ -0,0 +1,12 @@
|
|||||||
|
import { Flex, Icon } from '@chakra-ui/react';
|
||||||
|
import { FaImage } from 'react-icons/fa';
|
||||||
|
|
||||||
|
const SelectImagePlaceholder = () => {
|
||||||
|
return (
|
||||||
|
<Flex sx={{ h: 36, alignItems: 'center', justifyContent: 'center' }}>
|
||||||
|
<Icon color="base.400" boxSize={32} as={FaImage}></Icon>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default SelectImagePlaceholder;
|
@ -1,27 +1,160 @@
|
|||||||
import { Flex, Heading, Text, VStack } from '@chakra-ui/react';
|
// import WorkInProgress from './WorkInProgress';
|
||||||
import { useTranslation } from 'react-i18next';
|
// import ReactFlow, {
|
||||||
import WorkInProgress from './WorkInProgress';
|
// applyEdgeChanges,
|
||||||
|
// applyNodeChanges,
|
||||||
|
// Background,
|
||||||
|
// Controls,
|
||||||
|
// Edge,
|
||||||
|
// Handle,
|
||||||
|
// Node,
|
||||||
|
// NodeTypes,
|
||||||
|
// OnEdgesChange,
|
||||||
|
// OnNodesChange,
|
||||||
|
// Position,
|
||||||
|
// } from 'reactflow';
|
||||||
|
|
||||||
export default function NodesWIP() {
|
// import 'reactflow/dist/style.css';
|
||||||
const { t } = useTranslation();
|
// import {
|
||||||
return (
|
// Fragment,
|
||||||
<WorkInProgress>
|
// FunctionComponent,
|
||||||
<Flex
|
// ReactNode,
|
||||||
sx={{
|
// useCallback,
|
||||||
flexDirection: 'column',
|
// useMemo,
|
||||||
alignItems: 'center',
|
// useState,
|
||||||
justifyContent: 'center',
|
// } from 'react';
|
||||||
w: '100%',
|
// import { OpenAPIV3 } from 'openapi-types';
|
||||||
h: '100%',
|
// import { filter, map, reduce } from 'lodash';
|
||||||
gap: 4,
|
// import {
|
||||||
textAlign: 'center',
|
// Box,
|
||||||
}}
|
// Flex,
|
||||||
>
|
// FormControl,
|
||||||
<Heading>{t('common.nodes')}</Heading>
|
// FormLabel,
|
||||||
<VStack maxW="50rem" gap={4}>
|
// Input,
|
||||||
<Text>{t('common.nodesDesc')}</Text>
|
// Select,
|
||||||
</VStack>
|
// Switch,
|
||||||
</Flex>
|
// Text,
|
||||||
</WorkInProgress>
|
// NumberInput,
|
||||||
);
|
// NumberInputField,
|
||||||
}
|
// NumberInputStepper,
|
||||||
|
// NumberIncrementStepper,
|
||||||
|
// NumberDecrementStepper,
|
||||||
|
// Tooltip,
|
||||||
|
// chakra,
|
||||||
|
// Badge,
|
||||||
|
// Heading,
|
||||||
|
// VStack,
|
||||||
|
// HStack,
|
||||||
|
// Menu,
|
||||||
|
// MenuButton,
|
||||||
|
// MenuList,
|
||||||
|
// MenuItem,
|
||||||
|
// MenuItemOption,
|
||||||
|
// MenuGroup,
|
||||||
|
// MenuOptionGroup,
|
||||||
|
// MenuDivider,
|
||||||
|
// IconButton,
|
||||||
|
// } from '@chakra-ui/react';
|
||||||
|
// import { FaPlus } from 'react-icons/fa';
|
||||||
|
// import {
|
||||||
|
// FIELD_NAMES as FIELD_NAMES,
|
||||||
|
// FIELDS,
|
||||||
|
// INVOCATION_NAMES as INVOCATION_NAMES,
|
||||||
|
// INVOCATIONS,
|
||||||
|
// } from 'features/nodeEditor/constants';
|
||||||
|
|
||||||
|
// console.log('invocations', INVOCATIONS);
|
||||||
|
|
||||||
|
// const nodeTypes = reduce(
|
||||||
|
// INVOCATIONS,
|
||||||
|
// (acc, val, key) => {
|
||||||
|
// acc[key] = val.component;
|
||||||
|
// return acc;
|
||||||
|
// },
|
||||||
|
// {} as NodeTypes
|
||||||
|
// );
|
||||||
|
|
||||||
|
// console.log('nodeTypes', nodeTypes);
|
||||||
|
|
||||||
|
// // make initial nodes one of every node for now
|
||||||
|
// let n = 0;
|
||||||
|
// const initialNodes = map(INVOCATIONS, (i) => ({
|
||||||
|
// id: i.type,
|
||||||
|
// type: i.title,
|
||||||
|
// position: { x: (n += 20), y: (n += 20) },
|
||||||
|
// data: {},
|
||||||
|
// }));
|
||||||
|
|
||||||
|
// console.log('initialNodes', initialNodes);
|
||||||
|
|
||||||
|
// export default function NodesWIP() {
|
||||||
|
// const [nodes, setNodes] = useState<Node[]>([]);
|
||||||
|
// const [edges, setEdges] = useState<Edge[]>([]);
|
||||||
|
|
||||||
|
// const onNodesChange: OnNodesChange = useCallback(
|
||||||
|
// (changes) => setNodes((nds) => applyNodeChanges(changes, nds)),
|
||||||
|
// []
|
||||||
|
// );
|
||||||
|
|
||||||
|
// const onEdgesChange: OnEdgesChange = useCallback(
|
||||||
|
// (changes) => setEdges((eds: Edge[]) => applyEdgeChanges(changes, eds)),
|
||||||
|
// []
|
||||||
|
// );
|
||||||
|
|
||||||
|
// return (
|
||||||
|
// <Box
|
||||||
|
// sx={{
|
||||||
|
// position: 'relative',
|
||||||
|
// width: 'full',
|
||||||
|
// height: 'full',
|
||||||
|
// borderRadius: 'md',
|
||||||
|
// }}
|
||||||
|
// >
|
||||||
|
// <ReactFlow
|
||||||
|
// nodeTypes={nodeTypes}
|
||||||
|
// nodes={nodes}
|
||||||
|
// edges={edges}
|
||||||
|
// onNodesChange={onNodesChange}
|
||||||
|
// onEdgesChange={onEdgesChange}
|
||||||
|
// >
|
||||||
|
// <Background />
|
||||||
|
// <Controls />
|
||||||
|
// </ReactFlow>
|
||||||
|
// <HStack sx={{ position: 'absolute', top: 2, right: 2 }}>
|
||||||
|
// {FIELD_NAMES.map((field) => (
|
||||||
|
// <Badge
|
||||||
|
// key={field}
|
||||||
|
// colorScheme={FIELDS[field].color}
|
||||||
|
// sx={{ userSelect: 'none' }}
|
||||||
|
// >
|
||||||
|
// {field}
|
||||||
|
// </Badge>
|
||||||
|
// ))}
|
||||||
|
// </HStack>
|
||||||
|
// <Menu>
|
||||||
|
// <MenuButton
|
||||||
|
// as={IconButton}
|
||||||
|
// aria-label="Options"
|
||||||
|
// icon={<FaPlus />}
|
||||||
|
// sx={{ position: 'absolute', top: 2, left: 2 }}
|
||||||
|
// />
|
||||||
|
// <MenuList>
|
||||||
|
// {INVOCATION_NAMES.map((name) => {
|
||||||
|
// const invocation = INVOCATIONS[name];
|
||||||
|
// return (
|
||||||
|
// <Tooltip
|
||||||
|
// key={name}
|
||||||
|
// label={invocation.description}
|
||||||
|
// placement="end"
|
||||||
|
// hasArrow
|
||||||
|
// >
|
||||||
|
// <MenuItem>{invocation.title}</MenuItem>
|
||||||
|
// </Tooltip>
|
||||||
|
// );
|
||||||
|
// })}
|
||||||
|
// </MenuList>
|
||||||
|
// </Menu>
|
||||||
|
// </Box>
|
||||||
|
// );
|
||||||
|
// }
|
||||||
|
|
||||||
|
export default {};
|
||||||
|
@ -14,6 +14,8 @@ const WorkInProgress = (props: WorkInProgressProps) => {
|
|||||||
width: '100%',
|
width: '100%',
|
||||||
height: '100%',
|
height: '100%',
|
||||||
bg: 'base.850',
|
bg: 'base.850',
|
||||||
|
borderRadius: 'base',
|
||||||
|
position: 'relative',
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
{children}
|
{children}
|
||||||
|
119
invokeai/frontend/web/src/common/util/_parseMetadataZod.ts
Normal file
119
invokeai/frontend/web/src/common/util/_parseMetadataZod.ts
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
/**
|
||||||
|
* PARTIAL ZOD IMPLEMENTATION
|
||||||
|
*
|
||||||
|
* doesn't work well bc like most validators, zod is not built to skip invalid values.
|
||||||
|
* it mostly works but just seems clearer and simpler to manually parse for now.
|
||||||
|
*
|
||||||
|
* in the future it would be really nice if we could use zod for some things:
|
||||||
|
* - zodios (axios + zod): https://github.com/ecyrbe/zodios
|
||||||
|
* - openapi to zodios: https://github.com/astahmer/openapi-zod-client
|
||||||
|
*/
|
||||||
|
|
||||||
|
// import { z } from 'zod';
|
||||||
|
|
||||||
|
// const zMetadataStringField = z.string();
|
||||||
|
// export type MetadataStringField = z.infer<typeof zMetadataStringField>;
|
||||||
|
|
||||||
|
// const zMetadataIntegerField = z.number().int();
|
||||||
|
// export type MetadataIntegerField = z.infer<typeof zMetadataIntegerField>;
|
||||||
|
|
||||||
|
// const zMetadataFloatField = z.number();
|
||||||
|
// export type MetadataFloatField = z.infer<typeof zMetadataFloatField>;
|
||||||
|
|
||||||
|
// const zMetadataBooleanField = z.boolean();
|
||||||
|
// export type MetadataBooleanField = z.infer<typeof zMetadataBooleanField>;
|
||||||
|
|
||||||
|
// const zMetadataImageField = z.object({
|
||||||
|
// image_type: z.union([
|
||||||
|
// z.literal('results'),
|
||||||
|
// z.literal('uploads'),
|
||||||
|
// z.literal('intermediates'),
|
||||||
|
// ]),
|
||||||
|
// image_name: z.string().min(1),
|
||||||
|
// });
|
||||||
|
// export type MetadataImageField = z.infer<typeof zMetadataImageField>;
|
||||||
|
|
||||||
|
// const zMetadataLatentsField = z.object({
|
||||||
|
// latents_name: z.string().min(1),
|
||||||
|
// });
|
||||||
|
// export type MetadataLatentsField = z.infer<typeof zMetadataLatentsField>;
|
||||||
|
|
||||||
|
// /**
|
||||||
|
// * zod Schema for any node field. Use a `transform()` to manually parse, skipping invalid values.
|
||||||
|
// */
|
||||||
|
// const zAnyMetadataField = z.any().transform((val, ctx) => {
|
||||||
|
// // Grab the field name from the path
|
||||||
|
// const fieldName = String(ctx.path[ctx.path.length - 1]);
|
||||||
|
|
||||||
|
// // `id` and `type` must be strings if they exist
|
||||||
|
// if (['id', 'type'].includes(fieldName)) {
|
||||||
|
// const reservedStringPropertyResult = zMetadataStringField.safeParse(val);
|
||||||
|
// if (reservedStringPropertyResult.success) {
|
||||||
|
// return reservedStringPropertyResult.data;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// return;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // Parse the rest of the fields, only returning the data if the parsing is successful
|
||||||
|
|
||||||
|
// const stringFieldResult = zMetadataStringField.safeParse(val);
|
||||||
|
// if (stringFieldResult.success) {
|
||||||
|
// return stringFieldResult.data;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// const integerFieldResult = zMetadataIntegerField.safeParse(val);
|
||||||
|
// if (integerFieldResult.success) {
|
||||||
|
// return integerFieldResult.data;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// const floatFieldResult = zMetadataFloatField.safeParse(val);
|
||||||
|
// if (floatFieldResult.success) {
|
||||||
|
// return floatFieldResult.data;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// const booleanFieldResult = zMetadataBooleanField.safeParse(val);
|
||||||
|
// if (booleanFieldResult.success) {
|
||||||
|
// return booleanFieldResult.data;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// const imageFieldResult = zMetadataImageField.safeParse(val);
|
||||||
|
// if (imageFieldResult.success) {
|
||||||
|
// return imageFieldResult.data;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// const latentsFieldResult = zMetadataImageField.safeParse(val);
|
||||||
|
// if (latentsFieldResult.success) {
|
||||||
|
// return latentsFieldResult.data;
|
||||||
|
// }
|
||||||
|
// });
|
||||||
|
|
||||||
|
// /**
|
||||||
|
// * The node metadata schema.
|
||||||
|
// */
|
||||||
|
// const zNodeMetadata = z.object({
|
||||||
|
// session_id: z.string().min(1).optional(),
|
||||||
|
// node: z.record(z.string().min(1), zAnyMetadataField).optional(),
|
||||||
|
// });
|
||||||
|
|
||||||
|
// export type NodeMetadata = z.infer<typeof zNodeMetadata>;
|
||||||
|
|
||||||
|
// const zMetadata = z.object({
|
||||||
|
// invokeai: zNodeMetadata.optional(),
|
||||||
|
// 'sd-metadata': z.record(z.string().min(1), z.any()).optional(),
|
||||||
|
// });
|
||||||
|
// export type Metadata = z.infer<typeof zMetadata>;
|
||||||
|
|
||||||
|
// export const parseMetadata = (
|
||||||
|
// metadata: Record<string, any>
|
||||||
|
// ): Metadata | undefined => {
|
||||||
|
// const result = zMetadata.safeParse(metadata);
|
||||||
|
// if (!result.success) {
|
||||||
|
// console.log(result.error.issues);
|
||||||
|
// return;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// return result.data;
|
||||||
|
// };
|
||||||
|
|
||||||
|
export default {};
|
6
invokeai/frontend/web/src/common/util/getTimestamp.ts
Normal file
6
invokeai/frontend/web/src/common/util/getTimestamp.ts
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
import dateFormat from 'dateformat';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get a `now` timestamp with 1s precision, formatted as ISO datetime.
|
||||||
|
*/
|
||||||
|
export const getTimestamp = () => dateFormat(new Date(), 'isoDateTime');
|
28
invokeai/frontend/web/src/common/util/getUrl.ts
Normal file
28
invokeai/frontend/web/src/common/util/getUrl.ts
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
import { RootState } from 'app/store';
|
||||||
|
import { useAppSelector } from 'app/storeHooks';
|
||||||
|
import { OpenAPI } from 'services/api';
|
||||||
|
|
||||||
|
export const getUrlAlt = (url: string, shouldTransformUrls: boolean) => {
|
||||||
|
if (OpenAPI.BASE && shouldTransformUrls) {
|
||||||
|
return [OpenAPI.BASE, url].join('/');
|
||||||
|
}
|
||||||
|
|
||||||
|
return url;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const useGetUrl = () => {
|
||||||
|
const shouldTransformUrls = useAppSelector(
|
||||||
|
(state: RootState) => state.system.shouldTransformUrls
|
||||||
|
);
|
||||||
|
|
||||||
|
return {
|
||||||
|
shouldTransformUrls,
|
||||||
|
getUrl: (url?: string) => {
|
||||||
|
if (OpenAPI.BASE && shouldTransformUrls) {
|
||||||
|
return [OpenAPI.BASE, url].join('/');
|
||||||
|
}
|
||||||
|
|
||||||
|
return url;
|
||||||
|
},
|
||||||
|
};
|
||||||
|
};
|
169
invokeai/frontend/web/src/common/util/parseMetadata.ts
Normal file
169
invokeai/frontend/web/src/common/util/parseMetadata.ts
Normal file
@ -0,0 +1,169 @@
|
|||||||
|
import { forEach, size } from 'lodash';
|
||||||
|
import { ImageField, LatentsField } from 'services/api';
|
||||||
|
|
||||||
|
const OBJECT_TYPESTRING = '[object Object]';
|
||||||
|
const STRING_TYPESTRING = '[object String]';
|
||||||
|
const NUMBER_TYPESTRING = '[object Number]';
|
||||||
|
const BOOLEAN_TYPESTRING = '[object Boolean]';
|
||||||
|
const ARRAY_TYPESTRING = '[object Array]';
|
||||||
|
|
||||||
|
const isObject = (obj: unknown): obj is Record<string | number, any> =>
|
||||||
|
Object.prototype.toString.call(obj) === OBJECT_TYPESTRING;
|
||||||
|
|
||||||
|
const isString = (obj: unknown): obj is string =>
|
||||||
|
Object.prototype.toString.call(obj) === STRING_TYPESTRING;
|
||||||
|
|
||||||
|
const isNumber = (obj: unknown): obj is number =>
|
||||||
|
Object.prototype.toString.call(obj) === NUMBER_TYPESTRING;
|
||||||
|
|
||||||
|
const isBoolean = (obj: unknown): obj is boolean =>
|
||||||
|
Object.prototype.toString.call(obj) === BOOLEAN_TYPESTRING;
|
||||||
|
|
||||||
|
const isArray = (obj: unknown): obj is Array<any> =>
|
||||||
|
Object.prototype.toString.call(obj) === ARRAY_TYPESTRING;
|
||||||
|
|
||||||
|
const parseImageField = (imageField: unknown): ImageField | undefined => {
|
||||||
|
// Must be an object
|
||||||
|
if (!isObject(imageField)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// An ImageField must have both `image_name` and `image_type`
|
||||||
|
if (!('image_name' in imageField && 'image_type' in imageField)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// An ImageField's `image_type` must be one of the allowed values
|
||||||
|
if (
|
||||||
|
!['results', 'uploads', 'intermediates'].includes(imageField.image_type)
|
||||||
|
) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// An ImageField's `image_name` must be a string
|
||||||
|
if (typeof imageField.image_name !== 'string') {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build a valid ImageField
|
||||||
|
return {
|
||||||
|
image_type: imageField.image_type,
|
||||||
|
image_name: imageField.image_name,
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
const parseLatentsField = (latentsField: unknown): LatentsField | undefined => {
|
||||||
|
// Must be an object
|
||||||
|
if (!isObject(latentsField)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// A LatentsField must have a `latents_name`
|
||||||
|
if (!('latents_name' in latentsField)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// A LatentsField's `latents_name` must be a string
|
||||||
|
if (typeof latentsField.latents_name !== 'string') {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build a valid LatentsField
|
||||||
|
return {
|
||||||
|
latents_name: latentsField.latents_name,
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
type NodeMetadata = {
|
||||||
|
[key: string]: string | number | boolean | ImageField | LatentsField;
|
||||||
|
};
|
||||||
|
|
||||||
|
type InvokeAIMetadata = {
|
||||||
|
session_id?: string;
|
||||||
|
node?: NodeMetadata;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const parseNodeMetadata = (
|
||||||
|
nodeMetadata: Record<string | number, any>
|
||||||
|
): NodeMetadata | undefined => {
|
||||||
|
if (!isObject(nodeMetadata)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const parsed: NodeMetadata = {};
|
||||||
|
|
||||||
|
forEach(nodeMetadata, (nodeItem, nodeKey) => {
|
||||||
|
// `id` and `type` must be strings if they are present
|
||||||
|
if (['id', 'type'].includes(nodeKey)) {
|
||||||
|
if (isString(nodeItem)) {
|
||||||
|
parsed[nodeKey] = nodeItem;
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// the only valid object types are ImageField and LatentsField
|
||||||
|
if (isObject(nodeItem)) {
|
||||||
|
if ('image_name' in nodeItem || 'image_type' in nodeItem) {
|
||||||
|
const imageField = parseImageField(nodeItem);
|
||||||
|
if (imageField) {
|
||||||
|
parsed[nodeKey] = imageField;
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ('latents_name' in nodeItem) {
|
||||||
|
const latentsField = parseLatentsField(nodeItem);
|
||||||
|
if (latentsField) {
|
||||||
|
parsed[nodeKey] = latentsField;
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// otherwise we accept any string, number or boolean
|
||||||
|
if (isString(nodeItem) || isNumber(nodeItem) || isBoolean(nodeItem)) {
|
||||||
|
parsed[nodeKey] = nodeItem;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
if (size(parsed) === 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
return parsed;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const parseInvokeAIMetadata = (
|
||||||
|
metadata: Record<string | number, any> | undefined
|
||||||
|
): InvokeAIMetadata | undefined => {
|
||||||
|
if (metadata === undefined) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!isObject(metadata)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const parsed: InvokeAIMetadata = {};
|
||||||
|
|
||||||
|
forEach(metadata, (item, key) => {
|
||||||
|
if (key === 'session_id' && isString(item)) {
|
||||||
|
parsed['session_id'] = item;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (key === 'node' && isObject(item)) {
|
||||||
|
const nodeMetadata = parseNodeMetadata(item);
|
||||||
|
|
||||||
|
if (nodeMetadata) {
|
||||||
|
parsed['node'] = nodeMetadata;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
if (size(parsed) === 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
return parsed;
|
||||||
|
};
|
@ -1,8 +1,10 @@
|
|||||||
import React, { lazy, PropsWithChildren } from 'react';
|
import React, { lazy, PropsWithChildren, useEffect, useState } from 'react';
|
||||||
import { Provider } from 'react-redux';
|
import { Provider } from 'react-redux';
|
||||||
import { PersistGate } from 'redux-persist/integration/react';
|
import { PersistGate } from 'redux-persist/integration/react';
|
||||||
import { store } from './app/store';
|
import { buildMiddleware, store } from './app/store';
|
||||||
import { persistor } from './persistor';
|
import { persistor } from './persistor';
|
||||||
|
import { OpenAPI } from 'services/api';
|
||||||
|
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||||
import '@fontsource/inter/100.css';
|
import '@fontsource/inter/100.css';
|
||||||
import '@fontsource/inter/200.css';
|
import '@fontsource/inter/200.css';
|
||||||
import '@fontsource/inter/300.css';
|
import '@fontsource/inter/300.css';
|
||||||
@ -17,18 +19,61 @@ import Loading from './Loading';
|
|||||||
|
|
||||||
// Localization
|
// Localization
|
||||||
import './i18n';
|
import './i18n';
|
||||||
|
import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares';
|
||||||
|
|
||||||
const App = lazy(() => import('./app/App'));
|
const App = lazy(() => import('./app/App'));
|
||||||
const ThemeLocaleProvider = lazy(() => import('./app/ThemeLocaleProvider'));
|
const ThemeLocaleProvider = lazy(() => import('./app/ThemeLocaleProvider'));
|
||||||
|
|
||||||
export default function Component(props: PropsWithChildren) {
|
interface Props extends PropsWithChildren {
|
||||||
|
apiUrl?: string;
|
||||||
|
disabledPanels?: string[];
|
||||||
|
disabledTabs?: InvokeTabName[];
|
||||||
|
token?: string;
|
||||||
|
shouldTransformUrls?: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
export default function Component({
|
||||||
|
apiUrl,
|
||||||
|
disabledPanels = [],
|
||||||
|
disabledTabs = [],
|
||||||
|
token,
|
||||||
|
children,
|
||||||
|
shouldTransformUrls,
|
||||||
|
}: Props) {
|
||||||
|
useEffect(() => {
|
||||||
|
// configure API client token
|
||||||
|
if (token) {
|
||||||
|
OpenAPI.TOKEN = token;
|
||||||
|
}
|
||||||
|
|
||||||
|
// configure API client base url
|
||||||
|
if (apiUrl) {
|
||||||
|
OpenAPI.BASE = apiUrl;
|
||||||
|
}
|
||||||
|
|
||||||
|
// reset dynamically added middlewares
|
||||||
|
resetMiddlewares();
|
||||||
|
|
||||||
|
// TODO: at this point, after resetting the middleware, we really ought to clean up the socket
|
||||||
|
// stuff by calling `dispatch(socketReset())`. but we cannot dispatch from here as we are
|
||||||
|
// outside the provider. it's not needed until there is the possibility that we will change
|
||||||
|
// the `apiUrl`/`token` dynamically.
|
||||||
|
|
||||||
|
// rebuild socket middleware with token and apiUrl
|
||||||
|
addMiddleware(buildMiddleware());
|
||||||
|
}, [apiUrl, token]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<React.StrictMode>
|
<React.StrictMode>
|
||||||
<Provider store={store}>
|
<Provider store={store}>
|
||||||
<PersistGate loading={<Loading />} persistor={persistor}>
|
<PersistGate loading={<Loading />} persistor={persistor}>
|
||||||
<React.Suspense fallback={<Loading showText />}>
|
<React.Suspense fallback={<Loading showText />}>
|
||||||
<ThemeLocaleProvider>
|
<ThemeLocaleProvider>
|
||||||
<App>{props.children}</App>
|
<App
|
||||||
|
options={{ disabledPanels, disabledTabs, shouldTransformUrls }}
|
||||||
|
>
|
||||||
|
{children}
|
||||||
|
</App>
|
||||||
</ThemeLocaleProvider>
|
</ThemeLocaleProvider>
|
||||||
</React.Suspense>
|
</React.Suspense>
|
||||||
</PersistGate>
|
</PersistGate>
|
||||||
|
@ -5,6 +5,8 @@ import ThemeChanger from './features/system/components/ThemeChanger';
|
|||||||
import IAIPopover from './common/components/IAIPopover';
|
import IAIPopover from './common/components/IAIPopover';
|
||||||
import IAIIconButton from './common/components/IAIIconButton';
|
import IAIIconButton from './common/components/IAIIconButton';
|
||||||
import SettingsModal from './features/system/components/SettingsModal/SettingsModal';
|
import SettingsModal from './features/system/components/SettingsModal/SettingsModal';
|
||||||
|
import StatusIndicator from './features/system/components/StatusIndicator';
|
||||||
|
import ModelSelect from 'features/system/components/ModelSelect';
|
||||||
|
|
||||||
export default Component;
|
export default Component;
|
||||||
export {
|
export {
|
||||||
@ -13,4 +15,6 @@ export {
|
|||||||
IAIPopover,
|
IAIPopover,
|
||||||
IAIIconButton,
|
IAIIconButton,
|
||||||
SettingsModal,
|
SettingsModal,
|
||||||
|
StatusIndicator,
|
||||||
|
ModelSelect,
|
||||||
};
|
};
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { RootState } from 'app/store';
|
import { RootState } from 'app/store';
|
||||||
import { useAppSelector } from 'app/storeHooks';
|
import { useAppSelector } from 'app/storeHooks';
|
||||||
|
import { useGetUrl } from 'common/util/getUrl';
|
||||||
import { GalleryState } from 'features/gallery/store/gallerySlice';
|
import { GalleryState } from 'features/gallery/store/gallerySlice';
|
||||||
import { ImageConfig } from 'konva/lib/shapes/Image';
|
import { ImageConfig } from 'konva/lib/shapes/Image';
|
||||||
import { isEqual } from 'lodash';
|
import { isEqual } from 'lodash';
|
||||||
@ -25,7 +26,7 @@ type Props = Omit<ImageConfig, 'image'>;
|
|||||||
const IAICanvasIntermediateImage = (props: Props) => {
|
const IAICanvasIntermediateImage = (props: Props) => {
|
||||||
const { ...rest } = props;
|
const { ...rest } = props;
|
||||||
const intermediateImage = useAppSelector(selector);
|
const intermediateImage = useAppSelector(selector);
|
||||||
|
const { getUrl } = useGetUrl();
|
||||||
const [loadedImageElement, setLoadedImageElement] =
|
const [loadedImageElement, setLoadedImageElement] =
|
||||||
useState<HTMLImageElement | null>(null);
|
useState<HTMLImageElement | null>(null);
|
||||||
|
|
||||||
@ -36,8 +37,8 @@ const IAICanvasIntermediateImage = (props: Props) => {
|
|||||||
tempImage.onload = () => {
|
tempImage.onload = () => {
|
||||||
setLoadedImageElement(tempImage);
|
setLoadedImageElement(tempImage);
|
||||||
};
|
};
|
||||||
tempImage.src = intermediateImage.url;
|
tempImage.src = getUrl(intermediateImage.url);
|
||||||
}, [intermediateImage]);
|
}, [intermediateImage, getUrl]);
|
||||||
|
|
||||||
if (!intermediateImage?.boundingBox) return null;
|
if (!intermediateImage?.boundingBox) return null;
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { useAppSelector } from 'app/storeHooks';
|
import { useAppSelector } from 'app/storeHooks';
|
||||||
|
import { useGetUrl } from 'common/util/getUrl';
|
||||||
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
||||||
import { rgbaColorToString } from 'features/canvas/util/colorToString';
|
import { rgbaColorToString } from 'features/canvas/util/colorToString';
|
||||||
import { isEqual } from 'lodash';
|
import { isEqual } from 'lodash';
|
||||||
@ -32,6 +33,7 @@ const selector = createSelector(
|
|||||||
|
|
||||||
const IAICanvasObjectRenderer = () => {
|
const IAICanvasObjectRenderer = () => {
|
||||||
const { objects } = useAppSelector(selector);
|
const { objects } = useAppSelector(selector);
|
||||||
|
const { getUrl } = useGetUrl();
|
||||||
|
|
||||||
if (!objects) return null;
|
if (!objects) return null;
|
||||||
|
|
||||||
@ -40,7 +42,12 @@ const IAICanvasObjectRenderer = () => {
|
|||||||
{objects.map((obj, i) => {
|
{objects.map((obj, i) => {
|
||||||
if (isCanvasBaseImage(obj)) {
|
if (isCanvasBaseImage(obj)) {
|
||||||
return (
|
return (
|
||||||
<IAICanvasImage key={i} x={obj.x} y={obj.y} url={obj.image.url} />
|
<IAICanvasImage
|
||||||
|
key={i}
|
||||||
|
x={obj.x}
|
||||||
|
y={obj.y}
|
||||||
|
url={getUrl(obj.image.url)}
|
||||||
|
/>
|
||||||
);
|
);
|
||||||
} else if (isCanvasBaseLine(obj)) {
|
} else if (isCanvasBaseLine(obj)) {
|
||||||
const line = (
|
const line = (
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { useAppSelector } from 'app/storeHooks';
|
import { useAppSelector } from 'app/storeHooks';
|
||||||
|
import { useGetUrl } from 'common/util/getUrl';
|
||||||
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
||||||
import { GroupConfig } from 'konva/lib/Group';
|
import { GroupConfig } from 'konva/lib/Group';
|
||||||
import { isEqual } from 'lodash';
|
import { isEqual } from 'lodash';
|
||||||
@ -53,11 +54,16 @@ const IAICanvasStagingArea = (props: Props) => {
|
|||||||
width,
|
width,
|
||||||
height,
|
height,
|
||||||
} = useAppSelector(selector);
|
} = useAppSelector(selector);
|
||||||
|
const { getUrl } = useGetUrl();
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Group {...rest}>
|
<Group {...rest}>
|
||||||
{shouldShowStagingImage && currentStagingAreaImage && (
|
{shouldShowStagingImage && currentStagingAreaImage && (
|
||||||
<IAICanvasImage url={currentStagingAreaImage.image.url} x={x} y={y} />
|
<IAICanvasImage
|
||||||
|
url={getUrl(currentStagingAreaImage.image.url)}
|
||||||
|
x={x}
|
||||||
|
y={y}
|
||||||
|
/>
|
||||||
)}
|
)}
|
||||||
{shouldShowStagingOutline && (
|
{shouldShowStagingOutline && (
|
||||||
<Group>
|
<Group>
|
||||||
|
@ -0,0 +1,14 @@
|
|||||||
|
import { CanvasState } from './canvasTypes';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Canvas slice persist blacklist
|
||||||
|
*/
|
||||||
|
const itemsToBlacklist: (keyof CanvasState)[] = [
|
||||||
|
'cursorPosition',
|
||||||
|
'isCanvasInitialized',
|
||||||
|
'doesCanvasNeedScaling',
|
||||||
|
];
|
||||||
|
|
||||||
|
export const canvasBlacklist = itemsToBlacklist.map(
|
||||||
|
(blacklistItem) => `canvas.${blacklistItem}`
|
||||||
|
);
|
@ -156,7 +156,7 @@ export const canvasSlice = createSlice({
|
|||||||
setCursorPosition: (state, action: PayloadAction<Vector2d | null>) => {
|
setCursorPosition: (state, action: PayloadAction<Vector2d | null>) => {
|
||||||
state.cursorPosition = action.payload;
|
state.cursorPosition = action.payload;
|
||||||
},
|
},
|
||||||
setInitialCanvasImage: (state, action: PayloadAction<InvokeAI.Image>) => {
|
setInitialCanvasImage: (state, action: PayloadAction<InvokeAI._Image>) => {
|
||||||
const image = action.payload;
|
const image = action.payload;
|
||||||
const { stageDimensions } = state;
|
const { stageDimensions } = state;
|
||||||
|
|
||||||
@ -291,7 +291,7 @@ export const canvasSlice = createSlice({
|
|||||||
state,
|
state,
|
||||||
action: PayloadAction<{
|
action: PayloadAction<{
|
||||||
boundingBox: IRect;
|
boundingBox: IRect;
|
||||||
image: InvokeAI.Image;
|
image: InvokeAI._Image;
|
||||||
}>
|
}>
|
||||||
) => {
|
) => {
|
||||||
const { boundingBox, image } = action.payload;
|
const { boundingBox, image } = action.payload;
|
||||||
|
@ -37,7 +37,7 @@ export type CanvasImage = {
|
|||||||
y: number;
|
y: number;
|
||||||
width: number;
|
width: number;
|
||||||
height: number;
|
height: number;
|
||||||
image: InvokeAI.Image;
|
image: InvokeAI._Image;
|
||||||
};
|
};
|
||||||
|
|
||||||
export type CanvasMaskLine = {
|
export type CanvasMaskLine = {
|
||||||
@ -125,7 +125,7 @@ export interface CanvasState {
|
|||||||
cursorPosition: Vector2d | null;
|
cursorPosition: Vector2d | null;
|
||||||
doesCanvasNeedScaling: boolean;
|
doesCanvasNeedScaling: boolean;
|
||||||
futureLayerStates: CanvasLayerState[];
|
futureLayerStates: CanvasLayerState[];
|
||||||
intermediateImage?: InvokeAI.Image;
|
intermediateImage?: InvokeAI._Image;
|
||||||
isCanvasInitialized: boolean;
|
isCanvasInitialized: boolean;
|
||||||
isDrawing: boolean;
|
isDrawing: boolean;
|
||||||
isMaskEnabled: boolean;
|
isMaskEnabled: boolean;
|
||||||
|
@ -105,7 +105,7 @@ export const mergeAndUploadCanvas =
|
|||||||
|
|
||||||
const { url, width, height } = image;
|
const { url, width, height } = image;
|
||||||
|
|
||||||
const newImage: InvokeAI.Image = {
|
const newImage: InvokeAI._Image = {
|
||||||
uuid: uuidv4(),
|
uuid: uuidv4(),
|
||||||
category: shouldSaveToGallery ? 'result' : 'user',
|
category: shouldSaveToGallery ? 'result' : 'user',
|
||||||
...image,
|
...image,
|
||||||
|
@ -14,8 +14,9 @@ import { setIsLightboxOpen } from 'features/lightbox/store/lightboxSlice';
|
|||||||
import FaceRestoreSettings from 'features/parameters/components/AdvancedParameters/FaceRestore/FaceRestoreSettings';
|
import FaceRestoreSettings from 'features/parameters/components/AdvancedParameters/FaceRestore/FaceRestoreSettings';
|
||||||
import UpscaleSettings from 'features/parameters/components/AdvancedParameters/Upscale/UpscaleSettings';
|
import UpscaleSettings from 'features/parameters/components/AdvancedParameters/Upscale/UpscaleSettings';
|
||||||
import {
|
import {
|
||||||
|
initialImageSelected,
|
||||||
setAllParameters,
|
setAllParameters,
|
||||||
setInitialImage,
|
// setInitialImage,
|
||||||
setSeed,
|
setSeed,
|
||||||
} from 'features/parameters/store/generationSlice';
|
} from 'features/parameters/store/generationSlice';
|
||||||
import { postprocessingSelector } from 'features/parameters/store/postprocessingSelectors';
|
import { postprocessingSelector } from 'features/parameters/store/postprocessingSelectors';
|
||||||
@ -48,11 +49,15 @@ import {
|
|||||||
FaShareAlt,
|
FaShareAlt,
|
||||||
FaTrash,
|
FaTrash,
|
||||||
} from 'react-icons/fa';
|
} from 'react-icons/fa';
|
||||||
import { gallerySelector } from '../store/gallerySelectors';
|
import {
|
||||||
|
gallerySelector,
|
||||||
|
selectedImageSelector,
|
||||||
|
} from '../store/gallerySelectors';
|
||||||
import DeleteImageModal from './DeleteImageModal';
|
import DeleteImageModal from './DeleteImageModal';
|
||||||
import { useCallback } from 'react';
|
import { useCallback } from 'react';
|
||||||
import useSetBothPrompts from 'features/parameters/hooks/usePrompt';
|
import useSetBothPrompts from 'features/parameters/hooks/usePrompt';
|
||||||
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||||
|
import { useGetUrl } from 'common/util/getUrl';
|
||||||
|
|
||||||
const currentImageButtonsSelector = createSelector(
|
const currentImageButtonsSelector = createSelector(
|
||||||
[
|
[
|
||||||
@ -62,6 +67,7 @@ const currentImageButtonsSelector = createSelector(
|
|||||||
uiSelector,
|
uiSelector,
|
||||||
lightboxSelector,
|
lightboxSelector,
|
||||||
activeTabNameSelector,
|
activeTabNameSelector,
|
||||||
|
selectedImageSelector,
|
||||||
],
|
],
|
||||||
(
|
(
|
||||||
system: SystemState,
|
system: SystemState,
|
||||||
@ -69,7 +75,8 @@ const currentImageButtonsSelector = createSelector(
|
|||||||
postprocessing,
|
postprocessing,
|
||||||
ui,
|
ui,
|
||||||
lightbox,
|
lightbox,
|
||||||
activeTabName
|
activeTabName,
|
||||||
|
selectedImage
|
||||||
) => {
|
) => {
|
||||||
const { isProcessing, isConnected, isGFPGANAvailable, isESRGANAvailable } =
|
const { isProcessing, isConnected, isGFPGANAvailable, isESRGANAvailable } =
|
||||||
system;
|
system;
|
||||||
@ -95,6 +102,7 @@ const currentImageButtonsSelector = createSelector(
|
|||||||
activeTabName,
|
activeTabName,
|
||||||
isLightboxOpen,
|
isLightboxOpen,
|
||||||
shouldHidePreview,
|
shouldHidePreview,
|
||||||
|
selectedImage,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -121,27 +129,33 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
facetoolStrength,
|
facetoolStrength,
|
||||||
shouldDisableToolbarButtons,
|
shouldDisableToolbarButtons,
|
||||||
shouldShowImageDetails,
|
shouldShowImageDetails,
|
||||||
currentImage,
|
// currentImage,
|
||||||
isLightboxOpen,
|
isLightboxOpen,
|
||||||
activeTabName,
|
activeTabName,
|
||||||
shouldHidePreview,
|
shouldHidePreview,
|
||||||
|
selectedImage,
|
||||||
} = useAppSelector(currentImageButtonsSelector);
|
} = useAppSelector(currentImageButtonsSelector);
|
||||||
|
const { getUrl, shouldTransformUrls } = useGetUrl();
|
||||||
|
|
||||||
const toast = useToast();
|
const toast = useToast();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const setBothPrompts = useSetBothPrompts();
|
const setBothPrompts = useSetBothPrompts();
|
||||||
|
|
||||||
const handleClickUseAsInitialImage = () => {
|
const handleClickUseAsInitialImage = () => {
|
||||||
if (!currentImage) return;
|
if (!selectedImage) return;
|
||||||
if (isLightboxOpen) dispatch(setIsLightboxOpen(false));
|
if (isLightboxOpen) dispatch(setIsLightboxOpen(false));
|
||||||
dispatch(setInitialImage(currentImage));
|
dispatch(initialImageSelected(selectedImage.name));
|
||||||
dispatch(setActiveTab('img2img'));
|
// dispatch(setInitialImage(currentImage));
|
||||||
|
|
||||||
|
// dispatch(setActiveTab('img2img'));
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleCopyImage = async () => {
|
const handleCopyImage = async () => {
|
||||||
if (!currentImage) return;
|
if (!selectedImage) return;
|
||||||
|
|
||||||
const blob = await fetch(currentImage.url).then((res) => res.blob());
|
const blob = await fetch(getUrl(selectedImage.url)).then((res) =>
|
||||||
|
res.blob()
|
||||||
|
);
|
||||||
const data = [new ClipboardItem({ [blob.type]: blob })];
|
const data = [new ClipboardItem({ [blob.type]: blob })];
|
||||||
|
|
||||||
await navigator.clipboard.write(data);
|
await navigator.clipboard.write(data);
|
||||||
@ -155,24 +169,26 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const handleCopyImageLink = () => {
|
const handleCopyImageLink = () => {
|
||||||
navigator.clipboard
|
const url = selectedImage
|
||||||
.writeText(
|
? shouldTransformUrls
|
||||||
currentImage ? window.location.toString() + currentImage.url : ''
|
? getUrl(selectedImage.url)
|
||||||
)
|
: window.location.toString() + selectedImage.url
|
||||||
.then(() => {
|
: '';
|
||||||
toast({
|
|
||||||
title: t('toast.imageLinkCopied'),
|
navigator.clipboard.writeText(url).then(() => {
|
||||||
status: 'success',
|
toast({
|
||||||
duration: 2500,
|
title: t('toast.imageLinkCopied'),
|
||||||
isClosable: true,
|
status: 'success',
|
||||||
});
|
duration: 2500,
|
||||||
|
isClosable: true,
|
||||||
});
|
});
|
||||||
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
useHotkeys(
|
useHotkeys(
|
||||||
'shift+i',
|
'shift+i',
|
||||||
() => {
|
() => {
|
||||||
if (currentImage) {
|
if (selectedImage) {
|
||||||
handleClickUseAsInitialImage();
|
handleClickUseAsInitialImage();
|
||||||
toast({
|
toast({
|
||||||
title: t('toast.sentToImageToImage'),
|
title: t('toast.sentToImageToImage'),
|
||||||
@ -190,7 +206,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[currentImage]
|
[selectedImage]
|
||||||
);
|
);
|
||||||
|
|
||||||
const handlePreviewVisibility = () => {
|
const handlePreviewVisibility = () => {
|
||||||
@ -198,20 +214,23 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const handleClickUseAllParameters = () => {
|
const handleClickUseAllParameters = () => {
|
||||||
if (!currentImage) return;
|
if (!selectedImage) return;
|
||||||
currentImage.metadata && dispatch(setAllParameters(currentImage.metadata));
|
// selectedImage.metadata &&
|
||||||
if (currentImage.metadata?.image.type === 'img2img') {
|
// dispatch(setAllParameters(selectedImage.metadata));
|
||||||
dispatch(setActiveTab('img2img'));
|
// if (selectedImage.metadata?.image.type === 'img2img') {
|
||||||
} else if (currentImage.metadata?.image.type === 'txt2img') {
|
// dispatch(setActiveTab('img2img'));
|
||||||
dispatch(setActiveTab('txt2img'));
|
// } else if (selectedImage.metadata?.image.type === 'txt2img') {
|
||||||
}
|
// dispatch(setActiveTab('txt2img'));
|
||||||
|
// }
|
||||||
};
|
};
|
||||||
|
|
||||||
useHotkeys(
|
useHotkeys(
|
||||||
'a',
|
'a',
|
||||||
() => {
|
() => {
|
||||||
if (
|
if (
|
||||||
['txt2img', 'img2img'].includes(currentImage?.metadata?.image?.type)
|
['txt2img', 'img2img'].includes(
|
||||||
|
selectedImage?.metadata?.sd_metadata?.type
|
||||||
|
)
|
||||||
) {
|
) {
|
||||||
handleClickUseAllParameters();
|
handleClickUseAllParameters();
|
||||||
toast({
|
toast({
|
||||||
@ -230,18 +249,18 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[currentImage]
|
[selectedImage]
|
||||||
);
|
);
|
||||||
|
|
||||||
const handleClickUseSeed = () => {
|
const handleClickUseSeed = () => {
|
||||||
currentImage?.metadata &&
|
selectedImage?.metadata &&
|
||||||
dispatch(setSeed(currentImage.metadata.image.seed));
|
dispatch(setSeed(selectedImage.metadata.sd_metadata.seed));
|
||||||
};
|
};
|
||||||
|
|
||||||
useHotkeys(
|
useHotkeys(
|
||||||
's',
|
's',
|
||||||
() => {
|
() => {
|
||||||
if (currentImage?.metadata?.image?.seed) {
|
if (selectedImage?.metadata?.sd_metadata?.seed) {
|
||||||
handleClickUseSeed();
|
handleClickUseSeed();
|
||||||
toast({
|
toast({
|
||||||
title: t('toast.seedSet'),
|
title: t('toast.seedSet'),
|
||||||
@ -259,19 +278,19 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[currentImage]
|
[selectedImage]
|
||||||
);
|
);
|
||||||
|
|
||||||
const handleClickUsePrompt = useCallback(() => {
|
const handleClickUsePrompt = useCallback(() => {
|
||||||
if (currentImage?.metadata?.image?.prompt) {
|
if (selectedImage?.metadata?.sd_metadata?.prompt) {
|
||||||
setBothPrompts(currentImage?.metadata?.image?.prompt);
|
setBothPrompts(selectedImage?.metadata?.sd_metadata?.prompt);
|
||||||
}
|
}
|
||||||
}, [currentImage?.metadata?.image?.prompt, setBothPrompts]);
|
}, [selectedImage?.metadata?.sd_metadata?.prompt, setBothPrompts]);
|
||||||
|
|
||||||
useHotkeys(
|
useHotkeys(
|
||||||
'p',
|
'p',
|
||||||
() => {
|
() => {
|
||||||
if (currentImage?.metadata?.image?.prompt) {
|
if (selectedImage?.metadata?.sd_metadata?.prompt) {
|
||||||
handleClickUsePrompt();
|
handleClickUsePrompt();
|
||||||
toast({
|
toast({
|
||||||
title: t('toast.promptSet'),
|
title: t('toast.promptSet'),
|
||||||
@ -289,11 +308,11 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[currentImage]
|
[selectedImage]
|
||||||
);
|
);
|
||||||
|
|
||||||
const handleClickUpscale = () => {
|
const handleClickUpscale = () => {
|
||||||
currentImage && dispatch(runESRGAN(currentImage));
|
// selectedImage && dispatch(runESRGAN(selectedImage));
|
||||||
};
|
};
|
||||||
|
|
||||||
useHotkeys(
|
useHotkeys(
|
||||||
@ -317,7 +336,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
[
|
[
|
||||||
currentImage,
|
selectedImage,
|
||||||
isESRGANAvailable,
|
isESRGANAvailable,
|
||||||
shouldDisableToolbarButtons,
|
shouldDisableToolbarButtons,
|
||||||
isConnected,
|
isConnected,
|
||||||
@ -327,7 +346,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
const handleClickFixFaces = () => {
|
const handleClickFixFaces = () => {
|
||||||
currentImage && dispatch(runFacetool(currentImage));
|
// selectedImage && dispatch(runFacetool(selectedImage));
|
||||||
};
|
};
|
||||||
|
|
||||||
useHotkeys(
|
useHotkeys(
|
||||||
@ -351,7 +370,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
[
|
[
|
||||||
currentImage,
|
selectedImage,
|
||||||
isGFPGANAvailable,
|
isGFPGANAvailable,
|
||||||
shouldDisableToolbarButtons,
|
shouldDisableToolbarButtons,
|
||||||
isConnected,
|
isConnected,
|
||||||
@ -364,10 +383,10 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
dispatch(setShouldShowImageDetails(!shouldShowImageDetails));
|
dispatch(setShouldShowImageDetails(!shouldShowImageDetails));
|
||||||
|
|
||||||
const handleSendToCanvas = () => {
|
const handleSendToCanvas = () => {
|
||||||
if (!currentImage) return;
|
if (!selectedImage) return;
|
||||||
if (isLightboxOpen) dispatch(setIsLightboxOpen(false));
|
if (isLightboxOpen) dispatch(setIsLightboxOpen(false));
|
||||||
|
|
||||||
dispatch(setInitialCanvasImage(currentImage));
|
// dispatch(setInitialCanvasImage(selectedImage));
|
||||||
dispatch(requestCanvasRescale());
|
dispatch(requestCanvasRescale());
|
||||||
|
|
||||||
if (activeTabName !== 'unifiedCanvas') {
|
if (activeTabName !== 'unifiedCanvas') {
|
||||||
@ -385,7 +404,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
useHotkeys(
|
useHotkeys(
|
||||||
'i',
|
'i',
|
||||||
() => {
|
() => {
|
||||||
if (currentImage) {
|
if (selectedImage) {
|
||||||
handleClickShowImageDetails();
|
handleClickShowImageDetails();
|
||||||
} else {
|
} else {
|
||||||
toast({
|
toast({
|
||||||
@ -396,7 +415,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[currentImage, shouldShowImageDetails]
|
[selectedImage, shouldShowImageDetails]
|
||||||
);
|
);
|
||||||
|
|
||||||
const handleLightBox = () => {
|
const handleLightBox = () => {
|
||||||
@ -458,7 +477,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
{t('parameters.copyImageToLink')}
|
{t('parameters.copyImageToLink')}
|
||||||
</IAIButton>
|
</IAIButton>
|
||||||
|
|
||||||
<Link download={true} href={currentImage?.url}>
|
<Link download={true} href={getUrl(selectedImage!.url)}>
|
||||||
<IAIButton leftIcon={<FaDownload />} size="sm" w="100%">
|
<IAIButton leftIcon={<FaDownload />} size="sm" w="100%">
|
||||||
{t('parameters.downloadImage')}
|
{t('parameters.downloadImage')}
|
||||||
</IAIButton>
|
</IAIButton>
|
||||||
@ -502,7 +521,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
icon={<FaQuoteRight />}
|
icon={<FaQuoteRight />}
|
||||||
tooltip={`${t('parameters.usePrompt')} (P)`}
|
tooltip={`${t('parameters.usePrompt')} (P)`}
|
||||||
aria-label={`${t('parameters.usePrompt')} (P)`}
|
aria-label={`${t('parameters.usePrompt')} (P)`}
|
||||||
isDisabled={!currentImage?.metadata?.image?.prompt}
|
isDisabled={!selectedImage?.metadata?.sd_metadata?.prompt}
|
||||||
onClick={handleClickUsePrompt}
|
onClick={handleClickUsePrompt}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
@ -510,7 +529,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
icon={<FaSeedling />}
|
icon={<FaSeedling />}
|
||||||
tooltip={`${t('parameters.useSeed')} (S)`}
|
tooltip={`${t('parameters.useSeed')} (S)`}
|
||||||
aria-label={`${t('parameters.useSeed')} (S)`}
|
aria-label={`${t('parameters.useSeed')} (S)`}
|
||||||
isDisabled={!currentImage?.metadata?.image?.seed}
|
isDisabled={!selectedImage?.metadata?.sd_metadata?.seed}
|
||||||
onClick={handleClickUseSeed}
|
onClick={handleClickUseSeed}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
@ -520,7 +539,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
aria-label={`${t('parameters.useAll')} (A)`}
|
aria-label={`${t('parameters.useAll')} (A)`}
|
||||||
isDisabled={
|
isDisabled={
|
||||||
!['txt2img', 'img2img'].includes(
|
!['txt2img', 'img2img'].includes(
|
||||||
currentImage?.metadata?.image?.type
|
selectedImage?.metadata?.sd_metadata?.type
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
onClick={handleClickUseAllParameters}
|
onClick={handleClickUseAllParameters}
|
||||||
@ -546,7 +565,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
<IAIButton
|
<IAIButton
|
||||||
isDisabled={
|
isDisabled={
|
||||||
!isGFPGANAvailable ||
|
!isGFPGANAvailable ||
|
||||||
!currentImage ||
|
!selectedImage ||
|
||||||
!(isConnected && !isProcessing) ||
|
!(isConnected && !isProcessing) ||
|
||||||
!facetoolStrength
|
!facetoolStrength
|
||||||
}
|
}
|
||||||
@ -575,7 +594,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
<IAIButton
|
<IAIButton
|
||||||
isDisabled={
|
isDisabled={
|
||||||
!isESRGANAvailable ||
|
!isESRGANAvailable ||
|
||||||
!currentImage ||
|
!selectedImage ||
|
||||||
!(isConnected && !isProcessing) ||
|
!(isConnected && !isProcessing) ||
|
||||||
!upscalingLevel
|
!upscalingLevel
|
||||||
}
|
}
|
||||||
@ -597,15 +616,15 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
/>
|
/>
|
||||||
</ButtonGroup>
|
</ButtonGroup>
|
||||||
|
|
||||||
<DeleteImageModal image={currentImage}>
|
{/* <DeleteImageModal image={selectedImage}>
|
||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
icon={<FaTrash />}
|
icon={<FaTrash />}
|
||||||
tooltip={`${t('parameters.deleteImage')} (Del)`}
|
tooltip={`${t('parameters.deleteImage')} (Del)`}
|
||||||
aria-label={`${t('parameters.deleteImage')} (Del)`}
|
aria-label={`${t('parameters.deleteImage')} (Del)`}
|
||||||
isDisabled={!currentImage || !isConnected || isProcessing}
|
isDisabled={!selectedImage || !isConnected || isProcessing}
|
||||||
colorScheme="error"
|
colorScheme="error"
|
||||||
/>
|
/>
|
||||||
</DeleteImageModal>
|
</DeleteImageModal> */}
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -4,17 +4,20 @@ import { useAppSelector } from 'app/storeHooks';
|
|||||||
import { isEqual } from 'lodash';
|
import { isEqual } from 'lodash';
|
||||||
|
|
||||||
import { MdPhoto } from 'react-icons/md';
|
import { MdPhoto } from 'react-icons/md';
|
||||||
import { gallerySelector } from '../store/gallerySelectors';
|
import {
|
||||||
|
gallerySelector,
|
||||||
|
selectedImageSelector,
|
||||||
|
} from '../store/gallerySelectors';
|
||||||
import CurrentImageButtons from './CurrentImageButtons';
|
import CurrentImageButtons from './CurrentImageButtons';
|
||||||
import CurrentImagePreview from './CurrentImagePreview';
|
import CurrentImagePreview from './CurrentImagePreview';
|
||||||
|
|
||||||
export const currentImageDisplaySelector = createSelector(
|
export const currentImageDisplaySelector = createSelector(
|
||||||
[gallerySelector],
|
[gallerySelector, selectedImageSelector],
|
||||||
(gallery) => {
|
(gallery, selectedImage) => {
|
||||||
const { currentImage, intermediateImage } = gallery;
|
const { currentImage, intermediateImage } = gallery;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
hasAnImageToDisplay: currentImage || intermediateImage,
|
hasAnImageToDisplay: selectedImage || intermediateImage,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -1,28 +1,48 @@
|
|||||||
import { Box, Flex, Image } from '@chakra-ui/react';
|
import { Box, Flex, Image } from '@chakra-ui/react';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { useAppSelector } from 'app/storeHooks';
|
import { useAppSelector } from 'app/storeHooks';
|
||||||
import { GalleryState } from 'features/gallery/store/gallerySlice';
|
import { useGetUrl } from 'common/util/getUrl';
|
||||||
|
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||||
import { uiSelector } from 'features/ui/store/uiSelectors';
|
import { uiSelector } from 'features/ui/store/uiSelectors';
|
||||||
import { isEqual } from 'lodash';
|
import { isEqual } from 'lodash';
|
||||||
|
import { ReactEventHandler } from 'react';
|
||||||
import { APP_METADATA_HEIGHT } from 'theme/util/constants';
|
import { APP_METADATA_HEIGHT } from 'theme/util/constants';
|
||||||
|
|
||||||
import { gallerySelector } from '../store/gallerySelectors';
|
import { selectedImageSelector } from '../store/gallerySelectors';
|
||||||
import CurrentImageFallback from './CurrentImageFallback';
|
import CurrentImageFallback from './CurrentImageFallback';
|
||||||
import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer';
|
import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer';
|
||||||
import NextPrevImageButtons from './NextPrevImageButtons';
|
import NextPrevImageButtons from './NextPrevImageButtons';
|
||||||
import CurrentImageHidden from './CurrentImageHidden';
|
import CurrentImageHidden from './CurrentImageHidden';
|
||||||
|
|
||||||
export const imagesSelector = createSelector(
|
export const imagesSelector = createSelector(
|
||||||
[gallerySelector, uiSelector],
|
[uiSelector, selectedImageSelector, systemSelector],
|
||||||
(gallery: GalleryState, ui) => {
|
(ui, selectedImage, system) => {
|
||||||
const { currentImage, intermediateImage } = gallery;
|
|
||||||
const { shouldShowImageDetails, shouldHidePreview } = ui;
|
const { shouldShowImageDetails, shouldHidePreview } = ui;
|
||||||
|
const { progressImage } = system;
|
||||||
|
|
||||||
|
// TODO: Clean this up, this is really gross
|
||||||
|
const imageToDisplay = progressImage
|
||||||
|
? {
|
||||||
|
url: progressImage.dataURL,
|
||||||
|
width: progressImage.width,
|
||||||
|
height: progressImage.height,
|
||||||
|
isProgressImage: true,
|
||||||
|
image: progressImage,
|
||||||
|
}
|
||||||
|
: selectedImage
|
||||||
|
? {
|
||||||
|
url: selectedImage.url,
|
||||||
|
width: selectedImage.metadata.width,
|
||||||
|
height: selectedImage.metadata.height,
|
||||||
|
isProgressImage: false,
|
||||||
|
image: selectedImage,
|
||||||
|
}
|
||||||
|
: null;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
imageToDisplay: intermediateImage ? intermediateImage : currentImage,
|
|
||||||
isIntermediate: Boolean(intermediateImage),
|
|
||||||
shouldShowImageDetails,
|
shouldShowImageDetails,
|
||||||
shouldHidePreview,
|
shouldHidePreview,
|
||||||
|
imageToDisplay,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -33,12 +53,9 @@ export const imagesSelector = createSelector(
|
|||||||
);
|
);
|
||||||
|
|
||||||
export default function CurrentImagePreview() {
|
export default function CurrentImagePreview() {
|
||||||
const {
|
const { shouldShowImageDetails, imageToDisplay, shouldHidePreview } =
|
||||||
shouldShowImageDetails,
|
useAppSelector(imagesSelector);
|
||||||
imageToDisplay,
|
const { getUrl } = useGetUrl();
|
||||||
isIntermediate,
|
|
||||||
shouldHidePreview,
|
|
||||||
} = useAppSelector(imagesSelector);
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex
|
<Flex
|
||||||
@ -52,13 +69,19 @@ export default function CurrentImagePreview() {
|
|||||||
>
|
>
|
||||||
{imageToDisplay && (
|
{imageToDisplay && (
|
||||||
<Image
|
<Image
|
||||||
src={shouldHidePreview ? undefined : imageToDisplay.url}
|
src={
|
||||||
|
shouldHidePreview
|
||||||
|
? undefined
|
||||||
|
: imageToDisplay.isProgressImage
|
||||||
|
? imageToDisplay.url
|
||||||
|
: getUrl(imageToDisplay.url)
|
||||||
|
}
|
||||||
width={imageToDisplay.width}
|
width={imageToDisplay.width}
|
||||||
height={imageToDisplay.height}
|
height={imageToDisplay.height}
|
||||||
fallback={
|
fallback={
|
||||||
shouldHidePreview ? (
|
shouldHidePreview ? (
|
||||||
<CurrentImageHidden />
|
<CurrentImageHidden />
|
||||||
) : !isIntermediate ? (
|
) : !imageToDisplay.isProgressImage ? (
|
||||||
<CurrentImageFallback />
|
<CurrentImageFallback />
|
||||||
) : undefined
|
) : undefined
|
||||||
}
|
}
|
||||||
@ -68,27 +91,31 @@ export default function CurrentImagePreview() {
|
|||||||
maxHeight: '100%',
|
maxHeight: '100%',
|
||||||
height: 'auto',
|
height: 'auto',
|
||||||
position: 'absolute',
|
position: 'absolute',
|
||||||
imageRendering: isIntermediate ? 'pixelated' : 'initial',
|
imageRendering: imageToDisplay.isProgressImage
|
||||||
|
? 'pixelated'
|
||||||
|
: 'initial',
|
||||||
borderRadius: 'base',
|
borderRadius: 'base',
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{!shouldShowImageDetails && <NextPrevImageButtons />}
|
{!shouldShowImageDetails && <NextPrevImageButtons />}
|
||||||
{shouldShowImageDetails && imageToDisplay && (
|
{shouldShowImageDetails &&
|
||||||
<Box
|
imageToDisplay &&
|
||||||
sx={{
|
'metadata' in imageToDisplay.image && (
|
||||||
position: 'absolute',
|
<Box
|
||||||
top: '0',
|
sx={{
|
||||||
width: '100%',
|
position: 'absolute',
|
||||||
height: '100%',
|
top: '0',
|
||||||
borderRadius: 'base',
|
width: '100%',
|
||||||
overflow: 'scroll',
|
height: '100%',
|
||||||
maxHeight: APP_METADATA_HEIGHT,
|
borderRadius: 'base',
|
||||||
}}
|
overflow: 'scroll',
|
||||||
>
|
maxHeight: APP_METADATA_HEIGHT,
|
||||||
<ImageMetadataViewer image={imageToDisplay} />
|
}}
|
||||||
</Box>
|
>
|
||||||
)}
|
<ImageMetadataViewer image={imageToDisplay.image} />
|
||||||
|
</Box>
|
||||||
|
)}
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -52,7 +52,7 @@ interface DeleteImageModalProps {
|
|||||||
/**
|
/**
|
||||||
* The image to delete.
|
* The image to delete.
|
||||||
*/
|
*/
|
||||||
image?: InvokeAI.Image;
|
image?: InvokeAI._Image;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -9,11 +9,14 @@ import {
|
|||||||
useToast,
|
useToast,
|
||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||||
import { setCurrentImage } from 'features/gallery/store/gallerySlice';
|
|
||||||
import {
|
import {
|
||||||
|
imageSelected,
|
||||||
|
setCurrentImage,
|
||||||
|
} from 'features/gallery/store/gallerySlice';
|
||||||
|
import {
|
||||||
|
initialImageSelected,
|
||||||
setAllImageToImageParameters,
|
setAllImageToImageParameters,
|
||||||
setAllParameters,
|
setAllParameters,
|
||||||
setInitialImage,
|
|
||||||
setSeed,
|
setSeed,
|
||||||
} from 'features/parameters/store/generationSlice';
|
} from 'features/parameters/store/generationSlice';
|
||||||
import { DragEvent, memo, useState } from 'react';
|
import { DragEvent, memo, useState } from 'react';
|
||||||
@ -31,6 +34,7 @@ import { useTranslation } from 'react-i18next';
|
|||||||
import useSetBothPrompts from 'features/parameters/hooks/usePrompt';
|
import useSetBothPrompts from 'features/parameters/hooks/usePrompt';
|
||||||
import { setIsLightboxOpen } from 'features/lightbox/store/lightboxSlice';
|
import { setIsLightboxOpen } from 'features/lightbox/store/lightboxSlice';
|
||||||
import IAIIconButton from 'common/components/IAIIconButton';
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
|
import { useGetUrl } from 'common/util/getUrl';
|
||||||
|
|
||||||
interface HoverableImageProps {
|
interface HoverableImageProps {
|
||||||
image: InvokeAI.Image;
|
image: InvokeAI.Image;
|
||||||
@ -40,7 +44,7 @@ interface HoverableImageProps {
|
|||||||
const memoEqualityCheck = (
|
const memoEqualityCheck = (
|
||||||
prev: HoverableImageProps,
|
prev: HoverableImageProps,
|
||||||
next: HoverableImageProps
|
next: HoverableImageProps
|
||||||
) => prev.image.uuid === next.image.uuid && prev.isSelected === next.isSelected;
|
) => prev.image.name === next.image.name && prev.isSelected === next.isSelected;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Gallery image component with delete/use all/use seed buttons on hover.
|
* Gallery image component with delete/use all/use seed buttons on hover.
|
||||||
@ -55,7 +59,8 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
|||||||
shouldUseSingleGalleryColumn,
|
shouldUseSingleGalleryColumn,
|
||||||
} = useAppSelector(hoverableImageSelector);
|
} = useAppSelector(hoverableImageSelector);
|
||||||
const { image, isSelected } = props;
|
const { image, isSelected } = props;
|
||||||
const { url, thumbnail, uuid, metadata } = image;
|
const { url, thumbnail, name, metadata } = image;
|
||||||
|
const { getUrl } = useGetUrl();
|
||||||
|
|
||||||
const [isHovered, setIsHovered] = useState<boolean>(false);
|
const [isHovered, setIsHovered] = useState<boolean>(false);
|
||||||
|
|
||||||
@ -69,10 +74,9 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
|||||||
const handleMouseOut = () => setIsHovered(false);
|
const handleMouseOut = () => setIsHovered(false);
|
||||||
|
|
||||||
const handleUsePrompt = () => {
|
const handleUsePrompt = () => {
|
||||||
if (image.metadata?.image?.prompt) {
|
if (image.metadata?.sd_metadata?.prompt) {
|
||||||
setBothPrompts(image.metadata?.image?.prompt);
|
setBothPrompts(image.metadata?.sd_metadata?.prompt);
|
||||||
}
|
}
|
||||||
|
|
||||||
toast({
|
toast({
|
||||||
title: t('toast.promptSet'),
|
title: t('toast.promptSet'),
|
||||||
status: 'success',
|
status: 'success',
|
||||||
@ -82,7 +86,8 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const handleUseSeed = () => {
|
const handleUseSeed = () => {
|
||||||
image.metadata && dispatch(setSeed(image.metadata.image.seed));
|
image.metadata.sd_metadata &&
|
||||||
|
dispatch(setSeed(image.metadata.sd_metadata.image.seed));
|
||||||
toast({
|
toast({
|
||||||
title: t('toast.seedSet'),
|
title: t('toast.seedSet'),
|
||||||
status: 'success',
|
status: 'success',
|
||||||
@ -92,20 +97,11 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const handleSendToImageToImage = () => {
|
const handleSendToImageToImage = () => {
|
||||||
dispatch(setInitialImage(image));
|
dispatch(initialImageSelected(image.name));
|
||||||
if (activeTabName !== 'img2img') {
|
|
||||||
dispatch(setActiveTab('img2img'));
|
|
||||||
}
|
|
||||||
toast({
|
|
||||||
title: t('toast.sentToImageToImage'),
|
|
||||||
status: 'success',
|
|
||||||
duration: 2500,
|
|
||||||
isClosable: true,
|
|
||||||
});
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleSendToCanvas = () => {
|
const handleSendToCanvas = () => {
|
||||||
dispatch(setInitialCanvasImage(image));
|
// dispatch(setInitialCanvasImage(image));
|
||||||
|
|
||||||
dispatch(resizeAndScaleCanvas());
|
dispatch(resizeAndScaleCanvas());
|
||||||
|
|
||||||
@ -122,7 +118,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const handleUseAllParameters = () => {
|
const handleUseAllParameters = () => {
|
||||||
metadata && dispatch(setAllParameters(metadata));
|
metadata.sd_metadata && dispatch(setAllParameters(metadata.sd_metadata));
|
||||||
toast({
|
toast({
|
||||||
title: t('toast.parametersSet'),
|
title: t('toast.parametersSet'),
|
||||||
status: 'success',
|
status: 'success',
|
||||||
@ -132,11 +128,13 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const handleUseInitialImage = async () => {
|
const handleUseInitialImage = async () => {
|
||||||
if (metadata?.image?.init_image_path) {
|
if (metadata.sd_metadata?.image?.init_image_path) {
|
||||||
const response = await fetch(metadata.image.init_image_path);
|
const response = await fetch(
|
||||||
|
metadata.sd_metadata?.image?.init_image_path
|
||||||
|
);
|
||||||
if (response.ok) {
|
if (response.ok) {
|
||||||
dispatch(setActiveTab('img2img'));
|
dispatch(setActiveTab('img2img'));
|
||||||
dispatch(setAllImageToImageParameters(metadata));
|
dispatch(setAllImageToImageParameters(metadata?.sd_metadata));
|
||||||
toast({
|
toast({
|
||||||
title: t('toast.initialImageSet'),
|
title: t('toast.initialImageSet'),
|
||||||
status: 'success',
|
status: 'success',
|
||||||
@ -155,16 +153,20 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
|||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleSelectImage = () => dispatch(setCurrentImage(image));
|
const handleSelectImage = () => {
|
||||||
|
dispatch(imageSelected(image.name));
|
||||||
|
};
|
||||||
|
|
||||||
const handleDragStart = (e: DragEvent<HTMLDivElement>) => {
|
const handleDragStart = (e: DragEvent<HTMLDivElement>) => {
|
||||||
e.dataTransfer.setData('invokeai/imageUuid', uuid);
|
console.log('drag started');
|
||||||
|
e.dataTransfer.setData('invokeai/imageName', image.name);
|
||||||
|
e.dataTransfer.setData('invokeai/imageType', image.type);
|
||||||
e.dataTransfer.effectAllowed = 'move';
|
e.dataTransfer.effectAllowed = 'move';
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleLightBox = () => {
|
const handleLightBox = () => {
|
||||||
dispatch(setCurrentImage(image));
|
// dispatch(setCurrentImage(image));
|
||||||
dispatch(setIsLightboxOpen(true));
|
// dispatch(setIsLightboxOpen(true));
|
||||||
};
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@ -177,28 +179,30 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
|||||||
</MenuItem>
|
</MenuItem>
|
||||||
<MenuItem
|
<MenuItem
|
||||||
onClickCapture={handleUsePrompt}
|
onClickCapture={handleUsePrompt}
|
||||||
isDisabled={image?.metadata?.image?.prompt === undefined}
|
isDisabled={image?.metadata?.sd_metadata?.prompt === undefined}
|
||||||
>
|
>
|
||||||
{t('parameters.usePrompt')}
|
{t('parameters.usePrompt')}
|
||||||
</MenuItem>
|
</MenuItem>
|
||||||
|
|
||||||
<MenuItem
|
<MenuItem
|
||||||
onClickCapture={handleUseSeed}
|
onClickCapture={handleUseSeed}
|
||||||
isDisabled={image?.metadata?.image?.seed === undefined}
|
isDisabled={image?.metadata?.sd_metadata?.seed === undefined}
|
||||||
>
|
>
|
||||||
{t('parameters.useSeed')}
|
{t('parameters.useSeed')}
|
||||||
</MenuItem>
|
</MenuItem>
|
||||||
<MenuItem
|
<MenuItem
|
||||||
onClickCapture={handleUseAllParameters}
|
onClickCapture={handleUseAllParameters}
|
||||||
isDisabled={
|
isDisabled={
|
||||||
!['txt2img', 'img2img'].includes(image?.metadata?.image?.type)
|
!['txt2img', 'img2img'].includes(
|
||||||
|
image?.metadata?.sd_metadata?.type
|
||||||
|
)
|
||||||
}
|
}
|
||||||
>
|
>
|
||||||
{t('parameters.useAll')}
|
{t('parameters.useAll')}
|
||||||
</MenuItem>
|
</MenuItem>
|
||||||
<MenuItem
|
<MenuItem
|
||||||
onClickCapture={handleUseInitialImage}
|
onClickCapture={handleUseInitialImage}
|
||||||
isDisabled={image?.metadata?.image?.type !== 'img2img'}
|
isDisabled={image?.metadata?.sd_metadata?.type !== 'img2img'}
|
||||||
>
|
>
|
||||||
{t('parameters.useInitImg')}
|
{t('parameters.useInitImg')}
|
||||||
</MenuItem>
|
</MenuItem>
|
||||||
@ -209,9 +213,9 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
|||||||
{t('parameters.sendToUnifiedCanvas')}
|
{t('parameters.sendToUnifiedCanvas')}
|
||||||
</MenuItem>
|
</MenuItem>
|
||||||
<MenuItem data-warning>
|
<MenuItem data-warning>
|
||||||
<DeleteImageModal image={image}>
|
{/* <DeleteImageModal image={image}>
|
||||||
<p>{t('parameters.deleteImage')}</p>
|
<p>{t('parameters.deleteImage')}</p>
|
||||||
</DeleteImageModal>
|
</DeleteImageModal> */}
|
||||||
</MenuItem>
|
</MenuItem>
|
||||||
</MenuList>
|
</MenuList>
|
||||||
)}
|
)}
|
||||||
@ -219,7 +223,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
|||||||
{(ref) => (
|
{(ref) => (
|
||||||
<Box
|
<Box
|
||||||
position="relative"
|
position="relative"
|
||||||
key={uuid}
|
key={name}
|
||||||
onMouseOver={handleMouseOver}
|
onMouseOver={handleMouseOver}
|
||||||
onMouseOut={handleMouseOut}
|
onMouseOut={handleMouseOut}
|
||||||
userSelect="none"
|
userSelect="none"
|
||||||
@ -244,7 +248,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
|||||||
shouldUseSingleGalleryColumn ? 'contain' : galleryImageObjectFit
|
shouldUseSingleGalleryColumn ? 'contain' : galleryImageObjectFit
|
||||||
}
|
}
|
||||||
rounded="md"
|
rounded="md"
|
||||||
src={thumbnail || url}
|
src={getUrl(thumbnail || url)}
|
||||||
loading="lazy"
|
loading="lazy"
|
||||||
sx={{
|
sx={{
|
||||||
position: 'absolute',
|
position: 'absolute',
|
||||||
@ -290,7 +294,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
|||||||
insetInlineEnd: 1,
|
insetInlineEnd: 1,
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<DeleteImageModal image={image}>
|
{/* <DeleteImageModal image={image}>
|
||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
aria-label={t('parameters.deleteImage')}
|
aria-label={t('parameters.deleteImage')}
|
||||||
icon={<FaTrashAlt />}
|
icon={<FaTrashAlt />}
|
||||||
@ -298,7 +302,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
|||||||
fontSize={14}
|
fontSize={14}
|
||||||
isDisabled={!mayDeleteImage}
|
isDisabled={!mayDeleteImage}
|
||||||
/>
|
/>
|
||||||
</DeleteImageModal>
|
</DeleteImageModal> */}
|
||||||
</Box>
|
</Box>
|
||||||
)}
|
)}
|
||||||
</Box>
|
</Box>
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import { ButtonGroup, Flex, Grid, Icon, Text } from '@chakra-ui/react';
|
import { ButtonGroup, Flex, Grid, Icon, Image, Text } from '@chakra-ui/react';
|
||||||
import { requestImages } from 'app/socketio/actions';
|
import { requestImages } from 'app/socketio/actions';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||||
import IAIButton from 'common/components/IAIButton';
|
import IAIButton from 'common/components/IAIButton';
|
||||||
@ -25,9 +25,44 @@ import HoverableImage from './HoverableImage';
|
|||||||
|
|
||||||
import Scrollable from 'features/ui/components/common/Scrollable';
|
import Scrollable from 'features/ui/components/common/Scrollable';
|
||||||
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||||
|
import {
|
||||||
|
resultsAdapter,
|
||||||
|
selectResultsAll,
|
||||||
|
selectResultsTotal,
|
||||||
|
} from '../store/resultsSlice';
|
||||||
|
import {
|
||||||
|
receivedResultImagesPage,
|
||||||
|
receivedUploadImagesPage,
|
||||||
|
} from 'services/thunks/gallery';
|
||||||
|
import { selectUploadsAll, uploadsAdapter } from '../store/uploadsSlice';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { RootState } from 'app/store';
|
||||||
|
|
||||||
const GALLERY_SHOW_BUTTONS_MIN_WIDTH = 290;
|
const GALLERY_SHOW_BUTTONS_MIN_WIDTH = 290;
|
||||||
|
|
||||||
|
const gallerySelector = createSelector(
|
||||||
|
[
|
||||||
|
(state: RootState) => state.uploads,
|
||||||
|
(state: RootState) => state.results,
|
||||||
|
(state: RootState) => state.gallery,
|
||||||
|
],
|
||||||
|
(uploads, results, gallery) => {
|
||||||
|
const { currentCategory } = gallery;
|
||||||
|
|
||||||
|
return currentCategory === 'result'
|
||||||
|
? {
|
||||||
|
images: resultsAdapter.getSelectors().selectAll(results),
|
||||||
|
isLoading: results.isLoading,
|
||||||
|
areMoreImagesAvailable: results.page < results.pages - 1,
|
||||||
|
}
|
||||||
|
: {
|
||||||
|
images: uploadsAdapter.getSelectors().selectAll(uploads),
|
||||||
|
isLoading: uploads.isLoading,
|
||||||
|
areMoreImagesAvailable: uploads.page < uploads.pages - 1,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
const ImageGalleryContent = () => {
|
const ImageGalleryContent = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
@ -35,7 +70,7 @@ const ImageGalleryContent = () => {
|
|||||||
const [shouldShouldIconButtons, setShouldShouldIconButtons] = useState(true);
|
const [shouldShouldIconButtons, setShouldShouldIconButtons] = useState(true);
|
||||||
|
|
||||||
const {
|
const {
|
||||||
images,
|
// images,
|
||||||
currentCategory,
|
currentCategory,
|
||||||
currentImageUuid,
|
currentImageUuid,
|
||||||
shouldPinGallery,
|
shouldPinGallery,
|
||||||
@ -43,12 +78,24 @@ const ImageGalleryContent = () => {
|
|||||||
galleryGridTemplateColumns,
|
galleryGridTemplateColumns,
|
||||||
galleryImageObjectFit,
|
galleryImageObjectFit,
|
||||||
shouldAutoSwitchToNewImages,
|
shouldAutoSwitchToNewImages,
|
||||||
areMoreImagesAvailable,
|
// areMoreImagesAvailable,
|
||||||
shouldUseSingleGalleryColumn,
|
shouldUseSingleGalleryColumn,
|
||||||
} = useAppSelector(imageGallerySelector);
|
} = useAppSelector(imageGallerySelector);
|
||||||
|
|
||||||
|
const { images, areMoreImagesAvailable, isLoading } =
|
||||||
|
useAppSelector(gallerySelector);
|
||||||
|
|
||||||
|
// const handleClickLoadMore = () => {
|
||||||
|
// dispatch(requestImages(currentCategory));
|
||||||
|
// };
|
||||||
const handleClickLoadMore = () => {
|
const handleClickLoadMore = () => {
|
||||||
dispatch(requestImages(currentCategory));
|
if (currentCategory === 'result') {
|
||||||
|
dispatch(receivedResultImagesPage());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (currentCategory === 'user') {
|
||||||
|
dispatch(receivedUploadImagesPage());
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleChangeGalleryImageMinimumWidth = (v: number) => {
|
const handleChangeGalleryImageMinimumWidth = (v: number) => {
|
||||||
@ -203,11 +250,11 @@ const ImageGalleryContent = () => {
|
|||||||
style={{ gridTemplateColumns: galleryGridTemplateColumns }}
|
style={{ gridTemplateColumns: galleryGridTemplateColumns }}
|
||||||
>
|
>
|
||||||
{images.map((image) => {
|
{images.map((image) => {
|
||||||
const { uuid } = image;
|
const { name } = image;
|
||||||
const isSelected = currentImageUuid === uuid;
|
const isSelected = currentImageUuid === name;
|
||||||
return (
|
return (
|
||||||
<HoverableImage
|
<HoverableImage
|
||||||
key={uuid}
|
key={name}
|
||||||
image={image}
|
image={image}
|
||||||
isSelected={isSelected}
|
isSelected={isSelected}
|
||||||
/>
|
/>
|
||||||
@ -217,6 +264,7 @@ const ImageGalleryContent = () => {
|
|||||||
<IAIButton
|
<IAIButton
|
||||||
onClick={handleClickLoadMore}
|
onClick={handleClickLoadMore}
|
||||||
isDisabled={!areMoreImagesAvailable}
|
isDisabled={!areMoreImagesAvailable}
|
||||||
|
isLoading={isLoading}
|
||||||
flexShrink={0}
|
flexShrink={0}
|
||||||
>
|
>
|
||||||
{areMoreImagesAvailable
|
{areMoreImagesAvailable
|
||||||
|
@ -33,12 +33,13 @@ const GALLERY_TAB_WIDTHS: Record<
|
|||||||
InvokeTabName,
|
InvokeTabName,
|
||||||
{ galleryMinWidth: number; galleryMaxWidth: number }
|
{ galleryMinWidth: number; galleryMaxWidth: number }
|
||||||
> = {
|
> = {
|
||||||
txt2img: { galleryMinWidth: 200, galleryMaxWidth: 500 },
|
// txt2img: { galleryMinWidth: 200, galleryMaxWidth: 500 },
|
||||||
img2img: { galleryMinWidth: 200, galleryMaxWidth: 500 },
|
// img2img: { galleryMinWidth: 200, galleryMaxWidth: 500 },
|
||||||
|
linear: { galleryMinWidth: 200, galleryMaxWidth: 500 },
|
||||||
unifiedCanvas: { galleryMinWidth: 200, galleryMaxWidth: 200 },
|
unifiedCanvas: { galleryMinWidth: 200, galleryMaxWidth: 200 },
|
||||||
nodes: { galleryMinWidth: 200, galleryMaxWidth: 500 },
|
nodes: { galleryMinWidth: 200, galleryMaxWidth: 500 },
|
||||||
postprocessing: { galleryMinWidth: 200, galleryMaxWidth: 500 },
|
// postprocessing: { galleryMinWidth: 200, galleryMaxWidth: 500 },
|
||||||
training: { galleryMinWidth: 200, galleryMaxWidth: 500 },
|
// training: { galleryMinWidth: 200, galleryMaxWidth: 500 },
|
||||||
};
|
};
|
||||||
|
|
||||||
const galleryPanelSelector = createSelector(
|
const galleryPanelSelector = createSelector(
|
||||||
|
@ -11,6 +11,7 @@ import {
|
|||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
import * as InvokeAI from 'app/invokeai';
|
import * as InvokeAI from 'app/invokeai';
|
||||||
import { useAppDispatch } from 'app/storeHooks';
|
import { useAppDispatch } from 'app/storeHooks';
|
||||||
|
import { useGetUrl } from 'common/util/getUrl';
|
||||||
import promptToString from 'common/util/promptToString';
|
import promptToString from 'common/util/promptToString';
|
||||||
import { seedWeightsToString } from 'common/util/seedWeightPairs';
|
import { seedWeightsToString } from 'common/util/seedWeightPairs';
|
||||||
import useSetBothPrompts from 'features/parameters/hooks/usePrompt';
|
import useSetBothPrompts from 'features/parameters/hooks/usePrompt';
|
||||||
@ -18,7 +19,7 @@ import {
|
|||||||
setCfgScale,
|
setCfgScale,
|
||||||
setHeight,
|
setHeight,
|
||||||
setImg2imgStrength,
|
setImg2imgStrength,
|
||||||
setInitialImage,
|
// setInitialImage,
|
||||||
setMaskPath,
|
setMaskPath,
|
||||||
setPerlin,
|
setPerlin,
|
||||||
setSampler,
|
setSampler,
|
||||||
@ -120,7 +121,7 @@ type ImageMetadataViewerProps = {
|
|||||||
const memoEqualityCheck = (
|
const memoEqualityCheck = (
|
||||||
prev: ImageMetadataViewerProps,
|
prev: ImageMetadataViewerProps,
|
||||||
next: ImageMetadataViewerProps
|
next: ImageMetadataViewerProps
|
||||||
) => prev.image.uuid === next.image.uuid;
|
) => prev.image.name === next.image.name;
|
||||||
|
|
||||||
// TODO: Show more interesting information in this component.
|
// TODO: Show more interesting information in this component.
|
||||||
|
|
||||||
@ -137,34 +138,13 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
|||||||
dispatch(setShouldShowImageDetails(false));
|
dispatch(setShouldShowImageDetails(false));
|
||||||
});
|
});
|
||||||
|
|
||||||
const metadata = image?.metadata?.image || {};
|
const sessionId = image.metadata.invokeai?.session_id;
|
||||||
const dreamPrompt = image?.dreamPrompt;
|
const node = image.metadata.invokeai?.node as Record<string, any>;
|
||||||
|
|
||||||
const {
|
|
||||||
cfg_scale,
|
|
||||||
fit,
|
|
||||||
height,
|
|
||||||
hires_fix,
|
|
||||||
init_image_path,
|
|
||||||
mask_image_path,
|
|
||||||
orig_path,
|
|
||||||
perlin,
|
|
||||||
postprocessing,
|
|
||||||
prompt,
|
|
||||||
sampler,
|
|
||||||
seamless,
|
|
||||||
seed,
|
|
||||||
steps,
|
|
||||||
strength,
|
|
||||||
threshold,
|
|
||||||
type,
|
|
||||||
variations,
|
|
||||||
width,
|
|
||||||
} = metadata;
|
|
||||||
|
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
const { getUrl } = useGetUrl();
|
||||||
|
|
||||||
const metadataJSON = JSON.stringify(image.metadata, null, 2);
|
const metadataJSON = JSON.stringify(image, null, 2);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex
|
<Flex
|
||||||
@ -183,262 +163,134 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
|||||||
>
|
>
|
||||||
<Flex gap={2}>
|
<Flex gap={2}>
|
||||||
<Text fontWeight="semibold">File:</Text>
|
<Text fontWeight="semibold">File:</Text>
|
||||||
<Link href={image.url} isExternal maxW="calc(100% - 3rem)">
|
<Link href={getUrl(image.url)} isExternal maxW="calc(100% - 3rem)">
|
||||||
{image.url.length > 64
|
{image.url.length > 64
|
||||||
? image.url.substring(0, 64).concat('...')
|
? image.url.substring(0, 64).concat('...')
|
||||||
: image.url}
|
: image.url}
|
||||||
<ExternalLinkIcon mx="2px" />
|
<ExternalLinkIcon mx="2px" />
|
||||||
</Link>
|
</Link>
|
||||||
</Flex>
|
</Flex>
|
||||||
{Object.keys(metadata).length > 0 ? (
|
{node && Object.keys(node).length > 0 ? (
|
||||||
<>
|
<>
|
||||||
{type && <MetadataItem label="Generation type" value={type} />}
|
{node.type && (
|
||||||
{image.metadata?.model_weights && (
|
<MetadataItem label="Invocation type" value={node.type} />
|
||||||
<MetadataItem label="Model" value={image.metadata.model_weights} />
|
|
||||||
)}
|
)}
|
||||||
{['esrgan', 'gfpgan'].includes(type) && (
|
{node.model && <MetadataItem label="Model" value={node.model} />}
|
||||||
<MetadataItem label="Original image" value={orig_path} />
|
{node.prompt && (
|
||||||
)}
|
|
||||||
{prompt && (
|
|
||||||
<MetadataItem
|
<MetadataItem
|
||||||
label="Prompt"
|
label="Prompt"
|
||||||
labelPosition="top"
|
labelPosition="top"
|
||||||
value={
|
value={
|
||||||
typeof prompt === 'string' ? prompt : promptToString(prompt)
|
typeof node.prompt === 'string'
|
||||||
|
? node.prompt
|
||||||
|
: promptToString(node.prompt)
|
||||||
}
|
}
|
||||||
onClick={() => setBothPrompts(prompt)}
|
onClick={() => setBothPrompts(node.prompt)}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{seed !== undefined && (
|
{node.seed !== undefined && (
|
||||||
<MetadataItem
|
<MetadataItem
|
||||||
label="Seed"
|
label="Seed"
|
||||||
value={seed}
|
value={node.seed}
|
||||||
onClick={() => dispatch(setSeed(seed))}
|
onClick={() => dispatch(setSeed(Number(node.seed)))}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{threshold !== undefined && (
|
{node.threshold !== undefined && (
|
||||||
<MetadataItem
|
<MetadataItem
|
||||||
label="Noise Threshold"
|
label="Noise Threshold"
|
||||||
value={threshold}
|
value={node.threshold}
|
||||||
onClick={() => dispatch(setThreshold(threshold))}
|
onClick={() => dispatch(setThreshold(Number(node.threshold)))}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{perlin !== undefined && (
|
{node.perlin !== undefined && (
|
||||||
<MetadataItem
|
<MetadataItem
|
||||||
label="Perlin Noise"
|
label="Perlin Noise"
|
||||||
value={perlin}
|
value={node.perlin}
|
||||||
onClick={() => dispatch(setPerlin(perlin))}
|
onClick={() => dispatch(setPerlin(Number(node.perlin)))}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{sampler && (
|
{node.scheduler && (
|
||||||
<MetadataItem
|
<MetadataItem
|
||||||
label="Sampler"
|
label="Sampler"
|
||||||
value={sampler}
|
value={node.scheduler}
|
||||||
onClick={() => dispatch(setSampler(sampler))}
|
onClick={() => dispatch(setSampler(node.scheduler))}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{steps && (
|
{node.steps && (
|
||||||
<MetadataItem
|
<MetadataItem
|
||||||
label="Steps"
|
label="Steps"
|
||||||
value={steps}
|
value={node.steps}
|
||||||
onClick={() => dispatch(setSteps(steps))}
|
onClick={() => dispatch(setSteps(Number(node.steps)))}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{cfg_scale !== undefined && (
|
{node.cfg_scale !== undefined && (
|
||||||
<MetadataItem
|
<MetadataItem
|
||||||
label="CFG scale"
|
label="CFG scale"
|
||||||
value={cfg_scale}
|
value={node.cfg_scale}
|
||||||
onClick={() => dispatch(setCfgScale(cfg_scale))}
|
onClick={() => dispatch(setCfgScale(Number(node.cfg_scale)))}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{variations && variations.length > 0 && (
|
{node.variations && node.variations.length > 0 && (
|
||||||
<MetadataItem
|
<MetadataItem
|
||||||
label="Seed-weight pairs"
|
label="Seed-weight pairs"
|
||||||
value={seedWeightsToString(variations)}
|
value={seedWeightsToString(node.variations)}
|
||||||
onClick={() =>
|
onClick={() =>
|
||||||
dispatch(setSeedWeights(seedWeightsToString(variations)))
|
dispatch(setSeedWeights(seedWeightsToString(node.variations)))
|
||||||
}
|
}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{seamless && (
|
{node.seamless && (
|
||||||
<MetadataItem
|
<MetadataItem
|
||||||
label="Seamless"
|
label="Seamless"
|
||||||
value={seamless}
|
value={node.seamless}
|
||||||
onClick={() => dispatch(setSeamless(seamless))}
|
onClick={() => dispatch(setSeamless(node.seamless))}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{hires_fix && (
|
{node.hires_fix && (
|
||||||
<MetadataItem
|
<MetadataItem
|
||||||
label="High Resolution Optimization"
|
label="High Resolution Optimization"
|
||||||
value={hires_fix}
|
value={node.hires_fix}
|
||||||
onClick={() => dispatch(setHiresFix(hires_fix))}
|
onClick={() => dispatch(setHiresFix(node.hires_fix))}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{width && (
|
{node.width && (
|
||||||
<MetadataItem
|
<MetadataItem
|
||||||
label="Width"
|
label="Width"
|
||||||
value={width}
|
value={node.width}
|
||||||
onClick={() => dispatch(setWidth(width))}
|
onClick={() => dispatch(setWidth(Number(node.width)))}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{height && (
|
{node.height && (
|
||||||
<MetadataItem
|
<MetadataItem
|
||||||
label="Height"
|
label="Height"
|
||||||
value={height}
|
value={node.height}
|
||||||
onClick={() => dispatch(setHeight(height))}
|
onClick={() => dispatch(setHeight(Number(node.height)))}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{init_image_path && (
|
{/* {init_image_path && (
|
||||||
<MetadataItem
|
<MetadataItem
|
||||||
label="Initial image"
|
label="Initial image"
|
||||||
value={init_image_path}
|
value={init_image_path}
|
||||||
isLink
|
isLink
|
||||||
onClick={() => dispatch(setInitialImage(init_image_path))}
|
onClick={() => dispatch(setInitialImage(init_image_path))}
|
||||||
/>
|
/>
|
||||||
)}
|
)} */}
|
||||||
{mask_image_path && (
|
{node.strength && (
|
||||||
<MetadataItem
|
|
||||||
label="Mask image"
|
|
||||||
value={mask_image_path}
|
|
||||||
isLink
|
|
||||||
onClick={() => dispatch(setMaskPath(mask_image_path))}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
{type === 'img2img' && strength && (
|
|
||||||
<MetadataItem
|
<MetadataItem
|
||||||
label="Image to image strength"
|
label="Image to image strength"
|
||||||
value={strength}
|
value={node.strength}
|
||||||
onClick={() => dispatch(setImg2imgStrength(strength))}
|
onClick={() =>
|
||||||
|
dispatch(setImg2imgStrength(Number(node.strength)))
|
||||||
|
}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{fit && (
|
{node.fit && (
|
||||||
<MetadataItem
|
<MetadataItem
|
||||||
label="Image to image fit"
|
label="Image to image fit"
|
||||||
value={fit}
|
value={node.fit}
|
||||||
onClick={() => dispatch(setShouldFitToWidthHeight(fit))}
|
onClick={() => dispatch(setShouldFitToWidthHeight(node.fit))}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{postprocessing && postprocessing.length > 0 && (
|
|
||||||
<>
|
|
||||||
<Heading size="sm">Postprocessing</Heading>
|
|
||||||
{postprocessing.map(
|
|
||||||
(
|
|
||||||
postprocess: InvokeAI.PostProcessedImageMetadata,
|
|
||||||
i: number
|
|
||||||
) => {
|
|
||||||
if (postprocess.type === 'esrgan') {
|
|
||||||
const { scale, strength, denoise_str } = postprocess;
|
|
||||||
return (
|
|
||||||
<Flex key={i} pl={8} gap={1} direction="column">
|
|
||||||
<Text size="md">{`${i + 1}: Upscale (ESRGAN)`}</Text>
|
|
||||||
<MetadataItem
|
|
||||||
label="Scale"
|
|
||||||
value={scale}
|
|
||||||
onClick={() => dispatch(setUpscalingLevel(scale))}
|
|
||||||
/>
|
|
||||||
<MetadataItem
|
|
||||||
label="Strength"
|
|
||||||
value={strength}
|
|
||||||
onClick={() =>
|
|
||||||
dispatch(setUpscalingStrength(strength))
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
{denoise_str !== undefined && (
|
|
||||||
<MetadataItem
|
|
||||||
label="Denoising strength"
|
|
||||||
value={denoise_str}
|
|
||||||
onClick={() =>
|
|
||||||
dispatch(setUpscalingDenoising(denoise_str))
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
} else if (postprocess.type === 'gfpgan') {
|
|
||||||
const { strength } = postprocess;
|
|
||||||
return (
|
|
||||||
<Flex key={i} pl={8} gap={1} direction="column">
|
|
||||||
<Text size="md">{`${
|
|
||||||
i + 1
|
|
||||||
}: Face restoration (GFPGAN)`}</Text>
|
|
||||||
|
|
||||||
<MetadataItem
|
|
||||||
label="Strength"
|
|
||||||
value={strength}
|
|
||||||
onClick={() => {
|
|
||||||
dispatch(setFacetoolStrength(strength));
|
|
||||||
dispatch(setFacetoolType('gfpgan'));
|
|
||||||
}}
|
|
||||||
/>
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
} else if (postprocess.type === 'codeformer') {
|
|
||||||
const { strength, fidelity } = postprocess;
|
|
||||||
return (
|
|
||||||
<Flex key={i} pl={8} gap={1} direction="column">
|
|
||||||
<Text size="md">{`${
|
|
||||||
i + 1
|
|
||||||
}: Face restoration (Codeformer)`}</Text>
|
|
||||||
|
|
||||||
<MetadataItem
|
|
||||||
label="Strength"
|
|
||||||
value={strength}
|
|
||||||
onClick={() => {
|
|
||||||
dispatch(setFacetoolStrength(strength));
|
|
||||||
dispatch(setFacetoolType('codeformer'));
|
|
||||||
}}
|
|
||||||
/>
|
|
||||||
{fidelity && (
|
|
||||||
<MetadataItem
|
|
||||||
label="Fidelity"
|
|
||||||
value={fidelity}
|
|
||||||
onClick={() => {
|
|
||||||
dispatch(setCodeformerFidelity(fidelity));
|
|
||||||
dispatch(setFacetoolType('codeformer'));
|
|
||||||
}}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)}
|
|
||||||
</>
|
|
||||||
)}
|
|
||||||
{dreamPrompt && (
|
|
||||||
<MetadataItem withCopy label="Dream Prompt" value={dreamPrompt} />
|
|
||||||
)}
|
|
||||||
<Flex gap={2} direction="column">
|
|
||||||
<Flex gap={2}>
|
|
||||||
<Tooltip label="Copy metadata JSON">
|
|
||||||
<IconButton
|
|
||||||
aria-label={t('accessibility.copyMetadataJson')}
|
|
||||||
icon={<FaCopy />}
|
|
||||||
size="xs"
|
|
||||||
variant="ghost"
|
|
||||||
fontSize={14}
|
|
||||||
onClick={() => navigator.clipboard.writeText(metadataJSON)}
|
|
||||||
/>
|
|
||||||
</Tooltip>
|
|
||||||
<Text fontWeight="semibold">Metadata JSON:</Text>
|
|
||||||
</Flex>
|
|
||||||
<Box
|
|
||||||
sx={{
|
|
||||||
mt: 0,
|
|
||||||
mr: 2,
|
|
||||||
mb: 4,
|
|
||||||
ml: 2,
|
|
||||||
padding: 4,
|
|
||||||
borderRadius: 'base',
|
|
||||||
overflowX: 'scroll',
|
|
||||||
wordBreak: 'break-all',
|
|
||||||
bg: 'whiteAlpha.500',
|
|
||||||
_dark: { bg: 'blackAlpha.500' },
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<pre>{metadataJSON}</pre>
|
|
||||||
</Box>
|
|
||||||
</Flex>
|
|
||||||
</>
|
</>
|
||||||
) : (
|
) : (
|
||||||
<Center width="100%" pt={10}>
|
<Center width="100%" pt={10}>
|
||||||
@ -447,6 +299,37 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
|||||||
</Text>
|
</Text>
|
||||||
</Center>
|
</Center>
|
||||||
)}
|
)}
|
||||||
|
<Flex gap={2} direction="column">
|
||||||
|
<Flex gap={2}>
|
||||||
|
<Tooltip label="Copy metadata JSON">
|
||||||
|
<IconButton
|
||||||
|
aria-label={t('accessibility.copyMetadataJson')}
|
||||||
|
icon={<FaCopy />}
|
||||||
|
size="xs"
|
||||||
|
variant="ghost"
|
||||||
|
fontSize={14}
|
||||||
|
onClick={() => navigator.clipboard.writeText(metadataJSON)}
|
||||||
|
/>
|
||||||
|
</Tooltip>
|
||||||
|
<Text fontWeight="semibold">Metadata JSON:</Text>
|
||||||
|
</Flex>
|
||||||
|
<Box
|
||||||
|
sx={{
|
||||||
|
mt: 0,
|
||||||
|
mr: 2,
|
||||||
|
mb: 4,
|
||||||
|
ml: 2,
|
||||||
|
padding: 4,
|
||||||
|
borderRadius: 'base',
|
||||||
|
overflowX: 'scroll',
|
||||||
|
wordBreak: 'break-all',
|
||||||
|
bg: 'whiteAlpha.500',
|
||||||
|
_dark: { bg: 'blackAlpha.500' },
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<pre>{metadataJSON}</pre>
|
||||||
|
</Box>
|
||||||
|
</Flex>
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
}, memoEqualityCheck);
|
}, memoEqualityCheck);
|
||||||
|
@ -0,0 +1,470 @@
|
|||||||
|
import { ExternalLinkIcon } from '@chakra-ui/icons';
|
||||||
|
import {
|
||||||
|
Box,
|
||||||
|
Center,
|
||||||
|
Flex,
|
||||||
|
Heading,
|
||||||
|
IconButton,
|
||||||
|
Link,
|
||||||
|
Text,
|
||||||
|
Tooltip,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
import * as InvokeAI from 'app/invokeai';
|
||||||
|
import { useAppDispatch } from 'app/storeHooks';
|
||||||
|
import { useGetUrl } from 'common/util/getUrl';
|
||||||
|
import promptToString from 'common/util/promptToString';
|
||||||
|
import { seedWeightsToString } from 'common/util/seedWeightPairs';
|
||||||
|
import useSetBothPrompts from 'features/parameters/hooks/usePrompt';
|
||||||
|
import {
|
||||||
|
setCfgScale,
|
||||||
|
setHeight,
|
||||||
|
setImg2imgStrength,
|
||||||
|
// setInitialImage,
|
||||||
|
setMaskPath,
|
||||||
|
setPerlin,
|
||||||
|
setSampler,
|
||||||
|
setSeamless,
|
||||||
|
setSeed,
|
||||||
|
setSeedWeights,
|
||||||
|
setShouldFitToWidthHeight,
|
||||||
|
setSteps,
|
||||||
|
setThreshold,
|
||||||
|
setWidth,
|
||||||
|
} from 'features/parameters/store/generationSlice';
|
||||||
|
import {
|
||||||
|
setCodeformerFidelity,
|
||||||
|
setFacetoolStrength,
|
||||||
|
setFacetoolType,
|
||||||
|
setHiresFix,
|
||||||
|
setUpscalingDenoising,
|
||||||
|
setUpscalingLevel,
|
||||||
|
setUpscalingStrength,
|
||||||
|
} from 'features/parameters/store/postprocessingSlice';
|
||||||
|
import { setShouldShowImageDetails } from 'features/ui/store/uiSlice';
|
||||||
|
import { memo } from 'react';
|
||||||
|
import { useHotkeys } from 'react-hotkeys-hook';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { FaCopy } from 'react-icons/fa';
|
||||||
|
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
|
||||||
|
import * as png from '@stevebel/png';
|
||||||
|
|
||||||
|
type MetadataItemProps = {
|
||||||
|
isLink?: boolean;
|
||||||
|
label: string;
|
||||||
|
onClick?: () => void;
|
||||||
|
value: number | string | boolean;
|
||||||
|
labelPosition?: string;
|
||||||
|
withCopy?: boolean;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Component to display an individual metadata item or parameter.
|
||||||
|
*/
|
||||||
|
const MetadataItem = ({
|
||||||
|
label,
|
||||||
|
value,
|
||||||
|
onClick,
|
||||||
|
isLink,
|
||||||
|
labelPosition,
|
||||||
|
withCopy = false,
|
||||||
|
}: MetadataItemProps) => {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex gap={2}>
|
||||||
|
{onClick && (
|
||||||
|
<Tooltip label={`Recall ${label}`}>
|
||||||
|
<IconButton
|
||||||
|
aria-label={t('accessibility.useThisParameter')}
|
||||||
|
icon={<IoArrowUndoCircleOutline />}
|
||||||
|
size="xs"
|
||||||
|
variant="ghost"
|
||||||
|
fontSize={20}
|
||||||
|
onClick={onClick}
|
||||||
|
/>
|
||||||
|
</Tooltip>
|
||||||
|
)}
|
||||||
|
{withCopy && (
|
||||||
|
<Tooltip label={`Copy ${label}`}>
|
||||||
|
<IconButton
|
||||||
|
aria-label={`Copy ${label}`}
|
||||||
|
icon={<FaCopy />}
|
||||||
|
size="xs"
|
||||||
|
variant="ghost"
|
||||||
|
fontSize={14}
|
||||||
|
onClick={() => navigator.clipboard.writeText(value.toString())}
|
||||||
|
/>
|
||||||
|
</Tooltip>
|
||||||
|
)}
|
||||||
|
<Flex direction={labelPosition ? 'column' : 'row'}>
|
||||||
|
<Text fontWeight="semibold" whiteSpace="pre-wrap" pr={2}>
|
||||||
|
{label}:
|
||||||
|
</Text>
|
||||||
|
{isLink ? (
|
||||||
|
<Link href={value.toString()} isExternal wordBreak="break-all">
|
||||||
|
{value.toString()} <ExternalLinkIcon mx="2px" />
|
||||||
|
</Link>
|
||||||
|
) : (
|
||||||
|
<Text overflowY="scroll" wordBreak="break-all">
|
||||||
|
{value.toString()}
|
||||||
|
</Text>
|
||||||
|
)}
|
||||||
|
</Flex>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
type ImageMetadataViewerProps = {
|
||||||
|
image: InvokeAI.Image;
|
||||||
|
};
|
||||||
|
|
||||||
|
// TODO: I don't know if this is needed.
|
||||||
|
const memoEqualityCheck = (
|
||||||
|
prev: ImageMetadataViewerProps,
|
||||||
|
next: ImageMetadataViewerProps
|
||||||
|
) => prev.image.name === next.image.name;
|
||||||
|
|
||||||
|
// TODO: Show more interesting information in this component.
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Image metadata viewer overlays currently selected image and provides
|
||||||
|
* access to any of its metadata for use in processing.
|
||||||
|
*/
|
||||||
|
const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const setBothPrompts = useSetBothPrompts();
|
||||||
|
|
||||||
|
useHotkeys('esc', () => {
|
||||||
|
dispatch(setShouldShowImageDetails(false));
|
||||||
|
});
|
||||||
|
|
||||||
|
const metadata = image?.metadata.sd_metadata || {};
|
||||||
|
const dreamPrompt = image?.metadata.sd_metadata?.dreamPrompt;
|
||||||
|
|
||||||
|
const {
|
||||||
|
cfg_scale,
|
||||||
|
fit,
|
||||||
|
height,
|
||||||
|
hires_fix,
|
||||||
|
init_image_path,
|
||||||
|
mask_image_path,
|
||||||
|
orig_path,
|
||||||
|
perlin,
|
||||||
|
postprocessing,
|
||||||
|
prompt,
|
||||||
|
sampler,
|
||||||
|
seamless,
|
||||||
|
seed,
|
||||||
|
steps,
|
||||||
|
strength,
|
||||||
|
threshold,
|
||||||
|
type,
|
||||||
|
variations,
|
||||||
|
width,
|
||||||
|
model_weights,
|
||||||
|
} = metadata;
|
||||||
|
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const { getUrl } = useGetUrl();
|
||||||
|
|
||||||
|
const metadataJSON = JSON.stringify(image, null, 2);
|
||||||
|
|
||||||
|
// fetch(getUrl(image.url))
|
||||||
|
// .then((r) => r.arrayBuffer())
|
||||||
|
// .then((buffer) => {
|
||||||
|
// const { text } = png.decode(buffer);
|
||||||
|
// const metadata = text?.['sd-metadata']
|
||||||
|
// ? JSON.parse(text['sd-metadata'] ?? {})
|
||||||
|
// : {};
|
||||||
|
// console.log(metadata);
|
||||||
|
// });
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
padding: 4,
|
||||||
|
gap: 1,
|
||||||
|
flexDirection: 'column',
|
||||||
|
width: 'full',
|
||||||
|
height: 'full',
|
||||||
|
backdropFilter: 'blur(20px)',
|
||||||
|
bg: 'whiteAlpha.600',
|
||||||
|
_dark: {
|
||||||
|
bg: 'blackAlpha.600',
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Flex gap={2}>
|
||||||
|
<Text fontWeight="semibold">File:</Text>
|
||||||
|
<Link href={getUrl(image.url)} isExternal maxW="calc(100% - 3rem)">
|
||||||
|
{image.url.length > 64
|
||||||
|
? image.url.substring(0, 64).concat('...')
|
||||||
|
: image.url}
|
||||||
|
<ExternalLinkIcon mx="2px" />
|
||||||
|
</Link>
|
||||||
|
</Flex>
|
||||||
|
<Flex gap={2} direction="column">
|
||||||
|
<Flex gap={2}>
|
||||||
|
<Tooltip label="Copy metadata JSON">
|
||||||
|
<IconButton
|
||||||
|
aria-label={t('accessibility.copyMetadataJson')}
|
||||||
|
icon={<FaCopy />}
|
||||||
|
size="xs"
|
||||||
|
variant="ghost"
|
||||||
|
fontSize={14}
|
||||||
|
onClick={() => navigator.clipboard.writeText(metadataJSON)}
|
||||||
|
/>
|
||||||
|
</Tooltip>
|
||||||
|
<Text fontWeight="semibold">Metadata JSON:</Text>
|
||||||
|
</Flex>
|
||||||
|
<Box
|
||||||
|
sx={{
|
||||||
|
mt: 0,
|
||||||
|
mr: 2,
|
||||||
|
mb: 4,
|
||||||
|
ml: 2,
|
||||||
|
padding: 4,
|
||||||
|
borderRadius: 'base',
|
||||||
|
overflowX: 'scroll',
|
||||||
|
wordBreak: 'break-all',
|
||||||
|
bg: 'whiteAlpha.500',
|
||||||
|
_dark: { bg: 'blackAlpha.500' },
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<pre>{metadataJSON}</pre>
|
||||||
|
</Box>
|
||||||
|
</Flex>
|
||||||
|
{Object.keys(metadata).length > 0 ? (
|
||||||
|
<>
|
||||||
|
{type && <MetadataItem label="Generation type" value={type} />}
|
||||||
|
{model_weights && (
|
||||||
|
<MetadataItem label="Model" value={model_weights} />
|
||||||
|
)}
|
||||||
|
{['esrgan', 'gfpgan'].includes(type) && (
|
||||||
|
<MetadataItem label="Original image" value={orig_path} />
|
||||||
|
)}
|
||||||
|
{prompt && (
|
||||||
|
<MetadataItem
|
||||||
|
label="Prompt"
|
||||||
|
labelPosition="top"
|
||||||
|
value={
|
||||||
|
typeof prompt === 'string' ? prompt : promptToString(prompt)
|
||||||
|
}
|
||||||
|
onClick={() => setBothPrompts(prompt)}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{seed !== undefined && (
|
||||||
|
<MetadataItem
|
||||||
|
label="Seed"
|
||||||
|
value={seed}
|
||||||
|
onClick={() => dispatch(setSeed(seed))}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{threshold !== undefined && (
|
||||||
|
<MetadataItem
|
||||||
|
label="Noise Threshold"
|
||||||
|
value={threshold}
|
||||||
|
onClick={() => dispatch(setThreshold(threshold))}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{perlin !== undefined && (
|
||||||
|
<MetadataItem
|
||||||
|
label="Perlin Noise"
|
||||||
|
value={perlin}
|
||||||
|
onClick={() => dispatch(setPerlin(perlin))}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{sampler && (
|
||||||
|
<MetadataItem
|
||||||
|
label="Sampler"
|
||||||
|
value={sampler}
|
||||||
|
onClick={() => dispatch(setSampler(sampler))}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{steps && (
|
||||||
|
<MetadataItem
|
||||||
|
label="Steps"
|
||||||
|
value={steps}
|
||||||
|
onClick={() => dispatch(setSteps(steps))}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{cfg_scale !== undefined && (
|
||||||
|
<MetadataItem
|
||||||
|
label="CFG scale"
|
||||||
|
value={cfg_scale}
|
||||||
|
onClick={() => dispatch(setCfgScale(cfg_scale))}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{variations && variations.length > 0 && (
|
||||||
|
<MetadataItem
|
||||||
|
label="Seed-weight pairs"
|
||||||
|
value={seedWeightsToString(variations)}
|
||||||
|
onClick={() =>
|
||||||
|
dispatch(setSeedWeights(seedWeightsToString(variations)))
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{seamless && (
|
||||||
|
<MetadataItem
|
||||||
|
label="Seamless"
|
||||||
|
value={seamless}
|
||||||
|
onClick={() => dispatch(setSeamless(seamless))}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{hires_fix && (
|
||||||
|
<MetadataItem
|
||||||
|
label="High Resolution Optimization"
|
||||||
|
value={hires_fix}
|
||||||
|
onClick={() => dispatch(setHiresFix(hires_fix))}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{width && (
|
||||||
|
<MetadataItem
|
||||||
|
label="Width"
|
||||||
|
value={width}
|
||||||
|
onClick={() => dispatch(setWidth(width))}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{height && (
|
||||||
|
<MetadataItem
|
||||||
|
label="Height"
|
||||||
|
value={height}
|
||||||
|
onClick={() => dispatch(setHeight(height))}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{/* {init_image_path && (
|
||||||
|
<MetadataItem
|
||||||
|
label="Initial image"
|
||||||
|
value={init_image_path}
|
||||||
|
isLink
|
||||||
|
onClick={() => dispatch(setInitialImage(init_image_path))}
|
||||||
|
/>
|
||||||
|
)} */}
|
||||||
|
{mask_image_path && (
|
||||||
|
<MetadataItem
|
||||||
|
label="Mask image"
|
||||||
|
value={mask_image_path}
|
||||||
|
isLink
|
||||||
|
onClick={() => dispatch(setMaskPath(mask_image_path))}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{type === 'img2img' && strength && (
|
||||||
|
<MetadataItem
|
||||||
|
label="Image to image strength"
|
||||||
|
value={strength}
|
||||||
|
onClick={() => dispatch(setImg2imgStrength(strength))}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{fit && (
|
||||||
|
<MetadataItem
|
||||||
|
label="Image to image fit"
|
||||||
|
value={fit}
|
||||||
|
onClick={() => dispatch(setShouldFitToWidthHeight(fit))}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{postprocessing && postprocessing.length > 0 && (
|
||||||
|
<>
|
||||||
|
<Heading size="sm">Postprocessing</Heading>
|
||||||
|
{postprocessing.map(
|
||||||
|
(
|
||||||
|
postprocess: InvokeAI.PostProcessedImageMetadata,
|
||||||
|
i: number
|
||||||
|
) => {
|
||||||
|
if (postprocess.type === 'esrgan') {
|
||||||
|
const { scale, strength, denoise_str } = postprocess;
|
||||||
|
return (
|
||||||
|
<Flex key={i} pl={8} gap={1} direction="column">
|
||||||
|
<Text size="md">{`${i + 1}: Upscale (ESRGAN)`}</Text>
|
||||||
|
<MetadataItem
|
||||||
|
label="Scale"
|
||||||
|
value={scale}
|
||||||
|
onClick={() => dispatch(setUpscalingLevel(scale))}
|
||||||
|
/>
|
||||||
|
<MetadataItem
|
||||||
|
label="Strength"
|
||||||
|
value={strength}
|
||||||
|
onClick={() =>
|
||||||
|
dispatch(setUpscalingStrength(strength))
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
{denoise_str !== undefined && (
|
||||||
|
<MetadataItem
|
||||||
|
label="Denoising strength"
|
||||||
|
value={denoise_str}
|
||||||
|
onClick={() =>
|
||||||
|
dispatch(setUpscalingDenoising(denoise_str))
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
} else if (postprocess.type === 'gfpgan') {
|
||||||
|
const { strength } = postprocess;
|
||||||
|
return (
|
||||||
|
<Flex key={i} pl={8} gap={1} direction="column">
|
||||||
|
<Text size="md">{`${
|
||||||
|
i + 1
|
||||||
|
}: Face restoration (GFPGAN)`}</Text>
|
||||||
|
|
||||||
|
<MetadataItem
|
||||||
|
label="Strength"
|
||||||
|
value={strength}
|
||||||
|
onClick={() => {
|
||||||
|
dispatch(setFacetoolStrength(strength));
|
||||||
|
dispatch(setFacetoolType('gfpgan'));
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
} else if (postprocess.type === 'codeformer') {
|
||||||
|
const { strength, fidelity } = postprocess;
|
||||||
|
return (
|
||||||
|
<Flex key={i} pl={8} gap={1} direction="column">
|
||||||
|
<Text size="md">{`${
|
||||||
|
i + 1
|
||||||
|
}: Face restoration (Codeformer)`}</Text>
|
||||||
|
|
||||||
|
<MetadataItem
|
||||||
|
label="Strength"
|
||||||
|
value={strength}
|
||||||
|
onClick={() => {
|
||||||
|
dispatch(setFacetoolStrength(strength));
|
||||||
|
dispatch(setFacetoolType('codeformer'));
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
{fidelity && (
|
||||||
|
<MetadataItem
|
||||||
|
label="Fidelity"
|
||||||
|
value={fidelity}
|
||||||
|
onClick={() => {
|
||||||
|
dispatch(setCodeformerFidelity(fidelity));
|
||||||
|
dispatch(setFacetoolType('codeformer'));
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)}
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
{dreamPrompt && (
|
||||||
|
<MetadataItem withCopy label="Dream Prompt" value={dreamPrompt} />
|
||||||
|
)}
|
||||||
|
</>
|
||||||
|
) : (
|
||||||
|
<Center width="100%" pt={10}>
|
||||||
|
<Text fontSize="lg" fontWeight="semibold">
|
||||||
|
No metadata available
|
||||||
|
</Text>
|
||||||
|
</Center>
|
||||||
|
)}
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
}, memoEqualityCheck);
|
||||||
|
|
||||||
|
ImageMetadataViewer.displayName = 'ImageMetadataViewer';
|
||||||
|
|
||||||
|
export default ImageMetadataViewer;
|
@ -0,0 +1,35 @@
|
|||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { useAppSelector } from 'app/storeHooks';
|
||||||
|
import { ImageType } from 'services/api';
|
||||||
|
import { selectResultsEntities } from '../store/resultsSlice';
|
||||||
|
import { selectUploadsEntities } from '../store/uploadsSlice';
|
||||||
|
|
||||||
|
const useGetImageByNameSelector = createSelector(
|
||||||
|
[selectResultsEntities, selectUploadsEntities],
|
||||||
|
(allResults, allUploads) => {
|
||||||
|
return { allResults, allUploads };
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
const useGetImageByNameAndType = () => {
|
||||||
|
const { allResults, allUploads } = useAppSelector(useGetImageByNameSelector);
|
||||||
|
|
||||||
|
return (name: string, type: ImageType) => {
|
||||||
|
if (type === 'results') {
|
||||||
|
const resultImagesResult = allResults[name];
|
||||||
|
|
||||||
|
if (resultImagesResult) {
|
||||||
|
return resultImagesResult;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (type === 'uploads') {
|
||||||
|
const userImagesResult = allUploads[name];
|
||||||
|
if (userImagesResult) {
|
||||||
|
return userImagesResult;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
export default useGetImageByNameAndType;
|
@ -0,0 +1,17 @@
|
|||||||
|
import { GalleryState } from './gallerySlice';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gallery slice persist blacklist
|
||||||
|
*/
|
||||||
|
const itemsToBlacklist: (keyof GalleryState)[] = [
|
||||||
|
'categories',
|
||||||
|
'currentCategory',
|
||||||
|
'currentImage',
|
||||||
|
'currentImageUuid',
|
||||||
|
'shouldAutoSwitchToNewImages',
|
||||||
|
'intermediateImage',
|
||||||
|
];
|
||||||
|
|
||||||
|
export const galleryBlacklist = itemsToBlacklist.map(
|
||||||
|
(blacklistItem) => `gallery.${blacklistItem}`
|
||||||
|
);
|
@ -7,6 +7,16 @@ import {
|
|||||||
uiSelector,
|
uiSelector,
|
||||||
} from 'features/ui/store/uiSelectors';
|
} from 'features/ui/store/uiSelectors';
|
||||||
import { isEqual } from 'lodash';
|
import { isEqual } from 'lodash';
|
||||||
|
import {
|
||||||
|
selectResultsAll,
|
||||||
|
selectResultsById,
|
||||||
|
selectResultsEntities,
|
||||||
|
} from './resultsSlice';
|
||||||
|
import {
|
||||||
|
selectUploadsAll,
|
||||||
|
selectUploadsById,
|
||||||
|
selectUploadsEntities,
|
||||||
|
} from './uploadsSlice';
|
||||||
|
|
||||||
export const gallerySelector = (state: RootState) => state.gallery;
|
export const gallerySelector = (state: RootState) => state.gallery;
|
||||||
|
|
||||||
@ -75,3 +85,18 @@ export const hoverableImageSelector = createSelector(
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
|
export const selectedImageSelector = createSelector(
|
||||||
|
[gallerySelector, selectResultsEntities, selectUploadsEntities],
|
||||||
|
(gallery, allResults, allUploads) => {
|
||||||
|
const selectedImageName = gallery.selectedImageName;
|
||||||
|
|
||||||
|
if (selectedImageName in allResults) {
|
||||||
|
return allResults[selectedImageName];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (selectedImageName in allUploads) {
|
||||||
|
return allUploads[selectedImageName];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
);
|
||||||
|
@ -1,14 +1,17 @@
|
|||||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||||
import { createSlice } from '@reduxjs/toolkit';
|
import { createSlice } from '@reduxjs/toolkit';
|
||||||
import * as InvokeAI from 'app/invokeai';
|
import * as InvokeAI from 'app/invokeai';
|
||||||
|
import { invocationComplete } from 'services/events/actions';
|
||||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||||
import { IRect } from 'konva/lib/types';
|
import { IRect } from 'konva/lib/types';
|
||||||
import { clamp } from 'lodash';
|
import { clamp } from 'lodash';
|
||||||
|
import { isImageOutput } from 'services/types/guards';
|
||||||
|
import { imageUploaded } from 'services/thunks/image';
|
||||||
|
|
||||||
export type GalleryCategory = 'user' | 'result';
|
export type GalleryCategory = 'user' | 'result';
|
||||||
|
|
||||||
export type AddImagesPayload = {
|
export type AddImagesPayload = {
|
||||||
images: Array<InvokeAI.Image>;
|
images: Array<InvokeAI._Image>;
|
||||||
areMoreImagesAvailable: boolean;
|
areMoreImagesAvailable: boolean;
|
||||||
category: GalleryCategory;
|
category: GalleryCategory;
|
||||||
};
|
};
|
||||||
@ -16,16 +19,33 @@ export type AddImagesPayload = {
|
|||||||
type GalleryImageObjectFitType = 'contain' | 'cover';
|
type GalleryImageObjectFitType = 'contain' | 'cover';
|
||||||
|
|
||||||
export type Gallery = {
|
export type Gallery = {
|
||||||
images: InvokeAI.Image[];
|
images: InvokeAI._Image[];
|
||||||
latest_mtime?: number;
|
latest_mtime?: number;
|
||||||
earliest_mtime?: number;
|
earliest_mtime?: number;
|
||||||
areMoreImagesAvailable: boolean;
|
areMoreImagesAvailable: boolean;
|
||||||
};
|
};
|
||||||
|
|
||||||
export interface GalleryState {
|
export interface GalleryState {
|
||||||
currentImage?: InvokeAI.Image;
|
/**
|
||||||
|
* The selected image's unique name
|
||||||
|
* Use `selectedImageSelector` to access the image
|
||||||
|
*/
|
||||||
|
selectedImageName: string;
|
||||||
|
/**
|
||||||
|
* The currently selected image
|
||||||
|
* @deprecated See `state.gallery.selectedImageName`
|
||||||
|
*/
|
||||||
|
currentImage?: InvokeAI._Image;
|
||||||
|
/**
|
||||||
|
* The currently selected image's uuid.
|
||||||
|
* @deprecated See `state.gallery.selectedImageName`, use `selectedImageSelector` to access the image
|
||||||
|
*/
|
||||||
currentImageUuid: string;
|
currentImageUuid: string;
|
||||||
intermediateImage?: InvokeAI.Image & {
|
/**
|
||||||
|
* The current progress image
|
||||||
|
* @deprecated See `state.system.progressImage`
|
||||||
|
*/
|
||||||
|
intermediateImage?: InvokeAI._Image & {
|
||||||
boundingBox?: IRect;
|
boundingBox?: IRect;
|
||||||
generationMode?: InvokeTabName;
|
generationMode?: InvokeTabName;
|
||||||
};
|
};
|
||||||
@ -42,6 +62,7 @@ export interface GalleryState {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const initialState: GalleryState = {
|
const initialState: GalleryState = {
|
||||||
|
selectedImageName: '',
|
||||||
currentImageUuid: '',
|
currentImageUuid: '',
|
||||||
galleryImageMinimumWidth: 64,
|
galleryImageMinimumWidth: 64,
|
||||||
galleryImageObjectFit: 'cover',
|
galleryImageObjectFit: 'cover',
|
||||||
@ -69,7 +90,10 @@ export const gallerySlice = createSlice({
|
|||||||
name: 'gallery',
|
name: 'gallery',
|
||||||
initialState,
|
initialState,
|
||||||
reducers: {
|
reducers: {
|
||||||
setCurrentImage: (state, action: PayloadAction<InvokeAI.Image>) => {
|
imageSelected: (state, action: PayloadAction<string>) => {
|
||||||
|
state.selectedImageName = action.payload;
|
||||||
|
},
|
||||||
|
setCurrentImage: (state, action: PayloadAction<InvokeAI._Image>) => {
|
||||||
state.currentImage = action.payload;
|
state.currentImage = action.payload;
|
||||||
state.currentImageUuid = action.payload.uuid;
|
state.currentImageUuid = action.payload.uuid;
|
||||||
},
|
},
|
||||||
@ -124,7 +148,7 @@ export const gallerySlice = createSlice({
|
|||||||
addImage: (
|
addImage: (
|
||||||
state,
|
state,
|
||||||
action: PayloadAction<{
|
action: PayloadAction<{
|
||||||
image: InvokeAI.Image;
|
image: InvokeAI._Image;
|
||||||
category: GalleryCategory;
|
category: GalleryCategory;
|
||||||
}>
|
}>
|
||||||
) => {
|
) => {
|
||||||
@ -150,7 +174,10 @@ export const gallerySlice = createSlice({
|
|||||||
setIntermediateImage: (
|
setIntermediateImage: (
|
||||||
state,
|
state,
|
||||||
action: PayloadAction<
|
action: PayloadAction<
|
||||||
InvokeAI.Image & { boundingBox?: IRect; generationMode?: InvokeTabName }
|
InvokeAI._Image & {
|
||||||
|
boundingBox?: IRect;
|
||||||
|
generationMode?: InvokeTabName;
|
||||||
|
}
|
||||||
>
|
>
|
||||||
) => {
|
) => {
|
||||||
state.intermediateImage = action.payload;
|
state.intermediateImage = action.payload;
|
||||||
@ -252,9 +279,31 @@ export const gallerySlice = createSlice({
|
|||||||
state.shouldUseSingleGalleryColumn = action.payload;
|
state.shouldUseSingleGalleryColumn = action.payload;
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
extraReducers(builder) {
|
||||||
|
/**
|
||||||
|
* Invocation Complete
|
||||||
|
*/
|
||||||
|
builder.addCase(invocationComplete, (state, action) => {
|
||||||
|
const { data } = action.payload;
|
||||||
|
if (isImageOutput(data.result)) {
|
||||||
|
state.selectedImageName = data.result.image.image_name;
|
||||||
|
state.intermediateImage = undefined;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Upload Image - FULFILLED
|
||||||
|
*/
|
||||||
|
builder.addCase(imageUploaded.fulfilled, (state, action) => {
|
||||||
|
const { location } = action.payload;
|
||||||
|
const imageName = location.split('/').pop() || '';
|
||||||
|
state.selectedImageName = imageName;
|
||||||
|
});
|
||||||
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
export const {
|
export const {
|
||||||
|
imageSelected,
|
||||||
addImage,
|
addImage,
|
||||||
clearIntermediateImage,
|
clearIntermediateImage,
|
||||||
removeImage,
|
removeImage,
|
||||||
|
@ -0,0 +1,12 @@
|
|||||||
|
import { ResultsState } from './resultsSlice';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Results slice persist blacklist
|
||||||
|
*
|
||||||
|
* Currently blacklisting results slice entirely, see persist config in store.ts
|
||||||
|
*/
|
||||||
|
const itemsToBlacklist: (keyof ResultsState)[] = [];
|
||||||
|
|
||||||
|
export const resultsBlacklist = itemsToBlacklist.map(
|
||||||
|
(blacklistItem) => `results.${blacklistItem}`
|
||||||
|
);
|
139
invokeai/frontend/web/src/features/gallery/store/resultsSlice.ts
Normal file
139
invokeai/frontend/web/src/features/gallery/store/resultsSlice.ts
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
|
||||||
|
import { Image } from 'app/invokeai';
|
||||||
|
import { invocationComplete } from 'services/events/actions';
|
||||||
|
|
||||||
|
import { RootState } from 'app/store';
|
||||||
|
import {
|
||||||
|
receivedResultImagesPage,
|
||||||
|
IMAGES_PER_PAGE,
|
||||||
|
} from 'services/thunks/gallery';
|
||||||
|
import { isImageOutput } from 'services/types/guards';
|
||||||
|
import {
|
||||||
|
buildImageUrls,
|
||||||
|
extractTimestampFromImageName,
|
||||||
|
} from 'services/util/deserializeImageField';
|
||||||
|
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
|
||||||
|
|
||||||
|
// use `createEntityAdapter` to create a slice for results images
|
||||||
|
// https://redux-toolkit.js.org/api/createEntityAdapter#overview
|
||||||
|
|
||||||
|
// the "Entity" is InvokeAI.ResultImage, while the "entities" are instances of that type
|
||||||
|
export const resultsAdapter = createEntityAdapter<Image>({
|
||||||
|
// Provide a callback to get a stable, unique identifier for each entity. This defaults to
|
||||||
|
// `(item) => item.id`, but for our result images, the `name` is the unique identifier.
|
||||||
|
selectId: (image) => image.name,
|
||||||
|
// Order all images by their time (in descending order)
|
||||||
|
sortComparer: (a, b) => b.metadata.created - a.metadata.created,
|
||||||
|
});
|
||||||
|
|
||||||
|
// This type is intersected with the Entity type to create the shape of the state
|
||||||
|
type AdditionalResultsState = {
|
||||||
|
// these are a bit misleading; they refer to sessions, not results, but we don't have a route
|
||||||
|
// to list all images directly at this time...
|
||||||
|
page: number; // current page we are on
|
||||||
|
pages: number; // the total number of pages available
|
||||||
|
isLoading: boolean; // whether we are loading more images or not, mostly a placeholder
|
||||||
|
nextPage: number; // the next page to request
|
||||||
|
};
|
||||||
|
|
||||||
|
export const initialResultsState =
|
||||||
|
resultsAdapter.getInitialState<AdditionalResultsState>({
|
||||||
|
// provide the additional initial state
|
||||||
|
page: 0,
|
||||||
|
pages: 0,
|
||||||
|
isLoading: false,
|
||||||
|
nextPage: 0,
|
||||||
|
});
|
||||||
|
|
||||||
|
export type ResultsState = typeof initialResultsState;
|
||||||
|
|
||||||
|
const resultsSlice = createSlice({
|
||||||
|
name: 'results',
|
||||||
|
initialState: initialResultsState,
|
||||||
|
reducers: {
|
||||||
|
// the adapter provides some helper reducers; see the docs for all of them
|
||||||
|
// can use them as helper functions within a reducer, or use the function itself as a reducer
|
||||||
|
|
||||||
|
// here we just use the function itself as the reducer. we'll call this on `invocation_complete`
|
||||||
|
// to add a single result
|
||||||
|
resultAdded: resultsAdapter.upsertOne,
|
||||||
|
},
|
||||||
|
extraReducers: (builder) => {
|
||||||
|
// here we can respond to a fulfilled call of the `getNextResultsPage` thunk
|
||||||
|
// because we pass in the fulfilled thunk action creator, everything is typed
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Received Result Images Page - PENDING
|
||||||
|
*/
|
||||||
|
builder.addCase(receivedResultImagesPage.pending, (state) => {
|
||||||
|
state.isLoading = true;
|
||||||
|
});
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Received Result Images Page - FULFILLED
|
||||||
|
*/
|
||||||
|
builder.addCase(receivedResultImagesPage.fulfilled, (state, action) => {
|
||||||
|
const { items, page, pages } = action.payload;
|
||||||
|
|
||||||
|
const resultImages = items.map((image) =>
|
||||||
|
deserializeImageResponse(image)
|
||||||
|
);
|
||||||
|
|
||||||
|
// use the adapter reducer to append all the results to state
|
||||||
|
resultsAdapter.addMany(state, resultImages);
|
||||||
|
|
||||||
|
state.page = page;
|
||||||
|
state.pages = pages;
|
||||||
|
state.nextPage = items.length < IMAGES_PER_PAGE ? page : page + 1;
|
||||||
|
state.isLoading = false;
|
||||||
|
});
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Invocation Complete
|
||||||
|
*/
|
||||||
|
builder.addCase(invocationComplete, (state, action) => {
|
||||||
|
const { data } = action.payload;
|
||||||
|
const { result, node, graph_execution_state_id } = data;
|
||||||
|
|
||||||
|
if (isImageOutput(result)) {
|
||||||
|
const name = result.image.image_name;
|
||||||
|
const type = result.image.image_type;
|
||||||
|
const { url, thumbnail } = buildImageUrls(type, name);
|
||||||
|
|
||||||
|
const timestamp = extractTimestampFromImageName(name);
|
||||||
|
|
||||||
|
const image: Image = {
|
||||||
|
name,
|
||||||
|
type,
|
||||||
|
url,
|
||||||
|
thumbnail,
|
||||||
|
metadata: {
|
||||||
|
created: timestamp,
|
||||||
|
width: result.width, // TODO: add tese dimensions
|
||||||
|
height: result.height,
|
||||||
|
invokeai: {
|
||||||
|
session_id: graph_execution_state_id,
|
||||||
|
...(node ? { node } : {}),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
resultsAdapter.addOne(state, image);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
// Create a set of memoized selectors based on the location of this entity state
|
||||||
|
// to be used as selectors in a `useAppSelector()` call
|
||||||
|
export const {
|
||||||
|
selectAll: selectResultsAll,
|
||||||
|
selectById: selectResultsById,
|
||||||
|
selectEntities: selectResultsEntities,
|
||||||
|
selectIds: selectResultsIds,
|
||||||
|
selectTotal: selectResultsTotal,
|
||||||
|
} = resultsAdapter.getSelectors<RootState>((state) => state.results);
|
||||||
|
|
||||||
|
export const { resultAdded } = resultsSlice.actions;
|
||||||
|
|
||||||
|
export default resultsSlice.reducer;
|
@ -1,54 +0,0 @@
|
|||||||
import { AnyAction, ThunkAction } from '@reduxjs/toolkit';
|
|
||||||
import * as InvokeAI from 'app/invokeai';
|
|
||||||
import { RootState } from 'app/store';
|
|
||||||
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
|
||||||
import { setInitialImage } from 'features/parameters/store/generationSlice';
|
|
||||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
|
||||||
import { v4 as uuidv4 } from 'uuid';
|
|
||||||
import { addImage } from '../gallerySlice';
|
|
||||||
|
|
||||||
type UploadImageConfig = {
|
|
||||||
imageFile: File;
|
|
||||||
};
|
|
||||||
|
|
||||||
export const uploadImage =
|
|
||||||
(
|
|
||||||
config: UploadImageConfig
|
|
||||||
): ThunkAction<void, RootState, unknown, AnyAction> =>
|
|
||||||
async (dispatch, getState) => {
|
|
||||||
const { imageFile } = config;
|
|
||||||
|
|
||||||
const state = getState() as RootState;
|
|
||||||
|
|
||||||
const activeTabName = activeTabNameSelector(state);
|
|
||||||
|
|
||||||
const formData = new FormData();
|
|
||||||
|
|
||||||
formData.append('file', imageFile, imageFile.name);
|
|
||||||
formData.append(
|
|
||||||
'data',
|
|
||||||
JSON.stringify({
|
|
||||||
kind: 'init',
|
|
||||||
})
|
|
||||||
);
|
|
||||||
|
|
||||||
const response = await fetch(`${window.location.origin}/upload`, {
|
|
||||||
method: 'POST',
|
|
||||||
body: formData,
|
|
||||||
});
|
|
||||||
|
|
||||||
const image = (await response.json()) as InvokeAI.ImageUploadResponse;
|
|
||||||
const newImage: InvokeAI.Image = {
|
|
||||||
uuid: uuidv4(),
|
|
||||||
category: 'user',
|
|
||||||
...image,
|
|
||||||
};
|
|
||||||
|
|
||||||
dispatch(addImage({ image: newImage, category: 'user' }));
|
|
||||||
|
|
||||||
if (activeTabName === 'unifiedCanvas') {
|
|
||||||
dispatch(setInitialCanvasImage(newImage));
|
|
||||||
} else if (activeTabName === 'img2img') {
|
|
||||||
dispatch(setInitialImage(newImage));
|
|
||||||
}
|
|
||||||
};
|
|
@ -0,0 +1,12 @@
|
|||||||
|
import { UploadsState } from './uploadsSlice';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Uploads slice persist blacklist
|
||||||
|
*
|
||||||
|
* Currently blacklisting uploads slice entirely, see persist config in store.ts
|
||||||
|
*/
|
||||||
|
const itemsToBlacklist: (keyof UploadsState)[] = [];
|
||||||
|
|
||||||
|
export const uploadsBlacklist = itemsToBlacklist.map(
|
||||||
|
(blacklistItem) => `uploads.${blacklistItem}`
|
||||||
|
);
|
@ -0,0 +1,87 @@
|
|||||||
|
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
|
||||||
|
import { Image } from 'app/invokeai';
|
||||||
|
|
||||||
|
import { RootState } from 'app/store';
|
||||||
|
import {
|
||||||
|
receivedUploadImagesPage,
|
||||||
|
IMAGES_PER_PAGE,
|
||||||
|
} from 'services/thunks/gallery';
|
||||||
|
import { imageUploaded } from 'services/thunks/image';
|
||||||
|
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
|
||||||
|
|
||||||
|
export const uploadsAdapter = createEntityAdapter<Image>({
|
||||||
|
selectId: (image) => image.name,
|
||||||
|
sortComparer: (a, b) => b.metadata.created - a.metadata.created,
|
||||||
|
});
|
||||||
|
|
||||||
|
type AdditionalUploadsState = {
|
||||||
|
page: number;
|
||||||
|
pages: number;
|
||||||
|
isLoading: boolean;
|
||||||
|
nextPage: number;
|
||||||
|
};
|
||||||
|
|
||||||
|
const initialUploadsState =
|
||||||
|
uploadsAdapter.getInitialState<AdditionalUploadsState>({
|
||||||
|
page: 0,
|
||||||
|
pages: 0,
|
||||||
|
nextPage: 0,
|
||||||
|
isLoading: false,
|
||||||
|
});
|
||||||
|
|
||||||
|
export type UploadsState = typeof initialUploadsState;
|
||||||
|
|
||||||
|
const uploadsSlice = createSlice({
|
||||||
|
name: 'uploads',
|
||||||
|
initialState: initialUploadsState,
|
||||||
|
reducers: {
|
||||||
|
uploadAdded: uploadsAdapter.addOne,
|
||||||
|
},
|
||||||
|
extraReducers: (builder) => {
|
||||||
|
/**
|
||||||
|
* Received Upload Images Page - PENDING
|
||||||
|
*/
|
||||||
|
builder.addCase(receivedUploadImagesPage.pending, (state) => {
|
||||||
|
state.isLoading = true;
|
||||||
|
});
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Received Upload Images Page - FULFILLED
|
||||||
|
*/
|
||||||
|
builder.addCase(receivedUploadImagesPage.fulfilled, (state, action) => {
|
||||||
|
const { items, page, pages } = action.payload;
|
||||||
|
|
||||||
|
const images = items.map((image) => deserializeImageResponse(image));
|
||||||
|
|
||||||
|
uploadsAdapter.addMany(state, images);
|
||||||
|
|
||||||
|
state.page = page;
|
||||||
|
state.pages = pages;
|
||||||
|
state.nextPage = items.length < IMAGES_PER_PAGE ? page : page + 1;
|
||||||
|
state.isLoading = false;
|
||||||
|
});
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Upload Image - FULFILLED
|
||||||
|
*/
|
||||||
|
builder.addCase(imageUploaded.fulfilled, (state, action) => {
|
||||||
|
const { location, response } = action.payload;
|
||||||
|
|
||||||
|
const uploadedImage = deserializeImageResponse(response);
|
||||||
|
|
||||||
|
uploadsAdapter.addOne(state, uploadedImage);
|
||||||
|
});
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
export const {
|
||||||
|
selectAll: selectUploadsAll,
|
||||||
|
selectById: selectUploadsById,
|
||||||
|
selectEntities: selectUploadsEntities,
|
||||||
|
selectIds: selectUploadsIds,
|
||||||
|
selectTotal: selectUploadsTotal,
|
||||||
|
} = uploadsAdapter.getSelectors<RootState>((state) => state.uploads);
|
||||||
|
|
||||||
|
export const { uploadAdded } = uploadsSlice.actions;
|
||||||
|
|
||||||
|
export default uploadsSlice.reducer;
|
@ -1,9 +1,10 @@
|
|||||||
import * as React from 'react';
|
import * as React from 'react';
|
||||||
import { TransformComponent, useTransformContext } from 'react-zoom-pan-pinch';
|
import { TransformComponent, useTransformContext } from 'react-zoom-pan-pinch';
|
||||||
import * as InvokeAI from 'app/invokeai';
|
import * as InvokeAI from 'app/invokeai';
|
||||||
|
import { useGetUrl } from 'common/util/getUrl';
|
||||||
|
|
||||||
type ReactPanZoomProps = {
|
type ReactPanZoomProps = {
|
||||||
image: InvokeAI.Image;
|
image: InvokeAI._Image;
|
||||||
styleClass?: string;
|
styleClass?: string;
|
||||||
alt?: string;
|
alt?: string;
|
||||||
ref?: React.Ref<HTMLImageElement>;
|
ref?: React.Ref<HTMLImageElement>;
|
||||||
@ -22,6 +23,7 @@ export default function ReactPanZoomImage({
|
|||||||
scaleY,
|
scaleY,
|
||||||
}: ReactPanZoomProps) {
|
}: ReactPanZoomProps) {
|
||||||
const { centerView } = useTransformContext();
|
const { centerView } = useTransformContext();
|
||||||
|
const { getUrl } = useGetUrl();
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<TransformComponent
|
<TransformComponent
|
||||||
@ -35,7 +37,7 @@ export default function ReactPanZoomImage({
|
|||||||
transform: `rotate(${rotation}deg) scaleX(${scaleX}) scaleY(${scaleY})`,
|
transform: `rotate(${rotation}deg) scaleX(${scaleX}) scaleY(${scaleY})`,
|
||||||
width: '100%',
|
width: '100%',
|
||||||
}}
|
}}
|
||||||
src={image.url}
|
src={getUrl(image.url)}
|
||||||
alt={alt}
|
alt={alt}
|
||||||
ref={ref}
|
ref={ref}
|
||||||
className={styleClass ? styleClass : ''}
|
className={styleClass ? styleClass : ''}
|
||||||
|
@ -0,0 +1,10 @@
|
|||||||
|
import { LightboxState } from './lightboxSlice';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Lightbox slice persist blacklist
|
||||||
|
*/
|
||||||
|
const itemsToBlacklist: (keyof LightboxState)[] = ['isLightboxOpen'];
|
||||||
|
|
||||||
|
export const lightboxBlacklist = itemsToBlacklist.map(
|
||||||
|
(blacklistItem) => `lightbox.${blacklistItem}`
|
||||||
|
);
|
@ -0,0 +1,63 @@
|
|||||||
|
import { v4 as uuidv4 } from 'uuid';
|
||||||
|
|
||||||
|
import 'reactflow/dist/style.css';
|
||||||
|
import { useCallback } from 'react';
|
||||||
|
import {
|
||||||
|
Tooltip,
|
||||||
|
Menu,
|
||||||
|
MenuButton,
|
||||||
|
MenuList,
|
||||||
|
MenuItem,
|
||||||
|
IconButton,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
import { FaPlus } from 'react-icons/fa';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||||
|
import { nodeAdded } from '../store/nodesSlice';
|
||||||
|
import { cloneDeep, map } from 'lodash';
|
||||||
|
import { RootState } from 'app/store';
|
||||||
|
import { useBuildInvocation } from '../hooks/useBuildInvocation';
|
||||||
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { makeToast } from 'features/system/hooks/useToastWatcher';
|
||||||
|
|
||||||
|
export const AddNodeMenu = () => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const invocationTemplates = useAppSelector(
|
||||||
|
(state: RootState) => state.nodes.invocationTemplates
|
||||||
|
);
|
||||||
|
|
||||||
|
const buildInvocation = useBuildInvocation();
|
||||||
|
|
||||||
|
const addNode = useCallback(
|
||||||
|
(nodeType: string) => {
|
||||||
|
const invocation = buildInvocation(nodeType);
|
||||||
|
|
||||||
|
if (!invocation) {
|
||||||
|
const toast = makeToast({
|
||||||
|
status: 'error',
|
||||||
|
title: `Unknown Invocation type ${nodeType}`,
|
||||||
|
});
|
||||||
|
dispatch(addToast(toast));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(nodeAdded(invocation));
|
||||||
|
},
|
||||||
|
[dispatch, buildInvocation]
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Menu>
|
||||||
|
<MenuButton as={IconButton} aria-label="Add Node" icon={<FaPlus />} />
|
||||||
|
<MenuList>
|
||||||
|
{map(invocationTemplates, ({ title, description, type }, key) => {
|
||||||
|
return (
|
||||||
|
<Tooltip key={key} label={description} placement="end" hasArrow>
|
||||||
|
<MenuItem onClick={() => addNode(type)}>{title}</MenuItem>
|
||||||
|
</Tooltip>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
</MenuList>
|
||||||
|
</Menu>
|
||||||
|
);
|
||||||
|
};
|
@ -0,0 +1,69 @@
|
|||||||
|
import { Tooltip } from '@chakra-ui/react';
|
||||||
|
import { CSSProperties, useMemo } from 'react';
|
||||||
|
import {
|
||||||
|
Handle,
|
||||||
|
Position,
|
||||||
|
Connection,
|
||||||
|
HandleType,
|
||||||
|
useReactFlow,
|
||||||
|
} from 'reactflow';
|
||||||
|
import { FIELDS, HANDLE_TOOLTIP_OPEN_DELAY } from '../types/constants';
|
||||||
|
// import { useConnectionEventStyles } from '../hooks/useConnectionEventStyles';
|
||||||
|
import { InputFieldTemplate, OutputFieldTemplate } from '../types/types';
|
||||||
|
|
||||||
|
const handleBaseStyles: CSSProperties = {
|
||||||
|
position: 'absolute',
|
||||||
|
width: '1rem',
|
||||||
|
height: '1rem',
|
||||||
|
borderWidth: 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
const inputHandleStyles: CSSProperties = {
|
||||||
|
left: '-1.7rem',
|
||||||
|
};
|
||||||
|
|
||||||
|
const outputHandleStyles: CSSProperties = {
|
||||||
|
right: '-1.7rem',
|
||||||
|
};
|
||||||
|
|
||||||
|
const requiredConnectionStyles: CSSProperties = {
|
||||||
|
boxShadow: '0 0 0.5rem 0.5rem var(--invokeai-colors-error-400)',
|
||||||
|
};
|
||||||
|
|
||||||
|
type FieldHandleProps = {
|
||||||
|
nodeId: string;
|
||||||
|
field: InputFieldTemplate | OutputFieldTemplate;
|
||||||
|
isValidConnection: (connection: Connection) => boolean;
|
||||||
|
handleType: HandleType;
|
||||||
|
styles?: CSSProperties;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const FieldHandle = (props: FieldHandleProps) => {
|
||||||
|
const { nodeId, field, isValidConnection, handleType, styles } = props;
|
||||||
|
const { name, title, type, description } = field;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Tooltip
|
||||||
|
key={name}
|
||||||
|
label={type}
|
||||||
|
placement={handleType === 'target' ? 'start' : 'end'}
|
||||||
|
hasArrow
|
||||||
|
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
|
||||||
|
>
|
||||||
|
<Handle
|
||||||
|
type={handleType}
|
||||||
|
id={name}
|
||||||
|
isValidConnection={isValidConnection}
|
||||||
|
position={handleType === 'target' ? Position.Left : Position.Right}
|
||||||
|
style={{
|
||||||
|
backgroundColor: FIELDS[type].colorCssVar,
|
||||||
|
...styles,
|
||||||
|
...handleBaseStyles,
|
||||||
|
...(handleType === 'target' ? inputHandleStyles : outputHandleStyles),
|
||||||
|
// ...(inputRequirement === 'always' ? requiredConnectionStyles : {}),
|
||||||
|
// ...connectionEventStyles,
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</Tooltip>
|
||||||
|
);
|
||||||
|
};
|
@ -0,0 +1,18 @@
|
|||||||
|
import 'reactflow/dist/style.css';
|
||||||
|
import { Tooltip, Badge, HStack } from '@chakra-ui/react';
|
||||||
|
import { map } from 'lodash';
|
||||||
|
import { FIELDS } from '../types/constants';
|
||||||
|
|
||||||
|
export const FieldTypeLegend = () => {
|
||||||
|
return (
|
||||||
|
<HStack>
|
||||||
|
{map(FIELDS, ({ title, description, color }, key) => (
|
||||||
|
<Tooltip key={key} label={description}>
|
||||||
|
<Badge colorScheme={color} sx={{ userSelect: 'none' }}>
|
||||||
|
{title}
|
||||||
|
</Badge>
|
||||||
|
</Tooltip>
|
||||||
|
))}
|
||||||
|
</HStack>
|
||||||
|
);
|
||||||
|
};
|
104
invokeai/frontend/web/src/features/nodes/components/Flow.tsx
Normal file
104
invokeai/frontend/web/src/features/nodes/components/Flow.tsx
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
import {
|
||||||
|
Background,
|
||||||
|
Controls,
|
||||||
|
MiniMap,
|
||||||
|
OnConnect,
|
||||||
|
OnEdgesChange,
|
||||||
|
OnNodesChange,
|
||||||
|
ReactFlow,
|
||||||
|
ConnectionLineType,
|
||||||
|
OnConnectStart,
|
||||||
|
OnConnectEnd,
|
||||||
|
Panel,
|
||||||
|
} from 'reactflow';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||||
|
import { RootState } from 'app/store';
|
||||||
|
import {
|
||||||
|
connectionEnded,
|
||||||
|
connectionMade,
|
||||||
|
connectionStarted,
|
||||||
|
edgesChanged,
|
||||||
|
nodesChanged,
|
||||||
|
} from '../store/nodesSlice';
|
||||||
|
import { useCallback } from 'react';
|
||||||
|
import { InvocationComponent } from './InvocationComponent';
|
||||||
|
import { AddNodeMenu } from './AddNodeMenu';
|
||||||
|
import { FieldTypeLegend } from './FieldTypeLegend';
|
||||||
|
import { Button } from '@chakra-ui/react';
|
||||||
|
import { nodesGraphBuilt } from 'services/thunks/session';
|
||||||
|
|
||||||
|
const nodeTypes = { invocation: InvocationComponent };
|
||||||
|
|
||||||
|
export const Flow = () => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const nodes = useAppSelector((state: RootState) => state.nodes.nodes);
|
||||||
|
const edges = useAppSelector((state: RootState) => state.nodes.edges);
|
||||||
|
|
||||||
|
const onNodesChange: OnNodesChange = useCallback(
|
||||||
|
(changes) => {
|
||||||
|
dispatch(nodesChanged(changes));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
const onEdgesChange: OnEdgesChange = useCallback(
|
||||||
|
(changes) => {
|
||||||
|
dispatch(edgesChanged(changes));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
const onConnectStart: OnConnectStart = useCallback(
|
||||||
|
(event, params) => {
|
||||||
|
dispatch(connectionStarted(params));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
const onConnect: OnConnect = useCallback(
|
||||||
|
(connection) => {
|
||||||
|
dispatch(connectionMade(connection));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
const onConnectEnd: OnConnectEnd = useCallback(
|
||||||
|
(event) => {
|
||||||
|
dispatch(connectionEnded());
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleInvoke = useCallback(() => {
|
||||||
|
dispatch(nodesGraphBuilt());
|
||||||
|
}, [dispatch]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<ReactFlow
|
||||||
|
nodeTypes={nodeTypes}
|
||||||
|
nodes={nodes}
|
||||||
|
edges={edges}
|
||||||
|
onNodesChange={onNodesChange}
|
||||||
|
onEdgesChange={onEdgesChange}
|
||||||
|
onConnectStart={onConnectStart}
|
||||||
|
onConnect={onConnect}
|
||||||
|
onConnectEnd={onConnectEnd}
|
||||||
|
defaultEdgeOptions={{
|
||||||
|
style: { strokeWidth: 2 },
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Panel position="top-left">
|
||||||
|
<AddNodeMenu />
|
||||||
|
</Panel>
|
||||||
|
<Panel position="top-center">
|
||||||
|
<Button onClick={handleInvoke}>Will it blend?</Button>
|
||||||
|
</Panel>
|
||||||
|
<Panel position="top-right">
|
||||||
|
<FieldTypeLegend />
|
||||||
|
</Panel>
|
||||||
|
<Background />
|
||||||
|
<Controls />
|
||||||
|
<MiniMap nodeStrokeWidth={3} zoomable pannable />
|
||||||
|
</ReactFlow>
|
||||||
|
);
|
||||||
|
};
|
@ -0,0 +1,107 @@
|
|||||||
|
import { Box } from '@chakra-ui/react';
|
||||||
|
import { InputFieldTemplate, InputFieldValue } from '../types/types';
|
||||||
|
import { ArrayInputFieldComponent } from './fields/ArrayInputField.tsx';
|
||||||
|
import { BooleanInputFieldComponent } from './fields/BooleanInputFieldComponent';
|
||||||
|
import { EnumInputFieldComponent } from './fields/EnumInputFieldComponent';
|
||||||
|
import { ImageInputFieldComponent } from './fields/ImageInputFieldComponent';
|
||||||
|
import { LatentsInputFieldComponent } from './fields/LatentsInputFieldComponent';
|
||||||
|
import { ModelInputFieldComponent } from './fields/ModelInputFieldComponent';
|
||||||
|
import { NumberInputFieldComponent } from './fields/NumberInputFieldComponent';
|
||||||
|
import { StringInputFieldComponent } from './fields/StringInputFieldComponent';
|
||||||
|
|
||||||
|
type InputFieldComponentProps = {
|
||||||
|
nodeId: string;
|
||||||
|
field: InputFieldValue;
|
||||||
|
template: InputFieldTemplate;
|
||||||
|
};
|
||||||
|
|
||||||
|
// build an individual input element based on the schema
|
||||||
|
export const InputFieldComponent = (props: InputFieldComponentProps) => {
|
||||||
|
const { nodeId, field, template } = props;
|
||||||
|
const { type, value } = field;
|
||||||
|
|
||||||
|
if (type === 'string' && template.type === 'string') {
|
||||||
|
return (
|
||||||
|
<StringInputFieldComponent
|
||||||
|
nodeId={nodeId}
|
||||||
|
field={field}
|
||||||
|
template={template}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (type === 'boolean' && template.type === 'boolean') {
|
||||||
|
return (
|
||||||
|
<BooleanInputFieldComponent
|
||||||
|
nodeId={nodeId}
|
||||||
|
field={field}
|
||||||
|
template={template}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
(type === 'integer' && template.type === 'integer') ||
|
||||||
|
(type === 'float' && template.type === 'float')
|
||||||
|
) {
|
||||||
|
return (
|
||||||
|
<NumberInputFieldComponent
|
||||||
|
nodeId={nodeId}
|
||||||
|
field={field}
|
||||||
|
template={template}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (type === 'enum' && template.type === 'enum') {
|
||||||
|
return (
|
||||||
|
<EnumInputFieldComponent
|
||||||
|
nodeId={nodeId}
|
||||||
|
field={field}
|
||||||
|
template={template}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (type === 'image' && template.type === 'image') {
|
||||||
|
return (
|
||||||
|
<ImageInputFieldComponent
|
||||||
|
nodeId={nodeId}
|
||||||
|
field={field}
|
||||||
|
template={template}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (type === 'latents' && template.type === 'latents') {
|
||||||
|
return (
|
||||||
|
<LatentsInputFieldComponent
|
||||||
|
nodeId={nodeId}
|
||||||
|
field={field}
|
||||||
|
template={template}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (type === 'model' && template.type === 'model') {
|
||||||
|
return (
|
||||||
|
<ModelInputFieldComponent
|
||||||
|
nodeId={nodeId}
|
||||||
|
field={field}
|
||||||
|
template={template}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (type === 'array' && template.type === 'array') {
|
||||||
|
return (
|
||||||
|
<ArrayInputFieldComponent
|
||||||
|
nodeId={nodeId}
|
||||||
|
field={field}
|
||||||
|
template={template}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return <Box p={2}>Unknown field type: {type}</Box>;
|
||||||
|
};
|
@ -0,0 +1,243 @@
|
|||||||
|
import { NodeProps, useReactFlow } from 'reactflow';
|
||||||
|
import {
|
||||||
|
Box,
|
||||||
|
Flex,
|
||||||
|
FormControl,
|
||||||
|
FormLabel,
|
||||||
|
Heading,
|
||||||
|
HStack,
|
||||||
|
Tooltip,
|
||||||
|
Icon,
|
||||||
|
Code,
|
||||||
|
Text,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
import { FaExclamationCircle, FaInfoCircle } from 'react-icons/fa';
|
||||||
|
import { InvocationValue } from '../types/types';
|
||||||
|
import { InputFieldComponent } from './InputFieldComponent';
|
||||||
|
import { FieldHandle } from './FieldHandle';
|
||||||
|
import { isEqual, map, size } from 'lodash';
|
||||||
|
import { memo, useMemo, useRef } from 'react';
|
||||||
|
import { useIsValidConnection } from '../hooks/useIsValidConnection';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { RootState } from 'app/store';
|
||||||
|
import { useAppSelector } from 'app/storeHooks';
|
||||||
|
import { useGetInvocationTemplate } from '../hooks/useInvocationTemplate';
|
||||||
|
|
||||||
|
const connectedInputFieldsSelector = createSelector(
|
||||||
|
[(state: RootState) => state.nodes.edges],
|
||||||
|
(edges) => {
|
||||||
|
// return edges.map((e) => e.targetHandle);
|
||||||
|
return edges;
|
||||||
|
},
|
||||||
|
{
|
||||||
|
memoizeOptions: {
|
||||||
|
resultEqualityCheck: isEqual,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
export const InvocationComponent = memo((props: NodeProps<InvocationValue>) => {
|
||||||
|
const { id: nodeId, data, selected } = props;
|
||||||
|
const { type, inputs, outputs } = data;
|
||||||
|
|
||||||
|
const isValidConnection = useIsValidConnection();
|
||||||
|
|
||||||
|
const connectedInputs = useAppSelector(connectedInputFieldsSelector);
|
||||||
|
const getInvocationTemplate = useGetInvocationTemplate();
|
||||||
|
// TODO: determine if a field/handle is connected and disable the input if so
|
||||||
|
|
||||||
|
const template = useRef(getInvocationTemplate(type));
|
||||||
|
|
||||||
|
if (!template.current) {
|
||||||
|
return (
|
||||||
|
<Box
|
||||||
|
sx={{
|
||||||
|
padding: 4,
|
||||||
|
bg: 'base.800',
|
||||||
|
borderRadius: 'md',
|
||||||
|
boxShadow: 'dark-lg',
|
||||||
|
borderWidth: 2,
|
||||||
|
borderColor: selected ? 'base.400' : 'transparent',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Flex sx={{ alignItems: 'center', justifyContent: 'center' }}>
|
||||||
|
<Icon color="base.400" boxSize={32} as={FaExclamationCircle}></Icon>
|
||||||
|
</Flex>
|
||||||
|
</Box>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Box
|
||||||
|
sx={{
|
||||||
|
padding: 4,
|
||||||
|
bg: 'base.800',
|
||||||
|
borderRadius: 'md',
|
||||||
|
boxShadow: 'dark-lg',
|
||||||
|
borderWidth: 2,
|
||||||
|
borderColor: selected ? 'base.400' : 'transparent',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Flex flexDirection="column" gap={2}>
|
||||||
|
<>
|
||||||
|
<Code>{nodeId}</Code>
|
||||||
|
<HStack justifyContent="space-between">
|
||||||
|
<Heading size="sm" fontWeight={500} color="base.100">
|
||||||
|
{template.current.title}
|
||||||
|
</Heading>
|
||||||
|
<Tooltip
|
||||||
|
label={template.current.description}
|
||||||
|
placement="top"
|
||||||
|
hasArrow
|
||||||
|
shouldWrapChildren
|
||||||
|
>
|
||||||
|
<Icon color="base.300" as={FaInfoCircle} />
|
||||||
|
</Tooltip>
|
||||||
|
</HStack>
|
||||||
|
{map(inputs, (input, i) => {
|
||||||
|
const { id: fieldId } = input;
|
||||||
|
const inputTemplate = template.current?.inputs[input.name];
|
||||||
|
|
||||||
|
if (!inputTemplate) {
|
||||||
|
return (
|
||||||
|
<Box
|
||||||
|
key={fieldId}
|
||||||
|
position="relative"
|
||||||
|
p={2}
|
||||||
|
borderWidth={1}
|
||||||
|
borderRadius="md"
|
||||||
|
sx={{
|
||||||
|
borderColor: 'error.400',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<FormControl isDisabled={true}>
|
||||||
|
<HStack justifyContent="space-between" alignItems="center">
|
||||||
|
<FormLabel>Unknown input: {input.name}</FormLabel>
|
||||||
|
</HStack>
|
||||||
|
</FormControl>
|
||||||
|
</Box>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
const isConnected = Boolean(
|
||||||
|
connectedInputs.filter((connectedInput) => {
|
||||||
|
return (
|
||||||
|
connectedInput.target === nodeId &&
|
||||||
|
connectedInput.targetHandle === input.name
|
||||||
|
);
|
||||||
|
}).length
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Box
|
||||||
|
key={fieldId}
|
||||||
|
position="relative"
|
||||||
|
p={2}
|
||||||
|
borderWidth={1}
|
||||||
|
borderRadius="md"
|
||||||
|
sx={{
|
||||||
|
borderColor:
|
||||||
|
!isConnected &&
|
||||||
|
['always', 'connectionOnly'].includes(
|
||||||
|
String(inputTemplate?.inputRequirement)
|
||||||
|
) &&
|
||||||
|
input.value === undefined
|
||||||
|
? 'warning.400'
|
||||||
|
: undefined,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<FormControl isDisabled={isConnected}>
|
||||||
|
<HStack justifyContent="space-between" alignItems="center">
|
||||||
|
<FormLabel>{inputTemplate?.title}</FormLabel>
|
||||||
|
<Tooltip
|
||||||
|
label={inputTemplate?.description}
|
||||||
|
placement="top"
|
||||||
|
hasArrow
|
||||||
|
shouldWrapChildren
|
||||||
|
>
|
||||||
|
<Icon color="base.400" as={FaInfoCircle} />
|
||||||
|
</Tooltip>
|
||||||
|
</HStack>
|
||||||
|
<InputFieldComponent
|
||||||
|
nodeId={nodeId}
|
||||||
|
field={input}
|
||||||
|
template={inputTemplate}
|
||||||
|
/>
|
||||||
|
</FormControl>
|
||||||
|
{!['never', 'directOnly'].includes(
|
||||||
|
inputTemplate?.inputRequirement ?? ''
|
||||||
|
) && (
|
||||||
|
<FieldHandle
|
||||||
|
nodeId={nodeId}
|
||||||
|
field={inputTemplate}
|
||||||
|
isValidConnection={isValidConnection}
|
||||||
|
handleType="target"
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</Box>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
{map(outputs).map((output, i) => {
|
||||||
|
const outputTemplate = template.current?.outputs[output.name];
|
||||||
|
|
||||||
|
const isConnected = Boolean(
|
||||||
|
connectedInputs.filter((connectedInput) => {
|
||||||
|
return (
|
||||||
|
connectedInput.source === nodeId &&
|
||||||
|
connectedInput.sourceHandle === output.name
|
||||||
|
);
|
||||||
|
}).length
|
||||||
|
);
|
||||||
|
|
||||||
|
if (!outputTemplate) {
|
||||||
|
return (
|
||||||
|
<Box
|
||||||
|
key={output.id}
|
||||||
|
position="relative"
|
||||||
|
p={2}
|
||||||
|
borderWidth={1}
|
||||||
|
borderRadius="md"
|
||||||
|
sx={{
|
||||||
|
borderColor: 'error.400',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<FormControl isDisabled={true}>
|
||||||
|
<HStack justifyContent="space-between" alignItems="center">
|
||||||
|
<FormLabel>Unknown output: {output.name}</FormLabel>
|
||||||
|
</HStack>
|
||||||
|
</FormControl>
|
||||||
|
</Box>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Box
|
||||||
|
key={output.id}
|
||||||
|
position="relative"
|
||||||
|
p={2}
|
||||||
|
borderWidth={1}
|
||||||
|
borderRadius="md"
|
||||||
|
>
|
||||||
|
<FormControl isDisabled={isConnected}>
|
||||||
|
<FormLabel textAlign="end">
|
||||||
|
{outputTemplate?.title} Output
|
||||||
|
</FormLabel>
|
||||||
|
</FormControl>
|
||||||
|
<FieldHandle
|
||||||
|
key={output.id}
|
||||||
|
nodeId={nodeId}
|
||||||
|
field={outputTemplate}
|
||||||
|
isValidConnection={isValidConnection}
|
||||||
|
handleType="source"
|
||||||
|
/>
|
||||||
|
</Box>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
</>
|
||||||
|
</Flex>
|
||||||
|
<Flex></Flex>
|
||||||
|
</Box>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
InvocationComponent.displayName = 'InvocationComponent';
|
@ -0,0 +1,46 @@
|
|||||||
|
import 'reactflow/dist/style.css';
|
||||||
|
import { Box } from '@chakra-ui/react';
|
||||||
|
import { ReactFlowProvider } from 'reactflow';
|
||||||
|
|
||||||
|
import { Flow } from './Flow';
|
||||||
|
import { useAppSelector } from 'app/storeHooks';
|
||||||
|
import { RootState } from 'app/store';
|
||||||
|
import { buildNodesGraph } from '../util/nodesGraphBuilder/buildNodesGraph';
|
||||||
|
|
||||||
|
const NodeEditor = () => {
|
||||||
|
const state = useAppSelector((state: RootState) => state);
|
||||||
|
|
||||||
|
const graph = buildNodesGraph(state);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Box
|
||||||
|
sx={{
|
||||||
|
position: 'relative',
|
||||||
|
width: 'full',
|
||||||
|
height: 'full',
|
||||||
|
borderRadius: 'md',
|
||||||
|
bg: 'base.850',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<ReactFlowProvider>
|
||||||
|
<Flow />
|
||||||
|
</ReactFlowProvider>
|
||||||
|
<Box
|
||||||
|
as="pre"
|
||||||
|
fontFamily="monospace"
|
||||||
|
position="absolute"
|
||||||
|
top={2}
|
||||||
|
left={2}
|
||||||
|
width="full"
|
||||||
|
height="full"
|
||||||
|
userSelect="none"
|
||||||
|
pointerEvents="none"
|
||||||
|
opacity={0.7}
|
||||||
|
>
|
||||||
|
<Box w="50%">{JSON.stringify(graph, null, 2)}</Box>
|
||||||
|
</Box>
|
||||||
|
</Box>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default NodeEditor;
|
@ -0,0 +1,14 @@
|
|||||||
|
import {
|
||||||
|
ArrayInputFieldTemplate,
|
||||||
|
ArrayInputFieldValue,
|
||||||
|
} from 'features/nodes/types';
|
||||||
|
import { FaImage, FaList } from 'react-icons/fa';
|
||||||
|
import { FieldComponentProps } from './types';
|
||||||
|
|
||||||
|
export const ArrayInputFieldComponent = (
|
||||||
|
props: FieldComponentProps<ArrayInputFieldValue, ArrayInputFieldTemplate>
|
||||||
|
) => {
|
||||||
|
const { nodeId, field } = props;
|
||||||
|
|
||||||
|
return <FaList />;
|
||||||
|
};
|
@ -0,0 +1,31 @@
|
|||||||
|
import { Switch } from '@chakra-ui/react';
|
||||||
|
import { useAppDispatch } from 'app/storeHooks';
|
||||||
|
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
|
import {
|
||||||
|
BooleanInputFieldTemplate,
|
||||||
|
BooleanInputFieldValue,
|
||||||
|
} from 'features/nodes/types';
|
||||||
|
import { ChangeEvent } from 'react';
|
||||||
|
import { FieldComponentProps } from './types';
|
||||||
|
|
||||||
|
export const BooleanInputFieldComponent = (
|
||||||
|
props: FieldComponentProps<BooleanInputFieldValue, BooleanInputFieldTemplate>
|
||||||
|
) => {
|
||||||
|
const { nodeId, field } = props;
|
||||||
|
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const handleValueChanged = (e: ChangeEvent<HTMLInputElement>) => {
|
||||||
|
dispatch(
|
||||||
|
fieldValueChanged({
|
||||||
|
nodeId,
|
||||||
|
fieldName: field.name,
|
||||||
|
value: e.target.checked,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Switch onChange={handleValueChanged} isChecked={field.value}></Switch>
|
||||||
|
);
|
||||||
|
};
|
@ -0,0 +1,35 @@
|
|||||||
|
import { Select } from '@chakra-ui/react';
|
||||||
|
import { useAppDispatch } from 'app/storeHooks';
|
||||||
|
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
|
import {
|
||||||
|
EnumInputFieldTemplate,
|
||||||
|
EnumInputFieldValue,
|
||||||
|
} from 'features/nodes/types';
|
||||||
|
import { ChangeEvent } from 'react';
|
||||||
|
import { FieldComponentProps } from './types';
|
||||||
|
|
||||||
|
export const EnumInputFieldComponent = (
|
||||||
|
props: FieldComponentProps<EnumInputFieldValue, EnumInputFieldTemplate>
|
||||||
|
) => {
|
||||||
|
const { nodeId, field, template } = props;
|
||||||
|
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const handleValueChanged = (e: ChangeEvent<HTMLSelectElement>) => {
|
||||||
|
dispatch(
|
||||||
|
fieldValueChanged({
|
||||||
|
nodeId,
|
||||||
|
fieldName: field.name,
|
||||||
|
value: e.target.value,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Select onChange={handleValueChanged} value={field.value}>
|
||||||
|
{template.options.map((option) => (
|
||||||
|
<option key={option}>{option}</option>
|
||||||
|
))}
|
||||||
|
</Select>
|
||||||
|
);
|
||||||
|
};
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user