mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into doc_updates_23
This commit is contained in:
commit
00cb8a0c64
14
.github/CODEOWNERS
vendored
14
.github/CODEOWNERS
vendored
@ -2,7 +2,7 @@
|
|||||||
/.github/workflows/ @lstein @blessedcoolant
|
/.github/workflows/ @lstein @blessedcoolant
|
||||||
|
|
||||||
# documentation
|
# documentation
|
||||||
/docs/ @lstein @tildebyte @blessedcoolant
|
/docs/ @lstein @blessedcoolant @hipsterusername
|
||||||
/mkdocs.yml @lstein @blessedcoolant
|
/mkdocs.yml @lstein @blessedcoolant
|
||||||
|
|
||||||
# nodes
|
# nodes
|
||||||
@ -18,17 +18,17 @@
|
|||||||
/invokeai/version @lstein @blessedcoolant
|
/invokeai/version @lstein @blessedcoolant
|
||||||
|
|
||||||
# web ui
|
# web ui
|
||||||
/invokeai/frontend @blessedcoolant @psychedelicious @lstein
|
/invokeai/frontend @blessedcoolant @psychedelicious @lstein @maryhipp
|
||||||
/invokeai/backend @blessedcoolant @psychedelicious @lstein
|
/invokeai/backend @blessedcoolant @psychedelicious @lstein @maryhipp
|
||||||
|
|
||||||
# generation, model management, postprocessing
|
# generation, model management, postprocessing
|
||||||
/invokeai/backend @damian0815 @lstein @blessedcoolant @jpphoto @gregghelt2
|
/invokeai/backend @damian0815 @lstein @blessedcoolant @jpphoto @gregghelt2 @StAlKeR7779
|
||||||
|
|
||||||
# front ends
|
# front ends
|
||||||
/invokeai/frontend/CLI @lstein
|
/invokeai/frontend/CLI @lstein
|
||||||
/invokeai/frontend/install @lstein @ebr
|
/invokeai/frontend/install @lstein @ebr
|
||||||
/invokeai/frontend/merge @lstein @blessedcoolant @hipsterusername
|
/invokeai/frontend/merge @lstein @blessedcoolant
|
||||||
/invokeai/frontend/training @lstein @blessedcoolant @hipsterusername
|
/invokeai/frontend/training @lstein @blessedcoolant
|
||||||
/invokeai/frontend/web @psychedelicious @blessedcoolant
|
/invokeai/frontend/web @psychedelicious @blessedcoolant @maryhipp
|
||||||
|
|
||||||
|
|
||||||
|
1
.github/workflows/test-invoke-pip.yml
vendored
1
.github/workflows/test-invoke-pip.yml
vendored
@ -125,6 +125,7 @@ jobs:
|
|||||||
--no-nsfw_checker
|
--no-nsfw_checker
|
||||||
--precision=float32
|
--precision=float32
|
||||||
--always_use_cpu
|
--always_use_cpu
|
||||||
|
--use_memory_db
|
||||||
--outdir ${{ env.INVOKEAI_OUTDIR }}/${{ matrix.python-version }}/${{ matrix.pytorch }}
|
--outdir ${{ env.INVOKEAI_OUTDIR }}/${{ matrix.python-version }}/${{ matrix.pytorch }}
|
||||||
--from_file ${{ env.TEST_PROMPTS }}
|
--from_file ${{ env.TEST_PROMPTS }}
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import io
|
import io
|
||||||
from fastapi import HTTPException, Path, Query, Request, Response, UploadFile
|
from typing import Optional
|
||||||
|
from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@ -7,7 +8,11 @@ from invokeai.app.models.image import (
|
|||||||
ImageCategory,
|
ImageCategory,
|
||||||
ImageType,
|
ImageType,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.models.image_record import ImageDTO, ImageUrlsDTO
|
from invokeai.app.services.models.image_record import (
|
||||||
|
ImageDTO,
|
||||||
|
ImageRecordChanges,
|
||||||
|
ImageUrlsDTO,
|
||||||
|
)
|
||||||
from invokeai.app.services.item_storage import PaginatedResults
|
from invokeai.app.services.item_storage import PaginatedResults
|
||||||
|
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
@ -27,10 +32,17 @@ images_router = APIRouter(prefix="/v1/images", tags=["images"])
|
|||||||
)
|
)
|
||||||
async def upload_image(
|
async def upload_image(
|
||||||
file: UploadFile,
|
file: UploadFile,
|
||||||
image_type: ImageType,
|
|
||||||
request: Request,
|
request: Request,
|
||||||
response: Response,
|
response: Response,
|
||||||
image_category: ImageCategory = ImageCategory.GENERAL,
|
image_category: ImageCategory = Query(
|
||||||
|
default=ImageCategory.GENERAL, description="The category of the image"
|
||||||
|
),
|
||||||
|
is_intermediate: bool = Query(
|
||||||
|
default=False, description="Whether this is an intermediate image"
|
||||||
|
),
|
||||||
|
session_id: Optional[str] = Query(
|
||||||
|
default=None, description="The session ID associated with this upload, if any"
|
||||||
|
),
|
||||||
) -> ImageDTO:
|
) -> ImageDTO:
|
||||||
"""Uploads an image"""
|
"""Uploads an image"""
|
||||||
if not file.content_type.startswith("image"):
|
if not file.content_type.startswith("image"):
|
||||||
@ -46,9 +58,11 @@ async def upload_image(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
image_dto = ApiDependencies.invoker.services.images.create(
|
image_dto = ApiDependencies.invoker.services.images.create(
|
||||||
pil_image,
|
image=pil_image,
|
||||||
image_type,
|
image_type=ImageType.UPLOAD,
|
||||||
image_category,
|
image_category=image_category,
|
||||||
|
session_id=session_id,
|
||||||
|
is_intermediate=is_intermediate,
|
||||||
)
|
)
|
||||||
|
|
||||||
response.status_code = 201
|
response.status_code = 201
|
||||||
@ -61,7 +75,7 @@ async def upload_image(
|
|||||||
|
|
||||||
@images_router.delete("/{image_type}/{image_name}", operation_id="delete_image")
|
@images_router.delete("/{image_type}/{image_name}", operation_id="delete_image")
|
||||||
async def delete_image(
|
async def delete_image(
|
||||||
image_type: ImageType = Query(description="The type of image to delete"),
|
image_type: ImageType = Path(description="The type of image to delete"),
|
||||||
image_name: str = Path(description="The name of the image to delete"),
|
image_name: str = Path(description="The name of the image to delete"),
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Deletes an image"""
|
"""Deletes an image"""
|
||||||
@ -73,6 +87,28 @@ async def delete_image(
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@images_router.patch(
|
||||||
|
"/{image_type}/{image_name}",
|
||||||
|
operation_id="update_image",
|
||||||
|
response_model=ImageDTO,
|
||||||
|
)
|
||||||
|
async def update_image(
|
||||||
|
image_type: ImageType = Path(description="The type of image to update"),
|
||||||
|
image_name: str = Path(description="The name of the image to update"),
|
||||||
|
image_changes: ImageRecordChanges = Body(
|
||||||
|
description="The changes to apply to the image"
|
||||||
|
),
|
||||||
|
) -> ImageDTO:
|
||||||
|
"""Updates an image"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
return ApiDependencies.invoker.services.images.update(
|
||||||
|
image_type, image_name, image_changes
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=400, detail="Failed to update image")
|
||||||
|
|
||||||
|
|
||||||
@images_router.get(
|
@images_router.get(
|
||||||
"/{image_type}/{image_name}/metadata",
|
"/{image_type}/{image_name}/metadata",
|
||||||
operation_id="get_image_metadata",
|
operation_id="get_image_metadata",
|
||||||
@ -85,9 +121,7 @@ async def get_image_metadata(
|
|||||||
"""Gets an image's metadata"""
|
"""Gets an image's metadata"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return ApiDependencies.invoker.services.images.get_dto(
|
return ApiDependencies.invoker.services.images.get_dto(image_type, image_name)
|
||||||
image_type, image_name
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=404)
|
raise HTTPException(status_code=404)
|
||||||
|
|
||||||
@ -113,9 +147,7 @@ async def get_image_full(
|
|||||||
"""Gets a full-resolution image file"""
|
"""Gets a full-resolution image file"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
path = ApiDependencies.invoker.services.images.get_path(
|
path = ApiDependencies.invoker.services.images.get_path(image_type, image_name)
|
||||||
image_type, image_name
|
|
||||||
)
|
|
||||||
|
|
||||||
if not ApiDependencies.invoker.services.images.validate_path(path):
|
if not ApiDependencies.invoker.services.images.validate_path(path):
|
||||||
raise HTTPException(status_code=404)
|
raise HTTPException(status_code=404)
|
||||||
|
@ -13,10 +13,13 @@ from typing import (
|
|||||||
|
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel, ValidationError
|
||||||
from pydantic.fields import Field
|
from pydantic.fields import Field
|
||||||
|
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||||
|
from invokeai.app.services.images import ImageService
|
||||||
|
from invokeai.app.services.metadata import CoreMetadataService
|
||||||
|
from invokeai.app.services.urls import LocalUrlService
|
||||||
|
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.metadata import PngMetadataService
|
|
||||||
from .services.default_graphs import create_system_graphs
|
from .services.default_graphs import create_system_graphs
|
||||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||||
|
|
||||||
@ -188,6 +191,9 @@ def invoke_all(context: CliContext):
|
|||||||
raise SessionError()
|
raise SessionError()
|
||||||
|
|
||||||
|
|
||||||
|
logger = logger.InvokeAILogger.getLogger()
|
||||||
|
|
||||||
|
|
||||||
def invoke_cli():
|
def invoke_cli():
|
||||||
# this gets the basic configuration
|
# this gets the basic configuration
|
||||||
config = get_invokeai_config()
|
config = get_invokeai_config()
|
||||||
@ -206,24 +212,43 @@ def invoke_cli():
|
|||||||
|
|
||||||
events = EventServiceBase()
|
events = EventServiceBase()
|
||||||
output_folder = config.output_path
|
output_folder = config.output_path
|
||||||
metadata = PngMetadataService()
|
|
||||||
|
|
||||||
# TODO: build a file/path manager?
|
# TODO: build a file/path manager?
|
||||||
|
if config.use_memory_db:
|
||||||
|
db_location = ":memory:"
|
||||||
|
else:
|
||||||
db_location = os.path.join(output_folder, "invokeai.db")
|
db_location = os.path.join(output_folder, "invokeai.db")
|
||||||
|
|
||||||
|
logger.info(f'InvokeAI database location is "{db_location}"')
|
||||||
|
|
||||||
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||||
|
filename=db_location, table_name="graph_executions"
|
||||||
|
)
|
||||||
|
|
||||||
|
urls = LocalUrlService()
|
||||||
|
metadata = CoreMetadataService()
|
||||||
|
image_record_storage = SqliteImageRecordStorage(db_location)
|
||||||
|
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
||||||
|
|
||||||
|
images = ImageService(
|
||||||
|
image_record_storage=image_record_storage,
|
||||||
|
image_file_storage=image_file_storage,
|
||||||
|
metadata=metadata,
|
||||||
|
url=urls,
|
||||||
|
logger=logger,
|
||||||
|
graph_execution_manager=graph_execution_manager,
|
||||||
|
)
|
||||||
|
|
||||||
services = InvocationServices(
|
services = InvocationServices(
|
||||||
model_manager=model_manager,
|
model_manager=model_manager,
|
||||||
events=events,
|
events=events,
|
||||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')),
|
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')),
|
||||||
images=DiskImageFileStorage(f'{output_folder}/images', metadata_service=metadata),
|
images=images,
|
||||||
metadata=metadata,
|
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
graph_library=SqliteItemStorage[LibraryGraph](
|
graph_library=SqliteItemStorage[LibraryGraph](
|
||||||
filename=db_location, table_name="graphs"
|
filename=db_location, table_name="graphs"
|
||||||
),
|
),
|
||||||
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
graph_execution_manager=graph_execution_manager,
|
||||||
filename=db_location, table_name="graph_executions"
|
|
||||||
),
|
|
||||||
processor=DefaultInvocationProcessor(),
|
processor=DefaultInvocationProcessor(),
|
||||||
restoration=RestorationServices(config,logger=logger),
|
restoration=RestorationServices(config,logger=logger),
|
||||||
logger=logger,
|
logger=logger,
|
||||||
|
@ -78,6 +78,7 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
|
|
||||||
#fmt: off
|
#fmt: off
|
||||||
id: str = Field(description="The id of this node. Must be unique among all nodes.")
|
id: str = Field(description="The id of this node. Must be unique among all nodes.")
|
||||||
|
is_intermediate: bool = Field(default=False, description="Whether or not this node is an intermediate node.")
|
||||||
#fmt: on
|
#fmt: on
|
||||||
|
|
||||||
|
|
||||||
@ -95,6 +96,7 @@ class UIConfig(TypedDict, total=False):
|
|||||||
"image",
|
"image",
|
||||||
"latents",
|
"latents",
|
||||||
"model",
|
"model",
|
||||||
|
"control",
|
||||||
],
|
],
|
||||||
]
|
]
|
||||||
tags: List[str]
|
tags: List[str]
|
||||||
|
@ -22,6 +22,14 @@ class IntCollectionOutput(BaseInvocationOutput):
|
|||||||
# Outputs
|
# Outputs
|
||||||
collection: list[int] = Field(default=[], description="The int collection")
|
collection: list[int] = Field(default=[], description="The int collection")
|
||||||
|
|
||||||
|
class FloatCollectionOutput(BaseInvocationOutput):
|
||||||
|
"""A collection of floats"""
|
||||||
|
|
||||||
|
type: Literal["float_collection"] = "float_collection"
|
||||||
|
|
||||||
|
# Outputs
|
||||||
|
collection: list[float] = Field(default=[], description="The float collection")
|
||||||
|
|
||||||
|
|
||||||
class RangeInvocation(BaseInvocation):
|
class RangeInvocation(BaseInvocation):
|
||||||
"""Creates a range of numbers from start to stop with step"""
|
"""Creates a range of numbers from start to stop with step"""
|
||||||
|
428
invokeai/app/invocations/controlnet_image_processors.py
Normal file
428
invokeai/app/invocations/controlnet_image_processors.py
Normal file
@ -0,0 +1,428 @@
|
|||||||
|
# InvokeAI nodes for ControlNet image preprocessors
|
||||||
|
# initial implementation by Gregg Helt, 2023
|
||||||
|
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from typing import Literal, Optional, Union, List
|
||||||
|
from PIL import Image, ImageFilter, ImageOps
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from ..models.image import ImageField, ImageType, ImageCategory
|
||||||
|
from .baseinvocation import (
|
||||||
|
BaseInvocation,
|
||||||
|
BaseInvocationOutput,
|
||||||
|
InvocationContext,
|
||||||
|
InvocationConfig,
|
||||||
|
)
|
||||||
|
from controlnet_aux import (
|
||||||
|
CannyDetector,
|
||||||
|
HEDdetector,
|
||||||
|
LineartDetector,
|
||||||
|
LineartAnimeDetector,
|
||||||
|
MidasDetector,
|
||||||
|
MLSDdetector,
|
||||||
|
NormalBaeDetector,
|
||||||
|
OpenposeDetector,
|
||||||
|
PidiNetDetector,
|
||||||
|
ContentShuffleDetector,
|
||||||
|
ZoeDetector,
|
||||||
|
MediapipeFaceDetector,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .image import ImageOutput, PILInvocationConfig
|
||||||
|
|
||||||
|
CONTROLNET_DEFAULT_MODELS = [
|
||||||
|
###########################################
|
||||||
|
# lllyasviel sd v1.5, ControlNet v1.0 models
|
||||||
|
##############################################
|
||||||
|
"lllyasviel/sd-controlnet-canny",
|
||||||
|
"lllyasviel/sd-controlnet-depth",
|
||||||
|
"lllyasviel/sd-controlnet-hed",
|
||||||
|
"lllyasviel/sd-controlnet-seg",
|
||||||
|
"lllyasviel/sd-controlnet-openpose",
|
||||||
|
"lllyasviel/sd-controlnet-scribble",
|
||||||
|
"lllyasviel/sd-controlnet-normal",
|
||||||
|
"lllyasviel/sd-controlnet-mlsd",
|
||||||
|
|
||||||
|
#############################################
|
||||||
|
# lllyasviel sd v1.5, ControlNet v1.1 models
|
||||||
|
#############################################
|
||||||
|
"lllyasviel/control_v11p_sd15_canny",
|
||||||
|
"lllyasviel/control_v11p_sd15_openpose",
|
||||||
|
"lllyasviel/control_v11p_sd15_seg",
|
||||||
|
# "lllyasviel/control_v11p_sd15_depth", # broken
|
||||||
|
"lllyasviel/control_v11f1p_sd15_depth",
|
||||||
|
"lllyasviel/control_v11p_sd15_normalbae",
|
||||||
|
"lllyasviel/control_v11p_sd15_scribble",
|
||||||
|
"lllyasviel/control_v11p_sd15_mlsd",
|
||||||
|
"lllyasviel/control_v11p_sd15_softedge",
|
||||||
|
"lllyasviel/control_v11p_sd15s2_lineart_anime",
|
||||||
|
"lllyasviel/control_v11p_sd15_lineart",
|
||||||
|
"lllyasviel/control_v11p_sd15_inpaint",
|
||||||
|
# "lllyasviel/control_v11u_sd15_tile",
|
||||||
|
# problem (temporary?) with huffingface "lllyasviel/control_v11u_sd15_tile",
|
||||||
|
# so for now replace "lllyasviel/control_v11f1e_sd15_tile",
|
||||||
|
"lllyasviel/control_v11e_sd15_shuffle",
|
||||||
|
"lllyasviel/control_v11e_sd15_ip2p",
|
||||||
|
"lllyasviel/control_v11f1e_sd15_tile",
|
||||||
|
|
||||||
|
#################################################
|
||||||
|
# thibaud sd v2.1 models (ControlNet v1.0? or v1.1?
|
||||||
|
##################################################
|
||||||
|
"thibaud/controlnet-sd21-openpose-diffusers",
|
||||||
|
"thibaud/controlnet-sd21-canny-diffusers",
|
||||||
|
"thibaud/controlnet-sd21-depth-diffusers",
|
||||||
|
"thibaud/controlnet-sd21-scribble-diffusers",
|
||||||
|
"thibaud/controlnet-sd21-hed-diffusers",
|
||||||
|
"thibaud/controlnet-sd21-zoedepth-diffusers",
|
||||||
|
"thibaud/controlnet-sd21-color-diffusers",
|
||||||
|
"thibaud/controlnet-sd21-openposev2-diffusers",
|
||||||
|
"thibaud/controlnet-sd21-lineart-diffusers",
|
||||||
|
"thibaud/controlnet-sd21-normalbae-diffusers",
|
||||||
|
"thibaud/controlnet-sd21-ade20k-diffusers",
|
||||||
|
|
||||||
|
##############################################
|
||||||
|
# ControlNetMediaPipeface, ControlNet v1.1
|
||||||
|
##############################################
|
||||||
|
# ["CrucibleAI/ControlNetMediaPipeFace", "diffusion_sd15"], # SD 1.5
|
||||||
|
# diffusion_sd15 needs to be passed to from_pretrained() as subfolder arg
|
||||||
|
# hacked t2l to split to model & subfolder if format is "model,subfolder"
|
||||||
|
"CrucibleAI/ControlNetMediaPipeFace,diffusion_sd15", # SD 1.5
|
||||||
|
"CrucibleAI/ControlNetMediaPipeFace", # SD 2.1?
|
||||||
|
]
|
||||||
|
|
||||||
|
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
|
||||||
|
|
||||||
|
class ControlField(BaseModel):
|
||||||
|
image: ImageField = Field(default=None, description="processed image")
|
||||||
|
control_model: Optional[str] = Field(default=None, description="control model used")
|
||||||
|
control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
|
||||||
|
begin_step_percent: float = Field(default=0, ge=0, le=1,
|
||||||
|
description="% of total steps at which controlnet is first applied")
|
||||||
|
end_step_percent: float = Field(default=1, ge=0, le=1,
|
||||||
|
description="% of total steps at which controlnet is last applied")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
"required": ["image", "control_model", "control_weight", "begin_step_percent", "end_step_percent"]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ControlOutput(BaseInvocationOutput):
|
||||||
|
"""node output for ControlNet info"""
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["control_output"] = "control_output"
|
||||||
|
control: ControlField = Field(default=None, description="The control info dict")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
class ControlNetInvocation(BaseInvocation):
|
||||||
|
"""Collects ControlNet info to pass to other nodes"""
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["controlnet"] = "controlnet"
|
||||||
|
# Inputs
|
||||||
|
image: ImageField = Field(default=None, description="image to process")
|
||||||
|
control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny",
|
||||||
|
description="control model used")
|
||||||
|
control_weight: float = Field(default=1.0, ge=0, le=1, description="weight given to controlnet")
|
||||||
|
# TODO: add support in backend core for begin_step_percent, end_step_percent, guess_mode
|
||||||
|
begin_step_percent: float = Field(default=0, ge=0, le=1,
|
||||||
|
description="% of total steps at which controlnet is first applied")
|
||||||
|
end_step_percent: float = Field(default=1, ge=0, le=1,
|
||||||
|
description="% of total steps at which controlnet is last applied")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> ControlOutput:
|
||||||
|
|
||||||
|
return ControlOutput(
|
||||||
|
control=ControlField(
|
||||||
|
image=self.image,
|
||||||
|
control_model=self.control_model,
|
||||||
|
control_weight=self.control_weight,
|
||||||
|
begin_step_percent=self.begin_step_percent,
|
||||||
|
end_step_percent=self.end_step_percent,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: move image processors to separate file (image_analysis.py
|
||||||
|
class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
|
||||||
|
"""Base class for invocations that preprocess images for ControlNet"""
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["image_processor"] = "image_processor"
|
||||||
|
# Inputs
|
||||||
|
image: ImageField = Field(default=None, description="image to process")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
def run_processor(self, image):
|
||||||
|
# superclass just passes through image without processing
|
||||||
|
return image
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
|
||||||
|
raw_image = context.services.images.get_pil_image(
|
||||||
|
self.image.image_type, self.image.image_name
|
||||||
|
)
|
||||||
|
# image type should be PIL.PngImagePlugin.PngImageFile ?
|
||||||
|
processed_image = self.run_processor(raw_image)
|
||||||
|
|
||||||
|
# FIXME: what happened to image metadata?
|
||||||
|
# metadata = context.services.metadata.build_metadata(
|
||||||
|
# session_id=context.graph_execution_state_id, node=self
|
||||||
|
# )
|
||||||
|
|
||||||
|
# currently can't see processed image in node UI without a showImage node,
|
||||||
|
# so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery
|
||||||
|
image_dto = context.services.images.create(
|
||||||
|
image=processed_image,
|
||||||
|
image_type=ImageType.RESULT,
|
||||||
|
image_category=ImageCategory.GENERAL,
|
||||||
|
session_id=context.graph_execution_state_id,
|
||||||
|
node_id=self.id,
|
||||||
|
is_intermediate=self.is_intermediate
|
||||||
|
)
|
||||||
|
|
||||||
|
"""Builds an ImageOutput and its ImageField"""
|
||||||
|
processed_image_field = ImageField(
|
||||||
|
image_name=image_dto.image_name,
|
||||||
|
image_type=image_dto.image_type,
|
||||||
|
)
|
||||||
|
return ImageOutput(
|
||||||
|
image=processed_image_field,
|
||||||
|
# width=processed_image.width,
|
||||||
|
width = image_dto.width,
|
||||||
|
# height=processed_image.height,
|
||||||
|
height = image_dto.height,
|
||||||
|
# mode=processed_image.mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||||
|
"""Canny edge detection for ControlNet"""
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["canny_image_processor"] = "canny_image_processor"
|
||||||
|
# Input
|
||||||
|
low_threshold: float = Field(default=100, ge=0, description="low threshold of Canny pixel gradient")
|
||||||
|
high_threshold: float = Field(default=200, ge=0, description="high threshold of Canny pixel gradient")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def run_processor(self, image):
|
||||||
|
canny_processor = CannyDetector()
|
||||||
|
processed_image = canny_processor(image, self.low_threshold, self.high_threshold)
|
||||||
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
|
class HedImageprocessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||||
|
"""Applies HED edge detection to image"""
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["hed_image_processor"] = "hed_image_processor"
|
||||||
|
# Inputs
|
||||||
|
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
|
||||||
|
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
|
||||||
|
# safe not supported in controlnet_aux v0.0.3
|
||||||
|
# safe: bool = Field(default=False, description="whether to use safe mode")
|
||||||
|
scribble: bool = Field(default=False, description="whether to use scribble mode")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def run_processor(self, image):
|
||||||
|
hed_processor = HEDdetector.from_pretrained("lllyasviel/Annotators")
|
||||||
|
processed_image = hed_processor(image,
|
||||||
|
detect_resolution=self.detect_resolution,
|
||||||
|
image_resolution=self.image_resolution,
|
||||||
|
# safe not supported in controlnet_aux v0.0.3
|
||||||
|
# safe=self.safe,
|
||||||
|
scribble=self.scribble,
|
||||||
|
)
|
||||||
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
|
class LineartImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||||
|
"""Applies line art processing to image"""
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["lineart_image_processor"] = "lineart_image_processor"
|
||||||
|
# Inputs
|
||||||
|
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
|
||||||
|
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
|
||||||
|
coarse: bool = Field(default=False, description="whether to use coarse mode")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def run_processor(self, image):
|
||||||
|
lineart_processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
|
processed_image = lineart_processor(image,
|
||||||
|
detect_resolution=self.detect_resolution,
|
||||||
|
image_resolution=self.image_resolution,
|
||||||
|
coarse=self.coarse)
|
||||||
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
|
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||||
|
"""Applies line art anime processing to image"""
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["lineart_anime_image_processor"] = "lineart_anime_image_processor"
|
||||||
|
# Inputs
|
||||||
|
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
|
||||||
|
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def run_processor(self, image):
|
||||||
|
processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
|
processed_image = processor(image,
|
||||||
|
detect_resolution=self.detect_resolution,
|
||||||
|
image_resolution=self.image_resolution,
|
||||||
|
)
|
||||||
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
|
class OpenposeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||||
|
"""Applies Openpose processing to image"""
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["openpose_image_processor"] = "openpose_image_processor"
|
||||||
|
# Inputs
|
||||||
|
hand_and_face: bool = Field(default=False, description="whether to use hands and face mode")
|
||||||
|
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
|
||||||
|
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def run_processor(self, image):
|
||||||
|
openpose_processor = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
|
processed_image = openpose_processor(image,
|
||||||
|
detect_resolution=self.detect_resolution,
|
||||||
|
image_resolution=self.image_resolution,
|
||||||
|
hand_and_face=self.hand_and_face,
|
||||||
|
)
|
||||||
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
|
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||||
|
"""Applies Midas depth processing to image"""
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["midas_depth_image_processor"] = "midas_depth_image_processor"
|
||||||
|
# Inputs
|
||||||
|
a_mult: float = Field(default=2.0, ge=0, description="Midas parameter a = amult * PI")
|
||||||
|
bg_th: float = Field(default=0.1, ge=0, description="Midas parameter bg_th")
|
||||||
|
# depth_and_normal not supported in controlnet_aux v0.0.3
|
||||||
|
# depth_and_normal: bool = Field(default=False, description="whether to use depth and normal mode")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def run_processor(self, image):
|
||||||
|
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
|
processed_image = midas_processor(image,
|
||||||
|
a=np.pi * self.a_mult,
|
||||||
|
bg_th=self.bg_th,
|
||||||
|
# dept_and_normal not supported in controlnet_aux v0.0.3
|
||||||
|
# depth_and_normal=self.depth_and_normal,
|
||||||
|
)
|
||||||
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
|
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||||
|
"""Applies NormalBae processing to image"""
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["normalbae_image_processor"] = "normalbae_image_processor"
|
||||||
|
# Inputs
|
||||||
|
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
|
||||||
|
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def run_processor(self, image):
|
||||||
|
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
|
processed_image = normalbae_processor(image,
|
||||||
|
detect_resolution=self.detect_resolution,
|
||||||
|
image_resolution=self.image_resolution)
|
||||||
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
|
class MlsdImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||||
|
"""Applies MLSD processing to image"""
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["mlsd_image_processor"] = "mlsd_image_processor"
|
||||||
|
# Inputs
|
||||||
|
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
|
||||||
|
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
|
||||||
|
thr_v: float = Field(default=0.1, ge=0, description="MLSD parameter thr_v")
|
||||||
|
thr_d: float = Field(default=0.1, ge=0, description="MLSD parameter thr_d")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def run_processor(self, image):
|
||||||
|
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
|
||||||
|
processed_image = mlsd_processor(image,
|
||||||
|
detect_resolution=self.detect_resolution,
|
||||||
|
image_resolution=self.image_resolution,
|
||||||
|
thr_v=self.thr_v,
|
||||||
|
thr_d=self.thr_d)
|
||||||
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
|
class PidiImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||||
|
"""Applies PIDI processing to image"""
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["pidi_image_processor"] = "pidi_image_processor"
|
||||||
|
# Inputs
|
||||||
|
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
|
||||||
|
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
|
||||||
|
safe: bool = Field(default=False, description="whether to use safe mode")
|
||||||
|
scribble: bool = Field(default=False, description="whether to use scribble mode")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def run_processor(self, image):
|
||||||
|
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
|
processed_image = pidi_processor(image,
|
||||||
|
detect_resolution=self.detect_resolution,
|
||||||
|
image_resolution=self.image_resolution,
|
||||||
|
safe=self.safe,
|
||||||
|
scribble=self.scribble)
|
||||||
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
|
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||||
|
"""Applies content shuffle processing to image"""
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["content_shuffle_image_processor"] = "content_shuffle_image_processor"
|
||||||
|
# Inputs
|
||||||
|
detect_resolution: int = Field(default=512, ge=0, description="pixel resolution for edge detection")
|
||||||
|
image_resolution: int = Field(default=512, ge=0, description="pixel resolution for output image")
|
||||||
|
h: Union[int | None] = Field(default=512, ge=0, description="content shuffle h parameter")
|
||||||
|
w: Union[int | None] = Field(default=512, ge=0, description="content shuffle w parameter")
|
||||||
|
f: Union[int | None] = Field(default=256, ge=0, description="cont")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def run_processor(self, image):
|
||||||
|
content_shuffle_processor = ContentShuffleDetector()
|
||||||
|
processed_image = content_shuffle_processor(image,
|
||||||
|
detect_resolution=self.detect_resolution,
|
||||||
|
image_resolution=self.image_resolution,
|
||||||
|
h=self.h,
|
||||||
|
w=self.w,
|
||||||
|
f=self.f
|
||||||
|
)
|
||||||
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
|
# should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13
|
||||||
|
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||||
|
"""Applies Zoe depth processing to image"""
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["zoe_depth_image_processor"] = "zoe_depth_image_processor"
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def run_processor(self, image):
|
||||||
|
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
|
processed_image = zoe_depth_processor(image)
|
||||||
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
|
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||||
|
"""Applies mediapipe face processing to image"""
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["mediapipe_face_processor"] = "mediapipe_face_processor"
|
||||||
|
# Inputs
|
||||||
|
max_faces: int = Field(default=1, ge=1, description="maximum number of faces to detect")
|
||||||
|
min_confidence: float = Field(default=0.5, ge=0, le=1, description="minimum confidence for face detection")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def run_processor(self, image):
|
||||||
|
mediapipe_face_processor = MediapipeFaceDetector()
|
||||||
|
processed_image = mediapipe_face_processor(image, max_faces=self.max_faces, min_confidence=self.min_confidence)
|
||||||
|
return processed_image
|
@ -57,10 +57,11 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
|
|||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=image_inpainted,
|
image=image_inpainted,
|
||||||
image_type=ImageType.INTERMEDIATE,
|
image_type=ImageType.RESULT,
|
||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
|
@ -4,7 +4,9 @@ from functools import partial
|
|||||||
from typing import Literal, Optional, Union, get_args
|
from typing import Literal, Optional, Union, get_args
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from diffusers import ControlNetModel
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
import torch
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@ -56,8 +58,11 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
|||||||
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", )
|
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", )
|
||||||
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
|
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
|
||||||
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||||
scheduler: SAMPLER_NAME_VALUES = Field(default="lms", description="The scheduler to use" )
|
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||||
model: str = Field(default="", description="The model to use (currently ignored)")
|
model: str = Field(default="", description="The model to use (currently ignored)")
|
||||||
|
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
|
||||||
|
control_model: Optional[str] = Field(default=None, description="The control model to use")
|
||||||
|
control_image: Optional[ImageField] = Field(default=None, description="The processed control image")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||||
@ -78,17 +83,35 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
|||||||
# Handle invalid model parameter
|
# Handle invalid model parameter
|
||||||
model = choose_model(context.services.model_manager, self.model)
|
model = choose_model(context.services.model_manager, self.model)
|
||||||
|
|
||||||
|
# loading controlnet image (currently requires pre-processed image)
|
||||||
|
control_image = (
|
||||||
|
None if self.control_image is None
|
||||||
|
else context.services.images.get(
|
||||||
|
self.control_image.image_type, self.control_image.image_name
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# loading controlnet model
|
||||||
|
if (self.control_model is None or self.control_model==''):
|
||||||
|
control_model = None
|
||||||
|
else:
|
||||||
|
# FIXME: change this to dropdown menu?
|
||||||
|
# FIXME: generalize so don't have to hardcode torch_dtype and device
|
||||||
|
control_model = ControlNetModel.from_pretrained(self.control_model,
|
||||||
|
torch_dtype=torch.float16).to("cuda")
|
||||||
|
|
||||||
# Get the source node id (we are invoking the prepared node)
|
# Get the source node id (we are invoking the prepared node)
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(
|
graph_execution_state = context.services.graph_execution_manager.get(
|
||||||
context.graph_execution_state_id
|
context.graph_execution_state_id
|
||||||
)
|
)
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||||
|
|
||||||
outputs = Txt2Img(model).generate(
|
txt2img = Txt2Img(model, control_model=control_model)
|
||||||
|
outputs = txt2img.generate(
|
||||||
prompt=self.prompt,
|
prompt=self.prompt,
|
||||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||||
|
control_image=control_image,
|
||||||
**self.dict(
|
**self.dict(
|
||||||
exclude={"prompt"}
|
exclude={"prompt", "control_image" }
|
||||||
), # Shorthand for passing all of the parameters above manually
|
), # Shorthand for passing all of the parameters above manually
|
||||||
)
|
)
|
||||||
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
||||||
@ -101,6 +124,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
|||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
@ -181,6 +205,7 @@ class ImageToImageInvocation(TextToImageInvocation):
|
|||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
@ -296,6 +321,7 @@ class InpaintInvocation(ImageToImageInvocation):
|
|||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
|
@ -143,6 +143,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
@ -204,6 +205,7 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
@ -242,6 +244,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
image_category=ImageCategory.MASK,
|
image_category=ImageCategory.MASK,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
)
|
)
|
||||||
|
|
||||||
return MaskOutput(
|
return MaskOutput(
|
||||||
@ -280,6 +283,7 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
@ -318,6 +322,7 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
@ -356,6 +361,7 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
@ -397,6 +403,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
@ -437,6 +444,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
@ -482,6 +490,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
|
@ -149,6 +149,7 @@ class InfillColorInvocation(BaseInvocation):
|
|||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
@ -193,6 +194,7 @@ class InfillTileInvocation(BaseInvocation):
|
|||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
@ -230,6 +232,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
|
|||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
|
@ -1,8 +1,11 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
import random
|
import random
|
||||||
from typing import Literal, Optional, Union
|
|
||||||
import einops
|
import einops
|
||||||
|
from typing import Literal, Optional, Union, List
|
||||||
|
|
||||||
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, validator
|
from pydantic import BaseModel, Field, validator
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -11,14 +14,18 @@ from invokeai.app.models.image import ImageCategory
|
|||||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
|
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
|
from .controlnet_image_processors import ControlField
|
||||||
|
|
||||||
from ...backend.model_management.model_manager import ModelManager
|
from ...backend.model_management.model_manager import ModelManager
|
||||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||||
from ...backend.image_util.seamless import configure_model_padding
|
from ...backend.image_util.seamless import configure_model_padding
|
||||||
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
|
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
|
||||||
|
|
||||||
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor
|
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor
|
||||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
|
from ...backend.stable_diffusion.diffusers_pipeline import ControlNetData
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from ..services.image_file_storage import ImageType
|
from ..services.image_file_storage import ImageType
|
||||||
@ -28,7 +35,7 @@ from .compel import ConditioningField
|
|||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||||
import diffusers
|
import diffusers
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline, ControlNetModel
|
||||||
|
|
||||||
|
|
||||||
class LatentsField(BaseModel):
|
class LatentsField(BaseModel):
|
||||||
@ -167,8 +174,9 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
noise: Optional[LatentsField] = Field(description="The noise to use")
|
noise: Optional[LatentsField] = Field(description="The noise to use")
|
||||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||||
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||||
scheduler: SAMPLER_NAME_VALUES = Field(default="lms", description="The scheduler to use" )
|
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||||
model: str = Field(default="", description="The model to use (currently ignored)")
|
model: str = Field(default="", description="The model to use (currently ignored)")
|
||||||
|
control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
||||||
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||||
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
@ -179,7 +187,8 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
"ui": {
|
"ui": {
|
||||||
"tags": ["latents", "image"],
|
"tags": ["latents", "image"],
|
||||||
"type_hints": {
|
"type_hints": {
|
||||||
"model": "model"
|
"model": "model",
|
||||||
|
"control": "control",
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -238,6 +247,81 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
).add_scheduler_args_if_applicable(model.scheduler, eta=0.0)#ddim_eta)
|
).add_scheduler_args_if_applicable(model.scheduler, eta=0.0)#ddim_eta)
|
||||||
return conditioning_data
|
return conditioning_data
|
||||||
|
|
||||||
|
def prep_control_data(self,
|
||||||
|
context: InvocationContext,
|
||||||
|
model: StableDiffusionGeneratorPipeline, # really only need model for dtype and device
|
||||||
|
control_input: List[ControlField],
|
||||||
|
latents_shape: List[int],
|
||||||
|
do_classifier_free_guidance: bool = True,
|
||||||
|
) -> List[ControlNetData]:
|
||||||
|
# assuming fixed dimensional scaling of 8:1 for image:latents
|
||||||
|
control_height_resize = latents_shape[2] * 8
|
||||||
|
control_width_resize = latents_shape[3] * 8
|
||||||
|
if control_input is None:
|
||||||
|
# print("control input is None")
|
||||||
|
control_list = None
|
||||||
|
elif isinstance(control_input, list) and len(control_input) == 0:
|
||||||
|
# print("control input is empty list")
|
||||||
|
control_list = None
|
||||||
|
elif isinstance(control_input, ControlField):
|
||||||
|
# print("control input is ControlField")
|
||||||
|
control_list = [control_input]
|
||||||
|
elif isinstance(control_input, list) and len(control_input) > 0 and isinstance(control_input[0], ControlField):
|
||||||
|
# print("control input is list[ControlField]")
|
||||||
|
control_list = control_input
|
||||||
|
else:
|
||||||
|
# print("input control is unrecognized:", type(self.control))
|
||||||
|
control_list = None
|
||||||
|
if (control_list is None):
|
||||||
|
control_data = None
|
||||||
|
# from above handling, any control that is not None should now be of type list[ControlField]
|
||||||
|
else:
|
||||||
|
# FIXME: add checks to skip entry if model or image is None
|
||||||
|
# and if weight is None, populate with default 1.0?
|
||||||
|
control_data = []
|
||||||
|
control_models = []
|
||||||
|
for control_info in control_list:
|
||||||
|
# handle control models
|
||||||
|
if ("," in control_info.control_model):
|
||||||
|
control_model_split = control_info.control_model.split(",")
|
||||||
|
control_name = control_model_split[0]
|
||||||
|
control_subfolder = control_model_split[1]
|
||||||
|
print("Using HF model subfolders")
|
||||||
|
print(" control_name: ", control_name)
|
||||||
|
print(" control_subfolder: ", control_subfolder)
|
||||||
|
control_model = ControlNetModel.from_pretrained(control_name,
|
||||||
|
subfolder=control_subfolder,
|
||||||
|
torch_dtype=model.unet.dtype).to(model.device)
|
||||||
|
else:
|
||||||
|
control_model = ControlNetModel.from_pretrained(control_info.control_model,
|
||||||
|
torch_dtype=model.unet.dtype).to(model.device)
|
||||||
|
control_models.append(control_model)
|
||||||
|
control_image_field = control_info.image
|
||||||
|
input_image = context.services.images.get_pil_image(control_image_field.image_type,
|
||||||
|
control_image_field.image_name)
|
||||||
|
# self.image.image_type, self.image.image_name
|
||||||
|
# FIXME: still need to test with different widths, heights, devices, dtypes
|
||||||
|
# and add in batch_size, num_images_per_prompt?
|
||||||
|
# and do real check for classifier_free_guidance?
|
||||||
|
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
|
||||||
|
control_image = model.prepare_control_image(
|
||||||
|
image=input_image,
|
||||||
|
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||||
|
width=control_width_resize,
|
||||||
|
height=control_height_resize,
|
||||||
|
# batch_size=batch_size * num_images_per_prompt,
|
||||||
|
# num_images_per_prompt=num_images_per_prompt,
|
||||||
|
device=control_model.device,
|
||||||
|
dtype=control_model.dtype,
|
||||||
|
)
|
||||||
|
control_item = ControlNetData(model=control_model,
|
||||||
|
image_tensor=control_image,
|
||||||
|
weight=control_info.control_weight,
|
||||||
|
begin_step_percent=control_info.begin_step_percent,
|
||||||
|
end_step_percent=control_info.end_step_percent)
|
||||||
|
control_data.append(control_item)
|
||||||
|
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
|
||||||
|
return control_data
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
noise = context.services.latents.get(self.noise.latents_name)
|
noise = context.services.latents.get(self.noise.latents_name)
|
||||||
@ -252,14 +336,19 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
model = self.get_model(context.services.model_manager)
|
model = self.get_model(context.services.model_manager)
|
||||||
conditioning_data = self.get_conditioning_data(context, model)
|
conditioning_data = self.get_conditioning_data(context, model)
|
||||||
|
|
||||||
# TODO: Verify the noise is the right size
|
print("type of control input: ", type(self.control))
|
||||||
|
control_data = self.prep_control_data(model=model, context=context, control_input=self.control,
|
||||||
|
latents_shape=noise.shape,
|
||||||
|
do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||||
|
|
||||||
|
# TODO: Verify the noise is the right size
|
||||||
result_latents, result_attention_map_saver = model.latents_from_embeddings(
|
result_latents, result_attention_map_saver = model.latents_from_embeddings(
|
||||||
latents=torch.zeros_like(noise, dtype=torch_dtype(model.device)),
|
latents=torch.zeros_like(noise, dtype=torch_dtype(model.device)),
|
||||||
noise=noise,
|
noise=noise,
|
||||||
num_inference_steps=self.steps,
|
num_inference_steps=self.steps,
|
||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
callback=step_callback
|
control_data=control_data, # list[ControlNetData]
|
||||||
|
callback=step_callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
@ -285,7 +374,8 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
"ui": {
|
"ui": {
|
||||||
"tags": ["latents"],
|
"tags": ["latents"],
|
||||||
"type_hints": {
|
"type_hints": {
|
||||||
"model": "model"
|
"model": "model",
|
||||||
|
"control": "control",
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -304,6 +394,11 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
model = self.get_model(context.services.model_manager)
|
model = self.get_model(context.services.model_manager)
|
||||||
conditioning_data = self.get_conditioning_data(context, model)
|
conditioning_data = self.get_conditioning_data(context, model)
|
||||||
|
|
||||||
|
print("type of control input: ", type(self.control))
|
||||||
|
control_data = self.prep_control_data(model=model, context=context, control_input=self.control,
|
||||||
|
latents_shape=noise.shape,
|
||||||
|
do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||||
|
|
||||||
# TODO: Verify the noise is the right size
|
# TODO: Verify the noise is the right size
|
||||||
|
|
||||||
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
|
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
|
||||||
@ -318,6 +413,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
noise=noise,
|
noise=noise,
|
||||||
num_inference_steps=self.steps,
|
num_inference_steps=self.steps,
|
||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
|
control_data=control_data, # list[ControlNetData]
|
||||||
callback=step_callback
|
callback=step_callback
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -362,14 +458,21 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
np_image = model.decode_latents(latents)
|
np_image = model.decode_latents(latents)
|
||||||
image = model.numpy_to_pil(np_image)[0]
|
image = model.numpy_to_pil(np_image)[0]
|
||||||
|
|
||||||
|
# what happened to metadata?
|
||||||
|
# metadata = context.services.metadata.build_metadata(
|
||||||
|
# session_id=context.graph_execution_state_id, node=self
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# new (post Image service refactor) way of using services to save image
|
||||||
|
# and gnenerate unique image_name
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=image,
|
image=image,
|
||||||
image_type=ImageType.RESULT,
|
image_type=ImageType.RESULT,
|
||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
|
is_intermediate=self.is_intermediate
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
@ -413,6 +516,7 @@ class ResizeLatentsInvocation(BaseInvocation):
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
|
# context.services.latents.set(name, resized_latents)
|
||||||
context.services.latents.save(name, resized_latents)
|
context.services.latents.save(name, resized_latents)
|
||||||
return build_latents_output(latents_name=name, latents=resized_latents)
|
return build_latents_output(latents_name=name, latents=resized_latents)
|
||||||
|
|
||||||
@ -443,6 +547,7 @@ class ScaleLatentsInvocation(BaseInvocation):
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
|
# context.services.latents.set(name, resized_latents)
|
||||||
context.services.latents.save(name, resized_latents)
|
context.services.latents.save(name, resized_latents)
|
||||||
return build_latents_output(latents_name=name, latents=resized_latents)
|
return build_latents_output(latents_name=name, latents=resized_latents)
|
||||||
|
|
||||||
@ -467,6 +572,9 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
|
# image = context.services.images.get(
|
||||||
|
# self.image.image_type, self.image.image_name
|
||||||
|
# )
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_type, self.image.image_name
|
||||||
)
|
)
|
||||||
@ -487,6 +595,6 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
|
# context.services.latents.set(name, latents)
|
||||||
context.services.latents.save(name, latents)
|
context.services.latents.save(name, latents)
|
||||||
return build_latents_output(latents_name=name, latents=latents)
|
return build_latents_output(latents_name=name, latents=latents)
|
||||||
|
|
||||||
|
@ -34,6 +34,15 @@ class IntOutput(BaseInvocationOutput):
|
|||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
class FloatOutput(BaseInvocationOutput):
|
||||||
|
"""A float output"""
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["float_output"] = "float_output"
|
||||||
|
param: float = Field(default=None, description="The output float")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
class AddInvocation(BaseInvocation, MathInvocationConfig):
|
class AddInvocation(BaseInvocation, MathInvocationConfig):
|
||||||
"""Adds two numbers"""
|
"""Adds two numbers"""
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
from typing import Literal
|
from typing import Literal
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||||
from .math import IntOutput
|
from .math import IntOutput, FloatOutput
|
||||||
|
|
||||||
# Pass-through parameter nodes - used by subgraphs
|
# Pass-through parameter nodes - used by subgraphs
|
||||||
|
|
||||||
@ -16,3 +16,13 @@ class ParamIntInvocation(BaseInvocation):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||||
return IntOutput(a=self.a)
|
return IntOutput(a=self.a)
|
||||||
|
|
||||||
|
class ParamFloatInvocation(BaseInvocation):
|
||||||
|
"""A float parameter"""
|
||||||
|
#fmt: off
|
||||||
|
type: Literal["param_float"] = "param_float"
|
||||||
|
param: float = Field(default=0.0, description="The float value")
|
||||||
|
#fmt: on
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> FloatOutput:
|
||||||
|
return FloatOutput(param=self.param)
|
||||||
|
@ -43,10 +43,11 @@ class RestoreFaceInvocation(BaseInvocation):
|
|||||||
# TODO: can this return multiple results?
|
# TODO: can this return multiple results?
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=results[0][0],
|
image=results[0][0],
|
||||||
image_type=ImageType.INTERMEDIATE,
|
image_type=ImageType.RESULT,
|
||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
|
@ -49,6 +49,7 @@ class UpscaleInvocation(BaseInvocation):
|
|||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
|
@ -10,7 +10,6 @@ class ImageType(str, Enum, metaclass=MetaEnum):
|
|||||||
|
|
||||||
RESULT = "results"
|
RESULT = "results"
|
||||||
UPLOAD = "uploads"
|
UPLOAD = "uploads"
|
||||||
INTERMEDIATE = "intermediates"
|
|
||||||
|
|
||||||
|
|
||||||
class InvalidImageTypeException(ValueError):
|
class InvalidImageTypeException(ValueError):
|
||||||
|
@ -353,6 +353,7 @@ setting environment variables INVOKEAI_<setting>.
|
|||||||
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
|
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
|
||||||
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
|
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
|
||||||
|
|
||||||
|
|
||||||
root : Path = Field(default=_find_root(), description='InvokeAI runtime root directory', category='Paths')
|
root : Path = Field(default=_find_root(), description='InvokeAI runtime root directory', category='Paths')
|
||||||
autoconvert_dir : Path = Field(default=None, description='Path to a directory of ckpt files to be converted into diffusers and imported on startup.', category='Paths')
|
autoconvert_dir : Path = Field(default=None, description='Path to a directory of ckpt files to be converted into diffusers and imported on startup.', category='Paths')
|
||||||
conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths')
|
conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths')
|
||||||
@ -362,6 +363,7 @@ setting environment variables INVOKEAI_<setting>.
|
|||||||
lora_dir : Path = Field(default='loras', description='Path to InvokeAI LoRA model directory', category='Paths')
|
lora_dir : Path = Field(default='loras', description='Path to InvokeAI LoRA model directory', category='Paths')
|
||||||
outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths')
|
outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths')
|
||||||
from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths')
|
from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths')
|
||||||
|
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', category='Paths')
|
||||||
|
|
||||||
model : str = Field(default='stable-diffusion-1.5', description='Initial model name', category='Models')
|
model : str = Field(default='stable-diffusion-1.5', description='Initial model name', category='Models')
|
||||||
embeddings : bool = Field(default=True, description='Load contents of embeddings directory', category='Models')
|
embeddings : bool = Field(default=True, description='Load contents of embeddings directory', category='Models')
|
||||||
@ -511,7 +513,7 @@ class PagingArgumentParser(argparse.ArgumentParser):
|
|||||||
text = self.format_help()
|
text = self.format_help()
|
||||||
pydoc.pager(text)
|
pydoc.pager(text)
|
||||||
|
|
||||||
def get_invokeai_config(cls:Type[InvokeAISettings]=InvokeAIAppConfig,**kwargs)->InvokeAISettings:
|
def get_invokeai_config(cls:Type[InvokeAISettings]=InvokeAIAppConfig,**kwargs)->InvokeAIAppConfig:
|
||||||
'''
|
'''
|
||||||
This returns a singleton InvokeAIAppConfig configuration object.
|
This returns a singleton InvokeAIAppConfig configuration object.
|
||||||
'''
|
'''
|
||||||
|
@ -60,6 +60,35 @@ def get_input_field(node: BaseInvocation, field: str) -> Any:
|
|||||||
node_input_field = node_inputs.get(field) or None
|
node_input_field = node_inputs.get(field) or None
|
||||||
return node_input_field
|
return node_input_field
|
||||||
|
|
||||||
|
from typing import Optional, Union, List, get_args
|
||||||
|
|
||||||
|
def is_union_subtype(t1, t2):
|
||||||
|
t1_args = get_args(t1)
|
||||||
|
t2_args = get_args(t2)
|
||||||
|
|
||||||
|
if not t1_args:
|
||||||
|
# t1 is a single type
|
||||||
|
return t1 in t2_args
|
||||||
|
else:
|
||||||
|
# t1 is a Union, check that all of its types are in t2_args
|
||||||
|
return all(arg in t2_args for arg in t1_args)
|
||||||
|
|
||||||
|
def is_list_or_contains_list(t):
|
||||||
|
t_args = get_args(t)
|
||||||
|
|
||||||
|
# If the type is a List
|
||||||
|
if get_origin(t) is list:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# If the type is a Union
|
||||||
|
elif t_args:
|
||||||
|
# Check if any of the types in the Union is a List
|
||||||
|
for arg in t_args:
|
||||||
|
if get_origin(arg) is list:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
|
def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
|
||||||
if not from_type:
|
if not from_type:
|
||||||
@ -85,7 +114,8 @@ def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
|
|||||||
if to_type in get_args(from_type):
|
if to_type in get_args(from_type):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if not issubclass(from_type, to_type):
|
# if not issubclass(from_type, to_type):
|
||||||
|
if not is_union_subtype(from_type, to_type):
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
@ -694,7 +724,11 @@ class Graph(BaseModel):
|
|||||||
input_root_type = next(t[0] for t in type_degrees if t[1] == 0) # type: ignore
|
input_root_type = next(t[0] for t in type_degrees if t[1] == 0) # type: ignore
|
||||||
|
|
||||||
# Verify that all outputs are lists
|
# Verify that all outputs are lists
|
||||||
if not all((get_origin(f) == list for f in output_fields)):
|
# if not all((get_origin(f) == list for f in output_fields)):
|
||||||
|
# return False
|
||||||
|
|
||||||
|
# Verify that all outputs are lists
|
||||||
|
if not all(is_list_or_contains_list(f) for f in output_fields):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Verify that all outputs match the input type (are a base class or the same class)
|
# Verify that all outputs match the input type (are a base class or the same class)
|
||||||
|
@ -12,6 +12,7 @@ from invokeai.app.models.image import (
|
|||||||
)
|
)
|
||||||
from invokeai.app.services.models.image_record import (
|
from invokeai.app.services.models.image_record import (
|
||||||
ImageRecord,
|
ImageRecord,
|
||||||
|
ImageRecordChanges,
|
||||||
deserialize_image_record,
|
deserialize_image_record,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.item_storage import PaginatedResults
|
from invokeai.app.services.item_storage import PaginatedResults
|
||||||
@ -49,6 +50,16 @@ class ImageRecordStorageBase(ABC):
|
|||||||
"""Gets an image record."""
|
"""Gets an image record."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
image_name: str,
|
||||||
|
image_type: ImageType,
|
||||||
|
changes: ImageRecordChanges,
|
||||||
|
) -> None:
|
||||||
|
"""Updates an image record."""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_many(
|
def get_many(
|
||||||
self,
|
self,
|
||||||
@ -78,6 +89,7 @@ class ImageRecordStorageBase(ABC):
|
|||||||
session_id: Optional[str],
|
session_id: Optional[str],
|
||||||
node_id: Optional[str],
|
node_id: Optional[str],
|
||||||
metadata: Optional[ImageMetadata],
|
metadata: Optional[ImageMetadata],
|
||||||
|
is_intermediate: bool = False,
|
||||||
) -> datetime:
|
) -> datetime:
|
||||||
"""Saves an image record."""
|
"""Saves an image record."""
|
||||||
pass
|
pass
|
||||||
@ -125,6 +137,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
session_id TEXT,
|
session_id TEXT,
|
||||||
node_id TEXT,
|
node_id TEXT,
|
||||||
metadata TEXT,
|
metadata TEXT,
|
||||||
|
is_intermediate BOOLEAN DEFAULT FALSE,
|
||||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
-- Updated via trigger
|
-- Updated via trigger
|
||||||
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
@ -193,6 +206,42 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
|
|
||||||
return deserialize_image_record(dict(result))
|
return deserialize_image_record(dict(result))
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
image_name: str,
|
||||||
|
image_type: ImageType,
|
||||||
|
changes: ImageRecordChanges,
|
||||||
|
) -> None:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
# Change the category of the image
|
||||||
|
if changes.image_category is not None:
|
||||||
|
self._cursor.execute(
|
||||||
|
f"""--sql
|
||||||
|
UPDATE images
|
||||||
|
SET image_category = ?
|
||||||
|
WHERE image_name = ?;
|
||||||
|
""",
|
||||||
|
(changes.image_category, image_name),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Change the session associated with the image
|
||||||
|
if changes.session_id is not None:
|
||||||
|
self._cursor.execute(
|
||||||
|
f"""--sql
|
||||||
|
UPDATE images
|
||||||
|
SET session_id = ?
|
||||||
|
WHERE image_name = ?;
|
||||||
|
""",
|
||||||
|
(changes.session_id, image_name),
|
||||||
|
)
|
||||||
|
self._conn.commit()
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise ImageRecordSaveException from e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
def get_many(
|
def get_many(
|
||||||
self,
|
self,
|
||||||
image_type: ImageType,
|
image_type: ImageType,
|
||||||
@ -265,6 +314,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
height: int,
|
height: int,
|
||||||
node_id: Optional[str],
|
node_id: Optional[str],
|
||||||
metadata: Optional[ImageMetadata],
|
metadata: Optional[ImageMetadata],
|
||||||
|
is_intermediate: bool = False,
|
||||||
) -> datetime:
|
) -> datetime:
|
||||||
try:
|
try:
|
||||||
metadata_json = (
|
metadata_json = (
|
||||||
@ -281,9 +331,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
height,
|
height,
|
||||||
node_id,
|
node_id,
|
||||||
session_id,
|
session_id,
|
||||||
metadata
|
metadata,
|
||||||
|
is_intermediate
|
||||||
)
|
)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?);
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?);
|
||||||
""",
|
""",
|
||||||
(
|
(
|
||||||
image_name,
|
image_name,
|
||||||
@ -294,6 +345,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
node_id,
|
node_id,
|
||||||
session_id,
|
session_id,
|
||||||
metadata_json,
|
metadata_json,
|
||||||
|
is_intermediate,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self._conn.commit()
|
self._conn.commit()
|
||||||
|
@ -20,6 +20,7 @@ from invokeai.app.services.image_record_storage import (
|
|||||||
from invokeai.app.services.models.image_record import (
|
from invokeai.app.services.models.image_record import (
|
||||||
ImageRecord,
|
ImageRecord,
|
||||||
ImageDTO,
|
ImageDTO,
|
||||||
|
ImageRecordChanges,
|
||||||
image_record_to_dto,
|
image_record_to_dto,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.image_file_storage import (
|
from invokeai.app.services.image_file_storage import (
|
||||||
@ -31,7 +32,6 @@ from invokeai.app.services.image_file_storage import (
|
|||||||
from invokeai.app.services.item_storage import ItemStorageABC, PaginatedResults
|
from invokeai.app.services.item_storage import ItemStorageABC, PaginatedResults
|
||||||
from invokeai.app.services.metadata import MetadataServiceBase
|
from invokeai.app.services.metadata import MetadataServiceBase
|
||||||
from invokeai.app.services.urls import UrlServiceBase
|
from invokeai.app.services.urls import UrlServiceBase
|
||||||
from invokeai.app.util.misc import get_iso_timestamp
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from invokeai.app.services.graph import GraphExecutionState
|
from invokeai.app.services.graph import GraphExecutionState
|
||||||
@ -48,11 +48,21 @@ class ImageServiceABC(ABC):
|
|||||||
image_category: ImageCategory,
|
image_category: ImageCategory,
|
||||||
node_id: Optional[str] = None,
|
node_id: Optional[str] = None,
|
||||||
session_id: Optional[str] = None,
|
session_id: Optional[str] = None,
|
||||||
metadata: Optional[ImageMetadata] = None,
|
intermediate: bool = False,
|
||||||
) -> ImageDTO:
|
) -> ImageDTO:
|
||||||
"""Creates an image, storing the file and its metadata."""
|
"""Creates an image, storing the file and its metadata."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
image_type: ImageType,
|
||||||
|
image_name: str,
|
||||||
|
changes: ImageRecordChanges,
|
||||||
|
) -> ImageDTO:
|
||||||
|
"""Updates an image."""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
|
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
|
||||||
"""Gets an image as a PIL image."""
|
"""Gets an image as a PIL image."""
|
||||||
@ -157,6 +167,7 @@ class ImageService(ImageServiceABC):
|
|||||||
image_category: ImageCategory,
|
image_category: ImageCategory,
|
||||||
node_id: Optional[str] = None,
|
node_id: Optional[str] = None,
|
||||||
session_id: Optional[str] = None,
|
session_id: Optional[str] = None,
|
||||||
|
is_intermediate: bool = False,
|
||||||
) -> ImageDTO:
|
) -> ImageDTO:
|
||||||
if image_type not in ImageType:
|
if image_type not in ImageType:
|
||||||
raise InvalidImageTypeException
|
raise InvalidImageTypeException
|
||||||
@ -184,6 +195,8 @@ class ImageService(ImageServiceABC):
|
|||||||
image_category=image_category,
|
image_category=image_category,
|
||||||
width=width,
|
width=width,
|
||||||
height=height,
|
height=height,
|
||||||
|
# Meta fields
|
||||||
|
is_intermediate=is_intermediate,
|
||||||
# Nullable fields
|
# Nullable fields
|
||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
@ -217,6 +230,7 @@ class ImageService(ImageServiceABC):
|
|||||||
created_at=created_at,
|
created_at=created_at,
|
||||||
updated_at=created_at, # this is always the same as the created_at at this time
|
updated_at=created_at, # this is always the same as the created_at at this time
|
||||||
deleted_at=None,
|
deleted_at=None,
|
||||||
|
is_intermediate=is_intermediate,
|
||||||
# Extra non-nullable fields for DTO
|
# Extra non-nullable fields for DTO
|
||||||
image_url=image_url,
|
image_url=image_url,
|
||||||
thumbnail_url=thumbnail_url,
|
thumbnail_url=thumbnail_url,
|
||||||
@ -231,6 +245,23 @@ class ImageService(ImageServiceABC):
|
|||||||
self._services.logger.error("Problem saving image record and file")
|
self._services.logger.error("Problem saving image record and file")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
image_type: ImageType,
|
||||||
|
image_name: str,
|
||||||
|
changes: ImageRecordChanges,
|
||||||
|
) -> ImageDTO:
|
||||||
|
try:
|
||||||
|
self._services.records.update(image_name, image_type, changes)
|
||||||
|
return self.get_dto(image_type, image_name)
|
||||||
|
except ImageRecordSaveException:
|
||||||
|
self._services.logger.error("Failed to update image record")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
self._services.logger.error("Problem updating image record")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
|
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
|
||||||
try:
|
try:
|
||||||
return self._services.files.get(image_type, image_name)
|
return self._services.files.get(image_type, image_name)
|
||||||
|
@ -1,18 +1,17 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||||
|
from __future__ import annotations
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
from logging import Logger
|
|
||||||
|
|
||||||
from invokeai.app.services.images import ImageService
|
|
||||||
from invokeai.backend import ModelManager
|
|
||||||
from .events import EventServiceBase
|
|
||||||
from .latent_storage import LatentsStorageBase
|
|
||||||
from .restoration_services import RestorationServices
|
|
||||||
from .invocation_queue import InvocationQueueABC
|
|
||||||
from .item_storage import ItemStorageABC
|
|
||||||
from .config import InvokeAISettings
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from logging import Logger
|
||||||
|
from invokeai.app.services.images import ImageService
|
||||||
|
from invokeai.backend import ModelManager
|
||||||
|
from invokeai.app.services.events import EventServiceBase
|
||||||
|
from invokeai.app.services.latent_storage import LatentsStorageBase
|
||||||
|
from invokeai.app.services.restoration_services import RestorationServices
|
||||||
|
from invokeai.app.services.invocation_queue import InvocationQueueABC
|
||||||
|
from invokeai.app.services.item_storage import ItemStorageABC
|
||||||
|
from invokeai.app.services.config import InvokeAISettings
|
||||||
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
|
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
|
||||||
from invokeai.app.services.invoker import InvocationProcessorABC
|
from invokeai.app.services.invoker import InvocationProcessorABC
|
||||||
|
|
||||||
@ -20,32 +19,33 @@ if TYPE_CHECKING:
|
|||||||
class InvocationServices:
|
class InvocationServices:
|
||||||
"""Services that can be used by invocations"""
|
"""Services that can be used by invocations"""
|
||||||
|
|
||||||
events: EventServiceBase
|
# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
|
||||||
latents: LatentsStorageBase
|
events: "EventServiceBase"
|
||||||
queue: InvocationQueueABC
|
latents: "LatentsStorageBase"
|
||||||
model_manager: ModelManager
|
queue: "InvocationQueueABC"
|
||||||
restoration: RestorationServices
|
model_manager: "ModelManager"
|
||||||
configuration: InvokeAISettings
|
restoration: "RestorationServices"
|
||||||
images: ImageService
|
configuration: "InvokeAISettings"
|
||||||
|
images: "ImageService"
|
||||||
|
|
||||||
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
|
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
|
||||||
graph_library: ItemStorageABC["LibraryGraph"]
|
graph_library: "ItemStorageABC"["LibraryGraph"]
|
||||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
|
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"]
|
||||||
processor: "InvocationProcessorABC"
|
processor: "InvocationProcessorABC"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_manager: ModelManager,
|
model_manager: "ModelManager",
|
||||||
events: EventServiceBase,
|
events: "EventServiceBase",
|
||||||
logger: Logger,
|
logger: "Logger",
|
||||||
latents: LatentsStorageBase,
|
latents: "LatentsStorageBase",
|
||||||
images: ImageService,
|
images: "ImageService",
|
||||||
queue: InvocationQueueABC,
|
queue: "InvocationQueueABC",
|
||||||
graph_library: ItemStorageABC["LibraryGraph"],
|
graph_library: "ItemStorageABC"["LibraryGraph"],
|
||||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
|
||||||
processor: "InvocationProcessorABC",
|
processor: "InvocationProcessorABC",
|
||||||
restoration: RestorationServices,
|
restoration: "RestorationServices",
|
||||||
configuration: InvokeAISettings = None,
|
configuration: "InvokeAISettings",
|
||||||
):
|
):
|
||||||
self.model_manager = model_manager
|
self.model_manager = model_manager
|
||||||
self.events = events
|
self.events = events
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import datetime
|
import datetime
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Extra, Field, StrictStr
|
||||||
from invokeai.app.models.image import ImageCategory, ImageType
|
from invokeai.app.models.image import ImageCategory, ImageType
|
||||||
from invokeai.app.models.metadata import ImageMetadata
|
from invokeai.app.models.metadata import ImageMetadata
|
||||||
from invokeai.app.util.misc import get_iso_timestamp
|
from invokeai.app.util.misc import get_iso_timestamp
|
||||||
@ -31,6 +31,8 @@ class ImageRecord(BaseModel):
|
|||||||
description="The deleted timestamp of the image."
|
description="The deleted timestamp of the image."
|
||||||
)
|
)
|
||||||
"""The deleted timestamp of the image."""
|
"""The deleted timestamp of the image."""
|
||||||
|
is_intermediate: bool = Field(description="Whether this is an intermediate image.")
|
||||||
|
"""Whether this is an intermediate image."""
|
||||||
session_id: Optional[str] = Field(
|
session_id: Optional[str] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The session ID that generated this image, if it is a generated image.",
|
description="The session ID that generated this image, if it is a generated image.",
|
||||||
@ -48,6 +50,25 @@ class ImageRecord(BaseModel):
|
|||||||
"""A limited subset of the image's generation metadata. Retrieve the image's session for full metadata."""
|
"""A limited subset of the image's generation metadata. Retrieve the image's session for full metadata."""
|
||||||
|
|
||||||
|
|
||||||
|
class ImageRecordChanges(BaseModel, extra=Extra.forbid):
|
||||||
|
"""A set of changes to apply to an image record.
|
||||||
|
|
||||||
|
Only limited changes are valid:
|
||||||
|
- `image_category`: change the category of an image
|
||||||
|
- `session_id`: change the session associated with an image
|
||||||
|
"""
|
||||||
|
|
||||||
|
image_category: Optional[ImageCategory] = Field(
|
||||||
|
description="The image's new category."
|
||||||
|
)
|
||||||
|
"""The image's new category."""
|
||||||
|
session_id: Optional[StrictStr] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The image's new session ID.",
|
||||||
|
)
|
||||||
|
"""The image's new session ID."""
|
||||||
|
|
||||||
|
|
||||||
class ImageUrlsDTO(BaseModel):
|
class ImageUrlsDTO(BaseModel):
|
||||||
"""The URLs for an image and its thumbnail."""
|
"""The URLs for an image and its thumbnail."""
|
||||||
|
|
||||||
@ -95,6 +116,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
|||||||
created_at = image_dict.get("created_at", get_iso_timestamp())
|
created_at = image_dict.get("created_at", get_iso_timestamp())
|
||||||
updated_at = image_dict.get("updated_at", get_iso_timestamp())
|
updated_at = image_dict.get("updated_at", get_iso_timestamp())
|
||||||
deleted_at = image_dict.get("deleted_at", get_iso_timestamp())
|
deleted_at = image_dict.get("deleted_at", get_iso_timestamp())
|
||||||
|
is_intermediate = image_dict.get("is_intermediate", False)
|
||||||
|
|
||||||
raw_metadata = image_dict.get("metadata")
|
raw_metadata = image_dict.get("metadata")
|
||||||
|
|
||||||
@ -115,4 +137,5 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
|||||||
created_at=created_at,
|
created_at=created_at,
|
||||||
updated_at=updated_at,
|
updated_at=updated_at,
|
||||||
deleted_at=deleted_at,
|
deleted_at=deleted_at,
|
||||||
|
is_intermediate=is_intermediate,
|
||||||
)
|
)
|
||||||
|
@ -1,10 +1,7 @@
|
|||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from threading import Event, Thread, BoundedSemaphore
|
from threading import Event, Thread, BoundedSemaphore
|
||||||
from typing import Any, TypeGuard
|
|
||||||
|
|
||||||
from invokeai.app.invocations.image import ImageOutput
|
|
||||||
from invokeai.app.models.image import ImageType
|
|
||||||
from ..invocations.baseinvocation import InvocationContext
|
from ..invocations.baseinvocation import InvocationContext
|
||||||
from .invocation_queue import InvocationQueueItem
|
from .invocation_queue import InvocationQueueItem
|
||||||
from .invoker import InvocationProcessorABC, Invoker
|
from .invoker import InvocationProcessorABC, Invoker
|
||||||
|
@ -75,9 +75,11 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
model_info: dict,
|
model_info: dict,
|
||||||
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
|
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.model_info=model_info
|
self.model_info=model_info
|
||||||
self.params=params
|
self.params=params
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
def generate(self,
|
def generate(self,
|
||||||
prompt: str='',
|
prompt: str='',
|
||||||
@ -118,9 +120,12 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
|||||||
model=model,
|
model=model,
|
||||||
scheduler_name=generator_args.get('scheduler')
|
scheduler_name=generator_args.get('scheduler')
|
||||||
)
|
)
|
||||||
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model)
|
|
||||||
|
# get conditioning from prompt via Compel package
|
||||||
|
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt, model=model)
|
||||||
|
|
||||||
gen_class = self._generator_class()
|
gen_class = self._generator_class()
|
||||||
generator = gen_class(model, self.params.precision)
|
generator = gen_class(model, self.params.precision, **self.kwargs)
|
||||||
if self.params.variation_amount > 0:
|
if self.params.variation_amount > 0:
|
||||||
generator.set_variation(generator_args.get('seed'),
|
generator.set_variation(generator_args.get('seed'),
|
||||||
generator_args.get('variation_amount'),
|
generator_args.get('variation_amount'),
|
||||||
@ -276,7 +281,7 @@ class Generator:
|
|||||||
precision: str
|
precision: str
|
||||||
model: DiffusionPipeline
|
model: DiffusionPipeline
|
||||||
|
|
||||||
def __init__(self, model: DiffusionPipeline, precision: str):
|
def __init__(self, model: DiffusionPipeline, precision: str, **kwargs):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.precision = precision
|
self.precision = precision
|
||||||
self.seed = None
|
self.seed = None
|
||||||
|
@ -4,6 +4,10 @@ invokeai.backend.generator.txt2img inherits from invokeai.backend.generator
|
|||||||
import PIL.Image
|
import PIL.Image
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
from diffusers.models.controlnet import ControlNetModel, ControlNetOutput
|
||||||
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
|
||||||
|
|
||||||
from ..stable_diffusion import (
|
from ..stable_diffusion import (
|
||||||
ConditioningData,
|
ConditioningData,
|
||||||
PostprocessingSettings,
|
PostprocessingSettings,
|
||||||
@ -13,8 +17,13 @@ from .base import Generator
|
|||||||
|
|
||||||
|
|
||||||
class Txt2Img(Generator):
|
class Txt2Img(Generator):
|
||||||
def __init__(self, model, precision):
|
def __init__(self, model, precision,
|
||||||
super().__init__(model, precision)
|
control_model: Optional[Union[ControlNetModel, List[ControlNetModel]]] = None,
|
||||||
|
**kwargs):
|
||||||
|
self.control_model = control_model
|
||||||
|
if isinstance(self.control_model, list):
|
||||||
|
self.control_model = MultiControlNetModel(self.control_model)
|
||||||
|
super().__init__(model, precision, **kwargs)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def get_make_image(
|
def get_make_image(
|
||||||
@ -42,9 +51,12 @@ class Txt2Img(Generator):
|
|||||||
kwargs are 'width' and 'height'
|
kwargs are 'width' and 'height'
|
||||||
"""
|
"""
|
||||||
self.perlin = perlin
|
self.perlin = perlin
|
||||||
|
control_image = kwargs.get("control_image", None)
|
||||||
|
do_classifier_free_guidance = cfg_scale > 1.0
|
||||||
|
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
pipeline: StableDiffusionGeneratorPipeline = self.model
|
pipeline: StableDiffusionGeneratorPipeline = self.model
|
||||||
|
pipeline.control_model = self.control_model
|
||||||
pipeline.scheduler = sampler
|
pipeline.scheduler = sampler
|
||||||
|
|
||||||
uc, c, extra_conditioning_info = conditioning
|
uc, c, extra_conditioning_info = conditioning
|
||||||
@ -61,6 +73,37 @@ class Txt2Img(Generator):
|
|||||||
),
|
),
|
||||||
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
|
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
|
||||||
|
|
||||||
|
# FIXME: still need to test with different widths, heights, devices, dtypes
|
||||||
|
# and add in batch_size, num_images_per_prompt?
|
||||||
|
if control_image is not None:
|
||||||
|
if isinstance(self.control_model, ControlNetModel):
|
||||||
|
control_image = pipeline.prepare_control_image(
|
||||||
|
image=control_image,
|
||||||
|
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
# batch_size=batch_size * num_images_per_prompt,
|
||||||
|
# num_images_per_prompt=num_images_per_prompt,
|
||||||
|
device=self.control_model.device,
|
||||||
|
dtype=self.control_model.dtype,
|
||||||
|
)
|
||||||
|
elif isinstance(self.control_model, MultiControlNetModel):
|
||||||
|
images = []
|
||||||
|
for image_ in control_image:
|
||||||
|
image_ = self.model.prepare_control_image(
|
||||||
|
image=image_,
|
||||||
|
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
# batch_size=batch_size * num_images_per_prompt,
|
||||||
|
# num_images_per_prompt=num_images_per_prompt,
|
||||||
|
device=self.control_model.device,
|
||||||
|
dtype=self.control_model.dtype,
|
||||||
|
)
|
||||||
|
images.append(image_)
|
||||||
|
control_image = images
|
||||||
|
kwargs["control_image"] = control_image
|
||||||
|
|
||||||
def make_image(x_T: torch.Tensor, _: int) -> PIL.Image.Image:
|
def make_image(x_T: torch.Tensor, _: int) -> PIL.Image.Image:
|
||||||
pipeline_output = pipeline.image_from_embeddings(
|
pipeline_output = pipeline.image_from_embeddings(
|
||||||
latents=torch.zeros_like(x_T, dtype=self.torch_dtype()),
|
latents=torch.zeros_like(x_T, dtype=self.torch_dtype()),
|
||||||
@ -68,6 +111,7 @@ class Txt2Img(Generator):
|
|||||||
num_inference_steps=steps,
|
num_inference_steps=steps,
|
||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
callback=step_callback,
|
callback=step_callback,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
@ -2,23 +2,29 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import inspect
|
import inspect
|
||||||
|
import math
|
||||||
import secrets
|
import secrets
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
|
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, Union
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
|
import numpy as np
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
from compel import EmbeddingsProvider
|
from compel import EmbeddingsProvider
|
||||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||||
|
from diffusers.models.controlnet import ControlNetModel, ControlNetOutput
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
|
||||||
StableDiffusionPipeline,
|
StableDiffusionPipeline,
|
||||||
)
|
)
|
||||||
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
|
||||||
|
|
||||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
|
||||||
StableDiffusionImg2ImgPipeline,
|
StableDiffusionImg2ImgPipeline,
|
||||||
)
|
)
|
||||||
@ -27,6 +33,7 @@ from diffusers.pipelines.stable_diffusion.safety_checker import (
|
|||||||
)
|
)
|
||||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||||
|
from diffusers.utils import PIL_INTERPOLATION
|
||||||
from diffusers.utils.import_utils import is_xformers_available
|
from diffusers.utils.import_utils import is_xformers_available
|
||||||
from diffusers.utils.outputs import BaseOutput
|
from diffusers.utils.outputs import BaseOutput
|
||||||
from torchvision.transforms.functional import resize as tv_resize
|
from torchvision.transforms.functional import resize as tv_resize
|
||||||
@ -207,6 +214,13 @@ class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
|
|||||||
raise AssertionError("why was that an empty generator?")
|
raise AssertionError("why was that an empty generator?")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ControlNetData:
|
||||||
|
model: ControlNetModel = Field(default=None)
|
||||||
|
image_tensor: torch.Tensor= Field(default=None)
|
||||||
|
weight: float = Field(default=1.0)
|
||||||
|
begin_step_percent: float = Field(default=0.0)
|
||||||
|
end_step_percent: float = Field(default=1.0)
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ConditioningData:
|
class ConditioningData:
|
||||||
@ -302,6 +316,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
feature_extractor: Optional[CLIPFeatureExtractor],
|
feature_extractor: Optional[CLIPFeatureExtractor],
|
||||||
requires_safety_checker: bool = False,
|
requires_safety_checker: bool = False,
|
||||||
precision: str = "float32",
|
precision: str = "float32",
|
||||||
|
control_model: ControlNetModel = None,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
vae,
|
vae,
|
||||||
@ -322,6 +337,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
safety_checker=safety_checker,
|
safety_checker=safety_checker,
|
||||||
feature_extractor=feature_extractor,
|
feature_extractor=feature_extractor,
|
||||||
|
# FIXME: can't currently register control module
|
||||||
|
# control_model=control_model,
|
||||||
)
|
)
|
||||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(
|
self.invokeai_diffuser = InvokeAIDiffuserComponent(
|
||||||
self.unet, self._unet_forward, is_running_diffusers=True
|
self.unet, self._unet_forward, is_running_diffusers=True
|
||||||
@ -341,6 +358,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
self._model_group = FullyLoadedModelGroup(self.unet.device)
|
self._model_group = FullyLoadedModelGroup(self.unet.device)
|
||||||
self._model_group.install(*self._submodels)
|
self._model_group.install(*self._submodels)
|
||||||
|
self.control_model = control_model
|
||||||
|
|
||||||
def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
|
def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
|
||||||
"""
|
"""
|
||||||
@ -463,6 +481,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
noise: torch.Tensor,
|
noise: torch.Tensor,
|
||||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||||
run_id=None,
|
run_id=None,
|
||||||
|
**kwargs,
|
||||||
) -> InvokeAIStableDiffusionPipelineOutput:
|
) -> InvokeAIStableDiffusionPipelineOutput:
|
||||||
r"""
|
r"""
|
||||||
Function invoked when calling the pipeline for generation.
|
Function invoked when calling the pipeline for generation.
|
||||||
@ -483,6 +502,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
noise=noise,
|
noise=noise,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
callback=callback,
|
callback=callback,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@ -507,6 +527,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
additional_guidance: List[Callable] = None,
|
additional_guidance: List[Callable] = None,
|
||||||
run_id=None,
|
run_id=None,
|
||||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||||
|
control_data: List[ControlNetData] = None,
|
||||||
|
**kwargs,
|
||||||
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
||||||
if self.scheduler.config.get("cpu_only", False):
|
if self.scheduler.config.get("cpu_only", False):
|
||||||
scheduler_device = torch.device('cpu')
|
scheduler_device = torch.device('cpu')
|
||||||
@ -527,6 +549,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
additional_guidance=additional_guidance,
|
additional_guidance=additional_guidance,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
callback=callback,
|
callback=callback,
|
||||||
|
control_data=control_data,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
return result.latents, result.attention_map_saver
|
return result.latents, result.attention_map_saver
|
||||||
|
|
||||||
@ -539,6 +563,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
noise: torch.Tensor,
|
noise: torch.Tensor,
|
||||||
run_id: str = None,
|
run_id: str = None,
|
||||||
additional_guidance: List[Callable] = None,
|
additional_guidance: List[Callable] = None,
|
||||||
|
control_data: List[ControlNetData] = None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
self._adjust_memory_efficient_attention(latents)
|
self._adjust_memory_efficient_attention(latents)
|
||||||
if run_id is None:
|
if run_id is None:
|
||||||
@ -568,7 +594,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
||||||
|
|
||||||
attention_map_saver: Optional[AttentionMapSaver] = None
|
attention_map_saver: Optional[AttentionMapSaver] = None
|
||||||
|
# print("timesteps:", timesteps)
|
||||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||||
batched_t.fill_(t)
|
batched_t.fill_(t)
|
||||||
step_output = self.step(
|
step_output = self.step(
|
||||||
@ -578,6 +604,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
step_index=i,
|
step_index=i,
|
||||||
total_step_count=len(timesteps),
|
total_step_count=len(timesteps),
|
||||||
additional_guidance=additional_guidance,
|
additional_guidance=additional_guidance,
|
||||||
|
control_data=control_data,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
latents = step_output.prev_sample
|
latents = step_output.prev_sample
|
||||||
|
|
||||||
@ -618,10 +646,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
step_index: int,
|
step_index: int,
|
||||||
total_step_count: int,
|
total_step_count: int,
|
||||||
additional_guidance: List[Callable] = None,
|
additional_guidance: List[Callable] = None,
|
||||||
|
control_data: List[ControlNetData] = None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
||||||
timestep = t[0]
|
timestep = t[0]
|
||||||
|
|
||||||
if additional_guidance is None:
|
if additional_guidance is None:
|
||||||
additional_guidance = []
|
additional_guidance = []
|
||||||
|
|
||||||
@ -629,6 +658,48 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
# i.e. before or after passing it to InvokeAIDiffuserComponent
|
# i.e. before or after passing it to InvokeAIDiffuserComponent
|
||||||
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
|
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
|
||||||
|
|
||||||
|
# default is no controlnet, so set controlnet processing output to None
|
||||||
|
down_block_res_samples, mid_block_res_sample = None, None
|
||||||
|
|
||||||
|
if control_data is not None:
|
||||||
|
if conditioning_data.guidance_scale > 1.0:
|
||||||
|
# expand the latents input to control model if doing classifier free guidance
|
||||||
|
# (which I think for now is always true, there is conditional elsewhere that stops execution if
|
||||||
|
# classifier_free_guidance is <= 1.0 ?)
|
||||||
|
latent_control_input = torch.cat([latent_model_input] * 2)
|
||||||
|
else:
|
||||||
|
latent_control_input = latent_model_input
|
||||||
|
# control_data should be type List[ControlNetData]
|
||||||
|
# this loop covers both ControlNet (one ControlNetData in list)
|
||||||
|
# and MultiControlNet (multiple ControlNetData in list)
|
||||||
|
for i, control_datum in enumerate(control_data):
|
||||||
|
# print("controlnet", i, "==>", type(control_datum))
|
||||||
|
first_control_step = math.floor(control_datum.begin_step_percent * total_step_count)
|
||||||
|
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
|
||||||
|
# only apply controlnet if current step is within the controlnet's begin/end step range
|
||||||
|
if step_index >= first_control_step and step_index <= last_control_step:
|
||||||
|
# print("running controlnet", i, "for step", step_index)
|
||||||
|
down_samples, mid_sample = control_datum.model(
|
||||||
|
sample=latent_control_input,
|
||||||
|
timestep=timestep,
|
||||||
|
encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings,
|
||||||
|
conditioning_data.text_embeddings]),
|
||||||
|
controlnet_cond=control_datum.image_tensor,
|
||||||
|
conditioning_scale=control_datum.weight,
|
||||||
|
# cross_attention_kwargs,
|
||||||
|
guess_mode=False,
|
||||||
|
return_dict=False,
|
||||||
|
)
|
||||||
|
if down_block_res_samples is None and mid_block_res_sample is None:
|
||||||
|
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
|
||||||
|
else:
|
||||||
|
# add controlnet outputs together if have multiple controlnets
|
||||||
|
down_block_res_samples = [
|
||||||
|
samples_prev + samples_curr
|
||||||
|
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
|
||||||
|
]
|
||||||
|
mid_block_res_sample += mid_sample
|
||||||
|
|
||||||
# predict the noise residual
|
# predict the noise residual
|
||||||
noise_pred = self.invokeai_diffuser.do_diffusion_step(
|
noise_pred = self.invokeai_diffuser.do_diffusion_step(
|
||||||
latent_model_input,
|
latent_model_input,
|
||||||
@ -638,6 +709,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
conditioning_data.guidance_scale,
|
conditioning_data.guidance_scale,
|
||||||
step_index=step_index,
|
step_index=step_index,
|
||||||
total_step_count=total_step_count,
|
total_step_count=total_step_count,
|
||||||
|
down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
|
||||||
|
mid_block_additional_residual=mid_block_res_sample, # from controlnet(s)
|
||||||
)
|
)
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
@ -659,6 +732,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
t,
|
t,
|
||||||
text_embeddings,
|
text_embeddings,
|
||||||
cross_attention_kwargs: Optional[dict[str, Any]] = None,
|
cross_attention_kwargs: Optional[dict[str, Any]] = None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""predict the noise residual"""
|
"""predict the noise residual"""
|
||||||
if is_inpainting_model(self.unet) and latents.size(1) == 4:
|
if is_inpainting_model(self.unet) and latents.size(1) == 4:
|
||||||
@ -678,7 +752,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
# First three args should be positional, not keywords, so torch hooks can see them.
|
# First three args should be positional, not keywords, so torch hooks can see them.
|
||||||
return self.unet(
|
return self.unet(
|
||||||
latents, t, text_embeddings, cross_attention_kwargs=cross_attention_kwargs
|
latents, t, text_embeddings, cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
**kwargs,
|
||||||
).sample
|
).sample
|
||||||
|
|
||||||
def img2img_from_embeddings(
|
def img2img_from_embeddings(
|
||||||
@ -940,3 +1015,51 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
debug_image(
|
debug_image(
|
||||||
img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True
|
img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Copied from diffusers pipeline_stable_diffusion_controlnet.py
|
||||||
|
# Returns torch.Tensor of shape (batch_size, 3, height, width)
|
||||||
|
def prepare_control_image(
|
||||||
|
self,
|
||||||
|
image,
|
||||||
|
# FIXME: need to fix hardwiring of width and height, change to basing on latents dimensions?
|
||||||
|
# latents,
|
||||||
|
width=512, # should be 8 * latent.shape[3]
|
||||||
|
height=512, # should be 8 * latent height[2]
|
||||||
|
batch_size=1,
|
||||||
|
num_images_per_prompt=1,
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.float16,
|
||||||
|
do_classifier_free_guidance=True,
|
||||||
|
):
|
||||||
|
|
||||||
|
if not isinstance(image, torch.Tensor):
|
||||||
|
if isinstance(image, PIL.Image.Image):
|
||||||
|
image = [image]
|
||||||
|
|
||||||
|
if isinstance(image[0], PIL.Image.Image):
|
||||||
|
images = []
|
||||||
|
for image_ in image:
|
||||||
|
image_ = image_.convert("RGB")
|
||||||
|
image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
|
||||||
|
image_ = np.array(image_)
|
||||||
|
image_ = image_[None, :]
|
||||||
|
images.append(image_)
|
||||||
|
image = images
|
||||||
|
image = np.concatenate(image, axis=0)
|
||||||
|
image = np.array(image).astype(np.float32) / 255.0
|
||||||
|
image = image.transpose(0, 3, 1, 2)
|
||||||
|
image = torch.from_numpy(image)
|
||||||
|
elif isinstance(image[0], torch.Tensor):
|
||||||
|
image = torch.cat(image, dim=0)
|
||||||
|
|
||||||
|
image_batch_size = image.shape[0]
|
||||||
|
if image_batch_size == 1:
|
||||||
|
repeat_by = batch_size
|
||||||
|
else:
|
||||||
|
# image batch size is the same as prompt batch size
|
||||||
|
repeat_by = num_images_per_prompt
|
||||||
|
image = image.repeat_interleave(repeat_by, dim=0)
|
||||||
|
image = image.to(device=device, dtype=dtype)
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
image = torch.cat([image] * 2)
|
||||||
|
return image
|
||||||
|
@ -181,6 +181,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
unconditional_guidance_scale: float,
|
unconditional_guidance_scale: float,
|
||||||
step_index: Optional[int] = None,
|
step_index: Optional[int] = None,
|
||||||
total_step_count: Optional[int] = None,
|
total_step_count: Optional[int] = None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
:param x: current latents
|
:param x: current latents
|
||||||
@ -209,7 +210,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
if wants_hybrid_conditioning:
|
if wants_hybrid_conditioning:
|
||||||
unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(
|
unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning(
|
||||||
x, sigma, unconditioning, conditioning
|
x, sigma, unconditioning, conditioning, **kwargs,
|
||||||
)
|
)
|
||||||
elif wants_cross_attention_control:
|
elif wants_cross_attention_control:
|
||||||
(
|
(
|
||||||
@ -221,13 +222,14 @@ class InvokeAIDiffuserComponent:
|
|||||||
unconditioning,
|
unconditioning,
|
||||||
conditioning,
|
conditioning,
|
||||||
cross_attention_control_types_to_do,
|
cross_attention_control_types_to_do,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
elif self.sequential_guidance:
|
elif self.sequential_guidance:
|
||||||
(
|
(
|
||||||
unconditioned_next_x,
|
unconditioned_next_x,
|
||||||
conditioned_next_x,
|
conditioned_next_x,
|
||||||
) = self._apply_standard_conditioning_sequentially(
|
) = self._apply_standard_conditioning_sequentially(
|
||||||
x, sigma, unconditioning, conditioning
|
x, sigma, unconditioning, conditioning, **kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -235,7 +237,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
unconditioned_next_x,
|
unconditioned_next_x,
|
||||||
conditioned_next_x,
|
conditioned_next_x,
|
||||||
) = self._apply_standard_conditioning(
|
) = self._apply_standard_conditioning(
|
||||||
x, sigma, unconditioning, conditioning
|
x, sigma, unconditioning, conditioning, **kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
combined_next_x = self._combine(
|
combined_next_x = self._combine(
|
||||||
@ -282,13 +284,13 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
# methods below are called from do_diffusion_step and should be considered private to this class.
|
# methods below are called from do_diffusion_step and should be considered private to this class.
|
||||||
|
|
||||||
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning):
|
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
||||||
# fast batched path
|
# fast batched path
|
||||||
x_twice = torch.cat([x] * 2)
|
x_twice = torch.cat([x] * 2)
|
||||||
sigma_twice = torch.cat([sigma] * 2)
|
sigma_twice = torch.cat([sigma] * 2)
|
||||||
both_conditionings = torch.cat([unconditioning, conditioning])
|
both_conditionings = torch.cat([unconditioning, conditioning])
|
||||||
both_results = self.model_forward_callback(
|
both_results = self.model_forward_callback(
|
||||||
x_twice, sigma_twice, both_conditionings
|
x_twice, sigma_twice, both_conditionings, **kwargs,
|
||||||
)
|
)
|
||||||
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
|
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
|
||||||
if conditioned_next_x.device.type == "mps":
|
if conditioned_next_x.device.type == "mps":
|
||||||
@ -302,16 +304,17 @@ class InvokeAIDiffuserComponent:
|
|||||||
sigma,
|
sigma,
|
||||||
unconditioning: torch.Tensor,
|
unconditioning: torch.Tensor,
|
||||||
conditioning: torch.Tensor,
|
conditioning: torch.Tensor,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
# low-memory sequential path
|
# low-memory sequential path
|
||||||
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
|
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs)
|
||||||
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning)
|
conditioned_next_x = self.model_forward_callback(x, sigma, conditioning, **kwargs)
|
||||||
if conditioned_next_x.device.type == "mps":
|
if conditioned_next_x.device.type == "mps":
|
||||||
# prevent a result filled with zeros. seems to be a torch bug.
|
# prevent a result filled with zeros. seems to be a torch bug.
|
||||||
conditioned_next_x = conditioned_next_x.clone()
|
conditioned_next_x = conditioned_next_x.clone()
|
||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning):
|
def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
||||||
assert isinstance(conditioning, dict)
|
assert isinstance(conditioning, dict)
|
||||||
assert isinstance(unconditioning, dict)
|
assert isinstance(unconditioning, dict)
|
||||||
x_twice = torch.cat([x] * 2)
|
x_twice = torch.cat([x] * 2)
|
||||||
@ -326,7 +329,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
else:
|
else:
|
||||||
both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]])
|
both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]])
|
||||||
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(
|
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(
|
||||||
x_twice, sigma_twice, both_conditionings
|
x_twice, sigma_twice, both_conditionings, **kwargs,
|
||||||
).chunk(2)
|
).chunk(2)
|
||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
@ -337,6 +340,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
unconditioning,
|
unconditioning,
|
||||||
conditioning,
|
conditioning,
|
||||||
cross_attention_control_types_to_do,
|
cross_attention_control_types_to_do,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
if self.is_running_diffusers:
|
if self.is_running_diffusers:
|
||||||
return self._apply_cross_attention_controlled_conditioning__diffusers(
|
return self._apply_cross_attention_controlled_conditioning__diffusers(
|
||||||
@ -345,6 +349,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
unconditioning,
|
unconditioning,
|
||||||
conditioning,
|
conditioning,
|
||||||
cross_attention_control_types_to_do,
|
cross_attention_control_types_to_do,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self._apply_cross_attention_controlled_conditioning__compvis(
|
return self._apply_cross_attention_controlled_conditioning__compvis(
|
||||||
@ -353,6 +358,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
unconditioning,
|
unconditioning,
|
||||||
conditioning,
|
conditioning,
|
||||||
cross_attention_control_types_to_do,
|
cross_attention_control_types_to_do,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _apply_cross_attention_controlled_conditioning__diffusers(
|
def _apply_cross_attention_controlled_conditioning__diffusers(
|
||||||
@ -362,6 +368,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
unconditioning,
|
unconditioning,
|
||||||
conditioning,
|
conditioning,
|
||||||
cross_attention_control_types_to_do,
|
cross_attention_control_types_to_do,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
context: Context = self.cross_attention_control_context
|
context: Context = self.cross_attention_control_context
|
||||||
|
|
||||||
@ -377,6 +384,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
sigma,
|
sigma,
|
||||||
unconditioning,
|
unconditioning,
|
||||||
{"swap_cross_attn_context": cross_attn_processor_context},
|
{"swap_cross_attn_context": cross_attn_processor_context},
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# do requested cross attention types for conditioning (positive prompt)
|
# do requested cross attention types for conditioning (positive prompt)
|
||||||
@ -388,6 +396,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
sigma,
|
sigma,
|
||||||
conditioning,
|
conditioning,
|
||||||
{"swap_cross_attn_context": cross_attn_processor_context},
|
{"swap_cross_attn_context": cross_attn_processor_context},
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
@ -398,6 +407,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
unconditioning,
|
unconditioning,
|
||||||
conditioning,
|
conditioning,
|
||||||
cross_attention_control_types_to_do,
|
cross_attention_control_types_to_do,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
|
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
|
||||||
# slower non-batched path (20% slower on mac MPS)
|
# slower non-batched path (20% slower on mac MPS)
|
||||||
@ -411,13 +421,13 @@ class InvokeAIDiffuserComponent:
|
|||||||
context: Context = self.cross_attention_control_context
|
context: Context = self.cross_attention_control_context
|
||||||
|
|
||||||
try:
|
try:
|
||||||
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
|
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs)
|
||||||
|
|
||||||
# process x using the original prompt, saving the attention maps
|
# process x using the original prompt, saving the attention maps
|
||||||
# print("saving attention maps for", cross_attention_control_types_to_do)
|
# print("saving attention maps for", cross_attention_control_types_to_do)
|
||||||
for ca_type in cross_attention_control_types_to_do:
|
for ca_type in cross_attention_control_types_to_do:
|
||||||
context.request_save_attention_maps(ca_type)
|
context.request_save_attention_maps(ca_type)
|
||||||
_ = self.model_forward_callback(x, sigma, conditioning)
|
_ = self.model_forward_callback(x, sigma, conditioning, **kwargs,)
|
||||||
context.clear_requests(cleanup=False)
|
context.clear_requests(cleanup=False)
|
||||||
|
|
||||||
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
|
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
|
||||||
@ -428,7 +438,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
self.conditioning.cross_attention_control_args.edited_conditioning
|
self.conditioning.cross_attention_control_args.edited_conditioning
|
||||||
)
|
)
|
||||||
conditioned_next_x = self.model_forward_callback(
|
conditioned_next_x = self.model_forward_callback(
|
||||||
x, sigma, edited_conditioning
|
x, sigma, edited_conditioning, **kwargs,
|
||||||
)
|
)
|
||||||
context.clear_requests(cleanup=True)
|
context.clear_requests(cleanup=True)
|
||||||
|
|
||||||
|
@ -8,9 +8,16 @@ import type { TypedStartListening, TypedAddListener } from '@reduxjs/toolkit';
|
|||||||
|
|
||||||
import type { RootState, AppDispatch } from '../../store';
|
import type { RootState, AppDispatch } from '../../store';
|
||||||
import { addInitialImageSelectedListener } from './listeners/initialImageSelected';
|
import { addInitialImageSelectedListener } from './listeners/initialImageSelected';
|
||||||
import { addImageResultReceivedListener } from './listeners/invocationComplete';
|
import {
|
||||||
import { addImageUploadedListener } from './listeners/imageUploaded';
|
addImageUploadedFulfilledListener,
|
||||||
import { addRequestedImageDeletionListener } from './listeners/imageDeleted';
|
addImageUploadedRejectedListener,
|
||||||
|
} from './listeners/imageUploaded';
|
||||||
|
import {
|
||||||
|
addImageDeletedFulfilledListener,
|
||||||
|
addImageDeletedPendingListener,
|
||||||
|
addImageDeletedRejectedListener,
|
||||||
|
addRequestedImageDeletionListener,
|
||||||
|
} from './listeners/imageDeleted';
|
||||||
import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
|
import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
|
||||||
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
|
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
|
||||||
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
|
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
|
||||||
@ -19,6 +26,47 @@ import { addCanvasSavedToGalleryListener } from './listeners/canvasSavedToGaller
|
|||||||
import { addCanvasDownloadedAsImageListener } from './listeners/canvasDownloadedAsImage';
|
import { addCanvasDownloadedAsImageListener } from './listeners/canvasDownloadedAsImage';
|
||||||
import { addCanvasCopiedToClipboardListener } from './listeners/canvasCopiedToClipboard';
|
import { addCanvasCopiedToClipboardListener } from './listeners/canvasCopiedToClipboard';
|
||||||
import { addCanvasMergedListener } from './listeners/canvasMerged';
|
import { addCanvasMergedListener } from './listeners/canvasMerged';
|
||||||
|
import { addGeneratorProgressListener } from './listeners/socketio/generatorProgress';
|
||||||
|
import { addGraphExecutionStateCompleteListener } from './listeners/socketio/graphExecutionStateComplete';
|
||||||
|
import { addInvocationCompleteListener } from './listeners/socketio/invocationComplete';
|
||||||
|
import { addInvocationErrorListener } from './listeners/socketio/invocationError';
|
||||||
|
import { addInvocationStartedListener } from './listeners/socketio/invocationStarted';
|
||||||
|
import { addSocketConnectedListener } from './listeners/socketio/socketConnected';
|
||||||
|
import { addSocketDisconnectedListener } from './listeners/socketio/socketDisconnected';
|
||||||
|
import { addSocketSubscribedListener } from './listeners/socketio/socketSubscribed';
|
||||||
|
import { addSocketUnsubscribedListener } from './listeners/socketio/socketUnsubscribed';
|
||||||
|
import { addSessionReadyToInvokeListener } from './listeners/sessionReadyToInvoke';
|
||||||
|
import {
|
||||||
|
addImageMetadataReceivedFulfilledListener,
|
||||||
|
addImageMetadataReceivedRejectedListener,
|
||||||
|
} from './listeners/imageMetadataReceived';
|
||||||
|
import {
|
||||||
|
addImageUrlsReceivedFulfilledListener,
|
||||||
|
addImageUrlsReceivedRejectedListener,
|
||||||
|
} from './listeners/imageUrlsReceived';
|
||||||
|
import {
|
||||||
|
addSessionCreatedFulfilledListener,
|
||||||
|
addSessionCreatedPendingListener,
|
||||||
|
addSessionCreatedRejectedListener,
|
||||||
|
} from './listeners/sessionCreated';
|
||||||
|
import {
|
||||||
|
addSessionInvokedFulfilledListener,
|
||||||
|
addSessionInvokedPendingListener,
|
||||||
|
addSessionInvokedRejectedListener,
|
||||||
|
} from './listeners/sessionInvoked';
|
||||||
|
import {
|
||||||
|
addSessionCanceledFulfilledListener,
|
||||||
|
addSessionCanceledPendingListener,
|
||||||
|
addSessionCanceledRejectedListener,
|
||||||
|
} from './listeners/sessionCanceled';
|
||||||
|
import {
|
||||||
|
addReceivedResultImagesPageFulfilledListener,
|
||||||
|
addReceivedResultImagesPageRejectedListener,
|
||||||
|
} from './listeners/receivedResultImagesPage';
|
||||||
|
import {
|
||||||
|
addReceivedUploadImagesPageFulfilledListener,
|
||||||
|
addReceivedUploadImagesPageRejectedListener,
|
||||||
|
} from './listeners/receivedUploadImagesPage';
|
||||||
|
|
||||||
export const listenerMiddleware = createListenerMiddleware();
|
export const listenerMiddleware = createListenerMiddleware();
|
||||||
|
|
||||||
@ -38,17 +86,67 @@ export type AppListenerEffect = ListenerEffect<
|
|||||||
AppDispatch
|
AppDispatch
|
||||||
>;
|
>;
|
||||||
|
|
||||||
addImageUploadedListener();
|
// Image uploaded
|
||||||
addInitialImageSelectedListener();
|
addImageUploadedFulfilledListener();
|
||||||
addImageResultReceivedListener();
|
addImageUploadedRejectedListener();
|
||||||
addRequestedImageDeletionListener();
|
|
||||||
|
|
||||||
|
addInitialImageSelectedListener();
|
||||||
|
|
||||||
|
// Image deleted
|
||||||
|
addRequestedImageDeletionListener();
|
||||||
|
addImageDeletedPendingListener();
|
||||||
|
addImageDeletedFulfilledListener();
|
||||||
|
addImageDeletedRejectedListener();
|
||||||
|
|
||||||
|
// Image metadata
|
||||||
|
addImageMetadataReceivedFulfilledListener();
|
||||||
|
addImageMetadataReceivedRejectedListener();
|
||||||
|
|
||||||
|
// Image URLs
|
||||||
|
addImageUrlsReceivedFulfilledListener();
|
||||||
|
addImageUrlsReceivedRejectedListener();
|
||||||
|
|
||||||
|
// User Invoked
|
||||||
addUserInvokedCanvasListener();
|
addUserInvokedCanvasListener();
|
||||||
addUserInvokedNodesListener();
|
addUserInvokedNodesListener();
|
||||||
addUserInvokedTextToImageListener();
|
addUserInvokedTextToImageListener();
|
||||||
addUserInvokedImageToImageListener();
|
addUserInvokedImageToImageListener();
|
||||||
|
addSessionReadyToInvokeListener();
|
||||||
|
|
||||||
|
// Canvas actions
|
||||||
addCanvasSavedToGalleryListener();
|
addCanvasSavedToGalleryListener();
|
||||||
addCanvasDownloadedAsImageListener();
|
addCanvasDownloadedAsImageListener();
|
||||||
addCanvasCopiedToClipboardListener();
|
addCanvasCopiedToClipboardListener();
|
||||||
addCanvasMergedListener();
|
addCanvasMergedListener();
|
||||||
|
|
||||||
|
// socketio
|
||||||
|
addGeneratorProgressListener();
|
||||||
|
addGraphExecutionStateCompleteListener();
|
||||||
|
addInvocationCompleteListener();
|
||||||
|
addInvocationErrorListener();
|
||||||
|
addInvocationStartedListener();
|
||||||
|
addSocketConnectedListener();
|
||||||
|
addSocketDisconnectedListener();
|
||||||
|
addSocketSubscribedListener();
|
||||||
|
addSocketUnsubscribedListener();
|
||||||
|
|
||||||
|
// Session Created
|
||||||
|
addSessionCreatedPendingListener();
|
||||||
|
addSessionCreatedFulfilledListener();
|
||||||
|
addSessionCreatedRejectedListener();
|
||||||
|
|
||||||
|
// Session Invoked
|
||||||
|
addSessionInvokedPendingListener();
|
||||||
|
addSessionInvokedFulfilledListener();
|
||||||
|
addSessionInvokedRejectedListener();
|
||||||
|
|
||||||
|
// Session Canceled
|
||||||
|
addSessionCanceledPendingListener();
|
||||||
|
addSessionCanceledFulfilledListener();
|
||||||
|
addSessionCanceledRejectedListener();
|
||||||
|
|
||||||
|
// Gallery pages
|
||||||
|
addReceivedResultImagesPageFulfilledListener();
|
||||||
|
addReceivedResultImagesPageRejectedListener();
|
||||||
|
addReceivedUploadImagesPageFulfilledListener();
|
||||||
|
addReceivedUploadImagesPageRejectedListener();
|
||||||
|
@ -52,7 +52,6 @@ export const addCanvasMergedListener = () => {
|
|||||||
|
|
||||||
dispatch(
|
dispatch(
|
||||||
imageUploaded({
|
imageUploaded({
|
||||||
imageType: 'intermediates',
|
|
||||||
formData: {
|
formData: {
|
||||||
file: new File([blob], filename, { type: 'image/png' }),
|
file: new File([blob], filename, { type: 'image/png' }),
|
||||||
},
|
},
|
||||||
@ -65,7 +64,7 @@ export const addCanvasMergedListener = () => {
|
|||||||
action.meta.arg.formData.file.name === filename
|
action.meta.arg.formData.file.name === filename
|
||||||
);
|
);
|
||||||
|
|
||||||
const mergedCanvasImage = payload.response;
|
const mergedCanvasImage = payload;
|
||||||
|
|
||||||
dispatch(
|
dispatch(
|
||||||
setMergedCanvas({
|
setMergedCanvas({
|
||||||
|
@ -29,7 +29,6 @@ export const addCanvasSavedToGalleryListener = () => {
|
|||||||
|
|
||||||
dispatch(
|
dispatch(
|
||||||
imageUploaded({
|
imageUploaded({
|
||||||
imageType: 'results',
|
|
||||||
formData: {
|
formData: {
|
||||||
file: new File([blob], 'mergedCanvas.png', { type: 'image/png' }),
|
file: new File([blob], 'mergedCanvas.png', { type: 'image/png' }),
|
||||||
},
|
},
|
||||||
|
@ -4,9 +4,14 @@ import { imageDeleted } from 'services/thunks/image';
|
|||||||
import { log } from 'app/logging/useLogger';
|
import { log } from 'app/logging/useLogger';
|
||||||
import { clamp } from 'lodash-es';
|
import { clamp } from 'lodash-es';
|
||||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||||
|
import { uploadsAdapter } from 'features/gallery/store/uploadsSlice';
|
||||||
|
import { resultsAdapter } from 'features/gallery/store/resultsSlice';
|
||||||
|
|
||||||
const moduleLog = log.child({ namespace: 'addRequestedImageDeletionListener' });
|
const moduleLog = log.child({ namespace: 'addRequestedImageDeletionListener' });
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Called when the user requests an image deletion
|
||||||
|
*/
|
||||||
export const addRequestedImageDeletionListener = () => {
|
export const addRequestedImageDeletionListener = () => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: requestedImageDeletion,
|
actionCreator: requestedImageDeletion,
|
||||||
@ -19,11 +24,6 @@ export const addRequestedImageDeletionListener = () => {
|
|||||||
|
|
||||||
const { image_name, image_type } = image;
|
const { image_name, image_type } = image;
|
||||||
|
|
||||||
if (image_type !== 'uploads' && image_type !== 'results') {
|
|
||||||
moduleLog.warn({ data: image }, `Invalid image type ${image_type}`);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const selectedImageName = getState().gallery.selectedImage?.image_name;
|
const selectedImageName = getState().gallery.selectedImage?.image_name;
|
||||||
|
|
||||||
if (selectedImageName === image_name) {
|
if (selectedImageName === image_name) {
|
||||||
@ -57,3 +57,49 @@ export const addRequestedImageDeletionListener = () => {
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Called when the actual delete request is sent to the server
|
||||||
|
*/
|
||||||
|
export const addImageDeletedPendingListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: imageDeleted.pending,
|
||||||
|
effect: (action, { dispatch, getState }) => {
|
||||||
|
const { imageName, imageType } = action.meta.arg;
|
||||||
|
// Preemptively remove the image from the gallery
|
||||||
|
if (imageType === 'uploads') {
|
||||||
|
uploadsAdapter.removeOne(getState().uploads, imageName);
|
||||||
|
}
|
||||||
|
if (imageType === 'results') {
|
||||||
|
resultsAdapter.removeOne(getState().results, imageName);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Called on successful delete
|
||||||
|
*/
|
||||||
|
export const addImageDeletedFulfilledListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: imageDeleted.fulfilled,
|
||||||
|
effect: (action, { dispatch, getState }) => {
|
||||||
|
moduleLog.debug({ data: { image: action.meta.arg } }, 'Image deleted');
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Called on failed delete
|
||||||
|
*/
|
||||||
|
export const addImageDeletedRejectedListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: imageDeleted.rejected,
|
||||||
|
effect: (action, { dispatch, getState }) => {
|
||||||
|
moduleLog.debug(
|
||||||
|
{ data: { image: action.meta.arg } },
|
||||||
|
'Unable to delete image'
|
||||||
|
);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
||||||
|
@ -0,0 +1,43 @@
|
|||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import { startAppListening } from '..';
|
||||||
|
import { imageMetadataReceived } from 'services/thunks/image';
|
||||||
|
import {
|
||||||
|
ResultsImageDTO,
|
||||||
|
resultUpserted,
|
||||||
|
} from 'features/gallery/store/resultsSlice';
|
||||||
|
import {
|
||||||
|
UploadsImageDTO,
|
||||||
|
uploadUpserted,
|
||||||
|
} from 'features/gallery/store/uploadsSlice';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'image' });
|
||||||
|
|
||||||
|
export const addImageMetadataReceivedFulfilledListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: imageMetadataReceived.fulfilled,
|
||||||
|
effect: (action, { getState, dispatch }) => {
|
||||||
|
const image = action.payload;
|
||||||
|
moduleLog.debug({ data: { image } }, 'Image metadata received');
|
||||||
|
|
||||||
|
if (image.image_type === 'results') {
|
||||||
|
dispatch(resultUpserted(action.payload as ResultsImageDTO));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (image.image_type === 'uploads') {
|
||||||
|
dispatch(uploadUpserted(action.payload as UploadsImageDTO));
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
export const addImageMetadataReceivedRejectedListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: imageMetadataReceived.rejected,
|
||||||
|
effect: (action, { getState, dispatch }) => {
|
||||||
|
moduleLog.debug(
|
||||||
|
{ data: { image: action.meta.arg } },
|
||||||
|
'Problem receiving image metadata'
|
||||||
|
);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -1,25 +1,31 @@
|
|||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
import { uploadAdded } from 'features/gallery/store/uploadsSlice';
|
import { uploadUpserted } from 'features/gallery/store/uploadsSlice';
|
||||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||||
import { imageUploaded } from 'services/thunks/image';
|
import { imageUploaded } from 'services/thunks/image';
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
import { initialImageSelected } from 'features/parameters/store/actions';
|
import { initialImageSelected } from 'features/parameters/store/actions';
|
||||||
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
||||||
import { resultAdded } from 'features/gallery/store/resultsSlice';
|
import { resultUpserted } from 'features/gallery/store/resultsSlice';
|
||||||
import { isResultsImageDTO, isUploadsImageDTO } from 'services/types/guards';
|
import { isResultsImageDTO, isUploadsImageDTO } from 'services/types/guards';
|
||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
|
||||||
export const addImageUploadedListener = () => {
|
const moduleLog = log.child({ namespace: 'image' });
|
||||||
|
|
||||||
|
export const addImageUploadedFulfilledListener = () => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
predicate: (action): action is ReturnType<typeof imageUploaded.fulfilled> =>
|
predicate: (action): action is ReturnType<typeof imageUploaded.fulfilled> =>
|
||||||
imageUploaded.fulfilled.match(action) &&
|
imageUploaded.fulfilled.match(action) &&
|
||||||
action.payload.response.image_type !== 'intermediates',
|
action.payload.is_intermediate === false,
|
||||||
effect: (action, { dispatch, getState }) => {
|
effect: (action, { dispatch, getState }) => {
|
||||||
const { response: image } = action.payload;
|
const image = action.payload;
|
||||||
|
|
||||||
|
moduleLog.debug({ arg: '<Blob>', image }, 'Image uploaded');
|
||||||
|
|
||||||
const state = getState();
|
const state = getState();
|
||||||
|
|
||||||
|
// Handle uploads
|
||||||
if (isUploadsImageDTO(image)) {
|
if (isUploadsImageDTO(image)) {
|
||||||
dispatch(uploadAdded(image));
|
dispatch(uploadUpserted(image));
|
||||||
|
|
||||||
dispatch(addToast({ title: 'Image Uploaded', status: 'success' }));
|
dispatch(addToast({ title: 'Image Uploaded', status: 'success' }));
|
||||||
|
|
||||||
@ -36,9 +42,26 @@ export const addImageUploadedListener = () => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle results
|
||||||
|
// TODO: Can this ever happen? I don't think so...
|
||||||
if (isResultsImageDTO(image)) {
|
if (isResultsImageDTO(image)) {
|
||||||
dispatch(resultAdded(image));
|
dispatch(resultUpserted(image));
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const addImageUploadedRejectedListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: imageUploaded.rejected,
|
||||||
|
effect: (action, { dispatch }) => {
|
||||||
|
dispatch(
|
||||||
|
addToast({
|
||||||
|
title: 'Image Upload Failed',
|
||||||
|
description: action.error.message,
|
||||||
|
status: 'error',
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
||||||
|
@ -0,0 +1,51 @@
|
|||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import { startAppListening } from '..';
|
||||||
|
import { imageUrlsReceived } from 'services/thunks/image';
|
||||||
|
import { resultsAdapter } from 'features/gallery/store/resultsSlice';
|
||||||
|
import { uploadsAdapter } from 'features/gallery/store/uploadsSlice';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'image' });
|
||||||
|
|
||||||
|
export const addImageUrlsReceivedFulfilledListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: imageUrlsReceived.fulfilled,
|
||||||
|
effect: (action, { getState, dispatch }) => {
|
||||||
|
const image = action.payload;
|
||||||
|
moduleLog.debug({ data: { image } }, 'Image URLs received');
|
||||||
|
|
||||||
|
const { image_type, image_name, image_url, thumbnail_url } = image;
|
||||||
|
|
||||||
|
if (image_type === 'results') {
|
||||||
|
resultsAdapter.updateOne(getState().results, {
|
||||||
|
id: image_name,
|
||||||
|
changes: {
|
||||||
|
image_url,
|
||||||
|
thumbnail_url,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (image_type === 'uploads') {
|
||||||
|
uploadsAdapter.updateOne(getState().uploads, {
|
||||||
|
id: image_name,
|
||||||
|
changes: {
|
||||||
|
image_url,
|
||||||
|
thumbnail_url,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
export const addImageUrlsReceivedRejectedListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: imageUrlsReceived.rejected,
|
||||||
|
effect: (action, { getState, dispatch }) => {
|
||||||
|
moduleLog.debug(
|
||||||
|
{ data: { image: action.meta.arg } },
|
||||||
|
'Problem getting image URLs'
|
||||||
|
);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -1,62 +0,0 @@
|
|||||||
import { invocationComplete } from 'services/events/actions';
|
|
||||||
import { isImageOutput } from 'services/types/guards';
|
|
||||||
import {
|
|
||||||
imageMetadataReceived,
|
|
||||||
imageUrlsReceived,
|
|
||||||
} from 'services/thunks/image';
|
|
||||||
import { startAppListening } from '..';
|
|
||||||
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
|
|
||||||
|
|
||||||
const nodeDenylist = ['dataURL_image'];
|
|
||||||
|
|
||||||
export const addImageResultReceivedListener = () => {
|
|
||||||
startAppListening({
|
|
||||||
predicate: (action) => {
|
|
||||||
if (
|
|
||||||
invocationComplete.match(action) &&
|
|
||||||
isImageOutput(action.payload.data.result)
|
|
||||||
) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
},
|
|
||||||
effect: async (action, { getState, dispatch, take }) => {
|
|
||||||
if (!invocationComplete.match(action)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const { data } = action.payload;
|
|
||||||
const { result, node, graph_execution_state_id } = data;
|
|
||||||
|
|
||||||
if (isImageOutput(result) && !nodeDenylist.includes(node.type)) {
|
|
||||||
const { image_name, image_type } = result.image;
|
|
||||||
|
|
||||||
dispatch(
|
|
||||||
imageUrlsReceived({ imageName: image_name, imageType: image_type })
|
|
||||||
);
|
|
||||||
|
|
||||||
dispatch(
|
|
||||||
imageMetadataReceived({
|
|
||||||
imageName: image_name,
|
|
||||||
imageType: image_type,
|
|
||||||
})
|
|
||||||
);
|
|
||||||
|
|
||||||
// Handle canvas image
|
|
||||||
if (
|
|
||||||
graph_execution_state_id ===
|
|
||||||
getState().canvas.layerState.stagingArea.sessionId
|
|
||||||
) {
|
|
||||||
const [{ payload: image }] = await take(
|
|
||||||
(
|
|
||||||
action
|
|
||||||
): action is ReturnType<typeof imageMetadataReceived.fulfilled> =>
|
|
||||||
imageMetadataReceived.fulfilled.match(action) &&
|
|
||||||
action.payload.image_name === image_name
|
|
||||||
);
|
|
||||||
dispatch(addImageToStagingArea(image));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
});
|
|
||||||
};
|
|
@ -0,0 +1,33 @@
|
|||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import { startAppListening } from '..';
|
||||||
|
import { receivedResultImagesPage } from 'services/thunks/gallery';
|
||||||
|
import { serializeError } from 'serialize-error';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'gallery' });
|
||||||
|
|
||||||
|
export const addReceivedResultImagesPageFulfilledListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: receivedResultImagesPage.fulfilled,
|
||||||
|
effect: (action, { getState, dispatch }) => {
|
||||||
|
const page = action.payload;
|
||||||
|
moduleLog.debug(
|
||||||
|
{ data: { page } },
|
||||||
|
`Received ${page.items.length} results`
|
||||||
|
);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
export const addReceivedResultImagesPageRejectedListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: receivedResultImagesPage.rejected,
|
||||||
|
effect: (action, { getState, dispatch }) => {
|
||||||
|
if (action.payload) {
|
||||||
|
moduleLog.debug(
|
||||||
|
{ data: { error: serializeError(action.payload.error) } },
|
||||||
|
'Problem receiving results'
|
||||||
|
);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -0,0 +1,33 @@
|
|||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import { startAppListening } from '..';
|
||||||
|
import { receivedUploadImagesPage } from 'services/thunks/gallery';
|
||||||
|
import { serializeError } from 'serialize-error';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'gallery' });
|
||||||
|
|
||||||
|
export const addReceivedUploadImagesPageFulfilledListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: receivedUploadImagesPage.fulfilled,
|
||||||
|
effect: (action, { getState, dispatch }) => {
|
||||||
|
const page = action.payload;
|
||||||
|
moduleLog.debug(
|
||||||
|
{ data: { page } },
|
||||||
|
`Received ${page.items.length} uploads`
|
||||||
|
);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
export const addReceivedUploadImagesPageRejectedListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: receivedUploadImagesPage.rejected,
|
||||||
|
effect: (action, { getState, dispatch }) => {
|
||||||
|
if (action.payload) {
|
||||||
|
moduleLog.debug(
|
||||||
|
{ data: { error: serializeError(action.payload.error) } },
|
||||||
|
'Problem receiving uploads'
|
||||||
|
);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -0,0 +1,48 @@
|
|||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import { startAppListening } from '..';
|
||||||
|
import { sessionCanceled } from 'services/thunks/session';
|
||||||
|
import { serializeError } from 'serialize-error';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'session' });
|
||||||
|
|
||||||
|
export const addSessionCanceledPendingListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: sessionCanceled.pending,
|
||||||
|
effect: (action, { getState, dispatch }) => {
|
||||||
|
//
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
export const addSessionCanceledFulfilledListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: sessionCanceled.fulfilled,
|
||||||
|
effect: (action, { getState, dispatch }) => {
|
||||||
|
const { sessionId } = action.meta.arg;
|
||||||
|
moduleLog.debug(
|
||||||
|
{ data: { sessionId } },
|
||||||
|
`Session canceled (${sessionId})`
|
||||||
|
);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
export const addSessionCanceledRejectedListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: sessionCanceled.rejected,
|
||||||
|
effect: (action, { getState, dispatch }) => {
|
||||||
|
if (action.payload) {
|
||||||
|
const { arg, error } = action.payload;
|
||||||
|
moduleLog.error(
|
||||||
|
{
|
||||||
|
data: {
|
||||||
|
arg,
|
||||||
|
error: serializeError(error),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
`Problem canceling session`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -0,0 +1,45 @@
|
|||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import { startAppListening } from '..';
|
||||||
|
import { sessionCreated } from 'services/thunks/session';
|
||||||
|
import { serializeError } from 'serialize-error';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'session' });
|
||||||
|
|
||||||
|
export const addSessionCreatedPendingListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: sessionCreated.pending,
|
||||||
|
effect: (action, { getState, dispatch }) => {
|
||||||
|
//
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
export const addSessionCreatedFulfilledListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: sessionCreated.fulfilled,
|
||||||
|
effect: (action, { getState, dispatch }) => {
|
||||||
|
const session = action.payload;
|
||||||
|
moduleLog.debug({ data: { session } }, `Session created (${session.id})`);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
export const addSessionCreatedRejectedListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: sessionCreated.rejected,
|
||||||
|
effect: (action, { getState, dispatch }) => {
|
||||||
|
if (action.payload) {
|
||||||
|
const { arg, error } = action.payload;
|
||||||
|
moduleLog.error(
|
||||||
|
{
|
||||||
|
data: {
|
||||||
|
arg,
|
||||||
|
error: serializeError(error),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
`Problem creating session`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -0,0 +1,48 @@
|
|||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import { startAppListening } from '..';
|
||||||
|
import { sessionInvoked } from 'services/thunks/session';
|
||||||
|
import { serializeError } from 'serialize-error';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'session' });
|
||||||
|
|
||||||
|
export const addSessionInvokedPendingListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: sessionInvoked.pending,
|
||||||
|
effect: (action, { getState, dispatch }) => {
|
||||||
|
//
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
export const addSessionInvokedFulfilledListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: sessionInvoked.fulfilled,
|
||||||
|
effect: (action, { getState, dispatch }) => {
|
||||||
|
const { sessionId } = action.meta.arg;
|
||||||
|
moduleLog.debug(
|
||||||
|
{ data: { sessionId } },
|
||||||
|
`Session invoked (${sessionId})`
|
||||||
|
);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
export const addSessionInvokedRejectedListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: sessionInvoked.rejected,
|
||||||
|
effect: (action, { getState, dispatch }) => {
|
||||||
|
if (action.payload) {
|
||||||
|
const { arg, error } = action.payload;
|
||||||
|
moduleLog.error(
|
||||||
|
{
|
||||||
|
data: {
|
||||||
|
arg,
|
||||||
|
error: serializeError(error),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
`Problem invoking session`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -0,0 +1,22 @@
|
|||||||
|
import { startAppListening } from '..';
|
||||||
|
import { sessionInvoked } from 'services/thunks/session';
|
||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import { sessionReadyToInvoke } from 'features/system/store/actions';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'session' });
|
||||||
|
|
||||||
|
export const addSessionReadyToInvokeListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: sessionReadyToInvoke,
|
||||||
|
effect: (action, { getState, dispatch }) => {
|
||||||
|
const { sessionId } = getState().system;
|
||||||
|
if (sessionId) {
|
||||||
|
moduleLog.debug(
|
||||||
|
{ sessionId },
|
||||||
|
`Session ready to invoke (${sessionId})})`
|
||||||
|
);
|
||||||
|
dispatch(sessionInvoked({ sessionId }));
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -0,0 +1,28 @@
|
|||||||
|
import { startAppListening } from '../..';
|
||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import { generatorProgress } from 'services/events/actions';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'socketio' });
|
||||||
|
|
||||||
|
export const addGeneratorProgressListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: generatorProgress,
|
||||||
|
effect: (action, { dispatch, getState }) => {
|
||||||
|
if (
|
||||||
|
getState().system.canceledSession ===
|
||||||
|
action.payload.data.graph_execution_state_id
|
||||||
|
) {
|
||||||
|
moduleLog.trace(
|
||||||
|
action.payload,
|
||||||
|
'Ignored generator progress for canceled session'
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
moduleLog.trace(
|
||||||
|
action.payload,
|
||||||
|
`Generator progress (${action.payload.data.node.type})`
|
||||||
|
);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -0,0 +1,17 @@
|
|||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import { graphExecutionStateComplete } from 'services/events/actions';
|
||||||
|
import { startAppListening } from '../..';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'socketio' });
|
||||||
|
|
||||||
|
export const addGraphExecutionStateCompleteListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: graphExecutionStateComplete,
|
||||||
|
effect: (action, { dispatch, getState }) => {
|
||||||
|
moduleLog.debug(
|
||||||
|
action.payload,
|
||||||
|
`Session invocation complete (${action.payload.data.graph_execution_state_id})`
|
||||||
|
);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -0,0 +1,74 @@
|
|||||||
|
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
|
||||||
|
import { startAppListening } from '../..';
|
||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import { invocationComplete } from 'services/events/actions';
|
||||||
|
import { imageMetadataReceived } from 'services/thunks/image';
|
||||||
|
import { sessionCanceled } from 'services/thunks/session';
|
||||||
|
import { isImageOutput } from 'services/types/guards';
|
||||||
|
import { progressImageSet } from 'features/system/store/systemSlice';
|
||||||
|
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'socketio' });
|
||||||
|
const nodeDenylist = ['dataURL_image'];
|
||||||
|
|
||||||
|
export const addInvocationCompleteListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: invocationComplete,
|
||||||
|
effect: async (action, { dispatch, getState, take }) => {
|
||||||
|
moduleLog.debug(
|
||||||
|
action.payload,
|
||||||
|
`Invocation complete (${action.payload.data.node.type})`
|
||||||
|
);
|
||||||
|
|
||||||
|
const sessionId = action.payload.data.graph_execution_state_id;
|
||||||
|
|
||||||
|
const { cancelType, isCancelScheduled } = getState().system;
|
||||||
|
|
||||||
|
// Handle scheduled cancelation
|
||||||
|
if (cancelType === 'scheduled' && isCancelScheduled) {
|
||||||
|
dispatch(sessionCanceled({ sessionId }));
|
||||||
|
}
|
||||||
|
|
||||||
|
const { data } = action.payload;
|
||||||
|
const { result, node, graph_execution_state_id } = data;
|
||||||
|
|
||||||
|
// This complete event has an associated image output
|
||||||
|
if (isImageOutput(result) && !nodeDenylist.includes(node.type)) {
|
||||||
|
const { image_name, image_type } = result.image;
|
||||||
|
|
||||||
|
// Get its metadata
|
||||||
|
dispatch(
|
||||||
|
imageMetadataReceived({
|
||||||
|
imageName: image_name,
|
||||||
|
imageType: image_type,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
|
const [{ payload: imageDTO }] = await take(
|
||||||
|
imageMetadataReceived.fulfilled.match
|
||||||
|
);
|
||||||
|
|
||||||
|
if (getState().gallery.shouldAutoSwitchToNewImages) {
|
||||||
|
dispatch(imageSelected(imageDTO));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle canvas image
|
||||||
|
if (
|
||||||
|
graph_execution_state_id ===
|
||||||
|
getState().canvas.layerState.stagingArea.sessionId
|
||||||
|
) {
|
||||||
|
const [{ payload: image }] = await take(
|
||||||
|
(
|
||||||
|
action
|
||||||
|
): action is ReturnType<typeof imageMetadataReceived.fulfilled> =>
|
||||||
|
imageMetadataReceived.fulfilled.match(action) &&
|
||||||
|
action.payload.image_name === image_name
|
||||||
|
);
|
||||||
|
dispatch(addImageToStagingArea(image));
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(progressImageSet(null));
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -0,0 +1,17 @@
|
|||||||
|
import { startAppListening } from '../..';
|
||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import { invocationError } from 'services/events/actions';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'socketio' });
|
||||||
|
|
||||||
|
export const addInvocationErrorListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: invocationError,
|
||||||
|
effect: (action, { dispatch, getState }) => {
|
||||||
|
moduleLog.error(
|
||||||
|
action.payload,
|
||||||
|
`Invocation error (${action.payload.data.node.type})`
|
||||||
|
);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -0,0 +1,28 @@
|
|||||||
|
import { startAppListening } from '../..';
|
||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import { invocationStarted } from 'services/events/actions';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'socketio' });
|
||||||
|
|
||||||
|
export const addInvocationStartedListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: invocationStarted,
|
||||||
|
effect: (action, { dispatch, getState }) => {
|
||||||
|
if (
|
||||||
|
getState().system.canceledSession ===
|
||||||
|
action.payload.data.graph_execution_state_id
|
||||||
|
) {
|
||||||
|
moduleLog.trace(
|
||||||
|
action.payload,
|
||||||
|
'Ignored invocation started for canceled session'
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
moduleLog.debug(
|
||||||
|
action.payload,
|
||||||
|
`Invocation started (${action.payload.data.node.type})`
|
||||||
|
);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -0,0 +1,43 @@
|
|||||||
|
import { startAppListening } from '../..';
|
||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import { socketConnected } from 'services/events/actions';
|
||||||
|
import {
|
||||||
|
receivedResultImagesPage,
|
||||||
|
receivedUploadImagesPage,
|
||||||
|
} from 'services/thunks/gallery';
|
||||||
|
import { receivedModels } from 'services/thunks/model';
|
||||||
|
import { receivedOpenAPISchema } from 'services/thunks/schema';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'socketio' });
|
||||||
|
|
||||||
|
export const addSocketConnectedListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: socketConnected,
|
||||||
|
effect: (action, { dispatch, getState }) => {
|
||||||
|
const { timestamp } = action.payload;
|
||||||
|
|
||||||
|
moduleLog.debug({ timestamp }, 'Connected');
|
||||||
|
|
||||||
|
const { results, uploads, models, nodes, config } = getState();
|
||||||
|
|
||||||
|
const { disabledTabs } = config;
|
||||||
|
|
||||||
|
// These thunks need to be dispatch in middleware; cannot handle in a reducer
|
||||||
|
if (!results.ids.length) {
|
||||||
|
dispatch(receivedResultImagesPage());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!uploads.ids.length) {
|
||||||
|
dispatch(receivedUploadImagesPage());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!models.ids.length) {
|
||||||
|
dispatch(receivedModels());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!nodes.schema && !disabledTabs.includes('nodes')) {
|
||||||
|
dispatch(receivedOpenAPISchema());
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -0,0 +1,14 @@
|
|||||||
|
import { startAppListening } from '../..';
|
||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import { socketDisconnected } from 'services/events/actions';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'socketio' });
|
||||||
|
|
||||||
|
export const addSocketDisconnectedListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: socketDisconnected,
|
||||||
|
effect: (action, { dispatch, getState }) => {
|
||||||
|
moduleLog.debug(action.payload, 'Disconnected');
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -0,0 +1,17 @@
|
|||||||
|
import { startAppListening } from '../..';
|
||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import { socketSubscribed } from 'services/events/actions';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'socketio' });
|
||||||
|
|
||||||
|
export const addSocketSubscribedListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: socketSubscribed,
|
||||||
|
effect: (action, { dispatch, getState }) => {
|
||||||
|
moduleLog.debug(
|
||||||
|
action.payload,
|
||||||
|
`Subscribed (${action.payload.sessionId}))`
|
||||||
|
);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -0,0 +1,17 @@
|
|||||||
|
import { startAppListening } from '../..';
|
||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import { socketUnsubscribed } from 'services/events/actions';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'socketio' });
|
||||||
|
|
||||||
|
export const addSocketUnsubscribedListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: socketUnsubscribed,
|
||||||
|
effect: (action, { dispatch, getState }) => {
|
||||||
|
moduleLog.debug(
|
||||||
|
action.payload,
|
||||||
|
`Unsubscribed (${action.payload.sessionId})`
|
||||||
|
);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -1,9 +1,9 @@
|
|||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
import { sessionCreated, sessionInvoked } from 'services/thunks/session';
|
import { sessionCreated } from 'services/thunks/session';
|
||||||
import { buildCanvasGraphComponents } from 'features/nodes/util/graphBuilders/buildCanvasGraph';
|
import { buildCanvasGraphComponents } from 'features/nodes/util/graphBuilders/buildCanvasGraph';
|
||||||
import { log } from 'app/logging/useLogger';
|
import { log } from 'app/logging/useLogger';
|
||||||
import { canvasGraphBuilt } from 'features/nodes/store/actions';
|
import { canvasGraphBuilt } from 'features/nodes/store/actions';
|
||||||
import { imageUploaded } from 'services/thunks/image';
|
import { imageUpdated, imageUploaded } from 'services/thunks/image';
|
||||||
import { v4 as uuidv4 } from 'uuid';
|
import { v4 as uuidv4 } from 'uuid';
|
||||||
import { Graph } from 'services/api';
|
import { Graph } from 'services/api';
|
||||||
import {
|
import {
|
||||||
@ -15,12 +15,22 @@ import { getCanvasData } from 'features/canvas/util/getCanvasData';
|
|||||||
import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGenerationMode';
|
import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGenerationMode';
|
||||||
import { blobToDataURL } from 'features/canvas/util/blobToDataURL';
|
import { blobToDataURL } from 'features/canvas/util/blobToDataURL';
|
||||||
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
|
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
|
||||||
|
import { sessionReadyToInvoke } from 'features/system/store/actions';
|
||||||
|
|
||||||
const moduleLog = log.child({ namespace: 'invoke' });
|
const moduleLog = log.child({ namespace: 'invoke' });
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This listener is responsible for building the canvas graph and blobs when the user invokes the canvas.
|
* This listener is responsible invoking the canvas. This involves a number of steps:
|
||||||
* It is also responsible for uploading the base and mask layers to the server.
|
*
|
||||||
|
* 1. Generate image blobs from the canvas layers
|
||||||
|
* 2. Determine the generation mode from the layers (txt2img, img2img, inpaint)
|
||||||
|
* 3. Build the canvas graph
|
||||||
|
* 4. Create the session with the graph
|
||||||
|
* 5. Upload the init image if necessary
|
||||||
|
* 6. Upload the mask image if necessary
|
||||||
|
* 7. Update the init and mask images with the session ID
|
||||||
|
* 8. Initialize the staging area if not yet initialized
|
||||||
|
* 9. Dispatch the sessionReadyToInvoke action to invoke the session
|
||||||
*/
|
*/
|
||||||
export const addUserInvokedCanvasListener = () => {
|
export const addUserInvokedCanvasListener = () => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
@ -70,63 +80,7 @@ export const addUserInvokedCanvasListener = () => {
|
|||||||
|
|
||||||
const { rangeNode, iterateNode, baseNode, edges } = graphComponents;
|
const { rangeNode, iterateNode, baseNode, edges } = graphComponents;
|
||||||
|
|
||||||
// Upload the base layer, to be used as init image
|
// Assemble! Note that this graph *does not have the init or mask image set yet!*
|
||||||
const baseFilename = `${uuidv4()}.png`;
|
|
||||||
|
|
||||||
dispatch(
|
|
||||||
imageUploaded({
|
|
||||||
imageType: 'intermediates',
|
|
||||||
formData: {
|
|
||||||
file: new File([baseBlob], baseFilename, { type: 'image/png' }),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
);
|
|
||||||
|
|
||||||
if (baseNode.type === 'img2img' || baseNode.type === 'inpaint') {
|
|
||||||
const [{ payload: basePayload }] = await take(
|
|
||||||
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
|
|
||||||
imageUploaded.fulfilled.match(action) &&
|
|
||||||
action.meta.arg.formData.file.name === baseFilename
|
|
||||||
);
|
|
||||||
|
|
||||||
const { image_name: baseName, image_type: baseType } =
|
|
||||||
basePayload.response;
|
|
||||||
|
|
||||||
baseNode.image = {
|
|
||||||
image_name: baseName,
|
|
||||||
image_type: baseType,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Upload the mask layer image
|
|
||||||
const maskFilename = `${uuidv4()}.png`;
|
|
||||||
|
|
||||||
if (baseNode.type === 'inpaint') {
|
|
||||||
dispatch(
|
|
||||||
imageUploaded({
|
|
||||||
imageType: 'intermediates',
|
|
||||||
formData: {
|
|
||||||
file: new File([maskBlob], maskFilename, { type: 'image/png' }),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
);
|
|
||||||
|
|
||||||
const [{ payload: maskPayload }] = await take(
|
|
||||||
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
|
|
||||||
imageUploaded.fulfilled.match(action) &&
|
|
||||||
action.meta.arg.formData.file.name === maskFilename
|
|
||||||
);
|
|
||||||
|
|
||||||
const { image_name: maskName, image_type: maskType } =
|
|
||||||
maskPayload.response;
|
|
||||||
|
|
||||||
baseNode.mask = {
|
|
||||||
image_name: maskName,
|
|
||||||
image_type: maskType,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Assemble!
|
|
||||||
const nodes: Graph['nodes'] = {
|
const nodes: Graph['nodes'] = {
|
||||||
[rangeNode.id]: rangeNode,
|
[rangeNode.id]: rangeNode,
|
||||||
[iterateNode.id]: iterateNode,
|
[iterateNode.id]: iterateNode,
|
||||||
@ -136,15 +90,90 @@ export const addUserInvokedCanvasListener = () => {
|
|||||||
const graph = { nodes, edges };
|
const graph = { nodes, edges };
|
||||||
|
|
||||||
dispatch(canvasGraphBuilt(graph));
|
dispatch(canvasGraphBuilt(graph));
|
||||||
moduleLog({ data: graph }, 'Canvas graph built');
|
|
||||||
|
|
||||||
// Actually create the session
|
moduleLog.debug({ data: graph }, 'Canvas graph built');
|
||||||
|
|
||||||
|
// If we are generating img2img or inpaint, we need to upload the init images
|
||||||
|
if (baseNode.type === 'img2img' || baseNode.type === 'inpaint') {
|
||||||
|
const baseFilename = `${uuidv4()}.png`;
|
||||||
|
dispatch(
|
||||||
|
imageUploaded({
|
||||||
|
formData: {
|
||||||
|
file: new File([baseBlob], baseFilename, { type: 'image/png' }),
|
||||||
|
},
|
||||||
|
isIntermediate: true,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
|
// Wait for the image to be uploaded
|
||||||
|
const [{ payload: baseImageDTO }] = await take(
|
||||||
|
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
|
||||||
|
imageUploaded.fulfilled.match(action) &&
|
||||||
|
action.meta.arg.formData.file.name === baseFilename
|
||||||
|
);
|
||||||
|
|
||||||
|
// Update the base node with the image name and type
|
||||||
|
baseNode.image = {
|
||||||
|
image_name: baseImageDTO.image_name,
|
||||||
|
image_type: baseImageDTO.image_type,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// For inpaint, we also need to upload the mask layer
|
||||||
|
if (baseNode.type === 'inpaint') {
|
||||||
|
const maskFilename = `${uuidv4()}.png`;
|
||||||
|
dispatch(
|
||||||
|
imageUploaded({
|
||||||
|
formData: {
|
||||||
|
file: new File([maskBlob], maskFilename, { type: 'image/png' }),
|
||||||
|
},
|
||||||
|
isIntermediate: true,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
|
// Wait for the mask to be uploaded
|
||||||
|
const [{ payload: maskImageDTO }] = await take(
|
||||||
|
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
|
||||||
|
imageUploaded.fulfilled.match(action) &&
|
||||||
|
action.meta.arg.formData.file.name === maskFilename
|
||||||
|
);
|
||||||
|
|
||||||
|
// Update the base node with the image name and type
|
||||||
|
baseNode.mask = {
|
||||||
|
image_name: maskImageDTO.image_name,
|
||||||
|
image_type: maskImageDTO.image_type,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the session and wait for response
|
||||||
dispatch(sessionCreated({ graph }));
|
dispatch(sessionCreated({ graph }));
|
||||||
|
const [sessionCreatedAction] = await take(sessionCreated.fulfilled.match);
|
||||||
|
const sessionId = sessionCreatedAction.payload.id;
|
||||||
|
|
||||||
// Wait for the session to be invoked (this is just the HTTP request to start processing)
|
// Associate the init image with the session, now that we have the session ID
|
||||||
const [{ meta }] = await take(sessionInvoked.fulfilled.match);
|
if (
|
||||||
|
(baseNode.type === 'img2img' || baseNode.type === 'inpaint') &&
|
||||||
|
baseNode.image
|
||||||
|
) {
|
||||||
|
dispatch(
|
||||||
|
imageUpdated({
|
||||||
|
imageName: baseNode.image.image_name,
|
||||||
|
imageType: baseNode.image.image_type,
|
||||||
|
requestBody: { session_id: sessionId },
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
const { sessionId } = meta.arg;
|
// Associate the mask image with the session, now that we have the session ID
|
||||||
|
if (baseNode.type === 'inpaint' && baseNode.mask) {
|
||||||
|
dispatch(
|
||||||
|
imageUpdated({
|
||||||
|
imageName: baseNode.mask.image_name,
|
||||||
|
imageType: baseNode.mask.image_type,
|
||||||
|
requestBody: { session_id: sessionId },
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
if (!state.canvas.layerState.stagingArea.boundingBox) {
|
if (!state.canvas.layerState.stagingArea.boundingBox) {
|
||||||
dispatch(
|
dispatch(
|
||||||
@ -158,7 +187,11 @@ export const addUserInvokedCanvasListener = () => {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Flag the session with the canvas session ID
|
||||||
dispatch(canvasSessionIdChanged(sessionId));
|
dispatch(canvasSessionIdChanged(sessionId));
|
||||||
|
|
||||||
|
// We are ready to invoke the session!
|
||||||
|
dispatch(sessionReadyToInvoke());
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
@ -4,6 +4,7 @@ import { sessionCreated } from 'services/thunks/session';
|
|||||||
import { log } from 'app/logging/useLogger';
|
import { log } from 'app/logging/useLogger';
|
||||||
import { imageToImageGraphBuilt } from 'features/nodes/store/actions';
|
import { imageToImageGraphBuilt } from 'features/nodes/store/actions';
|
||||||
import { userInvoked } from 'app/store/actions';
|
import { userInvoked } from 'app/store/actions';
|
||||||
|
import { sessionReadyToInvoke } from 'features/system/store/actions';
|
||||||
|
|
||||||
const moduleLog = log.child({ namespace: 'invoke' });
|
const moduleLog = log.child({ namespace: 'invoke' });
|
||||||
|
|
||||||
@ -11,14 +12,18 @@ export const addUserInvokedImageToImageListener = () => {
|
|||||||
startAppListening({
|
startAppListening({
|
||||||
predicate: (action): action is ReturnType<typeof userInvoked> =>
|
predicate: (action): action is ReturnType<typeof userInvoked> =>
|
||||||
userInvoked.match(action) && action.payload === 'img2img',
|
userInvoked.match(action) && action.payload === 'img2img',
|
||||||
effect: (action, { getState, dispatch }) => {
|
effect: async (action, { getState, dispatch, take }) => {
|
||||||
const state = getState();
|
const state = getState();
|
||||||
|
|
||||||
const graph = buildImageToImageGraph(state);
|
const graph = buildImageToImageGraph(state);
|
||||||
dispatch(imageToImageGraphBuilt(graph));
|
dispatch(imageToImageGraphBuilt(graph));
|
||||||
moduleLog({ data: graph }, 'Image to Image graph built');
|
moduleLog.debug({ data: graph }, 'Image to Image graph built');
|
||||||
|
|
||||||
dispatch(sessionCreated({ graph }));
|
dispatch(sessionCreated({ graph }));
|
||||||
|
|
||||||
|
await take(sessionCreated.fulfilled.match);
|
||||||
|
|
||||||
|
dispatch(sessionReadyToInvoke());
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
@ -4,6 +4,7 @@ import { buildNodesGraph } from 'features/nodes/util/graphBuilders/buildNodesGra
|
|||||||
import { log } from 'app/logging/useLogger';
|
import { log } from 'app/logging/useLogger';
|
||||||
import { nodesGraphBuilt } from 'features/nodes/store/actions';
|
import { nodesGraphBuilt } from 'features/nodes/store/actions';
|
||||||
import { userInvoked } from 'app/store/actions';
|
import { userInvoked } from 'app/store/actions';
|
||||||
|
import { sessionReadyToInvoke } from 'features/system/store/actions';
|
||||||
|
|
||||||
const moduleLog = log.child({ namespace: 'invoke' });
|
const moduleLog = log.child({ namespace: 'invoke' });
|
||||||
|
|
||||||
@ -11,14 +12,18 @@ export const addUserInvokedNodesListener = () => {
|
|||||||
startAppListening({
|
startAppListening({
|
||||||
predicate: (action): action is ReturnType<typeof userInvoked> =>
|
predicate: (action): action is ReturnType<typeof userInvoked> =>
|
||||||
userInvoked.match(action) && action.payload === 'nodes',
|
userInvoked.match(action) && action.payload === 'nodes',
|
||||||
effect: (action, { getState, dispatch }) => {
|
effect: async (action, { getState, dispatch, take }) => {
|
||||||
const state = getState();
|
const state = getState();
|
||||||
|
|
||||||
const graph = buildNodesGraph(state);
|
const graph = buildNodesGraph(state);
|
||||||
dispatch(nodesGraphBuilt(graph));
|
dispatch(nodesGraphBuilt(graph));
|
||||||
moduleLog({ data: graph }, 'Nodes graph built');
|
moduleLog.debug({ data: graph }, 'Nodes graph built');
|
||||||
|
|
||||||
dispatch(sessionCreated({ graph }));
|
dispatch(sessionCreated({ graph }));
|
||||||
|
|
||||||
|
await take(sessionCreated.fulfilled.match);
|
||||||
|
|
||||||
|
dispatch(sessionReadyToInvoke());
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
@ -4,6 +4,7 @@ import { sessionCreated } from 'services/thunks/session';
|
|||||||
import { log } from 'app/logging/useLogger';
|
import { log } from 'app/logging/useLogger';
|
||||||
import { textToImageGraphBuilt } from 'features/nodes/store/actions';
|
import { textToImageGraphBuilt } from 'features/nodes/store/actions';
|
||||||
import { userInvoked } from 'app/store/actions';
|
import { userInvoked } from 'app/store/actions';
|
||||||
|
import { sessionReadyToInvoke } from 'features/system/store/actions';
|
||||||
|
|
||||||
const moduleLog = log.child({ namespace: 'invoke' });
|
const moduleLog = log.child({ namespace: 'invoke' });
|
||||||
|
|
||||||
@ -11,14 +12,20 @@ export const addUserInvokedTextToImageListener = () => {
|
|||||||
startAppListening({
|
startAppListening({
|
||||||
predicate: (action): action is ReturnType<typeof userInvoked> =>
|
predicate: (action): action is ReturnType<typeof userInvoked> =>
|
||||||
userInvoked.match(action) && action.payload === 'txt2img',
|
userInvoked.match(action) && action.payload === 'txt2img',
|
||||||
effect: (action, { getState, dispatch }) => {
|
effect: async (action, { getState, dispatch, take }) => {
|
||||||
const state = getState();
|
const state = getState();
|
||||||
|
|
||||||
const graph = buildTextToImageGraph(state);
|
const graph = buildTextToImageGraph(state);
|
||||||
|
|
||||||
dispatch(textToImageGraphBuilt(graph));
|
dispatch(textToImageGraphBuilt(graph));
|
||||||
moduleLog({ data: graph }, 'Text to Image graph built');
|
|
||||||
|
moduleLog.debug({ data: graph }, 'Text to Image graph built');
|
||||||
|
|
||||||
dispatch(sessionCreated({ graph }));
|
dispatch(sessionCreated({ graph }));
|
||||||
|
|
||||||
|
await take(sessionCreated.fulfilled.match);
|
||||||
|
|
||||||
|
dispatch(sessionReadyToInvoke());
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
@ -16,6 +16,7 @@ import lightboxReducer from 'features/lightbox/store/lightboxSlice';
|
|||||||
import generationReducer from 'features/parameters/store/generationSlice';
|
import generationReducer from 'features/parameters/store/generationSlice';
|
||||||
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
|
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
|
||||||
import systemReducer from 'features/system/store/systemSlice';
|
import systemReducer from 'features/system/store/systemSlice';
|
||||||
|
// import sessionReducer from 'features/system/store/sessionSlice';
|
||||||
import configReducer from 'features/system/store/configSlice';
|
import configReducer from 'features/system/store/configSlice';
|
||||||
import uiReducer from 'features/ui/store/uiSlice';
|
import uiReducer from 'features/ui/store/uiSlice';
|
||||||
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
|
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
|
||||||
@ -46,6 +47,7 @@ const allReducers = {
|
|||||||
ui: uiReducer,
|
ui: uiReducer,
|
||||||
uploads: uploadsReducer,
|
uploads: uploadsReducer,
|
||||||
hotkeys: hotkeysReducer,
|
hotkeys: hotkeysReducer,
|
||||||
|
// session: sessionReducer,
|
||||||
};
|
};
|
||||||
|
|
||||||
const rootReducer = combineReducers(allReducers);
|
const rootReducer = combineReducers(allReducers);
|
||||||
|
@ -68,7 +68,6 @@ const ImageUploader = (props: ImageUploaderProps) => {
|
|||||||
async (file: File) => {
|
async (file: File) => {
|
||||||
dispatch(
|
dispatch(
|
||||||
imageUploaded({
|
imageUploaded({
|
||||||
imageType: 'uploads',
|
|
||||||
formData: { file },
|
formData: { file },
|
||||||
activeTabName,
|
activeTabName,
|
||||||
})
|
})
|
||||||
|
@ -1,5 +1,10 @@
|
|||||||
import { forEach, size } from 'lodash-es';
|
import { forEach, size } from 'lodash-es';
|
||||||
import { ImageField, LatentsField, ConditioningField } from 'services/api';
|
import {
|
||||||
|
ImageField,
|
||||||
|
LatentsField,
|
||||||
|
ConditioningField,
|
||||||
|
ControlField,
|
||||||
|
} from 'services/api';
|
||||||
|
|
||||||
const OBJECT_TYPESTRING = '[object Object]';
|
const OBJECT_TYPESTRING = '[object Object]';
|
||||||
const STRING_TYPESTRING = '[object String]';
|
const STRING_TYPESTRING = '[object String]';
|
||||||
@ -98,6 +103,24 @@ const parseConditioningField = (
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const parseControlField = (controlField: unknown): ControlField | undefined => {
|
||||||
|
// Must be an object
|
||||||
|
if (!isObject(controlField)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// A ControlField must have a `control`
|
||||||
|
if (!('control' in controlField)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// console.log(typeof controlField.control);
|
||||||
|
|
||||||
|
// Build a valid ControlField
|
||||||
|
return {
|
||||||
|
control: controlField.control,
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
type NodeMetadata = {
|
type NodeMetadata = {
|
||||||
[key: string]:
|
[key: string]:
|
||||||
| string
|
| string
|
||||||
@ -105,7 +128,8 @@ type NodeMetadata = {
|
|||||||
| boolean
|
| boolean
|
||||||
| ImageField
|
| ImageField
|
||||||
| LatentsField
|
| LatentsField
|
||||||
| ConditioningField;
|
| ConditioningField
|
||||||
|
| ControlField;
|
||||||
};
|
};
|
||||||
|
|
||||||
type InvokeAIMetadata = {
|
type InvokeAIMetadata = {
|
||||||
@ -131,7 +155,7 @@ export const parseNodeMetadata = (
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// the only valid object types are ImageField, LatentsField and ConditioningField
|
// the only valid object types are ImageField, LatentsField, ConditioningField, ControlField
|
||||||
if (isObject(nodeItem)) {
|
if (isObject(nodeItem)) {
|
||||||
if ('image_name' in nodeItem || 'image_type' in nodeItem) {
|
if ('image_name' in nodeItem || 'image_type' in nodeItem) {
|
||||||
const imageField = parseImageField(nodeItem);
|
const imageField = parseImageField(nodeItem);
|
||||||
@ -156,6 +180,14 @@ export const parseNodeMetadata = (
|
|||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if ('control' in nodeItem) {
|
||||||
|
const controlField = parseControlField(nodeItem);
|
||||||
|
if (controlField) {
|
||||||
|
parsed[nodeKey] = controlField;
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// otherwise we accept any string, number or boolean
|
// otherwise we accept any string, number or boolean
|
||||||
|
@ -109,8 +109,9 @@ const currentImageButtonsSelector = createSelector(
|
|||||||
isLightboxOpen,
|
isLightboxOpen,
|
||||||
shouldHidePreview,
|
shouldHidePreview,
|
||||||
image: selectedImage,
|
image: selectedImage,
|
||||||
seed: selectedImage?.metadata?.invokeai?.node?.seed,
|
seed: selectedImage?.metadata?.seed,
|
||||||
prompt: selectedImage?.metadata?.invokeai?.node?.prompt,
|
prompt: selectedImage?.metadata?.positive_conditioning,
|
||||||
|
negativePrompt: selectedImage?.metadata?.negative_conditioning,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -245,13 +246,16 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
const handleUseSeed = useCallback(() => {
|
const handleUseSeed = useCallback(() => {
|
||||||
recallSeed(image?.metadata?.invokeai?.node?.seed);
|
recallSeed(image?.metadata?.seed);
|
||||||
}, [image, recallSeed]);
|
}, [image, recallSeed]);
|
||||||
|
|
||||||
useHotkeys('s', handleUseSeed, [image]);
|
useHotkeys('s', handleUseSeed, [image]);
|
||||||
|
|
||||||
const handleUsePrompt = useCallback(() => {
|
const handleUsePrompt = useCallback(() => {
|
||||||
recallPrompt(image?.metadata?.invokeai?.node?.prompt);
|
recallPrompt(
|
||||||
|
image?.metadata?.positive_conditioning,
|
||||||
|
image?.metadata?.negative_conditioning
|
||||||
|
);
|
||||||
}, [image, recallPrompt]);
|
}, [image, recallPrompt]);
|
||||||
|
|
||||||
useHotkeys('p', handleUsePrompt, [image]);
|
useHotkeys('p', handleUsePrompt, [image]);
|
||||||
@ -454,7 +458,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
{t('parameters.copyImageToLink')}
|
{t('parameters.copyImageToLink')}
|
||||||
</IAIButton>
|
</IAIButton>
|
||||||
|
|
||||||
<Link download={true} href={getUrl(image?.url ?? '')}>
|
<Link download={true} href={getUrl(image?.image_url ?? '')}>
|
||||||
<IAIButton leftIcon={<FaDownload />} size="sm" w="100%">
|
<IAIButton leftIcon={<FaDownload />} size="sm" w="100%">
|
||||||
{t('parameters.downloadImage')}
|
{t('parameters.downloadImage')}
|
||||||
</IAIButton>
|
</IAIButton>
|
||||||
@ -500,7 +504,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
icon={<FaQuoteRight />}
|
icon={<FaQuoteRight />}
|
||||||
tooltip={`${t('parameters.usePrompt')} (P)`}
|
tooltip={`${t('parameters.usePrompt')} (P)`}
|
||||||
aria-label={`${t('parameters.usePrompt')} (P)`}
|
aria-label={`${t('parameters.usePrompt')} (P)`}
|
||||||
isDisabled={!image?.metadata?.invokeai?.node?.prompt}
|
isDisabled={!image?.metadata?.positive_conditioning}
|
||||||
onClick={handleUsePrompt}
|
onClick={handleUsePrompt}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
@ -508,7 +512,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
icon={<FaSeedling />}
|
icon={<FaSeedling />}
|
||||||
tooltip={`${t('parameters.useSeed')} (S)`}
|
tooltip={`${t('parameters.useSeed')} (S)`}
|
||||||
aria-label={`${t('parameters.useSeed')} (S)`}
|
aria-label={`${t('parameters.useSeed')} (S)`}
|
||||||
isDisabled={!image?.metadata?.invokeai?.node?.seed}
|
isDisabled={!image?.metadata?.seed}
|
||||||
onClick={handleUseSeed}
|
onClick={handleUseSeed}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
@ -517,9 +521,8 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
tooltip={`${t('parameters.useAll')} (A)`}
|
tooltip={`${t('parameters.useAll')} (A)`}
|
||||||
aria-label={`${t('parameters.useAll')} (A)`}
|
aria-label={`${t('parameters.useAll')} (A)`}
|
||||||
isDisabled={
|
isDisabled={
|
||||||
!['txt2img', 'img2img', 'inpaint'].includes(
|
// not sure what this list should be
|
||||||
String(image?.metadata?.invokeai?.node?.type)
|
!['t2l', 'l2l', 'inpaint'].includes(String(image?.metadata?.type))
|
||||||
)
|
|
||||||
}
|
}
|
||||||
onClick={handleClickUseAllParameters}
|
onClick={handleClickUseAllParameters}
|
||||||
/>
|
/>
|
||||||
|
@ -155,7 +155,10 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
|||||||
|
|
||||||
// Recall parameters handlers
|
// Recall parameters handlers
|
||||||
const handleRecallPrompt = useCallback(() => {
|
const handleRecallPrompt = useCallback(() => {
|
||||||
recallPrompt(image.metadata?.positive_conditioning);
|
recallPrompt(
|
||||||
|
image.metadata?.positive_conditioning,
|
||||||
|
image.metadata?.negative_conditioning
|
||||||
|
);
|
||||||
}, [image, recallPrompt]);
|
}, [image, recallPrompt]);
|
||||||
|
|
||||||
const handleRecallSeed = useCallback(() => {
|
const handleRecallSeed = useCallback(() => {
|
||||||
@ -248,7 +251,8 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
|||||||
icon={<IoArrowUndoCircleOutline />}
|
icon={<IoArrowUndoCircleOutline />}
|
||||||
onClickCapture={handleUseAllParameters}
|
onClickCapture={handleUseAllParameters}
|
||||||
isDisabled={
|
isDisabled={
|
||||||
!['txt2img', 'img2img', 'inpaint'].includes(
|
// what should these be
|
||||||
|
!['t2l', 'l2l', 'inpaint'].includes(
|
||||||
String(image?.metadata?.type)
|
String(image?.metadata?.type)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -8,29 +8,20 @@ import {
|
|||||||
Text,
|
Text,
|
||||||
Tooltip,
|
Tooltip,
|
||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
import * as InvokeAI from 'app/types/invokeai';
|
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import { useGetUrl } from 'common/util/getUrl';
|
import { useGetUrl } from 'common/util/getUrl';
|
||||||
import promptToString from 'common/util/promptToString';
|
import promptToString from 'common/util/promptToString';
|
||||||
import { seedWeightsToString } from 'common/util/seedWeightPairs';
|
|
||||||
import useSetBothPrompts from 'features/parameters/hooks/usePrompt';
|
|
||||||
import {
|
import {
|
||||||
setCfgScale,
|
setCfgScale,
|
||||||
setHeight,
|
setHeight,
|
||||||
setImg2imgStrength,
|
setImg2imgStrength,
|
||||||
setNegativePrompt,
|
setNegativePrompt,
|
||||||
setPerlin,
|
|
||||||
setPositivePrompt,
|
setPositivePrompt,
|
||||||
setScheduler,
|
setScheduler,
|
||||||
setSeamless,
|
|
||||||
setSeed,
|
setSeed,
|
||||||
setSeedWeights,
|
|
||||||
setShouldFitToWidthHeight,
|
|
||||||
setSteps,
|
setSteps,
|
||||||
setThreshold,
|
|
||||||
setWidth,
|
setWidth,
|
||||||
} from 'features/parameters/store/generationSlice';
|
} from 'features/parameters/store/generationSlice';
|
||||||
import { setHiresFix } from 'features/parameters/store/postprocessingSlice';
|
|
||||||
import { setShouldShowImageDetails } from 'features/ui/store/uiSlice';
|
import { setShouldShowImageDetails } from 'features/ui/store/uiSlice';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
import { useHotkeys } from 'react-hotkeys-hook';
|
import { useHotkeys } from 'react-hotkeys-hook';
|
||||||
@ -39,7 +30,6 @@ import { FaCopy } from 'react-icons/fa';
|
|||||||
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
|
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
|
||||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||||
import { ImageDTO } from 'services/api';
|
import { ImageDTO } from 'services/api';
|
||||||
import { filter } from 'lodash-es';
|
|
||||||
import { Scheduler } from 'app/constants';
|
import { Scheduler } from 'app/constants';
|
||||||
|
|
||||||
type MetadataItemProps = {
|
type MetadataItemProps = {
|
||||||
@ -126,8 +116,6 @@ const memoEqualityCheck = (
|
|||||||
const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const setBothPrompts = useSetBothPrompts();
|
|
||||||
|
|
||||||
useHotkeys('esc', () => {
|
useHotkeys('esc', () => {
|
||||||
dispatch(setShouldShowImageDetails(false));
|
dispatch(setShouldShowImageDetails(false));
|
||||||
});
|
});
|
||||||
|
@ -33,7 +33,7 @@ export const nextPrevImageButtonsSelector = createSelector(
|
|||||||
}
|
}
|
||||||
|
|
||||||
const currentImageIndex = state[currentCategory].ids.findIndex(
|
const currentImageIndex = state[currentCategory].ids.findIndex(
|
||||||
(i) => i === selectedImage.name
|
(i) => i === selectedImage.image_name
|
||||||
);
|
);
|
||||||
|
|
||||||
const nextImageIndex = clamp(
|
const nextImageIndex = clamp(
|
||||||
|
@ -1,14 +1,13 @@
|
|||||||
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
|
import {
|
||||||
|
PayloadAction,
|
||||||
|
createEntityAdapter,
|
||||||
|
createSlice,
|
||||||
|
} from '@reduxjs/toolkit';
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import {
|
import {
|
||||||
receivedResultImagesPage,
|
receivedResultImagesPage,
|
||||||
IMAGES_PER_PAGE,
|
IMAGES_PER_PAGE,
|
||||||
} from 'services/thunks/gallery';
|
} from 'services/thunks/gallery';
|
||||||
import {
|
|
||||||
imageDeleted,
|
|
||||||
imageMetadataReceived,
|
|
||||||
imageUrlsReceived,
|
|
||||||
} from 'services/thunks/image';
|
|
||||||
import { ImageDTO } from 'services/api';
|
import { ImageDTO } from 'services/api';
|
||||||
import { dateComparator } from 'common/util/dateComparator';
|
import { dateComparator } from 'common/util/dateComparator';
|
||||||
|
|
||||||
@ -26,6 +25,7 @@ type AdditionalResultsState = {
|
|||||||
pages: number;
|
pages: number;
|
||||||
isLoading: boolean;
|
isLoading: boolean;
|
||||||
nextPage: number;
|
nextPage: number;
|
||||||
|
upsertedImageCount: number;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const initialResultsState =
|
export const initialResultsState =
|
||||||
@ -34,6 +34,7 @@ export const initialResultsState =
|
|||||||
pages: 0,
|
pages: 0,
|
||||||
isLoading: false,
|
isLoading: false,
|
||||||
nextPage: 0,
|
nextPage: 0,
|
||||||
|
upsertedImageCount: 0,
|
||||||
});
|
});
|
||||||
|
|
||||||
export type ResultsState = typeof initialResultsState;
|
export type ResultsState = typeof initialResultsState;
|
||||||
@ -42,7 +43,10 @@ const resultsSlice = createSlice({
|
|||||||
name: 'results',
|
name: 'results',
|
||||||
initialState: initialResultsState,
|
initialState: initialResultsState,
|
||||||
reducers: {
|
reducers: {
|
||||||
resultAdded: resultsAdapter.upsertOne,
|
resultUpserted: (state, action: PayloadAction<ResultsImageDTO>) => {
|
||||||
|
resultsAdapter.upsertOne(state, action.payload);
|
||||||
|
state.upsertedImageCount += 1;
|
||||||
|
},
|
||||||
},
|
},
|
||||||
extraReducers: (builder) => {
|
extraReducers: (builder) => {
|
||||||
/**
|
/**
|
||||||
@ -68,47 +72,6 @@ const resultsSlice = createSlice({
|
|||||||
state.nextPage = items.length < IMAGES_PER_PAGE ? page : page + 1;
|
state.nextPage = items.length < IMAGES_PER_PAGE ? page : page + 1;
|
||||||
state.isLoading = false;
|
state.isLoading = false;
|
||||||
});
|
});
|
||||||
|
|
||||||
/**
|
|
||||||
* Image Metadata Received - FULFILLED
|
|
||||||
*/
|
|
||||||
builder.addCase(imageMetadataReceived.fulfilled, (state, action) => {
|
|
||||||
const { image_type } = action.payload;
|
|
||||||
|
|
||||||
if (image_type === 'results') {
|
|
||||||
resultsAdapter.upsertOne(state, action.payload as ResultsImageDTO);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Image URLs Received - FULFILLED
|
|
||||||
*/
|
|
||||||
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
|
||||||
const { image_name, image_type, image_url, thumbnail_url } =
|
|
||||||
action.payload;
|
|
||||||
|
|
||||||
if (image_type === 'results') {
|
|
||||||
resultsAdapter.updateOne(state, {
|
|
||||||
id: image_name,
|
|
||||||
changes: {
|
|
||||||
image_url: image_url,
|
|
||||||
thumbnail_url: thumbnail_url,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Delete Image - PENDING
|
|
||||||
* Pre-emptively remove the image from the gallery
|
|
||||||
*/
|
|
||||||
builder.addCase(imageDeleted.pending, (state, action) => {
|
|
||||||
const { imageType, imageName } = action.meta.arg;
|
|
||||||
|
|
||||||
if (imageType === 'results') {
|
|
||||||
resultsAdapter.removeOne(state, imageName);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -120,6 +83,6 @@ export const {
|
|||||||
selectTotal: selectResultsTotal,
|
selectTotal: selectResultsTotal,
|
||||||
} = resultsAdapter.getSelectors<RootState>((state) => state.results);
|
} = resultsAdapter.getSelectors<RootState>((state) => state.results);
|
||||||
|
|
||||||
export const { resultAdded } = resultsSlice.actions;
|
export const { resultUpserted } = resultsSlice.actions;
|
||||||
|
|
||||||
export default resultsSlice.reducer;
|
export default resultsSlice.reducer;
|
||||||
|
@ -1,11 +1,14 @@
|
|||||||
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
|
import {
|
||||||
|
PayloadAction,
|
||||||
|
createEntityAdapter,
|
||||||
|
createSlice,
|
||||||
|
} from '@reduxjs/toolkit';
|
||||||
|
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import {
|
import {
|
||||||
receivedUploadImagesPage,
|
receivedUploadImagesPage,
|
||||||
IMAGES_PER_PAGE,
|
IMAGES_PER_PAGE,
|
||||||
} from 'services/thunks/gallery';
|
} from 'services/thunks/gallery';
|
||||||
import { imageDeleted, imageUrlsReceived } from 'services/thunks/image';
|
|
||||||
import { ImageDTO } from 'services/api';
|
import { ImageDTO } from 'services/api';
|
||||||
import { dateComparator } from 'common/util/dateComparator';
|
import { dateComparator } from 'common/util/dateComparator';
|
||||||
|
|
||||||
@ -23,6 +26,7 @@ type AdditionalUploadsState = {
|
|||||||
pages: number;
|
pages: number;
|
||||||
isLoading: boolean;
|
isLoading: boolean;
|
||||||
nextPage: number;
|
nextPage: number;
|
||||||
|
upsertedImageCount: number;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const initialUploadsState =
|
export const initialUploadsState =
|
||||||
@ -31,6 +35,7 @@ export const initialUploadsState =
|
|||||||
pages: 0,
|
pages: 0,
|
||||||
nextPage: 0,
|
nextPage: 0,
|
||||||
isLoading: false,
|
isLoading: false,
|
||||||
|
upsertedImageCount: 0,
|
||||||
});
|
});
|
||||||
|
|
||||||
export type UploadsState = typeof initialUploadsState;
|
export type UploadsState = typeof initialUploadsState;
|
||||||
@ -39,7 +44,10 @@ const uploadsSlice = createSlice({
|
|||||||
name: 'uploads',
|
name: 'uploads',
|
||||||
initialState: initialUploadsState,
|
initialState: initialUploadsState,
|
||||||
reducers: {
|
reducers: {
|
||||||
uploadAdded: uploadsAdapter.upsertOne,
|
uploadUpserted: (state, action: PayloadAction<UploadsImageDTO>) => {
|
||||||
|
uploadsAdapter.upsertOne(state, action.payload);
|
||||||
|
state.upsertedImageCount += 1;
|
||||||
|
},
|
||||||
},
|
},
|
||||||
extraReducers: (builder) => {
|
extraReducers: (builder) => {
|
||||||
/**
|
/**
|
||||||
@ -65,36 +73,6 @@ const uploadsSlice = createSlice({
|
|||||||
state.nextPage = items.length < IMAGES_PER_PAGE ? page : page + 1;
|
state.nextPage = items.length < IMAGES_PER_PAGE ? page : page + 1;
|
||||||
state.isLoading = false;
|
state.isLoading = false;
|
||||||
});
|
});
|
||||||
|
|
||||||
/**
|
|
||||||
* Image URLs Received - FULFILLED
|
|
||||||
*/
|
|
||||||
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
|
||||||
const { image_name, image_type, image_url, thumbnail_url } =
|
|
||||||
action.payload;
|
|
||||||
|
|
||||||
if (image_type === 'uploads') {
|
|
||||||
uploadsAdapter.updateOne(state, {
|
|
||||||
id: image_name,
|
|
||||||
changes: {
|
|
||||||
image_url: image_url,
|
|
||||||
thumbnail_url: thumbnail_url,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Delete Image - pending
|
|
||||||
* Pre-emptively remove the image from the gallery
|
|
||||||
*/
|
|
||||||
builder.addCase(imageDeleted.pending, (state, action) => {
|
|
||||||
const { imageType, imageName } = action.meta.arg;
|
|
||||||
|
|
||||||
if (imageType === 'uploads') {
|
|
||||||
uploadsAdapter.removeOne(state, imageName);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -106,6 +84,6 @@ export const {
|
|||||||
selectTotal: selectUploadsTotal,
|
selectTotal: selectUploadsTotal,
|
||||||
} = uploadsAdapter.getSelectors<RootState>((state) => state.uploads);
|
} = uploadsAdapter.getSelectors<RootState>((state) => state.uploads);
|
||||||
|
|
||||||
export const { uploadAdded } = uploadsSlice.actions;
|
export const { uploadUpserted } = uploadsSlice.actions;
|
||||||
|
|
||||||
export default uploadsSlice.reducer;
|
export default uploadsSlice.reducer;
|
||||||
|
@ -7,6 +7,7 @@ import EnumInputFieldComponent from './fields/EnumInputFieldComponent';
|
|||||||
import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
|
import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
|
||||||
import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent';
|
import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent';
|
||||||
import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent';
|
import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent';
|
||||||
|
import ControlInputFieldComponent from './fields/ControlInputFieldComponent';
|
||||||
import ModelInputFieldComponent from './fields/ModelInputFieldComponent';
|
import ModelInputFieldComponent from './fields/ModelInputFieldComponent';
|
||||||
import NumberInputFieldComponent from './fields/NumberInputFieldComponent';
|
import NumberInputFieldComponent from './fields/NumberInputFieldComponent';
|
||||||
import StringInputFieldComponent from './fields/StringInputFieldComponent';
|
import StringInputFieldComponent from './fields/StringInputFieldComponent';
|
||||||
@ -97,6 +98,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (type === 'control' && template.type === 'control') {
|
||||||
|
return (
|
||||||
|
<ControlInputFieldComponent
|
||||||
|
nodeId={nodeId}
|
||||||
|
field={field}
|
||||||
|
template={template}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
if (type === 'model' && template.type === 'model') {
|
if (type === 'model' && template.type === 'model') {
|
||||||
return (
|
return (
|
||||||
<ModelInputFieldComponent
|
<ModelInputFieldComponent
|
||||||
|
@ -0,0 +1,16 @@
|
|||||||
|
import {
|
||||||
|
ControlInputFieldTemplate,
|
||||||
|
ControlInputFieldValue,
|
||||||
|
} from 'features/nodes/types/types';
|
||||||
|
import { memo } from 'react';
|
||||||
|
import { FieldComponentProps } from './types';
|
||||||
|
|
||||||
|
const ControlInputFieldComponent = (
|
||||||
|
props: FieldComponentProps<ControlInputFieldValue, ControlInputFieldTemplate>
|
||||||
|
) => {
|
||||||
|
const { nodeId, field } = props;
|
||||||
|
|
||||||
|
return null;
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(ControlInputFieldComponent);
|
@ -4,6 +4,7 @@ export const HANDLE_TOOLTIP_OPEN_DELAY = 500;
|
|||||||
|
|
||||||
export const FIELD_TYPE_MAP: Record<string, FieldType> = {
|
export const FIELD_TYPE_MAP: Record<string, FieldType> = {
|
||||||
integer: 'integer',
|
integer: 'integer',
|
||||||
|
float: 'float',
|
||||||
number: 'float',
|
number: 'float',
|
||||||
string: 'string',
|
string: 'string',
|
||||||
boolean: 'boolean',
|
boolean: 'boolean',
|
||||||
@ -15,6 +16,8 @@ export const FIELD_TYPE_MAP: Record<string, FieldType> = {
|
|||||||
array: 'array',
|
array: 'array',
|
||||||
item: 'item',
|
item: 'item',
|
||||||
ColorField: 'color',
|
ColorField: 'color',
|
||||||
|
ControlField: 'control',
|
||||||
|
control: 'control',
|
||||||
};
|
};
|
||||||
|
|
||||||
const COLOR_TOKEN_VALUE = 500;
|
const COLOR_TOKEN_VALUE = 500;
|
||||||
@ -22,6 +25,9 @@ const COLOR_TOKEN_VALUE = 500;
|
|||||||
const getColorTokenCssVariable = (color: string) =>
|
const getColorTokenCssVariable = (color: string) =>
|
||||||
`var(--invokeai-colors-${color}-${COLOR_TOKEN_VALUE})`;
|
`var(--invokeai-colors-${color}-${COLOR_TOKEN_VALUE})`;
|
||||||
|
|
||||||
|
// @ts-ignore
|
||||||
|
// @ts-ignore
|
||||||
|
// @ts-ignore
|
||||||
export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
||||||
integer: {
|
integer: {
|
||||||
color: 'red',
|
color: 'red',
|
||||||
@ -71,6 +77,12 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
|||||||
title: 'Conditioning',
|
title: 'Conditioning',
|
||||||
description: 'Conditioning may be passed between nodes.',
|
description: 'Conditioning may be passed between nodes.',
|
||||||
},
|
},
|
||||||
|
control: {
|
||||||
|
color: 'cyan',
|
||||||
|
colorCssVar: getColorTokenCssVariable('cyan'), // TODO: no free color left
|
||||||
|
title: 'Control',
|
||||||
|
description: 'Control info passed between nodes.',
|
||||||
|
},
|
||||||
model: {
|
model: {
|
||||||
color: 'teal',
|
color: 'teal',
|
||||||
colorCssVar: getColorTokenCssVariable('teal'),
|
colorCssVar: getColorTokenCssVariable('teal'),
|
||||||
|
@ -61,6 +61,7 @@ export type FieldType =
|
|||||||
| 'image'
|
| 'image'
|
||||||
| 'latents'
|
| 'latents'
|
||||||
| 'conditioning'
|
| 'conditioning'
|
||||||
|
| 'control'
|
||||||
| 'model'
|
| 'model'
|
||||||
| 'array'
|
| 'array'
|
||||||
| 'item'
|
| 'item'
|
||||||
@ -82,6 +83,7 @@ export type InputFieldValue =
|
|||||||
| ImageInputFieldValue
|
| ImageInputFieldValue
|
||||||
| LatentsInputFieldValue
|
| LatentsInputFieldValue
|
||||||
| ConditioningInputFieldValue
|
| ConditioningInputFieldValue
|
||||||
|
| ControlInputFieldValue
|
||||||
| EnumInputFieldValue
|
| EnumInputFieldValue
|
||||||
| ModelInputFieldValue
|
| ModelInputFieldValue
|
||||||
| ArrayInputFieldValue
|
| ArrayInputFieldValue
|
||||||
@ -102,6 +104,7 @@ export type InputFieldTemplate =
|
|||||||
| ImageInputFieldTemplate
|
| ImageInputFieldTemplate
|
||||||
| LatentsInputFieldTemplate
|
| LatentsInputFieldTemplate
|
||||||
| ConditioningInputFieldTemplate
|
| ConditioningInputFieldTemplate
|
||||||
|
| ControlInputFieldTemplate
|
||||||
| EnumInputFieldTemplate
|
| EnumInputFieldTemplate
|
||||||
| ModelInputFieldTemplate
|
| ModelInputFieldTemplate
|
||||||
| ArrayInputFieldTemplate
|
| ArrayInputFieldTemplate
|
||||||
@ -177,6 +180,11 @@ export type LatentsInputFieldValue = FieldValueBase & {
|
|||||||
|
|
||||||
export type ConditioningInputFieldValue = FieldValueBase & {
|
export type ConditioningInputFieldValue = FieldValueBase & {
|
||||||
type: 'conditioning';
|
type: 'conditioning';
|
||||||
|
value?: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type ControlInputFieldValue = FieldValueBase & {
|
||||||
|
type: 'control';
|
||||||
value?: undefined;
|
value?: undefined;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -262,6 +270,11 @@ export type ConditioningInputFieldTemplate = InputFieldTemplateBase & {
|
|||||||
type: 'conditioning';
|
type: 'conditioning';
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type ControlInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
|
default: undefined;
|
||||||
|
type: 'control';
|
||||||
|
};
|
||||||
|
|
||||||
export type EnumInputFieldTemplate = InputFieldTemplateBase & {
|
export type EnumInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
default: string | number;
|
default: string | number;
|
||||||
type: 'enum';
|
type: 'enum';
|
||||||
|
@ -10,6 +10,7 @@ import {
|
|||||||
IntegerInputFieldTemplate,
|
IntegerInputFieldTemplate,
|
||||||
LatentsInputFieldTemplate,
|
LatentsInputFieldTemplate,
|
||||||
ConditioningInputFieldTemplate,
|
ConditioningInputFieldTemplate,
|
||||||
|
ControlInputFieldTemplate,
|
||||||
StringInputFieldTemplate,
|
StringInputFieldTemplate,
|
||||||
ModelInputFieldTemplate,
|
ModelInputFieldTemplate,
|
||||||
ArrayInputFieldTemplate,
|
ArrayInputFieldTemplate,
|
||||||
@ -215,6 +216,21 @@ const buildConditioningInputFieldTemplate = ({
|
|||||||
return template;
|
return template;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const buildControlInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): ControlInputFieldTemplate => {
|
||||||
|
const template: ControlInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'control',
|
||||||
|
inputRequirement: 'always',
|
||||||
|
inputKind: 'connection',
|
||||||
|
default: schemaObject.default ?? undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
const buildEnumInputFieldTemplate = ({
|
const buildEnumInputFieldTemplate = ({
|
||||||
schemaObject,
|
schemaObject,
|
||||||
baseField,
|
baseField,
|
||||||
@ -286,9 +302,20 @@ export const getFieldType = (
|
|||||||
if (typeHints && name in typeHints) {
|
if (typeHints && name in typeHints) {
|
||||||
rawFieldType = typeHints[name];
|
rawFieldType = typeHints[name];
|
||||||
} else if (!schemaObject.type) {
|
} else if (!schemaObject.type) {
|
||||||
|
// if schemaObject has no type, then it should have one of allOf, anyOf, oneOf
|
||||||
|
if (schemaObject.allOf) {
|
||||||
rawFieldType = refObjectToFieldType(
|
rawFieldType = refObjectToFieldType(
|
||||||
schemaObject.allOf![0] as OpenAPIV3.ReferenceObject
|
schemaObject.allOf![0] as OpenAPIV3.ReferenceObject
|
||||||
);
|
);
|
||||||
|
} else if (schemaObject.anyOf) {
|
||||||
|
rawFieldType = refObjectToFieldType(
|
||||||
|
schemaObject.anyOf![0] as OpenAPIV3.ReferenceObject
|
||||||
|
);
|
||||||
|
} else if (schemaObject.oneOf) {
|
||||||
|
rawFieldType = refObjectToFieldType(
|
||||||
|
schemaObject.oneOf![0] as OpenAPIV3.ReferenceObject
|
||||||
|
);
|
||||||
|
}
|
||||||
} else if (schemaObject.enum) {
|
} else if (schemaObject.enum) {
|
||||||
rawFieldType = 'enum';
|
rawFieldType = 'enum';
|
||||||
} else if (schemaObject.type) {
|
} else if (schemaObject.type) {
|
||||||
@ -331,6 +358,9 @@ export const buildInputFieldTemplate = (
|
|||||||
if (['conditioning'].includes(fieldType)) {
|
if (['conditioning'].includes(fieldType)) {
|
||||||
return buildConditioningInputFieldTemplate({ schemaObject, baseField });
|
return buildConditioningInputFieldTemplate({ schemaObject, baseField });
|
||||||
}
|
}
|
||||||
|
if (['control'].includes(fieldType)) {
|
||||||
|
return buildControlInputFieldTemplate({ schemaObject, baseField });
|
||||||
|
}
|
||||||
if (['model'].includes(fieldType)) {
|
if (['model'].includes(fieldType)) {
|
||||||
return buildModelInputFieldTemplate({ schemaObject, baseField });
|
return buildModelInputFieldTemplate({ schemaObject, baseField });
|
||||||
}
|
}
|
||||||
|
@ -52,6 +52,10 @@ export const buildInputFieldValue = (
|
|||||||
fieldValue.value = undefined;
|
fieldValue.value = undefined;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (template.type === 'control') {
|
||||||
|
fieldValue.value = undefined;
|
||||||
|
}
|
||||||
|
|
||||||
if (template.type === 'model') {
|
if (template.type === 'model') {
|
||||||
fieldValue.value = undefined;
|
fieldValue.value = undefined;
|
||||||
}
|
}
|
||||||
|
@ -11,7 +11,7 @@ import { addNoiseNodes } from '../nodeBuilders/addNoiseNodes';
|
|||||||
const POSITIVE_CONDITIONING = 'positive_conditioning';
|
const POSITIVE_CONDITIONING = 'positive_conditioning';
|
||||||
const NEGATIVE_CONDITIONING = 'negative_conditioning';
|
const NEGATIVE_CONDITIONING = 'negative_conditioning';
|
||||||
const TEXT_TO_LATENTS = 'text_to_latents';
|
const TEXT_TO_LATENTS = 'text_to_latents';
|
||||||
const LATENTS_TO_IMAGE = 'latnets_to_image';
|
const LATENTS_TO_IMAGE = 'latents_to_image';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds the Text to Image tab graph.
|
* Builds the Text to Image tab graph.
|
||||||
|
@ -13,7 +13,9 @@ import {
|
|||||||
buildOutputFieldTemplates,
|
buildOutputFieldTemplates,
|
||||||
} from './fieldTemplateBuilders';
|
} from './fieldTemplateBuilders';
|
||||||
|
|
||||||
const invocationDenylist = ['Graph'];
|
const RESERVED_FIELD_NAMES = ['id', 'type', 'meta'];
|
||||||
|
|
||||||
|
const invocationDenylist = ['Graph', 'InvocationMeta'];
|
||||||
|
|
||||||
export const parseSchema = (openAPI: OpenAPIV3.Document) => {
|
export const parseSchema = (openAPI: OpenAPIV3.Document) => {
|
||||||
// filter out non-invocation schemas, plus some tricky invocations for now
|
// filter out non-invocation schemas, plus some tricky invocations for now
|
||||||
@ -73,7 +75,7 @@ export const parseSchema = (openAPI: OpenAPIV3.Document) => {
|
|||||||
(inputsAccumulator, property, propertyName) => {
|
(inputsAccumulator, property, propertyName) => {
|
||||||
if (
|
if (
|
||||||
// `type` and `id` are not valid inputs/outputs
|
// `type` and `id` are not valid inputs/outputs
|
||||||
!['type', 'id'].includes(propertyName) &&
|
!RESERVED_FIELD_NAMES.includes(propertyName) &&
|
||||||
isSchemaObject(property)
|
isSchemaObject(property)
|
||||||
) {
|
) {
|
||||||
const field: InputFieldTemplate | undefined =
|
const field: InputFieldTemplate | undefined =
|
||||||
|
@ -21,8 +21,8 @@ export const useParameters = () => {
|
|||||||
* Sets prompt with toast
|
* Sets prompt with toast
|
||||||
*/
|
*/
|
||||||
const recallPrompt = useCallback(
|
const recallPrompt = useCallback(
|
||||||
(prompt: unknown) => {
|
(prompt: unknown, negativePrompt?: unknown) => {
|
||||||
if (!isString(prompt)) {
|
if (!isString(prompt) || !isString(negativePrompt)) {
|
||||||
toaster({
|
toaster({
|
||||||
title: t('toast.promptNotSet'),
|
title: t('toast.promptNotSet'),
|
||||||
description: t('toast.promptNotSetDesc'),
|
description: t('toast.promptNotSetDesc'),
|
||||||
@ -33,7 +33,7 @@ export const useParameters = () => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
setBothPrompts(prompt);
|
setBothPrompts(prompt, negativePrompt);
|
||||||
toaster({
|
toaster({
|
||||||
title: t('toast.promptSet'),
|
title: t('toast.promptSet'),
|
||||||
status: 'info',
|
status: 'info',
|
||||||
@ -112,12 +112,13 @@ export const useParameters = () => {
|
|||||||
const recallAllParameters = useCallback(
|
const recallAllParameters = useCallback(
|
||||||
(image: ImageDTO | undefined) => {
|
(image: ImageDTO | undefined) => {
|
||||||
const type = image?.metadata?.type;
|
const type = image?.metadata?.type;
|
||||||
if (['txt2img', 'img2img', 'inpaint'].includes(String(type))) {
|
// not sure what this list should be
|
||||||
|
if (['t2l', 'l2l', 'inpaint'].includes(String(type))) {
|
||||||
dispatch(allParametersSet(image));
|
dispatch(allParametersSet(image));
|
||||||
|
|
||||||
if (image?.metadata?.type === 'img2img') {
|
if (image?.metadata?.type === 'l2l') {
|
||||||
dispatch(setActiveTab('img2img'));
|
dispatch(setActiveTab('img2img'));
|
||||||
} else if (image?.metadata?.type === 'txt2img') {
|
} else if (image?.metadata?.type === 't2l') {
|
||||||
dispatch(setActiveTab('txt2img'));
|
dispatch(setActiveTab('txt2img'));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -12,15 +12,8 @@ const useSetBothPrompts = () => {
|
|||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
return useCallback(
|
return useCallback(
|
||||||
(inputPrompt: InvokeAI.Prompt) => {
|
(inputPrompt: InvokeAI.Prompt, negativePrompt: InvokeAI.Prompt) => {
|
||||||
const promptString =
|
dispatch(setPositivePrompt(inputPrompt));
|
||||||
typeof inputPrompt === 'string'
|
|
||||||
? inputPrompt
|
|
||||||
: promptToString(inputPrompt);
|
|
||||||
|
|
||||||
const [prompt, negativePrompt] = getPromptAndNegative(promptString);
|
|
||||||
|
|
||||||
dispatch(setPositivePrompt(prompt));
|
|
||||||
dispatch(setNegativePrompt(negativePrompt));
|
dispatch(setNegativePrompt(negativePrompt));
|
||||||
},
|
},
|
||||||
[dispatch]
|
[dispatch]
|
||||||
|
@ -52,7 +52,7 @@ export const initialGenerationState: GenerationState = {
|
|||||||
perlin: 0,
|
perlin: 0,
|
||||||
positivePrompt: '',
|
positivePrompt: '',
|
||||||
negativePrompt: '',
|
negativePrompt: '',
|
||||||
scheduler: 'lms',
|
scheduler: 'euler',
|
||||||
seamBlur: 16,
|
seamBlur: 16,
|
||||||
seamSize: 96,
|
seamSize: 96,
|
||||||
seamSteps: 30,
|
seamSteps: 30,
|
||||||
|
@ -7,19 +7,29 @@ export const setAllParametersReducer = (
|
|||||||
state: Draft<GenerationState>,
|
state: Draft<GenerationState>,
|
||||||
action: PayloadAction<ImageDTO | undefined>
|
action: PayloadAction<ImageDTO | undefined>
|
||||||
) => {
|
) => {
|
||||||
const node = action.payload?.metadata.invokeai?.node;
|
const metadata = action.payload?.metadata;
|
||||||
|
|
||||||
if (!node) {
|
if (!metadata) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// not sure what this list should be
|
||||||
if (
|
if (
|
||||||
node.type === 'txt2img' ||
|
metadata.type === 't2l' ||
|
||||||
node.type === 'img2img' ||
|
metadata.type === 'l2l' ||
|
||||||
node.type === 'inpaint'
|
metadata.type === 'inpaint'
|
||||||
) {
|
) {
|
||||||
const { cfg_scale, height, model, prompt, scheduler, seed, steps, width } =
|
const {
|
||||||
node;
|
cfg_scale,
|
||||||
|
height,
|
||||||
|
model,
|
||||||
|
positive_conditioning,
|
||||||
|
negative_conditioning,
|
||||||
|
scheduler,
|
||||||
|
seed,
|
||||||
|
steps,
|
||||||
|
width,
|
||||||
|
} = metadata;
|
||||||
|
|
||||||
if (cfg_scale !== undefined) {
|
if (cfg_scale !== undefined) {
|
||||||
state.cfgScale = Number(cfg_scale);
|
state.cfgScale = Number(cfg_scale);
|
||||||
@ -30,8 +40,11 @@ export const setAllParametersReducer = (
|
|||||||
if (model !== undefined) {
|
if (model !== undefined) {
|
||||||
state.model = String(model);
|
state.model = String(model);
|
||||||
}
|
}
|
||||||
if (prompt !== undefined) {
|
if (positive_conditioning !== undefined) {
|
||||||
state.positivePrompt = String(prompt);
|
state.positivePrompt = String(positive_conditioning);
|
||||||
|
}
|
||||||
|
if (negative_conditioning !== undefined) {
|
||||||
|
state.negativePrompt = String(negative_conditioning);
|
||||||
}
|
}
|
||||||
if (scheduler !== undefined) {
|
if (scheduler !== undefined) {
|
||||||
const schedulerString = String(scheduler);
|
const schedulerString = String(scheduler);
|
||||||
@ -51,8 +64,8 @@ export const setAllParametersReducer = (
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (node.type === 'img2img') {
|
if (metadata.type === 'l2l') {
|
||||||
const { fit, image } = node as ImageToImageInvocation;
|
const { fit, image } = metadata as ImageToImageInvocation;
|
||||||
|
|
||||||
if (fit !== undefined) {
|
if (fit !== undefined) {
|
||||||
state.shouldFitToWidthHeight = Boolean(fit);
|
state.shouldFitToWidthHeight = Boolean(fit);
|
||||||
|
@ -0,0 +1,3 @@
|
|||||||
|
import { createAction } from '@reduxjs/toolkit';
|
||||||
|
|
||||||
|
export const sessionReadyToInvoke = createAction('system/sessionReadyToInvoke');
|
@ -0,0 +1,62 @@
|
|||||||
|
// TODO: split system slice inot this
|
||||||
|
|
||||||
|
// import type { PayloadAction } from '@reduxjs/toolkit';
|
||||||
|
// import { createSlice } from '@reduxjs/toolkit';
|
||||||
|
// import { socketSubscribed, socketUnsubscribed } from 'services/events/actions';
|
||||||
|
|
||||||
|
// export type SessionState = {
|
||||||
|
// /**
|
||||||
|
// * The current socket session id
|
||||||
|
// */
|
||||||
|
// sessionId: string;
|
||||||
|
// /**
|
||||||
|
// * Whether the current session is a canvas session. Needed to manage the staging area.
|
||||||
|
// */
|
||||||
|
// isCanvasSession: boolean;
|
||||||
|
// /**
|
||||||
|
// * When a session is canceled, its ID is stored here until a new session is created.
|
||||||
|
// */
|
||||||
|
// canceledSessionId: string;
|
||||||
|
// };
|
||||||
|
|
||||||
|
// export const initialSessionState: SessionState = {
|
||||||
|
// sessionId: '',
|
||||||
|
// isCanvasSession: false,
|
||||||
|
// canceledSessionId: '',
|
||||||
|
// };
|
||||||
|
|
||||||
|
// export const sessionSlice = createSlice({
|
||||||
|
// name: 'session',
|
||||||
|
// initialState: initialSessionState,
|
||||||
|
// reducers: {
|
||||||
|
// sessionIdChanged: (state, action: PayloadAction<string>) => {
|
||||||
|
// state.sessionId = action.payload;
|
||||||
|
// },
|
||||||
|
// isCanvasSessionChanged: (state, action: PayloadAction<boolean>) => {
|
||||||
|
// state.isCanvasSession = action.payload;
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// extraReducers: (builder) => {
|
||||||
|
// /**
|
||||||
|
// * Socket Subscribed
|
||||||
|
// */
|
||||||
|
// builder.addCase(socketSubscribed, (state, action) => {
|
||||||
|
// state.sessionId = action.payload.sessionId;
|
||||||
|
// state.canceledSessionId = '';
|
||||||
|
// });
|
||||||
|
|
||||||
|
// /**
|
||||||
|
// * Socket Unsubscribed
|
||||||
|
// */
|
||||||
|
// builder.addCase(socketUnsubscribed, (state) => {
|
||||||
|
// state.sessionId = '';
|
||||||
|
// });
|
||||||
|
// },
|
||||||
|
// });
|
||||||
|
|
||||||
|
// export const { sessionIdChanged, isCanvasSessionChanged } =
|
||||||
|
// sessionSlice.actions;
|
||||||
|
|
||||||
|
// export default sessionSlice.reducer;
|
||||||
|
|
||||||
|
export default {};
|
@ -1,5 +1,5 @@
|
|||||||
import { UseToastOptions } from '@chakra-ui/react';
|
import { UseToastOptions } from '@chakra-ui/react';
|
||||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
import { PayloadAction, isAnyOf } from '@reduxjs/toolkit';
|
||||||
import { createSlice } from '@reduxjs/toolkit';
|
import { createSlice } from '@reduxjs/toolkit';
|
||||||
import * as InvokeAI from 'app/types/invokeai';
|
import * as InvokeAI from 'app/types/invokeai';
|
||||||
import {
|
import {
|
||||||
@ -16,7 +16,11 @@ import {
|
|||||||
|
|
||||||
import { ProgressImage } from 'services/events/types';
|
import { ProgressImage } from 'services/events/types';
|
||||||
import { makeToast } from '../../../app/components/Toaster';
|
import { makeToast } from '../../../app/components/Toaster';
|
||||||
import { sessionCanceled, sessionInvoked } from 'services/thunks/session';
|
import {
|
||||||
|
sessionCanceled,
|
||||||
|
sessionCreated,
|
||||||
|
sessionInvoked,
|
||||||
|
} from 'services/thunks/session';
|
||||||
import { receivedModels } from 'services/thunks/model';
|
import { receivedModels } from 'services/thunks/model';
|
||||||
import { parsedOpenAPISchema } from 'features/nodes/store/nodesSlice';
|
import { parsedOpenAPISchema } from 'features/nodes/store/nodesSlice';
|
||||||
import { LogLevelName } from 'roarr';
|
import { LogLevelName } from 'roarr';
|
||||||
@ -215,6 +219,9 @@ export const systemSlice = createSlice({
|
|||||||
languageChanged: (state, action: PayloadAction<keyof typeof LANGUAGES>) => {
|
languageChanged: (state, action: PayloadAction<keyof typeof LANGUAGES>) => {
|
||||||
state.language = action.payload;
|
state.language = action.payload;
|
||||||
},
|
},
|
||||||
|
progressImageSet(state, action: PayloadAction<ProgressImage | null>) {
|
||||||
|
state.progressImage = action.payload;
|
||||||
|
},
|
||||||
},
|
},
|
||||||
extraReducers(builder) {
|
extraReducers(builder) {
|
||||||
/**
|
/**
|
||||||
@ -305,7 +312,6 @@ export const systemSlice = createSlice({
|
|||||||
state.currentStep = 0;
|
state.currentStep = 0;
|
||||||
state.totalSteps = 0;
|
state.totalSteps = 0;
|
||||||
state.statusTranslationKey = 'common.statusProcessingComplete';
|
state.statusTranslationKey = 'common.statusProcessingComplete';
|
||||||
state.progressImage = null;
|
|
||||||
|
|
||||||
if (state.canceledSession === data.graph_execution_state_id) {
|
if (state.canceledSession === data.graph_execution_state_id) {
|
||||||
state.isProcessing = false;
|
state.isProcessing = false;
|
||||||
@ -343,15 +349,8 @@ export const systemSlice = createSlice({
|
|||||||
state.statusTranslationKey = 'common.statusPreparing';
|
state.statusTranslationKey = 'common.statusPreparing';
|
||||||
});
|
});
|
||||||
|
|
||||||
builder.addCase(sessionInvoked.rejected, (state, action) => {
|
|
||||||
const error = action.payload as string | undefined;
|
|
||||||
state.toastQueue.push(
|
|
||||||
makeToast({ title: error || t('toast.serverError'), status: 'error' })
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Session Canceled
|
* Session Canceled - FULFILLED
|
||||||
*/
|
*/
|
||||||
builder.addCase(sessionCanceled.fulfilled, (state, action) => {
|
builder.addCase(sessionCanceled.fulfilled, (state, action) => {
|
||||||
state.canceledSession = action.meta.arg.sessionId;
|
state.canceledSession = action.meta.arg.sessionId;
|
||||||
@ -414,6 +413,26 @@ export const systemSlice = createSlice({
|
|||||||
builder.addCase(imageUploaded.fulfilled, (state) => {
|
builder.addCase(imageUploaded.fulfilled, (state) => {
|
||||||
state.isUploading = false;
|
state.isUploading = false;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// *** Matchers - must be after all cases ***
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Session Invoked - REJECTED
|
||||||
|
* Session Created - REJECTED
|
||||||
|
*/
|
||||||
|
builder.addMatcher(isAnySessionRejected, (state, action) => {
|
||||||
|
state.isProcessing = false;
|
||||||
|
state.isCancelable = false;
|
||||||
|
state.isCancelScheduled = false;
|
||||||
|
state.currentStep = 0;
|
||||||
|
state.totalSteps = 0;
|
||||||
|
state.statusTranslationKey = 'common.statusConnected';
|
||||||
|
state.progressImage = null;
|
||||||
|
|
||||||
|
state.toastQueue.push(
|
||||||
|
makeToast({ title: t('toast.serverError'), status: 'error' })
|
||||||
|
);
|
||||||
|
});
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -438,6 +457,12 @@ export const {
|
|||||||
isPersistedChanged,
|
isPersistedChanged,
|
||||||
shouldAntialiasProgressImageChanged,
|
shouldAntialiasProgressImageChanged,
|
||||||
languageChanged,
|
languageChanged,
|
||||||
|
progressImageSet,
|
||||||
} = systemSlice.actions;
|
} = systemSlice.actions;
|
||||||
|
|
||||||
export default systemSlice.reducer;
|
export default systemSlice.reducer;
|
||||||
|
|
||||||
|
const isAnySessionRejected = isAnyOf(
|
||||||
|
sessionCreated.rejected,
|
||||||
|
sessionInvoked.rejected
|
||||||
|
);
|
||||||
|
@ -7,7 +7,6 @@ export { OpenAPI } from './core/OpenAPI';
|
|||||||
export type { OpenAPIConfig } from './core/OpenAPI';
|
export type { OpenAPIConfig } from './core/OpenAPI';
|
||||||
|
|
||||||
export type { AddInvocation } from './models/AddInvocation';
|
export type { AddInvocation } from './models/AddInvocation';
|
||||||
export type { BlurInvocation } from './models/BlurInvocation';
|
|
||||||
export type { Body_upload_image } from './models/Body_upload_image';
|
export type { Body_upload_image } from './models/Body_upload_image';
|
||||||
export type { CkptModelInfo } from './models/CkptModelInfo';
|
export type { CkptModelInfo } from './models/CkptModelInfo';
|
||||||
export type { CollectInvocation } from './models/CollectInvocation';
|
export type { CollectInvocation } from './models/CollectInvocation';
|
||||||
@ -17,7 +16,6 @@ export type { CompelInvocation } from './models/CompelInvocation';
|
|||||||
export type { CompelOutput } from './models/CompelOutput';
|
export type { CompelOutput } from './models/CompelOutput';
|
||||||
export type { ConditioningField } from './models/ConditioningField';
|
export type { ConditioningField } from './models/ConditioningField';
|
||||||
export type { CreateModelRequest } from './models/CreateModelRequest';
|
export type { CreateModelRequest } from './models/CreateModelRequest';
|
||||||
export type { CropImageInvocation } from './models/CropImageInvocation';
|
|
||||||
export type { CvInpaintInvocation } from './models/CvInpaintInvocation';
|
export type { CvInpaintInvocation } from './models/CvInpaintInvocation';
|
||||||
export type { DiffusersModelInfo } from './models/DiffusersModelInfo';
|
export type { DiffusersModelInfo } from './models/DiffusersModelInfo';
|
||||||
export type { DivideInvocation } from './models/DivideInvocation';
|
export type { DivideInvocation } from './models/DivideInvocation';
|
||||||
@ -28,11 +26,20 @@ export type { GraphExecutionState } from './models/GraphExecutionState';
|
|||||||
export type { GraphInvocation } from './models/GraphInvocation';
|
export type { GraphInvocation } from './models/GraphInvocation';
|
||||||
export type { GraphInvocationOutput } from './models/GraphInvocationOutput';
|
export type { GraphInvocationOutput } from './models/GraphInvocationOutput';
|
||||||
export type { HTTPValidationError } from './models/HTTPValidationError';
|
export type { HTTPValidationError } from './models/HTTPValidationError';
|
||||||
|
export type { ImageBlurInvocation } from './models/ImageBlurInvocation';
|
||||||
export type { ImageCategory } from './models/ImageCategory';
|
export type { ImageCategory } from './models/ImageCategory';
|
||||||
|
export type { ImageChannelInvocation } from './models/ImageChannelInvocation';
|
||||||
|
export type { ImageConvertInvocation } from './models/ImageConvertInvocation';
|
||||||
|
export type { ImageCropInvocation } from './models/ImageCropInvocation';
|
||||||
export type { ImageDTO } from './models/ImageDTO';
|
export type { ImageDTO } from './models/ImageDTO';
|
||||||
export type { ImageField } from './models/ImageField';
|
export type { ImageField } from './models/ImageField';
|
||||||
|
export type { ImageInverseLerpInvocation } from './models/ImageInverseLerpInvocation';
|
||||||
|
export type { ImageLerpInvocation } from './models/ImageLerpInvocation';
|
||||||
export type { ImageMetadata } from './models/ImageMetadata';
|
export type { ImageMetadata } from './models/ImageMetadata';
|
||||||
|
export type { ImageMultiplyInvocation } from './models/ImageMultiplyInvocation';
|
||||||
export type { ImageOutput } from './models/ImageOutput';
|
export type { ImageOutput } from './models/ImageOutput';
|
||||||
|
export type { ImagePasteInvocation } from './models/ImagePasteInvocation';
|
||||||
|
export type { ImageRecordChanges } from './models/ImageRecordChanges';
|
||||||
export type { ImageToImageInvocation } from './models/ImageToImageInvocation';
|
export type { ImageToImageInvocation } from './models/ImageToImageInvocation';
|
||||||
export type { ImageToLatentsInvocation } from './models/ImageToLatentsInvocation';
|
export type { ImageToLatentsInvocation } from './models/ImageToLatentsInvocation';
|
||||||
export type { ImageType } from './models/ImageType';
|
export type { ImageType } from './models/ImageType';
|
||||||
@ -43,14 +50,12 @@ export type { InfillTileInvocation } from './models/InfillTileInvocation';
|
|||||||
export type { InpaintInvocation } from './models/InpaintInvocation';
|
export type { InpaintInvocation } from './models/InpaintInvocation';
|
||||||
export type { IntCollectionOutput } from './models/IntCollectionOutput';
|
export type { IntCollectionOutput } from './models/IntCollectionOutput';
|
||||||
export type { IntOutput } from './models/IntOutput';
|
export type { IntOutput } from './models/IntOutput';
|
||||||
export type { InverseLerpInvocation } from './models/InverseLerpInvocation';
|
|
||||||
export type { IterateInvocation } from './models/IterateInvocation';
|
export type { IterateInvocation } from './models/IterateInvocation';
|
||||||
export type { IterateInvocationOutput } from './models/IterateInvocationOutput';
|
export type { IterateInvocationOutput } from './models/IterateInvocationOutput';
|
||||||
export type { LatentsField } from './models/LatentsField';
|
export type { LatentsField } from './models/LatentsField';
|
||||||
export type { LatentsOutput } from './models/LatentsOutput';
|
export type { LatentsOutput } from './models/LatentsOutput';
|
||||||
export type { LatentsToImageInvocation } from './models/LatentsToImageInvocation';
|
export type { LatentsToImageInvocation } from './models/LatentsToImageInvocation';
|
||||||
export type { LatentsToLatentsInvocation } from './models/LatentsToLatentsInvocation';
|
export type { LatentsToLatentsInvocation } from './models/LatentsToLatentsInvocation';
|
||||||
export type { LerpInvocation } from './models/LerpInvocation';
|
|
||||||
export type { LoadImageInvocation } from './models/LoadImageInvocation';
|
export type { LoadImageInvocation } from './models/LoadImageInvocation';
|
||||||
export type { MaskFromAlphaInvocation } from './models/MaskFromAlphaInvocation';
|
export type { MaskFromAlphaInvocation } from './models/MaskFromAlphaInvocation';
|
||||||
export type { MaskOutput } from './models/MaskOutput';
|
export type { MaskOutput } from './models/MaskOutput';
|
||||||
@ -61,7 +66,6 @@ export type { NoiseOutput } from './models/NoiseOutput';
|
|||||||
export type { PaginatedResults_GraphExecutionState_ } from './models/PaginatedResults_GraphExecutionState_';
|
export type { PaginatedResults_GraphExecutionState_ } from './models/PaginatedResults_GraphExecutionState_';
|
||||||
export type { PaginatedResults_ImageDTO_ } from './models/PaginatedResults_ImageDTO_';
|
export type { PaginatedResults_ImageDTO_ } from './models/PaginatedResults_ImageDTO_';
|
||||||
export type { ParamIntInvocation } from './models/ParamIntInvocation';
|
export type { ParamIntInvocation } from './models/ParamIntInvocation';
|
||||||
export type { PasteImageInvocation } from './models/PasteImageInvocation';
|
|
||||||
export type { PromptOutput } from './models/PromptOutput';
|
export type { PromptOutput } from './models/PromptOutput';
|
||||||
export type { RandomIntInvocation } from './models/RandomIntInvocation';
|
export type { RandomIntInvocation } from './models/RandomIntInvocation';
|
||||||
export type { RandomRangeInvocation } from './models/RandomRangeInvocation';
|
export type { RandomRangeInvocation } from './models/RandomRangeInvocation';
|
||||||
|
@ -10,6 +10,10 @@ export type AddInvocation = {
|
|||||||
* The id of this node. Must be unique among all nodes.
|
* The id of this node. Must be unique among all nodes.
|
||||||
*/
|
*/
|
||||||
id: string;
|
id: string;
|
||||||
|
/**
|
||||||
|
* Whether or not this node is an intermediate node.
|
||||||
|
*/
|
||||||
|
is_intermediate?: boolean;
|
||||||
type?: 'add';
|
type?: 'add';
|
||||||
/**
|
/**
|
||||||
* The first number
|
* The first number
|
||||||
|
@ -0,0 +1,29 @@
|
|||||||
|
/* istanbul ignore file */
|
||||||
|
/* tslint:disable */
|
||||||
|
/* eslint-disable */
|
||||||
|
|
||||||
|
import type { ImageField } from './ImageField';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Canny edge detection for ControlNet
|
||||||
|
*/
|
||||||
|
export type CannyImageProcessorInvocation = {
|
||||||
|
/**
|
||||||
|
* The id of this node. Must be unique among all nodes.
|
||||||
|
*/
|
||||||
|
id: string;
|
||||||
|
type?: 'canny_image_processor';
|
||||||
|
/**
|
||||||
|
* image to process
|
||||||
|
*/
|
||||||
|
image?: ImageField;
|
||||||
|
/**
|
||||||
|
* low threshold of Canny pixel gradient
|
||||||
|
*/
|
||||||
|
low_threshold?: number;
|
||||||
|
/**
|
||||||
|
* high threshold of Canny pixel gradient
|
||||||
|
*/
|
||||||
|
high_threshold?: number;
|
||||||
|
};
|
||||||
|
|
@ -10,6 +10,10 @@ export type CollectInvocation = {
|
|||||||
* The id of this node. Must be unique among all nodes.
|
* The id of this node. Must be unique among all nodes.
|
||||||
*/
|
*/
|
||||||
id: string;
|
id: string;
|
||||||
|
/**
|
||||||
|
* Whether or not this node is an intermediate node.
|
||||||
|
*/
|
||||||
|
is_intermediate?: boolean;
|
||||||
type?: 'collect';
|
type?: 'collect';
|
||||||
/**
|
/**
|
||||||
* The item to collect (all inputs must be of the same type)
|
* The item to collect (all inputs must be of the same type)
|
||||||
|
@ -10,6 +10,10 @@ export type CompelInvocation = {
|
|||||||
* The id of this node. Must be unique among all nodes.
|
* The id of this node. Must be unique among all nodes.
|
||||||
*/
|
*/
|
||||||
id: string;
|
id: string;
|
||||||
|
/**
|
||||||
|
* Whether or not this node is an intermediate node.
|
||||||
|
*/
|
||||||
|
is_intermediate?: boolean;
|
||||||
type?: 'compel';
|
type?: 'compel';
|
||||||
/**
|
/**
|
||||||
* Prompt
|
* Prompt
|
||||||
|
@ -0,0 +1,41 @@
|
|||||||
|
/* istanbul ignore file */
|
||||||
|
/* tslint:disable */
|
||||||
|
/* eslint-disable */
|
||||||
|
|
||||||
|
import type { ImageField } from './ImageField';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Applies content shuffle processing to image
|
||||||
|
*/
|
||||||
|
export type ContentShuffleImageProcessorInvocation = {
|
||||||
|
/**
|
||||||
|
* The id of this node. Must be unique among all nodes.
|
||||||
|
*/
|
||||||
|
id: string;
|
||||||
|
type?: 'content_shuffle_image_processor';
|
||||||
|
/**
|
||||||
|
* image to process
|
||||||
|
*/
|
||||||
|
image?: ImageField;
|
||||||
|
/**
|
||||||
|
* pixel resolution for edge detection
|
||||||
|
*/
|
||||||
|
detect_resolution?: number;
|
||||||
|
/**
|
||||||
|
* pixel resolution for output image
|
||||||
|
*/
|
||||||
|
image_resolution?: number;
|
||||||
|
/**
|
||||||
|
* content shuffle h parameter
|
||||||
|
*/
|
||||||
|
'h'?: number;
|
||||||
|
/**
|
||||||
|
* content shuffle w parameter
|
||||||
|
*/
|
||||||
|
'w'?: number;
|
||||||
|
/**
|
||||||
|
* cont
|
||||||
|
*/
|
||||||
|
'f'?: number;
|
||||||
|
};
|
||||||
|
|
@ -0,0 +1,29 @@
|
|||||||
|
/* istanbul ignore file */
|
||||||
|
/* tslint:disable */
|
||||||
|
/* eslint-disable */
|
||||||
|
|
||||||
|
import type { ImageField } from './ImageField';
|
||||||
|
|
||||||
|
export type ControlField = {
|
||||||
|
/**
|
||||||
|
* processed image
|
||||||
|
*/
|
||||||
|
image: ImageField;
|
||||||
|
/**
|
||||||
|
* control model used
|
||||||
|
*/
|
||||||
|
control_model: string;
|
||||||
|
/**
|
||||||
|
* weight given to controlnet
|
||||||
|
*/
|
||||||
|
control_weight: number;
|
||||||
|
/**
|
||||||
|
* % of total steps at which controlnet is first applied
|
||||||
|
*/
|
||||||
|
begin_step_percent: number;
|
||||||
|
/**
|
||||||
|
* % of total steps at which controlnet is last applied
|
||||||
|
*/
|
||||||
|
end_step_percent: number;
|
||||||
|
};
|
||||||
|
|
@ -0,0 +1,37 @@
|
|||||||
|
/* istanbul ignore file */
|
||||||
|
/* tslint:disable */
|
||||||
|
/* eslint-disable */
|
||||||
|
|
||||||
|
import type { ImageField } from './ImageField';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Collects ControlNet info to pass to other nodes
|
||||||
|
*/
|
||||||
|
export type ControlNetInvocation = {
|
||||||
|
/**
|
||||||
|
* The id of this node. Must be unique among all nodes.
|
||||||
|
*/
|
||||||
|
id: string;
|
||||||
|
type?: 'controlnet';
|
||||||
|
/**
|
||||||
|
* image to process
|
||||||
|
*/
|
||||||
|
image?: ImageField;
|
||||||
|
/**
|
||||||
|
* control model used
|
||||||
|
*/
|
||||||
|
control_model?: 'lllyasviel/sd-controlnet-canny' | 'lllyasviel/sd-controlnet-depth' | 'lllyasviel/sd-controlnet-hed' | 'lllyasviel/sd-controlnet-seg' | 'lllyasviel/sd-controlnet-openpose' | 'lllyasviel/sd-controlnet-scribble' | 'lllyasviel/sd-controlnet-normal' | 'lllyasviel/sd-controlnet-mlsd' | 'lllyasviel/control_v11p_sd15_canny' | 'lllyasviel/control_v11p_sd15_openpose' | 'lllyasviel/control_v11p_sd15_seg' | 'lllyasviel/control_v11f1p_sd15_depth' | 'lllyasviel/control_v11p_sd15_normalbae' | 'lllyasviel/control_v11p_sd15_scribble' | 'lllyasviel/control_v11p_sd15_mlsd' | 'lllyasviel/control_v11p_sd15_softedge' | 'lllyasviel/control_v11p_sd15s2_lineart_anime' | 'lllyasviel/control_v11p_sd15_lineart' | 'lllyasviel/control_v11p_sd15_inpaint' | 'lllyasviel/control_v11e_sd15_shuffle' | 'lllyasviel/control_v11e_sd15_ip2p' | 'lllyasviel/control_v11f1e_sd15_tile' | 'thibaud/controlnet-sd21-openpose-diffusers' | 'thibaud/controlnet-sd21-canny-diffusers' | 'thibaud/controlnet-sd21-depth-diffusers' | 'thibaud/controlnet-sd21-scribble-diffusers' | 'thibaud/controlnet-sd21-hed-diffusers' | 'thibaud/controlnet-sd21-zoedepth-diffusers' | 'thibaud/controlnet-sd21-color-diffusers' | 'thibaud/controlnet-sd21-openposev2-diffusers' | 'thibaud/controlnet-sd21-lineart-diffusers' | 'thibaud/controlnet-sd21-normalbae-diffusers' | 'thibaud/controlnet-sd21-ade20k-diffusers' | 'CrucibleAI/ControlNetMediaPipeFace';
|
||||||
|
/**
|
||||||
|
* weight given to controlnet
|
||||||
|
*/
|
||||||
|
control_weight?: number;
|
||||||
|
/**
|
||||||
|
* % of total steps at which controlnet is first applied
|
||||||
|
*/
|
||||||
|
begin_step_percent?: number;
|
||||||
|
/**
|
||||||
|
* % of total steps at which controlnet is last applied
|
||||||
|
*/
|
||||||
|
end_step_percent?: number;
|
||||||
|
};
|
||||||
|
|
@ -0,0 +1,17 @@
|
|||||||
|
/* istanbul ignore file */
|
||||||
|
/* tslint:disable */
|
||||||
|
/* eslint-disable */
|
||||||
|
|
||||||
|
import type { ControlField } from './ControlField';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* node output for ControlNet info
|
||||||
|
*/
|
||||||
|
export type ControlOutput = {
|
||||||
|
type?: 'control_output';
|
||||||
|
/**
|
||||||
|
* The control info dict
|
||||||
|
*/
|
||||||
|
control?: ControlField;
|
||||||
|
};
|
||||||
|
|
@ -12,6 +12,10 @@ export type CvInpaintInvocation = {
|
|||||||
* The id of this node. Must be unique among all nodes.
|
* The id of this node. Must be unique among all nodes.
|
||||||
*/
|
*/
|
||||||
id: string;
|
id: string;
|
||||||
|
/**
|
||||||
|
* Whether or not this node is an intermediate node.
|
||||||
|
*/
|
||||||
|
is_intermediate?: boolean;
|
||||||
type?: 'cv_inpaint';
|
type?: 'cv_inpaint';
|
||||||
/**
|
/**
|
||||||
* The image to inpaint
|
* The image to inpaint
|
||||||
|
@ -10,6 +10,10 @@ export type DivideInvocation = {
|
|||||||
* The id of this node. Must be unique among all nodes.
|
* The id of this node. Must be unique among all nodes.
|
||||||
*/
|
*/
|
||||||
id: string;
|
id: string;
|
||||||
|
/**
|
||||||
|
* Whether or not this node is an intermediate node.
|
||||||
|
*/
|
||||||
|
is_intermediate?: boolean;
|
||||||
type?: 'div';
|
type?: 'div';
|
||||||
/**
|
/**
|
||||||
* The first number
|
* The first number
|
||||||
|
@ -3,31 +3,34 @@
|
|||||||
/* eslint-disable */
|
/* eslint-disable */
|
||||||
|
|
||||||
import type { AddInvocation } from './AddInvocation';
|
import type { AddInvocation } from './AddInvocation';
|
||||||
import type { BlurInvocation } from './BlurInvocation';
|
|
||||||
import type { CollectInvocation } from './CollectInvocation';
|
import type { CollectInvocation } from './CollectInvocation';
|
||||||
import type { CompelInvocation } from './CompelInvocation';
|
import type { CompelInvocation } from './CompelInvocation';
|
||||||
import type { CropImageInvocation } from './CropImageInvocation';
|
|
||||||
import type { CvInpaintInvocation } from './CvInpaintInvocation';
|
import type { CvInpaintInvocation } from './CvInpaintInvocation';
|
||||||
import type { DivideInvocation } from './DivideInvocation';
|
import type { DivideInvocation } from './DivideInvocation';
|
||||||
import type { Edge } from './Edge';
|
import type { Edge } from './Edge';
|
||||||
import type { GraphInvocation } from './GraphInvocation';
|
import type { GraphInvocation } from './GraphInvocation';
|
||||||
|
import type { ImageBlurInvocation } from './ImageBlurInvocation';
|
||||||
|
import type { ImageChannelInvocation } from './ImageChannelInvocation';
|
||||||
|
import type { ImageConvertInvocation } from './ImageConvertInvocation';
|
||||||
|
import type { ImageCropInvocation } from './ImageCropInvocation';
|
||||||
|
import type { ImageInverseLerpInvocation } from './ImageInverseLerpInvocation';
|
||||||
|
import type { ImageLerpInvocation } from './ImageLerpInvocation';
|
||||||
|
import type { ImageMultiplyInvocation } from './ImageMultiplyInvocation';
|
||||||
|
import type { ImagePasteInvocation } from './ImagePasteInvocation';
|
||||||
import type { ImageToImageInvocation } from './ImageToImageInvocation';
|
import type { ImageToImageInvocation } from './ImageToImageInvocation';
|
||||||
import type { ImageToLatentsInvocation } from './ImageToLatentsInvocation';
|
import type { ImageToLatentsInvocation } from './ImageToLatentsInvocation';
|
||||||
import type { InfillColorInvocation } from './InfillColorInvocation';
|
import type { InfillColorInvocation } from './InfillColorInvocation';
|
||||||
import type { InfillPatchMatchInvocation } from './InfillPatchMatchInvocation';
|
import type { InfillPatchMatchInvocation } from './InfillPatchMatchInvocation';
|
||||||
import type { InfillTileInvocation } from './InfillTileInvocation';
|
import type { InfillTileInvocation } from './InfillTileInvocation';
|
||||||
import type { InpaintInvocation } from './InpaintInvocation';
|
import type { InpaintInvocation } from './InpaintInvocation';
|
||||||
import type { InverseLerpInvocation } from './InverseLerpInvocation';
|
|
||||||
import type { IterateInvocation } from './IterateInvocation';
|
import type { IterateInvocation } from './IterateInvocation';
|
||||||
import type { LatentsToImageInvocation } from './LatentsToImageInvocation';
|
import type { LatentsToImageInvocation } from './LatentsToImageInvocation';
|
||||||
import type { LatentsToLatentsInvocation } from './LatentsToLatentsInvocation';
|
import type { LatentsToLatentsInvocation } from './LatentsToLatentsInvocation';
|
||||||
import type { LerpInvocation } from './LerpInvocation';
|
|
||||||
import type { LoadImageInvocation } from './LoadImageInvocation';
|
import type { LoadImageInvocation } from './LoadImageInvocation';
|
||||||
import type { MaskFromAlphaInvocation } from './MaskFromAlphaInvocation';
|
import type { MaskFromAlphaInvocation } from './MaskFromAlphaInvocation';
|
||||||
import type { MultiplyInvocation } from './MultiplyInvocation';
|
import type { MultiplyInvocation } from './MultiplyInvocation';
|
||||||
import type { NoiseInvocation } from './NoiseInvocation';
|
import type { NoiseInvocation } from './NoiseInvocation';
|
||||||
import type { ParamIntInvocation } from './ParamIntInvocation';
|
import type { ParamIntInvocation } from './ParamIntInvocation';
|
||||||
import type { PasteImageInvocation } from './PasteImageInvocation';
|
|
||||||
import type { RandomIntInvocation } from './RandomIntInvocation';
|
import type { RandomIntInvocation } from './RandomIntInvocation';
|
||||||
import type { RandomRangeInvocation } from './RandomRangeInvocation';
|
import type { RandomRangeInvocation } from './RandomRangeInvocation';
|
||||||
import type { RangeInvocation } from './RangeInvocation';
|
import type { RangeInvocation } from './RangeInvocation';
|
||||||
@ -49,7 +52,7 @@ export type Graph = {
|
|||||||
/**
|
/**
|
||||||
* The nodes in this graph
|
* The nodes in this graph
|
||||||
*/
|
*/
|
||||||
nodes?: Record<string, (LoadImageInvocation | ShowImageInvocation | CropImageInvocation | PasteImageInvocation | MaskFromAlphaInvocation | BlurInvocation | LerpInvocation | InverseLerpInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation)>;
|
nodes?: Record<string, (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation)>;
|
||||||
/**
|
/**
|
||||||
* The connections between nodes and their fields in this graph
|
* The connections between nodes and their fields in this graph
|
||||||
*/
|
*/
|
||||||
|
@ -5,14 +5,17 @@
|
|||||||
import type { Graph } from './Graph';
|
import type { Graph } from './Graph';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A node to process inputs and produce outputs.
|
* Execute a graph
|
||||||
* May use dependency injection in __init__ to receive providers.
|
|
||||||
*/
|
*/
|
||||||
export type GraphInvocation = {
|
export type GraphInvocation = {
|
||||||
/**
|
/**
|
||||||
* The id of this node. Must be unique among all nodes.
|
* The id of this node. Must be unique among all nodes.
|
||||||
*/
|
*/
|
||||||
id: string;
|
id: string;
|
||||||
|
/**
|
||||||
|
* Whether or not this node is an intermediate node.
|
||||||
|
*/
|
||||||
|
is_intermediate?: boolean;
|
||||||
type?: 'graph';
|
type?: 'graph';
|
||||||
/**
|
/**
|
||||||
* The graph to run
|
* The graph to run
|
||||||
|
@ -0,0 +1,33 @@
|
|||||||
|
/* istanbul ignore file */
|
||||||
|
/* tslint:disable */
|
||||||
|
/* eslint-disable */
|
||||||
|
|
||||||
|
import type { ImageField } from './ImageField';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Applies HED edge detection to image
|
||||||
|
*/
|
||||||
|
export type HedImageprocessorInvocation = {
|
||||||
|
/**
|
||||||
|
* The id of this node. Must be unique among all nodes.
|
||||||
|
*/
|
||||||
|
id: string;
|
||||||
|
type?: 'hed_image_processor';
|
||||||
|
/**
|
||||||
|
* image to process
|
||||||
|
*/
|
||||||
|
image?: ImageField;
|
||||||
|
/**
|
||||||
|
* pixel resolution for edge detection
|
||||||
|
*/
|
||||||
|
detect_resolution?: number;
|
||||||
|
/**
|
||||||
|
* pixel resolution for output image
|
||||||
|
*/
|
||||||
|
image_resolution?: number;
|
||||||
|
/**
|
||||||
|
* whether to use scribble mode
|
||||||
|
*/
|
||||||
|
scribble?: boolean;
|
||||||
|
};
|
||||||
|
|
@ -7,12 +7,16 @@ import type { ImageField } from './ImageField';
|
|||||||
/**
|
/**
|
||||||
* Blurs an image
|
* Blurs an image
|
||||||
*/
|
*/
|
||||||
export type BlurInvocation = {
|
export type ImageBlurInvocation = {
|
||||||
/**
|
/**
|
||||||
* The id of this node. Must be unique among all nodes.
|
* The id of this node. Must be unique among all nodes.
|
||||||
*/
|
*/
|
||||||
id: string;
|
id: string;
|
||||||
type?: 'blur';
|
/**
|
||||||
|
* Whether or not this node is an intermediate node.
|
||||||
|
*/
|
||||||
|
is_intermediate?: boolean;
|
||||||
|
type?: 'img_blur';
|
||||||
/**
|
/**
|
||||||
* The image to blur
|
* The image to blur
|
||||||
*/
|
*/
|
@ -5,4 +5,4 @@
|
|||||||
/**
|
/**
|
||||||
* The category of an image. Use ImageCategory.OTHER for non-default categories.
|
* The category of an image. Use ImageCategory.OTHER for non-default categories.
|
||||||
*/
|
*/
|
||||||
export type ImageCategory = 'general' | 'control' | 'other';
|
export type ImageCategory = 'general' | 'control' | 'mask' | 'other';
|
||||||
|
@ -0,0 +1,29 @@
|
|||||||
|
/* istanbul ignore file */
|
||||||
|
/* tslint:disable */
|
||||||
|
/* eslint-disable */
|
||||||
|
|
||||||
|
import type { ImageField } from './ImageField';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets a channel from an image.
|
||||||
|
*/
|
||||||
|
export type ImageChannelInvocation = {
|
||||||
|
/**
|
||||||
|
* The id of this node. Must be unique among all nodes.
|
||||||
|
*/
|
||||||
|
id: string;
|
||||||
|
/**
|
||||||
|
* Whether or not this node is an intermediate node.
|
||||||
|
*/
|
||||||
|
is_intermediate?: boolean;
|
||||||
|
type?: 'img_chan';
|
||||||
|
/**
|
||||||
|
* The image to get the channel from
|
||||||
|
*/
|
||||||
|
image?: ImageField;
|
||||||
|
/**
|
||||||
|
* The channel to get
|
||||||
|
*/
|
||||||
|
channel?: 'A' | 'R' | 'G' | 'B';
|
||||||
|
};
|
||||||
|
|
@ -0,0 +1,29 @@
|
|||||||
|
/* istanbul ignore file */
|
||||||
|
/* tslint:disable */
|
||||||
|
/* eslint-disable */
|
||||||
|
|
||||||
|
import type { ImageField } from './ImageField';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts an image to a different mode.
|
||||||
|
*/
|
||||||
|
export type ImageConvertInvocation = {
|
||||||
|
/**
|
||||||
|
* The id of this node. Must be unique among all nodes.
|
||||||
|
*/
|
||||||
|
id: string;
|
||||||
|
/**
|
||||||
|
* Whether or not this node is an intermediate node.
|
||||||
|
*/
|
||||||
|
is_intermediate?: boolean;
|
||||||
|
type?: 'img_conv';
|
||||||
|
/**
|
||||||
|
* The image to convert
|
||||||
|
*/
|
||||||
|
image?: ImageField;
|
||||||
|
/**
|
||||||
|
* The mode to convert to
|
||||||
|
*/
|
||||||
|
mode?: 'L' | 'RGB' | 'RGBA' | 'CMYK' | 'YCbCr' | 'LAB' | 'HSV' | 'I' | 'F';
|
||||||
|
};
|
||||||
|
|
@ -7,12 +7,16 @@ import type { ImageField } from './ImageField';
|
|||||||
/**
|
/**
|
||||||
* Crops an image to a specified box. The box can be outside of the image.
|
* Crops an image to a specified box. The box can be outside of the image.
|
||||||
*/
|
*/
|
||||||
export type CropImageInvocation = {
|
export type ImageCropInvocation = {
|
||||||
/**
|
/**
|
||||||
* The id of this node. Must be unique among all nodes.
|
* The id of this node. Must be unique among all nodes.
|
||||||
*/
|
*/
|
||||||
id: string;
|
id: string;
|
||||||
type?: 'crop';
|
/**
|
||||||
|
* Whether or not this node is an intermediate node.
|
||||||
|
*/
|
||||||
|
is_intermediate?: boolean;
|
||||||
|
type?: 'img_crop';
|
||||||
/**
|
/**
|
||||||
* The image to crop
|
* The image to crop
|
||||||
*/
|
*/
|
@ -50,6 +50,10 @@ export type ImageDTO = {
|
|||||||
* The deleted timestamp of the image.
|
* The deleted timestamp of the image.
|
||||||
*/
|
*/
|
||||||
deleted_at?: string;
|
deleted_at?: string;
|
||||||
|
/**
|
||||||
|
* Whether this is an intermediate image.
|
||||||
|
*/
|
||||||
|
is_intermediate: boolean;
|
||||||
/**
|
/**
|
||||||
* The session ID that generated this image, if it is a generated image.
|
* The session ID that generated this image, if it is a generated image.
|
||||||
*/
|
*/
|
||||||
|
@ -7,12 +7,16 @@ import type { ImageField } from './ImageField';
|
|||||||
/**
|
/**
|
||||||
* Inverse linear interpolation of all pixels of an image
|
* Inverse linear interpolation of all pixels of an image
|
||||||
*/
|
*/
|
||||||
export type InverseLerpInvocation = {
|
export type ImageInverseLerpInvocation = {
|
||||||
/**
|
/**
|
||||||
* The id of this node. Must be unique among all nodes.
|
* The id of this node. Must be unique among all nodes.
|
||||||
*/
|
*/
|
||||||
id: string;
|
id: string;
|
||||||
type?: 'ilerp';
|
/**
|
||||||
|
* Whether or not this node is an intermediate node.
|
||||||
|
*/
|
||||||
|
is_intermediate?: boolean;
|
||||||
|
type?: 'img_ilerp';
|
||||||
/**
|
/**
|
||||||
* The image to lerp
|
* The image to lerp
|
||||||
*/
|
*/
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user