mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
merge with main
This commit is contained in:
commit
98773b20ac
14
.github/CODEOWNERS
vendored
14
.github/CODEOWNERS
vendored
@ -2,7 +2,7 @@
|
|||||||
/.github/workflows/ @lstein @blessedcoolant
|
/.github/workflows/ @lstein @blessedcoolant
|
||||||
|
|
||||||
# documentation
|
# documentation
|
||||||
/docs/ @lstein @tildebyte @blessedcoolant
|
/docs/ @lstein @blessedcoolant @hipsterusername
|
||||||
/mkdocs.yml @lstein @blessedcoolant
|
/mkdocs.yml @lstein @blessedcoolant
|
||||||
|
|
||||||
# nodes
|
# nodes
|
||||||
@ -18,17 +18,17 @@
|
|||||||
/invokeai/version @lstein @blessedcoolant
|
/invokeai/version @lstein @blessedcoolant
|
||||||
|
|
||||||
# web ui
|
# web ui
|
||||||
/invokeai/frontend @blessedcoolant @psychedelicious @lstein
|
/invokeai/frontend @blessedcoolant @psychedelicious @lstein @maryhipp
|
||||||
/invokeai/backend @blessedcoolant @psychedelicious @lstein
|
/invokeai/backend @blessedcoolant @psychedelicious @lstein @maryhipp
|
||||||
|
|
||||||
# generation, model management, postprocessing
|
# generation, model management, postprocessing
|
||||||
/invokeai/backend @damian0815 @lstein @blessedcoolant @jpphoto @gregghelt2
|
/invokeai/backend @damian0815 @lstein @blessedcoolant @jpphoto @gregghelt2 @StAlKeR7779
|
||||||
|
|
||||||
# front ends
|
# front ends
|
||||||
/invokeai/frontend/CLI @lstein
|
/invokeai/frontend/CLI @lstein
|
||||||
/invokeai/frontend/install @lstein @ebr
|
/invokeai/frontend/install @lstein @ebr
|
||||||
/invokeai/frontend/merge @lstein @blessedcoolant @hipsterusername
|
/invokeai/frontend/merge @lstein @blessedcoolant
|
||||||
/invokeai/frontend/training @lstein @blessedcoolant @hipsterusername
|
/invokeai/frontend/training @lstein @blessedcoolant
|
||||||
/invokeai/frontend/web @psychedelicious @blessedcoolant
|
/invokeai/frontend/web @psychedelicious @blessedcoolant @maryhipp
|
||||||
|
|
||||||
|
|
||||||
|
@ -5,6 +5,7 @@ import os
|
|||||||
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||||
from invokeai.app.services.images import ImageService
|
from invokeai.app.services.images import ImageService
|
||||||
from invokeai.app.services.metadata import CoreMetadataService
|
from invokeai.app.services.metadata import CoreMetadataService
|
||||||
|
from invokeai.app.services.resource_name import SimpleNameService
|
||||||
from invokeai.app.services.urls import LocalUrlService
|
from invokeai.app.services.urls import LocalUrlService
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
@ -67,7 +68,7 @@ class ApiDependencies:
|
|||||||
metadata = CoreMetadataService()
|
metadata = CoreMetadataService()
|
||||||
image_record_storage = SqliteImageRecordStorage(db_location)
|
image_record_storage = SqliteImageRecordStorage(db_location)
|
||||||
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
||||||
|
names = SimpleNameService()
|
||||||
latents = ForwardCacheLatentsStorage(
|
latents = ForwardCacheLatentsStorage(
|
||||||
DiskLatentsStorage(f"{output_folder}/latents")
|
DiskLatentsStorage(f"{output_folder}/latents")
|
||||||
)
|
)
|
||||||
@ -78,6 +79,7 @@ class ApiDependencies:
|
|||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
url=urls,
|
url=urls,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
|
names=names,
|
||||||
graph_execution_manager=graph_execution_manager,
|
graph_execution_manager=graph_execution_manager,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,39 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageType
|
|
||||||
|
|
||||||
|
|
||||||
class ImageResponseMetadata(BaseModel):
|
|
||||||
"""An image's metadata. Used only in HTTP responses."""
|
|
||||||
|
|
||||||
created: int = Field(description="The creation timestamp of the image")
|
|
||||||
width: int = Field(description="The width of the image in pixels")
|
|
||||||
height: int = Field(description="The height of the image in pixels")
|
|
||||||
# invokeai: Optional[InvokeAIMetadata] = Field(
|
|
||||||
# description="The image's InvokeAI-specific metadata"
|
|
||||||
# )
|
|
||||||
|
|
||||||
|
|
||||||
class ImageResponse(BaseModel):
|
|
||||||
"""The response type for images"""
|
|
||||||
|
|
||||||
image_type: ImageType = Field(description="The type of the image")
|
|
||||||
image_name: str = Field(description="The name of the image")
|
|
||||||
image_url: str = Field(description="The url of the image")
|
|
||||||
thumbnail_url: str = Field(description="The url of the image's thumbnail")
|
|
||||||
metadata: ImageResponseMetadata = Field(description="The image's metadata")
|
|
||||||
|
|
||||||
|
|
||||||
class ProgressImage(BaseModel):
|
|
||||||
"""The progress image sent intermittently during processing"""
|
|
||||||
|
|
||||||
width: int = Field(description="The effective width of the image in pixels")
|
|
||||||
height: int = Field(description="The effective height of the image in pixels")
|
|
||||||
dataURL: str = Field(description="The image data as a b64 data URL")
|
|
||||||
|
|
||||||
|
|
||||||
class SavedImage(BaseModel):
|
|
||||||
image_name: str = Field(description="The name of the saved image")
|
|
||||||
thumbnail_name: str = Field(description="The name of the saved thumbnail")
|
|
||||||
created: int = Field(description="The created timestamp of the saved image")
|
|
@ -1,13 +1,19 @@
|
|||||||
import io
|
import io
|
||||||
from fastapi import HTTPException, Path, Query, Request, Response, UploadFile
|
from typing import Optional
|
||||||
|
from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from invokeai.app.models.image import (
|
from invokeai.app.models.image import (
|
||||||
ImageCategory,
|
ImageCategory,
|
||||||
ImageType,
|
ResourceOrigin,
|
||||||
|
)
|
||||||
|
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||||
|
from invokeai.app.services.models.image_record import (
|
||||||
|
ImageDTO,
|
||||||
|
ImageRecordChanges,
|
||||||
|
ImageUrlsDTO,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.models.image_record import ImageDTO, ImageUrlsDTO
|
|
||||||
from invokeai.app.services.item_storage import PaginatedResults
|
from invokeai.app.services.item_storage import PaginatedResults
|
||||||
|
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
@ -27,10 +33,13 @@ images_router = APIRouter(prefix="/v1/images", tags=["images"])
|
|||||||
)
|
)
|
||||||
async def upload_image(
|
async def upload_image(
|
||||||
file: UploadFile,
|
file: UploadFile,
|
||||||
image_type: ImageType,
|
|
||||||
request: Request,
|
request: Request,
|
||||||
response: Response,
|
response: Response,
|
||||||
image_category: ImageCategory = ImageCategory.GENERAL,
|
image_category: ImageCategory = Query(description="The category of the image"),
|
||||||
|
is_intermediate: bool = Query(description="Whether this is an intermediate image"),
|
||||||
|
session_id: Optional[str] = Query(
|
||||||
|
default=None, description="The session ID associated with this upload, if any"
|
||||||
|
),
|
||||||
) -> ImageDTO:
|
) -> ImageDTO:
|
||||||
"""Uploads an image"""
|
"""Uploads an image"""
|
||||||
if not file.content_type.startswith("image"):
|
if not file.content_type.startswith("image"):
|
||||||
@ -46,9 +55,11 @@ async def upload_image(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
image_dto = ApiDependencies.invoker.services.images.create(
|
image_dto = ApiDependencies.invoker.services.images.create(
|
||||||
pil_image,
|
image=pil_image,
|
||||||
image_type,
|
image_origin=ResourceOrigin.EXTERNAL,
|
||||||
image_category,
|
image_category=image_category,
|
||||||
|
session_id=session_id,
|
||||||
|
is_intermediate=is_intermediate,
|
||||||
)
|
)
|
||||||
|
|
||||||
response.status_code = 201
|
response.status_code = 201
|
||||||
@ -59,41 +70,61 @@ async def upload_image(
|
|||||||
raise HTTPException(status_code=500, detail="Failed to create image")
|
raise HTTPException(status_code=500, detail="Failed to create image")
|
||||||
|
|
||||||
|
|
||||||
@images_router.delete("/{image_type}/{image_name}", operation_id="delete_image")
|
@images_router.delete("/{image_origin}/{image_name}", operation_id="delete_image")
|
||||||
async def delete_image(
|
async def delete_image(
|
||||||
image_type: ImageType = Query(description="The type of image to delete"),
|
image_origin: ResourceOrigin = Path(description="The origin of image to delete"),
|
||||||
image_name: str = Path(description="The name of the image to delete"),
|
image_name: str = Path(description="The name of the image to delete"),
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Deletes an image"""
|
"""Deletes an image"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ApiDependencies.invoker.services.images.delete(image_type, image_name)
|
ApiDependencies.invoker.services.images.delete(image_origin, image_name)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# TODO: Does this need any exception handling at all?
|
# TODO: Does this need any exception handling at all?
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@images_router.patch(
|
||||||
|
"/{image_origin}/{image_name}",
|
||||||
|
operation_id="update_image",
|
||||||
|
response_model=ImageDTO,
|
||||||
|
)
|
||||||
|
async def update_image(
|
||||||
|
image_origin: ResourceOrigin = Path(description="The origin of image to update"),
|
||||||
|
image_name: str = Path(description="The name of the image to update"),
|
||||||
|
image_changes: ImageRecordChanges = Body(
|
||||||
|
description="The changes to apply to the image"
|
||||||
|
),
|
||||||
|
) -> ImageDTO:
|
||||||
|
"""Updates an image"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
return ApiDependencies.invoker.services.images.update(
|
||||||
|
image_origin, image_name, image_changes
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=400, detail="Failed to update image")
|
||||||
|
|
||||||
|
|
||||||
@images_router.get(
|
@images_router.get(
|
||||||
"/{image_type}/{image_name}/metadata",
|
"/{image_origin}/{image_name}/metadata",
|
||||||
operation_id="get_image_metadata",
|
operation_id="get_image_metadata",
|
||||||
response_model=ImageDTO,
|
response_model=ImageDTO,
|
||||||
)
|
)
|
||||||
async def get_image_metadata(
|
async def get_image_metadata(
|
||||||
image_type: ImageType = Path(description="The type of image to get"),
|
image_origin: ResourceOrigin = Path(description="The origin of image to get"),
|
||||||
image_name: str = Path(description="The name of image to get"),
|
image_name: str = Path(description="The name of image to get"),
|
||||||
) -> ImageDTO:
|
) -> ImageDTO:
|
||||||
"""Gets an image's metadata"""
|
"""Gets an image's metadata"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return ApiDependencies.invoker.services.images.get_dto(
|
return ApiDependencies.invoker.services.images.get_dto(image_origin, image_name)
|
||||||
image_type, image_name
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=404)
|
raise HTTPException(status_code=404)
|
||||||
|
|
||||||
|
|
||||||
@images_router.get(
|
@images_router.get(
|
||||||
"/{image_type}/{image_name}",
|
"/{image_origin}/{image_name}",
|
||||||
operation_id="get_image_full",
|
operation_id="get_image_full",
|
||||||
response_class=Response,
|
response_class=Response,
|
||||||
responses={
|
responses={
|
||||||
@ -105,7 +136,7 @@ async def get_image_metadata(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
async def get_image_full(
|
async def get_image_full(
|
||||||
image_type: ImageType = Path(
|
image_origin: ResourceOrigin = Path(
|
||||||
description="The type of full-resolution image file to get"
|
description="The type of full-resolution image file to get"
|
||||||
),
|
),
|
||||||
image_name: str = Path(description="The name of full-resolution image file to get"),
|
image_name: str = Path(description="The name of full-resolution image file to get"),
|
||||||
@ -113,9 +144,7 @@ async def get_image_full(
|
|||||||
"""Gets a full-resolution image file"""
|
"""Gets a full-resolution image file"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
path = ApiDependencies.invoker.services.images.get_path(
|
path = ApiDependencies.invoker.services.images.get_path(image_origin, image_name)
|
||||||
image_type, image_name
|
|
||||||
)
|
|
||||||
|
|
||||||
if not ApiDependencies.invoker.services.images.validate_path(path):
|
if not ApiDependencies.invoker.services.images.validate_path(path):
|
||||||
raise HTTPException(status_code=404)
|
raise HTTPException(status_code=404)
|
||||||
@ -131,7 +160,7 @@ async def get_image_full(
|
|||||||
|
|
||||||
|
|
||||||
@images_router.get(
|
@images_router.get(
|
||||||
"/{image_type}/{image_name}/thumbnail",
|
"/{image_origin}/{image_name}/thumbnail",
|
||||||
operation_id="get_image_thumbnail",
|
operation_id="get_image_thumbnail",
|
||||||
response_class=Response,
|
response_class=Response,
|
||||||
responses={
|
responses={
|
||||||
@ -143,14 +172,14 @@ async def get_image_full(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
async def get_image_thumbnail(
|
async def get_image_thumbnail(
|
||||||
image_type: ImageType = Path(description="The type of thumbnail image file to get"),
|
image_origin: ResourceOrigin = Path(description="The origin of thumbnail image file to get"),
|
||||||
image_name: str = Path(description="The name of thumbnail image file to get"),
|
image_name: str = Path(description="The name of thumbnail image file to get"),
|
||||||
) -> FileResponse:
|
) -> FileResponse:
|
||||||
"""Gets a thumbnail image file"""
|
"""Gets a thumbnail image file"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
path = ApiDependencies.invoker.services.images.get_path(
|
path = ApiDependencies.invoker.services.images.get_path(
|
||||||
image_type, image_name, thumbnail=True
|
image_origin, image_name, thumbnail=True
|
||||||
)
|
)
|
||||||
if not ApiDependencies.invoker.services.images.validate_path(path):
|
if not ApiDependencies.invoker.services.images.validate_path(path):
|
||||||
raise HTTPException(status_code=404)
|
raise HTTPException(status_code=404)
|
||||||
@ -163,25 +192,25 @@ async def get_image_thumbnail(
|
|||||||
|
|
||||||
|
|
||||||
@images_router.get(
|
@images_router.get(
|
||||||
"/{image_type}/{image_name}/urls",
|
"/{image_origin}/{image_name}/urls",
|
||||||
operation_id="get_image_urls",
|
operation_id="get_image_urls",
|
||||||
response_model=ImageUrlsDTO,
|
response_model=ImageUrlsDTO,
|
||||||
)
|
)
|
||||||
async def get_image_urls(
|
async def get_image_urls(
|
||||||
image_type: ImageType = Path(description="The type of the image whose URL to get"),
|
image_origin: ResourceOrigin = Path(description="The origin of the image whose URL to get"),
|
||||||
image_name: str = Path(description="The name of the image whose URL to get"),
|
image_name: str = Path(description="The name of the image whose URL to get"),
|
||||||
) -> ImageUrlsDTO:
|
) -> ImageUrlsDTO:
|
||||||
"""Gets an image and thumbnail URL"""
|
"""Gets an image and thumbnail URL"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
image_url = ApiDependencies.invoker.services.images.get_url(
|
image_url = ApiDependencies.invoker.services.images.get_url(
|
||||||
image_type, image_name
|
image_origin, image_name
|
||||||
)
|
)
|
||||||
thumbnail_url = ApiDependencies.invoker.services.images.get_url(
|
thumbnail_url = ApiDependencies.invoker.services.images.get_url(
|
||||||
image_type, image_name, thumbnail=True
|
image_origin, image_name, thumbnail=True
|
||||||
)
|
)
|
||||||
return ImageUrlsDTO(
|
return ImageUrlsDTO(
|
||||||
image_type=image_type,
|
image_origin=image_origin,
|
||||||
image_name=image_name,
|
image_name=image_name,
|
||||||
image_url=image_url,
|
image_url=image_url,
|
||||||
thumbnail_url=thumbnail_url,
|
thumbnail_url=thumbnail_url,
|
||||||
@ -193,23 +222,29 @@ async def get_image_urls(
|
|||||||
@images_router.get(
|
@images_router.get(
|
||||||
"/",
|
"/",
|
||||||
operation_id="list_images_with_metadata",
|
operation_id="list_images_with_metadata",
|
||||||
response_model=PaginatedResults[ImageDTO],
|
response_model=OffsetPaginatedResults[ImageDTO],
|
||||||
)
|
)
|
||||||
async def list_images_with_metadata(
|
async def list_images_with_metadata(
|
||||||
image_type: ImageType = Query(description="The type of images to list"),
|
image_origin: Optional[ResourceOrigin] = Query(
|
||||||
image_category: ImageCategory = Query(description="The kind of images to list"),
|
default=None, description="The origin of images to list"
|
||||||
page: int = Query(default=0, description="The page of image metadata to get"),
|
|
||||||
per_page: int = Query(
|
|
||||||
default=10, description="The number of image metadata per page"
|
|
||||||
),
|
),
|
||||||
) -> PaginatedResults[ImageDTO]:
|
categories: Optional[list[ImageCategory]] = Query(
|
||||||
"""Gets a list of images with metadata"""
|
default=None, description="The categories of image to include"
|
||||||
|
),
|
||||||
|
is_intermediate: Optional[bool] = Query(
|
||||||
|
default=None, description="Whether to list intermediate images"
|
||||||
|
),
|
||||||
|
offset: int = Query(default=0, description="The page offset"),
|
||||||
|
limit: int = Query(default=10, description="The number of images per page"),
|
||||||
|
) -> OffsetPaginatedResults[ImageDTO]:
|
||||||
|
"""Gets a list of images"""
|
||||||
|
|
||||||
image_dtos = ApiDependencies.invoker.services.images.get_many(
|
image_dtos = ApiDependencies.invoker.services.images.get_many(
|
||||||
image_type,
|
offset,
|
||||||
image_category,
|
limit,
|
||||||
page,
|
image_origin,
|
||||||
per_page,
|
categories,
|
||||||
|
is_intermediate,
|
||||||
)
|
)
|
||||||
|
|
||||||
return image_dtos
|
return image_dtos
|
||||||
|
@ -12,11 +12,10 @@ from pydantic import BaseModel, ValidationError
|
|||||||
from pydantic.fields import Field
|
from pydantic.fields import Field
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
import invokeai.version
|
|
||||||
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||||
from invokeai.app.services.images import ImageService
|
from invokeai.app.services.images import ImageService
|
||||||
from invokeai.app.services.metadata import (CoreMetadataService,
|
from invokeai.app.services.metadata import CoreMetadataService
|
||||||
PngMetadataService)
|
from invokeai.app.services.resource_name import SimpleNameService
|
||||||
from invokeai.app.services.urls import LocalUrlService
|
from invokeai.app.services.urls import LocalUrlService
|
||||||
|
|
||||||
from .cli.commands import (BaseCommand, CliContext, ExitCli,
|
from .cli.commands import (BaseCommand, CliContext, ExitCli,
|
||||||
@ -232,6 +231,7 @@ def invoke_cli():
|
|||||||
metadata = CoreMetadataService()
|
metadata = CoreMetadataService()
|
||||||
image_record_storage = SqliteImageRecordStorage(db_location)
|
image_record_storage = SqliteImageRecordStorage(db_location)
|
||||||
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
||||||
|
names = SimpleNameService()
|
||||||
|
|
||||||
images = ImageService(
|
images = ImageService(
|
||||||
image_record_storage=image_record_storage,
|
image_record_storage=image_record_storage,
|
||||||
@ -239,6 +239,7 @@ def invoke_cli():
|
|||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
url=urls,
|
url=urls,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
|
names=names,
|
||||||
graph_execution_manager=graph_execution_manager,
|
graph_execution_manager=graph_execution_manager,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -78,6 +78,7 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
|
|
||||||
#fmt: off
|
#fmt: off
|
||||||
id: str = Field(description="The id of this node. Must be unique among all nodes.")
|
id: str = Field(description="The id of this node. Must be unique among all nodes.")
|
||||||
|
is_intermediate: bool = Field(default=False, description="Whether or not this node is an intermediate node.")
|
||||||
#fmt: on
|
#fmt: on
|
||||||
|
|
||||||
|
|
||||||
@ -95,6 +96,7 @@ class UIConfig(TypedDict, total=False):
|
|||||||
"image",
|
"image",
|
||||||
"latents",
|
"latents",
|
||||||
"model",
|
"model",
|
||||||
|
"control",
|
||||||
],
|
],
|
||||||
]
|
]
|
||||||
tags: List[str]
|
tags: List[str]
|
||||||
|
@ -22,6 +22,14 @@ class IntCollectionOutput(BaseInvocationOutput):
|
|||||||
# Outputs
|
# Outputs
|
||||||
collection: list[int] = Field(default=[], description="The int collection")
|
collection: list[int] = Field(default=[], description="The int collection")
|
||||||
|
|
||||||
|
class FloatCollectionOutput(BaseInvocationOutput):
|
||||||
|
"""A collection of floats"""
|
||||||
|
|
||||||
|
type: Literal["float_collection"] = "float_collection"
|
||||||
|
|
||||||
|
# Outputs
|
||||||
|
collection: list[float] = Field(default=[], description="The float collection")
|
||||||
|
|
||||||
|
|
||||||
class RangeInvocation(BaseInvocation):
|
class RangeInvocation(BaseInvocation):
|
||||||
"""Creates a range of numbers from start to stop with step"""
|
"""Creates a range of numbers from start to stop with step"""
|
||||||
|
428
invokeai/app/invocations/controlnet_image_processors.py
Normal file
428
invokeai/app/invocations/controlnet_image_processors.py
Normal file
@ -0,0 +1,428 @@
|
|||||||
|
# InvokeAI nodes for ControlNet image preprocessors
|
||||||
|
# initial implementation by Gregg Helt, 2023
|
||||||
|
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from typing import Literal, Optional, Union, List
|
||||||
|
from PIL import Image, ImageFilter, ImageOps
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from ..models.image import ImageField, ImageCategory, ResourceOrigin
|
||||||
|
from .baseinvocation import (
|
||||||
|
BaseInvocation,
|
||||||
|
BaseInvocationOutput,
|
||||||
|
InvocationContext,
|
||||||
|
InvocationConfig,
|
||||||
|
)
|
||||||
|
from controlnet_aux import (
|
||||||
|
CannyDetector,
|
||||||
|
HEDdetector,
|
||||||
|
LineartDetector,
|
||||||
|
LineartAnimeDetector,
|
||||||
|
MidasDetector,
|
||||||
|
MLSDdetector,
|
||||||
|
NormalBaeDetector,
|
||||||
|
OpenposeDetector,
|
||||||
|
PidiNetDetector,
|
||||||
|
ContentShuffleDetector,
|
||||||
|
ZoeDetector,
|
||||||
|
MediapipeFaceDetector,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .image import ImageOutput, PILInvocationConfig
|
||||||
|
|
||||||
|
CONTROLNET_DEFAULT_MODELS = [
|
||||||
|
###########################################
|
||||||
|
# lllyasviel sd v1.5, ControlNet v1.0 models
|
||||||
|
##############################################
|
||||||
|
"lllyasviel/sd-controlnet-canny",
|
||||||
|
"lllyasviel/sd-controlnet-depth",
|
||||||
|
"lllyasviel/sd-controlnet-hed",
|
||||||
|
"lllyasviel/sd-controlnet-seg",
|
||||||
|
"lllyasviel/sd-controlnet-openpose",
|
||||||
|
"lllyasviel/sd-controlnet-scribble",
|
||||||
|
"lllyasviel/sd-controlnet-normal",
|
||||||
|
"lllyasviel/sd-controlnet-mlsd",
|
||||||
|
|
||||||
|
#############################################
|
||||||
|
# lllyasviel sd v1.5, ControlNet v1.1 models
|
||||||
|
#############################################
|
||||||
|
"lllyasviel/control_v11p_sd15_canny",
|
||||||
|
"lllyasviel/control_v11p_sd15_openpose",
|
||||||
|
"lllyasviel/control_v11p_sd15_seg",
|
||||||
|
# "lllyasviel/control_v11p_sd15_depth", # broken
|
||||||
|
"lllyasviel/control_v11f1p_sd15_depth",
|
||||||
|
"lllyasviel/control_v11p_sd15_normalbae",
|
||||||
|
"lllyasviel/control_v11p_sd15_scribble",
|
||||||
|
"lllyasviel/control_v11p_sd15_mlsd",
|
||||||
|
"lllyasviel/control_v11p_sd15_softedge",
|
||||||
|
"lllyasviel/control_v11p_sd15s2_lineart_anime",
|
||||||
|
"lllyasviel/control_v11p_sd15_lineart",
|
||||||
|
"lllyasviel/control_v11p_sd15_inpaint",
|
||||||
|
# "lllyasviel/control_v11u_sd15_tile",
|
||||||
|
# problem (temporary?) with huffingface "lllyasviel/control_v11u_sd15_tile",
|
||||||
|
# so for now replace "lllyasviel/control_v11f1e_sd15_tile",
|
||||||
|
"lllyasviel/control_v11e_sd15_shuffle",
|
||||||
|
"lllyasviel/control_v11e_sd15_ip2p",
|
||||||
|
"lllyasviel/control_v11f1e_sd15_tile",
|
||||||
|
|
||||||
|
#################################################
|
||||||
|
# thibaud sd v2.1 models (ControlNet v1.0? or v1.1?
|
||||||
|
##################################################
|
||||||
|
"thibaud/controlnet-sd21-openpose-diffusers",
|
||||||
|
"thibaud/controlnet-sd21-canny-diffusers",
|
||||||
|
"thibaud/controlnet-sd21-depth-diffusers",
|
||||||
|
"thibaud/controlnet-sd21-scribble-diffusers",
|
||||||
|
"thibaud/controlnet-sd21-hed-diffusers",
|
||||||
|
"thibaud/controlnet-sd21-zoedepth-diffusers",
|
||||||
|
"thibaud/controlnet-sd21-color-diffusers",
|
||||||
|
"thibaud/controlnet-sd21-openposev2-diffusers",
|
||||||
|
"thibaud/controlnet-sd21-lineart-diffusers",
|
||||||
|
"thibaud/controlnet-sd21-normalbae-diffusers",
|
||||||
|
"thibaud/controlnet-sd21-ade20k-diffusers",
|
||||||
|
|
||||||
|
##############################################
|
||||||
|
# ControlNetMediaPipeface, ControlNet v1.1
|
||||||
|
##############################################
|
||||||
|
# ["CrucibleAI/ControlNetMediaPipeFace", "diffusion_sd15"], # SD 1.5
|
||||||
|
# diffusion_sd15 needs to be passed to from_pretrained() as subfolder arg
|
||||||
|
# hacked t2l to split to model & subfolder if format is "model,subfolder"
|
||||||
|
"CrucibleAI/ControlNetMediaPipeFace,diffusion_sd15", # SD 1.5
|
||||||
|
"CrucibleAI/ControlNetMediaPipeFace", # SD 2.1?
|
||||||
|
]
|
||||||
|
|
||||||
|
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
|
||||||
|
|
||||||
|
class ControlField(BaseModel):
|
||||||
|
image: ImageField = Field(default=None, description="processed image")
|
||||||
|
control_model: Optional[str] = Field(default=None, description="control model used")
|
||||||
|
control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
|
||||||
|
begin_step_percent: float = Field(default=0, ge=0, le=1,
|
||||||
|
description="% of total steps at which controlnet is first applied")
|
||||||
|
end_step_percent: float = Field(default=1, ge=0, le=1,
|
||||||
|
description="% of total steps at which controlnet is last applied")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
"required": ["image", "control_model", "control_weight", "begin_step_percent", "end_step_percent"]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ControlOutput(BaseInvocationOutput):
|
||||||
|
"""node output for ControlNet info"""
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["control_output"] = "control_output"
|
||||||
|
control: ControlField = Field(default=None, description="The control info dict")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
class ControlNetInvocation(BaseInvocation):
|
||||||
|
"""Collects ControlNet info to pass to other nodes"""
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["controlnet"] = "controlnet"
|
||||||
|
# Inputs
|
||||||
|
image: ImageField = Field(default=None, description="image to process")
|
||||||
|
control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny",
|
||||||
|
description="control model used")
|
||||||
|
control_weight: float = Field(default=1.0, ge=0, le=1, description="weight given to controlnet")
|
||||||
|
# TODO: add support in backend core for begin_step_percent, end_step_percent, guess_mode
|
||||||
|
begin_step_percent: float = Field(default=0, ge=0, le=1,
|
||||||
|
description="% of total steps at which controlnet is first applied")
|
||||||
|
end_step_percent: float = Field(default=1, ge=0, le=1,
|
||||||
|
description="% of total steps at which controlnet is last applied")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> ControlOutput:
|
||||||
|
|
||||||
|
return ControlOutput(
|
||||||
|
control=ControlField(
|
||||||
|
image=self.image,
|
||||||
|
control_model=self.control_model,
|
||||||
|
control_weight=self.control_weight,
|
||||||
|
begin_step_percent=self.begin_step_percent,
|
||||||
|
end_step_percent=self.end_step_percent,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: move image processors to separate file (image_analysis.py
|
||||||
|
class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
|
||||||
|
"""Base class for invocations that preprocess images for ControlNet"""
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["image_processor"] = "image_processor"
|
||||||
|
# Inputs
|
||||||
|
image: ImageField = Field(default=None, description="image to process")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
def run_processor(self, image):
|
||||||
|
# superclass just passes through image without processing
|
||||||
|
return image
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
|
||||||
|
raw_image = context.services.images.get_pil_image(
|
||||||
|
self.image.image_origin, self.image.image_name
|
||||||
|
)
|
||||||
|
# image type should be PIL.PngImagePlugin.PngImageFile ?
|
||||||
|
processed_image = self.run_processor(raw_image)
|
||||||
|
|
||||||
|
# FIXME: what happened to image metadata?
|
||||||
|
# metadata = context.services.metadata.build_metadata(
|
||||||
|
# session_id=context.graph_execution_state_id, node=self
|
||||||
|
# )
|
||||||
|
|
||||||
|
# currently can't see processed image in node UI without a showImage node,
|
||||||
|
# so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery
|
||||||
|
image_dto = context.services.images.create(
|
||||||
|
image=processed_image,
|
||||||
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
|
image_category=ImageCategory.CONTROL,
|
||||||
|
session_id=context.graph_execution_state_id,
|
||||||
|
node_id=self.id,
|
||||||
|
is_intermediate=self.is_intermediate
|
||||||
|
)
|
||||||
|
|
||||||
|
"""Builds an ImageOutput and its ImageField"""
|
||||||
|
processed_image_field = ImageField(
|
||||||
|
image_name=image_dto.image_name,
|
||||||
|
image_origin=image_dto.image_origin,
|
||||||
|
)
|
||||||
|
return ImageOutput(
|
||||||
|
image=processed_image_field,
|
||||||
|
# width=processed_image.width,
|
||||||
|
width = image_dto.width,
|
||||||
|
# height=processed_image.height,
|
||||||
|
height = image_dto.height,
|
||||||
|
# mode=processed_image.mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||||
|
"""Canny edge detection for ControlNet"""
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["canny_image_processor"] = "canny_image_processor"
|
||||||
|
# Input
|
||||||
|
low_threshold: float = Field(default=100, ge=0, description="low threshold of Canny pixel gradient")
|
||||||
|
high_threshold: float = Field(default=200, ge=0, description="high threshold of Canny pixel gradient")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def run_processor(self, image):
|
||||||
|
canny_processor = CannyDetector()
|
||||||
|
processed_image = canny_processor(image, self.low_threshold, self.high_threshold)
|
||||||
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
|
class HedImageprocessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||||
|
"""Applies HED edge detection to image"""
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["hed_image_processor"] = "hed_image_processor"
|
||||||
|
# Inputs
|
||||||
|
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
|
||||||
|
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
|
||||||
|
# safe not supported in controlnet_aux v0.0.3
|
||||||
|
# safe: bool = Field(default=False, description="whether to use safe mode")
|
||||||
|
scribble: bool = Field(default=False, description="whether to use scribble mode")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def run_processor(self, image):
|
||||||
|
hed_processor = HEDdetector.from_pretrained("lllyasviel/Annotators")
|
||||||
|
processed_image = hed_processor(image,
|
||||||
|
detect_resolution=self.detect_resolution,
|
||||||
|
image_resolution=self.image_resolution,
|
||||||
|
# safe not supported in controlnet_aux v0.0.3
|
||||||
|
# safe=self.safe,
|
||||||
|
scribble=self.scribble,
|
||||||
|
)
|
||||||
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
|
class LineartImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||||
|
"""Applies line art processing to image"""
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["lineart_image_processor"] = "lineart_image_processor"
|
||||||
|
# Inputs
|
||||||
|
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
|
||||||
|
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
|
||||||
|
coarse: bool = Field(default=False, description="whether to use coarse mode")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def run_processor(self, image):
|
||||||
|
lineart_processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
|
processed_image = lineart_processor(image,
|
||||||
|
detect_resolution=self.detect_resolution,
|
||||||
|
image_resolution=self.image_resolution,
|
||||||
|
coarse=self.coarse)
|
||||||
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
|
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||||
|
"""Applies line art anime processing to image"""
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["lineart_anime_image_processor"] = "lineart_anime_image_processor"
|
||||||
|
# Inputs
|
||||||
|
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
|
||||||
|
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def run_processor(self, image):
|
||||||
|
processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
|
processed_image = processor(image,
|
||||||
|
detect_resolution=self.detect_resolution,
|
||||||
|
image_resolution=self.image_resolution,
|
||||||
|
)
|
||||||
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
|
class OpenposeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||||
|
"""Applies Openpose processing to image"""
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["openpose_image_processor"] = "openpose_image_processor"
|
||||||
|
# Inputs
|
||||||
|
hand_and_face: bool = Field(default=False, description="whether to use hands and face mode")
|
||||||
|
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
|
||||||
|
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def run_processor(self, image):
|
||||||
|
openpose_processor = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
|
processed_image = openpose_processor(image,
|
||||||
|
detect_resolution=self.detect_resolution,
|
||||||
|
image_resolution=self.image_resolution,
|
||||||
|
hand_and_face=self.hand_and_face,
|
||||||
|
)
|
||||||
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
|
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||||
|
"""Applies Midas depth processing to image"""
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["midas_depth_image_processor"] = "midas_depth_image_processor"
|
||||||
|
# Inputs
|
||||||
|
a_mult: float = Field(default=2.0, ge=0, description="Midas parameter a = amult * PI")
|
||||||
|
bg_th: float = Field(default=0.1, ge=0, description="Midas parameter bg_th")
|
||||||
|
# depth_and_normal not supported in controlnet_aux v0.0.3
|
||||||
|
# depth_and_normal: bool = Field(default=False, description="whether to use depth and normal mode")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def run_processor(self, image):
|
||||||
|
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
|
processed_image = midas_processor(image,
|
||||||
|
a=np.pi * self.a_mult,
|
||||||
|
bg_th=self.bg_th,
|
||||||
|
# dept_and_normal not supported in controlnet_aux v0.0.3
|
||||||
|
# depth_and_normal=self.depth_and_normal,
|
||||||
|
)
|
||||||
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
|
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||||
|
"""Applies NormalBae processing to image"""
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["normalbae_image_processor"] = "normalbae_image_processor"
|
||||||
|
# Inputs
|
||||||
|
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
|
||||||
|
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def run_processor(self, image):
|
||||||
|
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
|
processed_image = normalbae_processor(image,
|
||||||
|
detect_resolution=self.detect_resolution,
|
||||||
|
image_resolution=self.image_resolution)
|
||||||
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
|
class MlsdImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||||
|
"""Applies MLSD processing to image"""
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["mlsd_image_processor"] = "mlsd_image_processor"
|
||||||
|
# Inputs
|
||||||
|
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
|
||||||
|
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
|
||||||
|
thr_v: float = Field(default=0.1, ge=0, description="MLSD parameter thr_v")
|
||||||
|
thr_d: float = Field(default=0.1, ge=0, description="MLSD parameter thr_d")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def run_processor(self, image):
|
||||||
|
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
|
||||||
|
processed_image = mlsd_processor(image,
|
||||||
|
detect_resolution=self.detect_resolution,
|
||||||
|
image_resolution=self.image_resolution,
|
||||||
|
thr_v=self.thr_v,
|
||||||
|
thr_d=self.thr_d)
|
||||||
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
|
class PidiImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||||
|
"""Applies PIDI processing to image"""
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["pidi_image_processor"] = "pidi_image_processor"
|
||||||
|
# Inputs
|
||||||
|
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
|
||||||
|
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
|
||||||
|
safe: bool = Field(default=False, description="whether to use safe mode")
|
||||||
|
scribble: bool = Field(default=False, description="whether to use scribble mode")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def run_processor(self, image):
|
||||||
|
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
|
processed_image = pidi_processor(image,
|
||||||
|
detect_resolution=self.detect_resolution,
|
||||||
|
image_resolution=self.image_resolution,
|
||||||
|
safe=self.safe,
|
||||||
|
scribble=self.scribble)
|
||||||
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
|
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||||
|
"""Applies content shuffle processing to image"""
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["content_shuffle_image_processor"] = "content_shuffle_image_processor"
|
||||||
|
# Inputs
|
||||||
|
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
|
||||||
|
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
|
||||||
|
h: Union[int | None] = Field(default=512, ge=0, description="content shuffle h parameter")
|
||||||
|
w: Union[int | None] = Field(default=512, ge=0, description="content shuffle w parameter")
|
||||||
|
f: Union[int | None] = Field(default=256, ge=0, description="cont")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def run_processor(self, image):
|
||||||
|
content_shuffle_processor = ContentShuffleDetector()
|
||||||
|
processed_image = content_shuffle_processor(image,
|
||||||
|
detect_resolution=self.detect_resolution,
|
||||||
|
image_resolution=self.image_resolution,
|
||||||
|
h=self.h,
|
||||||
|
w=self.w,
|
||||||
|
f=self.f
|
||||||
|
)
|
||||||
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
|
# should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13
|
||||||
|
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||||
|
"""Applies Zoe depth processing to image"""
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["zoe_depth_image_processor"] = "zoe_depth_image_processor"
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def run_processor(self, image):
|
||||||
|
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
|
processed_image = zoe_depth_processor(image)
|
||||||
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
|
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||||
|
"""Applies mediapipe face processing to image"""
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["mediapipe_face_processor"] = "mediapipe_face_processor"
|
||||||
|
# Inputs
|
||||||
|
max_faces: int = Field(default=1, ge=1, description="maximum number of faces to detect")
|
||||||
|
min_confidence: float = Field(default=0.5, ge=0, le=1, description="minimum confidence for face detection")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def run_processor(self, image):
|
||||||
|
mediapipe_face_processor = MediapipeFaceDetector()
|
||||||
|
processed_image = mediapipe_face_processor(image, max_faces=self.max_faces, min_confidence=self.min_confidence)
|
||||||
|
return processed_image
|
@ -7,7 +7,7 @@ import numpy
|
|||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageCategory, ImageField, ImageType
|
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||||
from .image import ImageOutput
|
from .image import ImageOutput
|
||||||
|
|
||||||
@ -37,10 +37,10 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_origin, self.image.image_name
|
||||||
)
|
)
|
||||||
mask = context.services.images.get_pil_image(
|
mask = context.services.images.get_pil_image(
|
||||||
self.mask.image_type, self.mask.image_name
|
self.mask.image_origin, self.mask.image_name
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert to cv image/mask
|
# Convert to cv image/mask
|
||||||
@ -57,16 +57,17 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
|
|||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=image_inpainted,
|
image=image_inpainted,
|
||||||
image_type=ImageType.INTERMEDIATE,
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(
|
||||||
image_name=image_dto.image_name,
|
image_name=image_dto.image_name,
|
||||||
image_type=image_dto.image_type,
|
image_origin=image_dto.image_origin,
|
||||||
),
|
),
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
|
@ -3,16 +3,20 @@
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Literal, Optional, Union, get_args
|
from typing import Literal, Optional, Union, get_args
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from diffusers import ControlNetModel
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageCategory, ImageType, ColorField, ImageField
|
from invokeai.app.models.image import (ColorField, ImageCategory, ImageField,
|
||||||
|
ResourceOrigin)
|
||||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
from invokeai.backend.generator.inpaint import infill_methods
|
from invokeai.backend.generator.inpaint import infill_methods
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
|
||||||
from .image import ImageOutput
|
from ...backend.generator import Img2Img, Inpaint, InvokeAIGenerator, Txt2Img
|
||||||
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
|
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
from ..util.step_callback import stable_diffusion_step_callback
|
from ..util.step_callback import stable_diffusion_step_callback
|
||||||
|
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
|
||||||
|
from .image import ImageOutput
|
||||||
|
|
||||||
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
|
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
|
||||||
INFILL_METHODS = Literal[tuple(infill_methods())]
|
INFILL_METHODS = Literal[tuple(infill_methods())]
|
||||||
@ -53,6 +57,9 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
|||||||
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||||
model: str = Field(default="", description="The model to use (currently ignored)")
|
model: str = Field(default="", description="The model to use (currently ignored)")
|
||||||
|
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
|
||||||
|
control_model: Optional[str] = Field(default=None, description="The control model to use")
|
||||||
|
control_image: Optional[ImageField] = Field(default=None, description="The processed control image")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||||
@ -73,17 +80,35 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
|||||||
# Handle invalid model parameter
|
# Handle invalid model parameter
|
||||||
model = context.services.model_manager.get_model(self.model,node=self,context=context)
|
model = context.services.model_manager.get_model(self.model,node=self,context=context)
|
||||||
|
|
||||||
|
# loading controlnet image (currently requires pre-processed image)
|
||||||
|
control_image = (
|
||||||
|
None if self.control_image is None
|
||||||
|
else context.services.images.get_pil_image(
|
||||||
|
self.control_image.image_origin, self.control_image.image_name
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# loading controlnet model
|
||||||
|
if (self.control_model is None or self.control_model==''):
|
||||||
|
control_model = None
|
||||||
|
else:
|
||||||
|
# FIXME: change this to dropdown menu?
|
||||||
|
# FIXME: generalize so don't have to hardcode torch_dtype and device
|
||||||
|
control_model = ControlNetModel.from_pretrained(self.control_model,
|
||||||
|
torch_dtype=torch.float16).to("cuda")
|
||||||
|
|
||||||
# Get the source node id (we are invoking the prepared node)
|
# Get the source node id (we are invoking the prepared node)
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(
|
graph_execution_state = context.services.graph_execution_manager.get(
|
||||||
context.graph_execution_state_id
|
context.graph_execution_state_id
|
||||||
)
|
)
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||||
|
|
||||||
outputs = Txt2Img(model).generate(
|
txt2img = Txt2Img(model, control_model=control_model)
|
||||||
|
outputs = txt2img.generate(
|
||||||
prompt=self.prompt,
|
prompt=self.prompt,
|
||||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||||
|
control_image=control_image,
|
||||||
**self.dict(
|
**self.dict(
|
||||||
exclude={"prompt"}
|
exclude={"prompt", "control_image" }
|
||||||
), # Shorthand for passing all of the parameters above manually
|
), # Shorthand for passing all of the parameters above manually
|
||||||
)
|
)
|
||||||
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
||||||
@ -92,16 +117,17 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
|||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=generate_output.image,
|
image=generate_output.image,
|
||||||
image_type=ImageType.RESULT,
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(
|
||||||
image_name=image_dto.image_name,
|
image_name=image_dto.image_name,
|
||||||
image_type=image_dto.image_type,
|
image_origin=image_dto.image_origin,
|
||||||
),
|
),
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
@ -141,7 +167,7 @@ class ImageToImageInvocation(TextToImageInvocation):
|
|||||||
None
|
None
|
||||||
if self.image is None
|
if self.image is None
|
||||||
else context.services.images.get_pil_image(
|
else context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_origin, self.image.image_name
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -172,16 +198,17 @@ class ImageToImageInvocation(TextToImageInvocation):
|
|||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=generator_output.image,
|
image=generator_output.image,
|
||||||
image_type=ImageType.RESULT,
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(
|
||||||
image_name=image_dto.image_name,
|
image_name=image_dto.image_name,
|
||||||
image_type=image_dto.image_type,
|
image_origin=image_dto.image_origin,
|
||||||
),
|
),
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
@ -253,13 +280,13 @@ class InpaintInvocation(ImageToImageInvocation):
|
|||||||
None
|
None
|
||||||
if self.image is None
|
if self.image is None
|
||||||
else context.services.images.get_pil_image(
|
else context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_origin, self.image.image_name
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
mask = (
|
mask = (
|
||||||
None
|
None
|
||||||
if self.mask is None
|
if self.mask is None
|
||||||
else context.services.images.get_pil_image(self.mask.image_type, self.mask.image_name)
|
else context.services.images.get_pil_image(self.mask.image_origin, self.mask.image_name)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Handle invalid model parameter
|
# Handle invalid model parameter
|
||||||
@ -287,16 +314,17 @@ class InpaintInvocation(ImageToImageInvocation):
|
|||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=generator_output.image,
|
image=generator_output.image,
|
||||||
image_type=ImageType.RESULT,
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(
|
||||||
image_name=image_dto.image_name,
|
image_name=image_dto.image_name,
|
||||||
image_type=image_dto.image_type,
|
image_origin=image_dto.image_origin,
|
||||||
),
|
),
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
|
@ -7,7 +7,7 @@ import numpy
|
|||||||
from PIL import Image, ImageFilter, ImageOps, ImageChops
|
from PIL import Image, ImageFilter, ImageOps, ImageChops
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from ..models.image import ImageCategory, ImageField, ImageType
|
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
@ -72,12 +72,12 @@ class LoadImageInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_type, self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_origin, self.image.image_name)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(
|
||||||
image_name=self.image.image_name,
|
image_name=self.image.image_name,
|
||||||
image_type=self.image.image_type,
|
image_origin=self.image.image_origin,
|
||||||
),
|
),
|
||||||
width=image.width,
|
width=image.width,
|
||||||
height=image.height,
|
height=image.height,
|
||||||
@ -96,7 +96,7 @@ class ShowImageInvocation(BaseInvocation):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_origin, self.image.image_name
|
||||||
)
|
)
|
||||||
if image:
|
if image:
|
||||||
image.show()
|
image.show()
|
||||||
@ -106,7 +106,7 @@ class ShowImageInvocation(BaseInvocation):
|
|||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(
|
||||||
image_name=self.image.image_name,
|
image_name=self.image.image_name,
|
||||||
image_type=self.image.image_type,
|
image_origin=self.image.image_origin,
|
||||||
),
|
),
|
||||||
width=image.width,
|
width=image.width,
|
||||||
height=image.height,
|
height=image.height,
|
||||||
@ -129,7 +129,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_origin, self.image.image_name
|
||||||
)
|
)
|
||||||
|
|
||||||
image_crop = Image.new(
|
image_crop = Image.new(
|
||||||
@ -139,16 +139,17 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=image_crop,
|
image=image_crop,
|
||||||
image_type=ImageType.RESULT,
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(
|
||||||
image_name=image_dto.image_name,
|
image_name=image_dto.image_name,
|
||||||
image_type=image_dto.image_type,
|
image_origin=image_dto.image_origin,
|
||||||
),
|
),
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
@ -171,17 +172,17 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
base_image = context.services.images.get_pil_image(
|
base_image = context.services.images.get_pil_image(
|
||||||
self.base_image.image_type, self.base_image.image_name
|
self.base_image.image_origin, self.base_image.image_name
|
||||||
)
|
)
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_origin, self.image.image_name
|
||||||
)
|
)
|
||||||
mask = (
|
mask = (
|
||||||
None
|
None
|
||||||
if self.mask is None
|
if self.mask is None
|
||||||
else ImageOps.invert(
|
else ImageOps.invert(
|
||||||
context.services.images.get_pil_image(
|
context.services.images.get_pil_image(
|
||||||
self.mask.image_type, self.mask.image_name
|
self.mask.image_origin, self.mask.image_name
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -200,16 +201,17 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=new_image,
|
image=new_image,
|
||||||
image_type=ImageType.RESULT,
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(
|
||||||
image_name=image_dto.image_name,
|
image_name=image_dto.image_name,
|
||||||
image_type=image_dto.image_type,
|
image_origin=image_dto.image_origin,
|
||||||
),
|
),
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
@ -229,7 +231,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_origin, self.image.image_name
|
||||||
)
|
)
|
||||||
|
|
||||||
image_mask = image.split()[-1]
|
image_mask = image.split()[-1]
|
||||||
@ -238,15 +240,16 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=image_mask,
|
image=image_mask,
|
||||||
image_type=ImageType.RESULT,
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
image_category=ImageCategory.MASK,
|
image_category=ImageCategory.MASK,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
)
|
)
|
||||||
|
|
||||||
return MaskOutput(
|
return MaskOutput(
|
||||||
mask=ImageField(
|
mask=ImageField(
|
||||||
image_type=image_dto.image_type, image_name=image_dto.image_name
|
image_origin=image_dto.image_origin, image_name=image_dto.image_name
|
||||||
),
|
),
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
@ -266,25 +269,26 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image1 = context.services.images.get_pil_image(
|
image1 = context.services.images.get_pil_image(
|
||||||
self.image1.image_type, self.image1.image_name
|
self.image1.image_origin, self.image1.image_name
|
||||||
)
|
)
|
||||||
image2 = context.services.images.get_pil_image(
|
image2 = context.services.images.get_pil_image(
|
||||||
self.image2.image_type, self.image2.image_name
|
self.image2.image_origin, self.image2.image_name
|
||||||
)
|
)
|
||||||
|
|
||||||
multiply_image = ImageChops.multiply(image1, image2)
|
multiply_image = ImageChops.multiply(image1, image2)
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=multiply_image,
|
image=multiply_image,
|
||||||
image_type=ImageType.RESULT,
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(
|
||||||
image_type=image_dto.image_type, image_name=image_dto.image_name
|
image_origin=image_dto.image_origin, image_name=image_dto.image_name
|
||||||
),
|
),
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
@ -307,22 +311,23 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_origin, self.image.image_name
|
||||||
)
|
)
|
||||||
|
|
||||||
channel_image = image.getchannel(self.channel)
|
channel_image = image.getchannel(self.channel)
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=channel_image,
|
image=channel_image,
|
||||||
image_type=ImageType.RESULT,
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(
|
||||||
image_type=image_dto.image_type, image_name=image_dto.image_name
|
image_origin=image_dto.image_origin, image_name=image_dto.image_name
|
||||||
),
|
),
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
@ -345,22 +350,23 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_origin, self.image.image_name
|
||||||
)
|
)
|
||||||
|
|
||||||
converted_image = image.convert(self.mode)
|
converted_image = image.convert(self.mode)
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=converted_image,
|
image=converted_image,
|
||||||
image_type=ImageType.RESULT,
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(
|
||||||
image_type=image_dto.image_type, image_name=image_dto.image_name
|
image_origin=image_dto.image_origin, image_name=image_dto.image_name
|
||||||
),
|
),
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
@ -381,7 +387,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_origin, self.image.image_name
|
||||||
)
|
)
|
||||||
|
|
||||||
blur = (
|
blur = (
|
||||||
@ -393,16 +399,126 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=blur_image,
|
image=blur_image,
|
||||||
image_type=ImageType.RESULT,
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(
|
||||||
image_name=image_dto.image_name,
|
image_name=image_dto.image_name,
|
||||||
image_type=image_dto.image_type,
|
image_origin=image_dto.image_origin,
|
||||||
|
),
|
||||||
|
width=image_dto.width,
|
||||||
|
height=image_dto.height,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
PIL_RESAMPLING_MODES = Literal[
|
||||||
|
"nearest",
|
||||||
|
"box",
|
||||||
|
"bilinear",
|
||||||
|
"hamming",
|
||||||
|
"bicubic",
|
||||||
|
"lanczos",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
PIL_RESAMPLING_MAP = {
|
||||||
|
"nearest": Image.Resampling.NEAREST,
|
||||||
|
"box": Image.Resampling.BOX,
|
||||||
|
"bilinear": Image.Resampling.BILINEAR,
|
||||||
|
"hamming": Image.Resampling.HAMMING,
|
||||||
|
"bicubic": Image.Resampling.BICUBIC,
|
||||||
|
"lanczos": Image.Resampling.LANCZOS,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
|
||||||
|
"""Resizes an image to specific dimensions"""
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["img_resize"] = "img_resize"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
image: Union[ImageField, None] = Field(default=None, description="The image to resize")
|
||||||
|
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
|
||||||
|
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
|
||||||
|
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
image = context.services.images.get_pil_image(
|
||||||
|
self.image.image_origin, self.image.image_name
|
||||||
|
)
|
||||||
|
|
||||||
|
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
||||||
|
|
||||||
|
resize_image = image.resize(
|
||||||
|
(self.width, self.height),
|
||||||
|
resample=resample_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
image_dto = context.services.images.create(
|
||||||
|
image=resize_image,
|
||||||
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
|
image_category=ImageCategory.GENERAL,
|
||||||
|
node_id=self.id,
|
||||||
|
session_id=context.graph_execution_state_id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ImageOutput(
|
||||||
|
image=ImageField(
|
||||||
|
image_name=image_dto.image_name,
|
||||||
|
image_origin=image_dto.image_origin,
|
||||||
|
),
|
||||||
|
width=image_dto.width,
|
||||||
|
height=image_dto.height,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
|
||||||
|
"""Scales an image by a factor"""
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["img_scale"] = "img_scale"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
image: Union[ImageField, None] = Field(default=None, description="The image to scale")
|
||||||
|
scale_factor: float = Field(gt=0, description="The factor by which to scale the image")
|
||||||
|
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
image = context.services.images.get_pil_image(
|
||||||
|
self.image.image_origin, self.image.image_name
|
||||||
|
)
|
||||||
|
|
||||||
|
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
||||||
|
width = int(image.width * self.scale_factor)
|
||||||
|
height = int(image.height * self.scale_factor)
|
||||||
|
|
||||||
|
resize_image = image.resize(
|
||||||
|
(width, height),
|
||||||
|
resample=resample_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
image_dto = context.services.images.create(
|
||||||
|
image=resize_image,
|
||||||
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
|
image_category=ImageCategory.GENERAL,
|
||||||
|
node_id=self.id,
|
||||||
|
session_id=context.graph_execution_state_id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ImageOutput(
|
||||||
|
image=ImageField(
|
||||||
|
image_name=image_dto.image_name,
|
||||||
|
image_origin=image_dto.image_origin,
|
||||||
),
|
),
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
@ -423,7 +539,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_origin, self.image.image_name
|
||||||
)
|
)
|
||||||
|
|
||||||
image_arr = numpy.asarray(image, dtype=numpy.float32) / 255
|
image_arr = numpy.asarray(image, dtype=numpy.float32) / 255
|
||||||
@ -433,16 +549,17 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=lerp_image,
|
image=lerp_image,
|
||||||
image_type=ImageType.RESULT,
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(
|
||||||
image_name=image_dto.image_name,
|
image_name=image_dto.image_name,
|
||||||
image_type=image_dto.image_type,
|
image_origin=image_dto.image_origin,
|
||||||
),
|
),
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
@ -463,7 +580,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_origin, self.image.image_name
|
||||||
)
|
)
|
||||||
|
|
||||||
image_arr = numpy.asarray(image, dtype=numpy.float32)
|
image_arr = numpy.asarray(image, dtype=numpy.float32)
|
||||||
@ -478,16 +595,17 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=ilerp_image,
|
image=ilerp_image,
|
||||||
image_type=ImageType.RESULT,
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(
|
||||||
image_name=image_dto.image_name,
|
image_name=image_dto.image_name,
|
||||||
image_type=image_dto.image_type,
|
image_origin=image_dto.image_origin,
|
||||||
),
|
),
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
|
@ -11,7 +11,7 @@ from invokeai.app.invocations.image import ImageOutput
|
|||||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||||
|
|
||||||
from ..models.image import ColorField, ImageCategory, ImageField, ImageType
|
from ..models.image import ColorField, ImageCategory, ImageField, ResourceOrigin
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
InvocationContext,
|
InvocationContext,
|
||||||
@ -135,7 +135,7 @@ class InfillColorInvocation(BaseInvocation):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_origin, self.image.image_name
|
||||||
)
|
)
|
||||||
|
|
||||||
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
|
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
|
||||||
@ -145,16 +145,17 @@ class InfillColorInvocation(BaseInvocation):
|
|||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=infilled,
|
image=infilled,
|
||||||
image_type=ImageType.RESULT,
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(
|
||||||
image_name=image_dto.image_name,
|
image_name=image_dto.image_name,
|
||||||
image_type=image_dto.image_type,
|
image_origin=image_dto.image_origin,
|
||||||
),
|
),
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
@ -179,7 +180,7 @@ class InfillTileInvocation(BaseInvocation):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_origin, self.image.image_name
|
||||||
)
|
)
|
||||||
|
|
||||||
infilled = tile_fill_missing(
|
infilled = tile_fill_missing(
|
||||||
@ -189,16 +190,17 @@ class InfillTileInvocation(BaseInvocation):
|
|||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=infilled,
|
image=infilled,
|
||||||
image_type=ImageType.RESULT,
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(
|
||||||
image_name=image_dto.image_name,
|
image_name=image_dto.image_name,
|
||||||
image_type=image_dto.image_type,
|
image_origin=image_dto.image_origin,
|
||||||
),
|
),
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
@ -216,7 +218,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_origin, self.image.image_name
|
||||||
)
|
)
|
||||||
|
|
||||||
if PatchMatch.patchmatch_available():
|
if PatchMatch.patchmatch_available():
|
||||||
@ -226,16 +228,17 @@ class InfillPatchMatchInvocation(BaseInvocation):
|
|||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=infilled,
|
image=infilled,
|
||||||
image_type=ImageType.RESULT,
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(
|
||||||
image_name=image_dto.image_name,
|
image_name=image_dto.image_name,
|
||||||
image_type=image_dto.image_type,
|
image_origin=image_dto.image_origin,
|
||||||
),
|
),
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
|
@ -1,37 +1,36 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
from typing import Literal, Optional, Union
|
from contextlib import ExitStack
|
||||||
|
from typing import List, Literal, Optional, Union
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
import torch
|
import torch
|
||||||
|
from diffusers import ControlNetModel
|
||||||
from diffusers.image_processor import VaeImageProcessor
|
from diffusers.image_processor import VaeImageProcessor
|
||||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||||
from pydantic import BaseModel, Field, validator
|
from pydantic import BaseModel, Field, validator
|
||||||
from contextlib import ExitStack
|
|
||||||
|
|
||||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
|
|
||||||
|
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||||
from ...backend.image_util.seamless import configure_model_padding
|
from ...backend.image_util.seamless import configure_model_padding
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||||
ConditioningData, StableDiffusionGeneratorPipeline,
|
ConditioningData, ControlNetData, StableDiffusionGeneratorPipeline,
|
||||||
image_resized_to_grid_as_tensor)
|
image_resized_to_grid_as_tensor)
|
||||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \
|
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \
|
||||||
PostprocessingSettings
|
PostprocessingSettings
|
||||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||||
from ..services.image_file_storage import ImageType
|
from ...backend.model_management.lora import ModelPatcher
|
||||||
from ..services.model_manager_service import ModelManagerService
|
|
||||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||||
InvocationConfig, InvocationContext)
|
InvocationConfig, InvocationContext)
|
||||||
from .compel import ConditioningField
|
from .compel import ConditioningField
|
||||||
from .image import ImageCategory, ImageField, ImageOutput
|
from .controlnet_image_processors import ControlField
|
||||||
|
from .image import ImageOutput
|
||||||
from .model import ModelInfo, UNetField, VaeField
|
from .model import ModelInfo, UNetField, VaeField
|
||||||
|
|
||||||
from ...backend.model_management.lora import LoRAHelper
|
|
||||||
|
|
||||||
|
|
||||||
class LatentsField(BaseModel):
|
class LatentsField(BaseModel):
|
||||||
"""A latents field used for passing latents between invocations"""
|
"""A latents field used for passing latents between invocations"""
|
||||||
|
|
||||||
@ -93,10 +92,12 @@ def get_scheduler(
|
|||||||
orig_scheduler_info = context.services.model_manager.get_model(**scheduler_info.dict())
|
orig_scheduler_info = context.services.model_manager.get_model(**scheduler_info.dict())
|
||||||
with orig_scheduler_info as orig_scheduler:
|
with orig_scheduler_info as orig_scheduler:
|
||||||
scheduler_config = orig_scheduler.config
|
scheduler_config = orig_scheduler.config
|
||||||
|
|
||||||
if "_backup" in scheduler_config:
|
if "_backup" in scheduler_config:
|
||||||
scheduler_config = scheduler_config["_backup"]
|
scheduler_config = scheduler_config["_backup"]
|
||||||
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
|
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
|
||||||
scheduler = scheduler_class.from_config(scheduler_config)
|
scheduler = scheduler_class.from_config(scheduler_config)
|
||||||
|
|
||||||
# hack copied over from generate.py
|
# hack copied over from generate.py
|
||||||
if not hasattr(scheduler, 'uses_inpainting_model'):
|
if not hasattr(scheduler, 'uses_inpainting_model'):
|
||||||
scheduler.uses_inpainting_model = lambda: False
|
scheduler.uses_inpainting_model = lambda: False
|
||||||
@ -171,12 +172,13 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
||||||
noise: Optional[LatentsField] = Field(description="The noise to use")
|
noise: Optional[LatentsField] = Field(description="The noise to use")
|
||||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||||
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||||
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||||
seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||||
|
|
||||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||||
|
control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
# Schema customisation
|
# Schema customisation
|
||||||
@ -184,6 +186,10 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {
|
||||||
"tags": ["latents", "image"],
|
"tags": ["latents", "image"],
|
||||||
|
"type_hints": {
|
||||||
|
"model": "model",
|
||||||
|
"control": "control",
|
||||||
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -244,6 +250,82 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
#precision="float16", # TODO:
|
#precision="float16", # TODO:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def prep_control_data(self,
|
||||||
|
context: InvocationContext,
|
||||||
|
model: StableDiffusionGeneratorPipeline, # really only need model for dtype and device
|
||||||
|
control_input: List[ControlField],
|
||||||
|
latents_shape: List[int],
|
||||||
|
do_classifier_free_guidance: bool = True,
|
||||||
|
) -> List[ControlNetData]:
|
||||||
|
# assuming fixed dimensional scaling of 8:1 for image:latents
|
||||||
|
control_height_resize = latents_shape[2] * 8
|
||||||
|
control_width_resize = latents_shape[3] * 8
|
||||||
|
if control_input is None:
|
||||||
|
# print("control input is None")
|
||||||
|
control_list = None
|
||||||
|
elif isinstance(control_input, list) and len(control_input) == 0:
|
||||||
|
# print("control input is empty list")
|
||||||
|
control_list = None
|
||||||
|
elif isinstance(control_input, ControlField):
|
||||||
|
# print("control input is ControlField")
|
||||||
|
control_list = [control_input]
|
||||||
|
elif isinstance(control_input, list) and len(control_input) > 0 and isinstance(control_input[0], ControlField):
|
||||||
|
# print("control input is list[ControlField]")
|
||||||
|
control_list = control_input
|
||||||
|
else:
|
||||||
|
# print("input control is unrecognized:", type(self.control))
|
||||||
|
control_list = None
|
||||||
|
if (control_list is None):
|
||||||
|
control_data = None
|
||||||
|
# from above handling, any control that is not None should now be of type list[ControlField]
|
||||||
|
else:
|
||||||
|
# FIXME: add checks to skip entry if model or image is None
|
||||||
|
# and if weight is None, populate with default 1.0?
|
||||||
|
control_data = []
|
||||||
|
control_models = []
|
||||||
|
for control_info in control_list:
|
||||||
|
# handle control models
|
||||||
|
if ("," in control_info.control_model):
|
||||||
|
control_model_split = control_info.control_model.split(",")
|
||||||
|
control_name = control_model_split[0]
|
||||||
|
control_subfolder = control_model_split[1]
|
||||||
|
print("Using HF model subfolders")
|
||||||
|
print(" control_name: ", control_name)
|
||||||
|
print(" control_subfolder: ", control_subfolder)
|
||||||
|
control_model = ControlNetModel.from_pretrained(control_name,
|
||||||
|
subfolder=control_subfolder,
|
||||||
|
torch_dtype=model.unet.dtype).to(model.device)
|
||||||
|
else:
|
||||||
|
control_model = ControlNetModel.from_pretrained(control_info.control_model,
|
||||||
|
torch_dtype=model.unet.dtype).to(model.device)
|
||||||
|
control_models.append(control_model)
|
||||||
|
control_image_field = control_info.image
|
||||||
|
input_image = context.services.images.get_pil_image(control_image_field.image_origin,
|
||||||
|
control_image_field.image_name)
|
||||||
|
# self.image.image_type, self.image.image_name
|
||||||
|
# FIXME: still need to test with different widths, heights, devices, dtypes
|
||||||
|
# and add in batch_size, num_images_per_prompt?
|
||||||
|
# and do real check for classifier_free_guidance?
|
||||||
|
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
|
||||||
|
control_image = model.prepare_control_image(
|
||||||
|
image=input_image,
|
||||||
|
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||||
|
width=control_width_resize,
|
||||||
|
height=control_height_resize,
|
||||||
|
# batch_size=batch_size * num_images_per_prompt,
|
||||||
|
# num_images_per_prompt=num_images_per_prompt,
|
||||||
|
device=control_model.device,
|
||||||
|
dtype=control_model.dtype,
|
||||||
|
)
|
||||||
|
control_item = ControlNetData(model=control_model,
|
||||||
|
image_tensor=control_image,
|
||||||
|
weight=control_info.control_weight,
|
||||||
|
begin_step_percent=control_info.begin_step_percent,
|
||||||
|
end_step_percent=control_info.end_step_percent)
|
||||||
|
control_data.append(control_item)
|
||||||
|
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
|
||||||
|
return control_data
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
noise = context.services.latents.get(self.noise.latents_name)
|
noise = context.services.latents.get(self.noise.latents_name)
|
||||||
|
|
||||||
@ -269,13 +351,19 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
||||||
|
|
||||||
with LoRAHelper.apply_lora_unet(pipeline.unet, loras):
|
print("type of control input: ", type(self.control))
|
||||||
|
control_data = self.prep_control_data(model=pipeline, context=context, control_input=self.control,
|
||||||
|
latents_shape=noise.shape,
|
||||||
|
do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||||
|
|
||||||
|
with ModelPatcher.apply_lora_unet(pipeline.unet, loras):
|
||||||
# TODO: Verify the noise is the right size
|
# TODO: Verify the noise is the right size
|
||||||
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
||||||
latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)),
|
latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)),
|
||||||
noise=noise,
|
noise=noise,
|
||||||
num_inference_steps=self.steps,
|
num_inference_steps=self.steps,
|
||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
|
control_data=control_data, # list[ControlNetData]
|
||||||
callback=step_callback
|
callback=step_callback
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -286,7 +374,6 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
context.services.latents.save(name, result_latents)
|
context.services.latents.save(name, result_latents)
|
||||||
return build_latents_output(latents_name=name, latents=result_latents)
|
return build_latents_output(latents_name=name, latents=result_latents)
|
||||||
|
|
||||||
|
|
||||||
class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||||
"""Generates latents using latents as base image."""
|
"""Generates latents using latents as base image."""
|
||||||
|
|
||||||
@ -294,13 +381,17 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
|
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
|
||||||
strength: float = Field(default=0.5, description="The strength of the latents to use")
|
strength: float = Field(default=0.7, ge=0, le=1, description="The strength of the latents to use")
|
||||||
|
|
||||||
# Schema customisation
|
# Schema customisation
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {
|
||||||
"tags": ["latents"],
|
"tags": ["latents"],
|
||||||
|
"type_hints": {
|
||||||
|
"model": "model",
|
||||||
|
"control": "control",
|
||||||
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -315,7 +406,6 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
def step_callback(state: PipelineIntermediateState):
|
def step_callback(state: PipelineIntermediateState):
|
||||||
self.dispatch_progress(context, source_node_id, state)
|
self.dispatch_progress(context, source_node_id, state)
|
||||||
|
|
||||||
#unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
|
||||||
unet_info = context.services.model_manager.get_model(
|
unet_info = context.services.model_manager.get_model(
|
||||||
**self.unet.unet.dict(),
|
**self.unet.unet.dict(),
|
||||||
)
|
)
|
||||||
@ -345,7 +435,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
|
|
||||||
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
||||||
|
|
||||||
with LoRAHelper.apply_lora_unet(pipeline.unet, loras):
|
with ModelPatcher.apply_lora_unet(pipeline.unet, loras):
|
||||||
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
||||||
latents=initial_latents,
|
latents=initial_latents,
|
||||||
timesteps=timesteps,
|
timesteps=timesteps,
|
||||||
@ -413,7 +503,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=image,
|
image=image,
|
||||||
image_type=ImageType.RESULT,
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
@ -459,6 +549,7 @@ class ResizeLatentsInvocation(BaseInvocation):
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
|
# context.services.latents.set(name, resized_latents)
|
||||||
context.services.latents.save(name, resized_latents)
|
context.services.latents.save(name, resized_latents)
|
||||||
return build_latents_output(latents_name=name, latents=resized_latents)
|
return build_latents_output(latents_name=name, latents=resized_latents)
|
||||||
|
|
||||||
@ -489,6 +580,7 @@ class ScaleLatentsInvocation(BaseInvocation):
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
|
# context.services.latents.set(name, resized_latents)
|
||||||
context.services.latents.save(name, resized_latents)
|
context.services.latents.save(name, resized_latents)
|
||||||
return build_latents_output(latents_name=name, latents=resized_latents)
|
return build_latents_output(latents_name=name, latents=resized_latents)
|
||||||
|
|
||||||
@ -513,8 +605,11 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
|
# image = context.services.images.get(
|
||||||
|
# self.image.image_type, self.image.image_name
|
||||||
|
# )
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_origin, self.image.image_name
|
||||||
)
|
)
|
||||||
|
|
||||||
#vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
|
#vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
|
||||||
@ -543,6 +638,6 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
latents = 0.18215 * latents
|
latents = 0.18215 * latents
|
||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
|
# context.services.latents.set(name, latents)
|
||||||
context.services.latents.save(name, latents)
|
context.services.latents.save(name, latents)
|
||||||
return build_latents_output(latents_name=name, latents=latents)
|
return build_latents_output(latents_name=name, latents=latents)
|
||||||
|
|
||||||
|
@ -34,6 +34,15 @@ class IntOutput(BaseInvocationOutput):
|
|||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
class FloatOutput(BaseInvocationOutput):
|
||||||
|
"""A float output"""
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["float_output"] = "float_output"
|
||||||
|
param: float = Field(default=None, description="The output float")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
class AddInvocation(BaseInvocation, MathInvocationConfig):
|
class AddInvocation(BaseInvocation, MathInvocationConfig):
|
||||||
"""Adds two numbers"""
|
"""Adds two numbers"""
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
from typing import Literal
|
from typing import Literal
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||||
from .math import IntOutput
|
from .math import IntOutput, FloatOutput
|
||||||
|
|
||||||
# Pass-through parameter nodes - used by subgraphs
|
# Pass-through parameter nodes - used by subgraphs
|
||||||
|
|
||||||
@ -16,3 +16,13 @@ class ParamIntInvocation(BaseInvocation):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||||
return IntOutput(a=self.a)
|
return IntOutput(a=self.a)
|
||||||
|
|
||||||
|
class ParamFloatInvocation(BaseInvocation):
|
||||||
|
"""A float parameter"""
|
||||||
|
#fmt: off
|
||||||
|
type: Literal["param_float"] = "param_float"
|
||||||
|
param: float = Field(default=0.0, description="The float value")
|
||||||
|
#fmt: on
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> FloatOutput:
|
||||||
|
return FloatOutput(param=self.param)
|
||||||
|
@ -2,7 +2,7 @@ from typing import Literal, Union
|
|||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageCategory, ImageField, ImageType
|
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||||
from .image import ImageOutput
|
from .image import ImageOutput
|
||||||
@ -29,7 +29,7 @@ class RestoreFaceInvocation(BaseInvocation):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_origin, self.image.image_name
|
||||||
)
|
)
|
||||||
results = context.services.restoration.upscale_and_reconstruct(
|
results = context.services.restoration.upscale_and_reconstruct(
|
||||||
image_list=[[image, 0]],
|
image_list=[[image, 0]],
|
||||||
@ -43,16 +43,17 @@ class RestoreFaceInvocation(BaseInvocation):
|
|||||||
# TODO: can this return multiple results?
|
# TODO: can this return multiple results?
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=results[0][0],
|
image=results[0][0],
|
||||||
image_type=ImageType.INTERMEDIATE,
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(
|
||||||
image_name=image_dto.image_name,
|
image_name=image_dto.image_name,
|
||||||
image_type=image_dto.image_type,
|
image_origin=image_dto.image_origin,
|
||||||
),
|
),
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
|
@ -4,7 +4,7 @@ from typing import Literal, Union
|
|||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageCategory, ImageField, ImageType
|
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||||
from .image import ImageOutput
|
from .image import ImageOutput
|
||||||
|
|
||||||
@ -31,7 +31,7 @@ class UpscaleInvocation(BaseInvocation):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_origin, self.image.image_name
|
||||||
)
|
)
|
||||||
results = context.services.restoration.upscale_and_reconstruct(
|
results = context.services.restoration.upscale_and_reconstruct(
|
||||||
image_list=[[image, 0]],
|
image_list=[[image, 0]],
|
||||||
@ -45,16 +45,17 @@ class UpscaleInvocation(BaseInvocation):
|
|||||||
# TODO: can this return multiple results?
|
# TODO: can this return multiple results?
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=results[0][0],
|
image=results[0][0],
|
||||||
image_type=ImageType.RESULT,
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(
|
||||||
image_name=image_dto.image_name,
|
image_name=image_dto.image_name,
|
||||||
image_type=image_dto.image_type,
|
image_origin=image_dto.image_origin,
|
||||||
),
|
),
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
|
@ -5,31 +5,52 @@ from pydantic import BaseModel, Field
|
|||||||
from invokeai.app.util.metaenum import MetaEnum
|
from invokeai.app.util.metaenum import MetaEnum
|
||||||
|
|
||||||
|
|
||||||
class ImageType(str, Enum, metaclass=MetaEnum):
|
class ResourceOrigin(str, Enum, metaclass=MetaEnum):
|
||||||
"""The type of an image."""
|
"""The origin of a resource (eg image).
|
||||||
|
|
||||||
RESULT = "results"
|
- INTERNAL: The resource was created by the application.
|
||||||
UPLOAD = "uploads"
|
- EXTERNAL: The resource was not created by the application.
|
||||||
INTERMEDIATE = "intermediates"
|
This may be a user-initiated upload, or an internal application upload (eg Canvas init image).
|
||||||
|
"""
|
||||||
|
|
||||||
|
INTERNAL = "internal"
|
||||||
|
"""The resource was created by the application."""
|
||||||
|
EXTERNAL = "external"
|
||||||
|
"""The resource was not created by the application.
|
||||||
|
This may be a user-initiated upload, or an internal application upload (eg Canvas init image).
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class InvalidImageTypeException(ValueError):
|
class InvalidOriginException(ValueError):
|
||||||
"""Raised when a provided value is not a valid ImageType.
|
"""Raised when a provided value is not a valid ResourceOrigin.
|
||||||
|
|
||||||
Subclasses `ValueError`.
|
Subclasses `ValueError`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, message="Invalid image type."):
|
def __init__(self, message="Invalid resource origin."):
|
||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
class ImageCategory(str, Enum, metaclass=MetaEnum):
|
class ImageCategory(str, Enum, metaclass=MetaEnum):
|
||||||
"""The category of an image. Use ImageCategory.OTHER for non-default categories."""
|
"""The category of an image.
|
||||||
|
|
||||||
|
- GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose.
|
||||||
|
- MASK: The image is a mask image.
|
||||||
|
- CONTROL: The image is a ControlNet control image.
|
||||||
|
- USER: The image is a user-provide image.
|
||||||
|
- OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes.
|
||||||
|
"""
|
||||||
|
|
||||||
GENERAL = "general"
|
GENERAL = "general"
|
||||||
CONTROL = "control"
|
"""GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose."""
|
||||||
MASK = "mask"
|
MASK = "mask"
|
||||||
|
"""MASK: The image is a mask image."""
|
||||||
|
CONTROL = "control"
|
||||||
|
"""CONTROL: The image is a ControlNet control image."""
|
||||||
|
USER = "user"
|
||||||
|
"""USER: The image is a user-provide image."""
|
||||||
OTHER = "other"
|
OTHER = "other"
|
||||||
|
"""OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes."""
|
||||||
|
|
||||||
|
|
||||||
class InvalidImageCategoryException(ValueError):
|
class InvalidImageCategoryException(ValueError):
|
||||||
@ -45,13 +66,13 @@ class InvalidImageCategoryException(ValueError):
|
|||||||
class ImageField(BaseModel):
|
class ImageField(BaseModel):
|
||||||
"""An image field used for passing image objects between invocations"""
|
"""An image field used for passing image objects between invocations"""
|
||||||
|
|
||||||
image_type: ImageType = Field(
|
image_origin: ResourceOrigin = Field(
|
||||||
default=ImageType.RESULT, description="The type of the image"
|
default=ResourceOrigin.INTERNAL, description="The type of the image"
|
||||||
)
|
)
|
||||||
image_name: Optional[str] = Field(default=None, description="The name of the image")
|
image_name: Optional[str] = Field(default=None, description="The name of the image")
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
schema_extra = {"required": ["image_type", "image_name"]}
|
schema_extra = {"required": ["image_origin", "image_name"]}
|
||||||
|
|
||||||
|
|
||||||
class ColorField(BaseModel):
|
class ColorField(BaseModel):
|
||||||
@ -62,3 +83,11 @@ class ColorField(BaseModel):
|
|||||||
|
|
||||||
def tuple(self) -> Tuple[int, int, int, int]:
|
def tuple(self) -> Tuple[int, int, int, int]:
|
||||||
return (self.r, self.g, self.b, self.a)
|
return (self.r, self.g, self.b, self.a)
|
||||||
|
|
||||||
|
|
||||||
|
class ProgressImage(BaseModel):
|
||||||
|
"""The progress image sent intermittently during processing"""
|
||||||
|
|
||||||
|
width: int = Field(description="The effective width of the image in pixels")
|
||||||
|
height: int = Field(description="The effective height of the image in pixels")
|
||||||
|
dataURL: str = Field(description="The image data as a b64 data URL")
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
from typing import Any, Optional
|
from typing import Any
|
||||||
from invokeai.app.api.models.images import ProgressImage
|
from invokeai.app.models.image import ProgressImage
|
||||||
from invokeai.app.util.misc import get_timestamp
|
from invokeai.app.util.misc import get_timestamp
|
||||||
from invokeai.app.services.model_manager_service import SDModelType, SDModelInfo
|
from invokeai.app.services.model_manager_service import SDModelType, SDModelInfo
|
||||||
from invokeai.app.models.exceptions import CanceledException
|
from invokeai.app.models.exceptions import CanceledException
|
||||||
|
@ -60,6 +60,35 @@ def get_input_field(node: BaseInvocation, field: str) -> Any:
|
|||||||
node_input_field = node_inputs.get(field) or None
|
node_input_field = node_inputs.get(field) or None
|
||||||
return node_input_field
|
return node_input_field
|
||||||
|
|
||||||
|
from typing import Optional, Union, List, get_args
|
||||||
|
|
||||||
|
def is_union_subtype(t1, t2):
|
||||||
|
t1_args = get_args(t1)
|
||||||
|
t2_args = get_args(t2)
|
||||||
|
|
||||||
|
if not t1_args:
|
||||||
|
# t1 is a single type
|
||||||
|
return t1 in t2_args
|
||||||
|
else:
|
||||||
|
# t1 is a Union, check that all of its types are in t2_args
|
||||||
|
return all(arg in t2_args for arg in t1_args)
|
||||||
|
|
||||||
|
def is_list_or_contains_list(t):
|
||||||
|
t_args = get_args(t)
|
||||||
|
|
||||||
|
# If the type is a List
|
||||||
|
if get_origin(t) is list:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# If the type is a Union
|
||||||
|
elif t_args:
|
||||||
|
# Check if any of the types in the Union is a List
|
||||||
|
for arg in t_args:
|
||||||
|
if get_origin(arg) is list:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
|
def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
|
||||||
if not from_type:
|
if not from_type:
|
||||||
@ -85,7 +114,8 @@ def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
|
|||||||
if to_type in get_args(from_type):
|
if to_type in get_args(from_type):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if not issubclass(from_type, to_type):
|
# if not issubclass(from_type, to_type):
|
||||||
|
if not is_union_subtype(from_type, to_type):
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
@ -694,7 +724,11 @@ class Graph(BaseModel):
|
|||||||
input_root_type = next(t[0] for t in type_degrees if t[1] == 0) # type: ignore
|
input_root_type = next(t[0] for t in type_degrees if t[1] == 0) # type: ignore
|
||||||
|
|
||||||
# Verify that all outputs are lists
|
# Verify that all outputs are lists
|
||||||
if not all((get_origin(f) == list for f in output_fields)):
|
# if not all((get_origin(f) == list for f in output_fields)):
|
||||||
|
# return False
|
||||||
|
|
||||||
|
# Verify that all outputs are lists
|
||||||
|
if not all(is_list_or_contains_list(f) for f in output_fields):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Verify that all outputs match the input type (are a base class or the same class)
|
# Verify that all outputs match the input type (are a base class or the same class)
|
||||||
|
@ -9,7 +9,7 @@ from PIL.Image import Image as PILImageType
|
|||||||
from PIL import Image, PngImagePlugin
|
from PIL import Image, PngImagePlugin
|
||||||
from send2trash import send2trash
|
from send2trash import send2trash
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageType
|
from invokeai.app.models.image import ResourceOrigin
|
||||||
from invokeai.app.models.metadata import ImageMetadata
|
from invokeai.app.models.metadata import ImageMetadata
|
||||||
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
||||||
|
|
||||||
@ -40,13 +40,13 @@ class ImageFileStorageBase(ABC):
|
|||||||
"""Low-level service responsible for storing and retrieving image files."""
|
"""Low-level service responsible for storing and retrieving image files."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get(self, image_type: ImageType, image_name: str) -> PILImageType:
|
def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
|
||||||
"""Retrieves an image as PIL Image."""
|
"""Retrieves an image as PIL Image."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_path(
|
def get_path(
|
||||||
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Gets the internal path to an image or thumbnail."""
|
"""Gets the internal path to an image or thumbnail."""
|
||||||
pass
|
pass
|
||||||
@ -62,7 +62,7 @@ class ImageFileStorageBase(ABC):
|
|||||||
def save(
|
def save(
|
||||||
self,
|
self,
|
||||||
image: PILImageType,
|
image: PILImageType,
|
||||||
image_type: ImageType,
|
image_origin: ResourceOrigin,
|
||||||
image_name: str,
|
image_name: str,
|
||||||
metadata: Optional[ImageMetadata] = None,
|
metadata: Optional[ImageMetadata] = None,
|
||||||
thumbnail_size: int = 256,
|
thumbnail_size: int = 256,
|
||||||
@ -71,7 +71,7 @@ class ImageFileStorageBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
|
||||||
"""Deletes an image and its thumbnail (if one exists)."""
|
"""Deletes an image and its thumbnail (if one exists)."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -93,17 +93,17 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# TODO: don't hard-code. get/save/delete should maybe take subpath?
|
# TODO: don't hard-code. get/save/delete should maybe take subpath?
|
||||||
for image_type in ImageType:
|
for image_origin in ResourceOrigin:
|
||||||
Path(os.path.join(output_folder, image_type)).mkdir(
|
Path(os.path.join(output_folder, image_origin)).mkdir(
|
||||||
parents=True, exist_ok=True
|
parents=True, exist_ok=True
|
||||||
)
|
)
|
||||||
Path(os.path.join(output_folder, image_type, "thumbnails")).mkdir(
|
Path(os.path.join(output_folder, image_origin, "thumbnails")).mkdir(
|
||||||
parents=True, exist_ok=True
|
parents=True, exist_ok=True
|
||||||
)
|
)
|
||||||
|
|
||||||
def get(self, image_type: ImageType, image_name: str) -> PILImageType:
|
def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
|
||||||
try:
|
try:
|
||||||
image_path = self.get_path(image_type, image_name)
|
image_path = self.get_path(image_origin, image_name)
|
||||||
cache_item = self.__get_cache(image_path)
|
cache_item = self.__get_cache(image_path)
|
||||||
if cache_item:
|
if cache_item:
|
||||||
return cache_item
|
return cache_item
|
||||||
@ -117,13 +117,13 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
def save(
|
def save(
|
||||||
self,
|
self,
|
||||||
image: PILImageType,
|
image: PILImageType,
|
||||||
image_type: ImageType,
|
image_origin: ResourceOrigin,
|
||||||
image_name: str,
|
image_name: str,
|
||||||
metadata: Optional[ImageMetadata] = None,
|
metadata: Optional[ImageMetadata] = None,
|
||||||
thumbnail_size: int = 256,
|
thumbnail_size: int = 256,
|
||||||
) -> None:
|
) -> None:
|
||||||
try:
|
try:
|
||||||
image_path = self.get_path(image_type, image_name)
|
image_path = self.get_path(image_origin, image_name)
|
||||||
|
|
||||||
if metadata is not None:
|
if metadata is not None:
|
||||||
pnginfo = PngImagePlugin.PngInfo()
|
pnginfo = PngImagePlugin.PngInfo()
|
||||||
@ -133,7 +133,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
image.save(image_path, "PNG")
|
image.save(image_path, "PNG")
|
||||||
|
|
||||||
thumbnail_name = get_thumbnail_name(image_name)
|
thumbnail_name = get_thumbnail_name(image_name)
|
||||||
thumbnail_path = self.get_path(image_type, thumbnail_name, thumbnail=True)
|
thumbnail_path = self.get_path(image_origin, thumbnail_name, thumbnail=True)
|
||||||
thumbnail_image = make_thumbnail(image, thumbnail_size)
|
thumbnail_image = make_thumbnail(image, thumbnail_size)
|
||||||
thumbnail_image.save(thumbnail_path)
|
thumbnail_image.save(thumbnail_path)
|
||||||
|
|
||||||
@ -142,10 +142,10 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ImageFileSaveException from e
|
raise ImageFileSaveException from e
|
||||||
|
|
||||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
|
||||||
try:
|
try:
|
||||||
basename = os.path.basename(image_name)
|
basename = os.path.basename(image_name)
|
||||||
image_path = self.get_path(image_type, basename)
|
image_path = self.get_path(image_origin, basename)
|
||||||
|
|
||||||
if os.path.exists(image_path):
|
if os.path.exists(image_path):
|
||||||
send2trash(image_path)
|
send2trash(image_path)
|
||||||
@ -153,7 +153,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
del self.__cache[image_path]
|
del self.__cache[image_path]
|
||||||
|
|
||||||
thumbnail_name = get_thumbnail_name(image_name)
|
thumbnail_name = get_thumbnail_name(image_name)
|
||||||
thumbnail_path = self.get_path(image_type, thumbnail_name, True)
|
thumbnail_path = self.get_path(image_origin, thumbnail_name, True)
|
||||||
|
|
||||||
if os.path.exists(thumbnail_path):
|
if os.path.exists(thumbnail_path):
|
||||||
send2trash(thumbnail_path)
|
send2trash(thumbnail_path)
|
||||||
@ -164,7 +164,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
|
|
||||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||||
def get_path(
|
def get_path(
|
||||||
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||||
) -> str:
|
) -> str:
|
||||||
# strip out any relative path shenanigans
|
# strip out any relative path shenanigans
|
||||||
basename = os.path.basename(image_name)
|
basename = os.path.basename(image_name)
|
||||||
@ -172,10 +172,10 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
if thumbnail:
|
if thumbnail:
|
||||||
thumbnail_name = get_thumbnail_name(basename)
|
thumbnail_name = get_thumbnail_name(basename)
|
||||||
path = os.path.join(
|
path = os.path.join(
|
||||||
self.__output_folder, image_type, "thumbnails", thumbnail_name
|
self.__output_folder, image_origin, "thumbnails", thumbnail_name
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
path = os.path.join(self.__output_folder, image_type, basename)
|
path = os.path.join(self.__output_folder, image_origin, basename)
|
||||||
|
|
||||||
abspath = os.path.abspath(path)
|
abspath = os.path.abspath(path)
|
||||||
|
|
||||||
|
@ -1,20 +1,35 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional, cast
|
from typing import Generic, Optional, TypeVar, cast
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import threading
|
import threading
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from pydantic.generics import GenericModel
|
||||||
|
|
||||||
from invokeai.app.models.metadata import ImageMetadata
|
from invokeai.app.models.metadata import ImageMetadata
|
||||||
from invokeai.app.models.image import (
|
from invokeai.app.models.image import (
|
||||||
ImageCategory,
|
ImageCategory,
|
||||||
ImageType,
|
ResourceOrigin,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.models.image_record import (
|
from invokeai.app.services.models.image_record import (
|
||||||
ImageRecord,
|
ImageRecord,
|
||||||
|
ImageRecordChanges,
|
||||||
deserialize_image_record,
|
deserialize_image_record,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.item_storage import PaginatedResults
|
|
||||||
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
class OffsetPaginatedResults(GenericModel, Generic[T]):
|
||||||
|
"""Offset-paginated results"""
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
items: list[T] = Field(description="Items")
|
||||||
|
offset: int = Field(description="Offset from which to retrieve items")
|
||||||
|
limit: int = Field(description="Limit of items to get")
|
||||||
|
total: int = Field(description="Total number of items in result")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
# TODO: Should these excpetions subclass existing python exceptions?
|
# TODO: Should these excpetions subclass existing python exceptions?
|
||||||
@ -45,25 +60,36 @@ class ImageRecordStorageBase(ABC):
|
|||||||
# TODO: Implement an `update()` method
|
# TODO: Implement an `update()` method
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get(self, image_type: ImageType, image_name: str) -> ImageRecord:
|
def get(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
|
||||||
"""Gets an image record."""
|
"""Gets an image record."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
image_name: str,
|
||||||
|
image_origin: ResourceOrigin,
|
||||||
|
changes: ImageRecordChanges,
|
||||||
|
) -> None:
|
||||||
|
"""Updates an image record."""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_many(
|
def get_many(
|
||||||
self,
|
self,
|
||||||
image_type: ImageType,
|
offset: int = 0,
|
||||||
image_category: ImageCategory,
|
limit: int = 10,
|
||||||
page: int = 0,
|
image_origin: Optional[ResourceOrigin] = None,
|
||||||
per_page: int = 10,
|
categories: Optional[list[ImageCategory]] = None,
|
||||||
) -> PaginatedResults[ImageRecord]:
|
is_intermediate: Optional[bool] = None,
|
||||||
|
) -> OffsetPaginatedResults[ImageRecord]:
|
||||||
"""Gets a page of image records."""
|
"""Gets a page of image records."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# TODO: The database has a nullable `deleted_at` column, currently unused.
|
# TODO: The database has a nullable `deleted_at` column, currently unused.
|
||||||
# Should we implement soft deletes? Would need coordination with ImageFileStorage.
|
# Should we implement soft deletes? Would need coordination with ImageFileStorage.
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
|
||||||
"""Deletes an image record."""
|
"""Deletes an image record."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -71,13 +97,14 @@ class ImageRecordStorageBase(ABC):
|
|||||||
def save(
|
def save(
|
||||||
self,
|
self,
|
||||||
image_name: str,
|
image_name: str,
|
||||||
image_type: ImageType,
|
image_origin: ResourceOrigin,
|
||||||
image_category: ImageCategory,
|
image_category: ImageCategory,
|
||||||
width: int,
|
width: int,
|
||||||
height: int,
|
height: int,
|
||||||
session_id: Optional[str],
|
session_id: Optional[str],
|
||||||
node_id: Optional[str],
|
node_id: Optional[str],
|
||||||
metadata: Optional[ImageMetadata],
|
metadata: Optional[ImageMetadata],
|
||||||
|
is_intermediate: bool = False,
|
||||||
) -> datetime:
|
) -> datetime:
|
||||||
"""Saves an image record."""
|
"""Saves an image record."""
|
||||||
pass
|
pass
|
||||||
@ -91,7 +118,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
|
|
||||||
def __init__(self, filename: str) -> None:
|
def __init__(self, filename: str) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self._filename = filename
|
self._filename = filename
|
||||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
||||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||||
@ -117,7 +143,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
CREATE TABLE IF NOT EXISTS images (
|
CREATE TABLE IF NOT EXISTS images (
|
||||||
image_name TEXT NOT NULL PRIMARY KEY,
|
image_name TEXT NOT NULL PRIMARY KEY,
|
||||||
-- This is an enum in python, unrestricted string here for flexibility
|
-- This is an enum in python, unrestricted string here for flexibility
|
||||||
image_type TEXT NOT NULL,
|
image_origin TEXT NOT NULL,
|
||||||
-- This is an enum in python, unrestricted string here for flexibility
|
-- This is an enum in python, unrestricted string here for flexibility
|
||||||
image_category TEXT NOT NULL,
|
image_category TEXT NOT NULL,
|
||||||
width INTEGER NOT NULL,
|
width INTEGER NOT NULL,
|
||||||
@ -125,9 +151,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
session_id TEXT,
|
session_id TEXT,
|
||||||
node_id TEXT,
|
node_id TEXT,
|
||||||
metadata TEXT,
|
metadata TEXT,
|
||||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
is_intermediate BOOLEAN DEFAULT FALSE,
|
||||||
|
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
-- Updated via trigger
|
-- Updated via trigger
|
||||||
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
-- Soft delete, currently unused
|
-- Soft delete, currently unused
|
||||||
deleted_at DATETIME
|
deleted_at DATETIME
|
||||||
);
|
);
|
||||||
@ -142,7 +169,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
)
|
)
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
CREATE INDEX IF NOT EXISTS idx_images_image_type ON images(image_type);
|
CREATE INDEX IF NOT EXISTS idx_images_image_origin ON images(image_origin);
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
@ -169,7 +196,9 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
def get(self, image_type: ImageType, image_name: str) -> Union[ImageRecord, None]:
|
def get(
|
||||||
|
self, image_origin: ResourceOrigin, image_name: str
|
||||||
|
) -> Union[ImageRecord, None]:
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
|
|
||||||
@ -193,38 +222,110 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
|
|
||||||
return deserialize_image_record(dict(result))
|
return deserialize_image_record(dict(result))
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
image_name: str,
|
||||||
|
image_origin: ResourceOrigin,
|
||||||
|
changes: ImageRecordChanges,
|
||||||
|
) -> None:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
# Change the category of the image
|
||||||
|
if changes.image_category is not None:
|
||||||
|
self._cursor.execute(
|
||||||
|
f"""--sql
|
||||||
|
UPDATE images
|
||||||
|
SET image_category = ?
|
||||||
|
WHERE image_name = ?;
|
||||||
|
""",
|
||||||
|
(changes.image_category, image_name),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Change the session associated with the image
|
||||||
|
if changes.session_id is not None:
|
||||||
|
self._cursor.execute(
|
||||||
|
f"""--sql
|
||||||
|
UPDATE images
|
||||||
|
SET session_id = ?
|
||||||
|
WHERE image_name = ?;
|
||||||
|
""",
|
||||||
|
(changes.session_id, image_name),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Change the image's `is_intermediate`` flag
|
||||||
|
if changes.is_intermediate is not None:
|
||||||
|
self._cursor.execute(
|
||||||
|
f"""--sql
|
||||||
|
UPDATE images
|
||||||
|
SET is_intermediate = ?
|
||||||
|
WHERE image_name = ?;
|
||||||
|
""",
|
||||||
|
(changes.is_intermediate, image_name),
|
||||||
|
)
|
||||||
|
self._conn.commit()
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise ImageRecordSaveException from e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
def get_many(
|
def get_many(
|
||||||
self,
|
self,
|
||||||
image_type: ImageType,
|
offset: int = 0,
|
||||||
image_category: ImageCategory,
|
limit: int = 10,
|
||||||
page: int = 0,
|
image_origin: Optional[ResourceOrigin] = None,
|
||||||
per_page: int = 10,
|
categories: Optional[list[ImageCategory]] = None,
|
||||||
) -> PaginatedResults[ImageRecord]:
|
is_intermediate: Optional[bool] = None,
|
||||||
|
) -> OffsetPaginatedResults[ImageRecord]:
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
|
|
||||||
self._cursor.execute(
|
# Manually build two queries - one for the count, one for the records
|
||||||
f"""--sql
|
|
||||||
SELECT * FROM images
|
count_query = f"""SELECT COUNT(*) FROM images WHERE 1=1\n"""
|
||||||
WHERE image_type = ? AND image_category = ?
|
images_query = f"""SELECT * FROM images WHERE 1=1\n"""
|
||||||
ORDER BY created_at DESC
|
|
||||||
LIMIT ? OFFSET ?;
|
query_conditions = ""
|
||||||
""",
|
query_params = []
|
||||||
(image_type.value, image_category.value, per_page, page * per_page),
|
|
||||||
|
if image_origin is not None:
|
||||||
|
query_conditions += f"""AND image_origin = ?\n"""
|
||||||
|
query_params.append(image_origin.value)
|
||||||
|
|
||||||
|
if categories is not None:
|
||||||
|
## Convert the enum values to unique list of strings
|
||||||
|
category_strings = list(
|
||||||
|
map(lambda c: c.value, set(categories))
|
||||||
)
|
)
|
||||||
|
# Create the correct length of placeholders
|
||||||
|
placeholders = ",".join("?" * len(category_strings))
|
||||||
|
query_conditions += f"AND image_category IN ( {placeholders} )\n"
|
||||||
|
|
||||||
|
# Unpack the included categories into the query params
|
||||||
|
for c in category_strings:
|
||||||
|
query_params.append(c)
|
||||||
|
|
||||||
|
if is_intermediate is not None:
|
||||||
|
query_conditions += f"""AND is_intermediate = ?\n"""
|
||||||
|
query_params.append(is_intermediate)
|
||||||
|
|
||||||
|
query_pagination = f"""ORDER BY created_at DESC LIMIT ? OFFSET ?\n"""
|
||||||
|
|
||||||
|
# Final images query with pagination
|
||||||
|
images_query += query_conditions + query_pagination + ";"
|
||||||
|
# Add all the parameters
|
||||||
|
images_params = query_params.copy()
|
||||||
|
images_params.append(limit)
|
||||||
|
images_params.append(offset)
|
||||||
|
# Build the list of images, deserializing each row
|
||||||
|
self._cursor.execute(images_query, images_params)
|
||||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||||
|
|
||||||
images = list(map(lambda r: deserialize_image_record(dict(r)), result))
|
images = list(map(lambda r: deserialize_image_record(dict(r)), result))
|
||||||
|
|
||||||
self._cursor.execute(
|
# Set up and execute the count query, without pagination
|
||||||
"""--sql
|
count_query += query_conditions + ";"
|
||||||
SELECT count(*) FROM images
|
count_params = query_params.copy()
|
||||||
WHERE image_type = ? AND image_category = ?
|
self._cursor.execute(count_query, count_params)
|
||||||
""",
|
|
||||||
(image_type.value, image_category.value),
|
|
||||||
)
|
|
||||||
|
|
||||||
count = self._cursor.fetchone()[0]
|
count = self._cursor.fetchone()[0]
|
||||||
except sqlite3.Error as e:
|
except sqlite3.Error as e:
|
||||||
self._conn.rollback()
|
self._conn.rollback()
|
||||||
@ -232,13 +333,11 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
|
|
||||||
pageCount = int(count / per_page) + 1
|
return OffsetPaginatedResults(
|
||||||
|
items=images, offset=offset, limit=limit, total=count
|
||||||
return PaginatedResults(
|
|
||||||
items=images, page=page, pages=pageCount, per_page=per_page, total=count
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
@ -258,13 +357,14 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
def save(
|
def save(
|
||||||
self,
|
self,
|
||||||
image_name: str,
|
image_name: str,
|
||||||
image_type: ImageType,
|
image_origin: ResourceOrigin,
|
||||||
image_category: ImageCategory,
|
image_category: ImageCategory,
|
||||||
session_id: Optional[str],
|
session_id: Optional[str],
|
||||||
width: int,
|
width: int,
|
||||||
height: int,
|
height: int,
|
||||||
node_id: Optional[str],
|
node_id: Optional[str],
|
||||||
metadata: Optional[ImageMetadata],
|
metadata: Optional[ImageMetadata],
|
||||||
|
is_intermediate: bool = False,
|
||||||
) -> datetime:
|
) -> datetime:
|
||||||
try:
|
try:
|
||||||
metadata_json = (
|
metadata_json = (
|
||||||
@ -275,25 +375,27 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
"""--sql
|
"""--sql
|
||||||
INSERT OR IGNORE INTO images (
|
INSERT OR IGNORE INTO images (
|
||||||
image_name,
|
image_name,
|
||||||
image_type,
|
image_origin,
|
||||||
image_category,
|
image_category,
|
||||||
width,
|
width,
|
||||||
height,
|
height,
|
||||||
node_id,
|
node_id,
|
||||||
session_id,
|
session_id,
|
||||||
metadata
|
metadata,
|
||||||
|
is_intermediate
|
||||||
)
|
)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?);
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?);
|
||||||
""",
|
""",
|
||||||
(
|
(
|
||||||
image_name,
|
image_name,
|
||||||
image_type.value,
|
image_origin.value,
|
||||||
image_category.value,
|
image_category.value,
|
||||||
width,
|
width,
|
||||||
height,
|
height,
|
||||||
node_id,
|
node_id,
|
||||||
session_id,
|
session_id,
|
||||||
metadata_json,
|
metadata_json,
|
||||||
|
is_intermediate,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self._conn.commit()
|
self._conn.commit()
|
||||||
|
@ -1,14 +1,13 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
from typing import Optional, TYPE_CHECKING, Union
|
from typing import Optional, TYPE_CHECKING, Union
|
||||||
import uuid
|
|
||||||
from PIL.Image import Image as PILImageType
|
from PIL.Image import Image as PILImageType
|
||||||
|
|
||||||
from invokeai.app.models.image import (
|
from invokeai.app.models.image import (
|
||||||
ImageCategory,
|
ImageCategory,
|
||||||
ImageType,
|
ResourceOrigin,
|
||||||
InvalidImageCategoryException,
|
InvalidImageCategoryException,
|
||||||
InvalidImageTypeException,
|
InvalidOriginException,
|
||||||
)
|
)
|
||||||
from invokeai.app.models.metadata import ImageMetadata
|
from invokeai.app.models.metadata import ImageMetadata
|
||||||
from invokeai.app.services.image_record_storage import (
|
from invokeai.app.services.image_record_storage import (
|
||||||
@ -16,10 +15,12 @@ from invokeai.app.services.image_record_storage import (
|
|||||||
ImageRecordNotFoundException,
|
ImageRecordNotFoundException,
|
||||||
ImageRecordSaveException,
|
ImageRecordSaveException,
|
||||||
ImageRecordStorageBase,
|
ImageRecordStorageBase,
|
||||||
|
OffsetPaginatedResults,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.models.image_record import (
|
from invokeai.app.services.models.image_record import (
|
||||||
ImageRecord,
|
ImageRecord,
|
||||||
ImageDTO,
|
ImageDTO,
|
||||||
|
ImageRecordChanges,
|
||||||
image_record_to_dto,
|
image_record_to_dto,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.image_file_storage import (
|
from invokeai.app.services.image_file_storage import (
|
||||||
@ -30,8 +31,8 @@ from invokeai.app.services.image_file_storage import (
|
|||||||
)
|
)
|
||||||
from invokeai.app.services.item_storage import ItemStorageABC, PaginatedResults
|
from invokeai.app.services.item_storage import ItemStorageABC, PaginatedResults
|
||||||
from invokeai.app.services.metadata import MetadataServiceBase
|
from invokeai.app.services.metadata import MetadataServiceBase
|
||||||
|
from invokeai.app.services.resource_name import NameServiceBase
|
||||||
from invokeai.app.services.urls import UrlServiceBase
|
from invokeai.app.services.urls import UrlServiceBase
|
||||||
from invokeai.app.util.misc import get_iso_timestamp
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from invokeai.app.services.graph import GraphExecutionState
|
from invokeai.app.services.graph import GraphExecutionState
|
||||||
@ -44,32 +45,42 @@ class ImageServiceABC(ABC):
|
|||||||
def create(
|
def create(
|
||||||
self,
|
self,
|
||||||
image: PILImageType,
|
image: PILImageType,
|
||||||
image_type: ImageType,
|
image_origin: ResourceOrigin,
|
||||||
image_category: ImageCategory,
|
image_category: ImageCategory,
|
||||||
node_id: Optional[str] = None,
|
node_id: Optional[str] = None,
|
||||||
session_id: Optional[str] = None,
|
session_id: Optional[str] = None,
|
||||||
metadata: Optional[ImageMetadata] = None,
|
intermediate: bool = False,
|
||||||
) -> ImageDTO:
|
) -> ImageDTO:
|
||||||
"""Creates an image, storing the file and its metadata."""
|
"""Creates an image, storing the file and its metadata."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
|
def update(
|
||||||
|
self,
|
||||||
|
image_origin: ResourceOrigin,
|
||||||
|
image_name: str,
|
||||||
|
changes: ImageRecordChanges,
|
||||||
|
) -> ImageDTO:
|
||||||
|
"""Updates an image."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
|
||||||
"""Gets an image as a PIL image."""
|
"""Gets an image as a PIL image."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord:
|
def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
|
||||||
"""Gets an image record."""
|
"""Gets an image record."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO:
|
def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO:
|
||||||
"""Gets an image DTO."""
|
"""Gets an image DTO."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_path(self, image_type: ImageType, image_name: str) -> str:
|
def get_path(self, image_origin: ResourceOrigin, image_name: str) -> str:
|
||||||
"""Gets an image's path."""
|
"""Gets an image's path."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -80,7 +91,7 @@ class ImageServiceABC(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_url(
|
def get_url(
|
||||||
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Gets an image's or thumbnail's URL."""
|
"""Gets an image's or thumbnail's URL."""
|
||||||
pass
|
pass
|
||||||
@ -88,16 +99,17 @@ class ImageServiceABC(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_many(
|
def get_many(
|
||||||
self,
|
self,
|
||||||
image_type: ImageType,
|
offset: int = 0,
|
||||||
image_category: ImageCategory,
|
limit: int = 10,
|
||||||
page: int = 0,
|
image_origin: Optional[ResourceOrigin] = None,
|
||||||
per_page: int = 10,
|
categories: Optional[list[ImageCategory]] = None,
|
||||||
) -> PaginatedResults[ImageDTO]:
|
is_intermediate: Optional[bool] = None,
|
||||||
|
) -> OffsetPaginatedResults[ImageDTO]:
|
||||||
"""Gets a paginated list of image DTOs."""
|
"""Gets a paginated list of image DTOs."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def delete(self, image_type: ImageType, image_name: str):
|
def delete(self, image_origin: ResourceOrigin, image_name: str):
|
||||||
"""Deletes an image."""
|
"""Deletes an image."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -110,6 +122,7 @@ class ImageServiceDependencies:
|
|||||||
metadata: MetadataServiceBase
|
metadata: MetadataServiceBase
|
||||||
urls: UrlServiceBase
|
urls: UrlServiceBase
|
||||||
logger: Logger
|
logger: Logger
|
||||||
|
names: NameServiceBase
|
||||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
|
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -119,6 +132,7 @@ class ImageServiceDependencies:
|
|||||||
metadata: MetadataServiceBase,
|
metadata: MetadataServiceBase,
|
||||||
url: UrlServiceBase,
|
url: UrlServiceBase,
|
||||||
logger: Logger,
|
logger: Logger,
|
||||||
|
names: NameServiceBase,
|
||||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||||
):
|
):
|
||||||
self.records = image_record_storage
|
self.records = image_record_storage
|
||||||
@ -126,6 +140,7 @@ class ImageServiceDependencies:
|
|||||||
self.metadata = metadata
|
self.metadata = metadata
|
||||||
self.urls = url
|
self.urls = url
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
self.names = names
|
||||||
self.graph_execution_manager = graph_execution_manager
|
self.graph_execution_manager = graph_execution_manager
|
||||||
|
|
||||||
|
|
||||||
@ -139,6 +154,7 @@ class ImageService(ImageServiceABC):
|
|||||||
metadata: MetadataServiceBase,
|
metadata: MetadataServiceBase,
|
||||||
url: UrlServiceBase,
|
url: UrlServiceBase,
|
||||||
logger: Logger,
|
logger: Logger,
|
||||||
|
names: NameServiceBase,
|
||||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||||
):
|
):
|
||||||
self._services = ImageServiceDependencies(
|
self._services = ImageServiceDependencies(
|
||||||
@ -147,29 +163,26 @@ class ImageService(ImageServiceABC):
|
|||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
url=url,
|
url=url,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
|
names=names,
|
||||||
graph_execution_manager=graph_execution_manager,
|
graph_execution_manager=graph_execution_manager,
|
||||||
)
|
)
|
||||||
|
|
||||||
def create(
|
def create(
|
||||||
self,
|
self,
|
||||||
image: PILImageType,
|
image: PILImageType,
|
||||||
image_type: ImageType,
|
image_origin: ResourceOrigin,
|
||||||
image_category: ImageCategory,
|
image_category: ImageCategory,
|
||||||
node_id: Optional[str] = None,
|
node_id: Optional[str] = None,
|
||||||
session_id: Optional[str] = None,
|
session_id: Optional[str] = None,
|
||||||
|
is_intermediate: bool = False,
|
||||||
) -> ImageDTO:
|
) -> ImageDTO:
|
||||||
if image_type not in ImageType:
|
if image_origin not in ResourceOrigin:
|
||||||
raise InvalidImageTypeException
|
raise InvalidOriginException
|
||||||
|
|
||||||
if image_category not in ImageCategory:
|
if image_category not in ImageCategory:
|
||||||
raise InvalidImageCategoryException
|
raise InvalidImageCategoryException
|
||||||
|
|
||||||
image_name = self._create_image_name(
|
image_name = self._services.names.create_image_name()
|
||||||
image_type=image_type,
|
|
||||||
image_category=image_category,
|
|
||||||
node_id=node_id,
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
metadata = self._get_metadata(session_id, node_id)
|
metadata = self._get_metadata(session_id, node_id)
|
||||||
|
|
||||||
@ -180,10 +193,12 @@ class ImageService(ImageServiceABC):
|
|||||||
created_at = self._services.records.save(
|
created_at = self._services.records.save(
|
||||||
# Non-nullable fields
|
# Non-nullable fields
|
||||||
image_name=image_name,
|
image_name=image_name,
|
||||||
image_type=image_type,
|
image_origin=image_origin,
|
||||||
image_category=image_category,
|
image_category=image_category,
|
||||||
width=width,
|
width=width,
|
||||||
height=height,
|
height=height,
|
||||||
|
# Meta fields
|
||||||
|
is_intermediate=is_intermediate,
|
||||||
# Nullable fields
|
# Nullable fields
|
||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
@ -191,21 +206,21 @@ class ImageService(ImageServiceABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self._services.files.save(
|
self._services.files.save(
|
||||||
image_type=image_type,
|
image_origin=image_origin,
|
||||||
image_name=image_name,
|
image_name=image_name,
|
||||||
image=image,
|
image=image,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
image_url = self._services.urls.get_image_url(image_type, image_name)
|
image_url = self._services.urls.get_image_url(image_origin, image_name)
|
||||||
thumbnail_url = self._services.urls.get_image_url(
|
thumbnail_url = self._services.urls.get_image_url(
|
||||||
image_type, image_name, True
|
image_origin, image_name, True
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageDTO(
|
return ImageDTO(
|
||||||
# Non-nullable fields
|
# Non-nullable fields
|
||||||
image_name=image_name,
|
image_name=image_name,
|
||||||
image_type=image_type,
|
image_origin=image_origin,
|
||||||
image_category=image_category,
|
image_category=image_category,
|
||||||
width=width,
|
width=width,
|
||||||
height=height,
|
height=height,
|
||||||
@ -217,6 +232,7 @@ class ImageService(ImageServiceABC):
|
|||||||
created_at=created_at,
|
created_at=created_at,
|
||||||
updated_at=created_at, # this is always the same as the created_at at this time
|
updated_at=created_at, # this is always the same as the created_at at this time
|
||||||
deleted_at=None,
|
deleted_at=None,
|
||||||
|
is_intermediate=is_intermediate,
|
||||||
# Extra non-nullable fields for DTO
|
# Extra non-nullable fields for DTO
|
||||||
image_url=image_url,
|
image_url=image_url,
|
||||||
thumbnail_url=thumbnail_url,
|
thumbnail_url=thumbnail_url,
|
||||||
@ -231,9 +247,25 @@ class ImageService(ImageServiceABC):
|
|||||||
self._services.logger.error("Problem saving image record and file")
|
self._services.logger.error("Problem saving image record and file")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
|
def update(
|
||||||
|
self,
|
||||||
|
image_origin: ResourceOrigin,
|
||||||
|
image_name: str,
|
||||||
|
changes: ImageRecordChanges,
|
||||||
|
) -> ImageDTO:
|
||||||
try:
|
try:
|
||||||
return self._services.files.get(image_type, image_name)
|
self._services.records.update(image_name, image_origin, changes)
|
||||||
|
return self.get_dto(image_origin, image_name)
|
||||||
|
except ImageRecordSaveException:
|
||||||
|
self._services.logger.error("Failed to update image record")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
self._services.logger.error("Problem updating image record")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
|
||||||
|
try:
|
||||||
|
return self._services.files.get(image_origin, image_name)
|
||||||
except ImageFileNotFoundException:
|
except ImageFileNotFoundException:
|
||||||
self._services.logger.error("Failed to get image file")
|
self._services.logger.error("Failed to get image file")
|
||||||
raise
|
raise
|
||||||
@ -241,9 +273,9 @@ class ImageService(ImageServiceABC):
|
|||||||
self._services.logger.error("Problem getting image file")
|
self._services.logger.error("Problem getting image file")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord:
|
def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
|
||||||
try:
|
try:
|
||||||
return self._services.records.get(image_type, image_name)
|
return self._services.records.get(image_origin, image_name)
|
||||||
except ImageRecordNotFoundException:
|
except ImageRecordNotFoundException:
|
||||||
self._services.logger.error("Image record not found")
|
self._services.logger.error("Image record not found")
|
||||||
raise
|
raise
|
||||||
@ -251,14 +283,14 @@ class ImageService(ImageServiceABC):
|
|||||||
self._services.logger.error("Problem getting image record")
|
self._services.logger.error("Problem getting image record")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO:
|
def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO:
|
||||||
try:
|
try:
|
||||||
image_record = self._services.records.get(image_type, image_name)
|
image_record = self._services.records.get(image_origin, image_name)
|
||||||
|
|
||||||
image_dto = image_record_to_dto(
|
image_dto = image_record_to_dto(
|
||||||
image_record,
|
image_record,
|
||||||
self._services.urls.get_image_url(image_type, image_name),
|
self._services.urls.get_image_url(image_origin, image_name),
|
||||||
self._services.urls.get_image_url(image_type, image_name, True),
|
self._services.urls.get_image_url(image_origin, image_name, True),
|
||||||
)
|
)
|
||||||
|
|
||||||
return image_dto
|
return image_dto
|
||||||
@ -270,10 +302,10 @@ class ImageService(ImageServiceABC):
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
def get_path(
|
def get_path(
|
||||||
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||||
) -> str:
|
) -> str:
|
||||||
try:
|
try:
|
||||||
return self._services.files.get_path(image_type, image_name, thumbnail)
|
return self._services.files.get_path(image_origin, image_name, thumbnail)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._services.logger.error("Problem getting image path")
|
self._services.logger.error("Problem getting image path")
|
||||||
raise e
|
raise e
|
||||||
@ -286,57 +318,58 @@ class ImageService(ImageServiceABC):
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
def get_url(
|
def get_url(
|
||||||
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||||
) -> str:
|
) -> str:
|
||||||
try:
|
try:
|
||||||
return self._services.urls.get_image_url(image_type, image_name, thumbnail)
|
return self._services.urls.get_image_url(image_origin, image_name, thumbnail)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._services.logger.error("Problem getting image path")
|
self._services.logger.error("Problem getting image path")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def get_many(
|
def get_many(
|
||||||
self,
|
self,
|
||||||
image_type: ImageType,
|
offset: int = 0,
|
||||||
image_category: ImageCategory,
|
limit: int = 10,
|
||||||
page: int = 0,
|
image_origin: Optional[ResourceOrigin] = None,
|
||||||
per_page: int = 10,
|
categories: Optional[list[ImageCategory]] = None,
|
||||||
) -> PaginatedResults[ImageDTO]:
|
is_intermediate: Optional[bool] = None,
|
||||||
|
) -> OffsetPaginatedResults[ImageDTO]:
|
||||||
try:
|
try:
|
||||||
results = self._services.records.get_many(
|
results = self._services.records.get_many(
|
||||||
image_type,
|
offset,
|
||||||
image_category,
|
limit,
|
||||||
page,
|
image_origin,
|
||||||
per_page,
|
categories,
|
||||||
|
is_intermediate,
|
||||||
)
|
)
|
||||||
|
|
||||||
image_dtos = list(
|
image_dtos = list(
|
||||||
map(
|
map(
|
||||||
lambda r: image_record_to_dto(
|
lambda r: image_record_to_dto(
|
||||||
r,
|
r,
|
||||||
self._services.urls.get_image_url(image_type, r.image_name),
|
self._services.urls.get_image_url(r.image_origin, r.image_name),
|
||||||
self._services.urls.get_image_url(
|
self._services.urls.get_image_url(
|
||||||
image_type, r.image_name, True
|
r.image_origin, r.image_name, True
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
results.items,
|
results.items,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return PaginatedResults[ImageDTO](
|
return OffsetPaginatedResults[ImageDTO](
|
||||||
items=image_dtos,
|
items=image_dtos,
|
||||||
page=results.page,
|
offset=results.offset,
|
||||||
pages=results.pages,
|
limit=results.limit,
|
||||||
per_page=results.per_page,
|
|
||||||
total=results.total,
|
total=results.total,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._services.logger.error("Problem getting paginated image DTOs")
|
self._services.logger.error("Problem getting paginated image DTOs")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def delete(self, image_type: ImageType, image_name: str):
|
def delete(self, image_origin: ResourceOrigin, image_name: str):
|
||||||
try:
|
try:
|
||||||
self._services.files.delete(image_type, image_name)
|
self._services.files.delete(image_origin, image_name)
|
||||||
self._services.records.delete(image_type, image_name)
|
self._services.records.delete(image_origin, image_name)
|
||||||
except ImageRecordDeleteException:
|
except ImageRecordDeleteException:
|
||||||
self._services.logger.error(f"Failed to delete image record")
|
self._services.logger.error(f"Failed to delete image record")
|
||||||
raise
|
raise
|
||||||
@ -347,21 +380,6 @@ class ImageService(ImageServiceABC):
|
|||||||
self._services.logger.error("Problem deleting image record and file")
|
self._services.logger.error("Problem deleting image record and file")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def _create_image_name(
|
|
||||||
self,
|
|
||||||
image_type: ImageType,
|
|
||||||
image_category: ImageCategory,
|
|
||||||
node_id: Optional[str] = None,
|
|
||||||
session_id: Optional[str] = None,
|
|
||||||
) -> str:
|
|
||||||
"""Create a unique image name."""
|
|
||||||
uuid_str = str(uuid.uuid4())
|
|
||||||
|
|
||||||
if node_id is not None and session_id is not None:
|
|
||||||
return f"{image_type.value}_{image_category.value}_{session_id}_{node_id}_{uuid_str}.png"
|
|
||||||
|
|
||||||
return f"{image_type.value}_{image_category.value}_{uuid_str}.png"
|
|
||||||
|
|
||||||
def _get_metadata(
|
def _get_metadata(
|
||||||
self, session_id: Optional[str] = None, node_id: Optional[str] = None
|
self, session_id: Optional[str] = None, node_id: Optional[str] = None
|
||||||
) -> Union[ImageMetadata, None]:
|
) -> Union[ImageMetadata, None]:
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import datetime
|
import datetime
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr
|
||||||
from invokeai.app.models.image import ImageCategory, ImageType
|
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||||
from invokeai.app.models.metadata import ImageMetadata
|
from invokeai.app.models.metadata import ImageMetadata
|
||||||
from invokeai.app.util.misc import get_iso_timestamp
|
from invokeai.app.util.misc import get_iso_timestamp
|
||||||
|
|
||||||
@ -11,8 +11,8 @@ class ImageRecord(BaseModel):
|
|||||||
|
|
||||||
image_name: str = Field(description="The unique name of the image.")
|
image_name: str = Field(description="The unique name of the image.")
|
||||||
"""The unique name of the image."""
|
"""The unique name of the image."""
|
||||||
image_type: ImageType = Field(description="The type of the image.")
|
image_origin: ResourceOrigin = Field(description="The type of the image.")
|
||||||
"""The type of the image."""
|
"""The origin of the image."""
|
||||||
image_category: ImageCategory = Field(description="The category of the image.")
|
image_category: ImageCategory = Field(description="The category of the image.")
|
||||||
"""The category of the image."""
|
"""The category of the image."""
|
||||||
width: int = Field(description="The width of the image in px.")
|
width: int = Field(description="The width of the image in px.")
|
||||||
@ -31,6 +31,8 @@ class ImageRecord(BaseModel):
|
|||||||
description="The deleted timestamp of the image."
|
description="The deleted timestamp of the image."
|
||||||
)
|
)
|
||||||
"""The deleted timestamp of the image."""
|
"""The deleted timestamp of the image."""
|
||||||
|
is_intermediate: bool = Field(description="Whether this is an intermediate image.")
|
||||||
|
"""Whether this is an intermediate image."""
|
||||||
session_id: Optional[str] = Field(
|
session_id: Optional[str] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The session ID that generated this image, if it is a generated image.",
|
description="The session ID that generated this image, if it is a generated image.",
|
||||||
@ -48,13 +50,37 @@ class ImageRecord(BaseModel):
|
|||||||
"""A limited subset of the image's generation metadata. Retrieve the image's session for full metadata."""
|
"""A limited subset of the image's generation metadata. Retrieve the image's session for full metadata."""
|
||||||
|
|
||||||
|
|
||||||
|
class ImageRecordChanges(BaseModel, extra=Extra.forbid):
|
||||||
|
"""A set of changes to apply to an image record.
|
||||||
|
|
||||||
|
Only limited changes are valid:
|
||||||
|
- `image_category`: change the category of an image
|
||||||
|
- `session_id`: change the session associated with an image
|
||||||
|
- `is_intermediate`: change the image's `is_intermediate` flag
|
||||||
|
"""
|
||||||
|
|
||||||
|
image_category: Optional[ImageCategory] = Field(
|
||||||
|
description="The image's new category."
|
||||||
|
)
|
||||||
|
"""The image's new category."""
|
||||||
|
session_id: Optional[StrictStr] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The image's new session ID.",
|
||||||
|
)
|
||||||
|
"""The image's new session ID."""
|
||||||
|
is_intermediate: Optional[StrictBool] = Field(
|
||||||
|
default=None, description="The image's new `is_intermediate` flag."
|
||||||
|
)
|
||||||
|
"""The image's new `is_intermediate` flag."""
|
||||||
|
|
||||||
|
|
||||||
class ImageUrlsDTO(BaseModel):
|
class ImageUrlsDTO(BaseModel):
|
||||||
"""The URLs for an image and its thumbnail."""
|
"""The URLs for an image and its thumbnail."""
|
||||||
|
|
||||||
image_name: str = Field(description="The unique name of the image.")
|
image_name: str = Field(description="The unique name of the image.")
|
||||||
"""The unique name of the image."""
|
"""The unique name of the image."""
|
||||||
image_type: ImageType = Field(description="The type of the image.")
|
image_origin: ResourceOrigin = Field(description="The type of the image.")
|
||||||
"""The type of the image."""
|
"""The origin of the image."""
|
||||||
image_url: str = Field(description="The URL of the image.")
|
image_url: str = Field(description="The URL of the image.")
|
||||||
"""The URL of the image."""
|
"""The URL of the image."""
|
||||||
thumbnail_url: str = Field(description="The URL of the image's thumbnail.")
|
thumbnail_url: str = Field(description="The URL of the image's thumbnail.")
|
||||||
@ -84,7 +110,9 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
|||||||
# Retrieve all the values, setting "reasonable" defaults if they are not present.
|
# Retrieve all the values, setting "reasonable" defaults if they are not present.
|
||||||
|
|
||||||
image_name = image_dict.get("image_name", "unknown")
|
image_name = image_dict.get("image_name", "unknown")
|
||||||
image_type = ImageType(image_dict.get("image_type", ImageType.RESULT.value))
|
image_origin = ResourceOrigin(
|
||||||
|
image_dict.get("image_origin", ResourceOrigin.INTERNAL.value)
|
||||||
|
)
|
||||||
image_category = ImageCategory(
|
image_category = ImageCategory(
|
||||||
image_dict.get("image_category", ImageCategory.GENERAL.value)
|
image_dict.get("image_category", ImageCategory.GENERAL.value)
|
||||||
)
|
)
|
||||||
@ -95,6 +123,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
|||||||
created_at = image_dict.get("created_at", get_iso_timestamp())
|
created_at = image_dict.get("created_at", get_iso_timestamp())
|
||||||
updated_at = image_dict.get("updated_at", get_iso_timestamp())
|
updated_at = image_dict.get("updated_at", get_iso_timestamp())
|
||||||
deleted_at = image_dict.get("deleted_at", get_iso_timestamp())
|
deleted_at = image_dict.get("deleted_at", get_iso_timestamp())
|
||||||
|
is_intermediate = image_dict.get("is_intermediate", False)
|
||||||
|
|
||||||
raw_metadata = image_dict.get("metadata")
|
raw_metadata = image_dict.get("metadata")
|
||||||
|
|
||||||
@ -105,7 +134,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
|||||||
|
|
||||||
return ImageRecord(
|
return ImageRecord(
|
||||||
image_name=image_name,
|
image_name=image_name,
|
||||||
image_type=image_type,
|
image_origin=image_origin,
|
||||||
image_category=image_category,
|
image_category=image_category,
|
||||||
width=width,
|
width=width,
|
||||||
height=height,
|
height=height,
|
||||||
@ -115,4 +144,5 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
|||||||
created_at=created_at,
|
created_at=created_at,
|
||||||
updated_at=updated_at,
|
updated_at=updated_at,
|
||||||
deleted_at=deleted_at,
|
deleted_at=deleted_at,
|
||||||
|
is_intermediate=is_intermediate,
|
||||||
)
|
)
|
||||||
|
30
invokeai/app/services/resource_name.py
Normal file
30
invokeai/app/services/resource_name.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from enum import Enum, EnumMeta
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
|
||||||
|
class ResourceType(str, Enum, metaclass=EnumMeta):
|
||||||
|
"""Enum for resource types."""
|
||||||
|
|
||||||
|
IMAGE = "image"
|
||||||
|
LATENT = "latent"
|
||||||
|
|
||||||
|
|
||||||
|
class NameServiceBase(ABC):
|
||||||
|
"""Low-level service responsible for naming resources (images, latents, etc)."""
|
||||||
|
|
||||||
|
# TODO: Add customizable naming schemes
|
||||||
|
@abstractmethod
|
||||||
|
def create_image_name(self) -> str:
|
||||||
|
"""Creates a name for an image."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleNameService(NameServiceBase):
|
||||||
|
"""Creates image names from UUIDs."""
|
||||||
|
|
||||||
|
# TODO: Add customizable naming schemes
|
||||||
|
def create_image_name(self) -> str:
|
||||||
|
uuid_str = str(uuid.uuid4())
|
||||||
|
filename = f"{uuid_str}.png"
|
||||||
|
return filename
|
@ -1,7 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageType
|
from invokeai.app.models.image import ResourceOrigin
|
||||||
from invokeai.app.util.thumbnails import get_thumbnail_name
|
from invokeai.app.util.thumbnails import get_thumbnail_name
|
||||||
|
|
||||||
|
|
||||||
@ -10,7 +10,7 @@ class UrlServiceBase(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_image_url(
|
def get_image_url(
|
||||||
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Gets the URL for an image or thumbnail."""
|
"""Gets the URL for an image or thumbnail."""
|
||||||
pass
|
pass
|
||||||
@ -21,14 +21,14 @@ class LocalUrlService(UrlServiceBase):
|
|||||||
self._base_url = base_url
|
self._base_url = base_url
|
||||||
|
|
||||||
def get_image_url(
|
def get_image_url(
|
||||||
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||||
) -> str:
|
) -> str:
|
||||||
image_basename = os.path.basename(image_name)
|
image_basename = os.path.basename(image_name)
|
||||||
|
|
||||||
# These paths are determined by the routes in invokeai/app/api/routers/images.py
|
# These paths are determined by the routes in invokeai/app/api/routers/images.py
|
||||||
if thumbnail:
|
if thumbnail:
|
||||||
return (
|
return (
|
||||||
f"{self._base_url}/images/{image_type.value}/{image_basename}/thumbnail"
|
f"{self._base_url}/images/{image_origin.value}/{image_basename}/thumbnail"
|
||||||
)
|
)
|
||||||
|
|
||||||
return f"{self._base_url}/images/{image_type.value}/{image_basename}"
|
return f"{self._base_url}/images/{image_origin.value}/{image_basename}"
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from invokeai.app.api.models.images import ProgressImage
|
|
||||||
from invokeai.app.models.exceptions import CanceledException
|
from invokeai.app.models.exceptions import CanceledException
|
||||||
|
from invokeai.app.models.image import ProgressImage
|
||||||
from ..invocations.baseinvocation import InvocationContext
|
from ..invocations.baseinvocation import InvocationContext
|
||||||
from ...backend.util.util import image_to_dataURL
|
from ...backend.util.util import image_to_dataURL
|
||||||
from ...backend.generator.base import Generator
|
from ...backend.generator.base import Generator
|
||||||
|
@ -75,9 +75,11 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
model_info: dict,
|
model_info: dict,
|
||||||
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
|
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.model_info=model_info
|
self.model_info=model_info
|
||||||
self.params=params
|
self.params=params
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
def generate(self,
|
def generate(self,
|
||||||
prompt: str='',
|
prompt: str='',
|
||||||
@ -120,7 +122,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
|||||||
)
|
)
|
||||||
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model)
|
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model)
|
||||||
gen_class = self._generator_class()
|
gen_class = self._generator_class()
|
||||||
generator = gen_class(model, self.params.precision)
|
generator = gen_class(model, self.params.precision, **self.kwargs)
|
||||||
if self.params.variation_amount > 0:
|
if self.params.variation_amount > 0:
|
||||||
generator.set_variation(generator_args.get('seed'),
|
generator.set_variation(generator_args.get('seed'),
|
||||||
generator_args.get('variation_amount'),
|
generator_args.get('variation_amount'),
|
||||||
@ -275,7 +277,7 @@ class Generator:
|
|||||||
precision: str
|
precision: str
|
||||||
model: DiffusionPipeline
|
model: DiffusionPipeline
|
||||||
|
|
||||||
def __init__(self, model: DiffusionPipeline, precision: str):
|
def __init__(self, model: DiffusionPipeline, precision: str, **kwargs):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.precision = precision
|
self.precision = precision
|
||||||
self.seed = None
|
self.seed = None
|
||||||
|
@ -4,6 +4,10 @@ invokeai.backend.generator.txt2img inherits from invokeai.backend.generator
|
|||||||
import PIL.Image
|
import PIL.Image
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
from diffusers.models.controlnet import ControlNetModel, ControlNetOutput
|
||||||
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
|
||||||
|
|
||||||
from ..stable_diffusion import (
|
from ..stable_diffusion import (
|
||||||
ConditioningData,
|
ConditioningData,
|
||||||
PostprocessingSettings,
|
PostprocessingSettings,
|
||||||
@ -13,8 +17,13 @@ from .base import Generator
|
|||||||
|
|
||||||
|
|
||||||
class Txt2Img(Generator):
|
class Txt2Img(Generator):
|
||||||
def __init__(self, model, precision):
|
def __init__(self, model, precision,
|
||||||
super().__init__(model, precision)
|
control_model: Optional[Union[ControlNetModel, List[ControlNetModel]]] = None,
|
||||||
|
**kwargs):
|
||||||
|
self.control_model = control_model
|
||||||
|
if isinstance(self.control_model, list):
|
||||||
|
self.control_model = MultiControlNetModel(self.control_model)
|
||||||
|
super().__init__(model, precision, **kwargs)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def get_make_image(
|
def get_make_image(
|
||||||
@ -42,9 +51,12 @@ class Txt2Img(Generator):
|
|||||||
kwargs are 'width' and 'height'
|
kwargs are 'width' and 'height'
|
||||||
"""
|
"""
|
||||||
self.perlin = perlin
|
self.perlin = perlin
|
||||||
|
control_image = kwargs.get("control_image", None)
|
||||||
|
do_classifier_free_guidance = cfg_scale > 1.0
|
||||||
|
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
pipeline: StableDiffusionGeneratorPipeline = self.model
|
pipeline: StableDiffusionGeneratorPipeline = self.model
|
||||||
|
pipeline.control_model = self.control_model
|
||||||
pipeline.scheduler = sampler
|
pipeline.scheduler = sampler
|
||||||
|
|
||||||
uc, c, extra_conditioning_info = conditioning
|
uc, c, extra_conditioning_info = conditioning
|
||||||
@ -61,6 +73,37 @@ class Txt2Img(Generator):
|
|||||||
),
|
),
|
||||||
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
|
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
|
||||||
|
|
||||||
|
# FIXME: still need to test with different widths, heights, devices, dtypes
|
||||||
|
# and add in batch_size, num_images_per_prompt?
|
||||||
|
if control_image is not None:
|
||||||
|
if isinstance(self.control_model, ControlNetModel):
|
||||||
|
control_image = pipeline.prepare_control_image(
|
||||||
|
image=control_image,
|
||||||
|
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
# batch_size=batch_size * num_images_per_prompt,
|
||||||
|
# num_images_per_prompt=num_images_per_prompt,
|
||||||
|
device=self.control_model.device,
|
||||||
|
dtype=self.control_model.dtype,
|
||||||
|
)
|
||||||
|
elif isinstance(self.control_model, MultiControlNetModel):
|
||||||
|
images = []
|
||||||
|
for image_ in control_image:
|
||||||
|
image_ = pipeline.prepare_control_image(
|
||||||
|
image=image_,
|
||||||
|
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
# batch_size=batch_size * num_images_per_prompt,
|
||||||
|
# num_images_per_prompt=num_images_per_prompt,
|
||||||
|
device=self.control_model.device,
|
||||||
|
dtype=self.control_model.dtype,
|
||||||
|
)
|
||||||
|
images.append(image_)
|
||||||
|
control_image = images
|
||||||
|
kwargs["control_image"] = control_image
|
||||||
|
|
||||||
def make_image(x_T: torch.Tensor, _: int) -> PIL.Image.Image:
|
def make_image(x_T: torch.Tensor, _: int) -> PIL.Image.Image:
|
||||||
pipeline_output = pipeline.image_from_embeddings(
|
pipeline_output = pipeline.image_from_embeddings(
|
||||||
latents=torch.zeros_like(x_T, dtype=self.torch_dtype()),
|
latents=torch.zeros_like(x_T, dtype=self.torch_dtype()),
|
||||||
@ -68,6 +111,7 @@ class Txt2Img(Generator):
|
|||||||
num_inference_steps=steps,
|
num_inference_steps=steps,
|
||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
callback=step_callback,
|
callback=step_callback,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
@ -2,23 +2,29 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import inspect
|
import inspect
|
||||||
|
import math
|
||||||
import secrets
|
import secrets
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
|
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
|
import numpy as np
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
from compel import EmbeddingsProvider
|
from compel import EmbeddingsProvider
|
||||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||||
|
from diffusers.models.controlnet import ControlNetModel, ControlNetOutput
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
|
||||||
StableDiffusionPipeline,
|
StableDiffusionPipeline,
|
||||||
)
|
)
|
||||||
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
|
||||||
|
|
||||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
|
||||||
StableDiffusionImg2ImgPipeline,
|
StableDiffusionImg2ImgPipeline,
|
||||||
)
|
)
|
||||||
@ -27,6 +33,7 @@ from diffusers.pipelines.stable_diffusion.safety_checker import (
|
|||||||
)
|
)
|
||||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||||
|
from diffusers.utils import PIL_INTERPOLATION
|
||||||
from diffusers.utils.import_utils import is_xformers_available
|
from diffusers.utils.import_utils import is_xformers_available
|
||||||
from diffusers.utils.outputs import BaseOutput
|
from diffusers.utils.outputs import BaseOutput
|
||||||
from torchvision.transforms.functional import resize as tv_resize
|
from torchvision.transforms.functional import resize as tv_resize
|
||||||
@ -68,10 +75,10 @@ class AddsMaskLatents:
|
|||||||
initial_image_latents: torch.Tensor
|
initial_image_latents: torch.Tensor
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, latents: torch.Tensor, t: torch.Tensor, text_embeddings: torch.Tensor
|
self, latents: torch.Tensor, t: torch.Tensor, text_embeddings: torch.Tensor, **kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
model_input = self.add_mask_channels(latents)
|
model_input = self.add_mask_channels(latents)
|
||||||
return self.forward(model_input, t, text_embeddings)
|
return self.forward(model_input, t, text_embeddings, **kwargs)
|
||||||
|
|
||||||
def add_mask_channels(self, latents):
|
def add_mask_channels(self, latents):
|
||||||
batch_size = latents.size(0)
|
batch_size = latents.size(0)
|
||||||
@ -207,6 +214,13 @@ class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
|
|||||||
raise AssertionError("why was that an empty generator?")
|
raise AssertionError("why was that an empty generator?")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ControlNetData:
|
||||||
|
model: ControlNetModel = Field(default=None)
|
||||||
|
image_tensor: torch.Tensor= Field(default=None)
|
||||||
|
weight: float = Field(default=1.0)
|
||||||
|
begin_step_percent: float = Field(default=0.0)
|
||||||
|
end_step_percent: float = Field(default=1.0)
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ConditioningData:
|
class ConditioningData:
|
||||||
@ -302,6 +316,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
feature_extractor: Optional[CLIPFeatureExtractor],
|
feature_extractor: Optional[CLIPFeatureExtractor],
|
||||||
requires_safety_checker: bool = False,
|
requires_safety_checker: bool = False,
|
||||||
precision: str = "float32",
|
precision: str = "float32",
|
||||||
|
control_model: ControlNetModel = None,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
vae,
|
vae,
|
||||||
@ -322,6 +337,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
safety_checker=safety_checker,
|
safety_checker=safety_checker,
|
||||||
feature_extractor=feature_extractor,
|
feature_extractor=feature_extractor,
|
||||||
|
# FIXME: can't currently register control module
|
||||||
|
# control_model=control_model,
|
||||||
)
|
)
|
||||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(
|
self.invokeai_diffuser = InvokeAIDiffuserComponent(
|
||||||
self.unet, self._unet_forward, is_running_diffusers=True
|
self.unet, self._unet_forward, is_running_diffusers=True
|
||||||
@ -341,6 +358,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
self._model_group = FullyLoadedModelGroup(self.unet.device)
|
self._model_group = FullyLoadedModelGroup(self.unet.device)
|
||||||
self._model_group.install(*self._submodels)
|
self._model_group.install(*self._submodels)
|
||||||
|
self.control_model = control_model
|
||||||
|
|
||||||
def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
|
def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
|
||||||
"""
|
"""
|
||||||
@ -463,6 +481,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
noise: torch.Tensor,
|
noise: torch.Tensor,
|
||||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||||
run_id=None,
|
run_id=None,
|
||||||
|
**kwargs,
|
||||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||||
r"""
|
r"""
|
||||||
Function invoked when calling the pipeline for generation.
|
Function invoked when calling the pipeline for generation.
|
||||||
@ -483,6 +502,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
noise=noise,
|
noise=noise,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
callback=callback,
|
callback=callback,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@ -507,6 +527,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
additional_guidance: List[Callable] = None,
|
additional_guidance: List[Callable] = None,
|
||||||
run_id=None,
|
run_id=None,
|
||||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||||
|
control_data: List[ControlNetData] = None,
|
||||||
|
**kwargs,
|
||||||
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
||||||
if self.scheduler.config.get("cpu_only", False):
|
if self.scheduler.config.get("cpu_only", False):
|
||||||
scheduler_device = torch.device('cpu')
|
scheduler_device = torch.device('cpu')
|
||||||
@ -527,6 +549,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
additional_guidance=additional_guidance,
|
additional_guidance=additional_guidance,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
callback=callback,
|
callback=callback,
|
||||||
|
control_data=control_data,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
return result.latents, result.attention_map_saver
|
return result.latents, result.attention_map_saver
|
||||||
|
|
||||||
@ -539,6 +563,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
noise: torch.Tensor,
|
noise: torch.Tensor,
|
||||||
run_id: str = None,
|
run_id: str = None,
|
||||||
additional_guidance: List[Callable] = None,
|
additional_guidance: List[Callable] = None,
|
||||||
|
control_data: List[ControlNetData] = None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
self._adjust_memory_efficient_attention(latents)
|
self._adjust_memory_efficient_attention(latents)
|
||||||
if run_id is None:
|
if run_id is None:
|
||||||
@ -568,7 +594,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
||||||
|
|
||||||
attention_map_saver: Optional[AttentionMapSaver] = None
|
attention_map_saver: Optional[AttentionMapSaver] = None
|
||||||
|
# print("timesteps:", timesteps)
|
||||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||||
batched_t.fill_(t)
|
batched_t.fill_(t)
|
||||||
step_output = self.step(
|
step_output = self.step(
|
||||||
@ -578,6 +604,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
step_index=i,
|
step_index=i,
|
||||||
total_step_count=len(timesteps),
|
total_step_count=len(timesteps),
|
||||||
additional_guidance=additional_guidance,
|
additional_guidance=additional_guidance,
|
||||||
|
control_data=control_data,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
latents = step_output.prev_sample
|
latents = step_output.prev_sample
|
||||||
|
|
||||||
@ -618,10 +646,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
step_index: int,
|
step_index: int,
|
||||||
total_step_count: int,
|
total_step_count: int,
|
||||||
additional_guidance: List[Callable] = None,
|
additional_guidance: List[Callable] = None,
|
||||||
|
control_data: List[ControlNetData] = None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
||||||
timestep = t[0]
|
timestep = t[0]
|
||||||
|
|
||||||
if additional_guidance is None:
|
if additional_guidance is None:
|
||||||
additional_guidance = []
|
additional_guidance = []
|
||||||
|
|
||||||
@ -629,6 +658,48 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
# i.e. before or after passing it to InvokeAIDiffuserComponent
|
# i.e. before or after passing it to InvokeAIDiffuserComponent
|
||||||
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
|
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
|
||||||
|
|
||||||
|
# default is no controlnet, so set controlnet processing output to None
|
||||||
|
down_block_res_samples, mid_block_res_sample = None, None
|
||||||
|
|
||||||
|
if control_data is not None:
|
||||||
|
if conditioning_data.guidance_scale > 1.0:
|
||||||
|
# expand the latents input to control model if doing classifier free guidance
|
||||||
|
# (which I think for now is always true, there is conditional elsewhere that stops execution if
|
||||||
|
# classifier_free_guidance is <= 1.0 ?)
|
||||||
|
latent_control_input = torch.cat([latent_model_input] * 2)
|
||||||
|
else:
|
||||||
|
latent_control_input = latent_model_input
|
||||||
|
# control_data should be type List[ControlNetData]
|
||||||
|
# this loop covers both ControlNet (one ControlNetData in list)
|
||||||
|
# and MultiControlNet (multiple ControlNetData in list)
|
||||||
|
for i, control_datum in enumerate(control_data):
|
||||||
|
# print("controlnet", i, "==>", type(control_datum))
|
||||||
|
first_control_step = math.floor(control_datum.begin_step_percent * total_step_count)
|
||||||
|
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
|
||||||
|
# only apply controlnet if current step is within the controlnet's begin/end step range
|
||||||
|
if step_index >= first_control_step and step_index <= last_control_step:
|
||||||
|
# print("running controlnet", i, "for step", step_index)
|
||||||
|
down_samples, mid_sample = control_datum.model(
|
||||||
|
sample=latent_control_input,
|
||||||
|
timestep=timestep,
|
||||||
|
encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings,
|
||||||
|
conditioning_data.text_embeddings]),
|
||||||
|
controlnet_cond=control_datum.image_tensor,
|
||||||
|
conditioning_scale=control_datum.weight,
|
||||||
|
# cross_attention_kwargs,
|
||||||
|
guess_mode=False,
|
||||||
|
return_dict=False,
|
||||||
|
)
|
||||||
|
if down_block_res_samples is None and mid_block_res_sample is None:
|
||||||
|
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
|
||||||
|
else:
|
||||||
|
# add controlnet outputs together if have multiple controlnets
|
||||||
|
down_block_res_samples = [
|
||||||
|
samples_prev + samples_curr
|
||||||
|
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
|
||||||
|
]
|
||||||
|
mid_block_res_sample += mid_sample
|
||||||
|
|
||||||
# predict the noise residual
|
# predict the noise residual
|
||||||
noise_pred = self.invokeai_diffuser.do_diffusion_step(
|
noise_pred = self.invokeai_diffuser.do_diffusion_step(
|
||||||
latent_model_input,
|
latent_model_input,
|
||||||
@ -638,6 +709,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
conditioning_data.guidance_scale,
|
conditioning_data.guidance_scale,
|
||||||
step_index=step_index,
|
step_index=step_index,
|
||||||
total_step_count=total_step_count,
|
total_step_count=total_step_count,
|
||||||
|
down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
|
||||||
|
mid_block_additional_residual=mid_block_res_sample, # from controlnet(s)
|
||||||
)
|
)
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
@ -659,6 +732,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
t,
|
t,
|
||||||
text_embeddings,
|
text_embeddings,
|
||||||
cross_attention_kwargs: Optional[dict[str, Any]] = None,
|
cross_attention_kwargs: Optional[dict[str, Any]] = None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""predict the noise residual"""
|
"""predict the noise residual"""
|
||||||
if is_inpainting_model(self.unet) and latents.size(1) == 4:
|
if is_inpainting_model(self.unet) and latents.size(1) == 4:
|
||||||
@ -678,7 +752,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
# First three args should be positional, not keywords, so torch hooks can see them.
|
# First three args should be positional, not keywords, so torch hooks can see them.
|
||||||
return self.unet(
|
return self.unet(
|
||||||
latents, t, text_embeddings, cross_attention_kwargs=cross_attention_kwargs
|
latents, t, text_embeddings, cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
**kwargs,
|
||||||
).sample
|
).sample
|
||||||
|
|
||||||
def img2img_from_embeddings(
|
def img2img_from_embeddings(
|
||||||
@ -940,3 +1015,51 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
debug_image(
|
debug_image(
|
||||||
img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True
|
img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Copied from diffusers pipeline_stable_diffusion_controlnet.py
|
||||||
|
# Returns torch.Tensor of shape (batch_size, 3, height, width)
|
||||||
|
@staticmethod
|
||||||
|
def prepare_control_image(
|
||||||
|
image,
|
||||||
|
# FIXME: need to fix hardwiring of width and height, change to basing on latents dimensions?
|
||||||
|
# latents,
|
||||||
|
width=512, # should be 8 * latent.shape[3]
|
||||||
|
height=512, # should be 8 * latent height[2]
|
||||||
|
batch_size=1,
|
||||||
|
num_images_per_prompt=1,
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.float16,
|
||||||
|
do_classifier_free_guidance=True,
|
||||||
|
):
|
||||||
|
|
||||||
|
if not isinstance(image, torch.Tensor):
|
||||||
|
if isinstance(image, PIL.Image.Image):
|
||||||
|
image = [image]
|
||||||
|
|
||||||
|
if isinstance(image[0], PIL.Image.Image):
|
||||||
|
images = []
|
||||||
|
for image_ in image:
|
||||||
|
image_ = image_.convert("RGB")
|
||||||
|
image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
|
||||||
|
image_ = np.array(image_)
|
||||||
|
image_ = image_[None, :]
|
||||||
|
images.append(image_)
|
||||||
|
image = images
|
||||||
|
image = np.concatenate(image, axis=0)
|
||||||
|
image = np.array(image).astype(np.float32) / 255.0
|
||||||
|
image = image.transpose(0, 3, 1, 2)
|
||||||
|
image = torch.from_numpy(image)
|
||||||
|
elif isinstance(image[0], torch.Tensor):
|
||||||
|
image = torch.cat(image, dim=0)
|
||||||
|
|
||||||
|
image_batch_size = image.shape[0]
|
||||||
|
if image_batch_size == 1:
|
||||||
|
repeat_by = batch_size
|
||||||
|
else:
|
||||||
|
# image batch size is the same as prompt batch size
|
||||||
|
repeat_by = num_images_per_prompt
|
||||||
|
image = image.repeat_interleave(repeat_by, dim=0)
|
||||||
|
image = image.to(device=device, dtype=dtype)
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
image = torch.cat([image] * 2)
|
||||||
|
return image
|
||||||
|
@ -181,6 +181,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
unconditional_guidance_scale: float,
|
unconditional_guidance_scale: float,
|
||||||
step_index: Optional[int] = None,
|
step_index: Optional[int] = None,
|
||||||
total_step_count: Optional[int] = None,
|
total_step_count: Optional[int] = None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
:param x: current latents
|
:param x: current latents
|
||||||
@ -209,7 +210,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
if wants_hybrid_conditioning:
|
if wants_hybrid_conditioning:
|
||||||
unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(
|
unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(
|
||||||
x, sigma, unconditioning, conditioning
|
x, sigma, unconditioning, conditioning, **kwargs,
|
||||||
)
|
)
|
||||||
elif wants_cross_attention_control:
|
elif wants_cross_attention_control:
|
||||||
(
|
(
|
||||||
@ -221,13 +222,14 @@ class InvokeAIDiffuserComponent:
|
|||||||
unconditioning,
|
unconditioning,
|
||||||
conditioning,
|
conditioning,
|
||||||
cross_attention_control_types_to_do,
|
cross_attention_control_types_to_do,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
elif self.sequential_guidance:
|
elif self.sequential_guidance:
|
||||||
(
|
(
|
||||||
unconditioned_next_x,
|
unconditioned_next_x,
|
||||||
conditioned_next_x,
|
conditioned_next_x,
|
||||||
) = self._apply_standard_conditioning_sequentially(
|
) = self._apply_standard_conditioning_sequentially(
|
||||||
x, sigma, unconditioning, conditioning
|
x, sigma, unconditioning, conditioning, **kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -235,7 +237,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
unconditioned_next_x,
|
unconditioned_next_x,
|
||||||
conditioned_next_x,
|
conditioned_next_x,
|
||||||
) = self._apply_standard_conditioning(
|
) = self._apply_standard_conditioning(
|
||||||
x, sigma, unconditioning, conditioning
|
x, sigma, unconditioning, conditioning, **kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
combined_next_x = self._combine(
|
combined_next_x = self._combine(
|
||||||
@ -282,13 +284,13 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
# methods below are called from do_diffusion_step and should be considered private to this class.
|
# methods below are called from do_diffusion_step and should be considered private to this class.
|
||||||
|
|
||||||
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning):
|
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
||||||
# fast batched path
|
# fast batched path
|
||||||
x_twice = torch.cat([x] * 2)
|
x_twice = torch.cat([x] * 2)
|
||||||
sigma_twice = torch.cat([sigma] * 2)
|
sigma_twice = torch.cat([sigma] * 2)
|
||||||
both_conditionings = torch.cat([unconditioning, conditioning])
|
both_conditionings = torch.cat([unconditioning, conditioning])
|
||||||
both_results = self.model_forward_callback(
|
both_results = self.model_forward_callback(
|
||||||
x_twice, sigma_twice, both_conditionings
|
x_twice, sigma_twice, both_conditionings, **kwargs,
|
||||||
)
|
)
|
||||||
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
|
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
|
||||||
if conditioned_next_x.device.type == "mps":
|
if conditioned_next_x.device.type == "mps":
|
||||||
@ -302,16 +304,17 @@ class InvokeAIDiffuserComponent:
|
|||||||
sigma,
|
sigma,
|
||||||
unconditioning: torch.Tensor,
|
unconditioning: torch.Tensor,
|
||||||
conditioning: torch.Tensor,
|
conditioning: torch.Tensor,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
# low-memory sequential path
|
# low-memory sequential path
|
||||||
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
|
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs)
|
||||||
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning)
|
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning, **kwargs)
|
||||||
if conditioned_next_x.device.type == "mps":
|
if conditioned_next_x.device.type == "mps":
|
||||||
# prevent a result filled with zeros. seems to be a torch bug.
|
# prevent a result filled with zeros. seems to be a torch bug.
|
||||||
conditioned_next_x = conditioned_next_x.clone()
|
conditioned_next_x = conditioned_next_x.clone()
|
||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning):
|
def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
||||||
assert isinstance(conditioning, dict)
|
assert isinstance(conditioning, dict)
|
||||||
assert isinstance(unconditioning, dict)
|
assert isinstance(unconditioning, dict)
|
||||||
x_twice = torch.cat([x] * 2)
|
x_twice = torch.cat([x] * 2)
|
||||||
@ -326,7 +329,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
else:
|
else:
|
||||||
both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]])
|
both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]])
|
||||||
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(
|
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(
|
||||||
x_twice, sigma_twice, both_conditionings
|
x_twice, sigma_twice, both_conditionings, **kwargs,
|
||||||
).chunk(2)
|
).chunk(2)
|
||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
@ -337,6 +340,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
unconditioning,
|
unconditioning,
|
||||||
conditioning,
|
conditioning,
|
||||||
cross_attention_control_types_to_do,
|
cross_attention_control_types_to_do,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
if self.is_running_diffusers:
|
if self.is_running_diffusers:
|
||||||
return self._apply_cross_attention_controlled_conditioning__diffusers(
|
return self._apply_cross_attention_controlled_conditioning__diffusers(
|
||||||
@ -345,6 +349,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
unconditioning,
|
unconditioning,
|
||||||
conditioning,
|
conditioning,
|
||||||
cross_attention_control_types_to_do,
|
cross_attention_control_types_to_do,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self._apply_cross_attention_controlled_conditioning__compvis(
|
return self._apply_cross_attention_controlled_conditioning__compvis(
|
||||||
@ -353,6 +358,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
unconditioning,
|
unconditioning,
|
||||||
conditioning,
|
conditioning,
|
||||||
cross_attention_control_types_to_do,
|
cross_attention_control_types_to_do,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _apply_cross_attention_controlled_conditioning__diffusers(
|
def _apply_cross_attention_controlled_conditioning__diffusers(
|
||||||
@ -362,6 +368,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
unconditioning,
|
unconditioning,
|
||||||
conditioning,
|
conditioning,
|
||||||
cross_attention_control_types_to_do,
|
cross_attention_control_types_to_do,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
context: Context = self.cross_attention_control_context
|
context: Context = self.cross_attention_control_context
|
||||||
|
|
||||||
@ -377,6 +384,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
sigma,
|
sigma,
|
||||||
unconditioning,
|
unconditioning,
|
||||||
{"swap_cross_attn_context": cross_attn_processor_context},
|
{"swap_cross_attn_context": cross_attn_processor_context},
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# do requested cross attention types for conditioning (positive prompt)
|
# do requested cross attention types for conditioning (positive prompt)
|
||||||
@ -388,6 +396,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
sigma,
|
sigma,
|
||||||
conditioning,
|
conditioning,
|
||||||
{"swap_cross_attn_context": cross_attn_processor_context},
|
{"swap_cross_attn_context": cross_attn_processor_context},
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
@ -398,6 +407,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
unconditioning,
|
unconditioning,
|
||||||
conditioning,
|
conditioning,
|
||||||
cross_attention_control_types_to_do,
|
cross_attention_control_types_to_do,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
|
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
|
||||||
# slower non-batched path (20% slower on mac MPS)
|
# slower non-batched path (20% slower on mac MPS)
|
||||||
@ -411,13 +421,13 @@ class InvokeAIDiffuserComponent:
|
|||||||
context: Context = self.cross_attention_control_context
|
context: Context = self.cross_attention_control_context
|
||||||
|
|
||||||
try:
|
try:
|
||||||
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
|
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs)
|
||||||
|
|
||||||
# process x using the original prompt, saving the attention maps
|
# process x using the original prompt, saving the attention maps
|
||||||
# print("saving attention maps for", cross_attention_control_types_to_do)
|
# print("saving attention maps for", cross_attention_control_types_to_do)
|
||||||
for ca_type in cross_attention_control_types_to_do:
|
for ca_type in cross_attention_control_types_to_do:
|
||||||
context.request_save_attention_maps(ca_type)
|
context.request_save_attention_maps(ca_type)
|
||||||
_ = self.model_forward_callback(x, sigma, conditioning)
|
_ = self.model_forward_callback(x, sigma, conditioning, **kwargs,)
|
||||||
context.clear_requests(cleanup=False)
|
context.clear_requests(cleanup=False)
|
||||||
|
|
||||||
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
|
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
|
||||||
@ -428,7 +438,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
self.conditioning.cross_attention_control_args.edited_conditioning
|
self.conditioning.cross_attention_control_args.edited_conditioning
|
||||||
)
|
)
|
||||||
conditioned_next_x = self.model_forward_callback(
|
conditioned_next_x = self.model_forward_callback(
|
||||||
x, sigma, edited_conditioning
|
x, sigma, edited_conditioning, **kwargs,
|
||||||
)
|
)
|
||||||
context.clear_requests(cleanup=True)
|
context.clear_requests(cleanup=True)
|
||||||
|
|
||||||
|
2
invokeai/frontend/web/dist/index.html
vendored
2
invokeai/frontend/web/dist/index.html
vendored
@ -12,7 +12,7 @@
|
|||||||
margin: 0;
|
margin: 0;
|
||||||
}
|
}
|
||||||
</style>
|
</style>
|
||||||
<script type="module" crossorigin src="./assets/index-5fb14ef2.js"></script>
|
<script type="module" crossorigin src="./assets/index-251c2c6e.js"></script>
|
||||||
</head>
|
</head>
|
||||||
|
|
||||||
<body dir="ltr">
|
<body dir="ltr">
|
||||||
|
10
invokeai/frontend/web/dist/locales/en.json
vendored
10
invokeai/frontend/web/dist/locales/en.json
vendored
@ -122,7 +122,9 @@
|
|||||||
"noImagesInGallery": "No Images In Gallery",
|
"noImagesInGallery": "No Images In Gallery",
|
||||||
"deleteImage": "Delete Image",
|
"deleteImage": "Delete Image",
|
||||||
"deleteImageBin": "Deleted images will be sent to your operating system's Bin.",
|
"deleteImageBin": "Deleted images will be sent to your operating system's Bin.",
|
||||||
"deleteImagePermanent": "Deleted images cannot be restored."
|
"deleteImagePermanent": "Deleted images cannot be restored.",
|
||||||
|
"images": "Images",
|
||||||
|
"assets": "Assets"
|
||||||
},
|
},
|
||||||
"hotkeys": {
|
"hotkeys": {
|
||||||
"keyboardShortcuts": "Keyboard Shortcuts",
|
"keyboardShortcuts": "Keyboard Shortcuts",
|
||||||
@ -452,6 +454,8 @@
|
|||||||
"height": "Height",
|
"height": "Height",
|
||||||
"scheduler": "Scheduler",
|
"scheduler": "Scheduler",
|
||||||
"seed": "Seed",
|
"seed": "Seed",
|
||||||
|
"boundingBoxWidth": "Bounding Box Width",
|
||||||
|
"boundingBoxHeight": "Bounding Box Height",
|
||||||
"imageToImage": "Image to Image",
|
"imageToImage": "Image to Image",
|
||||||
"randomizeSeed": "Randomize Seed",
|
"randomizeSeed": "Randomize Seed",
|
||||||
"shuffle": "Shuffle Seed",
|
"shuffle": "Shuffle Seed",
|
||||||
@ -524,7 +528,7 @@
|
|||||||
},
|
},
|
||||||
"settings": {
|
"settings": {
|
||||||
"models": "Models",
|
"models": "Models",
|
||||||
"displayInProgress": "Display In-Progress Images",
|
"displayInProgress": "Display Progress Images",
|
||||||
"saveSteps": "Save images every n steps",
|
"saveSteps": "Save images every n steps",
|
||||||
"confirmOnDelete": "Confirm On Delete",
|
"confirmOnDelete": "Confirm On Delete",
|
||||||
"displayHelpIcons": "Display Help Icons",
|
"displayHelpIcons": "Display Help Icons",
|
||||||
@ -564,6 +568,8 @@
|
|||||||
"canvasMerged": "Canvas Merged",
|
"canvasMerged": "Canvas Merged",
|
||||||
"sentToImageToImage": "Sent To Image To Image",
|
"sentToImageToImage": "Sent To Image To Image",
|
||||||
"sentToUnifiedCanvas": "Sent to Unified Canvas",
|
"sentToUnifiedCanvas": "Sent to Unified Canvas",
|
||||||
|
"parameterSet": "Parameter set",
|
||||||
|
"parameterNotSet": "Parameter not set",
|
||||||
"parametersSet": "Parameters Set",
|
"parametersSet": "Parameters Set",
|
||||||
"parametersNotSet": "Parameters Not Set",
|
"parametersNotSet": "Parameters Not Set",
|
||||||
"parametersNotSetDesc": "No metadata found for this image.",
|
"parametersNotSetDesc": "No metadata found for this image.",
|
||||||
|
@ -101,7 +101,8 @@
|
|||||||
"serialize-error": "^11.0.0",
|
"serialize-error": "^11.0.0",
|
||||||
"socket.io-client": "^4.6.0",
|
"socket.io-client": "^4.6.0",
|
||||||
"use-image": "^1.1.0",
|
"use-image": "^1.1.0",
|
||||||
"uuid": "^9.0.0"
|
"uuid": "^9.0.0",
|
||||||
|
"zod": "^3.21.4"
|
||||||
},
|
},
|
||||||
"peerDependencies": {
|
"peerDependencies": {
|
||||||
"@chakra-ui/cli": "^2.4.0",
|
"@chakra-ui/cli": "^2.4.0",
|
||||||
|
@ -122,7 +122,9 @@
|
|||||||
"noImagesInGallery": "No Images In Gallery",
|
"noImagesInGallery": "No Images In Gallery",
|
||||||
"deleteImage": "Delete Image",
|
"deleteImage": "Delete Image",
|
||||||
"deleteImageBin": "Deleted images will be sent to your operating system's Bin.",
|
"deleteImageBin": "Deleted images will be sent to your operating system's Bin.",
|
||||||
"deleteImagePermanent": "Deleted images cannot be restored."
|
"deleteImagePermanent": "Deleted images cannot be restored.",
|
||||||
|
"images": "Images",
|
||||||
|
"assets": "Assets"
|
||||||
},
|
},
|
||||||
"hotkeys": {
|
"hotkeys": {
|
||||||
"keyboardShortcuts": "Keyboard Shortcuts",
|
"keyboardShortcuts": "Keyboard Shortcuts",
|
||||||
@ -452,6 +454,8 @@
|
|||||||
"height": "Height",
|
"height": "Height",
|
||||||
"scheduler": "Scheduler",
|
"scheduler": "Scheduler",
|
||||||
"seed": "Seed",
|
"seed": "Seed",
|
||||||
|
"boundingBoxWidth": "Bounding Box Width",
|
||||||
|
"boundingBoxHeight": "Bounding Box Height",
|
||||||
"imageToImage": "Image to Image",
|
"imageToImage": "Image to Image",
|
||||||
"randomizeSeed": "Randomize Seed",
|
"randomizeSeed": "Randomize Seed",
|
||||||
"shuffle": "Shuffle Seed",
|
"shuffle": "Shuffle Seed",
|
||||||
@ -524,7 +528,7 @@
|
|||||||
},
|
},
|
||||||
"settings": {
|
"settings": {
|
||||||
"models": "Models",
|
"models": "Models",
|
||||||
"displayInProgress": "Display In-Progress Images",
|
"displayInProgress": "Display Progress Images",
|
||||||
"saveSteps": "Save images every n steps",
|
"saveSteps": "Save images every n steps",
|
||||||
"confirmOnDelete": "Confirm On Delete",
|
"confirmOnDelete": "Confirm On Delete",
|
||||||
"displayHelpIcons": "Display Help Icons",
|
"displayHelpIcons": "Display Help Icons",
|
||||||
@ -564,6 +568,8 @@
|
|||||||
"canvasMerged": "Canvas Merged",
|
"canvasMerged": "Canvas Merged",
|
||||||
"sentToImageToImage": "Sent To Image To Image",
|
"sentToImageToImage": "Sent To Image To Image",
|
||||||
"sentToUnifiedCanvas": "Sent to Unified Canvas",
|
"sentToUnifiedCanvas": "Sent to Unified Canvas",
|
||||||
|
"parameterSet": "Parameter set",
|
||||||
|
"parameterNotSet": "Parameter not set",
|
||||||
"parametersSet": "Parameters Set",
|
"parametersSet": "Parameters Set",
|
||||||
"parametersNotSet": "Parameters Not Set",
|
"parametersNotSet": "Parameters Not Set",
|
||||||
"parametersNotSetDesc": "No metadata found for this image.",
|
"parametersNotSetDesc": "No metadata found for this image.",
|
||||||
|
@ -21,25 +21,11 @@ export const SCHEDULERS = [
|
|||||||
|
|
||||||
export type Scheduler = (typeof SCHEDULERS)[number];
|
export type Scheduler = (typeof SCHEDULERS)[number];
|
||||||
|
|
||||||
export const isScheduler = (x: string): x is Scheduler =>
|
|
||||||
SCHEDULERS.includes(x as Scheduler);
|
|
||||||
|
|
||||||
// Valid image widths
|
|
||||||
export const WIDTHS: Array<number> = Array.from(Array(64)).map(
|
|
||||||
(_x, i) => (i + 1) * 64
|
|
||||||
);
|
|
||||||
|
|
||||||
// Valid image heights
|
|
||||||
export const HEIGHTS: Array<number> = Array.from(Array(64)).map(
|
|
||||||
(_x, i) => (i + 1) * 64
|
|
||||||
);
|
|
||||||
|
|
||||||
// Valid upscaling levels
|
// Valid upscaling levels
|
||||||
export const UPSCALING_LEVELS: Array<{ key: string; value: number }> = [
|
export const UPSCALING_LEVELS: Array<{ key: string; value: number }> = [
|
||||||
{ key: '2x', value: 2 },
|
{ key: '2x', value: 2 },
|
||||||
{ key: '4x', value: 4 },
|
{ key: '4x', value: 4 },
|
||||||
];
|
];
|
||||||
|
|
||||||
export const NUMPY_RAND_MIN = 0;
|
export const NUMPY_RAND_MIN = 0;
|
||||||
|
|
||||||
export const NUMPY_RAND_MAX = 2147483647;
|
export const NUMPY_RAND_MAX = 2147483647;
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
import { canvasPersistDenylist } from 'features/canvas/store/canvasPersistDenylist';
|
import { canvasPersistDenylist } from 'features/canvas/store/canvasPersistDenylist';
|
||||||
import { galleryPersistDenylist } from 'features/gallery/store/galleryPersistDenylist';
|
import { galleryPersistDenylist } from 'features/gallery/store/galleryPersistDenylist';
|
||||||
import { resultsPersistDenylist } from 'features/gallery/store/resultsPersistDenylist';
|
|
||||||
import { uploadsPersistDenylist } from 'features/gallery/store/uploadsPersistDenylist';
|
|
||||||
import { lightboxPersistDenylist } from 'features/lightbox/store/lightboxPersistDenylist';
|
import { lightboxPersistDenylist } from 'features/lightbox/store/lightboxPersistDenylist';
|
||||||
import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist';
|
import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist';
|
||||||
import { generationPersistDenylist } from 'features/parameters/store/generationPersistDenylist';
|
import { generationPersistDenylist } from 'features/parameters/store/generationPersistDenylist';
|
||||||
@ -22,11 +20,9 @@ const serializationDenylist: {
|
|||||||
models: modelsPersistDenylist,
|
models: modelsPersistDenylist,
|
||||||
nodes: nodesPersistDenylist,
|
nodes: nodesPersistDenylist,
|
||||||
postprocessing: postprocessingPersistDenylist,
|
postprocessing: postprocessingPersistDenylist,
|
||||||
results: resultsPersistDenylist,
|
|
||||||
system: systemPersistDenylist,
|
system: systemPersistDenylist,
|
||||||
// config: configPersistDenyList,
|
// config: configPersistDenyList,
|
||||||
ui: uiPersistDenylist,
|
ui: uiPersistDenylist,
|
||||||
uploads: uploadsPersistDenylist,
|
|
||||||
// hotkeys: hotkeysPersistDenylist,
|
// hotkeys: hotkeysPersistDenylist,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import { initialCanvasState } from 'features/canvas/store/canvasSlice';
|
import { initialCanvasState } from 'features/canvas/store/canvasSlice';
|
||||||
import { initialGalleryState } from 'features/gallery/store/gallerySlice';
|
import { initialGalleryState } from 'features/gallery/store/gallerySlice';
|
||||||
import { initialResultsState } from 'features/gallery/store/resultsSlice';
|
import { initialImagesState } from 'features/gallery/store/imagesSlice';
|
||||||
import { initialUploadsState } from 'features/gallery/store/uploadsSlice';
|
|
||||||
import { initialLightboxState } from 'features/lightbox/store/lightboxSlice';
|
import { initialLightboxState } from 'features/lightbox/store/lightboxSlice';
|
||||||
import { initialNodesState } from 'features/nodes/store/nodesSlice';
|
import { initialNodesState } from 'features/nodes/store/nodesSlice';
|
||||||
import { initialGenerationState } from 'features/parameters/store/generationSlice';
|
import { initialGenerationState } from 'features/parameters/store/generationSlice';
|
||||||
@ -24,12 +23,11 @@ const initialStates: {
|
|||||||
models: initialModelsState,
|
models: initialModelsState,
|
||||||
nodes: initialNodesState,
|
nodes: initialNodesState,
|
||||||
postprocessing: initialPostprocessingState,
|
postprocessing: initialPostprocessingState,
|
||||||
results: initialResultsState,
|
|
||||||
system: initialSystemState,
|
system: initialSystemState,
|
||||||
config: initialConfigState,
|
config: initialConfigState,
|
||||||
ui: initialUIState,
|
ui: initialUIState,
|
||||||
uploads: initialUploadsState,
|
|
||||||
hotkeys: initialHotkeysState,
|
hotkeys: initialHotkeysState,
|
||||||
|
images: initialImagesState,
|
||||||
};
|
};
|
||||||
|
|
||||||
export const unserialize: UnserializeFunction = (data, key) => {
|
export const unserialize: UnserializeFunction = (data, key) => {
|
||||||
|
@ -7,5 +7,6 @@ export const actionsDenylist = [
|
|||||||
'canvas/setBoundingBoxDimensions',
|
'canvas/setBoundingBoxDimensions',
|
||||||
'canvas/setIsDrawing',
|
'canvas/setIsDrawing',
|
||||||
'canvas/addPointToCurrentLine',
|
'canvas/addPointToCurrentLine',
|
||||||
'socket/generatorProgress',
|
'socket/socketGeneratorProgress',
|
||||||
|
'socket/appSocketGeneratorProgress',
|
||||||
];
|
];
|
||||||
|
@ -8,9 +8,16 @@ import type { TypedStartListening, TypedAddListener } from '@reduxjs/toolkit';
|
|||||||
|
|
||||||
import type { RootState, AppDispatch } from '../../store';
|
import type { RootState, AppDispatch } from '../../store';
|
||||||
import { addInitialImageSelectedListener } from './listeners/initialImageSelected';
|
import { addInitialImageSelectedListener } from './listeners/initialImageSelected';
|
||||||
import { addImageResultReceivedListener } from './listeners/invocationComplete';
|
import {
|
||||||
import { addImageUploadedListener } from './listeners/imageUploaded';
|
addImageUploadedFulfilledListener,
|
||||||
import { addRequestedImageDeletionListener } from './listeners/imageDeleted';
|
addImageUploadedRejectedListener,
|
||||||
|
} from './listeners/imageUploaded';
|
||||||
|
import {
|
||||||
|
addImageDeletedFulfilledListener,
|
||||||
|
addImageDeletedPendingListener,
|
||||||
|
addImageDeletedRejectedListener,
|
||||||
|
addRequestedImageDeletionListener,
|
||||||
|
} from './listeners/imageDeleted';
|
||||||
import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
|
import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
|
||||||
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
|
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
|
||||||
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
|
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
|
||||||
@ -19,6 +26,50 @@ import { addCanvasSavedToGalleryListener } from './listeners/canvasSavedToGaller
|
|||||||
import { addCanvasDownloadedAsImageListener } from './listeners/canvasDownloadedAsImage';
|
import { addCanvasDownloadedAsImageListener } from './listeners/canvasDownloadedAsImage';
|
||||||
import { addCanvasCopiedToClipboardListener } from './listeners/canvasCopiedToClipboard';
|
import { addCanvasCopiedToClipboardListener } from './listeners/canvasCopiedToClipboard';
|
||||||
import { addCanvasMergedListener } from './listeners/canvasMerged';
|
import { addCanvasMergedListener } from './listeners/canvasMerged';
|
||||||
|
import { addGeneratorProgressEventListener as addGeneratorProgressListener } from './listeners/socketio/socketGeneratorProgress';
|
||||||
|
import { addGraphExecutionStateCompleteEventListener as addGraphExecutionStateCompleteListener } from './listeners/socketio/socketGraphExecutionStateComplete';
|
||||||
|
import { addInvocationCompleteEventListener as addInvocationCompleteListener } from './listeners/socketio/socketInvocationComplete';
|
||||||
|
import { addInvocationErrorEventListener as addInvocationErrorListener } from './listeners/socketio/socketInvocationError';
|
||||||
|
import { addInvocationStartedEventListener as addInvocationStartedListener } from './listeners/socketio/socketInvocationStarted';
|
||||||
|
import { addSocketConnectedEventListener as addSocketConnectedListener } from './listeners/socketio/socketConnected';
|
||||||
|
import { addSocketDisconnectedEventListener as addSocketDisconnectedListener } from './listeners/socketio/socketDisconnected';
|
||||||
|
import { addSocketSubscribedEventListener as addSocketSubscribedListener } from './listeners/socketio/socketSubscribed';
|
||||||
|
import { addSocketUnsubscribedEventListener as addSocketUnsubscribedListener } from './listeners/socketio/socketUnsubscribed';
|
||||||
|
import { addSessionReadyToInvokeListener } from './listeners/sessionReadyToInvoke';
|
||||||
|
import {
|
||||||
|
addImageMetadataReceivedFulfilledListener,
|
||||||
|
addImageMetadataReceivedRejectedListener,
|
||||||
|
} from './listeners/imageMetadataReceived';
|
||||||
|
import {
|
||||||
|
addImageUrlsReceivedFulfilledListener,
|
||||||
|
addImageUrlsReceivedRejectedListener,
|
||||||
|
} from './listeners/imageUrlsReceived';
|
||||||
|
import {
|
||||||
|
addSessionCreatedFulfilledListener,
|
||||||
|
addSessionCreatedPendingListener,
|
||||||
|
addSessionCreatedRejectedListener,
|
||||||
|
} from './listeners/sessionCreated';
|
||||||
|
import {
|
||||||
|
addSessionInvokedFulfilledListener,
|
||||||
|
addSessionInvokedPendingListener,
|
||||||
|
addSessionInvokedRejectedListener,
|
||||||
|
} from './listeners/sessionInvoked';
|
||||||
|
import {
|
||||||
|
addSessionCanceledFulfilledListener,
|
||||||
|
addSessionCanceledPendingListener,
|
||||||
|
addSessionCanceledRejectedListener,
|
||||||
|
} from './listeners/sessionCanceled';
|
||||||
|
import {
|
||||||
|
addImageUpdatedFulfilledListener,
|
||||||
|
addImageUpdatedRejectedListener,
|
||||||
|
} from './listeners/imageUpdated';
|
||||||
|
import {
|
||||||
|
addReceivedPageOfImagesFulfilledListener,
|
||||||
|
addReceivedPageOfImagesRejectedListener,
|
||||||
|
} from './listeners/receivedPageOfImages';
|
||||||
|
import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSaved';
|
||||||
|
import { addCommitStagingAreaImageListener } from './listeners/addCommitStagingAreaImageListener';
|
||||||
|
import { addImageCategoriesChangedListener } from './listeners/imageCategoriesChanged';
|
||||||
|
|
||||||
export const listenerMiddleware = createListenerMiddleware();
|
export const listenerMiddleware = createListenerMiddleware();
|
||||||
|
|
||||||
@ -38,17 +89,87 @@ export type AppListenerEffect = ListenerEffect<
|
|||||||
AppDispatch
|
AppDispatch
|
||||||
>;
|
>;
|
||||||
|
|
||||||
addImageUploadedListener();
|
// Image uploaded
|
||||||
addInitialImageSelectedListener();
|
addImageUploadedFulfilledListener();
|
||||||
addImageResultReceivedListener();
|
addImageUploadedRejectedListener();
|
||||||
addRequestedImageDeletionListener();
|
|
||||||
|
|
||||||
|
// Image updated
|
||||||
|
addImageUpdatedFulfilledListener();
|
||||||
|
addImageUpdatedRejectedListener();
|
||||||
|
|
||||||
|
// Image selected
|
||||||
|
addInitialImageSelectedListener();
|
||||||
|
|
||||||
|
// Image deleted
|
||||||
|
addRequestedImageDeletionListener();
|
||||||
|
addImageDeletedPendingListener();
|
||||||
|
addImageDeletedFulfilledListener();
|
||||||
|
addImageDeletedRejectedListener();
|
||||||
|
|
||||||
|
// Image metadata
|
||||||
|
addImageMetadataReceivedFulfilledListener();
|
||||||
|
addImageMetadataReceivedRejectedListener();
|
||||||
|
|
||||||
|
// Image URLs
|
||||||
|
addImageUrlsReceivedFulfilledListener();
|
||||||
|
addImageUrlsReceivedRejectedListener();
|
||||||
|
|
||||||
|
// User Invoked
|
||||||
addUserInvokedCanvasListener();
|
addUserInvokedCanvasListener();
|
||||||
addUserInvokedNodesListener();
|
addUserInvokedNodesListener();
|
||||||
addUserInvokedTextToImageListener();
|
addUserInvokedTextToImageListener();
|
||||||
addUserInvokedImageToImageListener();
|
addUserInvokedImageToImageListener();
|
||||||
|
addSessionReadyToInvokeListener();
|
||||||
|
|
||||||
|
// Canvas actions
|
||||||
addCanvasSavedToGalleryListener();
|
addCanvasSavedToGalleryListener();
|
||||||
addCanvasDownloadedAsImageListener();
|
addCanvasDownloadedAsImageListener();
|
||||||
addCanvasCopiedToClipboardListener();
|
addCanvasCopiedToClipboardListener();
|
||||||
addCanvasMergedListener();
|
addCanvasMergedListener();
|
||||||
|
addStagingAreaImageSavedListener();
|
||||||
|
addCommitStagingAreaImageListener();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Socket.IO Events - these handle SIO events directly and pass on internal application actions.
|
||||||
|
* We don't handle SIO events in slices via `extraReducers` because some of these events shouldn't
|
||||||
|
* actually be handled at all.
|
||||||
|
*
|
||||||
|
* For example, we don't want to respond to progress events for canceled sessions. To avoid
|
||||||
|
* duplicating the logic to determine if an event should be responded to, we handle all of that
|
||||||
|
* "is this session canceled?" logic in these listeners.
|
||||||
|
*
|
||||||
|
* The `socketGeneratorProgress` listener will then only dispatch the `appSocketGeneratorProgress`
|
||||||
|
* action if it should be handled by the rest of the application. It is this `appSocketGeneratorProgress`
|
||||||
|
* action that is handled by reducers in slices.
|
||||||
|
*/
|
||||||
|
addGeneratorProgressListener();
|
||||||
|
addGraphExecutionStateCompleteListener();
|
||||||
|
addInvocationCompleteListener();
|
||||||
|
addInvocationErrorListener();
|
||||||
|
addInvocationStartedListener();
|
||||||
|
addSocketConnectedListener();
|
||||||
|
addSocketDisconnectedListener();
|
||||||
|
addSocketSubscribedListener();
|
||||||
|
addSocketUnsubscribedListener();
|
||||||
|
|
||||||
|
// Session Created
|
||||||
|
addSessionCreatedPendingListener();
|
||||||
|
addSessionCreatedFulfilledListener();
|
||||||
|
addSessionCreatedRejectedListener();
|
||||||
|
|
||||||
|
// Session Invoked
|
||||||
|
addSessionInvokedPendingListener();
|
||||||
|
addSessionInvokedFulfilledListener();
|
||||||
|
addSessionInvokedRejectedListener();
|
||||||
|
|
||||||
|
// Session Canceled
|
||||||
|
addSessionCanceledPendingListener();
|
||||||
|
addSessionCanceledFulfilledListener();
|
||||||
|
addSessionCanceledRejectedListener();
|
||||||
|
|
||||||
|
// Fetching images
|
||||||
|
addReceivedPageOfImagesFulfilledListener();
|
||||||
|
addReceivedPageOfImagesRejectedListener();
|
||||||
|
|
||||||
|
// Gallery
|
||||||
|
addImageCategoriesChangedListener();
|
||||||
|
@ -0,0 +1,42 @@
|
|||||||
|
import { startAppListening } from '..';
|
||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import { commitStagingAreaImage } from 'features/canvas/store/canvasSlice';
|
||||||
|
import { sessionCanceled } from 'services/thunks/session';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'canvas' });
|
||||||
|
|
||||||
|
export const addCommitStagingAreaImageListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: commitStagingAreaImage,
|
||||||
|
effect: async (action, { dispatch, getState }) => {
|
||||||
|
const state = getState();
|
||||||
|
const { sessionId, isProcessing } = state.system;
|
||||||
|
const canvasSessionId = action.payload;
|
||||||
|
|
||||||
|
if (!isProcessing) {
|
||||||
|
// Only need to cancel if we are processing
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!canvasSessionId) {
|
||||||
|
moduleLog.debug('No canvas session, skipping cancel');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (canvasSessionId !== sessionId) {
|
||||||
|
moduleLog.debug(
|
||||||
|
{
|
||||||
|
data: {
|
||||||
|
canvasSessionId,
|
||||||
|
sessionId,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
'Canvas session does not match global session, skipping cancel'
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(sessionCanceled({ sessionId }));
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -52,10 +52,11 @@ export const addCanvasMergedListener = () => {
|
|||||||
|
|
||||||
dispatch(
|
dispatch(
|
||||||
imageUploaded({
|
imageUploaded({
|
||||||
imageType: 'intermediates',
|
|
||||||
formData: {
|
formData: {
|
||||||
file: new File([blob], filename, { type: 'image/png' }),
|
file: new File([blob], filename, { type: 'image/png' }),
|
||||||
},
|
},
|
||||||
|
imageCategory: 'general',
|
||||||
|
isIntermediate: true,
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
|
|
||||||
@ -65,7 +66,7 @@ export const addCanvasMergedListener = () => {
|
|||||||
action.meta.arg.formData.file.name === filename
|
action.meta.arg.formData.file.name === filename
|
||||||
);
|
);
|
||||||
|
|
||||||
const mergedCanvasImage = payload.response;
|
const mergedCanvasImage = payload;
|
||||||
|
|
||||||
dispatch(
|
dispatch(
|
||||||
setMergedCanvas({
|
setMergedCanvas({
|
||||||
|
@ -4,16 +4,18 @@ import { log } from 'app/logging/useLogger';
|
|||||||
import { imageUploaded } from 'services/thunks/image';
|
import { imageUploaded } from 'services/thunks/image';
|
||||||
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
|
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { v4 as uuidv4 } from 'uuid';
|
||||||
|
import { imageUpserted } from 'features/gallery/store/imagesSlice';
|
||||||
|
|
||||||
const moduleLog = log.child({ namespace: 'canvasSavedToGalleryListener' });
|
const moduleLog = log.child({ namespace: 'canvasSavedToGalleryListener' });
|
||||||
|
|
||||||
export const addCanvasSavedToGalleryListener = () => {
|
export const addCanvasSavedToGalleryListener = () => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: canvasSavedToGallery,
|
actionCreator: canvasSavedToGallery,
|
||||||
effect: async (action, { dispatch, getState }) => {
|
effect: async (action, { dispatch, getState, take }) => {
|
||||||
const state = getState();
|
const state = getState();
|
||||||
|
|
||||||
const blob = await getBaseLayerBlob(state);
|
const blob = await getBaseLayerBlob(state, true);
|
||||||
|
|
||||||
if (!blob) {
|
if (!blob) {
|
||||||
moduleLog.error('Problem getting base layer blob');
|
moduleLog.error('Problem getting base layer blob');
|
||||||
@ -27,14 +29,25 @@ export const addCanvasSavedToGalleryListener = () => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const filename = `mergedCanvas_${uuidv4()}.png`;
|
||||||
|
|
||||||
dispatch(
|
dispatch(
|
||||||
imageUploaded({
|
imageUploaded({
|
||||||
imageType: 'results',
|
|
||||||
formData: {
|
formData: {
|
||||||
file: new File([blob], 'mergedCanvas.png', { type: 'image/png' }),
|
file: new File([blob], filename, { type: 'image/png' }),
|
||||||
},
|
},
|
||||||
|
imageCategory: 'general',
|
||||||
|
isIntermediate: false,
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const [{ payload: uploadedImageDTO }] = await take(
|
||||||
|
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
|
||||||
|
imageUploaded.fulfilled.match(action) &&
|
||||||
|
action.meta.arg.formData.file.name === filename
|
||||||
|
);
|
||||||
|
|
||||||
|
dispatch(imageUpserted(uploadedImageDTO));
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
@ -0,0 +1,24 @@
|
|||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import { startAppListening } from '..';
|
||||||
|
import { receivedPageOfImages } from 'services/thunks/image';
|
||||||
|
import {
|
||||||
|
imageCategoriesChanged,
|
||||||
|
selectFilteredImagesAsArray,
|
||||||
|
} from 'features/gallery/store/imagesSlice';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'gallery' });
|
||||||
|
|
||||||
|
export const addImageCategoriesChangedListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: imageCategoriesChanged,
|
||||||
|
effect: (action, { getState, dispatch }) => {
|
||||||
|
const filteredImagesCount = selectFilteredImagesAsArray(
|
||||||
|
getState()
|
||||||
|
).length;
|
||||||
|
|
||||||
|
if (!filteredImagesCount) {
|
||||||
|
dispatch(receivedPageOfImages());
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -4,9 +4,18 @@ import { imageDeleted } from 'services/thunks/image';
|
|||||||
import { log } from 'app/logging/useLogger';
|
import { log } from 'app/logging/useLogger';
|
||||||
import { clamp } from 'lodash-es';
|
import { clamp } from 'lodash-es';
|
||||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||||
|
import {
|
||||||
|
imageRemoved,
|
||||||
|
imagesAdapter,
|
||||||
|
selectImagesEntities,
|
||||||
|
selectImagesIds,
|
||||||
|
} from 'features/gallery/store/imagesSlice';
|
||||||
|
|
||||||
const moduleLog = log.child({ namespace: 'addRequestedImageDeletionListener' });
|
const moduleLog = log.child({ namespace: 'addRequestedImageDeletionListener' });
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Called when the user requests an image deletion
|
||||||
|
*/
|
||||||
export const addRequestedImageDeletionListener = () => {
|
export const addRequestedImageDeletionListener = () => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: requestedImageDeletion,
|
actionCreator: requestedImageDeletion,
|
||||||
@ -17,24 +26,20 @@ export const addRequestedImageDeletionListener = () => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const { image_name, image_type } = image;
|
const { image_name, image_origin } = image;
|
||||||
|
|
||||||
if (image_type !== 'uploads' && image_type !== 'results') {
|
const state = getState();
|
||||||
moduleLog.warn({ data: image }, `Invalid image type ${image_type}`);
|
const selectedImage = state.gallery.selectedImage;
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const selectedImageName = getState().gallery.selectedImage?.image_name;
|
if (selectedImage && selectedImage.image_name === image_name) {
|
||||||
|
const ids = selectImagesIds(state);
|
||||||
|
const entities = selectImagesEntities(state);
|
||||||
|
|
||||||
if (selectedImageName === image_name) {
|
const deletedImageIndex = ids.findIndex(
|
||||||
const allIds = getState()[image_type].ids;
|
|
||||||
const allEntities = getState()[image_type].entities;
|
|
||||||
|
|
||||||
const deletedImageIndex = allIds.findIndex(
|
|
||||||
(result) => result.toString() === image_name
|
(result) => result.toString() === image_name
|
||||||
);
|
);
|
||||||
|
|
||||||
const filteredIds = allIds.filter((id) => id.toString() !== image_name);
|
const filteredIds = ids.filter((id) => id.toString() !== image_name);
|
||||||
|
|
||||||
const newSelectedImageIndex = clamp(
|
const newSelectedImageIndex = clamp(
|
||||||
deletedImageIndex,
|
deletedImageIndex,
|
||||||
@ -44,7 +49,7 @@ export const addRequestedImageDeletionListener = () => {
|
|||||||
|
|
||||||
const newSelectedImageId = filteredIds[newSelectedImageIndex];
|
const newSelectedImageId = filteredIds[newSelectedImageIndex];
|
||||||
|
|
||||||
const newSelectedImage = allEntities[newSelectedImageId];
|
const newSelectedImage = entities[newSelectedImageId];
|
||||||
|
|
||||||
if (newSelectedImageId) {
|
if (newSelectedImageId) {
|
||||||
dispatch(imageSelected(newSelectedImage));
|
dispatch(imageSelected(newSelectedImage));
|
||||||
@ -53,7 +58,52 @@ export const addRequestedImageDeletionListener = () => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
dispatch(imageDeleted({ imageName: image_name, imageType: image_type }));
|
dispatch(imageRemoved(image_name));
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
imageDeleted({ imageName: image_name, imageOrigin: image_origin })
|
||||||
|
);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Called when the actual delete request is sent to the server
|
||||||
|
*/
|
||||||
|
export const addImageDeletedPendingListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: imageDeleted.pending,
|
||||||
|
effect: (action, { dispatch, getState }) => {
|
||||||
|
const { imageName, imageOrigin } = action.meta.arg;
|
||||||
|
// Preemptively remove the image from the gallery
|
||||||
|
imagesAdapter.removeOne(getState().images, imageName);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Called on successful delete
|
||||||
|
*/
|
||||||
|
export const addImageDeletedFulfilledListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: imageDeleted.fulfilled,
|
||||||
|
effect: (action, { dispatch, getState }) => {
|
||||||
|
moduleLog.debug({ data: { image: action.meta.arg } }, 'Image deleted');
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Called on failed delete
|
||||||
|
*/
|
||||||
|
export const addImageDeletedRejectedListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: imageDeleted.rejected,
|
||||||
|
effect: (action, { dispatch, getState }) => {
|
||||||
|
moduleLog.debug(
|
||||||
|
{ data: { image: action.meta.arg } },
|
||||||
|
'Unable to delete image'
|
||||||
|
);
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
@ -0,0 +1,33 @@
|
|||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import { startAppListening } from '..';
|
||||||
|
import { imageMetadataReceived } from 'services/thunks/image';
|
||||||
|
import { imageUpserted } from 'features/gallery/store/imagesSlice';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'image' });
|
||||||
|
|
||||||
|
export const addImageMetadataReceivedFulfilledListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: imageMetadataReceived.fulfilled,
|
||||||
|
effect: (action, { getState, dispatch }) => {
|
||||||
|
const image = action.payload;
|
||||||
|
if (image.is_intermediate) {
|
||||||
|
// No further actions needed for intermediate images
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
moduleLog.debug({ data: { image } }, 'Image metadata received');
|
||||||
|
dispatch(imageUpserted(image));
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
export const addImageMetadataReceivedRejectedListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: imageMetadataReceived.rejected,
|
||||||
|
effect: (action, { getState, dispatch }) => {
|
||||||
|
moduleLog.debug(
|
||||||
|
{ data: { image: action.meta.arg } },
|
||||||
|
'Problem receiving image metadata'
|
||||||
|
);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -0,0 +1,26 @@
|
|||||||
|
import { startAppListening } from '..';
|
||||||
|
import { imageUpdated } from 'services/thunks/image';
|
||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'image' });
|
||||||
|
|
||||||
|
export const addImageUpdatedFulfilledListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: imageUpdated.fulfilled,
|
||||||
|
effect: (action, { dispatch, getState }) => {
|
||||||
|
moduleLog.debug(
|
||||||
|
{ oldImage: action.meta.arg, updatedImage: action.payload },
|
||||||
|
'Image updated'
|
||||||
|
);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
export const addImageUpdatedRejectedListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: imageUpdated.rejected,
|
||||||
|
effect: (action, { dispatch }) => {
|
||||||
|
moduleLog.debug({ oldImage: action.meta.arg }, 'Image update failed');
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -1,44 +1,46 @@
|
|||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
import { uploadAdded } from 'features/gallery/store/uploadsSlice';
|
|
||||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
|
||||||
import { imageUploaded } from 'services/thunks/image';
|
import { imageUploaded } from 'services/thunks/image';
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
import { initialImageSelected } from 'features/parameters/store/actions';
|
import { log } from 'app/logging/useLogger';
|
||||||
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
import { imageUpserted } from 'features/gallery/store/imagesSlice';
|
||||||
import { resultAdded } from 'features/gallery/store/resultsSlice';
|
|
||||||
import { isResultsImageDTO, isUploadsImageDTO } from 'services/types/guards';
|
|
||||||
|
|
||||||
export const addImageUploadedListener = () => {
|
const moduleLog = log.child({ namespace: 'image' });
|
||||||
|
|
||||||
|
export const addImageUploadedFulfilledListener = () => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
predicate: (action): action is ReturnType<typeof imageUploaded.fulfilled> =>
|
actionCreator: imageUploaded.fulfilled,
|
||||||
imageUploaded.fulfilled.match(action) &&
|
|
||||||
action.payload.response.image_type !== 'intermediates',
|
|
||||||
effect: (action, { dispatch, getState }) => {
|
effect: (action, { dispatch, getState }) => {
|
||||||
const { response: image } = action.payload;
|
const image = action.payload;
|
||||||
|
|
||||||
|
moduleLog.debug({ arg: '<Blob>', image }, 'Image uploaded');
|
||||||
|
|
||||||
|
if (action.payload.is_intermediate) {
|
||||||
|
// No further actions needed for intermediate images
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const state = getState();
|
const state = getState();
|
||||||
|
|
||||||
if (isUploadsImageDTO(image)) {
|
dispatch(imageUpserted(image));
|
||||||
dispatch(uploadAdded(image));
|
|
||||||
|
|
||||||
dispatch(addToast({ title: 'Image Uploaded', status: 'success' }));
|
dispatch(addToast({ title: 'Image Uploaded', status: 'success' }));
|
||||||
|
},
|
||||||
if (state.gallery.shouldAutoSwitchToNewImages) {
|
});
|
||||||
dispatch(imageSelected(image));
|
};
|
||||||
}
|
|
||||||
|
export const addImageUploadedRejectedListener = () => {
|
||||||
if (action.meta.arg.activeTabName === 'img2img') {
|
startAppListening({
|
||||||
dispatch(initialImageSelected(image));
|
actionCreator: imageUploaded.rejected,
|
||||||
}
|
effect: (action, { dispatch }) => {
|
||||||
|
const { formData, ...rest } = action.meta.arg;
|
||||||
if (action.meta.arg.activeTabName === 'unifiedCanvas') {
|
const sanitizedData = { arg: { ...rest, formData: { file: '<Blob>' } } };
|
||||||
dispatch(setInitialCanvasImage(image));
|
moduleLog.error({ data: sanitizedData }, 'Image upload failed');
|
||||||
}
|
dispatch(
|
||||||
}
|
addToast({
|
||||||
|
title: 'Image Upload Failed',
|
||||||
if (isResultsImageDTO(image)) {
|
description: action.error.message,
|
||||||
dispatch(resultAdded(image));
|
status: 'error',
|
||||||
}
|
})
|
||||||
|
);
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
@ -0,0 +1,38 @@
|
|||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import { startAppListening } from '..';
|
||||||
|
import { imageUrlsReceived } from 'services/thunks/image';
|
||||||
|
import { imagesAdapter } from 'features/gallery/store/imagesSlice';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'image' });
|
||||||
|
|
||||||
|
export const addImageUrlsReceivedFulfilledListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: imageUrlsReceived.fulfilled,
|
||||||
|
effect: (action, { getState, dispatch }) => {
|
||||||
|
const image = action.payload;
|
||||||
|
moduleLog.debug({ data: { image } }, 'Image URLs received');
|
||||||
|
|
||||||
|
const { image_name, image_url, thumbnail_url } = image;
|
||||||
|
|
||||||
|
imagesAdapter.updateOne(getState().images, {
|
||||||
|
id: image_name,
|
||||||
|
changes: {
|
||||||
|
image_url,
|
||||||
|
thumbnail_url,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
export const addImageUrlsReceivedRejectedListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: imageUrlsReceived.rejected,
|
||||||
|
effect: (action, { getState, dispatch }) => {
|
||||||
|
moduleLog.debug(
|
||||||
|
{ data: { image: action.meta.arg } },
|
||||||
|
'Problem getting image URLs'
|
||||||
|
);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -1,6 +1,4 @@
|
|||||||
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
||||||
import { selectResultsById } from 'features/gallery/store/resultsSlice';
|
|
||||||
import { selectUploadsById } from 'features/gallery/store/uploadsSlice';
|
|
||||||
import { t } from 'i18next';
|
import { t } from 'i18next';
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
@ -9,7 +7,7 @@ import {
|
|||||||
isImageDTO,
|
isImageDTO,
|
||||||
} from 'features/parameters/store/actions';
|
} from 'features/parameters/store/actions';
|
||||||
import { makeToast } from 'app/components/Toaster';
|
import { makeToast } from 'app/components/Toaster';
|
||||||
import { ImageDTO } from 'services/api';
|
import { selectImagesById } from 'features/gallery/store/imagesSlice';
|
||||||
|
|
||||||
export const addInitialImageSelectedListener = () => {
|
export const addInitialImageSelectedListener = () => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
@ -30,16 +28,8 @@ export const addInitialImageSelectedListener = () => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const { image_name, image_type } = action.payload;
|
const imageName = action.payload;
|
||||||
|
const image = selectImagesById(getState(), imageName);
|
||||||
let image: ImageDTO | undefined;
|
|
||||||
const state = getState();
|
|
||||||
|
|
||||||
if (image_type === 'results') {
|
|
||||||
image = selectResultsById(state, image_name);
|
|
||||||
} else if (image_type === 'uploads') {
|
|
||||||
image = selectUploadsById(state, image_name);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!image) {
|
if (!image) {
|
||||||
dispatch(
|
dispatch(
|
||||||
|
@ -1,62 +0,0 @@
|
|||||||
import { invocationComplete } from 'services/events/actions';
|
|
||||||
import { isImageOutput } from 'services/types/guards';
|
|
||||||
import {
|
|
||||||
imageMetadataReceived,
|
|
||||||
imageUrlsReceived,
|
|
||||||
} from 'services/thunks/image';
|
|
||||||
import { startAppListening } from '..';
|
|
||||||
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
|
|
||||||
|
|
||||||
const nodeDenylist = ['dataURL_image'];
|
|
||||||
|
|
||||||
export const addImageResultReceivedListener = () => {
|
|
||||||
startAppListening({
|
|
||||||
predicate: (action) => {
|
|
||||||
if (
|
|
||||||
invocationComplete.match(action) &&
|
|
||||||
isImageOutput(action.payload.data.result)
|
|
||||||
) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
},
|
|
||||||
effect: async (action, { getState, dispatch, take }) => {
|
|
||||||
if (!invocationComplete.match(action)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const { data } = action.payload;
|
|
||||||
const { result, node, graph_execution_state_id } = data;
|
|
||||||
|
|
||||||
if (isImageOutput(result) && !nodeDenylist.includes(node.type)) {
|
|
||||||
const { image_name, image_type } = result.image;
|
|
||||||
|
|
||||||
dispatch(
|
|
||||||
imageUrlsReceived({ imageName: image_name, imageType: image_type })
|
|
||||||
);
|
|
||||||
|
|
||||||
dispatch(
|
|
||||||
imageMetadataReceived({
|
|
||||||
imageName: image_name,
|
|
||||||
imageType: image_type,
|
|
||||||
})
|
|
||||||
);
|
|
||||||
|
|
||||||
// Handle canvas image
|
|
||||||
if (
|
|
||||||
graph_execution_state_id ===
|
|
||||||
getState().canvas.layerState.stagingArea.sessionId
|
|
||||||
) {
|
|
||||||
const [{ payload: image }] = await take(
|
|
||||||
(
|
|
||||||
action
|
|
||||||
): action is ReturnType<typeof imageMetadataReceived.fulfilled> =>
|
|
||||||
imageMetadataReceived.fulfilled.match(action) &&
|
|
||||||
action.payload.image_name === image_name
|
|
||||||
);
|
|
||||||
dispatch(addImageToStagingArea(image));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
});
|
|
||||||
};
|
|
@ -0,0 +1,33 @@
|
|||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import { startAppListening } from '..';
|
||||||
|
import { serializeError } from 'serialize-error';
|
||||||
|
import { receivedPageOfImages } from 'services/thunks/image';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'gallery' });
|
||||||
|
|
||||||
|
export const addReceivedPageOfImagesFulfilledListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: receivedPageOfImages.fulfilled,
|
||||||
|
effect: (action, { getState, dispatch }) => {
|
||||||
|
const page = action.payload;
|
||||||
|
moduleLog.debug(
|
||||||
|
{ data: { payload: action.payload } },
|
||||||
|
`Received ${page.items.length} images`
|
||||||
|
);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
export const addReceivedPageOfImagesRejectedListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: receivedPageOfImages.rejected,
|
||||||
|
effect: (action, { getState, dispatch }) => {
|
||||||
|
if (action.payload) {
|
||||||
|
moduleLog.debug(
|
||||||
|
{ data: { error: serializeError(action.payload) } },
|
||||||
|
'Problem receiving images'
|
||||||
|
);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -0,0 +1,48 @@
|
|||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import { startAppListening } from '..';
|
||||||
|
import { sessionCanceled } from 'services/thunks/session';
|
||||||
|
import { serializeError } from 'serialize-error';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'session' });
|
||||||
|
|
||||||
|
export const addSessionCanceledPendingListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: sessionCanceled.pending,
|
||||||
|
effect: (action, { getState, dispatch }) => {
|
||||||
|
//
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
export const addSessionCanceledFulfilledListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: sessionCanceled.fulfilled,
|
||||||
|
effect: (action, { getState, dispatch }) => {
|
||||||
|
const { sessionId } = action.meta.arg;
|
||||||
|
moduleLog.debug(
|
||||||
|
{ data: { sessionId } },
|
||||||
|
`Session canceled (${sessionId})`
|
||||||
|
);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
export const addSessionCanceledRejectedListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: sessionCanceled.rejected,
|
||||||
|
effect: (action, { getState, dispatch }) => {
|
||||||
|
if (action.payload) {
|
||||||
|
const { arg, error } = action.payload;
|
||||||
|
moduleLog.error(
|
||||||
|
{
|
||||||
|
data: {
|
||||||
|
arg,
|
||||||
|
error: serializeError(error),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
`Problem canceling session`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -0,0 +1,45 @@
|
|||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import { startAppListening } from '..';
|
||||||
|
import { sessionCreated } from 'services/thunks/session';
|
||||||
|
import { serializeError } from 'serialize-error';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'session' });
|
||||||
|
|
||||||
|
export const addSessionCreatedPendingListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: sessionCreated.pending,
|
||||||
|
effect: (action, { getState, dispatch }) => {
|
||||||
|
//
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
export const addSessionCreatedFulfilledListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: sessionCreated.fulfilled,
|
||||||
|
effect: (action, { getState, dispatch }) => {
|
||||||
|
const session = action.payload;
|
||||||
|
moduleLog.debug({ data: { session } }, `Session created (${session.id})`);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
export const addSessionCreatedRejectedListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: sessionCreated.rejected,
|
||||||
|
effect: (action, { getState, dispatch }) => {
|
||||||
|
if (action.payload) {
|
||||||
|
const { arg, error } = action.payload;
|
||||||
|
moduleLog.error(
|
||||||
|
{
|
||||||
|
data: {
|
||||||
|
arg,
|
||||||
|
error: serializeError(error),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
`Problem creating session`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -0,0 +1,48 @@
|
|||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import { startAppListening } from '..';
|
||||||
|
import { sessionInvoked } from 'services/thunks/session';
|
||||||
|
import { serializeError } from 'serialize-error';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'session' });
|
||||||
|
|
||||||
|
export const addSessionInvokedPendingListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: sessionInvoked.pending,
|
||||||
|
effect: (action, { getState, dispatch }) => {
|
||||||
|
//
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
export const addSessionInvokedFulfilledListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: sessionInvoked.fulfilled,
|
||||||
|
effect: (action, { getState, dispatch }) => {
|
||||||
|
const { sessionId } = action.meta.arg;
|
||||||
|
moduleLog.debug(
|
||||||
|
{ data: { sessionId } },
|
||||||
|
`Session invoked (${sessionId})`
|
||||||
|
);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
export const addSessionInvokedRejectedListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: sessionInvoked.rejected,
|
||||||
|
effect: (action, { getState, dispatch }) => {
|
||||||
|
if (action.payload) {
|
||||||
|
const { arg, error } = action.payload;
|
||||||
|
moduleLog.error(
|
||||||
|
{
|
||||||
|
data: {
|
||||||
|
arg,
|
||||||
|
error: serializeError(error),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
`Problem invoking session`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -0,0 +1,22 @@
|
|||||||
|
import { startAppListening } from '..';
|
||||||
|
import { sessionInvoked } from 'services/thunks/session';
|
||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import { sessionReadyToInvoke } from 'features/system/store/actions';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'session' });
|
||||||
|
|
||||||
|
export const addSessionReadyToInvokeListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: sessionReadyToInvoke,
|
||||||
|
effect: (action, { getState, dispatch }) => {
|
||||||
|
const { sessionId } = getState().system;
|
||||||
|
if (sessionId) {
|
||||||
|
moduleLog.debug(
|
||||||
|
{ sessionId },
|
||||||
|
`Session ready to invoke (${sessionId})})`
|
||||||
|
);
|
||||||
|
dispatch(sessionInvoked({ sessionId }));
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -0,0 +1,38 @@
|
|||||||
|
import { startAppListening } from '../..';
|
||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import { appSocketConnected, socketConnected } from 'services/events/actions';
|
||||||
|
import { receivedPageOfImages } from 'services/thunks/image';
|
||||||
|
import { receivedModels } from 'services/thunks/model';
|
||||||
|
import { receivedOpenAPISchema } from 'services/thunks/schema';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'socketio' });
|
||||||
|
|
||||||
|
export const addSocketConnectedEventListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: socketConnected,
|
||||||
|
effect: (action, { dispatch, getState }) => {
|
||||||
|
const { timestamp } = action.payload;
|
||||||
|
|
||||||
|
moduleLog.debug({ timestamp }, 'Connected');
|
||||||
|
|
||||||
|
const { models, nodes, config, images } = getState();
|
||||||
|
|
||||||
|
const { disabledTabs } = config;
|
||||||
|
|
||||||
|
if (!images.ids.length) {
|
||||||
|
dispatch(receivedPageOfImages());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!models.ids.length) {
|
||||||
|
dispatch(receivedModels());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!nodes.schema && !disabledTabs.includes('nodes')) {
|
||||||
|
dispatch(receivedOpenAPISchema());
|
||||||
|
}
|
||||||
|
|
||||||
|
// pass along the socket event as an application action
|
||||||
|
dispatch(appSocketConnected(action.payload));
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -0,0 +1,19 @@
|
|||||||
|
import { startAppListening } from '../..';
|
||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import {
|
||||||
|
socketDisconnected,
|
||||||
|
appSocketDisconnected,
|
||||||
|
} from 'services/events/actions';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'socketio' });
|
||||||
|
|
||||||
|
export const addSocketDisconnectedEventListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: socketDisconnected,
|
||||||
|
effect: (action, { dispatch, getState }) => {
|
||||||
|
moduleLog.debug(action.payload, 'Disconnected');
|
||||||
|
// pass along the socket event as an application action
|
||||||
|
dispatch(appSocketDisconnected(action.payload));
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -0,0 +1,34 @@
|
|||||||
|
import { startAppListening } from '../..';
|
||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import {
|
||||||
|
appSocketGeneratorProgress,
|
||||||
|
socketGeneratorProgress,
|
||||||
|
} from 'services/events/actions';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'socketio' });
|
||||||
|
|
||||||
|
export const addGeneratorProgressEventListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: socketGeneratorProgress,
|
||||||
|
effect: (action, { dispatch, getState }) => {
|
||||||
|
if (
|
||||||
|
getState().system.canceledSession ===
|
||||||
|
action.payload.data.graph_execution_state_id
|
||||||
|
) {
|
||||||
|
moduleLog.trace(
|
||||||
|
action.payload,
|
||||||
|
'Ignored generator progress for canceled session'
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
moduleLog.trace(
|
||||||
|
action.payload,
|
||||||
|
`Generator progress (${action.payload.data.node.type})`
|
||||||
|
);
|
||||||
|
|
||||||
|
// pass along the socket event as an application action
|
||||||
|
dispatch(appSocketGeneratorProgress(action.payload));
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -0,0 +1,22 @@
|
|||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import {
|
||||||
|
appSocketGraphExecutionStateComplete,
|
||||||
|
socketGraphExecutionStateComplete,
|
||||||
|
} from 'services/events/actions';
|
||||||
|
import { startAppListening } from '../..';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'socketio' });
|
||||||
|
|
||||||
|
export const addGraphExecutionStateCompleteEventListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: socketGraphExecutionStateComplete,
|
||||||
|
effect: (action, { dispatch, getState }) => {
|
||||||
|
moduleLog.debug(
|
||||||
|
action.payload,
|
||||||
|
`Session invocation complete (${action.payload.data.graph_execution_state_id})`
|
||||||
|
);
|
||||||
|
// pass along the socket event as an application action
|
||||||
|
dispatch(appSocketGraphExecutionStateComplete(action.payload));
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -0,0 +1,67 @@
|
|||||||
|
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
|
||||||
|
import { startAppListening } from '../..';
|
||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import {
|
||||||
|
appSocketInvocationComplete,
|
||||||
|
socketInvocationComplete,
|
||||||
|
} from 'services/events/actions';
|
||||||
|
import { imageMetadataReceived } from 'services/thunks/image';
|
||||||
|
import { sessionCanceled } from 'services/thunks/session';
|
||||||
|
import { isImageOutput } from 'services/types/guards';
|
||||||
|
import { progressImageSet } from 'features/system/store/systemSlice';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'socketio' });
|
||||||
|
const nodeDenylist = ['dataURL_image'];
|
||||||
|
|
||||||
|
export const addInvocationCompleteEventListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: socketInvocationComplete,
|
||||||
|
effect: async (action, { dispatch, getState, take }) => {
|
||||||
|
moduleLog.debug(
|
||||||
|
action.payload,
|
||||||
|
`Invocation complete (${action.payload.data.node.type})`
|
||||||
|
);
|
||||||
|
|
||||||
|
const sessionId = action.payload.data.graph_execution_state_id;
|
||||||
|
|
||||||
|
const { cancelType, isCancelScheduled } = getState().system;
|
||||||
|
|
||||||
|
// Handle scheduled cancelation
|
||||||
|
if (cancelType === 'scheduled' && isCancelScheduled) {
|
||||||
|
dispatch(sessionCanceled({ sessionId }));
|
||||||
|
}
|
||||||
|
|
||||||
|
const { data } = action.payload;
|
||||||
|
const { result, node, graph_execution_state_id } = data;
|
||||||
|
|
||||||
|
// This complete event has an associated image output
|
||||||
|
if (isImageOutput(result) && !nodeDenylist.includes(node.type)) {
|
||||||
|
const { image_name, image_origin } = result.image;
|
||||||
|
|
||||||
|
// Get its metadata
|
||||||
|
dispatch(
|
||||||
|
imageMetadataReceived({
|
||||||
|
imageName: image_name,
|
||||||
|
imageOrigin: image_origin,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
|
const [{ payload: imageDTO }] = await take(
|
||||||
|
imageMetadataReceived.fulfilled.match
|
||||||
|
);
|
||||||
|
|
||||||
|
// Handle canvas image
|
||||||
|
if (
|
||||||
|
graph_execution_state_id ===
|
||||||
|
getState().canvas.layerState.stagingArea.sessionId
|
||||||
|
) {
|
||||||
|
dispatch(addImageToStagingArea(imageDTO));
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(progressImageSet(null));
|
||||||
|
}
|
||||||
|
// pass along the socket event as an application action
|
||||||
|
dispatch(appSocketInvocationComplete(action.payload));
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -0,0 +1,21 @@
|
|||||||
|
import { startAppListening } from '../..';
|
||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import {
|
||||||
|
appSocketInvocationError,
|
||||||
|
socketInvocationError,
|
||||||
|
} from 'services/events/actions';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'socketio' });
|
||||||
|
|
||||||
|
export const addInvocationErrorEventListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: socketInvocationError,
|
||||||
|
effect: (action, { dispatch, getState }) => {
|
||||||
|
moduleLog.error(
|
||||||
|
action.payload,
|
||||||
|
`Invocation error (${action.payload.data.node.type})`
|
||||||
|
);
|
||||||
|
dispatch(appSocketInvocationError(action.payload));
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -0,0 +1,32 @@
|
|||||||
|
import { startAppListening } from '../..';
|
||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import {
|
||||||
|
appSocketInvocationStarted,
|
||||||
|
socketInvocationStarted,
|
||||||
|
} from 'services/events/actions';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'socketio' });
|
||||||
|
|
||||||
|
export const addInvocationStartedEventListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: socketInvocationStarted,
|
||||||
|
effect: (action, { dispatch, getState }) => {
|
||||||
|
if (
|
||||||
|
getState().system.canceledSession ===
|
||||||
|
action.payload.data.graph_execution_state_id
|
||||||
|
) {
|
||||||
|
moduleLog.trace(
|
||||||
|
action.payload,
|
||||||
|
'Ignored invocation started for canceled session'
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
moduleLog.debug(
|
||||||
|
action.payload,
|
||||||
|
`Invocation started (${action.payload.data.node.type})`
|
||||||
|
);
|
||||||
|
dispatch(appSocketInvocationStarted(action.payload));
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -0,0 +1,18 @@
|
|||||||
|
import { startAppListening } from '../..';
|
||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import { appSocketSubscribed, socketSubscribed } from 'services/events/actions';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'socketio' });
|
||||||
|
|
||||||
|
export const addSocketSubscribedEventListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: socketSubscribed,
|
||||||
|
effect: (action, { dispatch, getState }) => {
|
||||||
|
moduleLog.debug(
|
||||||
|
action.payload,
|
||||||
|
`Subscribed (${action.payload.sessionId}))`
|
||||||
|
);
|
||||||
|
dispatch(appSocketSubscribed(action.payload));
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -0,0 +1,21 @@
|
|||||||
|
import { startAppListening } from '../..';
|
||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import {
|
||||||
|
appSocketUnsubscribed,
|
||||||
|
socketUnsubscribed,
|
||||||
|
} from 'services/events/actions';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'socketio' });
|
||||||
|
|
||||||
|
export const addSocketUnsubscribedEventListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: socketUnsubscribed,
|
||||||
|
effect: (action, { dispatch, getState }) => {
|
||||||
|
moduleLog.debug(
|
||||||
|
action.payload,
|
||||||
|
`Unsubscribed (${action.payload.sessionId})`
|
||||||
|
);
|
||||||
|
dispatch(appSocketUnsubscribed(action.payload));
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -0,0 +1,54 @@
|
|||||||
|
import { stagingAreaImageSaved } from 'features/canvas/store/actions';
|
||||||
|
import { startAppListening } from '..';
|
||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import { imageUpdated } from 'services/thunks/image';
|
||||||
|
import { imageUpserted } from 'features/gallery/store/imagesSlice';
|
||||||
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'canvas' });
|
||||||
|
|
||||||
|
export const addStagingAreaImageSavedListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: stagingAreaImageSaved,
|
||||||
|
effect: async (action, { dispatch, getState, take }) => {
|
||||||
|
const { image_name, image_origin } = action.payload;
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
imageUpdated({
|
||||||
|
imageName: image_name,
|
||||||
|
imageOrigin: image_origin,
|
||||||
|
requestBody: {
|
||||||
|
is_intermediate: false,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
|
const [imageUpdatedAction] = await take(
|
||||||
|
(action) =>
|
||||||
|
(imageUpdated.fulfilled.match(action) ||
|
||||||
|
imageUpdated.rejected.match(action)) &&
|
||||||
|
action.meta.arg.imageName === image_name
|
||||||
|
);
|
||||||
|
|
||||||
|
if (imageUpdated.rejected.match(imageUpdatedAction)) {
|
||||||
|
moduleLog.error(
|
||||||
|
{ data: { arg: imageUpdatedAction.meta.arg } },
|
||||||
|
'Image saving failed'
|
||||||
|
);
|
||||||
|
dispatch(
|
||||||
|
addToast({
|
||||||
|
title: 'Image Saving Failed',
|
||||||
|
description: imageUpdatedAction.error.message,
|
||||||
|
status: 'error',
|
||||||
|
})
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (imageUpdated.fulfilled.match(imageUpdatedAction)) {
|
||||||
|
dispatch(imageUpserted(imageUpdatedAction.payload));
|
||||||
|
dispatch(addToast({ title: 'Image Saved', status: 'success' }));
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -1,9 +1,9 @@
|
|||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
import { sessionCreated, sessionInvoked } from 'services/thunks/session';
|
import { sessionCreated } from 'services/thunks/session';
|
||||||
import { buildCanvasGraphComponents } from 'features/nodes/util/graphBuilders/buildCanvasGraph';
|
import { buildCanvasGraphComponents } from 'features/nodes/util/graphBuilders/buildCanvasGraph';
|
||||||
import { log } from 'app/logging/useLogger';
|
import { log } from 'app/logging/useLogger';
|
||||||
import { canvasGraphBuilt } from 'features/nodes/store/actions';
|
import { canvasGraphBuilt } from 'features/nodes/store/actions';
|
||||||
import { imageUploaded } from 'services/thunks/image';
|
import { imageUpdated, imageUploaded } from 'services/thunks/image';
|
||||||
import { v4 as uuidv4 } from 'uuid';
|
import { v4 as uuidv4 } from 'uuid';
|
||||||
import { Graph } from 'services/api';
|
import { Graph } from 'services/api';
|
||||||
import {
|
import {
|
||||||
@ -15,12 +15,22 @@ import { getCanvasData } from 'features/canvas/util/getCanvasData';
|
|||||||
import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGenerationMode';
|
import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGenerationMode';
|
||||||
import { blobToDataURL } from 'features/canvas/util/blobToDataURL';
|
import { blobToDataURL } from 'features/canvas/util/blobToDataURL';
|
||||||
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
|
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
|
||||||
|
import { sessionReadyToInvoke } from 'features/system/store/actions';
|
||||||
|
|
||||||
const moduleLog = log.child({ namespace: 'invoke' });
|
const moduleLog = log.child({ namespace: 'invoke' });
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This listener is responsible for building the canvas graph and blobs when the user invokes the canvas.
|
* This listener is responsible invoking the canvas. This involves a number of steps:
|
||||||
* It is also responsible for uploading the base and mask layers to the server.
|
*
|
||||||
|
* 1. Generate image blobs from the canvas layers
|
||||||
|
* 2. Determine the generation mode from the layers (txt2img, img2img, inpaint)
|
||||||
|
* 3. Build the canvas graph
|
||||||
|
* 4. Create the session with the graph
|
||||||
|
* 5. Upload the init image if necessary
|
||||||
|
* 6. Upload the mask image if necessary
|
||||||
|
* 7. Update the init and mask images with the session ID
|
||||||
|
* 8. Initialize the staging area if not yet initialized
|
||||||
|
* 9. Dispatch the sessionReadyToInvoke action to invoke the session
|
||||||
*/
|
*/
|
||||||
export const addUserInvokedCanvasListener = () => {
|
export const addUserInvokedCanvasListener = () => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
@ -70,63 +80,7 @@ export const addUserInvokedCanvasListener = () => {
|
|||||||
|
|
||||||
const { rangeNode, iterateNode, baseNode, edges } = graphComponents;
|
const { rangeNode, iterateNode, baseNode, edges } = graphComponents;
|
||||||
|
|
||||||
// Upload the base layer, to be used as init image
|
// Assemble! Note that this graph *does not have the init or mask image set yet!*
|
||||||
const baseFilename = `${uuidv4()}.png`;
|
|
||||||
|
|
||||||
dispatch(
|
|
||||||
imageUploaded({
|
|
||||||
imageType: 'intermediates',
|
|
||||||
formData: {
|
|
||||||
file: new File([baseBlob], baseFilename, { type: 'image/png' }),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
);
|
|
||||||
|
|
||||||
if (baseNode.type === 'img2img' || baseNode.type === 'inpaint') {
|
|
||||||
const [{ payload: basePayload }] = await take(
|
|
||||||
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
|
|
||||||
imageUploaded.fulfilled.match(action) &&
|
|
||||||
action.meta.arg.formData.file.name === baseFilename
|
|
||||||
);
|
|
||||||
|
|
||||||
const { image_name: baseName, image_type: baseType } =
|
|
||||||
basePayload.response;
|
|
||||||
|
|
||||||
baseNode.image = {
|
|
||||||
image_name: baseName,
|
|
||||||
image_type: baseType,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Upload the mask layer image
|
|
||||||
const maskFilename = `${uuidv4()}.png`;
|
|
||||||
|
|
||||||
if (baseNode.type === 'inpaint') {
|
|
||||||
dispatch(
|
|
||||||
imageUploaded({
|
|
||||||
imageType: 'intermediates',
|
|
||||||
formData: {
|
|
||||||
file: new File([maskBlob], maskFilename, { type: 'image/png' }),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
);
|
|
||||||
|
|
||||||
const [{ payload: maskPayload }] = await take(
|
|
||||||
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
|
|
||||||
imageUploaded.fulfilled.match(action) &&
|
|
||||||
action.meta.arg.formData.file.name === maskFilename
|
|
||||||
);
|
|
||||||
|
|
||||||
const { image_name: maskName, image_type: maskType } =
|
|
||||||
maskPayload.response;
|
|
||||||
|
|
||||||
baseNode.mask = {
|
|
||||||
image_name: maskName,
|
|
||||||
image_type: maskType,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Assemble!
|
|
||||||
const nodes: Graph['nodes'] = {
|
const nodes: Graph['nodes'] = {
|
||||||
[rangeNode.id]: rangeNode,
|
[rangeNode.id]: rangeNode,
|
||||||
[iterateNode.id]: iterateNode,
|
[iterateNode.id]: iterateNode,
|
||||||
@ -136,15 +90,92 @@ export const addUserInvokedCanvasListener = () => {
|
|||||||
const graph = { nodes, edges };
|
const graph = { nodes, edges };
|
||||||
|
|
||||||
dispatch(canvasGraphBuilt(graph));
|
dispatch(canvasGraphBuilt(graph));
|
||||||
moduleLog({ data: graph }, 'Canvas graph built');
|
|
||||||
|
|
||||||
// Actually create the session
|
moduleLog.debug({ data: graph }, 'Canvas graph built');
|
||||||
|
|
||||||
|
// If we are generating img2img or inpaint, we need to upload the init images
|
||||||
|
if (baseNode.type === 'img2img' || baseNode.type === 'inpaint') {
|
||||||
|
const baseFilename = `${uuidv4()}.png`;
|
||||||
|
dispatch(
|
||||||
|
imageUploaded({
|
||||||
|
formData: {
|
||||||
|
file: new File([baseBlob], baseFilename, { type: 'image/png' }),
|
||||||
|
},
|
||||||
|
imageCategory: 'general',
|
||||||
|
isIntermediate: true,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
|
// Wait for the image to be uploaded
|
||||||
|
const [{ payload: baseImageDTO }] = await take(
|
||||||
|
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
|
||||||
|
imageUploaded.fulfilled.match(action) &&
|
||||||
|
action.meta.arg.formData.file.name === baseFilename
|
||||||
|
);
|
||||||
|
|
||||||
|
// Update the base node with the image name and type
|
||||||
|
baseNode.image = {
|
||||||
|
image_name: baseImageDTO.image_name,
|
||||||
|
image_origin: baseImageDTO.image_origin,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// For inpaint, we also need to upload the mask layer
|
||||||
|
if (baseNode.type === 'inpaint') {
|
||||||
|
const maskFilename = `${uuidv4()}.png`;
|
||||||
|
dispatch(
|
||||||
|
imageUploaded({
|
||||||
|
formData: {
|
||||||
|
file: new File([maskBlob], maskFilename, { type: 'image/png' }),
|
||||||
|
},
|
||||||
|
imageCategory: 'mask',
|
||||||
|
isIntermediate: true,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
|
// Wait for the mask to be uploaded
|
||||||
|
const [{ payload: maskImageDTO }] = await take(
|
||||||
|
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
|
||||||
|
imageUploaded.fulfilled.match(action) &&
|
||||||
|
action.meta.arg.formData.file.name === maskFilename
|
||||||
|
);
|
||||||
|
|
||||||
|
// Update the base node with the image name and type
|
||||||
|
baseNode.mask = {
|
||||||
|
image_name: maskImageDTO.image_name,
|
||||||
|
image_origin: maskImageDTO.image_origin,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the session and wait for response
|
||||||
dispatch(sessionCreated({ graph }));
|
dispatch(sessionCreated({ graph }));
|
||||||
|
const [sessionCreatedAction] = await take(sessionCreated.fulfilled.match);
|
||||||
|
const sessionId = sessionCreatedAction.payload.id;
|
||||||
|
|
||||||
// Wait for the session to be invoked (this is just the HTTP request to start processing)
|
// Associate the init image with the session, now that we have the session ID
|
||||||
const [{ meta }] = await take(sessionInvoked.fulfilled.match);
|
if (
|
||||||
|
(baseNode.type === 'img2img' || baseNode.type === 'inpaint') &&
|
||||||
|
baseNode.image
|
||||||
|
) {
|
||||||
|
dispatch(
|
||||||
|
imageUpdated({
|
||||||
|
imageName: baseNode.image.image_name,
|
||||||
|
imageOrigin: baseNode.image.image_origin,
|
||||||
|
requestBody: { session_id: sessionId },
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
const { sessionId } = meta.arg;
|
// Associate the mask image with the session, now that we have the session ID
|
||||||
|
if (baseNode.type === 'inpaint' && baseNode.mask) {
|
||||||
|
dispatch(
|
||||||
|
imageUpdated({
|
||||||
|
imageName: baseNode.mask.image_name,
|
||||||
|
imageOrigin: baseNode.mask.image_origin,
|
||||||
|
requestBody: { session_id: sessionId },
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
if (!state.canvas.layerState.stagingArea.boundingBox) {
|
if (!state.canvas.layerState.stagingArea.boundingBox) {
|
||||||
dispatch(
|
dispatch(
|
||||||
@ -158,7 +189,11 @@ export const addUserInvokedCanvasListener = () => {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Flag the session with the canvas session ID
|
||||||
dispatch(canvasSessionIdChanged(sessionId));
|
dispatch(canvasSessionIdChanged(sessionId));
|
||||||
|
|
||||||
|
// We are ready to invoke the session!
|
||||||
|
dispatch(sessionReadyToInvoke());
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
@ -4,6 +4,7 @@ import { sessionCreated } from 'services/thunks/session';
|
|||||||
import { log } from 'app/logging/useLogger';
|
import { log } from 'app/logging/useLogger';
|
||||||
import { imageToImageGraphBuilt } from 'features/nodes/store/actions';
|
import { imageToImageGraphBuilt } from 'features/nodes/store/actions';
|
||||||
import { userInvoked } from 'app/store/actions';
|
import { userInvoked } from 'app/store/actions';
|
||||||
|
import { sessionReadyToInvoke } from 'features/system/store/actions';
|
||||||
|
|
||||||
const moduleLog = log.child({ namespace: 'invoke' });
|
const moduleLog = log.child({ namespace: 'invoke' });
|
||||||
|
|
||||||
@ -11,14 +12,18 @@ export const addUserInvokedImageToImageListener = () => {
|
|||||||
startAppListening({
|
startAppListening({
|
||||||
predicate: (action): action is ReturnType<typeof userInvoked> =>
|
predicate: (action): action is ReturnType<typeof userInvoked> =>
|
||||||
userInvoked.match(action) && action.payload === 'img2img',
|
userInvoked.match(action) && action.payload === 'img2img',
|
||||||
effect: (action, { getState, dispatch }) => {
|
effect: async (action, { getState, dispatch, take }) => {
|
||||||
const state = getState();
|
const state = getState();
|
||||||
|
|
||||||
const graph = buildImageToImageGraph(state);
|
const graph = buildImageToImageGraph(state);
|
||||||
dispatch(imageToImageGraphBuilt(graph));
|
dispatch(imageToImageGraphBuilt(graph));
|
||||||
moduleLog({ data: graph }, 'Image to Image graph built');
|
moduleLog.debug({ data: graph }, 'Image to Image graph built');
|
||||||
|
|
||||||
dispatch(sessionCreated({ graph }));
|
dispatch(sessionCreated({ graph }));
|
||||||
|
|
||||||
|
await take(sessionCreated.fulfilled.match);
|
||||||
|
|
||||||
|
dispatch(sessionReadyToInvoke());
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
@ -4,6 +4,7 @@ import { buildNodesGraph } from 'features/nodes/util/graphBuilders/buildNodesGra
|
|||||||
import { log } from 'app/logging/useLogger';
|
import { log } from 'app/logging/useLogger';
|
||||||
import { nodesGraphBuilt } from 'features/nodes/store/actions';
|
import { nodesGraphBuilt } from 'features/nodes/store/actions';
|
||||||
import { userInvoked } from 'app/store/actions';
|
import { userInvoked } from 'app/store/actions';
|
||||||
|
import { sessionReadyToInvoke } from 'features/system/store/actions';
|
||||||
|
|
||||||
const moduleLog = log.child({ namespace: 'invoke' });
|
const moduleLog = log.child({ namespace: 'invoke' });
|
||||||
|
|
||||||
@ -11,14 +12,18 @@ export const addUserInvokedNodesListener = () => {
|
|||||||
startAppListening({
|
startAppListening({
|
||||||
predicate: (action): action is ReturnType<typeof userInvoked> =>
|
predicate: (action): action is ReturnType<typeof userInvoked> =>
|
||||||
userInvoked.match(action) && action.payload === 'nodes',
|
userInvoked.match(action) && action.payload === 'nodes',
|
||||||
effect: (action, { getState, dispatch }) => {
|
effect: async (action, { getState, dispatch, take }) => {
|
||||||
const state = getState();
|
const state = getState();
|
||||||
|
|
||||||
const graph = buildNodesGraph(state);
|
const graph = buildNodesGraph(state);
|
||||||
dispatch(nodesGraphBuilt(graph));
|
dispatch(nodesGraphBuilt(graph));
|
||||||
moduleLog({ data: graph }, 'Nodes graph built');
|
moduleLog.debug({ data: graph }, 'Nodes graph built');
|
||||||
|
|
||||||
dispatch(sessionCreated({ graph }));
|
dispatch(sessionCreated({ graph }));
|
||||||
|
|
||||||
|
await take(sessionCreated.fulfilled.match);
|
||||||
|
|
||||||
|
dispatch(sessionReadyToInvoke());
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
@ -4,6 +4,7 @@ import { sessionCreated } from 'services/thunks/session';
|
|||||||
import { log } from 'app/logging/useLogger';
|
import { log } from 'app/logging/useLogger';
|
||||||
import { textToImageGraphBuilt } from 'features/nodes/store/actions';
|
import { textToImageGraphBuilt } from 'features/nodes/store/actions';
|
||||||
import { userInvoked } from 'app/store/actions';
|
import { userInvoked } from 'app/store/actions';
|
||||||
|
import { sessionReadyToInvoke } from 'features/system/store/actions';
|
||||||
|
|
||||||
const moduleLog = log.child({ namespace: 'invoke' });
|
const moduleLog = log.child({ namespace: 'invoke' });
|
||||||
|
|
||||||
@ -11,14 +12,20 @@ export const addUserInvokedTextToImageListener = () => {
|
|||||||
startAppListening({
|
startAppListening({
|
||||||
predicate: (action): action is ReturnType<typeof userInvoked> =>
|
predicate: (action): action is ReturnType<typeof userInvoked> =>
|
||||||
userInvoked.match(action) && action.payload === 'txt2img',
|
userInvoked.match(action) && action.payload === 'txt2img',
|
||||||
effect: (action, { getState, dispatch }) => {
|
effect: async (action, { getState, dispatch, take }) => {
|
||||||
const state = getState();
|
const state = getState();
|
||||||
|
|
||||||
const graph = buildTextToImageGraph(state);
|
const graph = buildTextToImageGraph(state);
|
||||||
|
|
||||||
dispatch(textToImageGraphBuilt(graph));
|
dispatch(textToImageGraphBuilt(graph));
|
||||||
moduleLog({ data: graph }, 'Text to Image graph built');
|
|
||||||
|
moduleLog.debug({ data: graph }, 'Text to Image graph built');
|
||||||
|
|
||||||
dispatch(sessionCreated({ graph }));
|
dispatch(sessionCreated({ graph }));
|
||||||
|
|
||||||
|
await take(sessionCreated.fulfilled.match);
|
||||||
|
|
||||||
|
dispatch(sessionReadyToInvoke());
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
@ -10,12 +10,12 @@ import dynamicMiddlewares from 'redux-dynamic-middlewares';
|
|||||||
|
|
||||||
import canvasReducer from 'features/canvas/store/canvasSlice';
|
import canvasReducer from 'features/canvas/store/canvasSlice';
|
||||||
import galleryReducer from 'features/gallery/store/gallerySlice';
|
import galleryReducer from 'features/gallery/store/gallerySlice';
|
||||||
import resultsReducer from 'features/gallery/store/resultsSlice';
|
import imagesReducer from 'features/gallery/store/imagesSlice';
|
||||||
import uploadsReducer from 'features/gallery/store/uploadsSlice';
|
|
||||||
import lightboxReducer from 'features/lightbox/store/lightboxSlice';
|
import lightboxReducer from 'features/lightbox/store/lightboxSlice';
|
||||||
import generationReducer from 'features/parameters/store/generationSlice';
|
import generationReducer from 'features/parameters/store/generationSlice';
|
||||||
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
|
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
|
||||||
import systemReducer from 'features/system/store/systemSlice';
|
import systemReducer from 'features/system/store/systemSlice';
|
||||||
|
// import sessionReducer from 'features/system/store/sessionSlice';
|
||||||
import configReducer from 'features/system/store/configSlice';
|
import configReducer from 'features/system/store/configSlice';
|
||||||
import uiReducer from 'features/ui/store/uiSlice';
|
import uiReducer from 'features/ui/store/uiSlice';
|
||||||
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
|
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
|
||||||
@ -40,12 +40,12 @@ const allReducers = {
|
|||||||
models: modelsReducer,
|
models: modelsReducer,
|
||||||
nodes: nodesReducer,
|
nodes: nodesReducer,
|
||||||
postprocessing: postprocessingReducer,
|
postprocessing: postprocessingReducer,
|
||||||
results: resultsReducer,
|
|
||||||
system: systemReducer,
|
system: systemReducer,
|
||||||
config: configReducer,
|
config: configReducer,
|
||||||
ui: uiReducer,
|
ui: uiReducer,
|
||||||
uploads: uploadsReducer,
|
|
||||||
hotkeys: hotkeysReducer,
|
hotkeys: hotkeysReducer,
|
||||||
|
images: imagesReducer,
|
||||||
|
// session: sessionReducer,
|
||||||
};
|
};
|
||||||
|
|
||||||
const rootReducer = combineReducers(allReducers);
|
const rootReducer = combineReducers(allReducers);
|
||||||
@ -63,8 +63,6 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
|
|||||||
'system',
|
'system',
|
||||||
'ui',
|
'ui',
|
||||||
// 'hotkeys',
|
// 'hotkeys',
|
||||||
// 'results',
|
|
||||||
// 'uploads',
|
|
||||||
// 'config',
|
// 'config',
|
||||||
];
|
];
|
||||||
|
|
||||||
|
@ -1,316 +1,82 @@
|
|||||||
/**
|
|
||||||
* Types for images, the things they are made of, and the things
|
|
||||||
* they make up.
|
|
||||||
*
|
|
||||||
* Generated images are txt2img and img2img images. They may have
|
|
||||||
* had additional postprocessing done on them when they were first
|
|
||||||
* generated.
|
|
||||||
*
|
|
||||||
* Postprocessed images are images which were not generated here
|
|
||||||
* but only postprocessed by the app. They only get postprocessing
|
|
||||||
* metadata and have a different image type, e.g. 'esrgan' or
|
|
||||||
* 'gfpgan'.
|
|
||||||
*/
|
|
||||||
|
|
||||||
import { SelectedImage } from 'features/parameters/store/actions';
|
|
||||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||||
import { IRect } from 'konva/lib/types';
|
|
||||||
import { ImageResponseMetadata, ImageType } from 'services/api';
|
|
||||||
import { O } from 'ts-toolbelt';
|
import { O } from 'ts-toolbelt';
|
||||||
|
|
||||||
/**
|
// These are old types from the model management UI
|
||||||
* TODO:
|
|
||||||
* Once an image has been generated, if it is postprocessed again,
|
|
||||||
* additional postprocessing steps are added to its postprocessing
|
|
||||||
* array.
|
|
||||||
*
|
|
||||||
* TODO: Better documentation of types.
|
|
||||||
*/
|
|
||||||
|
|
||||||
export type PromptItem = {
|
// export type ModelStatus = 'active' | 'cached' | 'not loaded';
|
||||||
prompt: string;
|
|
||||||
weight: number;
|
|
||||||
};
|
|
||||||
|
|
||||||
// TECHDEBT: We need to retain compatibility with plain prompt strings and the structure Prompt type
|
// export type Model = {
|
||||||
export type Prompt = Array<PromptItem> | string;
|
// status: ModelStatus;
|
||||||
|
// description: string;
|
||||||
export type SeedWeightPair = {
|
// weights: string;
|
||||||
seed: number;
|
// config?: string;
|
||||||
weight: number;
|
// vae?: string;
|
||||||
};
|
// width?: number;
|
||||||
|
// height?: number;
|
||||||
export type SeedWeights = Array<SeedWeightPair>;
|
// default?: boolean;
|
||||||
|
// format?: string;
|
||||||
// All generated images contain these metadata.
|
|
||||||
export type CommonGeneratedImageMetadata = {
|
|
||||||
postprocessing: null | Array<ESRGANMetadata | FacetoolMetadata>;
|
|
||||||
sampler:
|
|
||||||
| 'ddim'
|
|
||||||
| 'ddpm'
|
|
||||||
| 'deis'
|
|
||||||
| 'lms'
|
|
||||||
| 'pndm'
|
|
||||||
| 'heun'
|
|
||||||
| 'heun_k'
|
|
||||||
| 'euler'
|
|
||||||
| 'euler_k'
|
|
||||||
| 'euler_a'
|
|
||||||
| 'kdpm_2'
|
|
||||||
| 'kdpm_2_a'
|
|
||||||
| 'dpmpp_2s'
|
|
||||||
| 'dpmpp_2m'
|
|
||||||
| 'dpmpp_2m_k'
|
|
||||||
| 'unipc';
|
|
||||||
prompt: Prompt;
|
|
||||||
seed: number;
|
|
||||||
variations: SeedWeights;
|
|
||||||
steps: number;
|
|
||||||
cfg_scale: number;
|
|
||||||
width: number;
|
|
||||||
height: number;
|
|
||||||
seamless: boolean;
|
|
||||||
hires_fix: boolean;
|
|
||||||
extra: null | Record<string, never>; // Pending development of RFC #266
|
|
||||||
};
|
|
||||||
|
|
||||||
// txt2img and img2img images have some unique attributes.
|
|
||||||
export type Txt2ImgMetadata = CommonGeneratedImageMetadata & {
|
|
||||||
type: 'txt2img';
|
|
||||||
};
|
|
||||||
|
|
||||||
export type Img2ImgMetadata = CommonGeneratedImageMetadata & {
|
|
||||||
type: 'img2img';
|
|
||||||
orig_hash: string;
|
|
||||||
strength: number;
|
|
||||||
fit: boolean;
|
|
||||||
init_image_path: string;
|
|
||||||
mask_image_path?: string;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Superset of generated image metadata types.
|
|
||||||
export type GeneratedImageMetadata = Txt2ImgMetadata | Img2ImgMetadata;
|
|
||||||
|
|
||||||
// All post processed images contain these metadata.
|
|
||||||
export type CommonPostProcessedImageMetadata = {
|
|
||||||
orig_path: string;
|
|
||||||
orig_hash: string;
|
|
||||||
};
|
|
||||||
|
|
||||||
// esrgan and gfpgan images have some unique attributes.
|
|
||||||
export type ESRGANMetadata = CommonPostProcessedImageMetadata & {
|
|
||||||
type: 'esrgan';
|
|
||||||
scale: 2 | 4;
|
|
||||||
strength: number;
|
|
||||||
denoise_str: number;
|
|
||||||
};
|
|
||||||
|
|
||||||
export type FacetoolMetadata = CommonPostProcessedImageMetadata & {
|
|
||||||
type: 'gfpgan' | 'codeformer';
|
|
||||||
strength: number;
|
|
||||||
fidelity?: number;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Superset of all postprocessed image metadata types..
|
|
||||||
export type PostProcessedImageMetadata = ESRGANMetadata | FacetoolMetadata;
|
|
||||||
|
|
||||||
// Metadata includes the system config and image metadata.
|
|
||||||
// export type Metadata = SystemGenerationMetadata & {
|
|
||||||
// image: GeneratedImageMetadata | PostProcessedImageMetadata;
|
|
||||||
// };
|
// };
|
||||||
|
|
||||||
/**
|
// export type DiffusersModel = {
|
||||||
* ResultImage
|
// status: ModelStatus;
|
||||||
*/
|
// description: string;
|
||||||
// export ty`pe Image = {
|
// repo_id?: string;
|
||||||
|
// path?: string;
|
||||||
|
// vae?: {
|
||||||
|
// repo_id?: string;
|
||||||
|
// path?: string;
|
||||||
|
// };
|
||||||
|
// format?: string;
|
||||||
|
// default?: boolean;
|
||||||
|
// };
|
||||||
|
|
||||||
|
// export type ModelList = Record<string, Model & DiffusersModel>;
|
||||||
|
|
||||||
|
// export type FoundModel = {
|
||||||
// name: string;
|
// name: string;
|
||||||
// type: ImageType;
|
// location: string;
|
||||||
// url: string;
|
|
||||||
// thumbnail: string;
|
|
||||||
// metadata: ImageResponseMetadata;
|
|
||||||
// };
|
// };
|
||||||
|
|
||||||
// export const isInvokeAIImage = (obj: Image | SelectedImage): obj is Image => {
|
// export type InvokeModelConfigProps = {
|
||||||
// if ('url' in obj && 'thumbnail' in obj) {
|
// name: string | undefined;
|
||||||
// return true;
|
// description: string | undefined;
|
||||||
// }
|
// config: string | undefined;
|
||||||
|
// weights: string | undefined;
|
||||||
// return false;
|
// vae: string | undefined;
|
||||||
|
// width: number | undefined;
|
||||||
|
// height: number | undefined;
|
||||||
|
// default: boolean | undefined;
|
||||||
|
// format: string | undefined;
|
||||||
// };
|
// };
|
||||||
|
|
||||||
/**
|
// export type InvokeDiffusersModelConfigProps = {
|
||||||
* Types related to the system status.
|
// name: string | undefined;
|
||||||
*/
|
// description: string | undefined;
|
||||||
|
// repo_id: string | undefined;
|
||||||
// // This represents the processing status of the backend.
|
// path: string | undefined;
|
||||||
// export type SystemStatus = {
|
// default: boolean | undefined;
|
||||||
// isProcessing: boolean;
|
// format: string | undefined;
|
||||||
// currentStep: number;
|
// vae: {
|
||||||
// totalSteps: number;
|
// repo_id: string | undefined;
|
||||||
// currentIteration: number;
|
// path: string | undefined;
|
||||||
// totalIterations: number;
|
// };
|
||||||
// currentStatus: string;
|
|
||||||
// currentStatusHasSteps: boolean;
|
|
||||||
// hasError: boolean;
|
|
||||||
// };
|
// };
|
||||||
|
|
||||||
// export type SystemGenerationMetadata = {
|
// export type InvokeModelConversionProps = {
|
||||||
// model: string;
|
// model_name: string;
|
||||||
// model_weights?: string;
|
// save_location: string;
|
||||||
// model_id?: string;
|
// custom_location: string | null;
|
||||||
// model_hash: string;
|
|
||||||
// app_id: string;
|
|
||||||
// app_version: string;
|
|
||||||
// };
|
// };
|
||||||
|
|
||||||
// export type SystemConfig = SystemGenerationMetadata & {
|
// export type InvokeModelMergingProps = {
|
||||||
// model_list: ModelList;
|
// models_to_merge: string[];
|
||||||
// infill_methods: string[];
|
// alpha: number;
|
||||||
|
// interp: 'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference';
|
||||||
|
// force: boolean;
|
||||||
|
// merged_model_name: string;
|
||||||
|
// model_merge_save_path: string | null;
|
||||||
// };
|
// };
|
||||||
|
|
||||||
export type ModelStatus = 'active' | 'cached' | 'not loaded';
|
|
||||||
|
|
||||||
export type Model = {
|
|
||||||
status: ModelStatus;
|
|
||||||
description: string;
|
|
||||||
weights: string;
|
|
||||||
config?: string;
|
|
||||||
vae?: string;
|
|
||||||
width?: number;
|
|
||||||
height?: number;
|
|
||||||
default?: boolean;
|
|
||||||
format?: string;
|
|
||||||
};
|
|
||||||
|
|
||||||
export type DiffusersModel = {
|
|
||||||
status: ModelStatus;
|
|
||||||
description: string;
|
|
||||||
repo_id?: string;
|
|
||||||
path?: string;
|
|
||||||
vae?: {
|
|
||||||
repo_id?: string;
|
|
||||||
path?: string;
|
|
||||||
};
|
|
||||||
format?: string;
|
|
||||||
default?: boolean;
|
|
||||||
};
|
|
||||||
|
|
||||||
export type ModelList = Record<string, Model & DiffusersModel>;
|
|
||||||
|
|
||||||
export type FoundModel = {
|
|
||||||
name: string;
|
|
||||||
location: string;
|
|
||||||
};
|
|
||||||
|
|
||||||
export type InvokeModelConfigProps = {
|
|
||||||
name: string | undefined;
|
|
||||||
description: string | undefined;
|
|
||||||
config: string | undefined;
|
|
||||||
weights: string | undefined;
|
|
||||||
vae: string | undefined;
|
|
||||||
width: number | undefined;
|
|
||||||
height: number | undefined;
|
|
||||||
default: boolean | undefined;
|
|
||||||
format: string | undefined;
|
|
||||||
};
|
|
||||||
|
|
||||||
export type InvokeDiffusersModelConfigProps = {
|
|
||||||
name: string | undefined;
|
|
||||||
description: string | undefined;
|
|
||||||
repo_id: string | undefined;
|
|
||||||
path: string | undefined;
|
|
||||||
default: boolean | undefined;
|
|
||||||
format: string | undefined;
|
|
||||||
vae: {
|
|
||||||
repo_id: string | undefined;
|
|
||||||
path: string | undefined;
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
export type InvokeModelConversionProps = {
|
|
||||||
model_name: string;
|
|
||||||
save_location: string;
|
|
||||||
custom_location: string | null;
|
|
||||||
};
|
|
||||||
|
|
||||||
export type InvokeModelMergingProps = {
|
|
||||||
models_to_merge: string[];
|
|
||||||
alpha: number;
|
|
||||||
interp: 'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference';
|
|
||||||
force: boolean;
|
|
||||||
merged_model_name: string;
|
|
||||||
model_merge_save_path: string | null;
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
|
||||||
* These types type data received from the server via socketio.
|
|
||||||
*/
|
|
||||||
|
|
||||||
export type ModelChangeResponse = {
|
|
||||||
model_name: string;
|
|
||||||
model_list: ModelList;
|
|
||||||
};
|
|
||||||
|
|
||||||
export type ModelConvertedResponse = {
|
|
||||||
converted_model_name: string;
|
|
||||||
model_list: ModelList;
|
|
||||||
};
|
|
||||||
|
|
||||||
export type ModelsMergedResponse = {
|
|
||||||
merged_models: string[];
|
|
||||||
merged_model_name: string;
|
|
||||||
model_list: ModelList;
|
|
||||||
};
|
|
||||||
|
|
||||||
export type ModelAddedResponse = {
|
|
||||||
new_model_name: string;
|
|
||||||
model_list: ModelList;
|
|
||||||
update: boolean;
|
|
||||||
};
|
|
||||||
|
|
||||||
export type ModelDeletedResponse = {
|
|
||||||
deleted_model_name: string;
|
|
||||||
model_list: ModelList;
|
|
||||||
};
|
|
||||||
|
|
||||||
export type FoundModelResponse = {
|
|
||||||
search_folder: string;
|
|
||||||
found_models: FoundModel[];
|
|
||||||
};
|
|
||||||
|
|
||||||
// export type SystemStatusResponse = SystemStatus;
|
|
||||||
|
|
||||||
// export type SystemConfigResponse = SystemConfig;
|
|
||||||
|
|
||||||
export type ImageResultResponse = Omit<Image, 'uuid'> & {
|
|
||||||
boundingBox?: IRect;
|
|
||||||
generationMode: InvokeTabName;
|
|
||||||
};
|
|
||||||
|
|
||||||
export type ImageUploadResponse = {
|
|
||||||
// image: Omit<Image, 'uuid' | 'metadata' | 'category'>;
|
|
||||||
url: string;
|
|
||||||
mtime: number;
|
|
||||||
width: number;
|
|
||||||
height: number;
|
|
||||||
thumbnail: string;
|
|
||||||
// bbox: [number, number, number, number];
|
|
||||||
};
|
|
||||||
|
|
||||||
export type ErrorResponse = {
|
|
||||||
message: string;
|
|
||||||
additionalData?: string;
|
|
||||||
};
|
|
||||||
|
|
||||||
export type ImageUrlResponse = {
|
|
||||||
url: string;
|
|
||||||
};
|
|
||||||
|
|
||||||
export type UploadOutpaintingMergeImagePayload = {
|
|
||||||
dataURL: string;
|
|
||||||
name: string;
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A disable-able application feature
|
* A disable-able application feature
|
||||||
*/
|
*/
|
||||||
@ -322,7 +88,8 @@ export type AppFeature =
|
|||||||
| 'githubLink'
|
| 'githubLink'
|
||||||
| 'discordLink'
|
| 'discordLink'
|
||||||
| 'bugLink'
|
| 'bugLink'
|
||||||
| 'localization';
|
| 'localization'
|
||||||
|
| 'consoleLogging';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A disable-able Stable Diffusion feature
|
* A disable-able Stable Diffusion feature
|
||||||
@ -351,6 +118,7 @@ export type AppConfig = {
|
|||||||
disabledSDFeatures: SDFeature[];
|
disabledSDFeatures: SDFeature[];
|
||||||
canRestoreDeletedImagesFromBin: boolean;
|
canRestoreDeletedImagesFromBin: boolean;
|
||||||
sd: {
|
sd: {
|
||||||
|
defaultModel?: string;
|
||||||
iterations: {
|
iterations: {
|
||||||
initial: number;
|
initial: number;
|
||||||
min: number;
|
min: number;
|
||||||
|
@ -21,9 +21,12 @@ import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
|||||||
|
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
|
|
||||||
|
export type ItemTooltips = { [key: string]: string };
|
||||||
|
|
||||||
type IAICustomSelectProps = {
|
type IAICustomSelectProps = {
|
||||||
label?: string;
|
label?: string;
|
||||||
items: string[];
|
items: string[];
|
||||||
|
itemTooltips?: ItemTooltips;
|
||||||
selectedItem: string;
|
selectedItem: string;
|
||||||
setSelectedItem: (v: string | null | undefined) => void;
|
setSelectedItem: (v: string | null | undefined) => void;
|
||||||
withCheckIcon?: boolean;
|
withCheckIcon?: boolean;
|
||||||
@ -37,6 +40,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
|
|||||||
const {
|
const {
|
||||||
label,
|
label,
|
||||||
items,
|
items,
|
||||||
|
itemTooltips,
|
||||||
setSelectedItem,
|
setSelectedItem,
|
||||||
selectedItem,
|
selectedItem,
|
||||||
withCheckIcon,
|
withCheckIcon,
|
||||||
@ -118,6 +122,13 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
|
|||||||
>
|
>
|
||||||
<OverlayScrollbarsComponent>
|
<OverlayScrollbarsComponent>
|
||||||
{items.map((item, index) => (
|
{items.map((item, index) => (
|
||||||
|
<Tooltip
|
||||||
|
isDisabled={!itemTooltips}
|
||||||
|
key={`${item}${index}`}
|
||||||
|
label={itemTooltips?.[item]}
|
||||||
|
hasArrow
|
||||||
|
placement="right"
|
||||||
|
>
|
||||||
<ListItem
|
<ListItem
|
||||||
sx={{
|
sx={{
|
||||||
bg: highlightedIndex === index ? 'base.700' : undefined,
|
bg: highlightedIndex === index ? 'base.700' : undefined,
|
||||||
@ -160,6 +171,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
|
|||||||
</Text>
|
</Text>
|
||||||
)}
|
)}
|
||||||
</ListItem>
|
</ListItem>
|
||||||
|
</Tooltip>
|
||||||
))}
|
))}
|
||||||
</OverlayScrollbarsComponent>
|
</OverlayScrollbarsComponent>
|
||||||
</List>
|
</List>
|
||||||
|
@ -4,7 +4,6 @@ import { useHotkeys } from 'react-hotkeys-hook';
|
|||||||
type ImageUploadOverlayProps = {
|
type ImageUploadOverlayProps = {
|
||||||
isDragAccept: boolean;
|
isDragAccept: boolean;
|
||||||
isDragReject: boolean;
|
isDragReject: boolean;
|
||||||
overlaySecondaryText: string;
|
|
||||||
setIsHandlingUpload: (isHandlingUpload: boolean) => void;
|
setIsHandlingUpload: (isHandlingUpload: boolean) => void;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -12,7 +11,6 @@ const ImageUploadOverlay = (props: ImageUploadOverlayProps) => {
|
|||||||
const {
|
const {
|
||||||
isDragAccept,
|
isDragAccept,
|
||||||
isDragReject: _isDragAccept,
|
isDragReject: _isDragAccept,
|
||||||
overlaySecondaryText,
|
|
||||||
setIsHandlingUpload,
|
setIsHandlingUpload,
|
||||||
} = props;
|
} = props;
|
||||||
|
|
||||||
@ -48,7 +46,7 @@ const ImageUploadOverlay = (props: ImageUploadOverlayProps) => {
|
|||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
{isDragAccept ? (
|
{isDragAccept ? (
|
||||||
<Heading size="lg">Upload Image{overlaySecondaryText}</Heading>
|
<Heading size="lg">Drop to Upload</Heading>
|
||||||
) : (
|
) : (
|
||||||
<>
|
<>
|
||||||
<Heading size="lg">Invalid Upload</Heading>
|
<Heading size="lg">Invalid Upload</Heading>
|
||||||
|
@ -68,13 +68,13 @@ const ImageUploader = (props: ImageUploaderProps) => {
|
|||||||
async (file: File) => {
|
async (file: File) => {
|
||||||
dispatch(
|
dispatch(
|
||||||
imageUploaded({
|
imageUploaded({
|
||||||
imageType: 'uploads',
|
|
||||||
formData: { file },
|
formData: { file },
|
||||||
activeTabName,
|
imageCategory: 'user',
|
||||||
|
isIntermediate: false,
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
},
|
},
|
||||||
[dispatch, activeTabName]
|
[dispatch]
|
||||||
);
|
);
|
||||||
|
|
||||||
const onDrop = useCallback(
|
const onDrop = useCallback(
|
||||||
@ -145,14 +145,6 @@ const ImageUploader = (props: ImageUploaderProps) => {
|
|||||||
};
|
};
|
||||||
}, [inputRef, open, setOpenUploaderFunction]);
|
}, [inputRef, open, setOpenUploaderFunction]);
|
||||||
|
|
||||||
const overlaySecondaryText = useMemo(() => {
|
|
||||||
if (['img2img', 'unifiedCanvas'].includes(activeTabName)) {
|
|
||||||
return ` to ${String(t(`common.${activeTabName}` as ResourceKey))}`;
|
|
||||||
}
|
|
||||||
|
|
||||||
return '';
|
|
||||||
}, [t, activeTabName]);
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Box
|
<Box
|
||||||
{...getRootProps({ style: {} })}
|
{...getRootProps({ style: {} })}
|
||||||
@ -167,7 +159,6 @@ const ImageUploader = (props: ImageUploaderProps) => {
|
|||||||
<ImageUploadOverlay
|
<ImageUploadOverlay
|
||||||
isDragAccept={isDragAccept}
|
isDragAccept={isDragAccept}
|
||||||
isDragReject={isDragReject}
|
isDragReject={isDragReject}
|
||||||
overlaySecondaryText={overlaySecondaryText}
|
|
||||||
setIsHandlingUpload={setIsHandlingUpload}
|
setIsHandlingUpload={setIsHandlingUpload}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
|
@ -1,119 +0,0 @@
|
|||||||
/**
|
|
||||||
* PARTIAL ZOD IMPLEMENTATION
|
|
||||||
*
|
|
||||||
* doesn't work well bc like most validators, zod is not built to skip invalid values.
|
|
||||||
* it mostly works but just seems clearer and simpler to manually parse for now.
|
|
||||||
*
|
|
||||||
* in the future it would be really nice if we could use zod for some things:
|
|
||||||
* - zodios (axios + zod): https://github.com/ecyrbe/zodios
|
|
||||||
* - openapi to zodios: https://github.com/astahmer/openapi-zod-client
|
|
||||||
*/
|
|
||||||
|
|
||||||
// import { z } from 'zod';
|
|
||||||
|
|
||||||
// const zMetadataStringField = z.string();
|
|
||||||
// export type MetadataStringField = z.infer<typeof zMetadataStringField>;
|
|
||||||
|
|
||||||
// const zMetadataIntegerField = z.number().int();
|
|
||||||
// export type MetadataIntegerField = z.infer<typeof zMetadataIntegerField>;
|
|
||||||
|
|
||||||
// const zMetadataFloatField = z.number();
|
|
||||||
// export type MetadataFloatField = z.infer<typeof zMetadataFloatField>;
|
|
||||||
|
|
||||||
// const zMetadataBooleanField = z.boolean();
|
|
||||||
// export type MetadataBooleanField = z.infer<typeof zMetadataBooleanField>;
|
|
||||||
|
|
||||||
// const zMetadataImageField = z.object({
|
|
||||||
// image_type: z.union([
|
|
||||||
// z.literal('results'),
|
|
||||||
// z.literal('uploads'),
|
|
||||||
// z.literal('intermediates'),
|
|
||||||
// ]),
|
|
||||||
// image_name: z.string().min(1),
|
|
||||||
// });
|
|
||||||
// export type MetadataImageField = z.infer<typeof zMetadataImageField>;
|
|
||||||
|
|
||||||
// const zMetadataLatentsField = z.object({
|
|
||||||
// latents_name: z.string().min(1),
|
|
||||||
// });
|
|
||||||
// export type MetadataLatentsField = z.infer<typeof zMetadataLatentsField>;
|
|
||||||
|
|
||||||
// /**
|
|
||||||
// * zod Schema for any node field. Use a `transform()` to manually parse, skipping invalid values.
|
|
||||||
// */
|
|
||||||
// const zAnyMetadataField = z.any().transform((val, ctx) => {
|
|
||||||
// // Grab the field name from the path
|
|
||||||
// const fieldName = String(ctx.path[ctx.path.length - 1]);
|
|
||||||
|
|
||||||
// // `id` and `type` must be strings if they exist
|
|
||||||
// if (['id', 'type'].includes(fieldName)) {
|
|
||||||
// const reservedStringPropertyResult = zMetadataStringField.safeParse(val);
|
|
||||||
// if (reservedStringPropertyResult.success) {
|
|
||||||
// return reservedStringPropertyResult.data;
|
|
||||||
// }
|
|
||||||
|
|
||||||
// return;
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // Parse the rest of the fields, only returning the data if the parsing is successful
|
|
||||||
|
|
||||||
// const stringFieldResult = zMetadataStringField.safeParse(val);
|
|
||||||
// if (stringFieldResult.success) {
|
|
||||||
// return stringFieldResult.data;
|
|
||||||
// }
|
|
||||||
|
|
||||||
// const integerFieldResult = zMetadataIntegerField.safeParse(val);
|
|
||||||
// if (integerFieldResult.success) {
|
|
||||||
// return integerFieldResult.data;
|
|
||||||
// }
|
|
||||||
|
|
||||||
// const floatFieldResult = zMetadataFloatField.safeParse(val);
|
|
||||||
// if (floatFieldResult.success) {
|
|
||||||
// return floatFieldResult.data;
|
|
||||||
// }
|
|
||||||
|
|
||||||
// const booleanFieldResult = zMetadataBooleanField.safeParse(val);
|
|
||||||
// if (booleanFieldResult.success) {
|
|
||||||
// return booleanFieldResult.data;
|
|
||||||
// }
|
|
||||||
|
|
||||||
// const imageFieldResult = zMetadataImageField.safeParse(val);
|
|
||||||
// if (imageFieldResult.success) {
|
|
||||||
// return imageFieldResult.data;
|
|
||||||
// }
|
|
||||||
|
|
||||||
// const latentsFieldResult = zMetadataImageField.safeParse(val);
|
|
||||||
// if (latentsFieldResult.success) {
|
|
||||||
// return latentsFieldResult.data;
|
|
||||||
// }
|
|
||||||
// });
|
|
||||||
|
|
||||||
// /**
|
|
||||||
// * The node metadata schema.
|
|
||||||
// */
|
|
||||||
// const zNodeMetadata = z.object({
|
|
||||||
// session_id: z.string().min(1).optional(),
|
|
||||||
// node: z.record(z.string().min(1), zAnyMetadataField).optional(),
|
|
||||||
// });
|
|
||||||
|
|
||||||
// export type NodeMetadata = z.infer<typeof zNodeMetadata>;
|
|
||||||
|
|
||||||
// const zMetadata = z.object({
|
|
||||||
// invokeai: zNodeMetadata.optional(),
|
|
||||||
// 'sd-metadata': z.record(z.string().min(1), z.any()).optional(),
|
|
||||||
// });
|
|
||||||
// export type Metadata = z.infer<typeof zMetadata>;
|
|
||||||
|
|
||||||
// export const parseMetadata = (
|
|
||||||
// metadata: Record<string, any>
|
|
||||||
// ): Metadata | undefined => {
|
|
||||||
// const result = zMetadata.safeParse(metadata);
|
|
||||||
// if (!result.success) {
|
|
||||||
// console.log(result.error.issues);
|
|
||||||
// return;
|
|
||||||
// }
|
|
||||||
|
|
||||||
// return result.data;
|
|
||||||
// };
|
|
||||||
|
|
||||||
export default {};
|
|
@ -1,334 +0,0 @@
|
|||||||
import { forEach, size } from 'lodash-es';
|
|
||||||
import {
|
|
||||||
ImageField,
|
|
||||||
LatentsField,
|
|
||||||
ConditioningField,
|
|
||||||
UNetField,
|
|
||||||
ClipField,
|
|
||||||
VaeField,
|
|
||||||
} from 'services/api';
|
|
||||||
|
|
||||||
const OBJECT_TYPESTRING = '[object Object]';
|
|
||||||
const STRING_TYPESTRING = '[object String]';
|
|
||||||
const NUMBER_TYPESTRING = '[object Number]';
|
|
||||||
const BOOLEAN_TYPESTRING = '[object Boolean]';
|
|
||||||
const ARRAY_TYPESTRING = '[object Array]';
|
|
||||||
|
|
||||||
const isObject = (obj: unknown): obj is Record<string | number, any> =>
|
|
||||||
Object.prototype.toString.call(obj) === OBJECT_TYPESTRING;
|
|
||||||
|
|
||||||
const isString = (obj: unknown): obj is string =>
|
|
||||||
Object.prototype.toString.call(obj) === STRING_TYPESTRING;
|
|
||||||
|
|
||||||
const isNumber = (obj: unknown): obj is number =>
|
|
||||||
Object.prototype.toString.call(obj) === NUMBER_TYPESTRING;
|
|
||||||
|
|
||||||
const isBoolean = (obj: unknown): obj is boolean =>
|
|
||||||
Object.prototype.toString.call(obj) === BOOLEAN_TYPESTRING;
|
|
||||||
|
|
||||||
const isArray = (obj: unknown): obj is Array<any> =>
|
|
||||||
Object.prototype.toString.call(obj) === ARRAY_TYPESTRING;
|
|
||||||
|
|
||||||
const parseImageField = (imageField: unknown): ImageField | undefined => {
|
|
||||||
// Must be an object
|
|
||||||
if (!isObject(imageField)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// An ImageField must have both `image_name` and `image_type`
|
|
||||||
if (!('image_name' in imageField && 'image_type' in imageField)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// An ImageField's `image_type` must be one of the allowed values
|
|
||||||
if (
|
|
||||||
!['results', 'uploads', 'intermediates'].includes(imageField.image_type)
|
|
||||||
) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// An ImageField's `image_name` must be a string
|
|
||||||
if (typeof imageField.image_name !== 'string') {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build a valid ImageField
|
|
||||||
return {
|
|
||||||
image_type: imageField.image_type,
|
|
||||||
image_name: imageField.image_name,
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
const parseLatentsField = (latentsField: unknown): LatentsField | undefined => {
|
|
||||||
// Must be an object
|
|
||||||
if (!isObject(latentsField)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// A LatentsField must have a `latents_name`
|
|
||||||
if (!('latents_name' in latentsField)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// A LatentsField's `latents_name` must be a string
|
|
||||||
if (typeof latentsField.latents_name !== 'string') {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build a valid LatentsField
|
|
||||||
return {
|
|
||||||
latents_name: latentsField.latents_name,
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
const parseConditioningField = (
|
|
||||||
conditioningField: unknown
|
|
||||||
): ConditioningField | undefined => {
|
|
||||||
// Must be an object
|
|
||||||
if (!isObject(conditioningField)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// A ConditioningField must have a `conditioning_name`
|
|
||||||
if (!('conditioning_name' in conditioningField)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// A ConditioningField's `conditioning_name` must be a string
|
|
||||||
if (typeof conditioningField.conditioning_name !== 'string') {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build a valid ConditioningField
|
|
||||||
return {
|
|
||||||
conditioning_name: conditioningField.conditioning_name,
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
const _parseModelInfo = (modelInfo: unknown): ModelInfo | undefined => {
|
|
||||||
// Must be an object
|
|
||||||
if (!isObject(modelInfo)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!('model_name' in modelInfo && typeof modelInfo.model_name == 'string')) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!('model_type' in modelInfo && typeof modelInfo.model_type == 'string')) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!('submodel' in modelInfo && typeof modelInfo.submodel == 'string')) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
model_name: modelInfo.model_name,
|
|
||||||
model_type: modelInfo.model_type,
|
|
||||||
submodel: modelInfo.submodel,
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
const parseUNetField = (unetField: unknown): UNetField | undefined => {
|
|
||||||
// Must be an object
|
|
||||||
if (!isObject(unetField)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!('unet' in unetField && 'scheduler' in unetField)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const unet = _parseModelInfo(unetField.unet);
|
|
||||||
const scheduler = _parseModelInfo(unetField.scheduler);
|
|
||||||
|
|
||||||
if (!(unet && scheduler)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build a valid UNetField
|
|
||||||
return {
|
|
||||||
unet: unet,
|
|
||||||
scheduler: scheduler,
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
const parseClipField = (clipField: unknown): ClipField | undefined => {
|
|
||||||
// Must be an object
|
|
||||||
if (!isObject(clipField)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!('tokenizer' in clipField && 'text_encoder' in clipField)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const tokenizer = _parseModelInfo(clipField.tokenizer);
|
|
||||||
const text_encoder = _parseModelInfo(clipField.text_encoder);
|
|
||||||
|
|
||||||
if (!(tokenizer && text_encoder)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build a valid ClipField
|
|
||||||
return {
|
|
||||||
tokenizer: tokenizer,
|
|
||||||
text_encoder: text_encoder,
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
const parseVaeField = (vaeField: unknown): VaeField | undefined => {
|
|
||||||
// Must be an object
|
|
||||||
if (!isObject(vaeField)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!('vae' in vaeField)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const vae = _parseModelInfo(vaeField.vae);
|
|
||||||
|
|
||||||
if (!vae) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build a valid VaeField
|
|
||||||
return {
|
|
||||||
vae: vae,
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
type NodeMetadata = {
|
|
||||||
[key: string]:
|
|
||||||
| string
|
|
||||||
| number
|
|
||||||
| boolean
|
|
||||||
| ImageField
|
|
||||||
| LatentsField
|
|
||||||
| ConditioningField
|
|
||||||
| UNetField
|
|
||||||
| ClipField
|
|
||||||
| VaeField;
|
|
||||||
};
|
|
||||||
|
|
||||||
type InvokeAIMetadata = {
|
|
||||||
session_id?: string;
|
|
||||||
node?: NodeMetadata;
|
|
||||||
};
|
|
||||||
|
|
||||||
export const parseNodeMetadata = (
|
|
||||||
nodeMetadata: Record<string | number, any>
|
|
||||||
): NodeMetadata | undefined => {
|
|
||||||
if (!isObject(nodeMetadata)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const parsed: NodeMetadata = {};
|
|
||||||
|
|
||||||
forEach(nodeMetadata, (nodeItem, nodeKey) => {
|
|
||||||
// `id` and `type` must be strings if they are present
|
|
||||||
if (['id', 'type'].includes(nodeKey)) {
|
|
||||||
if (isString(nodeItem)) {
|
|
||||||
parsed[nodeKey] = nodeItem;
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// valid object types are:
|
|
||||||
// ImageField, LatentsField ConditioningField, UNetField, ClipField, VaeField
|
|
||||||
if (isObject(nodeItem)) {
|
|
||||||
if ('image_name' in nodeItem || 'image_type' in nodeItem) {
|
|
||||||
const imageField = parseImageField(nodeItem);
|
|
||||||
if (imageField) {
|
|
||||||
parsed[nodeKey] = imageField;
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if ('latents_name' in nodeItem) {
|
|
||||||
const latentsField = parseLatentsField(nodeItem);
|
|
||||||
if (latentsField) {
|
|
||||||
parsed[nodeKey] = latentsField;
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if ('conditioning_name' in nodeItem) {
|
|
||||||
const conditioningField = parseConditioningField(nodeItem);
|
|
||||||
if (conditioningField) {
|
|
||||||
parsed[nodeKey] = conditioningField;
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if ('unet' in nodeItem && 'scheduler' in nodeItem) {
|
|
||||||
const unetField = parseUNetField(nodeItem);
|
|
||||||
if (unetField) {
|
|
||||||
parsed[nodeKey] = unetField;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if ('tokenizer' in nodeItem && 'text_encoder' in nodeItem) {
|
|
||||||
const clipField = parseClipField(nodeItem);
|
|
||||||
if (clipField) {
|
|
||||||
parsed[nodeKey] = clipField;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if ('vae' in nodeItem) {
|
|
||||||
const vaeField = parseVaeField(nodeItem);
|
|
||||||
if (vaeField) {
|
|
||||||
parsed[nodeKey] = vaeField;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// otherwise we accept any string, number or boolean
|
|
||||||
if (isString(nodeItem) || isNumber(nodeItem) || isBoolean(nodeItem)) {
|
|
||||||
parsed[nodeKey] = nodeItem;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
if (size(parsed) === 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
return parsed;
|
|
||||||
};
|
|
||||||
|
|
||||||
export const parseInvokeAIMetadata = (
|
|
||||||
metadata: Record<string | number, any> | undefined
|
|
||||||
): InvokeAIMetadata | undefined => {
|
|
||||||
if (metadata === undefined) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!isObject(metadata)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const parsed: InvokeAIMetadata = {};
|
|
||||||
|
|
||||||
forEach(metadata, (item, key) => {
|
|
||||||
if (key === 'session_id' && isString(item)) {
|
|
||||||
parsed['session_id'] = item;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (key === 'node' && isObject(item)) {
|
|
||||||
const nodeMetadata = parseNodeMetadata(item);
|
|
||||||
|
|
||||||
if (nodeMetadata) {
|
|
||||||
parsed['node'] = nodeMetadata;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
if (size(parsed) === 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
return parsed;
|
|
||||||
};
|
|
@ -1,18 +1,24 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { RootState } from 'app/store/store';
|
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { useGetUrl } from 'common/util/getUrl';
|
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||||
import { GalleryState } from 'features/gallery/store/gallerySlice';
|
|
||||||
import { ImageConfig } from 'konva/lib/shapes/Image';
|
import { ImageConfig } from 'konva/lib/shapes/Image';
|
||||||
import { isEqual } from 'lodash-es';
|
import { isEqual } from 'lodash-es';
|
||||||
|
|
||||||
import { useEffect, useState } from 'react';
|
import { useEffect, useState } from 'react';
|
||||||
import { Image as KonvaImage } from 'react-konva';
|
import { Image as KonvaImage } from 'react-konva';
|
||||||
|
import { canvasSelector } from '../store/canvasSelectors';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
[(state: RootState) => state.gallery],
|
[systemSelector, canvasSelector],
|
||||||
(gallery: GalleryState) => {
|
(system, canvas) => {
|
||||||
return gallery.intermediateImage ? gallery.intermediateImage : null;
|
const { progressImage, sessionId } = system;
|
||||||
|
const { sessionId: canvasSessionId, boundingBox } =
|
||||||
|
canvas.layerState.stagingArea;
|
||||||
|
|
||||||
|
return {
|
||||||
|
boundingBox,
|
||||||
|
progressImage: sessionId === canvasSessionId ? progressImage : undefined,
|
||||||
|
};
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
memoizeOptions: {
|
memoizeOptions: {
|
||||||
@ -25,33 +31,34 @@ type Props = Omit<ImageConfig, 'image'>;
|
|||||||
|
|
||||||
const IAICanvasIntermediateImage = (props: Props) => {
|
const IAICanvasIntermediateImage = (props: Props) => {
|
||||||
const { ...rest } = props;
|
const { ...rest } = props;
|
||||||
const intermediateImage = useAppSelector(selector);
|
const { progressImage, boundingBox } = useAppSelector(selector);
|
||||||
const { getUrl } = useGetUrl();
|
|
||||||
const [loadedImageElement, setLoadedImageElement] =
|
const [loadedImageElement, setLoadedImageElement] =
|
||||||
useState<HTMLImageElement | null>(null);
|
useState<HTMLImageElement | null>(null);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!intermediateImage) return;
|
if (!progressImage) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const tempImage = new Image();
|
const tempImage = new Image();
|
||||||
|
|
||||||
tempImage.onload = () => {
|
tempImage.onload = () => {
|
||||||
setLoadedImageElement(tempImage);
|
setLoadedImageElement(tempImage);
|
||||||
};
|
};
|
||||||
tempImage.src = getUrl(intermediateImage.url);
|
|
||||||
}, [intermediateImage, getUrl]);
|
|
||||||
|
|
||||||
if (!intermediateImage?.boundingBox) return null;
|
tempImage.src = progressImage.dataURL;
|
||||||
|
}, [progressImage]);
|
||||||
|
|
||||||
const {
|
if (!(progressImage && boundingBox)) {
|
||||||
boundingBox: { x, y, width, height },
|
return null;
|
||||||
} = intermediateImage;
|
}
|
||||||
|
|
||||||
return loadedImageElement ? (
|
return loadedImageElement ? (
|
||||||
<KonvaImage
|
<KonvaImage
|
||||||
x={x}
|
x={boundingBox.x}
|
||||||
y={y}
|
y={boundingBox.y}
|
||||||
width={width}
|
width={boundingBox.width}
|
||||||
height={height}
|
height={boundingBox.height}
|
||||||
image={loadedImageElement}
|
image={loadedImageElement}
|
||||||
listening={false}
|
listening={false}
|
||||||
{...rest}
|
{...rest}
|
||||||
|
@ -62,7 +62,7 @@ const IAICanvasStagingArea = (props: Props) => {
|
|||||||
<Group {...rest}>
|
<Group {...rest}>
|
||||||
{shouldShowStagingImage && currentStagingAreaImage && (
|
{shouldShowStagingImage && currentStagingAreaImage && (
|
||||||
<IAICanvasImage
|
<IAICanvasImage
|
||||||
url={getUrl(currentStagingAreaImage.image.image_url)}
|
url={getUrl(currentStagingAreaImage.image.image_url) ?? ''}
|
||||||
x={x}
|
x={x}
|
||||||
y={y}
|
y={y}
|
||||||
/>
|
/>
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import { ButtonGroup, Flex } from '@chakra-ui/react';
|
import { ButtonGroup, Flex } from '@chakra-ui/react';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
// import { saveStagingAreaImageToGallery } from 'app/socketio/actions';
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAIIconButton from 'common/components/IAIIconButton';
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
||||||
@ -26,13 +25,14 @@ import {
|
|||||||
FaPlus,
|
FaPlus,
|
||||||
FaSave,
|
FaSave,
|
||||||
} from 'react-icons/fa';
|
} from 'react-icons/fa';
|
||||||
|
import { stagingAreaImageSaved } from '../store/actions';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
[canvasSelector],
|
[canvasSelector],
|
||||||
(canvas) => {
|
(canvas) => {
|
||||||
const {
|
const {
|
||||||
layerState: {
|
layerState: {
|
||||||
stagingArea: { images, selectedImageIndex },
|
stagingArea: { images, selectedImageIndex, sessionId },
|
||||||
},
|
},
|
||||||
shouldShowStagingOutline,
|
shouldShowStagingOutline,
|
||||||
shouldShowStagingImage,
|
shouldShowStagingImage,
|
||||||
@ -45,6 +45,7 @@ const selector = createSelector(
|
|||||||
isOnLastImage: selectedImageIndex === images.length - 1,
|
isOnLastImage: selectedImageIndex === images.length - 1,
|
||||||
shouldShowStagingImage,
|
shouldShowStagingImage,
|
||||||
shouldShowStagingOutline,
|
shouldShowStagingOutline,
|
||||||
|
sessionId,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -61,6 +62,7 @@ const IAICanvasStagingAreaToolbar = () => {
|
|||||||
isOnLastImage,
|
isOnLastImage,
|
||||||
currentStagingAreaImage,
|
currentStagingAreaImage,
|
||||||
shouldShowStagingImage,
|
shouldShowStagingImage,
|
||||||
|
sessionId,
|
||||||
} = useAppSelector(selector);
|
} = useAppSelector(selector);
|
||||||
|
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
@ -106,9 +108,20 @@ const IAICanvasStagingAreaToolbar = () => {
|
|||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
const handlePrevImage = () => dispatch(prevStagingAreaImage());
|
const handlePrevImage = useCallback(
|
||||||
const handleNextImage = () => dispatch(nextStagingAreaImage());
|
() => dispatch(prevStagingAreaImage()),
|
||||||
const handleAccept = () => dispatch(commitStagingAreaImage());
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleNextImage = useCallback(
|
||||||
|
() => dispatch(nextStagingAreaImage()),
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleAccept = useCallback(
|
||||||
|
() => dispatch(commitStagingAreaImage(sessionId)),
|
||||||
|
[dispatch, sessionId]
|
||||||
|
);
|
||||||
|
|
||||||
if (!currentStagingAreaImage) return null;
|
if (!currentStagingAreaImage) return null;
|
||||||
|
|
||||||
@ -157,19 +170,15 @@ const IAICanvasStagingAreaToolbar = () => {
|
|||||||
}
|
}
|
||||||
colorScheme="accent"
|
colorScheme="accent"
|
||||||
/>
|
/>
|
||||||
{/* <IAIIconButton
|
<IAIIconButton
|
||||||
tooltip={t('unifiedCanvas.saveToGallery')}
|
tooltip={t('unifiedCanvas.saveToGallery')}
|
||||||
aria-label={t('unifiedCanvas.saveToGallery')}
|
aria-label={t('unifiedCanvas.saveToGallery')}
|
||||||
icon={<FaSave />}
|
icon={<FaSave />}
|
||||||
onClick={() =>
|
onClick={() =>
|
||||||
dispatch(
|
dispatch(stagingAreaImageSaved(currentStagingAreaImage.image))
|
||||||
saveStagingAreaImageToGallery(
|
|
||||||
currentStagingAreaImage.image.image_url
|
|
||||||
)
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
colorScheme="accent"
|
colorScheme="accent"
|
||||||
/> */}
|
/>
|
||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
tooltip={t('unifiedCanvas.discardAll')}
|
tooltip={t('unifiedCanvas.discardAll')}
|
||||||
aria-label={t('unifiedCanvas.discardAll')}
|
aria-label={t('unifiedCanvas.discardAll')}
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import { createAction } from '@reduxjs/toolkit';
|
import { createAction } from '@reduxjs/toolkit';
|
||||||
|
import { ImageDTO } from 'services/api';
|
||||||
|
|
||||||
export const canvasSavedToGallery = createAction('canvas/canvasSavedToGallery');
|
export const canvasSavedToGallery = createAction('canvas/canvasSavedToGallery');
|
||||||
|
|
||||||
@ -11,3 +12,7 @@ export const canvasDownloadedAsImage = createAction(
|
|||||||
);
|
);
|
||||||
|
|
||||||
export const canvasMerged = createAction('canvas/canvasMerged');
|
export const canvasMerged = createAction('canvas/canvasMerged');
|
||||||
|
|
||||||
|
export const stagingAreaImageSaved = createAction<ImageDTO>(
|
||||||
|
'canvas/stagingAreaImageSaved'
|
||||||
|
);
|
||||||
|
@ -29,6 +29,7 @@ import {
|
|||||||
isCanvasMaskLine,
|
isCanvasMaskLine,
|
||||||
} from './canvasTypes';
|
} from './canvasTypes';
|
||||||
import { ImageDTO } from 'services/api';
|
import { ImageDTO } from 'services/api';
|
||||||
|
import { sessionCanceled } from 'services/thunks/session';
|
||||||
|
|
||||||
export const initialLayerState: CanvasLayerState = {
|
export const initialLayerState: CanvasLayerState = {
|
||||||
objects: [],
|
objects: [],
|
||||||
@ -696,7 +697,10 @@ export const canvasSlice = createSlice({
|
|||||||
0
|
0
|
||||||
);
|
);
|
||||||
},
|
},
|
||||||
commitStagingAreaImage: (state) => {
|
commitStagingAreaImage: (
|
||||||
|
state,
|
||||||
|
action: PayloadAction<string | undefined>
|
||||||
|
) => {
|
||||||
if (!state.layerState.stagingArea.images.length) {
|
if (!state.layerState.stagingArea.images.length) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -841,6 +845,13 @@ export const canvasSlice = createSlice({
|
|||||||
state.isTransformingBoundingBox = false;
|
state.isTransformingBoundingBox = false;
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
extraReducers: (builder) => {
|
||||||
|
builder.addCase(sessionCanceled.pending, (state) => {
|
||||||
|
if (!state.layerState.stagingArea.images.length) {
|
||||||
|
state.layerState.stagingArea = initialLayerState.stagingArea;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
export const {
|
export const {
|
||||||
|
@ -9,7 +9,8 @@ import { IRect } from 'konva/lib/types';
|
|||||||
*/
|
*/
|
||||||
const createMaskStage = async (
|
const createMaskStage = async (
|
||||||
lines: CanvasMaskLine[],
|
lines: CanvasMaskLine[],
|
||||||
boundingBox: IRect
|
boundingBox: IRect,
|
||||||
|
shouldInvertMask: boolean
|
||||||
): Promise<Konva.Stage> => {
|
): Promise<Konva.Stage> => {
|
||||||
// create an offscreen canvas and add the mask to it
|
// create an offscreen canvas and add the mask to it
|
||||||
const { width, height } = boundingBox;
|
const { width, height } = boundingBox;
|
||||||
@ -29,7 +30,7 @@ const createMaskStage = async (
|
|||||||
baseLayer.add(
|
baseLayer.add(
|
||||||
new Konva.Rect({
|
new Konva.Rect({
|
||||||
...boundingBox,
|
...boundingBox,
|
||||||
fill: 'white',
|
fill: shouldInvertMask ? 'black' : 'white',
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
|
|
||||||
@ -37,7 +38,7 @@ const createMaskStage = async (
|
|||||||
maskLayer.add(
|
maskLayer.add(
|
||||||
new Konva.Line({
|
new Konva.Line({
|
||||||
points: line.points,
|
points: line.points,
|
||||||
stroke: 'black',
|
stroke: shouldInvertMask ? 'white' : 'black',
|
||||||
strokeWidth: line.strokeWidth * 2,
|
strokeWidth: line.strokeWidth * 2,
|
||||||
tension: 0,
|
tension: 0,
|
||||||
lineCap: 'round',
|
lineCap: 'round',
|
||||||
|
@ -25,6 +25,7 @@ export const getCanvasData = async (state: RootState) => {
|
|||||||
boundingBoxCoordinates,
|
boundingBoxCoordinates,
|
||||||
boundingBoxDimensions,
|
boundingBoxDimensions,
|
||||||
isMaskEnabled,
|
isMaskEnabled,
|
||||||
|
shouldPreserveMaskedArea,
|
||||||
} = state.canvas;
|
} = state.canvas;
|
||||||
|
|
||||||
const boundingBox = {
|
const boundingBox = {
|
||||||
@ -58,7 +59,8 @@ export const getCanvasData = async (state: RootState) => {
|
|||||||
// For the mask layer, use the normal boundingBox
|
// For the mask layer, use the normal boundingBox
|
||||||
const maskStage = await createMaskStage(
|
const maskStage = await createMaskStage(
|
||||||
isMaskEnabled ? objects.filter(isCanvasMaskLine) : [], // only include mask lines, and only if mask is enabled
|
isMaskEnabled ? objects.filter(isCanvasMaskLine) : [], // only include mask lines, and only if mask is enabled
|
||||||
boundingBox
|
boundingBox,
|
||||||
|
shouldPreserveMaskedArea
|
||||||
);
|
);
|
||||||
const maskBlob = await konvaNodeToBlob(maskStage, boundingBox);
|
const maskBlob = await konvaNodeToBlob(maskStage, boundingBox);
|
||||||
const maskImageData = await konvaNodeToImageData(maskStage, boundingBox);
|
const maskImageData = await konvaNodeToImageData(maskStage, boundingBox);
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { isEqual, isString } from 'lodash-es';
|
import { isEqual } from 'lodash-es';
|
||||||
|
|
||||||
import {
|
import {
|
||||||
ButtonGroup,
|
ButtonGroup,
|
||||||
@ -25,8 +25,8 @@ import {
|
|||||||
} from 'features/ui/store/uiSelectors';
|
} from 'features/ui/store/uiSelectors';
|
||||||
import {
|
import {
|
||||||
setActiveTab,
|
setActiveTab,
|
||||||
setShouldHidePreview,
|
|
||||||
setShouldShowImageDetails,
|
setShouldShowImageDetails,
|
||||||
|
setShouldShowProgressInViewer,
|
||||||
} from 'features/ui/store/uiSlice';
|
} from 'features/ui/store/uiSlice';
|
||||||
import { useHotkeys } from 'react-hotkeys-hook';
|
import { useHotkeys } from 'react-hotkeys-hook';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
@ -37,23 +37,19 @@ import {
|
|||||||
FaDownload,
|
FaDownload,
|
||||||
FaExpand,
|
FaExpand,
|
||||||
FaExpandArrowsAlt,
|
FaExpandArrowsAlt,
|
||||||
FaEye,
|
|
||||||
FaEyeSlash,
|
|
||||||
FaGrinStars,
|
FaGrinStars,
|
||||||
|
FaHourglassHalf,
|
||||||
FaQuoteRight,
|
FaQuoteRight,
|
||||||
FaSeedling,
|
FaSeedling,
|
||||||
FaShare,
|
FaShare,
|
||||||
FaShareAlt,
|
FaShareAlt,
|
||||||
FaTrash,
|
|
||||||
FaWrench,
|
|
||||||
} from 'react-icons/fa';
|
} from 'react-icons/fa';
|
||||||
import { gallerySelector } from '../store/gallerySelectors';
|
import { gallerySelector } from '../store/gallerySelectors';
|
||||||
import DeleteImageModal from './DeleteImageModal';
|
|
||||||
import { useCallback } from 'react';
|
import { useCallback } from 'react';
|
||||||
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||||
import { useGetUrl } from 'common/util/getUrl';
|
import { useGetUrl } from 'common/util/getUrl';
|
||||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||||
import { useParameters } from 'features/parameters/hooks/useParameters';
|
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
||||||
import { initialImageSelected } from 'features/parameters/store/actions';
|
import { initialImageSelected } from 'features/parameters/store/actions';
|
||||||
import {
|
import {
|
||||||
requestedImageDeletion,
|
requestedImageDeletion,
|
||||||
@ -62,7 +58,6 @@ import {
|
|||||||
} from '../store/actions';
|
} from '../store/actions';
|
||||||
import FaceRestoreSettings from 'features/parameters/components/Parameters/FaceRestore/FaceRestoreSettings';
|
import FaceRestoreSettings from 'features/parameters/components/Parameters/FaceRestore/FaceRestoreSettings';
|
||||||
import UpscaleSettings from 'features/parameters/components/Parameters/Upscale/UpscaleSettings';
|
import UpscaleSettings from 'features/parameters/components/Parameters/Upscale/UpscaleSettings';
|
||||||
import { allParametersSet } from 'features/parameters/store/generationSlice';
|
|
||||||
import DeleteImageButton from './ImageActionButtons/DeleteImageButton';
|
import DeleteImageButton from './ImageActionButtons/DeleteImageButton';
|
||||||
import { useAppToaster } from 'app/components/Toaster';
|
import { useAppToaster } from 'app/components/Toaster';
|
||||||
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
||||||
@ -90,7 +85,11 @@ const currentImageButtonsSelector = createSelector(
|
|||||||
|
|
||||||
const { isLightboxOpen } = lightbox;
|
const { isLightboxOpen } = lightbox;
|
||||||
|
|
||||||
const { shouldShowImageDetails, shouldHidePreview } = ui;
|
const {
|
||||||
|
shouldShowImageDetails,
|
||||||
|
shouldHidePreview,
|
||||||
|
shouldShowProgressInViewer,
|
||||||
|
} = ui;
|
||||||
|
|
||||||
const { selectedImage } = gallery;
|
const { selectedImage } = gallery;
|
||||||
|
|
||||||
@ -112,6 +111,7 @@ const currentImageButtonsSelector = createSelector(
|
|||||||
seed: selectedImage?.metadata?.seed,
|
seed: selectedImage?.metadata?.seed,
|
||||||
prompt: selectedImage?.metadata?.positive_conditioning,
|
prompt: selectedImage?.metadata?.positive_conditioning,
|
||||||
negativePrompt: selectedImage?.metadata?.negative_conditioning,
|
negativePrompt: selectedImage?.metadata?.negative_conditioning,
|
||||||
|
shouldShowProgressInViewer,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -145,6 +145,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
image,
|
image,
|
||||||
canDeleteImage,
|
canDeleteImage,
|
||||||
shouldConfirmOnDelete,
|
shouldConfirmOnDelete,
|
||||||
|
shouldShowProgressInViewer,
|
||||||
} = useAppSelector(currentImageButtonsSelector);
|
} = useAppSelector(currentImageButtonsSelector);
|
||||||
|
|
||||||
const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled;
|
const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled;
|
||||||
@ -163,7 +164,8 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
const toaster = useAppToaster();
|
const toaster = useAppToaster();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const { recallPrompt, recallSeed, recallAllParameters } = useParameters();
|
const { recallBothPrompts, recallSeed, recallAllParameters } =
|
||||||
|
useRecallParameters();
|
||||||
|
|
||||||
// const handleCopyImage = useCallback(async () => {
|
// const handleCopyImage = useCallback(async () => {
|
||||||
// if (!image?.url) {
|
// if (!image?.url) {
|
||||||
@ -229,10 +231,6 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
});
|
});
|
||||||
}, [toaster, shouldTransformUrls, getUrl, t, image]);
|
}, [toaster, shouldTransformUrls, getUrl, t, image]);
|
||||||
|
|
||||||
const handlePreviewVisibility = useCallback(() => {
|
|
||||||
dispatch(setShouldHidePreview(!shouldHidePreview));
|
|
||||||
}, [dispatch, shouldHidePreview]);
|
|
||||||
|
|
||||||
const handleClickUseAllParameters = useCallback(() => {
|
const handleClickUseAllParameters = useCallback(() => {
|
||||||
recallAllParameters(image);
|
recallAllParameters(image);
|
||||||
}, [image, recallAllParameters]);
|
}, [image, recallAllParameters]);
|
||||||
@ -252,11 +250,11 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
useHotkeys('s', handleUseSeed, [image]);
|
useHotkeys('s', handleUseSeed, [image]);
|
||||||
|
|
||||||
const handleUsePrompt = useCallback(() => {
|
const handleUsePrompt = useCallback(() => {
|
||||||
recallPrompt(
|
recallBothPrompts(
|
||||||
image?.metadata?.positive_conditioning,
|
image?.metadata?.positive_conditioning,
|
||||||
image?.metadata?.negative_conditioning
|
image?.metadata?.negative_conditioning
|
||||||
);
|
);
|
||||||
}, [image, recallPrompt]);
|
}, [image, recallBothPrompts]);
|
||||||
|
|
||||||
useHotkeys('p', handleUsePrompt, [image]);
|
useHotkeys('p', handleUsePrompt, [image]);
|
||||||
|
|
||||||
@ -386,6 +384,10 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
}
|
}
|
||||||
}, [shouldConfirmOnDelete, onDeleteDialogOpen, handleDelete]);
|
}, [shouldConfirmOnDelete, onDeleteDialogOpen, handleDelete]);
|
||||||
|
|
||||||
|
const handleClickProgressImagesToggle = useCallback(() => {
|
||||||
|
dispatch(setShouldShowProgressInViewer(!shouldShowProgressInViewer));
|
||||||
|
}, [dispatch, shouldShowProgressInViewer]);
|
||||||
|
|
||||||
useHotkeys('delete', handleInitiateDelete, [
|
useHotkeys('delete', handleInitiateDelete, [
|
||||||
image,
|
image,
|
||||||
shouldConfirmOnDelete,
|
shouldConfirmOnDelete,
|
||||||
@ -412,8 +414,9 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
<IAIPopover
|
<IAIPopover
|
||||||
triggerComponent={
|
triggerComponent={
|
||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
isDisabled={!image}
|
|
||||||
aria-label={`${t('parameters.sendTo')}...`}
|
aria-label={`${t('parameters.sendTo')}...`}
|
||||||
|
tooltip={`${t('parameters.sendTo')}...`}
|
||||||
|
isDisabled={!image}
|
||||||
icon={<FaShareAlt />}
|
icon={<FaShareAlt />}
|
||||||
/>
|
/>
|
||||||
}
|
}
|
||||||
@ -458,28 +461,17 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
{t('parameters.copyImageToLink')}
|
{t('parameters.copyImageToLink')}
|
||||||
</IAIButton>
|
</IAIButton>
|
||||||
|
|
||||||
<Link download={true} href={getUrl(image?.image_url ?? '')}>
|
<Link
|
||||||
|
download={true}
|
||||||
|
href={getUrl(image?.image_url ?? '')}
|
||||||
|
target="_blank"
|
||||||
|
>
|
||||||
<IAIButton leftIcon={<FaDownload />} size="sm" w="100%">
|
<IAIButton leftIcon={<FaDownload />} size="sm" w="100%">
|
||||||
{t('parameters.downloadImage')}
|
{t('parameters.downloadImage')}
|
||||||
</IAIButton>
|
</IAIButton>
|
||||||
</Link>
|
</Link>
|
||||||
</Flex>
|
</Flex>
|
||||||
</IAIPopover>
|
</IAIPopover>
|
||||||
{/* <IAIIconButton
|
|
||||||
icon={shouldHidePreview ? <FaEyeSlash /> : <FaEye />}
|
|
||||||
tooltip={
|
|
||||||
!shouldHidePreview
|
|
||||||
? t('parameters.hidePreview')
|
|
||||||
: t('parameters.showPreview')
|
|
||||||
}
|
|
||||||
aria-label={
|
|
||||||
!shouldHidePreview
|
|
||||||
? t('parameters.hidePreview')
|
|
||||||
: t('parameters.showPreview')
|
|
||||||
}
|
|
||||||
isChecked={shouldHidePreview}
|
|
||||||
onClick={handlePreviewVisibility}
|
|
||||||
/> */}
|
|
||||||
{isLightboxEnabled && (
|
{isLightboxEnabled && (
|
||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
icon={<FaExpand />}
|
icon={<FaExpand />}
|
||||||
@ -604,6 +596,16 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
/>
|
/>
|
||||||
</ButtonGroup>
|
</ButtonGroup>
|
||||||
|
|
||||||
|
<ButtonGroup isAttached={true}>
|
||||||
|
<IAIIconButton
|
||||||
|
aria-label={t('settings.displayInProgress')}
|
||||||
|
tooltip={t('settings.displayInProgress')}
|
||||||
|
icon={<FaHourglassHalf />}
|
||||||
|
isChecked={shouldShowProgressInViewer}
|
||||||
|
onClick={handleClickProgressImagesToggle}
|
||||||
|
/>
|
||||||
|
</ButtonGroup>
|
||||||
|
|
||||||
<ButtonGroup isAttached={true}>
|
<ButtonGroup isAttached={true}>
|
||||||
<DeleteImageButton image={image} />
|
<DeleteImageButton image={image} />
|
||||||
</ButtonGroup>
|
</ButtonGroup>
|
||||||
|
@ -62,7 +62,6 @@ const CurrentImagePreview = () => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
e.dataTransfer.setData('invokeai/imageName', image.image_name);
|
e.dataTransfer.setData('invokeai/imageName', image.image_name);
|
||||||
e.dataTransfer.setData('invokeai/imageType', image.image_type);
|
|
||||||
e.dataTransfer.effectAllowed = 'move';
|
e.dataTransfer.effectAllowed = 'move';
|
||||||
},
|
},
|
||||||
[image]
|
[image]
|
||||||
|
@ -30,7 +30,7 @@ import { lightboxSelector } from 'features/lightbox/store/lightboxSelectors';
|
|||||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||||
import { isEqual } from 'lodash-es';
|
import { isEqual } from 'lodash-es';
|
||||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||||
import { useParameters } from 'features/parameters/hooks/useParameters';
|
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
||||||
import { initialImageSelected } from 'features/parameters/store/actions';
|
import { initialImageSelected } from 'features/parameters/store/actions';
|
||||||
import {
|
import {
|
||||||
requestedImageDeletion,
|
requestedImageDeletion,
|
||||||
@ -114,8 +114,8 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
|||||||
const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled;
|
const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled;
|
||||||
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
|
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
|
||||||
|
|
||||||
const { recallSeed, recallPrompt, recallInitialImage, recallAllParameters } =
|
const { recallBothPrompts, recallSeed, recallAllParameters } =
|
||||||
useParameters();
|
useRecallParameters();
|
||||||
|
|
||||||
const handleMouseOver = () => setIsHovered(true);
|
const handleMouseOver = () => setIsHovered(true);
|
||||||
const handleMouseOut = () => setIsHovered(false);
|
const handleMouseOut = () => setIsHovered(false);
|
||||||
@ -147,7 +147,6 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
|||||||
const handleDragStart = useCallback(
|
const handleDragStart = useCallback(
|
||||||
(e: DragEvent<HTMLDivElement>) => {
|
(e: DragEvent<HTMLDivElement>) => {
|
||||||
e.dataTransfer.setData('invokeai/imageName', image.image_name);
|
e.dataTransfer.setData('invokeai/imageName', image.image_name);
|
||||||
e.dataTransfer.setData('invokeai/imageType', image.image_type);
|
|
||||||
e.dataTransfer.effectAllowed = 'move';
|
e.dataTransfer.effectAllowed = 'move';
|
||||||
},
|
},
|
||||||
[image]
|
[image]
|
||||||
@ -155,11 +154,15 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
|||||||
|
|
||||||
// Recall parameters handlers
|
// Recall parameters handlers
|
||||||
const handleRecallPrompt = useCallback(() => {
|
const handleRecallPrompt = useCallback(() => {
|
||||||
recallPrompt(
|
recallBothPrompts(
|
||||||
image.metadata?.positive_conditioning,
|
image.metadata?.positive_conditioning,
|
||||||
image.metadata?.negative_conditioning
|
image.metadata?.negative_conditioning
|
||||||
);
|
);
|
||||||
}, [image, recallPrompt]);
|
}, [
|
||||||
|
image.metadata?.negative_conditioning,
|
||||||
|
image.metadata?.positive_conditioning,
|
||||||
|
recallBothPrompts,
|
||||||
|
]);
|
||||||
|
|
||||||
const handleRecallSeed = useCallback(() => {
|
const handleRecallSeed = useCallback(() => {
|
||||||
recallSeed(image.metadata?.seed);
|
recallSeed(image.metadata?.seed);
|
||||||
|
@ -16,7 +16,6 @@ import IAIPopover from 'common/components/IAIPopover';
|
|||||||
import IAISlider from 'common/components/IAISlider';
|
import IAISlider from 'common/components/IAISlider';
|
||||||
import { gallerySelector } from 'features/gallery/store/gallerySelectors';
|
import { gallerySelector } from 'features/gallery/store/gallerySelectors';
|
||||||
import {
|
import {
|
||||||
setCurrentCategory,
|
|
||||||
setGalleryImageMinimumWidth,
|
setGalleryImageMinimumWidth,
|
||||||
setGalleryImageObjectFit,
|
setGalleryImageObjectFit,
|
||||||
setShouldAutoSwitchToNewImages,
|
setShouldAutoSwitchToNewImages,
|
||||||
@ -31,59 +30,46 @@ import {
|
|||||||
memo,
|
memo,
|
||||||
useCallback,
|
useCallback,
|
||||||
useEffect,
|
useEffect,
|
||||||
|
useMemo,
|
||||||
useRef,
|
useRef,
|
||||||
useState,
|
useState,
|
||||||
} from 'react';
|
} from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { BsPinAngle, BsPinAngleFill } from 'react-icons/bs';
|
import { BsPinAngle, BsPinAngleFill } from 'react-icons/bs';
|
||||||
import { FaImage, FaUser, FaWrench } from 'react-icons/fa';
|
import { FaImage, FaServer, FaWrench } from 'react-icons/fa';
|
||||||
import { MdPhotoLibrary } from 'react-icons/md';
|
import { MdPhotoLibrary } from 'react-icons/md';
|
||||||
import HoverableImage from './HoverableImage';
|
import HoverableImage from './HoverableImage';
|
||||||
|
|
||||||
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||||
import { resultsAdapter } from '../store/resultsSlice';
|
|
||||||
import {
|
|
||||||
receivedResultImagesPage,
|
|
||||||
receivedUploadImagesPage,
|
|
||||||
} from 'services/thunks/gallery';
|
|
||||||
import { uploadsAdapter } from '../store/uploadsSlice';
|
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { Virtuoso, VirtuosoGrid } from 'react-virtuoso';
|
import { Virtuoso, VirtuosoGrid } from 'react-virtuoso';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import GalleryProgressImage from './GalleryProgressImage';
|
|
||||||
import { uiSelector } from 'features/ui/store/uiSelectors';
|
import { uiSelector } from 'features/ui/store/uiSelectors';
|
||||||
import { ImageDTO } from 'services/api';
|
import {
|
||||||
|
ASSETS_CATEGORIES,
|
||||||
const GALLERY_SHOW_BUTTONS_MIN_WIDTH = 290;
|
IMAGE_CATEGORIES,
|
||||||
const PROGRESS_IMAGE_PLACEHOLDER = 'PROGRESS_IMAGE_PLACEHOLDER';
|
imageCategoriesChanged,
|
||||||
|
selectImagesAll,
|
||||||
|
} from '../store/imagesSlice';
|
||||||
|
import { receivedPageOfImages } from 'services/thunks/image';
|
||||||
|
|
||||||
const categorySelector = createSelector(
|
const categorySelector = createSelector(
|
||||||
[(state: RootState) => state],
|
[(state: RootState) => state],
|
||||||
(state) => {
|
(state) => {
|
||||||
const { results, uploads, system, gallery } = state;
|
const { images } = state;
|
||||||
const { currentCategory } = gallery;
|
const { categories } = images;
|
||||||
|
|
||||||
if (currentCategory === 'results') {
|
const allImages = selectImagesAll(state);
|
||||||
const tempImages: (ImageDTO | typeof PROGRESS_IMAGE_PLACEHOLDER)[] = [];
|
const filteredImages = allImages.filter((i) =>
|
||||||
|
categories.includes(i.image_category)
|
||||||
if (system.progressImage) {
|
);
|
||||||
tempImages.push(PROGRESS_IMAGE_PLACEHOLDER);
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
images: tempImages.concat(
|
images: filteredImages,
|
||||||
resultsAdapter.getSelectors().selectAll(results)
|
isLoading: images.isLoading,
|
||||||
),
|
areMoreImagesAvailable: filteredImages.length < images.total,
|
||||||
isLoading: results.isLoading,
|
categories: images.categories,
|
||||||
areMoreImagesAvailable: results.page < results.pages - 1,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
images: uploadsAdapter.getSelectors().selectAll(uploads),
|
|
||||||
isLoading: uploads.isLoading,
|
|
||||||
areMoreImagesAvailable: uploads.page < uploads.pages - 1,
|
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
defaultSelectorOptions
|
defaultSelectorOptions
|
||||||
@ -93,7 +79,6 @@ const mainSelector = createSelector(
|
|||||||
[gallerySelector, uiSelector],
|
[gallerySelector, uiSelector],
|
||||||
(gallery, ui) => {
|
(gallery, ui) => {
|
||||||
const {
|
const {
|
||||||
currentCategory,
|
|
||||||
galleryImageMinimumWidth,
|
galleryImageMinimumWidth,
|
||||||
galleryImageObjectFit,
|
galleryImageObjectFit,
|
||||||
shouldAutoSwitchToNewImages,
|
shouldAutoSwitchToNewImages,
|
||||||
@ -104,7 +89,6 @@ const mainSelector = createSelector(
|
|||||||
const { shouldPinGallery } = ui;
|
const { shouldPinGallery } = ui;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
currentCategory,
|
|
||||||
shouldPinGallery,
|
shouldPinGallery,
|
||||||
galleryImageMinimumWidth,
|
galleryImageMinimumWidth,
|
||||||
galleryImageObjectFit,
|
galleryImageObjectFit,
|
||||||
@ -120,7 +104,6 @@ const ImageGalleryContent = () => {
|
|||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const resizeObserverRef = useRef<HTMLDivElement>(null);
|
const resizeObserverRef = useRef<HTMLDivElement>(null);
|
||||||
const [shouldShouldIconButtons, setShouldShouldIconButtons] = useState(true);
|
|
||||||
const rootRef = useRef(null);
|
const rootRef = useRef(null);
|
||||||
const [scroller, setScroller] = useState<HTMLElement | null>(null);
|
const [scroller, setScroller] = useState<HTMLElement | null>(null);
|
||||||
const [initialize, osInstance] = useOverlayScrollbars({
|
const [initialize, osInstance] = useOverlayScrollbars({
|
||||||
@ -137,7 +120,6 @@ const ImageGalleryContent = () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
const {
|
const {
|
||||||
currentCategory,
|
|
||||||
shouldPinGallery,
|
shouldPinGallery,
|
||||||
galleryImageMinimumWidth,
|
galleryImageMinimumWidth,
|
||||||
galleryImageObjectFit,
|
galleryImageObjectFit,
|
||||||
@ -146,18 +128,19 @@ const ImageGalleryContent = () => {
|
|||||||
selectedImage,
|
selectedImage,
|
||||||
} = useAppSelector(mainSelector);
|
} = useAppSelector(mainSelector);
|
||||||
|
|
||||||
const { images, areMoreImagesAvailable, isLoading } =
|
const { images, areMoreImagesAvailable, isLoading, categories } =
|
||||||
useAppSelector(categorySelector);
|
useAppSelector(categorySelector);
|
||||||
|
|
||||||
const handleClickLoadMore = () => {
|
const handleLoadMoreImages = useCallback(() => {
|
||||||
if (currentCategory === 'results') {
|
dispatch(receivedPageOfImages());
|
||||||
dispatch(receivedResultImagesPage());
|
}, [dispatch]);
|
||||||
}
|
|
||||||
|
|
||||||
if (currentCategory === 'uploads') {
|
const handleEndReached = useMemo(() => {
|
||||||
dispatch(receivedUploadImagesPage());
|
if (areMoreImagesAvailable && !isLoading) {
|
||||||
|
return handleLoadMoreImages;
|
||||||
}
|
}
|
||||||
};
|
return undefined;
|
||||||
|
}, [areMoreImagesAvailable, handleLoadMoreImages, isLoading]);
|
||||||
|
|
||||||
const handleChangeGalleryImageMinimumWidth = (v: number) => {
|
const handleChangeGalleryImageMinimumWidth = (v: number) => {
|
||||||
dispatch(setGalleryImageMinimumWidth(v));
|
dispatch(setGalleryImageMinimumWidth(v));
|
||||||
@ -168,28 +151,6 @@ const ImageGalleryContent = () => {
|
|||||||
dispatch(requestCanvasRescale());
|
dispatch(requestCanvasRescale());
|
||||||
};
|
};
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
if (!resizeObserverRef.current) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
const resizeObserver = new ResizeObserver(() => {
|
|
||||||
if (!resizeObserverRef.current) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (
|
|
||||||
resizeObserverRef.current.clientWidth < GALLERY_SHOW_BUTTONS_MIN_WIDTH
|
|
||||||
) {
|
|
||||||
setShouldShouldIconButtons(true);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
setShouldShouldIconButtons(false);
|
|
||||||
});
|
|
||||||
resizeObserver.observe(resizeObserverRef.current);
|
|
||||||
return () => resizeObserver.disconnect(); // clean up
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const { current: root } = rootRef;
|
const { current: root } = rootRef;
|
||||||
if (scroller && root) {
|
if (scroller && root) {
|
||||||
@ -209,13 +170,13 @@ const ImageGalleryContent = () => {
|
|||||||
}
|
}
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
const handleEndReached = useCallback(() => {
|
const handleClickImagesCategory = useCallback(() => {
|
||||||
if (currentCategory === 'results') {
|
dispatch(imageCategoriesChanged(IMAGE_CATEGORIES));
|
||||||
dispatch(receivedResultImagesPage());
|
}, [dispatch]);
|
||||||
} else if (currentCategory === 'uploads') {
|
|
||||||
dispatch(receivedUploadImagesPage());
|
const handleClickAssetsCategory = useCallback(() => {
|
||||||
}
|
dispatch(imageCategoriesChanged(ASSETS_CATEGORIES));
|
||||||
}, [dispatch, currentCategory]);
|
}, [dispatch]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex
|
<Flex
|
||||||
@ -232,59 +193,31 @@ const ImageGalleryContent = () => {
|
|||||||
alignItems="center"
|
alignItems="center"
|
||||||
justifyContent="space-between"
|
justifyContent="space-between"
|
||||||
>
|
>
|
||||||
<ButtonGroup
|
<ButtonGroup isAttached>
|
||||||
size="sm"
|
|
||||||
isAttached
|
|
||||||
w="max-content"
|
|
||||||
justifyContent="stretch"
|
|
||||||
>
|
|
||||||
{shouldShouldIconButtons ? (
|
|
||||||
<>
|
|
||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
aria-label={t('gallery.showGenerations')}
|
tooltip={t('gallery.images')}
|
||||||
tooltip={t('gallery.showGenerations')}
|
aria-label={t('gallery.images')}
|
||||||
isChecked={currentCategory === 'results'}
|
onClick={handleClickImagesCategory}
|
||||||
role="radio"
|
isChecked={categories === IMAGE_CATEGORIES}
|
||||||
|
size="sm"
|
||||||
icon={<FaImage />}
|
icon={<FaImage />}
|
||||||
onClick={() => dispatch(setCurrentCategory('results'))}
|
|
||||||
/>
|
/>
|
||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
aria-label={t('gallery.showUploads')}
|
tooltip={t('gallery.assets')}
|
||||||
tooltip={t('gallery.showUploads')}
|
aria-label={t('gallery.assets')}
|
||||||
role="radio"
|
onClick={handleClickAssetsCategory}
|
||||||
isChecked={currentCategory === 'uploads'}
|
isChecked={categories === ASSETS_CATEGORIES}
|
||||||
icon={<FaUser />}
|
size="sm"
|
||||||
onClick={() => dispatch(setCurrentCategory('uploads'))}
|
icon={<FaServer />}
|
||||||
/>
|
/>
|
||||||
</>
|
|
||||||
) : (
|
|
||||||
<>
|
|
||||||
<IAIButton
|
|
||||||
size="sm"
|
|
||||||
isChecked={currentCategory === 'results'}
|
|
||||||
onClick={() => dispatch(setCurrentCategory('results'))}
|
|
||||||
flexGrow={1}
|
|
||||||
>
|
|
||||||
{t('gallery.generations')}
|
|
||||||
</IAIButton>
|
|
||||||
<IAIButton
|
|
||||||
size="sm"
|
|
||||||
isChecked={currentCategory === 'uploads'}
|
|
||||||
onClick={() => dispatch(setCurrentCategory('uploads'))}
|
|
||||||
flexGrow={1}
|
|
||||||
>
|
|
||||||
{t('gallery.uploads')}
|
|
||||||
</IAIButton>
|
|
||||||
</>
|
|
||||||
)}
|
|
||||||
</ButtonGroup>
|
</ButtonGroup>
|
||||||
|
|
||||||
<Flex gap={2}>
|
<Flex gap={2}>
|
||||||
<IAIPopover
|
<IAIPopover
|
||||||
triggerComponent={
|
triggerComponent={
|
||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
size="sm"
|
tooltip={t('gallery.gallerySettings')}
|
||||||
aria-label={t('gallery.gallerySettings')}
|
aria-label={t('gallery.gallerySettings')}
|
||||||
|
size="sm"
|
||||||
icon={<FaWrench />}
|
icon={<FaWrench />}
|
||||||
/>
|
/>
|
||||||
}
|
}
|
||||||
@ -347,28 +280,17 @@ const ImageGalleryContent = () => {
|
|||||||
data={images}
|
data={images}
|
||||||
endReached={handleEndReached}
|
endReached={handleEndReached}
|
||||||
scrollerRef={(ref) => setScrollerRef(ref)}
|
scrollerRef={(ref) => setScrollerRef(ref)}
|
||||||
itemContent={(index, image) => {
|
itemContent={(index, image) => (
|
||||||
const isSelected =
|
|
||||||
image === PROGRESS_IMAGE_PLACEHOLDER
|
|
||||||
? false
|
|
||||||
: selectedImage?.image_name === image?.image_name;
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Flex sx={{ pb: 2 }}>
|
<Flex sx={{ pb: 2 }}>
|
||||||
{image === PROGRESS_IMAGE_PLACEHOLDER ? (
|
|
||||||
<GalleryProgressImage
|
|
||||||
key={PROGRESS_IMAGE_PLACEHOLDER}
|
|
||||||
/>
|
|
||||||
) : (
|
|
||||||
<HoverableImage
|
<HoverableImage
|
||||||
key={`${image.image_name}-${image.thumbnail_url}`}
|
key={`${image.image_name}-${image.thumbnail_url}`}
|
||||||
image={image}
|
image={image}
|
||||||
isSelected={isSelected}
|
isSelected={
|
||||||
|
selectedImage?.image_name === image?.image_name
|
||||||
|
}
|
||||||
/>
|
/>
|
||||||
)}
|
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
)}
|
||||||
}}
|
|
||||||
/>
|
/>
|
||||||
) : (
|
) : (
|
||||||
<VirtuosoGrid
|
<VirtuosoGrid
|
||||||
@ -380,27 +302,20 @@ const ImageGalleryContent = () => {
|
|||||||
List: ListContainer,
|
List: ListContainer,
|
||||||
}}
|
}}
|
||||||
scrollerRef={setScroller}
|
scrollerRef={setScroller}
|
||||||
itemContent={(index, image) => {
|
itemContent={(index, image) => (
|
||||||
const isSelected =
|
|
||||||
image === PROGRESS_IMAGE_PLACEHOLDER
|
|
||||||
? false
|
|
||||||
: selectedImage?.image_name === image?.image_name;
|
|
||||||
|
|
||||||
return image === PROGRESS_IMAGE_PLACEHOLDER ? (
|
|
||||||
<GalleryProgressImage key={PROGRESS_IMAGE_PLACEHOLDER} />
|
|
||||||
) : (
|
|
||||||
<HoverableImage
|
<HoverableImage
|
||||||
key={`${image.image_name}-${image.thumbnail_url}`}
|
key={`${image.image_name}-${image.thumbnail_url}`}
|
||||||
image={image}
|
image={image}
|
||||||
isSelected={isSelected}
|
isSelected={
|
||||||
|
selectedImage?.image_name === image?.image_name
|
||||||
|
}
|
||||||
/>
|
/>
|
||||||
);
|
)}
|
||||||
}}
|
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
</Box>
|
</Box>
|
||||||
<IAIButton
|
<IAIButton
|
||||||
onClick={handleClickLoadMore}
|
onClick={handleLoadMoreImages}
|
||||||
isDisabled={!areMoreImagesAvailable}
|
isDisabled={!areMoreImagesAvailable}
|
||||||
isLoading={isLoading}
|
isLoading={isLoading}
|
||||||
loadingText="Loading"
|
loadingText="Loading"
|
||||||
|
@ -31,6 +31,7 @@ import { IoArrowUndoCircleOutline } from 'react-icons/io5';
|
|||||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||||
import { ImageDTO } from 'services/api';
|
import { ImageDTO } from 'services/api';
|
||||||
import { Scheduler } from 'app/constants';
|
import { Scheduler } from 'app/constants';
|
||||||
|
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
||||||
|
|
||||||
type MetadataItemProps = {
|
type MetadataItemProps = {
|
||||||
isLink?: boolean;
|
isLink?: boolean;
|
||||||
@ -53,6 +54,11 @@ const MetadataItem = ({
|
|||||||
withCopy = false,
|
withCopy = false,
|
||||||
}: MetadataItemProps) => {
|
}: MetadataItemProps) => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
if (!value) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex gap={2}>
|
<Flex gap={2}>
|
||||||
{onClick && (
|
{onClick && (
|
||||||
@ -115,6 +121,21 @@ const memoEqualityCheck = (
|
|||||||
*/
|
*/
|
||||||
const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
const {
|
||||||
|
recallBothPrompts,
|
||||||
|
recallPositivePrompt,
|
||||||
|
recallNegativePrompt,
|
||||||
|
recallSeed,
|
||||||
|
recallInitialImage,
|
||||||
|
recallCfgScale,
|
||||||
|
recallModel,
|
||||||
|
recallScheduler,
|
||||||
|
recallSteps,
|
||||||
|
recallWidth,
|
||||||
|
recallHeight,
|
||||||
|
recallStrength,
|
||||||
|
recallAllParameters,
|
||||||
|
} = useRecallParameters();
|
||||||
|
|
||||||
useHotkeys('esc', () => {
|
useHotkeys('esc', () => {
|
||||||
dispatch(setShouldShowImageDetails(false));
|
dispatch(setShouldShowImageDetails(false));
|
||||||
@ -161,52 +182,53 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
|||||||
{metadata.type && (
|
{metadata.type && (
|
||||||
<MetadataItem label="Invocation type" value={metadata.type} />
|
<MetadataItem label="Invocation type" value={metadata.type} />
|
||||||
)}
|
)}
|
||||||
{metadata.width && (
|
{sessionId && <MetadataItem label="Session ID" value={sessionId} />}
|
||||||
<MetadataItem
|
|
||||||
label="Width"
|
|
||||||
value={metadata.width}
|
|
||||||
onClick={() => dispatch(setWidth(Number(metadata.width)))}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
{metadata.height && (
|
|
||||||
<MetadataItem
|
|
||||||
label="Height"
|
|
||||||
value={metadata.height}
|
|
||||||
onClick={() => dispatch(setHeight(Number(metadata.height)))}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
{metadata.model && (
|
|
||||||
<MetadataItem label="Model" value={metadata.model} />
|
|
||||||
)}
|
|
||||||
{metadata.positive_conditioning && (
|
{metadata.positive_conditioning && (
|
||||||
<MetadataItem
|
<MetadataItem
|
||||||
label="Prompt"
|
label="Positive Prompt"
|
||||||
labelPosition="top"
|
labelPosition="top"
|
||||||
value={
|
value={metadata.positive_conditioning}
|
||||||
typeof metadata.positive_conditioning === 'string'
|
onClick={() =>
|
||||||
? metadata.positive_conditioning
|
recallPositivePrompt(metadata.positive_conditioning)
|
||||||
: promptToString(metadata.positive_conditioning)
|
|
||||||
}
|
}
|
||||||
onClick={() => setPositivePrompt(metadata.positive_conditioning!)}
|
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{metadata.negative_conditioning && (
|
{metadata.negative_conditioning && (
|
||||||
<MetadataItem
|
<MetadataItem
|
||||||
label="Prompt"
|
label="Negative Prompt"
|
||||||
labelPosition="top"
|
labelPosition="top"
|
||||||
value={
|
value={metadata.negative_conditioning}
|
||||||
typeof metadata.negative_conditioning === 'string'
|
onClick={() =>
|
||||||
? metadata.negative_conditioning
|
recallNegativePrompt(metadata.negative_conditioning)
|
||||||
: promptToString(metadata.negative_conditioning)
|
|
||||||
}
|
}
|
||||||
onClick={() => setNegativePrompt(metadata.negative_conditioning!)}
|
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{metadata.seed !== undefined && (
|
{metadata.seed !== undefined && (
|
||||||
<MetadataItem
|
<MetadataItem
|
||||||
label="Seed"
|
label="Seed"
|
||||||
value={metadata.seed}
|
value={metadata.seed}
|
||||||
onClick={() => dispatch(setSeed(Number(metadata.seed)))}
|
onClick={() => recallSeed(metadata.seed)}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{metadata.model !== undefined && (
|
||||||
|
<MetadataItem
|
||||||
|
label="Model"
|
||||||
|
value={metadata.model}
|
||||||
|
onClick={() => recallModel(metadata.model)}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{metadata.width && (
|
||||||
|
<MetadataItem
|
||||||
|
label="Width"
|
||||||
|
value={metadata.width}
|
||||||
|
onClick={() => recallWidth(metadata.width)}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{metadata.height && (
|
||||||
|
<MetadataItem
|
||||||
|
label="Height"
|
||||||
|
value={metadata.height}
|
||||||
|
onClick={() => recallHeight(metadata.height)}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{/* {metadata.threshold !== undefined && (
|
{/* {metadata.threshold !== undefined && (
|
||||||
@ -227,23 +249,21 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
|||||||
<MetadataItem
|
<MetadataItem
|
||||||
label="Scheduler"
|
label="Scheduler"
|
||||||
value={metadata.scheduler}
|
value={metadata.scheduler}
|
||||||
onClick={() =>
|
onClick={() => recallScheduler(metadata.scheduler)}
|
||||||
dispatch(setScheduler(metadata.scheduler as Scheduler))
|
|
||||||
}
|
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{metadata.steps && (
|
{metadata.steps && (
|
||||||
<MetadataItem
|
<MetadataItem
|
||||||
label="Steps"
|
label="Steps"
|
||||||
value={metadata.steps}
|
value={metadata.steps}
|
||||||
onClick={() => dispatch(setSteps(Number(metadata.steps)))}
|
onClick={() => recallSteps(metadata.steps)}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{metadata.cfg_scale !== undefined && (
|
{metadata.cfg_scale !== undefined && (
|
||||||
<MetadataItem
|
<MetadataItem
|
||||||
label="CFG scale"
|
label="CFG scale"
|
||||||
value={metadata.cfg_scale}
|
value={metadata.cfg_scale}
|
||||||
onClick={() => dispatch(setCfgScale(Number(metadata.cfg_scale)))}
|
onClick={() => recallCfgScale(metadata.cfg_scale)}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{/* {metadata.variations && metadata.variations.length > 0 && (
|
{/* {metadata.variations && metadata.variations.length > 0 && (
|
||||||
@ -284,9 +304,7 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
|||||||
<MetadataItem
|
<MetadataItem
|
||||||
label="Image to image strength"
|
label="Image to image strength"
|
||||||
value={metadata.strength}
|
value={metadata.strength}
|
||||||
onClick={() =>
|
onClick={() => recallStrength(metadata.strength)}
|
||||||
dispatch(setImg2imgStrength(Number(metadata.strength)))
|
|
||||||
}
|
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{/* {metadata.fit && (
|
{/* {metadata.fit && (
|
||||||
|
@ -9,6 +9,10 @@ import { gallerySelector } from '../store/gallerySelectors';
|
|||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { imageSelected } from '../store/gallerySlice';
|
import { imageSelected } from '../store/gallerySlice';
|
||||||
import { useHotkeys } from 'react-hotkeys-hook';
|
import { useHotkeys } from 'react-hotkeys-hook';
|
||||||
|
import {
|
||||||
|
selectFilteredImagesAsObject,
|
||||||
|
selectFilteredImagesIds,
|
||||||
|
} from '../store/imagesSlice';
|
||||||
|
|
||||||
const nextPrevButtonTriggerAreaStyles: ChakraProps['sx'] = {
|
const nextPrevButtonTriggerAreaStyles: ChakraProps['sx'] = {
|
||||||
height: '100%',
|
height: '100%',
|
||||||
@ -21,9 +25,14 @@ const nextPrevButtonStyles: ChakraProps['sx'] = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
export const nextPrevImageButtonsSelector = createSelector(
|
export const nextPrevImageButtonsSelector = createSelector(
|
||||||
[(state: RootState) => state, gallerySelector],
|
[
|
||||||
(state, gallery) => {
|
(state: RootState) => state,
|
||||||
const { selectedImage, currentCategory } = gallery;
|
gallerySelector,
|
||||||
|
selectFilteredImagesAsObject,
|
||||||
|
selectFilteredImagesIds,
|
||||||
|
],
|
||||||
|
(state, gallery, filteredImagesAsObject, filteredImageIds) => {
|
||||||
|
const { selectedImage } = gallery;
|
||||||
|
|
||||||
if (!selectedImage) {
|
if (!selectedImage) {
|
||||||
return {
|
return {
|
||||||
@ -32,29 +41,29 @@ export const nextPrevImageButtonsSelector = createSelector(
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
const currentImageIndex = state[currentCategory].ids.findIndex(
|
const currentImageIndex = filteredImageIds.findIndex(
|
||||||
(i) => i === selectedImage.image_name
|
(i) => i === selectedImage.image_name
|
||||||
);
|
);
|
||||||
|
|
||||||
const nextImageIndex = clamp(
|
const nextImageIndex = clamp(
|
||||||
currentImageIndex + 1,
|
currentImageIndex + 1,
|
||||||
0,
|
0,
|
||||||
state[currentCategory].ids.length - 1
|
filteredImageIds.length - 1
|
||||||
);
|
);
|
||||||
|
|
||||||
const prevImageIndex = clamp(
|
const prevImageIndex = clamp(
|
||||||
currentImageIndex - 1,
|
currentImageIndex - 1,
|
||||||
0,
|
0,
|
||||||
state[currentCategory].ids.length - 1
|
filteredImageIds.length - 1
|
||||||
);
|
);
|
||||||
|
|
||||||
const nextImageId = state[currentCategory].ids[nextImageIndex];
|
const nextImageId = filteredImageIds[nextImageIndex];
|
||||||
const prevImageId = state[currentCategory].ids[prevImageIndex];
|
const prevImageId = filteredImageIds[prevImageIndex];
|
||||||
|
|
||||||
const nextImage = state[currentCategory].entities[nextImageId];
|
const nextImage = filteredImagesAsObject[nextImageId];
|
||||||
const prevImage = state[currentCategory].entities[prevImageId];
|
const prevImage = filteredImagesAsObject[prevImageId];
|
||||||
|
|
||||||
const imagesLength = state[currentCategory].ids.length;
|
const imagesLength = filteredImageIds.length;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
isOnFirstImage: currentImageIndex === 0,
|
isOnFirstImage: currentImageIndex === 0,
|
||||||
|
@ -1,33 +1,18 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { ImageType } from 'services/api';
|
import { selectImagesEntities } from '../store/imagesSlice';
|
||||||
import { selectResultsEntities } from '../store/resultsSlice';
|
import { useCallback } from 'react';
|
||||||
import { selectUploadsEntities } from '../store/uploadsSlice';
|
|
||||||
|
|
||||||
const useGetImageByNameSelector = createSelector(
|
const useGetImageByName = () => {
|
||||||
[selectResultsEntities, selectUploadsEntities],
|
const images = useAppSelector(selectImagesEntities);
|
||||||
(allResults, allUploads) => {
|
return useCallback(
|
||||||
return { allResults, allUploads };
|
(name: string | undefined) => {
|
||||||
|
if (!name) {
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
);
|
return images[name];
|
||||||
|
},
|
||||||
const useGetImageByNameAndType = () => {
|
[images]
|
||||||
const { allResults, allUploads } = useAppSelector(useGetImageByNameSelector);
|
);
|
||||||
return (name: string, type: ImageType) => {
|
|
||||||
if (type === 'results') {
|
|
||||||
const resultImagesResult = allResults[name];
|
|
||||||
if (resultImagesResult) {
|
|
||||||
return resultImagesResult;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (type === 'uploads') {
|
|
||||||
const userImagesResult = allUploads[name];
|
|
||||||
if (userImagesResult) {
|
|
||||||
return userImagesResult;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
};
|
};
|
||||||
|
|
||||||
export default useGetImageByNameAndType;
|
export default useGetImageByName;
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
import { createAction } from '@reduxjs/toolkit';
|
import { createAction } from '@reduxjs/toolkit';
|
||||||
import { ImageNameAndType } from 'features/parameters/store/actions';
|
import { ImageNameAndOrigin } from 'features/parameters/store/actions';
|
||||||
import { ImageDTO } from 'services/api';
|
import { ImageDTO } from 'services/api';
|
||||||
|
|
||||||
export const requestedImageDeletion = createAction<
|
export const requestedImageDeletion = createAction<
|
||||||
ImageDTO | ImageNameAndType | undefined
|
ImageDTO | ImageNameAndOrigin | undefined
|
||||||
>('gallery/requestedImageDeletion');
|
>('gallery/requestedImageDeletion');
|
||||||
|
|
||||||
export const sentImageToCanvas = createAction('gallery/sentImageToCanvas');
|
export const sentImageToCanvas = createAction('gallery/sentImageToCanvas');
|
||||||
|
@ -4,6 +4,5 @@ import { GalleryState } from './gallerySlice';
|
|||||||
* Gallery slice persist denylist
|
* Gallery slice persist denylist
|
||||||
*/
|
*/
|
||||||
export const galleryPersistDenylist: (keyof GalleryState)[] = [
|
export const galleryPersistDenylist: (keyof GalleryState)[] = [
|
||||||
'currentCategory',
|
|
||||||
'shouldAutoSwitchToNewImages',
|
'shouldAutoSwitchToNewImages',
|
||||||
];
|
];
|
||||||
|
@ -1,10 +1,7 @@
|
|||||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||||
import { createSlice } from '@reduxjs/toolkit';
|
import { createSlice } from '@reduxjs/toolkit';
|
||||||
import {
|
|
||||||
receivedResultImagesPage,
|
|
||||||
receivedUploadImagesPage,
|
|
||||||
} from '../../../services/thunks/gallery';
|
|
||||||
import { ImageDTO } from 'services/api';
|
import { ImageDTO } from 'services/api';
|
||||||
|
import { imageUpserted } from './imagesSlice';
|
||||||
|
|
||||||
type GalleryImageObjectFitType = 'contain' | 'cover';
|
type GalleryImageObjectFitType = 'contain' | 'cover';
|
||||||
|
|
||||||
@ -14,7 +11,6 @@ export interface GalleryState {
|
|||||||
galleryImageObjectFit: GalleryImageObjectFitType;
|
galleryImageObjectFit: GalleryImageObjectFitType;
|
||||||
shouldAutoSwitchToNewImages: boolean;
|
shouldAutoSwitchToNewImages: boolean;
|
||||||
shouldUseSingleGalleryColumn: boolean;
|
shouldUseSingleGalleryColumn: boolean;
|
||||||
currentCategory: 'results' | 'uploads';
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export const initialGalleryState: GalleryState = {
|
export const initialGalleryState: GalleryState = {
|
||||||
@ -22,7 +18,6 @@ export const initialGalleryState: GalleryState = {
|
|||||||
galleryImageObjectFit: 'cover',
|
galleryImageObjectFit: 'cover',
|
||||||
shouldAutoSwitchToNewImages: true,
|
shouldAutoSwitchToNewImages: true,
|
||||||
shouldUseSingleGalleryColumn: false,
|
shouldUseSingleGalleryColumn: false,
|
||||||
currentCategory: 'results',
|
|
||||||
};
|
};
|
||||||
|
|
||||||
export const gallerySlice = createSlice({
|
export const gallerySlice = createSlice({
|
||||||
@ -46,12 +41,6 @@ export const gallerySlice = createSlice({
|
|||||||
setShouldAutoSwitchToNewImages: (state, action: PayloadAction<boolean>) => {
|
setShouldAutoSwitchToNewImages: (state, action: PayloadAction<boolean>) => {
|
||||||
state.shouldAutoSwitchToNewImages = action.payload;
|
state.shouldAutoSwitchToNewImages = action.payload;
|
||||||
},
|
},
|
||||||
setCurrentCategory: (
|
|
||||||
state,
|
|
||||||
action: PayloadAction<'results' | 'uploads'>
|
|
||||||
) => {
|
|
||||||
state.currentCategory = action.payload;
|
|
||||||
},
|
|
||||||
setShouldUseSingleGalleryColumn: (
|
setShouldUseSingleGalleryColumn: (
|
||||||
state,
|
state,
|
||||||
action: PayloadAction<boolean>
|
action: PayloadAction<boolean>
|
||||||
@ -59,37 +48,10 @@ export const gallerySlice = createSlice({
|
|||||||
state.shouldUseSingleGalleryColumn = action.payload;
|
state.shouldUseSingleGalleryColumn = action.payload;
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
extraReducers(builder) {
|
extraReducers: (builder) => {
|
||||||
builder.addCase(receivedResultImagesPage.fulfilled, (state, action) => {
|
builder.addCase(imageUpserted, (state, action) => {
|
||||||
// rehydrate selectedImage URL when results list comes in
|
if (state.shouldAutoSwitchToNewImages) {
|
||||||
// solves case when outdated URL is in local storage
|
state.selectedImage = action.payload;
|
||||||
const selectedImage = state.selectedImage;
|
|
||||||
if (selectedImage) {
|
|
||||||
const selectedImageInResults = action.payload.items.find(
|
|
||||||
(image) => image.image_name === selectedImage.image_name
|
|
||||||
);
|
|
||||||
|
|
||||||
if (selectedImageInResults) {
|
|
||||||
selectedImage.image_url = selectedImageInResults.image_url;
|
|
||||||
selectedImage.thumbnail_url = selectedImageInResults.thumbnail_url;
|
|
||||||
state.selectedImage = selectedImage;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
builder.addCase(receivedUploadImagesPage.fulfilled, (state, action) => {
|
|
||||||
// rehydrate selectedImage URL when results list comes in
|
|
||||||
// solves case when outdated URL is in local storage
|
|
||||||
const selectedImage = state.selectedImage;
|
|
||||||
if (selectedImage) {
|
|
||||||
const selectedImageInResults = action.payload.items.find(
|
|
||||||
(image) => image.image_name === selectedImage.image_name
|
|
||||||
);
|
|
||||||
|
|
||||||
if (selectedImageInResults) {
|
|
||||||
selectedImage.image_url = selectedImageInResults.image_url;
|
|
||||||
selectedImage.thumbnail_url = selectedImageInResults.thumbnail_url;
|
|
||||||
state.selectedImage = selectedImage;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
@ -101,7 +63,6 @@ export const {
|
|||||||
setGalleryImageObjectFit,
|
setGalleryImageObjectFit,
|
||||||
setShouldAutoSwitchToNewImages,
|
setShouldAutoSwitchToNewImages,
|
||||||
setShouldUseSingleGalleryColumn,
|
setShouldUseSingleGalleryColumn,
|
||||||
setCurrentCategory,
|
|
||||||
} = gallerySlice.actions;
|
} = gallerySlice.actions;
|
||||||
|
|
||||||
export default gallerySlice.reducer;
|
export default gallerySlice.reducer;
|
||||||
|
135
invokeai/frontend/web/src/features/gallery/store/imagesSlice.ts
Normal file
135
invokeai/frontend/web/src/features/gallery/store/imagesSlice.ts
Normal file
@ -0,0 +1,135 @@
|
|||||||
|
import {
|
||||||
|
PayloadAction,
|
||||||
|
createEntityAdapter,
|
||||||
|
createSelector,
|
||||||
|
createSlice,
|
||||||
|
} from '@reduxjs/toolkit';
|
||||||
|
import { RootState } from 'app/store/store';
|
||||||
|
import { ImageCategory, ImageDTO } from 'services/api';
|
||||||
|
import { dateComparator } from 'common/util/dateComparator';
|
||||||
|
import { isString, keyBy } from 'lodash-es';
|
||||||
|
import { receivedPageOfImages } from 'services/thunks/image';
|
||||||
|
|
||||||
|
export const imagesAdapter = createEntityAdapter<ImageDTO>({
|
||||||
|
selectId: (image) => image.image_name,
|
||||||
|
sortComparer: (a, b) => dateComparator(b.created_at, a.created_at),
|
||||||
|
});
|
||||||
|
|
||||||
|
export const IMAGE_CATEGORIES: ImageCategory[] = ['general'];
|
||||||
|
export const ASSETS_CATEGORIES: ImageCategory[] = [
|
||||||
|
'control',
|
||||||
|
'mask',
|
||||||
|
'user',
|
||||||
|
'other',
|
||||||
|
];
|
||||||
|
|
||||||
|
type AdditionaImagesState = {
|
||||||
|
offset: number;
|
||||||
|
limit: number;
|
||||||
|
total: number;
|
||||||
|
isLoading: boolean;
|
||||||
|
categories: ImageCategory[];
|
||||||
|
};
|
||||||
|
|
||||||
|
export const initialImagesState =
|
||||||
|
imagesAdapter.getInitialState<AdditionaImagesState>({
|
||||||
|
offset: 0,
|
||||||
|
limit: 0,
|
||||||
|
total: 0,
|
||||||
|
isLoading: false,
|
||||||
|
categories: IMAGE_CATEGORIES,
|
||||||
|
});
|
||||||
|
|
||||||
|
export type ImagesState = typeof initialImagesState;
|
||||||
|
|
||||||
|
const imagesSlice = createSlice({
|
||||||
|
name: 'images',
|
||||||
|
initialState: initialImagesState,
|
||||||
|
reducers: {
|
||||||
|
imageUpserted: (state, action: PayloadAction<ImageDTO>) => {
|
||||||
|
imagesAdapter.upsertOne(state, action.payload);
|
||||||
|
},
|
||||||
|
imageRemoved: (state, action: PayloadAction<string | ImageDTO>) => {
|
||||||
|
if (isString(action.payload)) {
|
||||||
|
imagesAdapter.removeOne(state, action.payload);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
imagesAdapter.removeOne(state, action.payload.image_name);
|
||||||
|
},
|
||||||
|
imageCategoriesChanged: (state, action: PayloadAction<ImageCategory[]>) => {
|
||||||
|
state.categories = action.payload;
|
||||||
|
},
|
||||||
|
},
|
||||||
|
extraReducers: (builder) => {
|
||||||
|
builder.addCase(receivedPageOfImages.pending, (state) => {
|
||||||
|
state.isLoading = true;
|
||||||
|
});
|
||||||
|
builder.addCase(receivedPageOfImages.rejected, (state) => {
|
||||||
|
state.isLoading = false;
|
||||||
|
});
|
||||||
|
builder.addCase(receivedPageOfImages.fulfilled, (state, action) => {
|
||||||
|
state.isLoading = false;
|
||||||
|
const { items, offset, limit, total } = action.payload;
|
||||||
|
state.offset = offset;
|
||||||
|
state.limit = limit;
|
||||||
|
state.total = total;
|
||||||
|
imagesAdapter.upsertMany(state, items);
|
||||||
|
});
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
export const {
|
||||||
|
selectAll: selectImagesAll,
|
||||||
|
selectById: selectImagesById,
|
||||||
|
selectEntities: selectImagesEntities,
|
||||||
|
selectIds: selectImagesIds,
|
||||||
|
selectTotal: selectImagesTotal,
|
||||||
|
} = imagesAdapter.getSelectors<RootState>((state) => state.images);
|
||||||
|
|
||||||
|
export const { imageUpserted, imageRemoved, imageCategoriesChanged } =
|
||||||
|
imagesSlice.actions;
|
||||||
|
|
||||||
|
export default imagesSlice.reducer;
|
||||||
|
|
||||||
|
export const selectFilteredImagesAsArray = createSelector(
|
||||||
|
(state: RootState) => state,
|
||||||
|
(state) => {
|
||||||
|
const {
|
||||||
|
images: { categories },
|
||||||
|
} = state;
|
||||||
|
|
||||||
|
return selectImagesAll(state).filter((i) =>
|
||||||
|
categories.includes(i.image_category)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
export const selectFilteredImagesAsObject = createSelector(
|
||||||
|
(state: RootState) => state,
|
||||||
|
(state) => {
|
||||||
|
const {
|
||||||
|
images: { categories },
|
||||||
|
} = state;
|
||||||
|
|
||||||
|
return keyBy(
|
||||||
|
selectImagesAll(state).filter((i) =>
|
||||||
|
categories.includes(i.image_category)
|
||||||
|
),
|
||||||
|
'image_name'
|
||||||
|
);
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
export const selectFilteredImagesIds = createSelector(
|
||||||
|
(state: RootState) => state,
|
||||||
|
(state) => {
|
||||||
|
const {
|
||||||
|
images: { categories },
|
||||||
|
} = state;
|
||||||
|
|
||||||
|
return selectImagesAll(state)
|
||||||
|
.filter((i) => categories.includes(i.image_category))
|
||||||
|
.map((i) => i.image_name);
|
||||||
|
}
|
||||||
|
);
|
@ -1,8 +0,0 @@
|
|||||||
import { ResultsState } from './resultsSlice';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Results slice persist denylist
|
|
||||||
*
|
|
||||||
* Currently denylisting results slice entirely, see `serialize.ts`
|
|
||||||
*/
|
|
||||||
export const resultsPersistDenylist: (keyof ResultsState)[] = [];
|
|
@ -1,125 +0,0 @@
|
|||||||
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
|
|
||||||
import { RootState } from 'app/store/store';
|
|
||||||
import {
|
|
||||||
receivedResultImagesPage,
|
|
||||||
IMAGES_PER_PAGE,
|
|
||||||
} from 'services/thunks/gallery';
|
|
||||||
import {
|
|
||||||
imageDeleted,
|
|
||||||
imageMetadataReceived,
|
|
||||||
imageUrlsReceived,
|
|
||||||
} from 'services/thunks/image';
|
|
||||||
import { ImageDTO } from 'services/api';
|
|
||||||
import { dateComparator } from 'common/util/dateComparator';
|
|
||||||
|
|
||||||
export type ResultsImageDTO = Omit<ImageDTO, 'image_type'> & {
|
|
||||||
image_type: 'results';
|
|
||||||
};
|
|
||||||
|
|
||||||
export const resultsAdapter = createEntityAdapter<ResultsImageDTO>({
|
|
||||||
selectId: (image) => image.image_name,
|
|
||||||
sortComparer: (a, b) => dateComparator(b.created_at, a.created_at),
|
|
||||||
});
|
|
||||||
|
|
||||||
type AdditionalResultsState = {
|
|
||||||
page: number;
|
|
||||||
pages: number;
|
|
||||||
isLoading: boolean;
|
|
||||||
nextPage: number;
|
|
||||||
};
|
|
||||||
|
|
||||||
export const initialResultsState =
|
|
||||||
resultsAdapter.getInitialState<AdditionalResultsState>({
|
|
||||||
page: 0,
|
|
||||||
pages: 0,
|
|
||||||
isLoading: false,
|
|
||||||
nextPage: 0,
|
|
||||||
});
|
|
||||||
|
|
||||||
export type ResultsState = typeof initialResultsState;
|
|
||||||
|
|
||||||
const resultsSlice = createSlice({
|
|
||||||
name: 'results',
|
|
||||||
initialState: initialResultsState,
|
|
||||||
reducers: {
|
|
||||||
resultAdded: resultsAdapter.upsertOne,
|
|
||||||
},
|
|
||||||
extraReducers: (builder) => {
|
|
||||||
/**
|
|
||||||
* Received Result Images Page - PENDING
|
|
||||||
*/
|
|
||||||
builder.addCase(receivedResultImagesPage.pending, (state) => {
|
|
||||||
state.isLoading = true;
|
|
||||||
});
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Received Result Images Page - FULFILLED
|
|
||||||
*/
|
|
||||||
builder.addCase(receivedResultImagesPage.fulfilled, (state, action) => {
|
|
||||||
const { page, pages } = action.payload;
|
|
||||||
|
|
||||||
// We know these will all be of the results type, but it's not represented in the API types
|
|
||||||
const items = action.payload.items as ResultsImageDTO[];
|
|
||||||
|
|
||||||
resultsAdapter.setMany(state, items);
|
|
||||||
|
|
||||||
state.page = page;
|
|
||||||
state.pages = pages;
|
|
||||||
state.nextPage = items.length < IMAGES_PER_PAGE ? page : page + 1;
|
|
||||||
state.isLoading = false;
|
|
||||||
});
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Image Metadata Received - FULFILLED
|
|
||||||
*/
|
|
||||||
builder.addCase(imageMetadataReceived.fulfilled, (state, action) => {
|
|
||||||
const { image_type } = action.payload;
|
|
||||||
|
|
||||||
if (image_type === 'results') {
|
|
||||||
resultsAdapter.upsertOne(state, action.payload as ResultsImageDTO);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Image URLs Received - FULFILLED
|
|
||||||
*/
|
|
||||||
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
|
||||||
const { image_name, image_type, image_url, thumbnail_url } =
|
|
||||||
action.payload;
|
|
||||||
|
|
||||||
if (image_type === 'results') {
|
|
||||||
resultsAdapter.updateOne(state, {
|
|
||||||
id: image_name,
|
|
||||||
changes: {
|
|
||||||
image_url: image_url,
|
|
||||||
thumbnail_url: thumbnail_url,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Delete Image - PENDING
|
|
||||||
* Pre-emptively remove the image from the gallery
|
|
||||||
*/
|
|
||||||
builder.addCase(imageDeleted.pending, (state, action) => {
|
|
||||||
const { imageType, imageName } = action.meta.arg;
|
|
||||||
|
|
||||||
if (imageType === 'results') {
|
|
||||||
resultsAdapter.removeOne(state, imageName);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
export const {
|
|
||||||
selectAll: selectResultsAll,
|
|
||||||
selectById: selectResultsById,
|
|
||||||
selectEntities: selectResultsEntities,
|
|
||||||
selectIds: selectResultsIds,
|
|
||||||
selectTotal: selectResultsTotal,
|
|
||||||
} = resultsAdapter.getSelectors<RootState>((state) => state.results);
|
|
||||||
|
|
||||||
export const { resultAdded } = resultsSlice.actions;
|
|
||||||
|
|
||||||
export default resultsSlice.reducer;
|
|
@ -1,8 +0,0 @@
|
|||||||
import { UploadsState } from './uploadsSlice';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Uploads slice persist denylist
|
|
||||||
*
|
|
||||||
* Currently denylisting uploads slice entirely, see `serialize.ts`
|
|
||||||
*/
|
|
||||||
export const uploadsPersistDenylist: (keyof UploadsState)[] = [];
|
|
@ -1,111 +0,0 @@
|
|||||||
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
|
|
||||||
|
|
||||||
import { RootState } from 'app/store/store';
|
|
||||||
import {
|
|
||||||
receivedUploadImagesPage,
|
|
||||||
IMAGES_PER_PAGE,
|
|
||||||
} from 'services/thunks/gallery';
|
|
||||||
import { imageDeleted, imageUrlsReceived } from 'services/thunks/image';
|
|
||||||
import { ImageDTO } from 'services/api';
|
|
||||||
import { dateComparator } from 'common/util/dateComparator';
|
|
||||||
|
|
||||||
export type UploadsImageDTO = Omit<ImageDTO, 'image_type'> & {
|
|
||||||
image_type: 'uploads';
|
|
||||||
};
|
|
||||||
|
|
||||||
export const uploadsAdapter = createEntityAdapter<UploadsImageDTO>({
|
|
||||||
selectId: (image) => image.image_name,
|
|
||||||
sortComparer: (a, b) => dateComparator(b.created_at, a.created_at),
|
|
||||||
});
|
|
||||||
|
|
||||||
type AdditionalUploadsState = {
|
|
||||||
page: number;
|
|
||||||
pages: number;
|
|
||||||
isLoading: boolean;
|
|
||||||
nextPage: number;
|
|
||||||
};
|
|
||||||
|
|
||||||
export const initialUploadsState =
|
|
||||||
uploadsAdapter.getInitialState<AdditionalUploadsState>({
|
|
||||||
page: 0,
|
|
||||||
pages: 0,
|
|
||||||
nextPage: 0,
|
|
||||||
isLoading: false,
|
|
||||||
});
|
|
||||||
|
|
||||||
export type UploadsState = typeof initialUploadsState;
|
|
||||||
|
|
||||||
const uploadsSlice = createSlice({
|
|
||||||
name: 'uploads',
|
|
||||||
initialState: initialUploadsState,
|
|
||||||
reducers: {
|
|
||||||
uploadAdded: uploadsAdapter.upsertOne,
|
|
||||||
},
|
|
||||||
extraReducers: (builder) => {
|
|
||||||
/**
|
|
||||||
* Received Upload Images Page - PENDING
|
|
||||||
*/
|
|
||||||
builder.addCase(receivedUploadImagesPage.pending, (state) => {
|
|
||||||
state.isLoading = true;
|
|
||||||
});
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Received Upload Images Page - FULFILLED
|
|
||||||
*/
|
|
||||||
builder.addCase(receivedUploadImagesPage.fulfilled, (state, action) => {
|
|
||||||
const { page, pages } = action.payload;
|
|
||||||
|
|
||||||
// We know these will all be of the uploads type, but it's not represented in the API types
|
|
||||||
const items = action.payload.items as UploadsImageDTO[];
|
|
||||||
|
|
||||||
uploadsAdapter.setMany(state, items);
|
|
||||||
|
|
||||||
state.page = page;
|
|
||||||
state.pages = pages;
|
|
||||||
state.nextPage = items.length < IMAGES_PER_PAGE ? page : page + 1;
|
|
||||||
state.isLoading = false;
|
|
||||||
});
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Image URLs Received - FULFILLED
|
|
||||||
*/
|
|
||||||
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
|
||||||
const { image_name, image_type, image_url, thumbnail_url } =
|
|
||||||
action.payload;
|
|
||||||
|
|
||||||
if (image_type === 'uploads') {
|
|
||||||
uploadsAdapter.updateOne(state, {
|
|
||||||
id: image_name,
|
|
||||||
changes: {
|
|
||||||
image_url: image_url,
|
|
||||||
thumbnail_url: thumbnail_url,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Delete Image - pending
|
|
||||||
* Pre-emptively remove the image from the gallery
|
|
||||||
*/
|
|
||||||
builder.addCase(imageDeleted.pending, (state, action) => {
|
|
||||||
const { imageType, imageName } = action.meta.arg;
|
|
||||||
|
|
||||||
if (imageType === 'uploads') {
|
|
||||||
uploadsAdapter.removeOne(state, imageName);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
export const {
|
|
||||||
selectAll: selectUploadsAll,
|
|
||||||
selectById: selectUploadsById,
|
|
||||||
selectEntities: selectUploadsEntities,
|
|
||||||
selectIds: selectUploadsIds,
|
|
||||||
selectTotal: selectUploadsTotal,
|
|
||||||
} = uploadsAdapter.getSelectors<RootState>((state) => state.uploads);
|
|
||||||
|
|
||||||
export const { uploadAdded } = uploadsSlice.actions;
|
|
||||||
|
|
||||||
export default uploadsSlice.reducer;
|
|
@ -10,6 +10,7 @@ import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComp
|
|||||||
import UNetInputFieldComponent from './fields/UNetInputFieldComponent';
|
import UNetInputFieldComponent from './fields/UNetInputFieldComponent';
|
||||||
import ClipInputFieldComponent from './fields/ClipInputFieldComponent';
|
import ClipInputFieldComponent from './fields/ClipInputFieldComponent';
|
||||||
import VaeInputFieldComponent from './fields/VaeInputFieldComponent';
|
import VaeInputFieldComponent from './fields/VaeInputFieldComponent';
|
||||||
|
import ControlInputFieldComponent from './fields/ControlInputFieldComponent';
|
||||||
import ModelInputFieldComponent from './fields/ModelInputFieldComponent';
|
import ModelInputFieldComponent from './fields/ModelInputFieldComponent';
|
||||||
import NumberInputFieldComponent from './fields/NumberInputFieldComponent';
|
import NumberInputFieldComponent from './fields/NumberInputFieldComponent';
|
||||||
import StringInputFieldComponent from './fields/StringInputFieldComponent';
|
import StringInputFieldComponent from './fields/StringInputFieldComponent';
|
||||||
@ -130,6 +131,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (type === 'control' && template.type === 'control') {
|
||||||
|
return (
|
||||||
|
<ControlInputFieldComponent
|
||||||
|
nodeId={nodeId}
|
||||||
|
field={field}
|
||||||
|
template={template}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
if (type === 'model' && template.type === 'model') {
|
if (type === 'model' && template.type === 'model') {
|
||||||
return (
|
return (
|
||||||
<ModelInputFieldComponent
|
<ModelInputFieldComponent
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user