mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
2 Commits
feat/batch
...
feat/cleau
Author | SHA1 | Date | |
---|---|---|---|
edc8f5fb6f | |||
6bb657b3f3 |
36
.github/CODEOWNERS
vendored
36
.github/CODEOWNERS
vendored
@ -1,34 +1,34 @@
|
||||
# continuous integration
|
||||
/.github/workflows/ @lstein @blessedcoolant @hipsterusername
|
||||
/.github/workflows/ @lstein @blessedcoolant
|
||||
|
||||
# documentation
|
||||
/docs/ @lstein @blessedcoolant @hipsterusername @Millu
|
||||
/mkdocs.yml @lstein @blessedcoolant @hipsterusername @Millu
|
||||
/mkdocs.yml @lstein @blessedcoolant
|
||||
|
||||
# nodes
|
||||
/invokeai/app/ @Kyle0654 @blessedcoolant @psychedelicious @brandonrising @hipsterusername
|
||||
/invokeai/app/ @Kyle0654 @blessedcoolant @psychedelicious @brandonrising
|
||||
|
||||
# installation and configuration
|
||||
/pyproject.toml @lstein @blessedcoolant @hipsterusername
|
||||
/docker/ @lstein @blessedcoolant @hipsterusername
|
||||
/scripts/ @ebr @lstein @hipsterusername
|
||||
/installer/ @lstein @ebr @hipsterusername
|
||||
/invokeai/assets @lstein @ebr @hipsterusername
|
||||
/invokeai/configs @lstein @hipsterusername
|
||||
/invokeai/version @lstein @blessedcoolant @hipsterusername
|
||||
/pyproject.toml @lstein @blessedcoolant
|
||||
/docker/ @lstein @blessedcoolant
|
||||
/scripts/ @ebr @lstein
|
||||
/installer/ @lstein @ebr
|
||||
/invokeai/assets @lstein @ebr
|
||||
/invokeai/configs @lstein
|
||||
/invokeai/version @lstein @blessedcoolant
|
||||
|
||||
# web ui
|
||||
/invokeai/frontend @blessedcoolant @psychedelicious @lstein @maryhipp @hipsterusername
|
||||
/invokeai/backend @blessedcoolant @psychedelicious @lstein @maryhipp @hipsterusername
|
||||
/invokeai/frontend @blessedcoolant @psychedelicious @lstein @maryhipp
|
||||
/invokeai/backend @blessedcoolant @psychedelicious @lstein @maryhipp
|
||||
|
||||
# generation, model management, postprocessing
|
||||
/invokeai/backend @damian0815 @lstein @blessedcoolant @gregghelt2 @StAlKeR7779 @brandonrising @ryanjdick @hipsterusername
|
||||
/invokeai/backend @damian0815 @lstein @blessedcoolant @gregghelt2 @StAlKeR7779 @brandonrising @ryanjdick
|
||||
|
||||
# front ends
|
||||
/invokeai/frontend/CLI @lstein @hipsterusername
|
||||
/invokeai/frontend/install @lstein @ebr @hipsterusername
|
||||
/invokeai/frontend/merge @lstein @blessedcoolant @hipsterusername
|
||||
/invokeai/frontend/training @lstein @blessedcoolant @hipsterusername
|
||||
/invokeai/frontend/web @psychedelicious @blessedcoolant @maryhipp @hipsterusername
|
||||
/invokeai/frontend/CLI @lstein
|
||||
/invokeai/frontend/install @lstein @ebr
|
||||
/invokeai/frontend/merge @lstein @blessedcoolant
|
||||
/invokeai/frontend/training @lstein @blessedcoolant
|
||||
/invokeai/frontend/web @psychedelicious @blessedcoolant @maryhipp
|
||||
|
||||
|
||||
|
@ -244,12 +244,8 @@ copy-paste the template above.
|
||||
We can use the `@invocation` decorator to provide some additional info to the
|
||||
UI, like a custom title, tags and category.
|
||||
|
||||
We also encourage providing a version. This must be a
|
||||
[semver](https://semver.org/) version string ("$MAJOR.$MINOR.$PATCH"). The UI
|
||||
will let users know if their workflow is using a mismatched version of the node.
|
||||
|
||||
```python
|
||||
@invocation("resize", title="My Resizer", tags=["resize", "image"], category="My Invocations", version="1.0.0")
|
||||
@invocation("resize", title="My Resizer", tags=["resize", "image"], category="My Invocations")
|
||||
class ResizeInvocation(BaseInvocation):
|
||||
"""Resizes an image"""
|
||||
|
||||
@ -283,6 +279,8 @@ take a look a at our [contributing nodes overview](contributingNodes).
|
||||
|
||||
## Advanced
|
||||
|
||||
-->
|
||||
|
||||
### Custom Output Types
|
||||
|
||||
Like with custom inputs, sometimes you might find yourself needing custom
|
||||
|
@ -22,26 +22,12 @@ To use a community node graph, download the the `.json` node graph file and load
|
||||

|
||||

|
||||
|
||||
--------------------------------
|
||||
### Ideal Size
|
||||
|
||||
**Description:** This node calculates an ideal image size for a first pass of a multi-pass upscaling. The aim is to avoid duplication that results from choosing a size larger than the model is capable of.
|
||||
|
||||
**Node Link:** https://github.com/JPPhoto/ideal-size-node
|
||||
|
||||
--------------------------------
|
||||
### Film Grain
|
||||
|
||||
**Description:** This node adds a film grain effect to the input image based on the weights, seeds, and blur radii parameters. It works with RGB input images only.
|
||||
|
||||
**Node Link:** https://github.com/JPPhoto/film-grain-node
|
||||
|
||||
--------------------------------
|
||||
### Image Picker
|
||||
|
||||
**Description:** This InvokeAI node takes in a collection of images and randomly chooses one. This can be useful when you have a number of poses to choose from for a ControlNet node, or a number of input images for another purpose.
|
||||
|
||||
**Node Link:** https://github.com/JPPhoto/image-picker-node
|
||||
|
||||
--------------------------------
|
||||
### Retroize
|
||||
@ -109,91 +95,6 @@ a Text-Generation-Webui instance (might work remotely too, but I never tried it)
|
||||
|
||||
This node works best with SDXL models, especially as the style can be described independantly of the LLM's output.
|
||||
|
||||
--------------------------------
|
||||
### Depth Map from Wavefront OBJ
|
||||
|
||||
**Description:** Render depth maps from Wavefront .obj files (triangulated) using this simple 3D renderer utilizing numpy and matplotlib to compute and color the scene. There are simple parameters to change the FOV, camera position, and model orientation.
|
||||
|
||||
To be imported, an .obj must use triangulated meshes, so make sure to enable that option if exporting from a 3D modeling program. This renderer makes each triangle a solid color based on its average depth, so it will cause anomalies if your .obj has large triangles. In Blender, the Remesh modifier can be helpful to subdivide a mesh into small pieces that work well given these limitations.
|
||||
|
||||
**Node Link:** https://github.com/dwringer/depth-from-obj-node
|
||||
|
||||
**Example Usage:**
|
||||

|
||||
|
||||
--------------------------------
|
||||
### Enhance Image (simple adjustments)
|
||||
|
||||
**Description:** Boost or reduce color saturation, contrast, brightness, sharpness, or invert colors of any image at any stage with this simple wrapper for pillow [PIL]'s ImageEnhance module.
|
||||
|
||||
Color inversion is toggled with a simple switch, while each of the four enhancer modes are activated by entering a value other than 1 in each corresponding input field. Values less than 1 will reduce the corresponding property, while values greater than 1 will enhance it.
|
||||
|
||||
**Node Link:** https://github.com/dwringer/image-enhance-node
|
||||
|
||||
**Example Usage:**
|
||||

|
||||
|
||||
--------------------------------
|
||||
### Generative Grammar-Based Prompt Nodes
|
||||
|
||||
**Description:** This set of 3 nodes generates prompts from simple user-defined grammar rules (loaded from custom files - examples provided below). The prompts are made by recursively expanding a special template string, replacing nonterminal "parts-of-speech" until no more nonterminal terms remain in the string.
|
||||
|
||||
This includes 3 Nodes:
|
||||
- *Lookup Table from File* - loads a YAML file "prompt" section (or of a whole folder of YAML's) into a JSON-ified dictionary (Lookups output)
|
||||
- *Lookups Entry from Prompt* - places a single entry in a new Lookups output under the specified heading
|
||||
- *Prompt from Lookup Table* - uses a Collection of Lookups as grammar rules from which to randomly generate prompts.
|
||||
|
||||
**Node Link:** https://github.com/dwringer/generative-grammar-prompt-nodes
|
||||
|
||||
**Example Usage:**
|
||||

|
||||
|
||||
--------------------------------
|
||||
### Image and Mask Composition Pack
|
||||
|
||||
**Description:** This is a pack of nodes for composing masks and images, including a simple text mask creator and both image and latent offset nodes. The offsets wrap around, so these can be used in conjunction with the Seamless node to progressively generate centered on different parts of the seamless tiling.
|
||||
|
||||
This includes 4 Nodes:
|
||||
- *Text Mask (simple 2D)* - create and position a white on black (or black on white) line of text using any font locally available to Invoke.
|
||||
- *Image Compositor* - Take a subject from an image with a flat backdrop and layer it on another image using a chroma key or flood select background removal.
|
||||
- *Offset Latents* - Offset a latents tensor in the vertical and/or horizontal dimensions, wrapping it around.
|
||||
- *Offset Image* - Offset an image in the vertical and/or horizontal dimensions, wrapping it around.
|
||||
|
||||
**Node Link:** https://github.com/dwringer/composition-nodes
|
||||
|
||||
**Example Usage:**
|
||||

|
||||
|
||||
--------------------------------
|
||||
### Size Stepper Nodes
|
||||
|
||||
**Description:** This is a set of nodes for calculating the necessary size increments for doing upscaling workflows. Use the *Final Size & Orientation* node to enter your full size dimensions and orientation (portrait/landscape/random), then plug that and your initial generation dimensions into the *Ideal Size Stepper* and get 1, 2, or 3 intermediate pairs of dimensions for upscaling. Note this does not output the initial size or full size dimensions: the 1, 2, or 3 outputs of this node are only the intermediate sizes.
|
||||
|
||||
A third node is included, *Random Switch (Integers)*, which is just a generic version of Final Size with no orientation selection.
|
||||
|
||||
**Node Link:** https://github.com/dwringer/size-stepper-nodes
|
||||
|
||||
**Example Usage:**
|
||||

|
||||
|
||||
--------------------------------
|
||||
|
||||
### Text font to Image
|
||||
|
||||
**Description:** text font to text image node for InvokeAI, download a font to use (or if in font cache uses it from there), the text is always resized to the image size, but can control that with padding, optional 2nd line
|
||||
|
||||
**Node Link:** https://github.com/mickr777/textfontimage
|
||||
|
||||
**Output Examples**
|
||||
|
||||

|
||||
|
||||
Results after using the depth controlnet
|
||||
|
||||

|
||||

|
||||

|
||||
|
||||
--------------------------------
|
||||
|
||||
### Example Node Template
|
||||
|
@ -35,13 +35,13 @@ The table below contains a list of the default nodes shipped with InvokeAI and t
|
||||
|Inverse Lerp Image | Inverse linear interpolation of all pixels of an image|
|
||||
|Image Primitive | An image primitive value|
|
||||
|Lerp Image | Linear interpolation of all pixels of an image|
|
||||
|Offset Image Channel | Add to or subtract from an image color channel by a uniform value.|
|
||||
|Multiply Image Channel | Multiply or Invert an image color channel by a scalar value.|
|
||||
|Image Luminosity Adjustment | Adjusts the Luminosity (Value) of an image.|
|
||||
|Multiply Images | Multiplies two images together using `PIL.ImageChops.multiply()`.|
|
||||
|Blur NSFW Image | Add blur to NSFW-flagged images|
|
||||
|Paste Image | Pastes an image into another image.|
|
||||
|ImageProcessor | Base class for invocations that preprocess images for ControlNet|
|
||||
|Resize Image | Resizes an image to specific dimensions|
|
||||
|Image Saturation Adjustment | Adjusts the Saturation of an image.|
|
||||
|Scale Image | Scales an image by a factor|
|
||||
|Image to Latents | Encodes an image into latents.|
|
||||
|Add Invisible Watermark | Add an invisible watermark to an image|
|
||||
|
@ -1,7 +1,6 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from logging import Logger
|
||||
import sqlite3
|
||||
from invokeai.app.services.board_image_record_storage import (
|
||||
SqliteBoardImageRecordStorage,
|
||||
)
|
||||
@ -29,8 +28,6 @@ from ..services.invoker import Invoker
|
||||
from ..services.processor import DefaultInvocationProcessor
|
||||
from ..services.sqlite import SqliteItemStorage
|
||||
from ..services.model_manager_service import ModelManagerService
|
||||
from ..services.batch_manager import BatchManager
|
||||
from ..services.batch_manager_storage import SqliteBatchProcessStorage
|
||||
from ..services.invocation_stats import InvocationStatsService
|
||||
from .events import FastAPIEventService
|
||||
|
||||
@ -74,18 +71,18 @@ class ApiDependencies:
|
||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
db_location = str(db_path)
|
||||
|
||||
db_conn = sqlite3.connect(db_location, check_same_thread=False) # TODO: figure out a better threading solution
|
||||
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](conn=db_conn, table_name="graph_executions")
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
)
|
||||
|
||||
urls = LocalUrlService()
|
||||
image_record_storage = SqliteImageRecordStorage(conn=db_conn)
|
||||
image_record_storage = SqliteImageRecordStorage(db_location)
|
||||
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
||||
names = SimpleNameService()
|
||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
|
||||
|
||||
board_record_storage = SqliteBoardRecordStorage(conn=db_conn)
|
||||
board_image_record_storage = SqliteBoardImageRecordStorage(conn=db_conn)
|
||||
board_record_storage = SqliteBoardRecordStorage(db_location)
|
||||
board_image_record_storage = SqliteBoardImageRecordStorage(db_location)
|
||||
|
||||
boards = BoardService(
|
||||
services=BoardServiceDependencies(
|
||||
@ -119,19 +116,15 @@ class ApiDependencies:
|
||||
)
|
||||
)
|
||||
|
||||
batch_manager_storage = SqliteBatchProcessStorage(conn=db_conn)
|
||||
batch_manager = BatchManager(batch_manager_storage)
|
||||
|
||||
services = InvocationServices(
|
||||
model_manager=ModelManagerService(config, logger),
|
||||
events=events,
|
||||
latents=latents,
|
||||
images=images,
|
||||
batch_manager=batch_manager,
|
||||
boards=boards,
|
||||
board_images=board_images,
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs"),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
processor=DefaultInvocationProcessor(),
|
||||
configuration=config,
|
||||
|
@ -1,19 +1,19 @@
|
||||
import typing
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import Body
|
||||
from fastapi.routing import APIRouter
|
||||
from pathlib import Path
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.invocations.upscale import ESRGAN_MODELS
|
||||
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
||||
from invokeai.backend.util.logging import logging
|
||||
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
||||
from invokeai.app.invocations.upscale import ESRGAN_MODELS
|
||||
|
||||
from invokeai.version import __version__
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
from invokeai.backend.util.logging import logging
|
||||
|
||||
|
||||
class LogLevel(int, Enum):
|
||||
@ -55,7 +55,7 @@ async def get_version() -> AppVersion:
|
||||
|
||||
@app_router.get("/config", operation_id="get_config", status_code=200, response_model=AppConfig)
|
||||
async def get_config() -> AppConfig:
|
||||
infill_methods = ["tile", "lama", "cv2"]
|
||||
infill_methods = ["tile", "lama"]
|
||||
if PatchMatch.patchmatch_available():
|
||||
infill_methods.append("patchmatch")
|
||||
|
||||
|
@ -1,106 +0,0 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from fastapi import Body, HTTPException, Path, Response
|
||||
from fastapi.routing import APIRouter
|
||||
|
||||
from invokeai.app.services.batch_manager_storage import BatchSession, BatchSessionNotFoundException
|
||||
|
||||
# Importing * is bad karma but needed here for node detection
|
||||
from ...invocations import * # noqa: F401 F403
|
||||
from ...services.batch_manager import Batch, BatchProcessResponse
|
||||
from ...services.graph import Graph
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
batches_router = APIRouter(prefix="/v1/batches", tags=["sessions"])
|
||||
|
||||
|
||||
@batches_router.post(
|
||||
"/",
|
||||
operation_id="create_batch",
|
||||
responses={
|
||||
200: {"model": BatchProcessResponse},
|
||||
400: {"description": "Invalid json"},
|
||||
},
|
||||
)
|
||||
async def create_batch(
|
||||
graph: Graph = Body(description="The graph to initialize the session with"),
|
||||
batch: Batch = Body(description="Batch config to apply to the given graph"),
|
||||
) -> BatchProcessResponse:
|
||||
"""Creates a batch process"""
|
||||
return ApiDependencies.invoker.services.batch_manager.create_batch_process(batch, graph)
|
||||
|
||||
|
||||
@batches_router.put(
|
||||
"/b/{batch_process_id}/invoke",
|
||||
operation_id="start_batch",
|
||||
responses={
|
||||
202: {"description": "Batch process started"},
|
||||
404: {"description": "Batch session not found"},
|
||||
},
|
||||
)
|
||||
async def start_batch(
|
||||
batch_process_id: str = Path(description="ID of Batch to start"),
|
||||
) -> Response:
|
||||
"""Executes a batch process"""
|
||||
try:
|
||||
ApiDependencies.invoker.services.batch_manager.run_batch_process(batch_process_id)
|
||||
return Response(status_code=202)
|
||||
except BatchSessionNotFoundException:
|
||||
raise HTTPException(status_code=404, detail="Batch session not found")
|
||||
|
||||
|
||||
@batches_router.delete(
|
||||
"/b/{batch_process_id}",
|
||||
operation_id="cancel_batch",
|
||||
responses={202: {"description": "The batch is canceled"}},
|
||||
)
|
||||
async def cancel_batch(
|
||||
batch_process_id: str = Path(description="The id of the batch process to cancel"),
|
||||
) -> Response:
|
||||
"""Cancels a batch process"""
|
||||
ApiDependencies.invoker.services.batch_manager.cancel_batch_process(batch_process_id)
|
||||
return Response(status_code=202)
|
||||
|
||||
|
||||
@batches_router.get(
|
||||
"/incomplete",
|
||||
operation_id="list_incomplete_batches",
|
||||
responses={200: {"model": list[BatchProcessResponse]}},
|
||||
)
|
||||
async def list_incomplete_batches() -> list[BatchProcessResponse]:
|
||||
"""Lists incomplete batch processes"""
|
||||
return ApiDependencies.invoker.services.batch_manager.get_incomplete_batch_processes()
|
||||
|
||||
|
||||
@batches_router.get(
|
||||
"/",
|
||||
operation_id="list_batches",
|
||||
responses={200: {"model": list[BatchProcessResponse]}},
|
||||
)
|
||||
async def list_batches() -> list[BatchProcessResponse]:
|
||||
"""Lists all batch processes"""
|
||||
return ApiDependencies.invoker.services.batch_manager.get_batch_processes()
|
||||
|
||||
|
||||
@batches_router.get(
|
||||
"/b/{batch_process_id}",
|
||||
operation_id="get_batch",
|
||||
responses={200: {"model": BatchProcessResponse}},
|
||||
)
|
||||
async def get_batch(
|
||||
batch_process_id: str = Path(description="The id of the batch process to get"),
|
||||
) -> BatchProcessResponse:
|
||||
"""Gets a Batch Process"""
|
||||
return ApiDependencies.invoker.services.batch_manager.get_batch(batch_process_id)
|
||||
|
||||
|
||||
@batches_router.get(
|
||||
"/b/{batch_process_id}/sessions",
|
||||
operation_id="get_batch_sessions",
|
||||
responses={200: {"model": list[BatchSession]}},
|
||||
)
|
||||
async def get_batch_sessions(
|
||||
batch_process_id: str = Path(description="The id of the batch process to get"),
|
||||
) -> list[BatchSession]:
|
||||
"""Gets a list of batch sessions for a given batch process"""
|
||||
return ApiDependencies.invoker.services.batch_manager.get_sessions(batch_process_id)
|
@ -9,7 +9,13 @@ from pydantic.fields import Field
|
||||
# Importing * is bad karma but needed here for node detection
|
||||
from ...invocations import * # noqa: F401 F403
|
||||
from ...invocations.baseinvocation import BaseInvocation
|
||||
from ...services.graph import Edge, EdgeConnection, Graph, GraphExecutionState, NodeAlreadyExecutedError
|
||||
from ...services.graph import (
|
||||
Edge,
|
||||
EdgeConnection,
|
||||
Graph,
|
||||
GraphExecutionState,
|
||||
NodeAlreadyExecutedError,
|
||||
)
|
||||
from ...services.item_storage import PaginatedResults
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
|
@ -13,15 +13,11 @@ class SocketIO:
|
||||
|
||||
def __init__(self, app: FastAPI):
|
||||
self.__sio = SocketManager(app=app)
|
||||
self.__sio.on("subscribe", handler=self._handle_sub)
|
||||
self.__sio.on("unsubscribe", handler=self._handle_unsub)
|
||||
|
||||
self.__sio.on("subscribe_session", handler=self._handle_sub_session)
|
||||
self.__sio.on("unsubscribe_session", handler=self._handle_unsub_session)
|
||||
local_handler.register(event_name=EventServiceBase.session_event, _func=self._handle_session_event)
|
||||
|
||||
self.__sio.on("subscribe_batch", handler=self._handle_sub_batch)
|
||||
self.__sio.on("unsubscribe_batch", handler=self._handle_unsub_batch)
|
||||
local_handler.register(event_name=EventServiceBase.batch_event, _func=self._handle_batch_event)
|
||||
|
||||
async def _handle_session_event(self, event: Event):
|
||||
await self.__sio.emit(
|
||||
event=event[1]["event"],
|
||||
@ -29,25 +25,12 @@ class SocketIO:
|
||||
room=event[1]["data"]["graph_execution_state_id"],
|
||||
)
|
||||
|
||||
async def _handle_sub_session(self, sid, data, *args, **kwargs):
|
||||
async def _handle_sub(self, sid, data, *args, **kwargs):
|
||||
if "session" in data:
|
||||
self.__sio.enter_room(sid, data["session"])
|
||||
|
||||
async def _handle_unsub_session(self, sid, data, *args, **kwargs):
|
||||
# @app.sio.on('unsubscribe')
|
||||
|
||||
async def _handle_unsub(self, sid, data, *args, **kwargs):
|
||||
if "session" in data:
|
||||
self.__sio.leave_room(sid, data["session"])
|
||||
|
||||
async def _handle_batch_event(self, event: Event):
|
||||
await self.__sio.emit(
|
||||
event=event[1]["event"],
|
||||
data=event[1]["data"],
|
||||
room=event[1]["data"]["batch_id"],
|
||||
)
|
||||
|
||||
async def _handle_sub_batch(self, sid, data, *args, **kwargs):
|
||||
if "batch_id" in data:
|
||||
self.__sio.enter_room(sid, data["batch_id"])
|
||||
|
||||
async def _handle_unsub_batch(self, sid, data, *args, **kwargs):
|
||||
if "batch_id" in data:
|
||||
self.__sio.enter_room(sid, data["batch_id"])
|
||||
|
@ -24,7 +24,7 @@ import invokeai.frontend.web as web_dir
|
||||
import mimetypes
|
||||
|
||||
from .api.dependencies import ApiDependencies
|
||||
from .api.routers import sessions, batches, models, images, boards, board_images, app_info
|
||||
from .api.routers import sessions, models, images, boards, board_images, app_info
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, UIConfigBase
|
||||
|
||||
@ -90,8 +90,6 @@ async def shutdown_event():
|
||||
|
||||
app.include_router(sessions.session_router, prefix="/api")
|
||||
|
||||
app.include_router(batches.batches_router, prefix="/api")
|
||||
|
||||
app.include_router(models.models_router, prefix="/api")
|
||||
|
||||
app.include_router(images.images_router, prefix="/api")
|
||||
|
@ -5,7 +5,6 @@ import re
|
||||
import shlex
|
||||
import sys
|
||||
import time
|
||||
import sqlite3
|
||||
from typing import Union, get_type_hints, Optional
|
||||
|
||||
from pydantic import BaseModel, ValidationError
|
||||
@ -30,8 +29,6 @@ from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
||||
from invokeai.app.services.resource_name import SimpleNameService
|
||||
from invokeai.app.services.urls import LocalUrlService
|
||||
from invokeai.app.services.batch_manager import BatchManager
|
||||
from invokeai.app.services.batch_manager_storage import SqliteBatchProcessStorage
|
||||
from invokeai.app.services.invocation_stats import InvocationStatsService
|
||||
from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs
|
||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
@ -255,18 +252,19 @@ def invoke_cli():
|
||||
db_location = config.db_path
|
||||
db_location.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
db_conn = sqlite3.connect(db_location, check_same_thread=False) # TODO: figure out a better threading solution
|
||||
logger.info(f'InvokeAI database location is "{db_location}"')
|
||||
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](conn=db_conn, table_name="graph_executions")
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
)
|
||||
|
||||
urls = LocalUrlService()
|
||||
image_record_storage = SqliteImageRecordStorage(conn=db_conn)
|
||||
image_record_storage = SqliteImageRecordStorage(db_location)
|
||||
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
||||
names = SimpleNameService()
|
||||
|
||||
board_record_storage = SqliteBoardRecordStorage(conn=db_conn)
|
||||
board_image_record_storage = SqliteBoardImageRecordStorage(conn=db_conn)
|
||||
board_record_storage = SqliteBoardRecordStorage(db_location)
|
||||
board_image_record_storage = SqliteBoardImageRecordStorage(db_location)
|
||||
|
||||
boards = BoardService(
|
||||
services=BoardServiceDependencies(
|
||||
@ -300,19 +298,15 @@ def invoke_cli():
|
||||
)
|
||||
)
|
||||
|
||||
batch_manager_storage = SqliteBatchProcessStorage(conn=db_conn)
|
||||
batch_manager = BatchManager(batch_manager_storage)
|
||||
|
||||
services = InvocationServices(
|
||||
model_manager=model_manager,
|
||||
events=events,
|
||||
latents=ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")),
|
||||
images=images,
|
||||
boards=boards,
|
||||
batch_manager=batch_manager,
|
||||
board_images=board_images,
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs"),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
processor=DefaultInvocationProcessor(),
|
||||
performance_statistics=InvocationStatsService(graph_execution_manager),
|
||||
|
@ -26,16 +26,11 @@ from typing import (
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from pydantic.fields import Undefined, ModelField
|
||||
from pydantic.typing import NoArgAnyCallable
|
||||
import semver
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..services.invocation_services import InvocationServices
|
||||
|
||||
|
||||
class InvalidVersionError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class FieldDescriptions:
|
||||
denoising_start = "When to start denoising, expressed a percentage of total steps"
|
||||
denoising_end = "When to stop denoising, expressed a percentage of total steps"
|
||||
@ -110,39 +105,24 @@ class UIType(str, Enum):
|
||||
"""
|
||||
|
||||
# region Primitives
|
||||
Integer = "integer"
|
||||
Float = "float"
|
||||
Boolean = "boolean"
|
||||
Color = "ColorField"
|
||||
String = "string"
|
||||
Array = "array"
|
||||
Image = "ImageField"
|
||||
Latents = "LatentsField"
|
||||
Conditioning = "ConditioningField"
|
||||
Control = "ControlField"
|
||||
Float = "float"
|
||||
Image = "ImageField"
|
||||
Integer = "integer"
|
||||
Latents = "LatentsField"
|
||||
String = "string"
|
||||
# endregion
|
||||
|
||||
# region Collection Primitives
|
||||
BooleanCollection = "BooleanCollection"
|
||||
ColorCollection = "ColorCollection"
|
||||
ConditioningCollection = "ConditioningCollection"
|
||||
ControlCollection = "ControlCollection"
|
||||
FloatCollection = "FloatCollection"
|
||||
Color = "ColorField"
|
||||
ImageCollection = "ImageCollection"
|
||||
IntegerCollection = "IntegerCollection"
|
||||
ConditioningCollection = "ConditioningCollection"
|
||||
ColorCollection = "ColorCollection"
|
||||
LatentsCollection = "LatentsCollection"
|
||||
IntegerCollection = "IntegerCollection"
|
||||
FloatCollection = "FloatCollection"
|
||||
StringCollection = "StringCollection"
|
||||
# endregion
|
||||
|
||||
# region Polymorphic Primitives
|
||||
BooleanPolymorphic = "BooleanPolymorphic"
|
||||
ColorPolymorphic = "ColorPolymorphic"
|
||||
ConditioningPolymorphic = "ConditioningPolymorphic"
|
||||
ControlPolymorphic = "ControlPolymorphic"
|
||||
FloatPolymorphic = "FloatPolymorphic"
|
||||
ImagePolymorphic = "ImagePolymorphic"
|
||||
IntegerPolymorphic = "IntegerPolymorphic"
|
||||
LatentsPolymorphic = "LatentsPolymorphic"
|
||||
StringPolymorphic = "StringPolymorphic"
|
||||
BooleanCollection = "BooleanCollection"
|
||||
# endregion
|
||||
|
||||
# region Models
|
||||
@ -196,7 +176,6 @@ class _InputField(BaseModel):
|
||||
ui_type: Optional[UIType]
|
||||
ui_component: Optional[UIComponent]
|
||||
ui_order: Optional[int]
|
||||
item_default: Optional[Any]
|
||||
|
||||
|
||||
class _OutputField(BaseModel):
|
||||
@ -244,7 +223,6 @@ def InputField(
|
||||
ui_component: Optional[UIComponent] = None,
|
||||
ui_hidden: bool = False,
|
||||
ui_order: Optional[int] = None,
|
||||
item_default: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""
|
||||
@ -271,11 +249,6 @@ def InputField(
|
||||
For this case, you could provide `UIComponent.Textarea`.
|
||||
|
||||
: param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI.
|
||||
|
||||
: param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \
|
||||
|
||||
: param bool item_default: [None] Specifies the default item value, if this is a collection input. \
|
||||
Ignored for non-collection fields..
|
||||
"""
|
||||
return Field(
|
||||
*args,
|
||||
@ -309,7 +282,6 @@ def InputField(
|
||||
ui_component=ui_component,
|
||||
ui_hidden=ui_hidden,
|
||||
ui_order=ui_order,
|
||||
item_default=item_default,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -360,8 +332,6 @@ def OutputField(
|
||||
`UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field.
|
||||
|
||||
: param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \
|
||||
|
||||
: param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \
|
||||
"""
|
||||
return Field(
|
||||
*args,
|
||||
@ -406,9 +376,6 @@ class UIConfigBase(BaseModel):
|
||||
tags: Optional[list[str]] = Field(default_factory=None, description="The node's tags")
|
||||
title: Optional[str] = Field(default=None, description="The node's display name")
|
||||
category: Optional[str] = Field(default=None, description="The node's category")
|
||||
version: Optional[str] = Field(
|
||||
default=None, description='The node\'s version. Should be a valid semver string e.g. "1.0.0" or "3.8.13".'
|
||||
)
|
||||
|
||||
|
||||
class InvocationContext:
|
||||
@ -507,8 +474,6 @@ class BaseInvocation(ABC, BaseModel):
|
||||
schema["tags"] = uiconfig.tags
|
||||
if uiconfig and hasattr(uiconfig, "category"):
|
||||
schema["category"] = uiconfig.category
|
||||
if uiconfig and hasattr(uiconfig, "version"):
|
||||
schema["version"] = uiconfig.version
|
||||
if "required" not in schema or not isinstance(schema["required"], list):
|
||||
schema["required"] = list()
|
||||
schema["required"].extend(["type", "id"])
|
||||
@ -577,11 +542,7 @@ GenericBaseInvocation = TypeVar("GenericBaseInvocation", bound=BaseInvocation)
|
||||
|
||||
|
||||
def invocation(
|
||||
invocation_type: str,
|
||||
title: Optional[str] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
category: Optional[str] = None,
|
||||
version: Optional[str] = None,
|
||||
invocation_type: str, title: Optional[str] = None, tags: Optional[list[str]] = None, category: Optional[str] = None
|
||||
) -> Callable[[Type[GenericBaseInvocation]], Type[GenericBaseInvocation]]:
|
||||
"""
|
||||
Adds metadata to an invocation.
|
||||
@ -608,12 +569,6 @@ def invocation(
|
||||
cls.UIConfig.tags = tags
|
||||
if category is not None:
|
||||
cls.UIConfig.category = category
|
||||
if version is not None:
|
||||
try:
|
||||
semver.Version.parse(version)
|
||||
except ValueError as e:
|
||||
raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e
|
||||
cls.UIConfig.version = version
|
||||
|
||||
# Add the invocation type to the pydantic model of the invocation
|
||||
invocation_type_annotation = Literal[invocation_type] # type: ignore
|
||||
@ -625,9 +580,8 @@ def invocation(
|
||||
config=cls.__config__,
|
||||
)
|
||||
cls.__fields__.update({"type": invocation_type_field})
|
||||
# to support 3.9, 3.10 and 3.11, as described in https://docs.python.org/3/howto/annotations.html
|
||||
if annotations := cls.__dict__.get("__annotations__", None):
|
||||
annotations.update({"type": invocation_type_annotation})
|
||||
cls.__annotations__.update({"type": invocation_type_annotation})
|
||||
|
||||
return cls
|
||||
|
||||
return wrapper
|
||||
@ -661,10 +615,7 @@ def invocation_output(
|
||||
config=cls.__config__,
|
||||
)
|
||||
cls.__fields__.update({"type": output_type_field})
|
||||
|
||||
# to support 3.9, 3.10 and 3.11, as described in https://docs.python.org/3/howto/annotations.html
|
||||
if annotations := cls.__dict__.get("__annotations__", None):
|
||||
annotations.update({"type": output_type_annotation})
|
||||
cls.__annotations__.update({"type": output_type_annotation})
|
||||
|
||||
return cls
|
||||
|
||||
|
@ -10,9 +10,7 @@ from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||
|
||||
|
||||
@invocation(
|
||||
"range", title="Integer Range", tags=["collection", "integer", "range"], category="collections", version="1.0.0"
|
||||
)
|
||||
@invocation("range", title="Integer Range", tags=["collection", "integer", "range"], category="collections")
|
||||
class RangeInvocation(BaseInvocation):
|
||||
"""Creates a range of numbers from start to stop with step"""
|
||||
|
||||
@ -35,7 +33,6 @@ class RangeInvocation(BaseInvocation):
|
||||
title="Integer Range of Size",
|
||||
tags=["collection", "integer", "size", "range"],
|
||||
category="collections",
|
||||
version="1.0.0",
|
||||
)
|
||||
class RangeOfSizeInvocation(BaseInvocation):
|
||||
"""Creates a range from start to start + size with step"""
|
||||
@ -53,7 +50,6 @@ class RangeOfSizeInvocation(BaseInvocation):
|
||||
title="Random Range",
|
||||
tags=["range", "integer", "random", "collection"],
|
||||
category="collections",
|
||||
version="1.0.0",
|
||||
)
|
||||
class RandomRangeInvocation(BaseInvocation):
|
||||
"""Creates a collection of random numbers"""
|
||||
|
@ -44,7 +44,7 @@ class ConditioningFieldData:
|
||||
# PerpNeg = "perp_neg"
|
||||
|
||||
|
||||
@invocation("compel", title="Prompt", tags=["prompt", "compel"], category="conditioning", version="1.0.0")
|
||||
@invocation("compel", title="Prompt", tags=["prompt", "compel"], category="conditioning")
|
||||
class CompelInvocation(BaseInvocation):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
|
||||
@ -267,7 +267,6 @@ class SDXLPromptInvocationBase:
|
||||
title="SDXL Prompt",
|
||||
tags=["sdxl", "compel", "prompt"],
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
)
|
||||
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
@ -280,8 +279,8 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
crop_left: int = InputField(default=0, description="")
|
||||
target_width: int = InputField(default=1024, description="")
|
||||
target_height: int = InputField(default=1024, description="")
|
||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
|
||||
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
|
||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
@ -352,7 +351,6 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
title="SDXL Refiner Prompt",
|
||||
tags=["sdxl", "compel", "prompt"],
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
)
|
||||
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
@ -405,7 +403,7 @@ class ClipSkipInvocationOutput(BaseInvocationOutput):
|
||||
clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
|
||||
|
||||
@invocation("clip_skip", title="CLIP Skip", tags=["clipskip", "clip", "skip"], category="conditioning", version="1.0.0")
|
||||
@invocation("clip_skip", title="CLIP Skip", tags=["clipskip", "clip", "skip"], category="conditioning")
|
||||
class ClipSkipInvocation(BaseInvocation):
|
||||
"""Skip layers in clip text_encoder model."""
|
||||
|
||||
|
@ -95,12 +95,14 @@ class ControlOutput(BaseInvocationOutput):
|
||||
control: ControlField = OutputField(description=FieldDescriptions.control)
|
||||
|
||||
|
||||
@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.0.0")
|
||||
@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet")
|
||||
class ControlNetInvocation(BaseInvocation):
|
||||
"""Collects ControlNet info to pass to other nodes"""
|
||||
|
||||
image: ImageField = InputField(description="The control image")
|
||||
control_model: ControlNetModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct)
|
||||
control_model: ControlNetModelField = InputField(
|
||||
default="lllyasviel/sd-controlnet-canny", description=FieldDescriptions.controlnet_model, input=Input.Direct
|
||||
)
|
||||
control_weight: Union[float, List[float]] = InputField(
|
||||
default=1.0, description="The weight given to the ControlNet", ui_type=UIType.Float
|
||||
)
|
||||
@ -127,9 +129,7 @@ class ControlNetInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"image_processor", title="Base Image Processor", tags=["controlnet"], category="controlnet", version="1.0.0"
|
||||
)
|
||||
@invocation("image_processor", title="Base Image Processor", tags=["controlnet"], category="controlnet")
|
||||
class ImageProcessorInvocation(BaseInvocation):
|
||||
"""Base class for invocations that preprocess images for ControlNet"""
|
||||
|
||||
@ -173,7 +173,6 @@ class ImageProcessorInvocation(BaseInvocation):
|
||||
title="Canny Processor",
|
||||
tags=["controlnet", "canny"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
)
|
||||
class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Canny edge detection for ControlNet"""
|
||||
@ -196,7 +195,6 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="HED (softedge) Processor",
|
||||
tags=["controlnet", "hed", "softedge"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
)
|
||||
class HedImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies HED edge detection to image"""
|
||||
@ -225,7 +223,6 @@ class HedImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Lineart Processor",
|
||||
tags=["controlnet", "lineart"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
)
|
||||
class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies line art processing to image"""
|
||||
@ -247,7 +244,6 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Lineart Anime Processor",
|
||||
tags=["controlnet", "lineart", "anime"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
)
|
||||
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies line art anime processing to image"""
|
||||
@ -270,7 +266,6 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Openpose Processor",
|
||||
tags=["controlnet", "openpose", "pose"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
)
|
||||
class OpenposeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies Openpose processing to image"""
|
||||
@ -295,7 +290,6 @@ class OpenposeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Midas Depth Processor",
|
||||
tags=["controlnet", "midas"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
)
|
||||
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies Midas depth processing to image"""
|
||||
@ -322,7 +316,6 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Normal BAE Processor",
|
||||
tags=["controlnet"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
)
|
||||
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies NormalBae processing to image"""
|
||||
@ -338,9 +331,7 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet", version="1.0.0"
|
||||
)
|
||||
@invocation("mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet")
|
||||
class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies MLSD processing to image"""
|
||||
|
||||
@ -361,9 +352,7 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet", version="1.0.0"
|
||||
)
|
||||
@invocation("pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet")
|
||||
class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies PIDI processing to image"""
|
||||
|
||||
@ -389,7 +378,6 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Content Shuffle Processor",
|
||||
tags=["controlnet", "contentshuffle"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies content shuffle processing to image"""
|
||||
@ -419,7 +407,6 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Zoe (Depth) Processor",
|
||||
tags=["controlnet", "zoe", "depth"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies Zoe depth processing to image"""
|
||||
@ -435,7 +422,6 @@ class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Mediapipe Face Processor",
|
||||
tags=["controlnet", "mediapipe", "face"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
)
|
||||
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies mediapipe face processing to image"""
|
||||
@ -458,7 +444,6 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Leres (Depth) Processor",
|
||||
tags=["controlnet", "leres", "depth"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
)
|
||||
class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies leres processing to image"""
|
||||
@ -487,7 +472,6 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Tile Resample Processor",
|
||||
tags=["controlnet", "tile"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
)
|
||||
class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Tile resampler processor"""
|
||||
@ -527,7 +511,6 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Segment Anything Processor",
|
||||
tags=["controlnet", "segmentanything"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
)
|
||||
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies segment anything processing to image"""
|
||||
|
@ -10,7 +10,12 @@ from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||
|
||||
|
||||
@invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.0.0")
|
||||
@invocation(
|
||||
"cv_inpaint",
|
||||
title="OpenCV Inpaint",
|
||||
tags=["opencv", "inpaint"],
|
||||
category="inpaint",
|
||||
)
|
||||
class CvInpaintInvocation(BaseInvocation):
|
||||
"""Simple inpaint using opencv."""
|
||||
|
||||
|
@ -16,7 +16,7 @@ from ..models.image import ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, invocation
|
||||
|
||||
|
||||
@invocation("show_image", title="Show Image", tags=["image"], category="image", version="1.0.0")
|
||||
@invocation("show_image", title="Show Image", tags=["image"], category="image")
|
||||
class ShowImageInvocation(BaseInvocation):
|
||||
"""Displays a provided image using the OS image viewer, and passes it forward in the pipeline."""
|
||||
|
||||
@ -36,7 +36,7 @@ class ShowImageInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("blank_image", title="Blank Image", tags=["image"], category="image", version="1.0.0")
|
||||
@invocation("blank_image", title="Blank Image", tags=["image"], category="image")
|
||||
class BlankImageInvocation(BaseInvocation):
|
||||
"""Creates a blank image and forwards it to the pipeline"""
|
||||
|
||||
@ -65,7 +65,7 @@ class BlankImageInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_crop", title="Crop Image", tags=["image", "crop"], category="image", version="1.0.0")
|
||||
@invocation("img_crop", title="Crop Image", tags=["image", "crop"], category="image")
|
||||
class ImageCropInvocation(BaseInvocation):
|
||||
"""Crops an image to a specified box. The box can be outside of the image."""
|
||||
|
||||
@ -98,7 +98,7 @@ class ImageCropInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_paste", title="Paste Image", tags=["image", "paste"], category="image", version="1.0.0")
|
||||
@invocation("img_paste", title="Paste Image", tags=["image", "paste"], category="image")
|
||||
class ImagePasteInvocation(BaseInvocation):
|
||||
"""Pastes an image into another image."""
|
||||
|
||||
@ -146,7 +146,7 @@ class ImagePasteInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("tomask", title="Mask from Alpha", tags=["image", "mask"], category="image", version="1.0.0")
|
||||
@invocation("tomask", title="Mask from Alpha", tags=["image", "mask"], category="image")
|
||||
class MaskFromAlphaInvocation(BaseInvocation):
|
||||
"""Extracts the alpha channel of an image as a mask."""
|
||||
|
||||
@ -177,7 +177,7 @@ class MaskFromAlphaInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_mul", title="Multiply Images", tags=["image", "multiply"], category="image", version="1.0.0")
|
||||
@invocation("img_mul", title="Multiply Images", tags=["image", "multiply"], category="image")
|
||||
class ImageMultiplyInvocation(BaseInvocation):
|
||||
"""Multiplies two images together using `PIL.ImageChops.multiply()`."""
|
||||
|
||||
@ -210,7 +210,7 @@ class ImageMultiplyInvocation(BaseInvocation):
|
||||
IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
|
||||
|
||||
|
||||
@invocation("img_chan", title="Extract Image Channel", tags=["image", "channel"], category="image", version="1.0.0")
|
||||
@invocation("img_chan", title="Extract Image Channel", tags=["image", "channel"], category="image")
|
||||
class ImageChannelInvocation(BaseInvocation):
|
||||
"""Gets a channel from an image."""
|
||||
|
||||
@ -242,7 +242,7 @@ class ImageChannelInvocation(BaseInvocation):
|
||||
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
|
||||
|
||||
|
||||
@invocation("img_conv", title="Convert Image Mode", tags=["image", "convert"], category="image", version="1.0.0")
|
||||
@invocation("img_conv", title="Convert Image Mode", tags=["image", "convert"], category="image")
|
||||
class ImageConvertInvocation(BaseInvocation):
|
||||
"""Converts an image to a different mode."""
|
||||
|
||||
@ -271,7 +271,7 @@ class ImageConvertInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_blur", title="Blur Image", tags=["image", "blur"], category="image", version="1.0.0")
|
||||
@invocation("img_blur", title="Blur Image", tags=["image", "blur"], category="image")
|
||||
class ImageBlurInvocation(BaseInvocation):
|
||||
"""Blurs an image"""
|
||||
|
||||
@ -325,7 +325,7 @@ PIL_RESAMPLING_MAP = {
|
||||
}
|
||||
|
||||
|
||||
@invocation("img_resize", title="Resize Image", tags=["image", "resize"], category="image", version="1.0.0")
|
||||
@invocation("img_resize", title="Resize Image", tags=["image", "resize"], category="image")
|
||||
class ImageResizeInvocation(BaseInvocation):
|
||||
"""Resizes an image to specific dimensions"""
|
||||
|
||||
@ -365,7 +365,7 @@ class ImageResizeInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_scale", title="Scale Image", tags=["image", "scale"], category="image", version="1.0.0")
|
||||
@invocation("img_scale", title="Scale Image", tags=["image", "scale"], category="image")
|
||||
class ImageScaleInvocation(BaseInvocation):
|
||||
"""Scales an image by a factor"""
|
||||
|
||||
@ -406,7 +406,7 @@ class ImageScaleInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_lerp", title="Lerp Image", tags=["image", "lerp"], category="image", version="1.0.0")
|
||||
@invocation("img_lerp", title="Lerp Image", tags=["image", "lerp"], category="image")
|
||||
class ImageLerpInvocation(BaseInvocation):
|
||||
"""Linear interpolation of all pixels of an image"""
|
||||
|
||||
@ -439,7 +439,7 @@ class ImageLerpInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_ilerp", title="Inverse Lerp Image", tags=["image", "ilerp"], category="image", version="1.0.0")
|
||||
@invocation("img_ilerp", title="Inverse Lerp Image", tags=["image", "ilerp"], category="image")
|
||||
class ImageInverseLerpInvocation(BaseInvocation):
|
||||
"""Inverse linear interpolation of all pixels of an image"""
|
||||
|
||||
@ -472,7 +472,7 @@ class ImageInverseLerpInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_nsfw", title="Blur NSFW Image", tags=["image", "nsfw"], category="image", version="1.0.0")
|
||||
@invocation("img_nsfw", title="Blur NSFW Image", tags=["image", "nsfw"], category="image")
|
||||
class ImageNSFWBlurInvocation(BaseInvocation):
|
||||
"""Add blur to NSFW-flagged images"""
|
||||
|
||||
@ -517,9 +517,7 @@ class ImageNSFWBlurInvocation(BaseInvocation):
|
||||
return caution.resize((caution.width // 2, caution.height // 2))
|
||||
|
||||
|
||||
@invocation(
|
||||
"img_watermark", title="Add Invisible Watermark", tags=["image", "watermark"], category="image", version="1.0.0"
|
||||
)
|
||||
@invocation("img_watermark", title="Add Invisible Watermark", tags=["image", "watermark"], category="image")
|
||||
class ImageWatermarkInvocation(BaseInvocation):
|
||||
"""Add an invisible watermark to an image"""
|
||||
|
||||
@ -550,7 +548,7 @@ class ImageWatermarkInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("mask_edge", title="Mask Edge", tags=["image", "mask", "inpaint"], category="image", version="1.0.0")
|
||||
@invocation("mask_edge", title="Mask Edge", tags=["image", "mask", "inpaint"], category="image")
|
||||
class MaskEdgeInvocation(BaseInvocation):
|
||||
"""Applies an edge mask to an image"""
|
||||
|
||||
@ -563,7 +561,7 @@ class MaskEdgeInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
mask = context.services.images.get_pil_image(self.image.image_name).convert("L")
|
||||
mask = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
npimg = numpy.asarray(mask, dtype=numpy.uint8)
|
||||
npgradient = numpy.uint8(255 * (1.0 - numpy.floor(numpy.abs(0.5 - numpy.float32(npimg) / 255.0) * 2.0)))
|
||||
@ -595,9 +593,7 @@ class MaskEdgeInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"mask_combine", title="Combine Masks", tags=["image", "mask", "multiply"], category="image", version="1.0.0"
|
||||
)
|
||||
@invocation("mask_combine", title="Combine Masks", tags=["image", "mask", "multiply"], category="image")
|
||||
class MaskCombineInvocation(BaseInvocation):
|
||||
"""Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`."""
|
||||
|
||||
@ -627,7 +623,7 @@ class MaskCombineInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("color_correct", title="Color Correct", tags=["image", "color"], category="image", version="1.0.0")
|
||||
@invocation("color_correct", title="Color Correct", tags=["image", "color"], category="image")
|
||||
class ColorCorrectInvocation(BaseInvocation):
|
||||
"""
|
||||
Shifts the colors of a target image to match the reference image, optionally
|
||||
@ -700,13 +696,8 @@ class ColorCorrectInvocation(BaseInvocation):
|
||||
# Blur the mask out (into init image) by specified amount
|
||||
if self.mask_blur_radius > 0:
|
||||
nm = numpy.asarray(pil_init_mask, dtype=numpy.uint8)
|
||||
inverted_nm = 255 - nm
|
||||
dilation_size = int(round(self.mask_blur_radius) + 20)
|
||||
dilating_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (dilation_size, dilation_size))
|
||||
inverted_dilated_nm = cv2.dilate(inverted_nm, dilating_kernel)
|
||||
dilated_nm = 255 - inverted_dilated_nm
|
||||
nmd = cv2.erode(
|
||||
dilated_nm,
|
||||
nm,
|
||||
kernel=numpy.ones((3, 3), dtype=numpy.uint8),
|
||||
iterations=int(self.mask_blur_radius / 2),
|
||||
)
|
||||
@ -737,7 +728,7 @@ class ColorCorrectInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_hue_adjust", title="Adjust Image Hue", tags=["image", "hue"], category="image", version="1.0.0")
|
||||
@invocation("img_hue_adjust", title="Adjust Image Hue", tags=["image", "hue"], category="image")
|
||||
class ImageHueAdjustmentInvocation(BaseInvocation):
|
||||
"""Adjusts the Hue of an image."""
|
||||
|
||||
@ -778,95 +769,38 @@ class ImageHueAdjustmentInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
COLOR_CHANNELS = Literal[
|
||||
"Red (RGBA)",
|
||||
"Green (RGBA)",
|
||||
"Blue (RGBA)",
|
||||
"Alpha (RGBA)",
|
||||
"Cyan (CMYK)",
|
||||
"Magenta (CMYK)",
|
||||
"Yellow (CMYK)",
|
||||
"Black (CMYK)",
|
||||
"Hue (HSV)",
|
||||
"Saturation (HSV)",
|
||||
"Value (HSV)",
|
||||
"Luminosity (LAB)",
|
||||
"A (LAB)",
|
||||
"B (LAB)",
|
||||
"Y (YCbCr)",
|
||||
"Cb (YCbCr)",
|
||||
"Cr (YCbCr)",
|
||||
]
|
||||
|
||||
CHANNEL_FORMATS = {
|
||||
"Red (RGBA)": ("RGBA", 0),
|
||||
"Green (RGBA)": ("RGBA", 1),
|
||||
"Blue (RGBA)": ("RGBA", 2),
|
||||
"Alpha (RGBA)": ("RGBA", 3),
|
||||
"Cyan (CMYK)": ("CMYK", 0),
|
||||
"Magenta (CMYK)": ("CMYK", 1),
|
||||
"Yellow (CMYK)": ("CMYK", 2),
|
||||
"Black (CMYK)": ("CMYK", 3),
|
||||
"Hue (HSV)": ("HSV", 0),
|
||||
"Saturation (HSV)": ("HSV", 1),
|
||||
"Value (HSV)": ("HSV", 2),
|
||||
"Luminosity (LAB)": ("LAB", 0),
|
||||
"A (LAB)": ("LAB", 1),
|
||||
"B (LAB)": ("LAB", 2),
|
||||
"Y (YCbCr)": ("YCbCr", 0),
|
||||
"Cb (YCbCr)": ("YCbCr", 1),
|
||||
"Cr (YCbCr)": ("YCbCr", 2),
|
||||
}
|
||||
|
||||
|
||||
@invocation(
|
||||
"img_channel_offset",
|
||||
title="Offset Image Channel",
|
||||
tags=[
|
||||
"image",
|
||||
"offset",
|
||||
"red",
|
||||
"green",
|
||||
"blue",
|
||||
"alpha",
|
||||
"cyan",
|
||||
"magenta",
|
||||
"yellow",
|
||||
"black",
|
||||
"hue",
|
||||
"saturation",
|
||||
"luminosity",
|
||||
"value",
|
||||
],
|
||||
"img_luminosity_adjust",
|
||||
title="Adjust Image Luminosity",
|
||||
tags=["image", "luminosity", "hsl"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageChannelOffsetInvocation(BaseInvocation):
|
||||
"""Add or subtract a value from a specific color channel of an image."""
|
||||
class ImageLuminosityAdjustmentInvocation(BaseInvocation):
|
||||
"""Adjusts the Luminosity (Value) of an image."""
|
||||
|
||||
image: ImageField = InputField(description="The image to adjust")
|
||||
channel: COLOR_CHANNELS = InputField(description="Which channel to adjust")
|
||||
offset: int = InputField(default=0, ge=-255, le=255, description="The amount to adjust the channel by")
|
||||
luminosity: float = InputField(
|
||||
default=1.0, ge=0, le=1, description="The factor by which to adjust the luminosity (value)"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
pil_image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
# extract the channel and mode from the input and reference tuple
|
||||
mode = CHANNEL_FORMATS[self.channel][0]
|
||||
channel_number = CHANNEL_FORMATS[self.channel][1]
|
||||
# Convert PIL image to OpenCV format (numpy array), note color channel
|
||||
# ordering is changed from RGB to BGR
|
||||
image = numpy.array(pil_image.convert("RGB"))[:, :, ::-1]
|
||||
|
||||
# Convert PIL image to new format
|
||||
converted_image = numpy.array(pil_image.convert(mode)).astype(int)
|
||||
image_channel = converted_image[:, :, channel_number]
|
||||
# Convert image to HSV color space
|
||||
hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
|
||||
|
||||
# Adjust the value, clipping to 0..255
|
||||
image_channel = numpy.clip(image_channel + self.offset, 0, 255)
|
||||
# Adjust the luminosity (value)
|
||||
hsv_image[:, :, 2] = numpy.clip(hsv_image[:, :, 2] * self.luminosity, 0, 255)
|
||||
|
||||
# Put the channel back into the image
|
||||
converted_image[:, :, channel_number] = image_channel
|
||||
# Convert image back to BGR color space
|
||||
image = cv2.cvtColor(hsv_image, cv2.COLOR_HSV2BGR)
|
||||
|
||||
# Convert back to RGBA format and output
|
||||
pil_image = Image.fromarray(converted_image.astype(numpy.uint8), mode=mode).convert("RGBA")
|
||||
# Convert back to PIL format and to original color mode
|
||||
pil_image = Image.fromarray(image[:, :, ::-1], "RGB").convert("RGBA")
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=pil_image,
|
||||
@ -888,60 +822,35 @@ class ImageChannelOffsetInvocation(BaseInvocation):
|
||||
|
||||
|
||||
@invocation(
|
||||
"img_channel_multiply",
|
||||
title="Multiply Image Channel",
|
||||
tags=[
|
||||
"image",
|
||||
"invert",
|
||||
"scale",
|
||||
"multiply",
|
||||
"red",
|
||||
"green",
|
||||
"blue",
|
||||
"alpha",
|
||||
"cyan",
|
||||
"magenta",
|
||||
"yellow",
|
||||
"black",
|
||||
"hue",
|
||||
"saturation",
|
||||
"luminosity",
|
||||
"value",
|
||||
],
|
||||
"img_saturation_adjust",
|
||||
title="Adjust Image Saturation",
|
||||
tags=["image", "saturation", "hsl"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageChannelMultiplyInvocation(BaseInvocation):
|
||||
"""Scale a specific color channel of an image."""
|
||||
class ImageSaturationAdjustmentInvocation(BaseInvocation):
|
||||
"""Adjusts the Saturation of an image."""
|
||||
|
||||
image: ImageField = InputField(description="The image to adjust")
|
||||
channel: COLOR_CHANNELS = InputField(description="Which channel to adjust")
|
||||
scale: float = InputField(default=1.0, ge=0.0, description="The amount to scale the channel by.")
|
||||
invert_channel: bool = InputField(default=False, description="Invert the channel after scaling")
|
||||
saturation: float = InputField(default=1.0, ge=0, le=1, description="The factor by which to adjust the saturation")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
pil_image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
# extract the channel and mode from the input and reference tuple
|
||||
mode = CHANNEL_FORMATS[self.channel][0]
|
||||
channel_number = CHANNEL_FORMATS[self.channel][1]
|
||||
# Convert PIL image to OpenCV format (numpy array), note color channel
|
||||
# ordering is changed from RGB to BGR
|
||||
image = numpy.array(pil_image.convert("RGB"))[:, :, ::-1]
|
||||
|
||||
# Convert PIL image to new format
|
||||
converted_image = numpy.array(pil_image.convert(mode)).astype(float)
|
||||
image_channel = converted_image[:, :, channel_number]
|
||||
# Convert image to HSV color space
|
||||
hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
|
||||
|
||||
# Adjust the value, clipping to 0..255
|
||||
image_channel = numpy.clip(image_channel * self.scale, 0, 255)
|
||||
# Adjust the saturation
|
||||
hsv_image[:, :, 1] = numpy.clip(hsv_image[:, :, 1] * self.saturation, 0, 255)
|
||||
|
||||
# Invert the channel if requested
|
||||
if self.invert_channel:
|
||||
image_channel = 255 - image_channel
|
||||
# Convert image back to BGR color space
|
||||
image = cv2.cvtColor(hsv_image, cv2.COLOR_HSV2BGR)
|
||||
|
||||
# Put the channel back into the image
|
||||
converted_image[:, :, channel_number] = image_channel
|
||||
|
||||
# Convert back to RGBA format and output
|
||||
pil_image = Image.fromarray(converted_image.astype(numpy.uint8), mode=mode).convert("RGBA")
|
||||
# Convert back to PIL format and to original color mode
|
||||
pil_image = Image.fromarray(image[:, :, ::-1], "RGB").convert("RGBA")
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=pil_image,
|
||||
|
@ -8,17 +8,19 @@ from PIL import Image, ImageOps
|
||||
|
||||
from invokeai.app.invocations.primitives import ColorField, ImageField, ImageOutput
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint
|
||||
from invokeai.backend.image_util.lama import LaMA
|
||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||
|
||||
from ..models.image import ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||
from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES
|
||||
|
||||
|
||||
def infill_methods() -> list[str]:
|
||||
methods = ["tile", "solid", "lama", "cv2"]
|
||||
methods = [
|
||||
"tile",
|
||||
"solid",
|
||||
"lama",
|
||||
]
|
||||
if PatchMatch.patchmatch_available():
|
||||
methods.insert(0, "patchmatch")
|
||||
return methods
|
||||
@ -47,10 +49,6 @@ def infill_patchmatch(im: Image.Image) -> Image.Image:
|
||||
return im_patched
|
||||
|
||||
|
||||
def infill_cv2(im: Image.Image) -> Image.Image:
|
||||
return cv2_inpaint(im)
|
||||
|
||||
|
||||
def get_tile_images(image: np.ndarray, width=8, height=8):
|
||||
_nrows, _ncols, depth = image.shape
|
||||
_strides = image.strides
|
||||
@ -118,7 +116,7 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int]
|
||||
return si
|
||||
|
||||
|
||||
@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.0.0")
|
||||
@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint")
|
||||
class InfillColorInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image with a solid color"""
|
||||
|
||||
@ -153,7 +151,7 @@ class InfillColorInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.0.0")
|
||||
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint")
|
||||
class InfillTileInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image with tiles of the image"""
|
||||
|
||||
@ -189,42 +187,20 @@ class InfillTileInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.0.0"
|
||||
)
|
||||
@invocation("infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint")
|
||||
class InfillPatchMatchInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image using the PatchMatch algorithm"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
downscale: float = InputField(default=2.0, gt=0, description="Run patchmatch on downscaled image to speedup infill")
|
||||
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name).convert("RGBA")
|
||||
|
||||
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
||||
|
||||
infill_image = image.copy()
|
||||
width = int(image.width / self.downscale)
|
||||
height = int(image.height / self.downscale)
|
||||
infill_image = infill_image.resize(
|
||||
(width, height),
|
||||
resample=resample_mode,
|
||||
)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
if PatchMatch.patchmatch_available():
|
||||
infilled = infill_patchmatch(infill_image)
|
||||
infilled = infill_patchmatch(image.copy())
|
||||
else:
|
||||
raise ValueError("PatchMatch is not available on this system")
|
||||
|
||||
infilled = infilled.resize(
|
||||
(image.width, image.height),
|
||||
resample=resample_mode,
|
||||
)
|
||||
|
||||
infilled.paste(image, (0, 0), mask=image.split()[-1])
|
||||
# image.paste(infilled, (0, 0), mask=image.split()[-1])
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=infilled,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
@ -242,7 +218,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.0.0")
|
||||
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint")
|
||||
class LaMaInfillInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image using the LaMa model"""
|
||||
|
||||
@ -267,30 +243,3 @@ class LaMaInfillInvocation(BaseInvocation):
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint")
|
||||
class CV2InfillInvocation(BaseInvocation):
|
||||
"""Infills transparent areas of an image using OpenCV Inpainting"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
infilled = infill_cv2(image.copy())
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=infilled,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
@ -74,7 +74,7 @@ class SchedulerOutput(BaseInvocationOutput):
|
||||
scheduler: SAMPLER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler)
|
||||
|
||||
|
||||
@invocation("scheduler", title="Scheduler", tags=["scheduler"], category="latents", version="1.0.0")
|
||||
@invocation("scheduler", title="Scheduler", tags=["scheduler"], category="latents")
|
||||
class SchedulerInvocation(BaseInvocation):
|
||||
"""Selects a scheduler."""
|
||||
|
||||
@ -86,9 +86,7 @@ class SchedulerInvocation(BaseInvocation):
|
||||
return SchedulerOutput(scheduler=self.scheduler)
|
||||
|
||||
|
||||
@invocation(
|
||||
"create_denoise_mask", title="Create Denoise Mask", tags=["mask", "denoise"], category="latents", version="1.0.0"
|
||||
)
|
||||
@invocation("create_denoise_mask", title="Create Denoise Mask", tags=["mask", "denoise"], category="latents")
|
||||
class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
"""Creates mask for denoising model run."""
|
||||
|
||||
@ -188,7 +186,6 @@ def get_scheduler(
|
||||
title="Denoise Latents",
|
||||
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
)
|
||||
class DenoiseLatentsInvocation(BaseInvocation):
|
||||
"""Denoises noisy latents to decodable images"""
|
||||
@ -211,14 +208,12 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection, title="UNet", ui_order=2)
|
||||
control: Union[ControlField, list[ControlField]] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.control,
|
||||
input=Input.Connection,
|
||||
ui_order=5,
|
||||
default=None, description=FieldDescriptions.control, input=Input.Connection, ui_order=5
|
||||
)
|
||||
latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection)
|
||||
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
||||
default=None, description=FieldDescriptions.mask, input=Input.Connection, ui_order=6
|
||||
default=None,
|
||||
description=FieldDescriptions.mask,
|
||||
)
|
||||
|
||||
@validator("cfg_scale")
|
||||
@ -291,30 +286,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
unet,
|
||||
scheduler,
|
||||
) -> StableDiffusionGeneratorPipeline:
|
||||
# TODO:
|
||||
# configure_model_padding(
|
||||
# unet,
|
||||
# self.seamless,
|
||||
# self.seamless_axes,
|
||||
# )
|
||||
|
||||
class FakeVae:
|
||||
class FakeVaeConfig:
|
||||
def __init__(self):
|
||||
self.block_out_channels = [0]
|
||||
|
||||
def __init__(self):
|
||||
self.config = FakeVae.FakeVaeConfig()
|
||||
|
||||
return StableDiffusionGeneratorPipeline(
|
||||
vae=FakeVae(), # TODO: oh...
|
||||
text_encoder=None,
|
||||
tokenizer=None,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
|
||||
def prep_control_data(
|
||||
@ -322,7 +296,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
context: InvocationContext,
|
||||
# really only need model for dtype and device
|
||||
model: StableDiffusionGeneratorPipeline,
|
||||
control_input: Union[ControlField, List[ControlField]],
|
||||
control_input: List[ControlField],
|
||||
latents_shape: List[int],
|
||||
exit_stack: ExitStack,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
@ -547,9 +521,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
return build_latents_output(latents_name=name, latents=result_latents, seed=seed)
|
||||
|
||||
|
||||
@invocation(
|
||||
"l2i", title="Latents to Image", tags=["latents", "image", "vae", "l2i"], category="latents", version="1.0.0"
|
||||
)
|
||||
@invocation("l2i", title="Latents to Image", tags=["latents", "image", "vae", "l2i"], category="latents")
|
||||
class LatentsToImageInvocation(BaseInvocation):
|
||||
"""Generates an image from latents."""
|
||||
|
||||
@ -646,7 +618,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
|
||||
|
||||
|
||||
@invocation("lresize", title="Resize Latents", tags=["latents", "resize"], category="latents", version="1.0.0")
|
||||
@invocation("lresize", title="Resize Latents", tags=["latents", "resize"], category="latents")
|
||||
class ResizeLatentsInvocation(BaseInvocation):
|
||||
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
|
||||
|
||||
@ -690,7 +662,7 @@ class ResizeLatentsInvocation(BaseInvocation):
|
||||
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||
|
||||
|
||||
@invocation("lscale", title="Scale Latents", tags=["latents", "resize"], category="latents", version="1.0.0")
|
||||
@invocation("lscale", title="Scale Latents", tags=["latents", "resize"], category="latents")
|
||||
class ScaleLatentsInvocation(BaseInvocation):
|
||||
"""Scales latents by a given factor."""
|
||||
|
||||
@ -726,9 +698,7 @@ class ScaleLatentsInvocation(BaseInvocation):
|
||||
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||
|
||||
|
||||
@invocation(
|
||||
"i2l", title="Image to Latents", tags=["latents", "image", "vae", "i2l"], category="latents", version="1.0.0"
|
||||
)
|
||||
@invocation("i2l", title="Image to Latents", tags=["latents", "image", "vae", "i2l"], category="latents")
|
||||
class ImageToLatentsInvocation(BaseInvocation):
|
||||
"""Encodes an image into latents."""
|
||||
|
||||
@ -808,7 +778,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
return build_latents_output(latents_name=name, latents=latents, seed=None)
|
||||
|
||||
|
||||
@invocation("lblend", title="Blend Latents", tags=["latents", "blend"], category="latents", version="1.0.0")
|
||||
@invocation("lblend", title="Blend Latents", tags=["latents", "blend"], category="latents")
|
||||
class BlendLatentsInvocation(BaseInvocation):
|
||||
"""Blend two latents using a given alpha. Latents must have same size."""
|
||||
|
||||
|
@ -7,7 +7,7 @@ from invokeai.app.invocations.primitives import IntegerOutput
|
||||
from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, invocation
|
||||
|
||||
|
||||
@invocation("add", title="Add Integers", tags=["math", "add"], category="math", version="1.0.0")
|
||||
@invocation("add", title="Add Integers", tags=["math", "add"], category="math")
|
||||
class AddInvocation(BaseInvocation):
|
||||
"""Adds two numbers"""
|
||||
|
||||
@ -18,7 +18,7 @@ class AddInvocation(BaseInvocation):
|
||||
return IntegerOutput(value=self.a + self.b)
|
||||
|
||||
|
||||
@invocation("sub", title="Subtract Integers", tags=["math", "subtract"], category="math", version="1.0.0")
|
||||
@invocation("sub", title="Subtract Integers", tags=["math", "subtract"], category="math")
|
||||
class SubtractInvocation(BaseInvocation):
|
||||
"""Subtracts two numbers"""
|
||||
|
||||
@ -29,7 +29,7 @@ class SubtractInvocation(BaseInvocation):
|
||||
return IntegerOutput(value=self.a - self.b)
|
||||
|
||||
|
||||
@invocation("mul", title="Multiply Integers", tags=["math", "multiply"], category="math", version="1.0.0")
|
||||
@invocation("mul", title="Multiply Integers", tags=["math", "multiply"], category="math")
|
||||
class MultiplyInvocation(BaseInvocation):
|
||||
"""Multiplies two numbers"""
|
||||
|
||||
@ -40,7 +40,7 @@ class MultiplyInvocation(BaseInvocation):
|
||||
return IntegerOutput(value=self.a * self.b)
|
||||
|
||||
|
||||
@invocation("div", title="Divide Integers", tags=["math", "divide"], category="math", version="1.0.0")
|
||||
@invocation("div", title="Divide Integers", tags=["math", "divide"], category="math")
|
||||
class DivideInvocation(BaseInvocation):
|
||||
"""Divides two numbers"""
|
||||
|
||||
@ -51,7 +51,7 @@ class DivideInvocation(BaseInvocation):
|
||||
return IntegerOutput(value=int(self.a / self.b))
|
||||
|
||||
|
||||
@invocation("rand_int", title="Random Integer", tags=["math", "random"], category="math", version="1.0.0")
|
||||
@invocation("rand_int", title="Random Integer", tags=["math", "random"], category="math")
|
||||
class RandomIntInvocation(BaseInvocation):
|
||||
"""Outputs a single random integer."""
|
||||
|
||||
|
@ -72,10 +72,10 @@ class CoreMetadata(BaseModelExcludeNull):
|
||||
)
|
||||
refiner_steps: Optional[int] = Field(default=None, description="The number of steps used for the refiner")
|
||||
refiner_scheduler: Optional[str] = Field(default=None, description="The scheduler used for the refiner")
|
||||
refiner_positive_aesthetic_score: Optional[float] = Field(
|
||||
refiner_positive_aesthetic_store: Optional[float] = Field(
|
||||
default=None, description="The aesthetic score used for the refiner"
|
||||
)
|
||||
refiner_negative_aesthetic_score: Optional[float] = Field(
|
||||
refiner_negative_aesthetic_store: Optional[float] = Field(
|
||||
default=None, description="The aesthetic score used for the refiner"
|
||||
)
|
||||
refiner_start: Optional[float] = Field(default=None, description="The start value used for refiner denoising")
|
||||
@ -98,9 +98,7 @@ class MetadataAccumulatorOutput(BaseInvocationOutput):
|
||||
metadata: CoreMetadata = OutputField(description="The core metadata for the image")
|
||||
|
||||
|
||||
@invocation(
|
||||
"metadata_accumulator", title="Metadata Accumulator", tags=["metadata"], category="metadata", version="1.0.0"
|
||||
)
|
||||
@invocation("metadata_accumulator", title="Metadata Accumulator", tags=["metadata"], category="metadata")
|
||||
class MetadataAccumulatorInvocation(BaseInvocation):
|
||||
"""Outputs a Core Metadata Object"""
|
||||
|
||||
@ -162,11 +160,11 @@ class MetadataAccumulatorInvocation(BaseInvocation):
|
||||
default=None,
|
||||
description="The scheduler used for the refiner",
|
||||
)
|
||||
refiner_positive_aesthetic_score: Optional[float] = InputField(
|
||||
refiner_positive_aesthetic_store: Optional[float] = InputField(
|
||||
default=None,
|
||||
description="The aesthetic score used for the refiner",
|
||||
)
|
||||
refiner_negative_aesthetic_score: Optional[float] = InputField(
|
||||
refiner_negative_aesthetic_store: Optional[float] = InputField(
|
||||
default=None,
|
||||
description="The aesthetic score used for the refiner",
|
||||
)
|
||||
|
@ -73,7 +73,7 @@ class LoRAModelField(BaseModel):
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
|
||||
@invocation("main_model_loader", title="Main Model", tags=["model"], category="model", version="1.0.0")
|
||||
@invocation("main_model_loader", title="Main Model", tags=["model"], category="model")
|
||||
class MainModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a main model, outputting its submodels."""
|
||||
|
||||
@ -173,7 +173,7 @@ class LoraLoaderOutput(BaseInvocationOutput):
|
||||
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
|
||||
|
||||
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.0")
|
||||
@invocation("lora_loader", title="LoRA", tags=["model"], category="model")
|
||||
class LoraLoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
|
||||
@ -244,19 +244,19 @@ class SDXLLoraLoaderOutput(BaseInvocationOutput):
|
||||
clip2: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2")
|
||||
|
||||
|
||||
@invocation("sdxl_lora_loader", title="SDXL LoRA", tags=["lora", "model"], category="model", version="1.0.0")
|
||||
@invocation("sdxl_lora_loader", title="SDXL LoRA", tags=["lora", "model"], category="model")
|
||||
class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
|
||||
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
unet: Optional[UNetField] = InputField(
|
||||
default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNet"
|
||||
weight: float = Field(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
unet: Optional[UNetField] = Field(
|
||||
default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNET"
|
||||
)
|
||||
clip: Optional[ClipField] = InputField(
|
||||
clip: Optional[ClipField] = Field(
|
||||
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1"
|
||||
)
|
||||
clip2: Optional[ClipField] = InputField(
|
||||
clip2: Optional[ClipField] = Field(
|
||||
default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2"
|
||||
)
|
||||
|
||||
@ -338,7 +338,7 @@ class VaeLoaderOutput(BaseInvocationOutput):
|
||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.0")
|
||||
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model")
|
||||
class VaeLoaderInvocation(BaseInvocation):
|
||||
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
||||
|
||||
@ -376,7 +376,7 @@ class SeamlessModeOutput(BaseInvocationOutput):
|
||||
vae: Optional[VaeField] = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@invocation("seamless", title="Seamless", tags=["seamless", "model"], category="model", version="1.0.0")
|
||||
@invocation("seamless", title="Seamless", tags=["seamless", "model"], category="model")
|
||||
class SeamlessModeInvocation(BaseInvocation):
|
||||
"""Applies the seamless transformation to the Model UNet and VAE."""
|
||||
|
||||
|
@ -78,7 +78,7 @@ def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int):
|
||||
)
|
||||
|
||||
|
||||
@invocation("noise", title="Noise", tags=["latents", "noise"], category="latents", version="1.0.0")
|
||||
@invocation("noise", title="Noise", tags=["latents", "noise"], category="latents")
|
||||
class NoiseInvocation(BaseInvocation):
|
||||
"""Generates latent noise."""
|
||||
|
||||
|
@ -56,7 +56,7 @@ ORT_TO_NP_TYPE = {
|
||||
PRECISION_VALUES = Literal[tuple(list(ORT_TO_NP_TYPE.keys()))]
|
||||
|
||||
|
||||
@invocation("prompt_onnx", title="ONNX Prompt (Raw)", tags=["prompt", "onnx"], category="conditioning", version="1.0.0")
|
||||
@invocation("prompt_onnx", title="ONNX Prompt (Raw)", tags=["prompt", "onnx"], category="conditioning")
|
||||
class ONNXPromptInvocation(BaseInvocation):
|
||||
prompt: str = InputField(default="", description=FieldDescriptions.raw_prompt, ui_component=UIComponent.Textarea)
|
||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||
@ -143,7 +143,6 @@ class ONNXPromptInvocation(BaseInvocation):
|
||||
title="ONNX Text to Latents",
|
||||
tags=["latents", "inference", "txt2img", "onnx"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
"""Generates latents from conditionings."""
|
||||
@ -320,7 +319,6 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
title="ONNX Latents to Image",
|
||||
tags=["latents", "image", "vae", "onnx"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ONNXLatentsToImageInvocation(BaseInvocation):
|
||||
"""Generates an image from latents."""
|
||||
@ -405,7 +403,7 @@ class OnnxModelField(BaseModel):
|
||||
model_type: ModelType = Field(description="Model Type")
|
||||
|
||||
|
||||
@invocation("onnx_model_loader", title="ONNX Main Model", tags=["onnx", "model"], category="model", version="1.0.0")
|
||||
@invocation("onnx_model_loader", title="ONNX Main Model", tags=["onnx", "model"], category="model")
|
||||
class OnnxModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a main model, outputting its submodels."""
|
||||
|
||||
|
@ -45,7 +45,7 @@ from invokeai.app.invocations.primitives import FloatCollectionOutput
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||
|
||||
|
||||
@invocation("float_range", title="Float Range", tags=["math", "range"], category="math", version="1.0.0")
|
||||
@invocation("float_range", title="Float Range", tags=["math", "range"], category="math")
|
||||
class FloatLinearRangeInvocation(BaseInvocation):
|
||||
"""Creates a range"""
|
||||
|
||||
@ -96,7 +96,7 @@ EASING_FUNCTION_KEYS = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))]
|
||||
|
||||
|
||||
# actually I think for now could just use CollectionOutput (which is list[Any]
|
||||
@invocation("step_param_easing", title="Step Param Easing", tags=["step", "easing"], category="step", version="1.0.0")
|
||||
@invocation("step_param_easing", title="Step Param Easing", tags=["step", "easing"], category="step")
|
||||
class StepParamEasingInvocation(BaseInvocation):
|
||||
"""Experimental per-step parameter easing for denoising steps"""
|
||||
|
||||
|
@ -14,6 +14,7 @@ from .baseinvocation import (
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIComponent,
|
||||
UIType,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@ -39,14 +40,10 @@ class BooleanOutput(BaseInvocationOutput):
|
||||
class BooleanCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of booleans"""
|
||||
|
||||
collection: list[bool] = OutputField(
|
||||
description="The output boolean collection",
|
||||
)
|
||||
collection: list[bool] = OutputField(description="The output boolean collection", ui_type=UIType.BooleanCollection)
|
||||
|
||||
|
||||
@invocation(
|
||||
"boolean", title="Boolean Primitive", tags=["primitives", "boolean"], category="primitives", version="1.0.0"
|
||||
)
|
||||
@invocation("boolean", title="Boolean Primitive", tags=["primitives", "boolean"], category="primitives")
|
||||
class BooleanInvocation(BaseInvocation):
|
||||
"""A boolean primitive value"""
|
||||
|
||||
@ -61,12 +58,13 @@ class BooleanInvocation(BaseInvocation):
|
||||
title="Boolean Collection Primitive",
|
||||
tags=["primitives", "boolean", "collection"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
)
|
||||
class BooleanCollectionInvocation(BaseInvocation):
|
||||
"""A collection of boolean primitive values"""
|
||||
|
||||
collection: list[bool] = InputField(default_factory=list, description="The collection of boolean values")
|
||||
collection: list[bool] = InputField(
|
||||
default_factory=list, description="The collection of boolean values", ui_type=UIType.BooleanCollection
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> BooleanCollectionOutput:
|
||||
return BooleanCollectionOutput(collection=self.collection)
|
||||
@ -88,14 +86,10 @@ class IntegerOutput(BaseInvocationOutput):
|
||||
class IntegerCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of integers"""
|
||||
|
||||
collection: list[int] = OutputField(
|
||||
description="The int collection",
|
||||
)
|
||||
collection: list[int] = OutputField(description="The int collection", ui_type=UIType.IntegerCollection)
|
||||
|
||||
|
||||
@invocation(
|
||||
"integer", title="Integer Primitive", tags=["primitives", "integer"], category="primitives", version="1.0.0"
|
||||
)
|
||||
@invocation("integer", title="Integer Primitive", tags=["primitives", "integer"], category="primitives")
|
||||
class IntegerInvocation(BaseInvocation):
|
||||
"""An integer primitive value"""
|
||||
|
||||
@ -110,12 +104,13 @@ class IntegerInvocation(BaseInvocation):
|
||||
title="Integer Collection Primitive",
|
||||
tags=["primitives", "integer", "collection"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
)
|
||||
class IntegerCollectionInvocation(BaseInvocation):
|
||||
"""A collection of integer primitive values"""
|
||||
|
||||
collection: list[int] = InputField(default_factory=list, description="The collection of integer values")
|
||||
collection: list[int] = InputField(
|
||||
default_factory=list, description="The collection of integer values", ui_type=UIType.IntegerCollection
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
|
||||
return IntegerCollectionOutput(collection=self.collection)
|
||||
@ -137,12 +132,10 @@ class FloatOutput(BaseInvocationOutput):
|
||||
class FloatCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of floats"""
|
||||
|
||||
collection: list[float] = OutputField(
|
||||
description="The float collection",
|
||||
)
|
||||
collection: list[float] = OutputField(description="The float collection", ui_type=UIType.FloatCollection)
|
||||
|
||||
|
||||
@invocation("float", title="Float Primitive", tags=["primitives", "float"], category="primitives", version="1.0.0")
|
||||
@invocation("float", title="Float Primitive", tags=["primitives", "float"], category="primitives")
|
||||
class FloatInvocation(BaseInvocation):
|
||||
"""A float primitive value"""
|
||||
|
||||
@ -157,12 +150,13 @@ class FloatInvocation(BaseInvocation):
|
||||
title="Float Collection Primitive",
|
||||
tags=["primitives", "float", "collection"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
)
|
||||
class FloatCollectionInvocation(BaseInvocation):
|
||||
"""A collection of float primitive values"""
|
||||
|
||||
collection: list[float] = InputField(default_factory=list, description="The collection of float values")
|
||||
collection: list[float] = InputField(
|
||||
default_factory=list, description="The collection of float values", ui_type=UIType.FloatCollection
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||
return FloatCollectionOutput(collection=self.collection)
|
||||
@ -184,12 +178,10 @@ class StringOutput(BaseInvocationOutput):
|
||||
class StringCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of strings"""
|
||||
|
||||
collection: list[str] = OutputField(
|
||||
description="The output strings",
|
||||
)
|
||||
collection: list[str] = OutputField(description="The output strings", ui_type=UIType.StringCollection)
|
||||
|
||||
|
||||
@invocation("string", title="String Primitive", tags=["primitives", "string"], category="primitives", version="1.0.0")
|
||||
@invocation("string", title="String Primitive", tags=["primitives", "string"], category="primitives")
|
||||
class StringInvocation(BaseInvocation):
|
||||
"""A string primitive value"""
|
||||
|
||||
@ -204,12 +196,13 @@ class StringInvocation(BaseInvocation):
|
||||
title="String Collection Primitive",
|
||||
tags=["primitives", "string", "collection"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
)
|
||||
class StringCollectionInvocation(BaseInvocation):
|
||||
"""A collection of string primitive values"""
|
||||
|
||||
collection: list[str] = InputField(default_factory=list, description="The collection of string values")
|
||||
collection: list[str] = InputField(
|
||||
default_factory=list, description="The collection of string values", ui_type=UIType.StringCollection
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
|
||||
return StringCollectionOutput(collection=self.collection)
|
||||
@ -239,12 +232,10 @@ class ImageOutput(BaseInvocationOutput):
|
||||
class ImageCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of images"""
|
||||
|
||||
collection: list[ImageField] = OutputField(
|
||||
description="The output images",
|
||||
)
|
||||
collection: list[ImageField] = OutputField(description="The output images", ui_type=UIType.ImageCollection)
|
||||
|
||||
|
||||
@invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives", version="1.0.0")
|
||||
@invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives")
|
||||
class ImageInvocation(BaseInvocation):
|
||||
"""An image primitive value"""
|
||||
|
||||
@ -265,12 +256,13 @@ class ImageInvocation(BaseInvocation):
|
||||
title="Image Collection Primitive",
|
||||
tags=["primitives", "image", "collection"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageCollectionInvocation(BaseInvocation):
|
||||
"""A collection of image primitive values"""
|
||||
|
||||
collection: list[ImageField] = InputField(description="The collection of image values")
|
||||
collection: list[ImageField] = InputField(
|
||||
default_factory=list, description="The collection of image values", ui_type=UIType.ImageCollection
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageCollectionOutput:
|
||||
return ImageCollectionOutput(collection=self.collection)
|
||||
@ -324,12 +316,11 @@ class LatentsCollectionOutput(BaseInvocationOutput):
|
||||
|
||||
collection: list[LatentsField] = OutputField(
|
||||
description=FieldDescriptions.latents,
|
||||
ui_type=UIType.LatentsCollection,
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"latents", title="Latents Primitive", tags=["primitives", "latents"], category="primitives", version="1.0.0"
|
||||
)
|
||||
@invocation("latents", title="Latents Primitive", tags=["primitives", "latents"], category="primitives")
|
||||
class LatentsInvocation(BaseInvocation):
|
||||
"""A latents tensor primitive value"""
|
||||
|
||||
@ -346,13 +337,12 @@ class LatentsInvocation(BaseInvocation):
|
||||
title="Latents Collection Primitive",
|
||||
tags=["primitives", "latents", "collection"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
)
|
||||
class LatentsCollectionInvocation(BaseInvocation):
|
||||
"""A collection of latents tensor primitive values"""
|
||||
|
||||
collection: list[LatentsField] = InputField(
|
||||
description="The collection of latents tensors",
|
||||
description="The collection of latents tensors", ui_type=UIType.LatentsCollection
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsCollectionOutput:
|
||||
@ -395,12 +385,10 @@ class ColorOutput(BaseInvocationOutput):
|
||||
class ColorCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of colors"""
|
||||
|
||||
collection: list[ColorField] = OutputField(
|
||||
description="The output colors",
|
||||
)
|
||||
collection: list[ColorField] = OutputField(description="The output colors", ui_type=UIType.ColorCollection)
|
||||
|
||||
|
||||
@invocation("color", title="Color Primitive", tags=["primitives", "color"], category="primitives", version="1.0.0")
|
||||
@invocation("color", title="Color Primitive", tags=["primitives", "color"], category="primitives")
|
||||
class ColorInvocation(BaseInvocation):
|
||||
"""A color primitive value"""
|
||||
|
||||
@ -434,6 +422,7 @@ class ConditioningCollectionOutput(BaseInvocationOutput):
|
||||
|
||||
collection: list[ConditioningField] = OutputField(
|
||||
description="The output conditioning tensors",
|
||||
ui_type=UIType.ConditioningCollection,
|
||||
)
|
||||
|
||||
|
||||
@ -442,7 +431,6 @@ class ConditioningCollectionOutput(BaseInvocationOutput):
|
||||
title="Conditioning Primitive",
|
||||
tags=["primitives", "conditioning"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ConditioningInvocation(BaseInvocation):
|
||||
"""A conditioning tensor primitive value"""
|
||||
@ -458,7 +446,6 @@ class ConditioningInvocation(BaseInvocation):
|
||||
title="Conditioning Collection Primitive",
|
||||
tags=["primitives", "conditioning", "collection"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ConditioningCollectionInvocation(BaseInvocation):
|
||||
"""A collection of conditioning tensor primitive values"""
|
||||
@ -466,6 +453,7 @@ class ConditioningCollectionInvocation(BaseInvocation):
|
||||
collection: list[ConditioningField] = InputField(
|
||||
default_factory=list,
|
||||
description="The collection of conditioning tensors",
|
||||
ui_type=UIType.ConditioningCollection,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ConditioningCollectionOutput:
|
||||
|
@ -10,7 +10,7 @@ from invokeai.app.invocations.primitives import StringCollectionOutput
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIComponent, invocation
|
||||
|
||||
|
||||
@invocation("dynamic_prompt", title="Dynamic Prompt", tags=["prompt", "collection"], category="prompt", version="1.0.0")
|
||||
@invocation("dynamic_prompt", title="Dynamic Prompt", tags=["prompt", "collection"], category="prompt")
|
||||
class DynamicPromptInvocation(BaseInvocation):
|
||||
"""Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator"""
|
||||
|
||||
@ -29,7 +29,7 @@ class DynamicPromptInvocation(BaseInvocation):
|
||||
return StringCollectionOutput(collection=prompts)
|
||||
|
||||
|
||||
@invocation("prompt_from_file", title="Prompts from File", tags=["prompt", "file"], category="prompt", version="1.0.0")
|
||||
@invocation("prompt_from_file", title="Prompts from File", tags=["prompt", "file"], category="prompt")
|
||||
class PromptsFromFileInvocation(BaseInvocation):
|
||||
"""Loads prompts from a text file"""
|
||||
|
||||
|
@ -33,7 +33,7 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
|
||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.0")
|
||||
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model")
|
||||
class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an sdxl base model, outputting its submodels."""
|
||||
|
||||
@ -119,7 +119,6 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
title="SDXL Refiner Model",
|
||||
tags=["model", "sdxl", "refiner"],
|
||||
category="model",
|
||||
version="1.0.0",
|
||||
)
|
||||
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an sdxl refiner model, outputting its submodels."""
|
||||
|
@ -23,7 +23,7 @@ ESRGAN_MODELS = Literal[
|
||||
]
|
||||
|
||||
|
||||
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.0.0")
|
||||
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan")
|
||||
class ESRGANInvocation(BaseInvocation):
|
||||
"""Upscales an image using RealESRGAN."""
|
||||
|
||||
|
@ -1,215 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from itertools import product
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.typing import Event
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.services.batch_manager_storage import (
|
||||
Batch,
|
||||
BatchProcess,
|
||||
BatchProcessStorageBase,
|
||||
BatchSession,
|
||||
BatchSessionChanges,
|
||||
BatchSessionNotFoundException,
|
||||
)
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.app.services.graph import Graph, GraphExecutionState
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
|
||||
|
||||
class BatchProcessResponse(BaseModel):
|
||||
batch_id: str = Field(description="ID for the batch")
|
||||
session_ids: list[str] = Field(description="List of session IDs created for this batch")
|
||||
|
||||
|
||||
class BatchManagerBase(ABC):
|
||||
@abstractmethod
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
"""Starts the BatchManager service"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_batch_process(self, batch: Batch, graph: Graph) -> BatchProcessResponse:
|
||||
"""Creates a batch process"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def run_batch_process(self, batch_id: str) -> None:
|
||||
"""Runs a batch process"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_batch_process(self, batch_process_id: str) -> None:
|
||||
"""Cancels a batch process"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_batch(self, batch_id: str) -> BatchProcessResponse:
|
||||
"""Gets a batch process"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_batch_processes(self) -> list[BatchProcessResponse]:
|
||||
"""Gets all batch processes"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_incomplete_batch_processes(self) -> list[BatchProcessResponse]:
|
||||
"""Gets all incomplete batch processes"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_sessions(self, batch_id: str) -> list[BatchSession]:
|
||||
"""Gets the sessions associated with a batch"""
|
||||
pass
|
||||
|
||||
|
||||
class BatchManager(BatchManagerBase):
|
||||
"""Responsible for managing currently running and scheduled batch jobs"""
|
||||
|
||||
__invoker: Invoker
|
||||
__batch_process_storage: BatchProcessStorageBase
|
||||
|
||||
def __init__(self, batch_process_storage: BatchProcessStorageBase) -> None:
|
||||
super().__init__()
|
||||
self.__batch_process_storage = batch_process_storage
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self.__invoker = invoker
|
||||
local_handler.register(event_name=EventServiceBase.session_event, _func=self.on_event)
|
||||
|
||||
async def on_event(self, event: Event):
|
||||
event_name = event[1]["event"]
|
||||
|
||||
match event_name:
|
||||
case "graph_execution_state_complete":
|
||||
await self._process(event, False)
|
||||
case "invocation_error":
|
||||
await self._process(event, True)
|
||||
|
||||
return event
|
||||
|
||||
async def _process(self, event: Event, err: bool) -> None:
|
||||
data = event[1]["data"]
|
||||
try:
|
||||
batch_session = self.__batch_process_storage.get_session_by_session_id(data["graph_execution_state_id"])
|
||||
except BatchSessionNotFoundException:
|
||||
return None
|
||||
changes = BatchSessionChanges(state="error" if err else "completed")
|
||||
batch_session = self.__batch_process_storage.update_session_state(
|
||||
batch_session.batch_id,
|
||||
batch_session.session_id,
|
||||
changes,
|
||||
)
|
||||
sessions = self.get_sessions(batch_session.batch_id)
|
||||
batch_process = self.__batch_process_storage.get(batch_session.batch_id)
|
||||
if not batch_process.canceled:
|
||||
self.run_batch_process(batch_process.batch_id)
|
||||
|
||||
def _create_graph_execution_state(
|
||||
self, batch_process: BatchProcess, batch_indices: tuple[int, ...]
|
||||
) -> GraphExecutionState:
|
||||
graph = batch_process.graph.copy(deep=True)
|
||||
batch = batch_process.batch
|
||||
for index, bdl in enumerate(batch.data):
|
||||
for bd in bdl:
|
||||
node = graph.get_node(bd.node_path)
|
||||
if node is None:
|
||||
continue
|
||||
batch_index = batch_indices[index]
|
||||
datum = bd.items[batch_index]
|
||||
key = bd.field_name
|
||||
node.__dict__[key] = datum
|
||||
graph.update_node(bd.node_path, node)
|
||||
|
||||
return GraphExecutionState(graph=graph)
|
||||
|
||||
def run_batch_process(self, batch_id: str) -> None:
|
||||
self.__batch_process_storage.start(batch_id)
|
||||
batch_process = self.__batch_process_storage.get(batch_id)
|
||||
next_batch_index = self._get_batch_index_tuple(batch_process)
|
||||
if next_batch_index is None:
|
||||
# finished with current run
|
||||
if batch_process.current_run >= (batch_process.batch.runs - 1):
|
||||
# finished with all runs
|
||||
return
|
||||
batch_process.current_batch_index = 0
|
||||
batch_process.current_run += 1
|
||||
next_batch_index = self._get_batch_index_tuple(batch_process)
|
||||
if next_batch_index is None:
|
||||
# shouldn't happen; satisfy types
|
||||
return
|
||||
# remember to increment the batch index
|
||||
batch_process.current_batch_index += 1
|
||||
self.__batch_process_storage.save(batch_process)
|
||||
ges = self._create_graph_execution_state(batch_process=batch_process, batch_indices=next_batch_index)
|
||||
next_session = self.__batch_process_storage.create_session(
|
||||
BatchSession(
|
||||
batch_id=batch_id,
|
||||
session_id=str(uuid4()),
|
||||
state="uninitialized",
|
||||
batch_index=batch_process.current_batch_index,
|
||||
)
|
||||
)
|
||||
ges.id = next_session.session_id
|
||||
self.__invoker.services.graph_execution_manager.set(ges)
|
||||
self.__batch_process_storage.update_session_state(
|
||||
batch_id=next_session.batch_id,
|
||||
session_id=next_session.session_id,
|
||||
changes=BatchSessionChanges(state="in_progress"),
|
||||
)
|
||||
self.__invoker.services.events.emit_batch_session_created(next_session.batch_id, next_session.session_id)
|
||||
self.__invoker.invoke(ges, invoke_all=True)
|
||||
|
||||
def create_batch_process(self, batch: Batch, graph: Graph) -> BatchProcessResponse:
|
||||
batch_process = BatchProcess(
|
||||
batch=batch,
|
||||
graph=graph,
|
||||
)
|
||||
batch_process = self.__batch_process_storage.save(batch_process)
|
||||
return BatchProcessResponse(
|
||||
batch_id=batch_process.batch_id,
|
||||
session_ids=[],
|
||||
)
|
||||
|
||||
def get_sessions(self, batch_id: str) -> list[BatchSession]:
|
||||
return self.__batch_process_storage.get_sessions_by_batch_id(batch_id)
|
||||
|
||||
def get_batch(self, batch_id: str) -> BatchProcess:
|
||||
return self.__batch_process_storage.get(batch_id)
|
||||
|
||||
def get_batch_processes(self) -> list[BatchProcessResponse]:
|
||||
bps = self.__batch_process_storage.get_all()
|
||||
return self._get_batch_process_responses(bps)
|
||||
|
||||
def get_incomplete_batch_processes(self) -> list[BatchProcessResponse]:
|
||||
bps = self.__batch_process_storage.get_incomplete()
|
||||
return self._get_batch_process_responses(bps)
|
||||
|
||||
def cancel_batch_process(self, batch_process_id: str) -> None:
|
||||
self.__batch_process_storage.cancel(batch_process_id)
|
||||
|
||||
def _get_batch_process_responses(self, batch_processes: list[BatchProcess]) -> list[BatchProcessResponse]:
|
||||
sessions = list()
|
||||
res: list[BatchProcessResponse] = list()
|
||||
for bp in batch_processes:
|
||||
sessions = self.__batch_process_storage.get_sessions_by_batch_id(bp.batch_id)
|
||||
res.append(
|
||||
BatchProcessResponse(
|
||||
batch_id=bp.batch_id,
|
||||
session_ids=[session.session_id for session in sessions],
|
||||
)
|
||||
)
|
||||
return res
|
||||
|
||||
def _get_batch_index_tuple(self, batch_process: BatchProcess) -> Optional[tuple[int, ...]]:
|
||||
batch_indices = list()
|
||||
for batchdata in batch_process.batch.data:
|
||||
batch_indices.append(list(range(len(batchdata[0].items))))
|
||||
try:
|
||||
return list(product(*batch_indices))[batch_process.current_batch_index]
|
||||
except IndexError:
|
||||
return None
|
@ -1,707 +0,0 @@
|
||||
import sqlite3
|
||||
import threading
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Literal, Optional, Union, cast
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, StrictFloat, StrictInt, StrictStr, parse_raw_as, validator
|
||||
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.services.graph import Graph
|
||||
|
||||
BatchDataType = Union[StrictStr, StrictInt, StrictFloat, ImageField]
|
||||
|
||||
|
||||
class BatchData(BaseModel):
|
||||
"""
|
||||
A batch data collection.
|
||||
"""
|
||||
|
||||
node_path: str = Field(description="The node into which this batch data collection will be substituted.")
|
||||
field_name: str = Field(description="The field into which this batch data collection will be substituted.")
|
||||
items: list[BatchDataType] = Field(
|
||||
default_factory=list, description="The list of items to substitute into the node/field."
|
||||
)
|
||||
|
||||
|
||||
class Batch(BaseModel):
|
||||
"""
|
||||
A batch, consisting of a list of a list of batch data collections.
|
||||
|
||||
First, each inner list[BatchData] is zipped into a single batch data collection.
|
||||
|
||||
Then, the final batch collection is created by taking the Cartesian product of all batch data collections.
|
||||
"""
|
||||
|
||||
data: list[list[BatchData]] = Field(default_factory=list, description="The list of batch data collections.")
|
||||
runs: int = Field(default=1, description="Int stating how many times to iterate through all possible batch indices")
|
||||
|
||||
@validator("runs")
|
||||
def validate_positive_runs(cls, r: int):
|
||||
if r < 1:
|
||||
raise ValueError("runs must be a positive integer")
|
||||
return r
|
||||
|
||||
@validator("data")
|
||||
def validate_len(cls, v: list[list[BatchData]]):
|
||||
for batch_data in v:
|
||||
if any(len(batch_data[0].items) != len(i.items) for i in batch_data):
|
||||
raise ValueError("Zipped batch items must have all have same length")
|
||||
return v
|
||||
|
||||
@validator("data")
|
||||
def validate_types(cls, v: list[list[BatchData]]):
|
||||
for batch_data in v:
|
||||
for datum in batch_data:
|
||||
for item in datum.items:
|
||||
if not all(isinstance(item, type(i)) for i in datum.items):
|
||||
raise TypeError("All items in a batch must have have same type")
|
||||
return v
|
||||
|
||||
@validator("data")
|
||||
def validate_unique_field_mappings(cls, v: list[list[BatchData]]):
|
||||
paths: set[tuple[str, str]] = set()
|
||||
count: int = 0
|
||||
for batch_data in v:
|
||||
for datum in batch_data:
|
||||
paths.add((datum.node_path, datum.field_name))
|
||||
count += 1
|
||||
if len(paths) != count:
|
||||
raise ValueError("Each batch data must have unique node_id and field_name")
|
||||
return v
|
||||
|
||||
|
||||
def uuid_string():
|
||||
res = uuid.uuid4()
|
||||
return str(res)
|
||||
|
||||
|
||||
BATCH_SESSION_STATE = Literal["uninitialized", "in_progress", "completed", "error"]
|
||||
|
||||
|
||||
class BatchSession(BaseModel):
|
||||
batch_id: str = Field(defaultdescription="The Batch to which this BatchSession is attached.")
|
||||
session_id: str = Field(
|
||||
default_factory=uuid_string, description="The Session to which this BatchSession is attached."
|
||||
)
|
||||
batch_index: int = Field(description="The index of this batch session in its parent batch process")
|
||||
state: BATCH_SESSION_STATE = Field(default="uninitialized", description="The state of this BatchSession")
|
||||
|
||||
|
||||
class BatchProcess(BaseModel):
|
||||
batch_id: str = Field(default_factory=uuid_string, description="Identifier for this batch.")
|
||||
batch: Batch = Field(description="The Batch to apply to this session.")
|
||||
current_batch_index: int = Field(default=0, description="The last executed batch index")
|
||||
current_run: int = Field(default=0, description="The current run of the batch")
|
||||
canceled: bool = Field(description="Whether or not to run sessions from this batch.", default=False)
|
||||
graph: Graph = Field(description="The graph into which batch data will be inserted before being executed.")
|
||||
|
||||
|
||||
class BatchSessionChanges(BaseModel, extra=Extra.forbid):
|
||||
state: BATCH_SESSION_STATE = Field(description="The state of this BatchSession")
|
||||
|
||||
|
||||
class BatchProcessNotFoundException(Exception):
|
||||
"""Raised when an Batch Process record is not found."""
|
||||
|
||||
def __init__(self, message="BatchProcess record not found"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BatchProcessSaveException(Exception):
|
||||
"""Raised when an Batch Process record cannot be saved."""
|
||||
|
||||
def __init__(self, message="BatchProcess record not saved"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BatchProcessDeleteException(Exception):
|
||||
"""Raised when an Batch Process record cannot be deleted."""
|
||||
|
||||
def __init__(self, message="BatchProcess record not deleted"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BatchSessionNotFoundException(Exception):
|
||||
"""Raised when an Batch Session record is not found."""
|
||||
|
||||
def __init__(self, message="BatchSession record not found"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BatchSessionSaveException(Exception):
|
||||
"""Raised when an Batch Session record cannot be saved."""
|
||||
|
||||
def __init__(self, message="BatchSession record not saved"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BatchSessionDeleteException(Exception):
|
||||
"""Raised when an Batch Session record cannot be deleted."""
|
||||
|
||||
def __init__(self, message="BatchSession record not deleted"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BatchProcessStorageBase(ABC):
|
||||
"""Low-level service responsible for interfacing with the Batch Process record store."""
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, batch_id: str) -> None:
|
||||
"""Deletes a BatchProcess record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(
|
||||
self,
|
||||
batch_process: BatchProcess,
|
||||
) -> BatchProcess:
|
||||
"""Saves a BatchProcess record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(
|
||||
self,
|
||||
batch_id: str,
|
||||
) -> BatchProcess:
|
||||
"""Gets a BatchProcess record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_all(
|
||||
self,
|
||||
) -> list[BatchProcess]:
|
||||
"""Gets a BatchProcess record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_incomplete(
|
||||
self,
|
||||
) -> list[BatchProcess]:
|
||||
"""Gets a BatchProcess record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def start(
|
||||
self,
|
||||
batch_id: str,
|
||||
) -> None:
|
||||
"""'Starts' a BatchProcess record by marking its `canceled` attribute to False."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel(
|
||||
self,
|
||||
batch_id: str,
|
||||
) -> None:
|
||||
"""'Cancels' a BatchProcess record by setting its `canceled` attribute to True."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_session(
|
||||
self,
|
||||
session: BatchSession,
|
||||
) -> BatchSession:
|
||||
"""Creates a BatchSession attached to a BatchProcess."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_sessions(
|
||||
self,
|
||||
sessions: list[BatchSession],
|
||||
) -> list[BatchSession]:
|
||||
"""Creates many BatchSessions attached to a BatchProcess."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_session_by_session_id(self, session_id: str) -> BatchSession:
|
||||
"""Gets a BatchSession by session_id"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_sessions_by_batch_id(self, batch_id: str) -> List[BatchSession]:
|
||||
"""Gets all BatchSession's for a given BatchProcess id."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_sessions_by_session_ids(self, session_ids: list[str]) -> List[BatchSession]:
|
||||
"""Gets all BatchSession's for a given list of session ids."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_next_session(self, batch_id: str) -> BatchSession:
|
||||
"""Gets the next BatchSession with state `uninitialized`, for a given BatchProcess id."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_session_state(
|
||||
self,
|
||||
batch_id: str,
|
||||
session_id: str,
|
||||
changes: BatchSessionChanges,
|
||||
) -> BatchSession:
|
||||
"""Updates the state of a BatchSession record."""
|
||||
pass
|
||||
|
||||
|
||||
class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.Lock
|
||||
|
||||
def __init__(self, conn: sqlite3.Connection) -> None:
|
||||
super().__init__()
|
||||
self._conn = conn
|
||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._cursor = self._conn.cursor()
|
||||
self._lock = threading.Lock()
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
||||
# Enable foreign keys
|
||||
self._conn.execute("PRAGMA foreign_keys = ON;")
|
||||
self._create_tables()
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
"""Creates the `batch_process` table and `batch_session` junction table."""
|
||||
|
||||
# Create the `batch_process` table.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS batch_process (
|
||||
batch_id TEXT NOT NULL PRIMARY KEY,
|
||||
batch TEXT NOT NULL,
|
||||
graph TEXT NOT NULL,
|
||||
current_batch_index NUMBER NOT NULL,
|
||||
current_run NUMBER NOT NULL,
|
||||
canceled BOOLEAN NOT NULL DEFAULT(0),
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Soft delete, currently unused
|
||||
deleted_at DATETIME
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_batch_process_created_at ON batch_process (created_at);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_batch_process_updated_at
|
||||
AFTER UPDATE
|
||||
ON batch_process FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE batch_process SET updated_at = current_timestamp
|
||||
WHERE batch_id = old.batch_id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
# Create the `batch_session` junction table.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS batch_session (
|
||||
batch_id TEXT NOT NULL,
|
||||
session_id TEXT NOT NULL,
|
||||
state TEXT NOT NULL,
|
||||
batch_index NUMBER NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Soft delete, currently unused
|
||||
deleted_at DATETIME,
|
||||
-- enforce one-to-many relationship between batch_process and batch_session using PK
|
||||
-- (we can extend this to many-to-many later)
|
||||
PRIMARY KEY (batch_id,session_id),
|
||||
FOREIGN KEY (batch_id) REFERENCES batch_process (batch_id) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add index for batch id
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_batch_session_batch_id ON batch_session (batch_id);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add index for batch id, sorted by created_at
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_batch_session_batch_id_created_at ON batch_session (batch_id,created_at);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_batch_session_updated_at
|
||||
AFTER UPDATE
|
||||
ON batch_session FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE batch_session SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE batch_id = old.batch_id AND session_id = old.session_id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
def delete(self, batch_id: str) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM batch_process
|
||||
WHERE batch_id = ?;
|
||||
""",
|
||||
(batch_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchProcessDeleteException from e
|
||||
except Exception as e:
|
||||
self._conn.rollback()
|
||||
raise BatchProcessDeleteException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def save(
|
||||
self,
|
||||
batch_process: BatchProcess,
|
||||
) -> BatchProcess:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR REPLACE INTO batch_process (batch_id, batch, graph, current_batch_index, current_run)
|
||||
VALUES (?, ?, ?, ?, ?);
|
||||
""",
|
||||
(
|
||||
batch_process.batch_id,
|
||||
batch_process.batch.json(),
|
||||
batch_process.graph.json(),
|
||||
batch_process.current_batch_index,
|
||||
batch_process.current_run,
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchProcessSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get(batch_process.batch_id)
|
||||
|
||||
def _deserialize_batch_process(self, session_dict: dict) -> BatchProcess:
|
||||
"""Deserializes a batch session."""
|
||||
|
||||
# Retrieve all the values, setting "reasonable" defaults if they are not present.
|
||||
|
||||
batch_id = session_dict.get("batch_id", "unknown")
|
||||
batch_raw = session_dict.get("batch", "unknown")
|
||||
graph_raw = session_dict.get("graph", "unknown")
|
||||
current_batch_index = session_dict.get("current_batch_index", 0)
|
||||
current_run = session_dict.get("current_run", 0)
|
||||
canceled = session_dict.get("canceled", 0)
|
||||
return BatchProcess(
|
||||
batch_id=batch_id,
|
||||
batch=parse_raw_as(Batch, batch_raw),
|
||||
graph=parse_raw_as(Graph, graph_raw),
|
||||
current_batch_index=current_batch_index,
|
||||
current_run=current_run,
|
||||
canceled=canceled == 1,
|
||||
)
|
||||
|
||||
def get(
|
||||
self,
|
||||
batch_id: str,
|
||||
) -> BatchProcess:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM batch_process
|
||||
WHERE batch_id = ?;
|
||||
""",
|
||||
(batch_id,),
|
||||
)
|
||||
|
||||
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchProcessNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
if result is None:
|
||||
raise BatchProcessNotFoundException
|
||||
return self._deserialize_batch_process(dict(result))
|
||||
|
||||
def get_all(
|
||||
self,
|
||||
) -> list[BatchProcess]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM batch_process
|
||||
"""
|
||||
)
|
||||
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchProcessNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
if result is None:
|
||||
return list()
|
||||
return list(map(lambda r: self._deserialize_batch_process(dict(r)), result))
|
||||
|
||||
def get_incomplete(
|
||||
self,
|
||||
) -> list[BatchProcess]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT bp.*
|
||||
FROM batch_process bp
|
||||
WHERE bp.batch_id IN
|
||||
(
|
||||
SELECT batch_id
|
||||
FROM batch_session bs
|
||||
WHERE state IN ('uninitialized', 'in_progress')
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchProcessNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
if result is None:
|
||||
return list()
|
||||
return list(map(lambda r: self._deserialize_batch_process(dict(r)), result))
|
||||
|
||||
def start(
|
||||
self,
|
||||
batch_id: str,
|
||||
) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE batch_process
|
||||
SET canceled = 0
|
||||
WHERE batch_id = ?;
|
||||
""",
|
||||
(batch_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchSessionSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def cancel(
|
||||
self,
|
||||
batch_id: str,
|
||||
) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE batch_process
|
||||
SET canceled = 1
|
||||
WHERE batch_id = ?;
|
||||
""",
|
||||
(batch_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchSessionSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def create_session(
|
||||
self,
|
||||
session: BatchSession,
|
||||
) -> BatchSession:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO batch_session (batch_id, session_id, state, batch_index)
|
||||
VALUES (?, ?, ?, ?);
|
||||
""",
|
||||
(session.batch_id, session.session_id, session.state, session.batch_index),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchSessionSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get_session_by_session_id(session.session_id)
|
||||
|
||||
def create_sessions(
|
||||
self,
|
||||
sessions: list[BatchSession],
|
||||
) -> list[BatchSession]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
session_data = [(session.batch_id, session.session_id, session.state) for session in sessions]
|
||||
self._cursor.executemany(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO batch_session (batch_id, session_id, state, batch_index)
|
||||
VALUES (?, ?, ?, ?);
|
||||
""",
|
||||
session_data,
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchSessionSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get_sessions_by_session_ids([session.session_id for session in sessions])
|
||||
|
||||
def get_session_by_session_id(self, session_id: str) -> BatchSession:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM batch_session
|
||||
WHERE session_id= ?;
|
||||
""",
|
||||
(session_id,),
|
||||
)
|
||||
|
||||
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchSessionNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
if result is None:
|
||||
raise BatchSessionNotFoundException
|
||||
return self._deserialize_batch_session(dict(result))
|
||||
|
||||
def _deserialize_batch_session(self, session_dict: dict) -> BatchSession:
|
||||
"""Deserializes a batch session."""
|
||||
|
||||
return BatchSession.parse_obj(session_dict)
|
||||
|
||||
def get_next_session(self, batch_id: str) -> BatchSession:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM batch_session
|
||||
WHERE batch_id = ? AND state = 'uninitialized';
|
||||
""",
|
||||
(batch_id,),
|
||||
)
|
||||
|
||||
result = cast(Optional[sqlite3.Row], self._cursor.fetchone())
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchSessionNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
if result is None:
|
||||
raise BatchSessionNotFoundException
|
||||
session = self._deserialize_batch_session(dict(result))
|
||||
return session
|
||||
|
||||
def get_sessions_by_batch_id(self, batch_id: str) -> List[BatchSession]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM batch_session
|
||||
WHERE batch_id = ?;
|
||||
""",
|
||||
(batch_id,),
|
||||
)
|
||||
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchSessionNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
if result is None:
|
||||
raise BatchSessionNotFoundException
|
||||
sessions = list(map(lambda r: self._deserialize_batch_session(dict(r)), result))
|
||||
return sessions
|
||||
|
||||
def get_sessions_by_session_ids(self, session_ids: list[str]) -> List[BatchSession]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
placeholders = ",".join("?" * len(session_ids))
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
SELECT * FROM batch_session
|
||||
WHERE session_id
|
||||
IN ({placeholders})
|
||||
""",
|
||||
tuple(session_ids),
|
||||
)
|
||||
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchSessionNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
if result is None:
|
||||
raise BatchSessionNotFoundException
|
||||
sessions = list(map(lambda r: self._deserialize_batch_session(dict(r)), result))
|
||||
return sessions
|
||||
|
||||
def update_session_state(
|
||||
self,
|
||||
batch_id: str,
|
||||
session_id: str,
|
||||
changes: BatchSessionChanges,
|
||||
) -> BatchSession:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
# Change the state of a batch session
|
||||
if changes.state is not None:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE batch_session
|
||||
SET state = ?
|
||||
WHERE batch_id = ? AND session_id = ?;
|
||||
""",
|
||||
(changes.state, batch_id, session_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BatchSessionSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get_session_by_session_id(session_id)
|
@ -56,13 +56,15 @@ class BoardImageRecordStorageBase(ABC):
|
||||
|
||||
|
||||
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
_filename: str
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.Lock
|
||||
|
||||
def __init__(self, conn: sqlite3.Connection) -> None:
|
||||
def __init__(self, filename: str) -> None:
|
||||
super().__init__()
|
||||
self._conn = conn
|
||||
self._filename = filename
|
||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._cursor = self._conn.cursor()
|
||||
|
@ -89,13 +89,15 @@ class BoardRecordStorageBase(ABC):
|
||||
|
||||
|
||||
class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
_filename: str
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.Lock
|
||||
|
||||
def __init__(self, conn: sqlite3.Connection) -> None:
|
||||
def __init__(self, filename: str) -> None:
|
||||
super().__init__()
|
||||
self._conn = conn
|
||||
self._filename = filename
|
||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._cursor = self._conn.cursor()
|
||||
|
@ -13,7 +13,6 @@ from invokeai.app.services.model_manager_service import (
|
||||
|
||||
class EventServiceBase:
|
||||
session_event: str = "session_event"
|
||||
batch_event: str = "batch_event"
|
||||
|
||||
"""Basic event bus, to have an empty stand-in when not needed"""
|
||||
|
||||
@ -21,21 +20,12 @@ class EventServiceBase:
|
||||
pass
|
||||
|
||||
def __emit_session_event(self, event_name: str, payload: dict) -> None:
|
||||
"""Session events are emitted to a room with the session_id as the room name"""
|
||||
payload["timestamp"] = get_timestamp()
|
||||
self.dispatch(
|
||||
event_name=EventServiceBase.session_event,
|
||||
payload=dict(event=event_name, data=payload),
|
||||
)
|
||||
|
||||
def __emit_batch_event(self, event_name: str, payload: dict) -> None:
|
||||
"""Batch events are emitted to a room with the batch_id as the room name"""
|
||||
payload["timestamp"] = get_timestamp()
|
||||
self.dispatch(
|
||||
event_name=EventServiceBase.batch_event,
|
||||
payload=dict(event=event_name, data=payload),
|
||||
)
|
||||
|
||||
# Define events here for every event in the system.
|
||||
# This will make them easier to integrate until we find a schema generator.
|
||||
def emit_generator_progress(
|
||||
@ -197,14 +187,3 @@ class EventServiceBase:
|
||||
error=error,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_batch_session_created(
|
||||
self,
|
||||
batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
) -> None:
|
||||
"""Emitted when a batch session is created"""
|
||||
self.__emit_batch_event(
|
||||
event_name="batch_session_created",
|
||||
payload=dict(batch_id=batch_id, graph_execution_state_id=graph_execution_state_id),
|
||||
)
|
||||
|
@ -112,10 +112,6 @@ def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
|
||||
if to_type in get_args(from_type):
|
||||
return True
|
||||
|
||||
# allow int -> float, pydantic will cast for us
|
||||
if from_type is int and to_type is float:
|
||||
return True
|
||||
|
||||
# if not issubclass(from_type, to_type):
|
||||
if not is_union_subtype(from_type, to_type):
|
||||
return False
|
||||
@ -182,7 +178,7 @@ class IterateInvocationOutput(BaseInvocationOutput):
|
||||
|
||||
|
||||
# TODO: Fill this out and move to invocations
|
||||
@invocation("iterate", version="1.0.0")
|
||||
@invocation("iterate")
|
||||
class IterateInvocation(BaseInvocation):
|
||||
"""Iterates over a list of items"""
|
||||
|
||||
@ -203,7 +199,7 @@ class CollectInvocationOutput(BaseInvocationOutput):
|
||||
)
|
||||
|
||||
|
||||
@invocation("collect", version="1.0.0")
|
||||
@invocation("collect")
|
||||
class CollectInvocation(BaseInvocation):
|
||||
"""Collects values into a collection"""
|
||||
|
||||
|
@ -152,13 +152,15 @@ class ImageRecordStorageBase(ABC):
|
||||
|
||||
|
||||
class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
_filename: str
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.Lock
|
||||
|
||||
def __init__(self, conn: sqlite3.Connection) -> None:
|
||||
def __init__(self, filename: str) -> None:
|
||||
super().__init__()
|
||||
self._conn = conn
|
||||
self._filename = filename
|
||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._cursor = self._conn.cursor()
|
||||
|
@ -4,7 +4,6 @@ from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
from invokeai.app.services.batch_manager import BatchManagerBase
|
||||
from invokeai.app.services.board_images import BoardImagesServiceABC
|
||||
from invokeai.app.services.boards import BoardServiceABC
|
||||
from invokeai.app.services.images import ImageServiceABC
|
||||
@ -23,7 +22,6 @@ class InvocationServices:
|
||||
"""Services that can be used by invocations"""
|
||||
|
||||
# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
|
||||
batch_manager: "BatchManagerBase"
|
||||
board_images: "BoardImagesServiceABC"
|
||||
boards: "BoardServiceABC"
|
||||
configuration: "InvokeAIAppConfig"
|
||||
@ -40,7 +38,6 @@ class InvocationServices:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_manager: "BatchManagerBase",
|
||||
board_images: "BoardImagesServiceABC",
|
||||
boards: "BoardServiceABC",
|
||||
configuration: "InvokeAIAppConfig",
|
||||
@ -55,7 +52,6 @@ class InvocationServices:
|
||||
performance_statistics: "InvocationStatsServiceBase",
|
||||
queue: "InvocationQueueABC",
|
||||
):
|
||||
self.batch_manager = batch_manager
|
||||
self.board_images = board_images
|
||||
self.boards = boards
|
||||
self.boards = boards
|
||||
|
@ -12,19 +12,23 @@ sqlite_memory = ":memory:"
|
||||
|
||||
|
||||
class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
_filename: str
|
||||
_table_name: str
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_id_field: str
|
||||
_lock: Lock
|
||||
|
||||
def __init__(self, conn: sqlite3.Connection, table_name: str, id_field: str = "id"):
|
||||
def __init__(self, filename: str, table_name: str, id_field: str = "id"):
|
||||
super().__init__()
|
||||
|
||||
self._filename = filename
|
||||
self._table_name = table_name
|
||||
self._id_field = id_field # TODO: validate that T has this field
|
||||
self._lock = Lock()
|
||||
self._conn = conn
|
||||
self._conn = sqlite3.connect(
|
||||
self._filename, check_same_thread=False
|
||||
) # TODO: figure out a better threading solution
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
self._create_table()
|
||||
@ -45,7 +49,8 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
|
||||
def _parse_item(self, item: str) -> T:
|
||||
item_type = get_args(self.__orig_class__)[0]
|
||||
return parse_raw_as(item_type, item)
|
||||
parsed = parse_raw_as(item_type, item)
|
||||
return parsed
|
||||
|
||||
def set(self, item: T):
|
||||
try:
|
||||
|
@ -1,20 +0,0 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def cv2_inpaint(image: Image.Image) -> Image.Image:
|
||||
# Prepare Image
|
||||
image_array = np.array(image.convert("RGB"))
|
||||
image_cv = cv2.cvtColor(image_array, cv2.COLOR_RGB2BGR)
|
||||
|
||||
# Prepare Mask From Alpha Channel
|
||||
mask = image.split()[3].convert("RGB")
|
||||
mask_array = np.array(mask)
|
||||
mask_cv = cv2.cvtColor(mask_array, cv2.COLOR_BGR2GRAY)
|
||||
mask_inv = cv2.bitwise_not(mask_cv)
|
||||
|
||||
# Inpaint Image
|
||||
inpainted_result = cv2.inpaint(image_cv, mask_inv, 3, cv2.INPAINT_TELEA)
|
||||
inpainted_image = Image.fromarray(cv2.cvtColor(inpainted_result, cv2.COLOR_BGR2RGB))
|
||||
return inpainted_image
|
@ -5,7 +5,6 @@ import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import get_invokeai_config
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
|
||||
@ -20,7 +19,7 @@ def norm_img(np_img):
|
||||
|
||||
def load_jit_model(url_or_path, device):
|
||||
model_path = url_or_path
|
||||
logger.info(f"Loading model from: {model_path}")
|
||||
print(f"Loading model from: {model_path}")
|
||||
model = torch.jit.load(model_path, map_location="cpu").to(device)
|
||||
model.eval()
|
||||
return model
|
||||
@ -53,6 +52,5 @@ class LaMA:
|
||||
|
||||
del model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return infilled_image
|
||||
|
@ -290,20 +290,9 @@ def download_realesrgan():
|
||||
download_with_progress_bar(model["url"], config.models_path / model["dest"], model["description"])
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_lama():
|
||||
logger.info("Installing lama infill model")
|
||||
download_with_progress_bar(
|
||||
"https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
||||
config.models_path / "core/misc/lama/lama.pt",
|
||||
"lama infill model",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_support_models():
|
||||
download_realesrgan()
|
||||
download_lama()
|
||||
download_conversion_models()
|
||||
|
||||
|
||||
@ -507,7 +496,7 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||
scroll_exit=True,
|
||||
)
|
||||
else:
|
||||
self.vram = DummyWidgetValue.zero
|
||||
self.vram_cache_size = DummyWidgetValue.zero
|
||||
self.nextrely += 1
|
||||
self.outdir = self.add_widget_intelligent(
|
||||
FileBox,
|
||||
@ -605,8 +594,7 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS
|
||||
"vram",
|
||||
"outdir",
|
||||
]:
|
||||
if hasattr(self, attr):
|
||||
setattr(new_opts, attr, getattr(self, attr).value)
|
||||
setattr(new_opts, attr, getattr(self, attr).value)
|
||||
|
||||
for attr in self.autoimport_dirs:
|
||||
directory = Path(self.autoimport_dirs[attr].value)
|
||||
|
@ -50,7 +50,6 @@ class ModelProbe(object):
|
||||
"StableDiffusionInpaintPipeline": ModelType.Main,
|
||||
"StableDiffusionXLPipeline": ModelType.Main,
|
||||
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
|
||||
"StableDiffusionXLInpaintPipeline": ModelType.Main,
|
||||
"AutoencoderKL": ModelType.Vae,
|
||||
"ControlNetModel": ModelType.ControlNet,
|
||||
}
|
||||
|
@ -12,19 +12,12 @@ import torch
|
||||
import torchvision.transforms as T
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.models.controlnet import ControlNetModel
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
||||
StableDiffusionSafetyChecker,
|
||||
)
|
||||
from diffusers.models.attention_processor import AttnProcessor2_0
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.outputs import BaseOutput
|
||||
from pydantic import Field
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from .diffusion import (
|
||||
@ -34,6 +27,7 @@ from .diffusion import (
|
||||
BasicConditioningInfo,
|
||||
)
|
||||
from ..util import normalize_device, auto_detect_slice_size
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -205,145 +199,80 @@ class ConditioningData:
|
||||
return dataclasses.replace(self, scheduler_args=scheduler_args)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput):
|
||||
r"""
|
||||
Output class for InvokeAI's Stable Diffusion pipeline.
|
||||
|
||||
Args:
|
||||
attention_map_saver (`AttentionMapSaver`): Object containing attention maps that can be displayed to the user
|
||||
after generation completes. Optional.
|
||||
"""
|
||||
attention_map_saver: Optional[AttentionMapSaver]
|
||||
|
||||
|
||||
class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Implementation note: This class started as a refactored copy of diffusers.StableDiffusionPipeline.
|
||||
Hopefully future versions of diffusers provide access to more of these functions so that we don't
|
||||
need to duplicate them here: https://github.com/huggingface/diffusers/issues/551#issuecomment-1281508384
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`CLIPTextModel`]):
|
||||
Frozen text-encoder. Stable Diffusion uses the text portion of
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
||||
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
class StableDiffusionGeneratorPipeline:
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
safety_checker: Optional[StableDiffusionSafetyChecker],
|
||||
feature_extractor: Optional[CLIPFeatureExtractor],
|
||||
requires_safety_checker: bool = False,
|
||||
control_model: ControlNetModel = None,
|
||||
):
|
||||
super().__init__(
|
||||
vae,
|
||||
text_encoder,
|
||||
tokenizer,
|
||||
unet,
|
||||
scheduler,
|
||||
safety_checker,
|
||||
feature_extractor,
|
||||
requires_safety_checker,
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
# FIXME: can't currently register control module
|
||||
# control_model=control_model,
|
||||
)
|
||||
self.unet = unet
|
||||
self.scheduler = scheduler
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)
|
||||
self.control_model = control_model
|
||||
|
||||
def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
|
||||
def set_use_memory_efficient_attention_xformers(
|
||||
self, module: torch.nn.Module, valid: bool, attention_op: Optional[Callable] = None
|
||||
) -> None:
|
||||
# Recursively walk through all the children.
|
||||
# Any children which exposes the set_use_memory_efficient_attention_xformers method
|
||||
# gets the message
|
||||
def fn_recursive_set_mem_eff(module: torch.nn.Module):
|
||||
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
|
||||
module.set_use_memory_efficient_attention_xformers(valid, attention_op)
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_set_mem_eff(child)
|
||||
|
||||
fn_recursive_set_mem_eff(module)
|
||||
|
||||
def set_attention_slice(self, module: torch.nn.Module, slice_size: Optional[int]):
|
||||
if hasattr(module, "set_attention_slice"):
|
||||
module.set_attention_slice(slice_size)
|
||||
|
||||
def _adjust_memory_efficient_attention(self, model, latents: torch.Tensor):
|
||||
"""
|
||||
if xformers is available, use it, otherwise use sliced attention.
|
||||
"""
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
if config.attention_type == "xformers":
|
||||
self.enable_xformers_memory_efficient_attention()
|
||||
return
|
||||
self.set_use_memory_efficient_attention_xformers(model, True)
|
||||
|
||||
elif config.attention_type == "sliced":
|
||||
slice_size = config.attention_slice_size
|
||||
if slice_size == "auto":
|
||||
slice_size = auto_detect_slice_size(latents)
|
||||
elif slice_size == "balanced":
|
||||
|
||||
if slice_size == "balanced":
|
||||
slice_size = "auto"
|
||||
self.enable_attention_slicing(slice_size=slice_size)
|
||||
return
|
||||
self.set_attention_slice(model, slice_size=slice_size)
|
||||
|
||||
elif config.attention_type == "normal":
|
||||
self.disable_attention_slicing()
|
||||
return
|
||||
self.set_attention_slice(model, slice_size=None)
|
||||
|
||||
elif config.attention_type == "torch-sdp":
|
||||
raise Exception("torch-sdp attention slicing not yet implemented")
|
||||
if not hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
||||
raise Exception("torch-sdp requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
model.set_attn_processor(AttnProcessor2_0())
|
||||
|
||||
# the remainder if this code is called when attention_type=='auto'
|
||||
if self.unet.device.type == "cuda":
|
||||
if is_xformers_available() and not config.disable_xformers:
|
||||
self.enable_xformers_memory_efficient_attention()
|
||||
return
|
||||
elif hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
||||
# diffusers enable sdp automatically
|
||||
return
|
||||
else: # auto
|
||||
if model.device.type == "cuda":
|
||||
if is_xformers_available() and not config.disable_xformers:
|
||||
self.set_use_memory_efficient_attention_xformers(model, True)
|
||||
|
||||
if self.unet.device.type == "cpu" or self.unet.device.type == "mps":
|
||||
mem_free = psutil.virtual_memory().free
|
||||
elif self.unet.device.type == "cuda":
|
||||
mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.unet.device))
|
||||
else:
|
||||
raise ValueError(f"unrecognized device {self.unet.device}")
|
||||
# input tensor of [1, 4, h/8, w/8]
|
||||
# output tensor of [16, (h/8 * w/8), (h/8 * w/8)]
|
||||
bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4
|
||||
max_size_required_for_baddbmm = (
|
||||
16
|
||||
* latents.size(dim=2)
|
||||
* latents.size(dim=3)
|
||||
* latents.size(dim=2)
|
||||
* latents.size(dim=3)
|
||||
* bytes_per_element_needed_for_baddbmm_duplication
|
||||
)
|
||||
if max_size_required_for_baddbmm > (mem_free * 3.0 / 4.0): # 3.3 / 4.0 is from old Invoke code
|
||||
self.enable_attention_slicing(slice_size="max")
|
||||
elif torch.backends.mps.is_available():
|
||||
# diffusers recommends always enabling for mps
|
||||
self.enable_attention_slicing(slice_size="max")
|
||||
else:
|
||||
self.disable_attention_slicing()
|
||||
elif hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
||||
model.set_attn_processor(AttnProcessor2_0())
|
||||
|
||||
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False):
|
||||
raise Exception("Should not be called")
|
||||
else:
|
||||
if model.device.type == "cpu" or model.device.type == "mps":
|
||||
mem_free = psutil.virtual_memory().free
|
||||
elif model.device.type == "cuda":
|
||||
mem_free, _ = torch.cuda.mem_get_info(normalize_device(model.device))
|
||||
else:
|
||||
raise ValueError(f"unrecognized device {model.device}")
|
||||
|
||||
slice_size = auto_detect_slice_size(latents)
|
||||
if slice_size == "balanced":
|
||||
slice_size = "auto"
|
||||
self.set_attention_slice(model, slice_size=slice_size)
|
||||
|
||||
def latents_from_embeddings(
|
||||
self,
|
||||
@ -429,7 +358,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
control_data: List[ControlNetData] = None,
|
||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
):
|
||||
self._adjust_memory_efficient_attention(latents)
|
||||
self._adjust_memory_efficient_attention(self.unet, latents)
|
||||
if control_data is not None:
|
||||
for control in control_data:
|
||||
self._adjust_memory_efficient_attention(control.model, latents)
|
||||
|
||||
if additional_guidance is None:
|
||||
additional_guidance = []
|
||||
|
||||
@ -457,7 +390,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
)
|
||||
|
||||
# print("timesteps:", timesteps)
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
for i, t in enumerate(tqdm(timesteps)):
|
||||
batched_t = t.expand(batch_size)
|
||||
step_output = self.step(
|
||||
batched_t,
|
||||
|
@ -265,7 +265,7 @@ class InvokeAICrossAttentionMixin:
|
||||
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
|
||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||
else:
|
||||
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
||||
slice_size = math.floor(2 ** 30 / (q.shape[0] * q.shape[1]))
|
||||
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
||||
|
||||
def einsum_op_mps_v2(self, q, k, v):
|
||||
|
@ -215,10 +215,7 @@ class InvokeAIDiffuserComponent:
|
||||
dim=0,
|
||||
),
|
||||
}
|
||||
(
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
) = self._concat_conditionings_for_batch(
|
||||
(encoder_hidden_states, encoder_attention_mask,) = self._concat_conditionings_for_batch(
|
||||
conditioning_data.unconditioned_embeddings.embeds,
|
||||
conditioning_data.text_embeddings.embeds,
|
||||
)
|
||||
@ -280,10 +277,7 @@ class InvokeAIDiffuserComponent:
|
||||
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
|
||||
|
||||
if wants_cross_attention_control:
|
||||
(
|
||||
unconditioned_next_x,
|
||||
conditioned_next_x,
|
||||
) = self._apply_cross_attention_controlled_conditioning(
|
||||
(unconditioned_next_x, conditioned_next_x,) = self._apply_cross_attention_controlled_conditioning(
|
||||
sample,
|
||||
timestep,
|
||||
conditioning_data,
|
||||
@ -291,10 +285,7 @@ class InvokeAIDiffuserComponent:
|
||||
**kwargs,
|
||||
)
|
||||
elif self.sequential_guidance:
|
||||
(
|
||||
unconditioned_next_x,
|
||||
conditioned_next_x,
|
||||
) = self._apply_standard_conditioning_sequentially(
|
||||
(unconditioned_next_x, conditioned_next_x,) = self._apply_standard_conditioning_sequentially(
|
||||
sample,
|
||||
timestep,
|
||||
conditioning_data,
|
||||
@ -302,10 +293,7 @@ class InvokeAIDiffuserComponent:
|
||||
)
|
||||
|
||||
else:
|
||||
(
|
||||
unconditioned_next_x,
|
||||
conditioned_next_x,
|
||||
) = self._apply_standard_conditioning(
|
||||
(unconditioned_next_x, conditioned_next_x,) = self._apply_standard_conditioning(
|
||||
sample,
|
||||
timestep,
|
||||
conditioning_data,
|
||||
|
@ -0,0 +1,6 @@
|
||||
from ldm.modules.image_degradation.bsrgan import ( # noqa: F401
|
||||
degradation_bsrgan_variant as degradation_fn_bsr,
|
||||
)
|
||||
from ldm.modules.image_degradation.bsrgan_light import ( # noqa: F401
|
||||
degradation_bsrgan_variant as degradation_fn_bsr_light,
|
||||
)
|
794
invokeai/backend/stable_diffusion/image_degradation/bsrgan.py
Normal file
794
invokeai/backend/stable_diffusion/image_degradation/bsrgan.py
Normal file
@ -0,0 +1,794 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# Super-Resolution
|
||||
# --------------------------------------------
|
||||
#
|
||||
# Kai Zhang (cskaizhang@gmail.com)
|
||||
# https://github.com/cszn
|
||||
# From 2019/03--2021/08
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
import random
|
||||
from functools import partial
|
||||
|
||||
import albumentations
|
||||
import cv2
|
||||
import ldm.modules.image_degradation.utils_image as util
|
||||
import numpy as np
|
||||
import scipy
|
||||
import scipy.stats as ss
|
||||
import torch
|
||||
from scipy import ndimage
|
||||
from scipy.interpolate import interp2d
|
||||
from scipy.linalg import orth
|
||||
|
||||
|
||||
def modcrop_np(img, sf):
|
||||
"""
|
||||
Args:
|
||||
img: numpy image, WxH or WxHxC
|
||||
sf: scale factor
|
||||
Return:
|
||||
cropped image
|
||||
"""
|
||||
w, h = img.shape[:2]
|
||||
im = np.copy(img)
|
||||
return im[: w - w % sf, : h - h % sf, ...]
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# anisotropic Gaussian kernels
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
def analytic_kernel(k):
|
||||
"""Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
|
||||
k_size = k.shape[0]
|
||||
# Calculate the big kernels size
|
||||
big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
|
||||
# Loop over the small kernel to fill the big one
|
||||
for r in range(k_size):
|
||||
for c in range(k_size):
|
||||
big_k[2 * r : 2 * r + k_size, 2 * c : 2 * c + k_size] += k[r, c] * k
|
||||
# Crop the edges of the big kernel to ignore very small values and increase run time of SR
|
||||
crop = k_size // 2
|
||||
cropped_big_k = big_k[crop:-crop, crop:-crop]
|
||||
# Normalize to 1
|
||||
return cropped_big_k / cropped_big_k.sum()
|
||||
|
||||
|
||||
def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
|
||||
"""generate an anisotropic Gaussian kernel
|
||||
Args:
|
||||
ksize : e.g., 15, kernel size
|
||||
theta : [0, pi], rotation angle range
|
||||
l1 : [0.1,50], scaling of eigenvalues
|
||||
l2 : [0.1,l1], scaling of eigenvalues
|
||||
If l1 = l2, will get an isotropic Gaussian kernel.
|
||||
Returns:
|
||||
k : kernel
|
||||
"""
|
||||
|
||||
v = np.dot(
|
||||
np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]),
|
||||
np.array([1.0, 0.0]),
|
||||
)
|
||||
V = np.array([[v[0], v[1]], [v[1], -v[0]]])
|
||||
D = np.array([[l1, 0], [0, l2]])
|
||||
Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
|
||||
k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
|
||||
|
||||
return k
|
||||
|
||||
|
||||
def gm_blur_kernel(mean, cov, size=15):
|
||||
center = size / 2.0 + 0.5
|
||||
k = np.zeros([size, size])
|
||||
for y in range(size):
|
||||
for x in range(size):
|
||||
cy = y - center + 1
|
||||
cx = x - center + 1
|
||||
k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
|
||||
|
||||
k = k / np.sum(k)
|
||||
return k
|
||||
|
||||
|
||||
def shift_pixel(x, sf, upper_left=True):
|
||||
"""shift pixel for super-resolution with different scale factors
|
||||
Args:
|
||||
x: WxHxC or WxH
|
||||
sf: scale factor
|
||||
upper_left: shift direction
|
||||
"""
|
||||
h, w = x.shape[:2]
|
||||
shift = (sf - 1) * 0.5
|
||||
xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
|
||||
if upper_left:
|
||||
x1 = xv + shift
|
||||
y1 = yv + shift
|
||||
else:
|
||||
x1 = xv - shift
|
||||
y1 = yv - shift
|
||||
|
||||
x1 = np.clip(x1, 0, w - 1)
|
||||
y1 = np.clip(y1, 0, h - 1)
|
||||
|
||||
if x.ndim == 2:
|
||||
x = interp2d(xv, yv, x)(x1, y1)
|
||||
if x.ndim == 3:
|
||||
for i in range(x.shape[-1]):
|
||||
x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def blur(x, k):
|
||||
"""
|
||||
x: image, NxcxHxW
|
||||
k: kernel, Nx1xhxw
|
||||
"""
|
||||
n, c = x.shape[:2]
|
||||
p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
|
||||
x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode="replicate")
|
||||
k = k.repeat(1, c, 1, 1)
|
||||
k = k.view(-1, 1, k.shape[2], k.shape[3])
|
||||
x = x.view(1, -1, x.shape[2], x.shape[3])
|
||||
x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
|
||||
x = x.view(n, c, x.shape[2], x.shape[3])
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def gen_kernel(
|
||||
k_size=np.array([15, 15]),
|
||||
scale_factor=np.array([4, 4]),
|
||||
min_var=0.6,
|
||||
max_var=10.0,
|
||||
noise_level=0,
|
||||
):
|
||||
""" "
|
||||
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
|
||||
# Kai Zhang
|
||||
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
|
||||
# max_var = 2.5 * sf
|
||||
"""
|
||||
# Set random eigen-vals (lambdas) and angle (theta) for COV matrix
|
||||
lambda_1 = min_var + np.random.rand() * (max_var - min_var)
|
||||
lambda_2 = min_var + np.random.rand() * (max_var - min_var)
|
||||
theta = np.random.rand() * np.pi # random theta
|
||||
noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
|
||||
|
||||
# Set COV matrix using Lambdas and Theta
|
||||
LAMBDA = np.diag([lambda_1, lambda_2])
|
||||
Q = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
|
||||
SIGMA = Q @ LAMBDA @ Q.T
|
||||
INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
|
||||
|
||||
# Set expectation position (shifting kernel for aligned image)
|
||||
MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
|
||||
MU = MU[None, None, :, None]
|
||||
|
||||
# Create meshgrid for Gaussian
|
||||
[X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
|
||||
Z = np.stack([X, Y], 2)[:, :, :, None]
|
||||
|
||||
# Calcualte Gaussian for every pixel of the kernel
|
||||
ZZ = Z - MU
|
||||
ZZ_t = ZZ.transpose(0, 1, 3, 2)
|
||||
raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
|
||||
|
||||
# shift the kernel so it will be centered
|
||||
# raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
|
||||
|
||||
# Normalize the kernel and return
|
||||
# kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
|
||||
kernel = raw_kernel / np.sum(raw_kernel)
|
||||
return kernel
|
||||
|
||||
|
||||
def fspecial_gaussian(hsize, sigma):
|
||||
hsize = [hsize, hsize]
|
||||
siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
|
||||
std = sigma
|
||||
[x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
|
||||
arg = -(x * x + y * y) / (2 * std * std)
|
||||
h = np.exp(arg)
|
||||
h[h < scipy.finfo(float).eps * h.max()] = 0
|
||||
sumh = h.sum()
|
||||
if sumh != 0:
|
||||
h = h / sumh
|
||||
return h
|
||||
|
||||
|
||||
def fspecial_laplacian(alpha):
|
||||
alpha = max([0, min([alpha, 1])])
|
||||
h1 = alpha / (alpha + 1)
|
||||
h2 = (1 - alpha) / (alpha + 1)
|
||||
h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
|
||||
h = np.array(h)
|
||||
return h
|
||||
|
||||
|
||||
def fspecial(filter_type, *args, **kwargs):
|
||||
"""
|
||||
python code from:
|
||||
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
|
||||
"""
|
||||
if filter_type == "gaussian":
|
||||
return fspecial_gaussian(*args, **kwargs)
|
||||
if filter_type == "laplacian":
|
||||
return fspecial_laplacian(*args, **kwargs)
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# degradation models
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
def bicubic_degradation(x, sf=3):
|
||||
"""
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]
|
||||
sf: down-scale factor
|
||||
Return:
|
||||
bicubicly downsampled LR image
|
||||
"""
|
||||
x = util.imresize_np(x, scale=1 / sf)
|
||||
return x
|
||||
|
||||
|
||||
def srmd_degradation(x, k, sf=3):
|
||||
"""blur + bicubic downsampling
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]
|
||||
k: hxw, double
|
||||
sf: down-scale factor
|
||||
Return:
|
||||
downsampled LR image
|
||||
Reference:
|
||||
@inproceedings{zhang2018learning,
|
||||
title={Learning a single convolutional super-resolution network for multiple degradations},
|
||||
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
|
||||
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
|
||||
pages={3262--3271},
|
||||
year={2018}
|
||||
}
|
||||
"""
|
||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap") # 'nearest' | 'mirror'
|
||||
x = bicubic_degradation(x, sf=sf)
|
||||
return x
|
||||
|
||||
|
||||
def dpsr_degradation(x, k, sf=3):
|
||||
"""bicubic downsampling + blur
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]
|
||||
k: hxw, double
|
||||
sf: down-scale factor
|
||||
Return:
|
||||
downsampled LR image
|
||||
Reference:
|
||||
@inproceedings{zhang2019deep,
|
||||
title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
|
||||
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
|
||||
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
|
||||
pages={1671--1681},
|
||||
year={2019}
|
||||
}
|
||||
"""
|
||||
x = bicubic_degradation(x, sf=sf)
|
||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap")
|
||||
return x
|
||||
|
||||
|
||||
def classical_degradation(x, k, sf=3):
|
||||
"""blur + downsampling
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]/[0, 255]
|
||||
k: hxw, double
|
||||
sf: down-scale factor
|
||||
Return:
|
||||
downsampled LR image
|
||||
"""
|
||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap")
|
||||
# x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
|
||||
st = 0
|
||||
return x[st::sf, st::sf, ...]
|
||||
|
||||
|
||||
def add_sharpening(img, weight=0.5, radius=50, threshold=10):
|
||||
"""USM sharpening. borrowed from real-ESRGAN
|
||||
Input image: I; Blurry image: B.
|
||||
1. K = I + weight * (I - B)
|
||||
2. Mask = 1 if abs(I - B) > threshold, else: 0
|
||||
3. Blur mask:
|
||||
4. Out = Mask * K + (1 - Mask) * I
|
||||
Args:
|
||||
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
|
||||
weight (float): Sharp weight. Default: 1.
|
||||
radius (float): Kernel size of Gaussian blur. Default: 50.
|
||||
threshold (int):
|
||||
"""
|
||||
if radius % 2 == 0:
|
||||
radius += 1
|
||||
blur = cv2.GaussianBlur(img, (radius, radius), 0)
|
||||
residual = img - blur
|
||||
mask = np.abs(residual) * 255 > threshold
|
||||
mask = mask.astype("float32")
|
||||
soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
|
||||
|
||||
K = img + weight * residual
|
||||
K = np.clip(K, 0, 1)
|
||||
return soft_mask * K + (1 - soft_mask) * img
|
||||
|
||||
|
||||
def add_blur(img, sf=4):
|
||||
wd2 = 4.0 + sf
|
||||
wd = 2.0 + 0.2 * sf
|
||||
if random.random() < 0.5:
|
||||
l1 = wd2 * random.random()
|
||||
l2 = wd2 * random.random()
|
||||
k = anisotropic_Gaussian(
|
||||
ksize=2 * random.randint(2, 11) + 3,
|
||||
theta=random.random() * np.pi,
|
||||
l1=l1,
|
||||
l2=l2,
|
||||
)
|
||||
else:
|
||||
k = fspecial("gaussian", 2 * random.randint(2, 11) + 3, wd * random.random())
|
||||
img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode="mirror")
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def add_resize(img, sf=4):
|
||||
rnum = np.random.rand()
|
||||
if rnum > 0.8: # up
|
||||
sf1 = random.uniform(1, 2)
|
||||
elif rnum < 0.7: # down
|
||||
sf1 = random.uniform(0.5 / sf, 1)
|
||||
else:
|
||||
sf1 = 1.0
|
||||
img = cv2.resize(
|
||||
img,
|
||||
(int(sf1 * img.shape[1]), int(sf1 * img.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]),
|
||||
)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
|
||||
# noise_level = random.randint(noise_level1, noise_level2)
|
||||
# rnum = np.random.rand()
|
||||
# if rnum > 0.6: # add color Gaussian noise
|
||||
# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
||||
# elif rnum < 0.4: # add grayscale Gaussian noise
|
||||
# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
||||
# else: # add noise
|
||||
# L = noise_level2 / 255.
|
||||
# D = np.diag(np.random.rand(3))
|
||||
# U = orth(np.random.rand(3, 3))
|
||||
# conv = np.dot(np.dot(np.transpose(U), D), U)
|
||||
# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
|
||||
# img = np.clip(img, 0.0, 1.0)
|
||||
# return img
|
||||
|
||||
|
||||
def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
|
||||
noise_level = random.randint(noise_level1, noise_level2)
|
||||
rnum = np.random.rand()
|
||||
if rnum > 0.6: # add color Gaussian noise
|
||||
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
||||
elif rnum < 0.4: # add grayscale Gaussian noise
|
||||
img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
||||
else: # add noise
|
||||
L = noise_level2 / 255.0
|
||||
D = np.diag(np.random.rand(3))
|
||||
U = orth(np.random.rand(3, 3))
|
||||
conv = np.dot(np.dot(np.transpose(U), D), U)
|
||||
img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
return img
|
||||
|
||||
|
||||
def add_speckle_noise(img, noise_level1=2, noise_level2=25):
|
||||
noise_level = random.randint(noise_level1, noise_level2)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
rnum = random.random()
|
||||
if rnum > 0.6:
|
||||
img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
||||
elif rnum < 0.4:
|
||||
img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
||||
else:
|
||||
L = noise_level2 / 255.0
|
||||
D = np.diag(np.random.rand(3))
|
||||
U = orth(np.random.rand(3, 3))
|
||||
conv = np.dot(np.dot(np.transpose(U), D), U)
|
||||
img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
return img
|
||||
|
||||
|
||||
def add_Poisson_noise(img):
|
||||
img = np.clip((img * 255.0).round(), 0, 255) / 255.0
|
||||
vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
|
||||
if random.random() < 0.5:
|
||||
img = np.random.poisson(img * vals).astype(np.float32) / vals
|
||||
else:
|
||||
img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
|
||||
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.0
|
||||
noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
|
||||
img += noise_gray[:, :, np.newaxis]
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
return img
|
||||
|
||||
|
||||
def add_JPEG_noise(img):
|
||||
quality_factor = random.randint(30, 95)
|
||||
img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
|
||||
result, encimg = cv2.imencode(".jpg", img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
|
||||
img = cv2.imdecode(encimg, 1)
|
||||
img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
|
||||
return img
|
||||
|
||||
|
||||
def random_crop(lq, hq, sf=4, lq_patchsize=64):
|
||||
h, w = lq.shape[:2]
|
||||
rnd_h = random.randint(0, h - lq_patchsize)
|
||||
rnd_w = random.randint(0, w - lq_patchsize)
|
||||
lq = lq[rnd_h : rnd_h + lq_patchsize, rnd_w : rnd_w + lq_patchsize, :]
|
||||
|
||||
rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
|
||||
hq = hq[
|
||||
rnd_h_H : rnd_h_H + lq_patchsize * sf,
|
||||
rnd_w_H : rnd_w_H + lq_patchsize * sf,
|
||||
:,
|
||||
]
|
||||
return lq, hq
|
||||
|
||||
|
||||
def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
|
||||
"""
|
||||
This is the degradation model of BSRGAN from the paper
|
||||
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
|
||||
----------
|
||||
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
|
||||
sf: scale factor
|
||||
isp_model: camera ISP model
|
||||
Returns
|
||||
-------
|
||||
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
|
||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
||||
"""
|
||||
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
|
||||
sf_ori = sf
|
||||
|
||||
h1, w1 = img.shape[:2]
|
||||
img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
|
||||
h, w = img.shape[:2]
|
||||
|
||||
if h < lq_patchsize * sf or w < lq_patchsize * sf:
|
||||
raise ValueError(f"img size ({h1}X{w1}) is too small!")
|
||||
|
||||
hq = img.copy()
|
||||
|
||||
if sf == 4 and random.random() < scale2_prob: # downsample1
|
||||
if np.random.rand() < 0.5:
|
||||
img = cv2.resize(
|
||||
img,
|
||||
(int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]),
|
||||
)
|
||||
else:
|
||||
img = util.imresize_np(img, 1 / 2, True)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
sf = 2
|
||||
|
||||
shuffle_order = random.sample(range(7), 7)
|
||||
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
|
||||
if idx1 > idx2: # keep downsample3 last
|
||||
shuffle_order[idx1], shuffle_order[idx2] = (
|
||||
shuffle_order[idx2],
|
||||
shuffle_order[idx1],
|
||||
)
|
||||
|
||||
for i in shuffle_order:
|
||||
if i == 0:
|
||||
img = add_blur(img, sf=sf)
|
||||
|
||||
elif i == 1:
|
||||
img = add_blur(img, sf=sf)
|
||||
|
||||
elif i == 2:
|
||||
a, b = img.shape[1], img.shape[0]
|
||||
# downsample2
|
||||
if random.random() < 0.75:
|
||||
sf1 = random.uniform(1, 2 * sf)
|
||||
img = cv2.resize(
|
||||
img,
|
||||
(int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]),
|
||||
)
|
||||
else:
|
||||
k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf))
|
||||
k_shifted = shift_pixel(k, sf)
|
||||
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
|
||||
img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode="mirror")
|
||||
img = img[0::sf, 0::sf, ...] # nearest downsampling
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
|
||||
elif i == 3:
|
||||
# downsample3
|
||||
img = cv2.resize(
|
||||
img,
|
||||
(int(1 / sf * a), int(1 / sf * b)),
|
||||
interpolation=random.choice([1, 2, 3]),
|
||||
)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
|
||||
elif i == 4:
|
||||
# add Gaussian noise
|
||||
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
|
||||
|
||||
elif i == 5:
|
||||
# add JPEG noise
|
||||
if random.random() < jpeg_prob:
|
||||
img = add_JPEG_noise(img)
|
||||
|
||||
elif i == 6:
|
||||
# add processed camera sensor noise
|
||||
if random.random() < isp_prob and isp_model is not None:
|
||||
with torch.no_grad():
|
||||
img, hq = isp_model.forward(img.copy(), hq)
|
||||
|
||||
# add final JPEG compression noise
|
||||
img = add_JPEG_noise(img)
|
||||
|
||||
# random crop
|
||||
img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
|
||||
|
||||
return img, hq
|
||||
|
||||
|
||||
# todo no isp_model?
|
||||
def degradation_bsrgan_variant(image, sf=4, isp_model=None):
|
||||
"""
|
||||
This is the degradation model of BSRGAN from the paper
|
||||
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
|
||||
----------
|
||||
sf: scale factor
|
||||
isp_model: camera ISP model
|
||||
Returns
|
||||
-------
|
||||
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
|
||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
||||
"""
|
||||
image = util.uint2single(image)
|
||||
jpeg_prob, scale2_prob = 0.9, 0.25
|
||||
# isp_prob = 0.25 # uncomment with `if i== 6` block below
|
||||
# sf_ori = sf # uncomment with `if i== 6` block below
|
||||
|
||||
h1, w1 = image.shape[:2]
|
||||
image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
|
||||
h, w = image.shape[:2]
|
||||
|
||||
# hq = image.copy() # uncomment with `if i== 6` block below
|
||||
|
||||
if sf == 4 and random.random() < scale2_prob: # downsample1
|
||||
if np.random.rand() < 0.5:
|
||||
image = cv2.resize(
|
||||
image,
|
||||
(int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]),
|
||||
)
|
||||
else:
|
||||
image = util.imresize_np(image, 1 / 2, True)
|
||||
image = np.clip(image, 0.0, 1.0)
|
||||
sf = 2
|
||||
|
||||
shuffle_order = random.sample(range(7), 7)
|
||||
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
|
||||
if idx1 > idx2: # keep downsample3 last
|
||||
shuffle_order[idx1], shuffle_order[idx2] = (
|
||||
shuffle_order[idx2],
|
||||
shuffle_order[idx1],
|
||||
)
|
||||
|
||||
for i in shuffle_order:
|
||||
if i == 0:
|
||||
image = add_blur(image, sf=sf)
|
||||
|
||||
elif i == 1:
|
||||
image = add_blur(image, sf=sf)
|
||||
|
||||
elif i == 2:
|
||||
a, b = image.shape[1], image.shape[0]
|
||||
# downsample2
|
||||
if random.random() < 0.75:
|
||||
sf1 = random.uniform(1, 2 * sf)
|
||||
image = cv2.resize(
|
||||
image,
|
||||
(
|
||||
int(1 / sf1 * image.shape[1]),
|
||||
int(1 / sf1 * image.shape[0]),
|
||||
),
|
||||
interpolation=random.choice([1, 2, 3]),
|
||||
)
|
||||
else:
|
||||
k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf))
|
||||
k_shifted = shift_pixel(k, sf)
|
||||
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
|
||||
image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode="mirror")
|
||||
image = image[0::sf, 0::sf, ...] # nearest downsampling
|
||||
image = np.clip(image, 0.0, 1.0)
|
||||
|
||||
elif i == 3:
|
||||
# downsample3
|
||||
image = cv2.resize(
|
||||
image,
|
||||
(int(1 / sf * a), int(1 / sf * b)),
|
||||
interpolation=random.choice([1, 2, 3]),
|
||||
)
|
||||
image = np.clip(image, 0.0, 1.0)
|
||||
|
||||
elif i == 4:
|
||||
# add Gaussian noise
|
||||
image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25)
|
||||
|
||||
elif i == 5:
|
||||
# add JPEG noise
|
||||
if random.random() < jpeg_prob:
|
||||
image = add_JPEG_noise(image)
|
||||
|
||||
# elif i == 6:
|
||||
# # add processed camera sensor noise
|
||||
# if random.random() < isp_prob and isp_model is not None:
|
||||
# with torch.no_grad():
|
||||
# img, hq = isp_model.forward(img.copy(), hq)
|
||||
|
||||
# add final JPEG compression noise
|
||||
image = add_JPEG_noise(image)
|
||||
image = util.single2uint(image)
|
||||
example = {"image": image}
|
||||
return example
|
||||
|
||||
|
||||
# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc...
|
||||
def degradation_bsrgan_plus(
|
||||
img,
|
||||
sf=4,
|
||||
shuffle_prob=0.5,
|
||||
use_sharp=True,
|
||||
lq_patchsize=64,
|
||||
isp_model=None,
|
||||
):
|
||||
"""
|
||||
This is an extended degradation model by combining
|
||||
the degradation models of BSRGAN and Real-ESRGAN
|
||||
----------
|
||||
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
|
||||
sf: scale factor
|
||||
use_shuffle: the degradation shuffle
|
||||
use_sharp: sharpening the img
|
||||
Returns
|
||||
-------
|
||||
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
|
||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
||||
"""
|
||||
|
||||
h1, w1 = img.shape[:2]
|
||||
img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
|
||||
h, w = img.shape[:2]
|
||||
|
||||
if h < lq_patchsize * sf or w < lq_patchsize * sf:
|
||||
raise ValueError(f"img size ({h1}X{w1}) is too small!")
|
||||
|
||||
if use_sharp:
|
||||
img = add_sharpening(img)
|
||||
hq = img.copy()
|
||||
|
||||
if random.random() < shuffle_prob:
|
||||
shuffle_order = random.sample(range(13), 13)
|
||||
else:
|
||||
shuffle_order = list(range(13))
|
||||
# local shuffle for noise, JPEG is always the last one
|
||||
shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6)))
|
||||
shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13)))
|
||||
|
||||
poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1
|
||||
|
||||
for i in shuffle_order:
|
||||
if i == 0:
|
||||
img = add_blur(img, sf=sf)
|
||||
elif i == 1:
|
||||
img = add_resize(img, sf=sf)
|
||||
elif i == 2:
|
||||
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
|
||||
elif i == 3:
|
||||
if random.random() < poisson_prob:
|
||||
img = add_Poisson_noise(img)
|
||||
elif i == 4:
|
||||
if random.random() < speckle_prob:
|
||||
img = add_speckle_noise(img)
|
||||
elif i == 5:
|
||||
if random.random() < isp_prob and isp_model is not None:
|
||||
with torch.no_grad():
|
||||
img, hq = isp_model.forward(img.copy(), hq)
|
||||
elif i == 6:
|
||||
img = add_JPEG_noise(img)
|
||||
elif i == 7:
|
||||
img = add_blur(img, sf=sf)
|
||||
elif i == 8:
|
||||
img = add_resize(img, sf=sf)
|
||||
elif i == 9:
|
||||
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
|
||||
elif i == 10:
|
||||
if random.random() < poisson_prob:
|
||||
img = add_Poisson_noise(img)
|
||||
elif i == 11:
|
||||
if random.random() < speckle_prob:
|
||||
img = add_speckle_noise(img)
|
||||
elif i == 12:
|
||||
if random.random() < isp_prob and isp_model is not None:
|
||||
with torch.no_grad():
|
||||
img, hq = isp_model.forward(img.copy(), hq)
|
||||
else:
|
||||
print("check the shuffle!")
|
||||
|
||||
# resize to desired size
|
||||
img = cv2.resize(
|
||||
img,
|
||||
(int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]),
|
||||
)
|
||||
|
||||
# add final JPEG compression noise
|
||||
img = add_JPEG_noise(img)
|
||||
|
||||
# random crop
|
||||
img, hq = random_crop(img, hq, sf, lq_patchsize)
|
||||
|
||||
return img, hq
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("hey")
|
||||
img = util.imread_uint("utils/test.png", 3)
|
||||
print(img)
|
||||
img = util.uint2single(img)
|
||||
print(img)
|
||||
img = img[:448, :448]
|
||||
h = img.shape[0] // 4
|
||||
print("resizing to", h)
|
||||
sf = 4
|
||||
deg_fn = partial(degradation_bsrgan_variant, sf=sf)
|
||||
for i in range(20):
|
||||
print(i)
|
||||
img_lq = deg_fn(img)
|
||||
print(img_lq)
|
||||
img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"]
|
||||
print(img_lq.shape)
|
||||
print("bicubic", img_lq_bicubic.shape)
|
||||
# print(img_hq.shape)
|
||||
lq_nearest = cv2.resize(
|
||||
util.single2uint(img_lq),
|
||||
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
|
||||
interpolation=0,
|
||||
)
|
||||
lq_bicubic_nearest = cv2.resize(
|
||||
util.single2uint(img_lq_bicubic),
|
||||
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
|
||||
interpolation=0,
|
||||
)
|
||||
# img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
|
||||
img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest], axis=1)
|
||||
util.imsave(img_concat, str(i) + ".png")
|
@ -0,0 +1,704 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import random
|
||||
from functools import partial
|
||||
|
||||
import albumentations
|
||||
import cv2
|
||||
import ldm.modules.image_degradation.utils_image as util
|
||||
import numpy as np
|
||||
import scipy
|
||||
import scipy.stats as ss
|
||||
import torch
|
||||
from scipy import ndimage
|
||||
from scipy.interpolate import interp2d
|
||||
from scipy.linalg import orth
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# Super-Resolution
|
||||
# --------------------------------------------
|
||||
#
|
||||
# Kai Zhang (cskaizhang@gmail.com)
|
||||
# https://github.com/cszn
|
||||
# From 2019/03--2021/08
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
def modcrop_np(img, sf):
|
||||
"""
|
||||
Args:
|
||||
img: numpy image, WxH or WxHxC
|
||||
sf: scale factor
|
||||
Return:
|
||||
cropped image
|
||||
"""
|
||||
w, h = img.shape[:2]
|
||||
im = np.copy(img)
|
||||
return im[: w - w % sf, : h - h % sf, ...]
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# anisotropic Gaussian kernels
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
def analytic_kernel(k):
|
||||
"""Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
|
||||
k_size = k.shape[0]
|
||||
# Calculate the big kernels size
|
||||
big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
|
||||
# Loop over the small kernel to fill the big one
|
||||
for r in range(k_size):
|
||||
for c in range(k_size):
|
||||
big_k[2 * r : 2 * r + k_size, 2 * c : 2 * c + k_size] += k[r, c] * k
|
||||
# Crop the edges of the big kernel to ignore very small values and increase run time of SR
|
||||
crop = k_size // 2
|
||||
cropped_big_k = big_k[crop:-crop, crop:-crop]
|
||||
# Normalize to 1
|
||||
return cropped_big_k / cropped_big_k.sum()
|
||||
|
||||
|
||||
def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
|
||||
"""generate an anisotropic Gaussian kernel
|
||||
Args:
|
||||
ksize : e.g., 15, kernel size
|
||||
theta : [0, pi], rotation angle range
|
||||
l1 : [0.1,50], scaling of eigenvalues
|
||||
l2 : [0.1,l1], scaling of eigenvalues
|
||||
If l1 = l2, will get an isotropic Gaussian kernel.
|
||||
Returns:
|
||||
k : kernel
|
||||
"""
|
||||
|
||||
v = np.dot(
|
||||
np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]),
|
||||
np.array([1.0, 0.0]),
|
||||
)
|
||||
V = np.array([[v[0], v[1]], [v[1], -v[0]]])
|
||||
D = np.array([[l1, 0], [0, l2]])
|
||||
Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
|
||||
k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
|
||||
|
||||
return k
|
||||
|
||||
|
||||
def gm_blur_kernel(mean, cov, size=15):
|
||||
center = size / 2.0 + 0.5
|
||||
k = np.zeros([size, size])
|
||||
for y in range(size):
|
||||
for x in range(size):
|
||||
cy = y - center + 1
|
||||
cx = x - center + 1
|
||||
k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
|
||||
|
||||
k = k / np.sum(k)
|
||||
return k
|
||||
|
||||
|
||||
def shift_pixel(x, sf, upper_left=True):
|
||||
"""shift pixel for super-resolution with different scale factors
|
||||
Args:
|
||||
x: WxHxC or WxH
|
||||
sf: scale factor
|
||||
upper_left: shift direction
|
||||
"""
|
||||
h, w = x.shape[:2]
|
||||
shift = (sf - 1) * 0.5
|
||||
xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
|
||||
if upper_left:
|
||||
x1 = xv + shift
|
||||
y1 = yv + shift
|
||||
else:
|
||||
x1 = xv - shift
|
||||
y1 = yv - shift
|
||||
|
||||
x1 = np.clip(x1, 0, w - 1)
|
||||
y1 = np.clip(y1, 0, h - 1)
|
||||
|
||||
if x.ndim == 2:
|
||||
x = interp2d(xv, yv, x)(x1, y1)
|
||||
if x.ndim == 3:
|
||||
for i in range(x.shape[-1]):
|
||||
x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def blur(x, k):
|
||||
"""
|
||||
x: image, NxcxHxW
|
||||
k: kernel, Nx1xhxw
|
||||
"""
|
||||
n, c = x.shape[:2]
|
||||
p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
|
||||
x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode="replicate")
|
||||
k = k.repeat(1, c, 1, 1)
|
||||
k = k.view(-1, 1, k.shape[2], k.shape[3])
|
||||
x = x.view(1, -1, x.shape[2], x.shape[3])
|
||||
x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
|
||||
x = x.view(n, c, x.shape[2], x.shape[3])
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def gen_kernel(
|
||||
k_size=np.array([15, 15]),
|
||||
scale_factor=np.array([4, 4]),
|
||||
min_var=0.6,
|
||||
max_var=10.0,
|
||||
noise_level=0,
|
||||
):
|
||||
""" "
|
||||
# modified version of https://github.com/assafshocher/BlindSR_dataset_generator
|
||||
# Kai Zhang
|
||||
# min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
|
||||
# max_var = 2.5 * sf
|
||||
"""
|
||||
# Set random eigen-vals (lambdas) and angle (theta) for COV matrix
|
||||
lambda_1 = min_var + np.random.rand() * (max_var - min_var)
|
||||
lambda_2 = min_var + np.random.rand() * (max_var - min_var)
|
||||
theta = np.random.rand() * np.pi # random theta
|
||||
noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
|
||||
|
||||
# Set COV matrix using Lambdas and Theta
|
||||
LAMBDA = np.diag([lambda_1, lambda_2])
|
||||
Q = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
|
||||
SIGMA = Q @ LAMBDA @ Q.T
|
||||
INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
|
||||
|
||||
# Set expectation position (shifting kernel for aligned image)
|
||||
MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
|
||||
MU = MU[None, None, :, None]
|
||||
|
||||
# Create meshgrid for Gaussian
|
||||
[X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
|
||||
Z = np.stack([X, Y], 2)[:, :, :, None]
|
||||
|
||||
# Calcualte Gaussian for every pixel of the kernel
|
||||
ZZ = Z - MU
|
||||
ZZ_t = ZZ.transpose(0, 1, 3, 2)
|
||||
raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
|
||||
|
||||
# shift the kernel so it will be centered
|
||||
# raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
|
||||
|
||||
# Normalize the kernel and return
|
||||
# kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
|
||||
kernel = raw_kernel / np.sum(raw_kernel)
|
||||
return kernel
|
||||
|
||||
|
||||
def fspecial_gaussian(hsize, sigma):
|
||||
hsize = [hsize, hsize]
|
||||
siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
|
||||
std = sigma
|
||||
[x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
|
||||
arg = -(x * x + y * y) / (2 * std * std)
|
||||
h = np.exp(arg)
|
||||
h[h < scipy.finfo(float).eps * h.max()] = 0
|
||||
sumh = h.sum()
|
||||
if sumh != 0:
|
||||
h = h / sumh
|
||||
return h
|
||||
|
||||
|
||||
def fspecial_laplacian(alpha):
|
||||
alpha = max([0, min([alpha, 1])])
|
||||
h1 = alpha / (alpha + 1)
|
||||
h2 = (1 - alpha) / (alpha + 1)
|
||||
h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
|
||||
h = np.array(h)
|
||||
return h
|
||||
|
||||
|
||||
def fspecial(filter_type, *args, **kwargs):
|
||||
"""
|
||||
python code from:
|
||||
https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
|
||||
"""
|
||||
if filter_type == "gaussian":
|
||||
return fspecial_gaussian(*args, **kwargs)
|
||||
if filter_type == "laplacian":
|
||||
return fspecial_laplacian(*args, **kwargs)
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# degradation models
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
def bicubic_degradation(x, sf=3):
|
||||
"""
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]
|
||||
sf: down-scale factor
|
||||
Return:
|
||||
bicubicly downsampled LR image
|
||||
"""
|
||||
x = util.imresize_np(x, scale=1 / sf)
|
||||
return x
|
||||
|
||||
|
||||
def srmd_degradation(x, k, sf=3):
|
||||
"""blur + bicubic downsampling
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]
|
||||
k: hxw, double
|
||||
sf: down-scale factor
|
||||
Return:
|
||||
downsampled LR image
|
||||
Reference:
|
||||
@inproceedings{zhang2018learning,
|
||||
title={Learning a single convolutional super-resolution network for multiple degradations},
|
||||
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
|
||||
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
|
||||
pages={3262--3271},
|
||||
year={2018}
|
||||
}
|
||||
"""
|
||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap") # 'nearest' | 'mirror'
|
||||
x = bicubic_degradation(x, sf=sf)
|
||||
return x
|
||||
|
||||
|
||||
def dpsr_degradation(x, k, sf=3):
|
||||
"""bicubic downsampling + blur
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]
|
||||
k: hxw, double
|
||||
sf: down-scale factor
|
||||
Return:
|
||||
downsampled LR image
|
||||
Reference:
|
||||
@inproceedings{zhang2019deep,
|
||||
title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
|
||||
author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
|
||||
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
|
||||
pages={1671--1681},
|
||||
year={2019}
|
||||
}
|
||||
"""
|
||||
x = bicubic_degradation(x, sf=sf)
|
||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap")
|
||||
return x
|
||||
|
||||
|
||||
def classical_degradation(x, k, sf=3):
|
||||
"""blur + downsampling
|
||||
Args:
|
||||
x: HxWxC image, [0, 1]/[0, 255]
|
||||
k: hxw, double
|
||||
sf: down-scale factor
|
||||
Return:
|
||||
downsampled LR image
|
||||
"""
|
||||
x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap")
|
||||
# x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
|
||||
st = 0
|
||||
return x[st::sf, st::sf, ...]
|
||||
|
||||
|
||||
def add_sharpening(img, weight=0.5, radius=50, threshold=10):
|
||||
"""USM sharpening. borrowed from real-ESRGAN
|
||||
Input image: I; Blurry image: B.
|
||||
1. K = I + weight * (I - B)
|
||||
2. Mask = 1 if abs(I - B) > threshold, else: 0
|
||||
3. Blur mask:
|
||||
4. Out = Mask * K + (1 - Mask) * I
|
||||
Args:
|
||||
img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
|
||||
weight (float): Sharp weight. Default: 1.
|
||||
radius (float): Kernel size of Gaussian blur. Default: 50.
|
||||
threshold (int):
|
||||
"""
|
||||
if radius % 2 == 0:
|
||||
radius += 1
|
||||
blur = cv2.GaussianBlur(img, (radius, radius), 0)
|
||||
residual = img - blur
|
||||
mask = np.abs(residual) * 255 > threshold
|
||||
mask = mask.astype("float32")
|
||||
soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
|
||||
|
||||
K = img + weight * residual
|
||||
K = np.clip(K, 0, 1)
|
||||
return soft_mask * K + (1 - soft_mask) * img
|
||||
|
||||
|
||||
def add_blur(img, sf=4):
|
||||
wd2 = 4.0 + sf
|
||||
wd = 2.0 + 0.2 * sf
|
||||
|
||||
wd2 = wd2 / 4
|
||||
wd = wd / 4
|
||||
|
||||
if random.random() < 0.5:
|
||||
l1 = wd2 * random.random()
|
||||
l2 = wd2 * random.random()
|
||||
k = anisotropic_Gaussian(
|
||||
ksize=random.randint(2, 11) + 3,
|
||||
theta=random.random() * np.pi,
|
||||
l1=l1,
|
||||
l2=l2,
|
||||
)
|
||||
else:
|
||||
k = fspecial("gaussian", random.randint(2, 4) + 3, wd * random.random())
|
||||
img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode="mirror")
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def add_resize(img, sf=4):
|
||||
rnum = np.random.rand()
|
||||
if rnum > 0.8: # up
|
||||
sf1 = random.uniform(1, 2)
|
||||
elif rnum < 0.7: # down
|
||||
sf1 = random.uniform(0.5 / sf, 1)
|
||||
else:
|
||||
sf1 = 1.0
|
||||
img = cv2.resize(
|
||||
img,
|
||||
(int(sf1 * img.shape[1]), int(sf1 * img.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]),
|
||||
)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
|
||||
# noise_level = random.randint(noise_level1, noise_level2)
|
||||
# rnum = np.random.rand()
|
||||
# if rnum > 0.6: # add color Gaussian noise
|
||||
# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
||||
# elif rnum < 0.4: # add grayscale Gaussian noise
|
||||
# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
||||
# else: # add noise
|
||||
# L = noise_level2 / 255.
|
||||
# D = np.diag(np.random.rand(3))
|
||||
# U = orth(np.random.rand(3, 3))
|
||||
# conv = np.dot(np.dot(np.transpose(U), D), U)
|
||||
# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
|
||||
# img = np.clip(img, 0.0, 1.0)
|
||||
# return img
|
||||
|
||||
|
||||
def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
|
||||
noise_level = random.randint(noise_level1, noise_level2)
|
||||
rnum = np.random.rand()
|
||||
if rnum > 0.6: # add color Gaussian noise
|
||||
img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
||||
elif rnum < 0.4: # add grayscale Gaussian noise
|
||||
img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
||||
else: # add noise
|
||||
L = noise_level2 / 255.0
|
||||
D = np.diag(np.random.rand(3))
|
||||
U = orth(np.random.rand(3, 3))
|
||||
conv = np.dot(np.dot(np.transpose(U), D), U)
|
||||
img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
return img
|
||||
|
||||
|
||||
def add_speckle_noise(img, noise_level1=2, noise_level2=25):
|
||||
noise_level = random.randint(noise_level1, noise_level2)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
rnum = random.random()
|
||||
if rnum > 0.6:
|
||||
img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
|
||||
elif rnum < 0.4:
|
||||
img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
|
||||
else:
|
||||
L = noise_level2 / 255.0
|
||||
D = np.diag(np.random.rand(3))
|
||||
U = orth(np.random.rand(3, 3))
|
||||
conv = np.dot(np.dot(np.transpose(U), D), U)
|
||||
img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
return img
|
||||
|
||||
|
||||
def add_Poisson_noise(img):
|
||||
img = np.clip((img * 255.0).round(), 0, 255) / 255.0
|
||||
vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
|
||||
if random.random() < 0.5:
|
||||
img = np.random.poisson(img * vals).astype(np.float32) / vals
|
||||
else:
|
||||
img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
|
||||
img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.0
|
||||
noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
|
||||
img += noise_gray[:, :, np.newaxis]
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
return img
|
||||
|
||||
|
||||
def add_JPEG_noise(img):
|
||||
quality_factor = random.randint(80, 95)
|
||||
img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
|
||||
result, encimg = cv2.imencode(".jpg", img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
|
||||
img = cv2.imdecode(encimg, 1)
|
||||
img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
|
||||
return img
|
||||
|
||||
|
||||
def random_crop(lq, hq, sf=4, lq_patchsize=64):
|
||||
h, w = lq.shape[:2]
|
||||
rnd_h = random.randint(0, h - lq_patchsize)
|
||||
rnd_w = random.randint(0, w - lq_patchsize)
|
||||
lq = lq[rnd_h : rnd_h + lq_patchsize, rnd_w : rnd_w + lq_patchsize, :]
|
||||
|
||||
rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
|
||||
hq = hq[
|
||||
rnd_h_H : rnd_h_H + lq_patchsize * sf,
|
||||
rnd_w_H : rnd_w_H + lq_patchsize * sf,
|
||||
:,
|
||||
]
|
||||
return lq, hq
|
||||
|
||||
|
||||
def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
|
||||
"""
|
||||
This is the degradation model of BSRGAN from the paper
|
||||
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
|
||||
----------
|
||||
img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
|
||||
sf: scale factor
|
||||
isp_model: camera ISP model
|
||||
Returns
|
||||
-------
|
||||
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
|
||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
||||
"""
|
||||
isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
|
||||
sf_ori = sf
|
||||
|
||||
h1, w1 = img.shape[:2]
|
||||
img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
|
||||
h, w = img.shape[:2]
|
||||
|
||||
if h < lq_patchsize * sf or w < lq_patchsize * sf:
|
||||
raise ValueError(f"img size ({h1}X{w1}) is too small!")
|
||||
|
||||
hq = img.copy()
|
||||
|
||||
if sf == 4 and random.random() < scale2_prob: # downsample1
|
||||
if np.random.rand() < 0.5:
|
||||
img = cv2.resize(
|
||||
img,
|
||||
(int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]),
|
||||
)
|
||||
else:
|
||||
img = util.imresize_np(img, 1 / 2, True)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
sf = 2
|
||||
|
||||
shuffle_order = random.sample(range(7), 7)
|
||||
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
|
||||
if idx1 > idx2: # keep downsample3 last
|
||||
shuffle_order[idx1], shuffle_order[idx2] = (
|
||||
shuffle_order[idx2],
|
||||
shuffle_order[idx1],
|
||||
)
|
||||
|
||||
for i in shuffle_order:
|
||||
if i == 0:
|
||||
img = add_blur(img, sf=sf)
|
||||
|
||||
elif i == 1:
|
||||
img = add_blur(img, sf=sf)
|
||||
|
||||
elif i == 2:
|
||||
a, b = img.shape[1], img.shape[0]
|
||||
# downsample2
|
||||
if random.random() < 0.75:
|
||||
sf1 = random.uniform(1, 2 * sf)
|
||||
img = cv2.resize(
|
||||
img,
|
||||
(int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]),
|
||||
)
|
||||
else:
|
||||
k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf))
|
||||
k_shifted = shift_pixel(k, sf)
|
||||
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
|
||||
img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode="mirror")
|
||||
img = img[0::sf, 0::sf, ...] # nearest downsampling
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
|
||||
elif i == 3:
|
||||
# downsample3
|
||||
img = cv2.resize(
|
||||
img,
|
||||
(int(1 / sf * a), int(1 / sf * b)),
|
||||
interpolation=random.choice([1, 2, 3]),
|
||||
)
|
||||
img = np.clip(img, 0.0, 1.0)
|
||||
|
||||
elif i == 4:
|
||||
# add Gaussian noise
|
||||
img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8)
|
||||
|
||||
elif i == 5:
|
||||
# add JPEG noise
|
||||
if random.random() < jpeg_prob:
|
||||
img = add_JPEG_noise(img)
|
||||
|
||||
elif i == 6:
|
||||
# add processed camera sensor noise
|
||||
if random.random() < isp_prob and isp_model is not None:
|
||||
with torch.no_grad():
|
||||
img, hq = isp_model.forward(img.copy(), hq)
|
||||
|
||||
# add final JPEG compression noise
|
||||
img = add_JPEG_noise(img)
|
||||
|
||||
# random crop
|
||||
img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
|
||||
|
||||
return img, hq
|
||||
|
||||
|
||||
# todo no isp_model?
|
||||
def degradation_bsrgan_variant(image, sf=4, isp_model=None):
|
||||
"""
|
||||
This is the degradation model of BSRGAN from the paper
|
||||
"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
|
||||
----------
|
||||
sf: scale factor
|
||||
isp_model: camera ISP model
|
||||
Returns
|
||||
-------
|
||||
img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
|
||||
hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
|
||||
"""
|
||||
image = util.uint2single(image)
|
||||
jpeg_prob, scale2_prob = 0.9, 0.25
|
||||
# isp_prob = 0.25 # uncomment with `if i== 6` block below
|
||||
# sf_ori = sf # uncomment with `if i== 6` block below
|
||||
|
||||
h1, w1 = image.shape[:2]
|
||||
image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop
|
||||
h, w = image.shape[:2]
|
||||
|
||||
# hq = image.copy() # uncomment with `if i== 6` block below
|
||||
|
||||
if sf == 4 and random.random() < scale2_prob: # downsample1
|
||||
if np.random.rand() < 0.5:
|
||||
image = cv2.resize(
|
||||
image,
|
||||
(int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
|
||||
interpolation=random.choice([1, 2, 3]),
|
||||
)
|
||||
else:
|
||||
image = util.imresize_np(image, 1 / 2, True)
|
||||
image = np.clip(image, 0.0, 1.0)
|
||||
sf = 2
|
||||
|
||||
shuffle_order = random.sample(range(7), 7)
|
||||
idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
|
||||
if idx1 > idx2: # keep downsample3 last
|
||||
shuffle_order[idx1], shuffle_order[idx2] = (
|
||||
shuffle_order[idx2],
|
||||
shuffle_order[idx1],
|
||||
)
|
||||
|
||||
for i in shuffle_order:
|
||||
if i == 0:
|
||||
image = add_blur(image, sf=sf)
|
||||
|
||||
# elif i == 1:
|
||||
# image = add_blur(image, sf=sf)
|
||||
|
||||
if i == 0:
|
||||
pass
|
||||
|
||||
elif i == 2:
|
||||
a, b = image.shape[1], image.shape[0]
|
||||
# downsample2
|
||||
if random.random() < 0.8:
|
||||
sf1 = random.uniform(1, 2 * sf)
|
||||
image = cv2.resize(
|
||||
image,
|
||||
(
|
||||
int(1 / sf1 * image.shape[1]),
|
||||
int(1 / sf1 * image.shape[0]),
|
||||
),
|
||||
interpolation=random.choice([1, 2, 3]),
|
||||
)
|
||||
else:
|
||||
k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf))
|
||||
k_shifted = shift_pixel(k, sf)
|
||||
k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
|
||||
image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode="mirror")
|
||||
image = image[0::sf, 0::sf, ...] # nearest downsampling
|
||||
|
||||
image = np.clip(image, 0.0, 1.0)
|
||||
|
||||
elif i == 3:
|
||||
# downsample3
|
||||
image = cv2.resize(
|
||||
image,
|
||||
(int(1 / sf * a), int(1 / sf * b)),
|
||||
interpolation=random.choice([1, 2, 3]),
|
||||
)
|
||||
image = np.clip(image, 0.0, 1.0)
|
||||
|
||||
elif i == 4:
|
||||
# add Gaussian noise
|
||||
image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2)
|
||||
|
||||
elif i == 5:
|
||||
# add JPEG noise
|
||||
if random.random() < jpeg_prob:
|
||||
image = add_JPEG_noise(image)
|
||||
#
|
||||
# elif i == 6:
|
||||
# # add processed camera sensor noise
|
||||
# if random.random() < isp_prob and isp_model is not None:
|
||||
# with torch.no_grad():
|
||||
# img, hq = isp_model.forward(img.copy(), hq)
|
||||
|
||||
# add final JPEG compression noise
|
||||
image = add_JPEG_noise(image)
|
||||
image = util.single2uint(image)
|
||||
example = {"image": image}
|
||||
return example
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("hey")
|
||||
img = util.imread_uint("utils/test.png", 3)
|
||||
img = img[:448, :448]
|
||||
h = img.shape[0] // 4
|
||||
print("resizing to", h)
|
||||
sf = 4
|
||||
deg_fn = partial(degradation_bsrgan_variant, sf=sf)
|
||||
for i in range(20):
|
||||
print(i)
|
||||
img_hq = img
|
||||
img_lq = deg_fn(img)["image"]
|
||||
img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq)
|
||||
print(img_lq)
|
||||
img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)[
|
||||
"image"
|
||||
]
|
||||
print(img_lq.shape)
|
||||
print("bicubic", img_lq_bicubic.shape)
|
||||
print(img_hq.shape)
|
||||
lq_nearest = cv2.resize(
|
||||
util.single2uint(img_lq),
|
||||
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
|
||||
interpolation=0,
|
||||
)
|
||||
lq_bicubic_nearest = cv2.resize(
|
||||
util.single2uint(img_lq_bicubic),
|
||||
(int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
|
||||
interpolation=0,
|
||||
)
|
||||
img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
|
||||
util.imsave(img_concat, str(i) + ".png")
|
Binary file not shown.
After Width: | Height: | Size: 431 KiB |
@ -0,0 +1,980 @@
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from datetime import datetime
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from torchvision.utils import make_grid
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# Kai Zhang (github: https://github.com/cszn)
|
||||
# 03/Mar/2019
|
||||
# --------------------------------------------
|
||||
# https://github.com/twhui/SRGAN-pyTorch
|
||||
# https://github.com/xinntao/BasicSR
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
IMG_EXTENSIONS = [
|
||||
".jpg",
|
||||
".JPG",
|
||||
".jpeg",
|
||||
".JPEG",
|
||||
".png",
|
||||
".PNG",
|
||||
".ppm",
|
||||
".PPM",
|
||||
".bmp",
|
||||
".BMP",
|
||||
".tif",
|
||||
]
|
||||
|
||||
|
||||
def is_image_file(filename):
|
||||
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
||||
|
||||
|
||||
def get_timestamp():
|
||||
return datetime.now().strftime("%y%m%d-%H%M%S")
|
||||
|
||||
|
||||
def imshow(x, title=None, cbar=False, figsize=None):
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
plt.figure(figsize=figsize)
|
||||
plt.imshow(np.squeeze(x), interpolation="nearest", cmap="gray")
|
||||
if title:
|
||||
plt.title(title)
|
||||
if cbar:
|
||||
plt.colorbar()
|
||||
plt.show()
|
||||
|
||||
|
||||
def surf(Z, cmap="rainbow", figsize=None):
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
plt.figure(figsize=figsize)
|
||||
ax3 = plt.axes(projection="3d")
|
||||
|
||||
w, h = Z.shape[:2]
|
||||
xx = np.arange(0, w, 1)
|
||||
yy = np.arange(0, h, 1)
|
||||
X, Y = np.meshgrid(xx, yy)
|
||||
ax3.plot_surface(X, Y, Z, cmap=cmap)
|
||||
# ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap)
|
||||
plt.show()
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# get image pathes
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
def get_image_paths(dataroot):
|
||||
paths = None # return None if dataroot is None
|
||||
if dataroot is not None:
|
||||
paths = sorted(_get_paths_from_images(dataroot))
|
||||
return paths
|
||||
|
||||
|
||||
def _get_paths_from_images(path):
|
||||
assert os.path.isdir(path), "{:s} is not a valid directory".format(path)
|
||||
images = []
|
||||
for dirpath, _, fnames in sorted(os.walk(path, followlinks=True)):
|
||||
for fname in sorted(fnames):
|
||||
if is_image_file(fname):
|
||||
img_path = os.path.join(dirpath, fname)
|
||||
images.append(img_path)
|
||||
assert images, "{:s} has no valid image file".format(path)
|
||||
return images
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# split large images into small images
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
|
||||
w, h = img.shape[:2]
|
||||
patches = []
|
||||
if w > p_max and h > p_max:
|
||||
w1 = list(np.arange(0, w - p_size, p_size - p_overlap, dtype=np.int))
|
||||
h1 = list(np.arange(0, h - p_size, p_size - p_overlap, dtype=np.int))
|
||||
w1.append(w - p_size)
|
||||
h1.append(h - p_size)
|
||||
# print(w1)
|
||||
# print(h1)
|
||||
for i in w1:
|
||||
for j in h1:
|
||||
patches.append(img[i : i + p_size, j : j + p_size, :])
|
||||
else:
|
||||
patches.append(img)
|
||||
|
||||
return patches
|
||||
|
||||
|
||||
def imssave(imgs, img_path):
|
||||
"""
|
||||
imgs: list, N images of size WxHxC
|
||||
"""
|
||||
img_name, ext = os.path.splitext(os.path.basename(img_path))
|
||||
|
||||
for i, img in enumerate(imgs):
|
||||
if img.ndim == 3:
|
||||
img = img[:, :, [2, 1, 0]]
|
||||
new_path = os.path.join(
|
||||
os.path.dirname(img_path),
|
||||
img_name + str("_s{:04d}".format(i)) + ".png",
|
||||
)
|
||||
cv2.imwrite(new_path, img)
|
||||
|
||||
|
||||
def split_imageset(
|
||||
original_dataroot,
|
||||
taget_dataroot,
|
||||
n_channels=3,
|
||||
p_size=800,
|
||||
p_overlap=96,
|
||||
p_max=1000,
|
||||
):
|
||||
"""
|
||||
split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),
|
||||
and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)
|
||||
will be splitted.
|
||||
Args:
|
||||
original_dataroot:
|
||||
taget_dataroot:
|
||||
p_size: size of small images
|
||||
p_overlap: patch size in training is a good choice
|
||||
p_max: images with smaller size than (p_max)x(p_max) keep unchanged.
|
||||
"""
|
||||
paths = get_image_paths(original_dataroot)
|
||||
for img_path in paths:
|
||||
# img_name, ext = os.path.splitext(os.path.basename(img_path))
|
||||
img = imread_uint(img_path, n_channels=n_channels)
|
||||
patches = patches_from_image(img, p_size, p_overlap, p_max)
|
||||
imssave(patches, os.path.join(taget_dataroot, os.path.basename(img_path)))
|
||||
# if original_dataroot == taget_dataroot:
|
||||
# del img_path
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# makedir
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
def mkdir(path):
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
|
||||
|
||||
def mkdirs(paths):
|
||||
if isinstance(paths, str):
|
||||
mkdir(paths)
|
||||
else:
|
||||
for path in paths:
|
||||
mkdir(path)
|
||||
|
||||
|
||||
def mkdir_and_rename(path):
|
||||
if os.path.exists(path):
|
||||
new_name = path + "_archived_" + get_timestamp()
|
||||
logger.error("Path already exists. Rename it to [{:s}]".format(new_name))
|
||||
os.replace(path, new_name)
|
||||
os.makedirs(path)
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# read image from path
|
||||
# opencv is fast, but read BGR numpy image
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# get uint8 image of size HxWxn_channles (RGB)
|
||||
# --------------------------------------------
|
||||
def imread_uint(path, n_channels=3):
|
||||
# input: path
|
||||
# output: HxWx3(RGB or GGG), or HxWx1 (G)
|
||||
if n_channels == 1:
|
||||
img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE
|
||||
img = np.expand_dims(img, axis=2) # HxWx1
|
||||
elif n_channels == 3:
|
||||
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G
|
||||
if img.ndim == 2:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG
|
||||
else:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB
|
||||
return img
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# matlab's imwrite
|
||||
# --------------------------------------------
|
||||
def imsave(img, img_path):
|
||||
img = np.squeeze(img)
|
||||
if img.ndim == 3:
|
||||
img = img[:, :, [2, 1, 0]]
|
||||
cv2.imwrite(img_path, img)
|
||||
|
||||
|
||||
def imwrite(img, img_path):
|
||||
img = np.squeeze(img)
|
||||
if img.ndim == 3:
|
||||
img = img[:, :, [2, 1, 0]]
|
||||
cv2.imwrite(img_path, img)
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# get single image of size HxWxn_channles (BGR)
|
||||
# --------------------------------------------
|
||||
def read_img(path):
|
||||
# read image by cv2
|
||||
# return: Numpy float32, HWC, BGR, [0,1]
|
||||
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE
|
||||
img = img.astype(np.float32) / 255.0
|
||||
if img.ndim == 2:
|
||||
img = np.expand_dims(img, axis=2)
|
||||
# some images have 4 channels
|
||||
if img.shape[2] > 3:
|
||||
img = img[:, :, :3]
|
||||
return img
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# image format conversion
|
||||
# --------------------------------------------
|
||||
# numpy(single) <---> numpy(unit)
|
||||
# numpy(single) <---> tensor
|
||||
# numpy(unit) <---> tensor
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# numpy(single) [0, 1] <---> numpy(unit)
|
||||
# --------------------------------------------
|
||||
|
||||
|
||||
def uint2single(img):
|
||||
return np.float32(img / 255.0)
|
||||
|
||||
|
||||
def single2uint(img):
|
||||
return np.uint8((img.clip(0, 1) * 255.0).round())
|
||||
|
||||
|
||||
def uint162single(img):
|
||||
return np.float32(img / 65535.0)
|
||||
|
||||
|
||||
def single2uint16(img):
|
||||
return np.uint16((img.clip(0, 1) * 65535.0).round())
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# numpy(unit) (HxWxC or HxW) <---> tensor
|
||||
# --------------------------------------------
|
||||
|
||||
|
||||
# convert uint to 4-dimensional torch tensor
|
||||
def uint2tensor4(img):
|
||||
if img.ndim == 2:
|
||||
img = np.expand_dims(img, axis=2)
|
||||
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.0).unsqueeze(0)
|
||||
|
||||
|
||||
# convert uint to 3-dimensional torch tensor
|
||||
def uint2tensor3(img):
|
||||
if img.ndim == 2:
|
||||
img = np.expand_dims(img, axis=2)
|
||||
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.0)
|
||||
|
||||
|
||||
# convert 2/3/4-dimensional torch tensor to uint
|
||||
def tensor2uint(img):
|
||||
img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
|
||||
if img.ndim == 3:
|
||||
img = np.transpose(img, (1, 2, 0))
|
||||
return np.uint8((img * 255.0).round())
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# numpy(single) (HxWxC) <---> tensor
|
||||
# --------------------------------------------
|
||||
|
||||
|
||||
# convert single (HxWxC) to 3-dimensional torch tensor
|
||||
def single2tensor3(img):
|
||||
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float()
|
||||
|
||||
|
||||
# convert single (HxWxC) to 4-dimensional torch tensor
|
||||
def single2tensor4(img):
|
||||
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0)
|
||||
|
||||
|
||||
# convert torch tensor to single
|
||||
def tensor2single(img):
|
||||
img = img.data.squeeze().float().cpu().numpy()
|
||||
if img.ndim == 3:
|
||||
img = np.transpose(img, (1, 2, 0))
|
||||
|
||||
return img
|
||||
|
||||
|
||||
# convert torch tensor to single
|
||||
def tensor2single3(img):
|
||||
img = img.data.squeeze().float().cpu().numpy()
|
||||
if img.ndim == 3:
|
||||
img = np.transpose(img, (1, 2, 0))
|
||||
elif img.ndim == 2:
|
||||
img = np.expand_dims(img, axis=2)
|
||||
return img
|
||||
|
||||
|
||||
def single2tensor5(img):
|
||||
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0)
|
||||
|
||||
|
||||
def single32tensor5(img):
|
||||
return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0)
|
||||
|
||||
|
||||
def single42tensor4(img):
|
||||
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()
|
||||
|
||||
|
||||
# from skimage.io import imread, imsave
|
||||
def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
|
||||
"""
|
||||
Converts a torch Tensor into an image Numpy array of BGR channel order
|
||||
Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
|
||||
Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
|
||||
"""
|
||||
tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp
|
||||
tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
|
||||
n_dim = tensor.dim()
|
||||
if n_dim == 4:
|
||||
n_img = len(tensor)
|
||||
img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
|
||||
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
|
||||
elif n_dim == 3:
|
||||
img_np = tensor.numpy()
|
||||
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
|
||||
elif n_dim == 2:
|
||||
img_np = tensor.numpy()
|
||||
else:
|
||||
raise TypeError("Only support 4D, 3D and 2D tensor. But received with dimension: {:d}".format(n_dim))
|
||||
if out_type == np.uint8:
|
||||
img_np = (img_np * 255.0).round()
|
||||
# Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
|
||||
return img_np.astype(out_type)
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# Augmentation, flipe and/or rotate
|
||||
# --------------------------------------------
|
||||
# The following two are enough.
|
||||
# (1) augmet_img: numpy image of WxHxC or WxH
|
||||
# (2) augment_img_tensor4: tensor image 1xCxWxH
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
def augment_img(img, mode=0):
|
||||
"""Kai Zhang (github: https://github.com/cszn)"""
|
||||
if mode == 0:
|
||||
return img
|
||||
elif mode == 1:
|
||||
return np.flipud(np.rot90(img))
|
||||
elif mode == 2:
|
||||
return np.flipud(img)
|
||||
elif mode == 3:
|
||||
return np.rot90(img, k=3)
|
||||
elif mode == 4:
|
||||
return np.flipud(np.rot90(img, k=2))
|
||||
elif mode == 5:
|
||||
return np.rot90(img)
|
||||
elif mode == 6:
|
||||
return np.rot90(img, k=2)
|
||||
elif mode == 7:
|
||||
return np.flipud(np.rot90(img, k=3))
|
||||
|
||||
|
||||
def augment_img_tensor4(img, mode=0):
|
||||
"""Kai Zhang (github: https://github.com/cszn)"""
|
||||
if mode == 0:
|
||||
return img
|
||||
elif mode == 1:
|
||||
return img.rot90(1, [2, 3]).flip([2])
|
||||
elif mode == 2:
|
||||
return img.flip([2])
|
||||
elif mode == 3:
|
||||
return img.rot90(3, [2, 3])
|
||||
elif mode == 4:
|
||||
return img.rot90(2, [2, 3]).flip([2])
|
||||
elif mode == 5:
|
||||
return img.rot90(1, [2, 3])
|
||||
elif mode == 6:
|
||||
return img.rot90(2, [2, 3])
|
||||
elif mode == 7:
|
||||
return img.rot90(3, [2, 3]).flip([2])
|
||||
|
||||
|
||||
def augment_img_tensor(img, mode=0):
|
||||
"""Kai Zhang (github: https://github.com/cszn)"""
|
||||
img_size = img.size()
|
||||
img_np = img.data.cpu().numpy()
|
||||
if len(img_size) == 3:
|
||||
img_np = np.transpose(img_np, (1, 2, 0))
|
||||
elif len(img_size) == 4:
|
||||
img_np = np.transpose(img_np, (2, 3, 1, 0))
|
||||
img_np = augment_img(img_np, mode=mode)
|
||||
img_tensor = torch.from_numpy(np.ascontiguousarray(img_np))
|
||||
if len(img_size) == 3:
|
||||
img_tensor = img_tensor.permute(2, 0, 1)
|
||||
elif len(img_size) == 4:
|
||||
img_tensor = img_tensor.permute(3, 2, 0, 1)
|
||||
|
||||
return img_tensor.type_as(img)
|
||||
|
||||
|
||||
def augment_img_np3(img, mode=0):
|
||||
if mode == 0:
|
||||
return img
|
||||
elif mode == 1:
|
||||
return img.transpose(1, 0, 2)
|
||||
elif mode == 2:
|
||||
return img[::-1, :, :]
|
||||
elif mode == 3:
|
||||
img = img[::-1, :, :]
|
||||
img = img.transpose(1, 0, 2)
|
||||
return img
|
||||
elif mode == 4:
|
||||
return img[:, ::-1, :]
|
||||
elif mode == 5:
|
||||
img = img[:, ::-1, :]
|
||||
img = img.transpose(1, 0, 2)
|
||||
return img
|
||||
elif mode == 6:
|
||||
img = img[:, ::-1, :]
|
||||
img = img[::-1, :, :]
|
||||
return img
|
||||
elif mode == 7:
|
||||
img = img[:, ::-1, :]
|
||||
img = img[::-1, :, :]
|
||||
img = img.transpose(1, 0, 2)
|
||||
return img
|
||||
|
||||
|
||||
def augment_imgs(img_list, hflip=True, rot=True):
|
||||
# horizontal flip OR rotate
|
||||
hflip = hflip and random.random() < 0.5
|
||||
vflip = rot and random.random() < 0.5
|
||||
rot90 = rot and random.random() < 0.5
|
||||
|
||||
def _augment(img):
|
||||
if hflip:
|
||||
img = img[:, ::-1, :]
|
||||
if vflip:
|
||||
img = img[::-1, :, :]
|
||||
if rot90:
|
||||
img = img.transpose(1, 0, 2)
|
||||
return img
|
||||
|
||||
return [_augment(img) for img in img_list]
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# modcrop and shave
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
def modcrop(img_in, scale):
|
||||
# img_in: Numpy, HWC or HW
|
||||
img = np.copy(img_in)
|
||||
if img.ndim == 2:
|
||||
H, W = img.shape
|
||||
H_r, W_r = H % scale, W % scale
|
||||
img = img[: H - H_r, : W - W_r]
|
||||
elif img.ndim == 3:
|
||||
H, W, C = img.shape
|
||||
H_r, W_r = H % scale, W % scale
|
||||
img = img[: H - H_r, : W - W_r, :]
|
||||
else:
|
||||
raise ValueError("Wrong img ndim: [{:d}].".format(img.ndim))
|
||||
return img
|
||||
|
||||
|
||||
def shave(img_in, border=0):
|
||||
# img_in: Numpy, HWC or HW
|
||||
img = np.copy(img_in)
|
||||
h, w = img.shape[:2]
|
||||
img = img[border : h - border, border : w - border]
|
||||
return img
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# image processing process on numpy image
|
||||
# channel_convert(in_c, tar_type, img_list):
|
||||
# rgb2ycbcr(img, only_y=True):
|
||||
# bgr2ycbcr(img, only_y=True):
|
||||
# ycbcr2rgb(img):
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
def rgb2ycbcr(img, only_y=True):
|
||||
"""same as matlab rgb2ycbcr
|
||||
only_y: only return Y channel
|
||||
Input:
|
||||
uint8, [0, 255]
|
||||
float, [0, 1]
|
||||
"""
|
||||
in_img_type = img.dtype
|
||||
img.astype(np.float32)
|
||||
if in_img_type != np.uint8:
|
||||
img *= 255.0
|
||||
# convert
|
||||
if only_y:
|
||||
rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
|
||||
else:
|
||||
rlt = (
|
||||
np.matmul(
|
||||
img,
|
||||
[
|
||||
[65.481, -37.797, 112.0],
|
||||
[128.553, -74.203, -93.786],
|
||||
[24.966, 112.0, -18.214],
|
||||
],
|
||||
)
|
||||
/ 255.0
|
||||
+ [16, 128, 128]
|
||||
)
|
||||
if in_img_type == np.uint8:
|
||||
rlt = rlt.round()
|
||||
else:
|
||||
rlt /= 255.0
|
||||
return rlt.astype(in_img_type)
|
||||
|
||||
|
||||
def ycbcr2rgb(img):
|
||||
"""same as matlab ycbcr2rgb
|
||||
Input:
|
||||
uint8, [0, 255]
|
||||
float, [0, 1]
|
||||
"""
|
||||
in_img_type = img.dtype
|
||||
img.astype(np.float32)
|
||||
if in_img_type != np.uint8:
|
||||
img *= 255.0
|
||||
# convert
|
||||
rlt = (
|
||||
np.matmul(
|
||||
img,
|
||||
[
|
||||
[0.00456621, 0.00456621, 0.00456621],
|
||||
[0, -0.00153632, 0.00791071],
|
||||
[0.00625893, -0.00318811, 0],
|
||||
],
|
||||
)
|
||||
* 255.0
|
||||
+ [-222.921, 135.576, -276.836]
|
||||
)
|
||||
if in_img_type == np.uint8:
|
||||
rlt = rlt.round()
|
||||
else:
|
||||
rlt /= 255.0
|
||||
return rlt.astype(in_img_type)
|
||||
|
||||
|
||||
def bgr2ycbcr(img, only_y=True):
|
||||
"""bgr version of rgb2ycbcr
|
||||
only_y: only return Y channel
|
||||
Input:
|
||||
uint8, [0, 255]
|
||||
float, [0, 1]
|
||||
"""
|
||||
in_img_type = img.dtype
|
||||
img.astype(np.float32)
|
||||
if in_img_type != np.uint8:
|
||||
img *= 255.0
|
||||
# convert
|
||||
if only_y:
|
||||
rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
|
||||
else:
|
||||
rlt = (
|
||||
np.matmul(
|
||||
img,
|
||||
[
|
||||
[24.966, 112.0, -18.214],
|
||||
[128.553, -74.203, -93.786],
|
||||
[65.481, -37.797, 112.0],
|
||||
],
|
||||
)
|
||||
/ 255.0
|
||||
+ [16, 128, 128]
|
||||
)
|
||||
if in_img_type == np.uint8:
|
||||
rlt = rlt.round()
|
||||
else:
|
||||
rlt /= 255.0
|
||||
return rlt.astype(in_img_type)
|
||||
|
||||
|
||||
def channel_convert(in_c, tar_type, img_list):
|
||||
# conversion among BGR, gray and y
|
||||
if in_c == 3 and tar_type == "gray": # BGR to gray
|
||||
gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
|
||||
return [np.expand_dims(img, axis=2) for img in gray_list]
|
||||
elif in_c == 3 and tar_type == "y": # BGR to y
|
||||
y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
|
||||
return [np.expand_dims(img, axis=2) for img in y_list]
|
||||
elif in_c == 1 and tar_type == "RGB": # gray/y to BGR
|
||||
return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
|
||||
else:
|
||||
return img_list
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# metric, PSNR and SSIM
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# PSNR
|
||||
# --------------------------------------------
|
||||
def calculate_psnr(img1, img2, border=0):
|
||||
# img1 and img2 have range [0, 255]
|
||||
# img1 = img1.squeeze()
|
||||
# img2 = img2.squeeze()
|
||||
if not img1.shape == img2.shape:
|
||||
raise ValueError("Input images must have the same dimensions.")
|
||||
h, w = img1.shape[:2]
|
||||
img1 = img1[border : h - border, border : w - border]
|
||||
img2 = img2[border : h - border, border : w - border]
|
||||
|
||||
img1 = img1.astype(np.float64)
|
||||
img2 = img2.astype(np.float64)
|
||||
mse = np.mean((img1 - img2) ** 2)
|
||||
if mse == 0:
|
||||
return float("inf")
|
||||
return 20 * math.log10(255.0 / math.sqrt(mse))
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# SSIM
|
||||
# --------------------------------------------
|
||||
def calculate_ssim(img1, img2, border=0):
|
||||
"""calculate SSIM
|
||||
the same outputs as MATLAB's
|
||||
img1, img2: [0, 255]
|
||||
"""
|
||||
# img1 = img1.squeeze()
|
||||
# img2 = img2.squeeze()
|
||||
if not img1.shape == img2.shape:
|
||||
raise ValueError("Input images must have the same dimensions.")
|
||||
h, w = img1.shape[:2]
|
||||
img1 = img1[border : h - border, border : w - border]
|
||||
img2 = img2[border : h - border, border : w - border]
|
||||
|
||||
if img1.ndim == 2:
|
||||
return ssim(img1, img2)
|
||||
elif img1.ndim == 3:
|
||||
if img1.shape[2] == 3:
|
||||
ssims = []
|
||||
for i in range(3):
|
||||
ssims.append(ssim(img1[:, :, i], img2[:, :, i]))
|
||||
return np.array(ssims).mean()
|
||||
elif img1.shape[2] == 1:
|
||||
return ssim(np.squeeze(img1), np.squeeze(img2))
|
||||
else:
|
||||
raise ValueError("Wrong input image dimensions.")
|
||||
|
||||
|
||||
def ssim(img1, img2):
|
||||
C1 = (0.01 * 255) ** 2
|
||||
C2 = (0.03 * 255) ** 2
|
||||
|
||||
img1 = img1.astype(np.float64)
|
||||
img2 = img2.astype(np.float64)
|
||||
kernel = cv2.getGaussianKernel(11, 1.5)
|
||||
window = np.outer(kernel, kernel.transpose())
|
||||
|
||||
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
|
||||
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
|
||||
mu1_sq = mu1 ** 2
|
||||
mu2_sq = mu2 ** 2
|
||||
mu1_mu2 = mu1 * mu2
|
||||
sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq
|
||||
sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq
|
||||
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
|
||||
|
||||
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
||||
return ssim_map.mean()
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# matlab's bicubic imresize (numpy and torch) [0, 1]
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
# matlab 'imresize' function, now only support 'bicubic'
|
||||
def cubic(x):
|
||||
absx = torch.abs(x)
|
||||
absx2 = absx ** 2
|
||||
absx3 = absx ** 3
|
||||
return (1.5 * absx3 - 2.5 * absx2 + 1) * ((absx <= 1).type_as(absx)) + (
|
||||
-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2
|
||||
) * (((absx > 1) * (absx <= 2)).type_as(absx))
|
||||
|
||||
|
||||
def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
|
||||
if (scale < 1) and (antialiasing):
|
||||
# Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
|
||||
kernel_width = kernel_width / scale
|
||||
|
||||
# Output-space coordinates
|
||||
x = torch.linspace(1, out_length, out_length)
|
||||
|
||||
# Input-space coordinates. Calculate the inverse mapping such that 0.5
|
||||
# in output space maps to 0.5 in input space, and 0.5+scale in output
|
||||
# space maps to 1.5 in input space.
|
||||
u = x / scale + 0.5 * (1 - 1 / scale)
|
||||
|
||||
# What is the left-most pixel that can be involved in the computation?
|
||||
left = torch.floor(u - kernel_width / 2)
|
||||
|
||||
# What is the maximum number of pixels that can be involved in the
|
||||
# computation? Note: it's OK to use an extra pixel here; if the
|
||||
# corresponding weights are all zero, it will be eliminated at the end
|
||||
# of this function.
|
||||
P = math.ceil(kernel_width) + 2
|
||||
|
||||
# The indices of the input pixels involved in computing the k-th output
|
||||
# pixel are in row k of the indices matrix.
|
||||
indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(1, P).expand(
|
||||
out_length, P
|
||||
)
|
||||
|
||||
# The weights used to compute the k-th output pixel are in row k of the
|
||||
# weights matrix.
|
||||
distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
|
||||
# apply cubic kernel
|
||||
if (scale < 1) and (antialiasing):
|
||||
weights = scale * cubic(distance_to_center * scale)
|
||||
else:
|
||||
weights = cubic(distance_to_center)
|
||||
# Normalize the weights matrix so that each row sums to 1.
|
||||
weights_sum = torch.sum(weights, 1).view(out_length, 1)
|
||||
weights = weights / weights_sum.expand(out_length, P)
|
||||
|
||||
# If a column in weights is all zero, get rid of it. only consider the first and last column.
|
||||
weights_zero_tmp = torch.sum((weights == 0), 0)
|
||||
if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
|
||||
indices = indices.narrow(1, 1, P - 2)
|
||||
weights = weights.narrow(1, 1, P - 2)
|
||||
if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
|
||||
indices = indices.narrow(1, 0, P - 2)
|
||||
weights = weights.narrow(1, 0, P - 2)
|
||||
weights = weights.contiguous()
|
||||
indices = indices.contiguous()
|
||||
sym_len_s = -indices.min() + 1
|
||||
sym_len_e = indices.max() - in_length
|
||||
indices = indices + sym_len_s - 1
|
||||
return weights, indices, int(sym_len_s), int(sym_len_e)
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# imresize for tensor image [0, 1]
|
||||
# --------------------------------------------
|
||||
def imresize(img, scale, antialiasing=True):
|
||||
# Now the scale should be the same for H and W
|
||||
# input: img: pytorch tensor, CHW or HW [0,1]
|
||||
# output: CHW or HW [0,1] w/o round
|
||||
need_squeeze = True if img.dim() == 2 else False
|
||||
if need_squeeze:
|
||||
img.unsqueeze_(0)
|
||||
in_C, in_H, in_W = img.size()
|
||||
out_C, out_H, out_W = (
|
||||
in_C,
|
||||
math.ceil(in_H * scale),
|
||||
math.ceil(in_W * scale),
|
||||
)
|
||||
kernel_width = 4
|
||||
kernel = "cubic"
|
||||
|
||||
# Return the desired dimension order for performing the resize. The
|
||||
# strategy is to perform the resize first along the dimension with the
|
||||
# smallest scale factor.
|
||||
# Now we do not support this.
|
||||
|
||||
# get weights and indices
|
||||
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
|
||||
in_H, out_H, scale, kernel, kernel_width, antialiasing
|
||||
)
|
||||
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
|
||||
in_W, out_W, scale, kernel, kernel_width, antialiasing
|
||||
)
|
||||
# process H dimension
|
||||
# symmetric copying
|
||||
img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
|
||||
img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
|
||||
|
||||
sym_patch = img[:, :sym_len_Hs, :]
|
||||
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
||||
img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
|
||||
|
||||
sym_patch = img[:, -sym_len_He:, :]
|
||||
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
||||
img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
|
||||
|
||||
out_1 = torch.FloatTensor(in_C, out_H, in_W)
|
||||
kernel_width = weights_H.size(1)
|
||||
for i in range(out_H):
|
||||
idx = int(indices_H[i][0])
|
||||
for j in range(out_C):
|
||||
out_1[j, i, :] = img_aug[j, idx : idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
|
||||
|
||||
# process W dimension
|
||||
# symmetric copying
|
||||
out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
|
||||
out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
|
||||
|
||||
sym_patch = out_1[:, :, :sym_len_Ws]
|
||||
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
||||
out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
|
||||
|
||||
sym_patch = out_1[:, :, -sym_len_We:]
|
||||
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
||||
out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
|
||||
|
||||
out_2 = torch.FloatTensor(in_C, out_H, out_W)
|
||||
kernel_width = weights_W.size(1)
|
||||
for i in range(out_W):
|
||||
idx = int(indices_W[i][0])
|
||||
for j in range(out_C):
|
||||
out_2[j, :, i] = out_1_aug[j, :, idx : idx + kernel_width].mv(weights_W[i])
|
||||
if need_squeeze:
|
||||
out_2.squeeze_()
|
||||
return out_2
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
# imresize for numpy image [0, 1]
|
||||
# --------------------------------------------
|
||||
def imresize_np(img, scale, antialiasing=True):
|
||||
# Now the scale should be the same for H and W
|
||||
# input: img: Numpy, HWC or HW [0,1]
|
||||
# output: HWC or HW [0,1] w/o round
|
||||
img = torch.from_numpy(img)
|
||||
need_squeeze = True if img.dim() == 2 else False
|
||||
if need_squeeze:
|
||||
img.unsqueeze_(2)
|
||||
|
||||
in_H, in_W, in_C = img.size()
|
||||
out_C, out_H, out_W = (
|
||||
in_C,
|
||||
math.ceil(in_H * scale),
|
||||
math.ceil(in_W * scale),
|
||||
)
|
||||
kernel_width = 4
|
||||
kernel = "cubic"
|
||||
|
||||
# Return the desired dimension order for performing the resize. The
|
||||
# strategy is to perform the resize first along the dimension with the
|
||||
# smallest scale factor.
|
||||
# Now we do not support this.
|
||||
|
||||
# get weights and indices
|
||||
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
|
||||
in_H, out_H, scale, kernel, kernel_width, antialiasing
|
||||
)
|
||||
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
|
||||
in_W, out_W, scale, kernel, kernel_width, antialiasing
|
||||
)
|
||||
# process H dimension
|
||||
# symmetric copying
|
||||
img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
|
||||
img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
|
||||
|
||||
sym_patch = img[:sym_len_Hs, :, :]
|
||||
inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(0, inv_idx)
|
||||
img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
|
||||
|
||||
sym_patch = img[-sym_len_He:, :, :]
|
||||
inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(0, inv_idx)
|
||||
img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
|
||||
|
||||
out_1 = torch.FloatTensor(out_H, in_W, in_C)
|
||||
kernel_width = weights_H.size(1)
|
||||
for i in range(out_H):
|
||||
idx = int(indices_H[i][0])
|
||||
for j in range(out_C):
|
||||
out_1[i, :, j] = img_aug[idx : idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
|
||||
|
||||
# process W dimension
|
||||
# symmetric copying
|
||||
out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
|
||||
out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
|
||||
|
||||
sym_patch = out_1[:, :sym_len_Ws, :]
|
||||
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
||||
out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
|
||||
|
||||
sym_patch = out_1[:, -sym_len_We:, :]
|
||||
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
||||
out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
|
||||
|
||||
out_2 = torch.FloatTensor(out_H, out_W, in_C)
|
||||
kernel_width = weights_W.size(1)
|
||||
for i in range(out_W):
|
||||
idx = int(indices_W[i][0])
|
||||
for j in range(out_C):
|
||||
out_2[:, i, j] = out_1_aug[:, idx : idx + kernel_width, j].mv(weights_W[i])
|
||||
if need_squeeze:
|
||||
out_2.squeeze_()
|
||||
|
||||
return out_2.numpy()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("---")
|
||||
# img = imread_uint('test.bmp', 3)
|
||||
# img = uint2single(img)
|
||||
# img_bicubic = imresize_np(img, 1/4)
|
@ -475,10 +475,7 @@ class TextualInversionDataset(Dataset):
|
||||
|
||||
if self.center_crop:
|
||||
crop = min(img.shape[0], img.shape[1])
|
||||
(
|
||||
h,
|
||||
w,
|
||||
) = (
|
||||
(h, w,) = (
|
||||
img.shape[0],
|
||||
img.shape[1],
|
||||
)
|
||||
|
@ -10,6 +10,7 @@ from .devices import ( # noqa: F401
|
||||
normalize_device,
|
||||
torch_dtype,
|
||||
)
|
||||
from .log import write_log # noqa: F401
|
||||
from .util import ( # noqa: F401
|
||||
ask_user,
|
||||
download_with_resume,
|
||||
|
@ -1,7 +1,7 @@
|
||||
import math
|
||||
|
||||
import diffusers
|
||||
import torch
|
||||
import diffusers
|
||||
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
torch.empty = torch.zeros
|
||||
@ -203,7 +203,7 @@ class ChunkedSlicedAttnProcessor:
|
||||
if attn.upcast_attention:
|
||||
out_item_size = 4
|
||||
|
||||
chunk_size = 2**29
|
||||
chunk_size = 2 ** 29
|
||||
|
||||
out_size = query.shape[1] * key.shape[1] * out_item_size
|
||||
chunks_count = min(query.shape[1], math.ceil((out_size - 1) / chunk_size))
|
||||
|
@ -207,7 +207,7 @@ def parallel_data_prefetch(
|
||||
return gather_res
|
||||
|
||||
|
||||
def rand_perlin_2d(shape, res, device, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
|
||||
def rand_perlin_2d(shape, res, device, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3):
|
||||
delta = (res[0] / shape[0], res[1] / shape[1])
|
||||
d = (shape[0] // res[0], shape[1] // res[1])
|
||||
|
||||
|
@ -54,15 +54,13 @@ def welcome(versions: dict):
|
||||
def text():
|
||||
yield f"InvokeAI Version: [bold yellow]{__version__}"
|
||||
yield ""
|
||||
yield "This script will update InvokeAI to the latest release, or to the development version of your choice."
|
||||
yield ""
|
||||
yield "When updating to an arbitrary tag or branch, be aware that the front end may be mismatched to the backend,"
|
||||
yield "making the web frontend unusable. Please downgrade to the latest release if this happens."
|
||||
yield "This script will update InvokeAI to the latest release, or to a development version of your choice."
|
||||
yield ""
|
||||
yield "[bold yellow]Options:"
|
||||
yield f"""[1] Update to the latest official release ([italic]{versions[0]['tag_name']}[/italic])
|
||||
[2] Manually enter the [bold]tag name[/bold] for the version you wish to update to
|
||||
[3] Manually enter the [bold]branch name[/bold] for the version you wish to update to"""
|
||||
[2] Update to the bleeding-edge development version ([italic]main[/italic])
|
||||
[3] Manually enter the [bold]tag name[/bold] for the version you wish to update to
|
||||
[4] Manually enter the [bold]branch name[/bold] for the version you wish to update to"""
|
||||
|
||||
console.rule()
|
||||
print(
|
||||
@ -106,11 +104,11 @@ def main():
|
||||
if choice == "1":
|
||||
release = versions[0]["tag_name"]
|
||||
elif choice == "2":
|
||||
while not tag:
|
||||
tag = Prompt.ask("Enter an InvokeAI tag name")
|
||||
release = "main"
|
||||
elif choice == "3":
|
||||
while not branch:
|
||||
branch = Prompt.ask("Enter an InvokeAI branch name")
|
||||
tag = Prompt.ask("Enter an InvokeAI tag name")
|
||||
elif choice == "4":
|
||||
branch = Prompt.ask("Enter an InvokeAI branch name")
|
||||
|
||||
extras = get_extras()
|
||||
|
||||
|
@ -75,7 +75,6 @@
|
||||
"@reduxjs/toolkit": "^1.9.5",
|
||||
"@roarr/browser-log-writer": "^1.1.5",
|
||||
"@stevebel/png": "^1.5.1",
|
||||
"compare-versions": "^6.1.0",
|
||||
"dateformat": "^5.0.3",
|
||||
"formik": "^2.4.3",
|
||||
"framer-motion": "^10.16.1",
|
||||
|
@ -511,7 +511,6 @@
|
||||
"maskBlur": "Blur",
|
||||
"maskBlurMethod": "Blur Method",
|
||||
"coherencePassHeader": "Coherence Pass",
|
||||
"coherenceMode": "Mode",
|
||||
"coherenceSteps": "Steps",
|
||||
"coherenceStrength": "Strength",
|
||||
"seamLowThreshold": "Low",
|
||||
@ -521,7 +520,6 @@
|
||||
"scaledHeight": "Scaled H",
|
||||
"infillMethod": "Infill Method",
|
||||
"tileSize": "Tile Size",
|
||||
"patchmatchDownScaleSize": "Downscale",
|
||||
"boundingBoxHeader": "Bounding Box",
|
||||
"seamCorrectionHeader": "Seam Correction",
|
||||
"infillScalingHeader": "Infill and Scaling",
|
||||
|
@ -84,7 +84,6 @@ import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
|
||||
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
|
||||
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
|
||||
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
|
||||
import { addWorkflowLoadedListener } from './listeners/workflowLoaded';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
|
||||
@ -203,9 +202,6 @@ addBoardIdSelectedListener();
|
||||
// Node schemas
|
||||
addReceivedOpenAPISchemaListener();
|
||||
|
||||
// Workflows
|
||||
addWorkflowLoadedListener();
|
||||
|
||||
// DND
|
||||
addImageDroppedListener();
|
||||
|
||||
|
@ -1,55 +0,0 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { workflowLoadRequested } from 'features/nodes/store/actions';
|
||||
import { workflowLoaded } from 'features/nodes/store/nodesSlice';
|
||||
import { $flow } from 'features/nodes/store/reactFlowInstance';
|
||||
import { validateWorkflow } from 'features/nodes/util/validateWorkflow';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||
import { startAppListening } from '..';
|
||||
|
||||
export const addWorkflowLoadedListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: workflowLoadRequested,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const log = logger('nodes');
|
||||
const workflow = action.payload;
|
||||
const nodeTemplates = getState().nodes.nodeTemplates;
|
||||
|
||||
const { workflow: validatedWorkflow, errors } = validateWorkflow(
|
||||
workflow,
|
||||
nodeTemplates
|
||||
);
|
||||
|
||||
dispatch(workflowLoaded(validatedWorkflow));
|
||||
|
||||
if (!errors.length) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: 'Workflow Loaded',
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
} else {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: 'Workflow Loaded with Warnings',
|
||||
status: 'warning',
|
||||
})
|
||||
)
|
||||
);
|
||||
errors.forEach(({ message, ...rest }) => {
|
||||
log.warn(rest, message);
|
||||
});
|
||||
}
|
||||
|
||||
dispatch(setActiveTab('nodes'));
|
||||
requestAnimationFrame(() => {
|
||||
$flow.get()?.fitView();
|
||||
});
|
||||
},
|
||||
});
|
||||
};
|
@ -31,54 +31,48 @@ const selector = createSelector(
|
||||
reasons.push('No initial image selected');
|
||||
}
|
||||
|
||||
if (activeTabName === 'nodes') {
|
||||
if (nodes.shouldValidateGraph) {
|
||||
if (!nodes.nodes.length) {
|
||||
reasons.push('No nodes in graph');
|
||||
if (activeTabName === 'nodes' && nodes.shouldValidateGraph) {
|
||||
if (!nodes.nodes.length) {
|
||||
reasons.push('No nodes in graph');
|
||||
}
|
||||
|
||||
nodes.nodes.forEach((node) => {
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
|
||||
nodes.nodes.forEach((node) => {
|
||||
if (!isInvocationNode(node)) {
|
||||
const nodeTemplate = nodes.nodeTemplates[node.data.type];
|
||||
|
||||
if (!nodeTemplate) {
|
||||
// Node type not found
|
||||
reasons.push('Missing node template');
|
||||
return;
|
||||
}
|
||||
|
||||
const connectedEdges = getConnectedEdges([node], nodes.edges);
|
||||
|
||||
forEach(node.data.inputs, (field) => {
|
||||
const fieldTemplate = nodeTemplate.inputs[field.name];
|
||||
const hasConnection = connectedEdges.some(
|
||||
(edge) =>
|
||||
edge.target === node.id && edge.targetHandle === field.name
|
||||
);
|
||||
|
||||
if (!fieldTemplate) {
|
||||
reasons.push('Missing field template');
|
||||
return;
|
||||
}
|
||||
|
||||
const nodeTemplate = nodes.nodeTemplates[node.data.type];
|
||||
|
||||
if (!nodeTemplate) {
|
||||
// Node type not found
|
||||
reasons.push('Missing node template');
|
||||
return;
|
||||
}
|
||||
|
||||
const connectedEdges = getConnectedEdges([node], nodes.edges);
|
||||
|
||||
forEach(node.data.inputs, (field) => {
|
||||
const fieldTemplate = nodeTemplate.inputs[field.name];
|
||||
const hasConnection = connectedEdges.some(
|
||||
(edge) =>
|
||||
edge.target === node.id && edge.targetHandle === field.name
|
||||
if (fieldTemplate.required && !field.value && !hasConnection) {
|
||||
reasons.push(
|
||||
`${node.data.label || nodeTemplate.title} -> ${
|
||||
field.label || fieldTemplate.title
|
||||
} missing input`
|
||||
);
|
||||
|
||||
if (!fieldTemplate) {
|
||||
reasons.push('Missing field template');
|
||||
return;
|
||||
}
|
||||
|
||||
if (
|
||||
fieldTemplate.required &&
|
||||
field.value === undefined &&
|
||||
!hasConnection
|
||||
) {
|
||||
reasons.push(
|
||||
`${node.data.label || nodeTemplate.title} -> ${
|
||||
field.label || fieldTemplate.title
|
||||
} missing input`
|
||||
);
|
||||
return;
|
||||
}
|
||||
});
|
||||
return;
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
} else {
|
||||
if (!model) {
|
||||
reasons.push('No model selected');
|
||||
|
@ -1,2 +1,2 @@
|
||||
export const colorTokenToCssVar = (colorToken: string) =>
|
||||
`var(--invokeai-colors-${colorToken.split('.').join('-')})`;
|
||||
`var(--invokeai-colors-${colorToken.split('.').join('-')}`;
|
||||
|
@ -118,11 +118,7 @@ const IAICanvasToolChooserOptions = () => {
|
||||
useHotkeys(
|
||||
['BracketLeft'],
|
||||
() => {
|
||||
if (brushSize - 5 <= 5) {
|
||||
dispatch(setBrushSize(Math.max(brushSize - 1, 1)));
|
||||
} else {
|
||||
dispatch(setBrushSize(Math.max(brushSize - 5, 1)));
|
||||
}
|
||||
dispatch(setBrushSize(Math.max(brushSize - 5, 5)));
|
||||
},
|
||||
{
|
||||
enabled: () => !isStaging,
|
||||
|
@ -235,18 +235,10 @@ export const canvasSlice = createSlice({
|
||||
state.boundingBoxDimensions.width,
|
||||
state.boundingBoxDimensions.height,
|
||||
];
|
||||
const [currScaledWidth, currScaledHeight] = [
|
||||
state.scaledBoundingBoxDimensions.width,
|
||||
state.scaledBoundingBoxDimensions.height,
|
||||
];
|
||||
state.boundingBoxDimensions = {
|
||||
width: currHeight,
|
||||
height: currWidth,
|
||||
};
|
||||
state.scaledBoundingBoxDimensions = {
|
||||
width: currScaledHeight,
|
||||
height: currScaledWidth,
|
||||
};
|
||||
},
|
||||
setBoundingBoxCoordinates: (state, action: PayloadAction<Vector2d>) => {
|
||||
state.boundingBoxCoordinates = floorCoordinates(action.payload);
|
||||
@ -796,10 +788,6 @@ export const canvasSlice = createSlice({
|
||||
state.boundingBoxDimensions.width / ratio,
|
||||
64
|
||||
);
|
||||
state.scaledBoundingBoxDimensions.height = roundToMultiple(
|
||||
state.scaledBoundingBoxDimensions.width / ratio,
|
||||
64
|
||||
);
|
||||
}
|
||||
});
|
||||
},
|
||||
|
@ -104,22 +104,22 @@ const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => {
|
||||
]);
|
||||
|
||||
const handleSetControlImageToDimensions = useCallback(() => {
|
||||
if (!controlImage) {
|
||||
if (!processedControlImage) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (activeTabName === 'unifiedCanvas') {
|
||||
dispatch(
|
||||
setBoundingBoxDimensions({
|
||||
width: controlImage.width,
|
||||
height: controlImage.height,
|
||||
width: processedControlImage.width,
|
||||
height: processedControlImage.height,
|
||||
})
|
||||
);
|
||||
} else {
|
||||
dispatch(setWidth(controlImage.width));
|
||||
dispatch(setHeight(controlImage.height));
|
||||
dispatch(setWidth(processedControlImage.width));
|
||||
dispatch(setHeight(processedControlImage.height));
|
||||
}
|
||||
}, [controlImage, activeTabName, dispatch]);
|
||||
}, [processedControlImage, activeTabName, dispatch]);
|
||||
|
||||
const handleMouseEnter = useCallback(() => {
|
||||
setIsMouseOverImage(true);
|
||||
|
@ -17,13 +17,16 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { DeleteImageButton } from 'features/deleteImageModal/components/DeleteImageButton';
|
||||
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
|
||||
import { workflowLoadRequested } from 'features/nodes/store/actions';
|
||||
import { workflowLoaded } from 'features/nodes/store/nodesSlice';
|
||||
import ParamUpscalePopover from 'features/parameters/components/Parameters/Upscale/ParamUpscaleSettings';
|
||||
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
||||
import { initialImageSelected } from 'features/parameters/store/actions';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import {
|
||||
setActiveTab,
|
||||
setShouldShowImageDetails,
|
||||
setShouldShowProgressInViewer,
|
||||
} from 'features/ui/store/uiSlice';
|
||||
@ -107,7 +110,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
);
|
||||
|
||||
const { metadata, workflow, isLoading } = useGetImageMetadataFromFileQuery(
|
||||
lastSelectedImage ?? skipToken,
|
||||
lastSelectedImage?.image_name ?? skipToken,
|
||||
{
|
||||
selectFromResult: (res) => ({
|
||||
isLoading: res.isFetching,
|
||||
@ -121,7 +124,16 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
if (!workflow) {
|
||||
return;
|
||||
}
|
||||
dispatch(workflowLoadRequested(workflow));
|
||||
dispatch(workflowLoaded(workflow));
|
||||
dispatch(setActiveTab('nodes'));
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: 'Workflow Loaded',
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
}, [dispatch, workflow]);
|
||||
|
||||
const handleClickUseAllParameters = useCallback(() => {
|
||||
|
@ -7,9 +7,12 @@ import {
|
||||
isModalOpenChanged,
|
||||
} from 'features/changeBoardModal/store/slice';
|
||||
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
|
||||
import { workflowLoaded } from 'features/nodes/store/nodesSlice';
|
||||
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
||||
import { initialImageSelected } from 'features/parameters/store/actions';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard';
|
||||
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
@ -33,7 +36,6 @@ import {
|
||||
} from 'services/api/endpoints/images';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import { sentImageToCanvas, sentImageToImg2Img } from '../../store/actions';
|
||||
import { workflowLoadRequested } from 'features/nodes/store/actions';
|
||||
|
||||
type SingleSelectionMenuItemsProps = {
|
||||
imageDTO: ImageDTO;
|
||||
@ -50,7 +52,7 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
||||
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
|
||||
|
||||
const { metadata, workflow, isLoading } = useGetImageMetadataFromFileQuery(
|
||||
imageDTO,
|
||||
imageDTO.image_name,
|
||||
{
|
||||
selectFromResult: (res) => ({
|
||||
isLoading: res.isFetching,
|
||||
@ -100,7 +102,16 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
||||
if (!workflow) {
|
||||
return;
|
||||
}
|
||||
dispatch(workflowLoadRequested(workflow));
|
||||
dispatch(workflowLoaded(workflow));
|
||||
dispatch(setActiveTab('nodes'));
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: 'Workflow Loaded',
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
}, [dispatch, workflow]);
|
||||
|
||||
const handleSendToImageToImage = useCallback(() => {
|
||||
|
@ -101,15 +101,13 @@ const ImageMetadataActions = (props: Props) => {
|
||||
onClick={handleRecallSeed}
|
||||
/>
|
||||
)}
|
||||
{metadata.model !== undefined &&
|
||||
metadata.model !== null &&
|
||||
metadata.model.model_name && (
|
||||
<ImageMetadataItem
|
||||
label="Model"
|
||||
value={metadata.model.model_name}
|
||||
onClick={handleRecallModel}
|
||||
/>
|
||||
)}
|
||||
{metadata.model !== undefined && metadata.model !== null && (
|
||||
<ImageMetadataItem
|
||||
label="Model"
|
||||
value={metadata.model.model_name}
|
||||
onClick={handleRecallModel}
|
||||
/>
|
||||
)}
|
||||
{metadata.width && (
|
||||
<ImageMetadataItem
|
||||
label="Width"
|
||||
|
@ -27,12 +27,15 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
|
||||
// dispatch(setShouldShowImageDetails(false));
|
||||
// });
|
||||
|
||||
const { metadata, workflow } = useGetImageMetadataFromFileQuery(image, {
|
||||
selectFromResult: (res) => ({
|
||||
metadata: res?.currentData?.metadata,
|
||||
workflow: res?.currentData?.workflow,
|
||||
}),
|
||||
});
|
||||
const { metadata, workflow } = useGetImageMetadataFromFileQuery(
|
||||
image.image_name,
|
||||
{
|
||||
selectFromResult: (res) => ({
|
||||
metadata: res?.currentData?.metadata,
|
||||
workflow: res?.currentData?.workflow,
|
||||
}),
|
||||
}
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
|
@ -3,7 +3,6 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { $flow } from 'features/nodes/store/reactFlowInstance';
|
||||
import { contextMenusClosed } from 'features/ui/store/uiSlice';
|
||||
import { useCallback } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
@ -14,7 +13,6 @@ import {
|
||||
OnConnectStart,
|
||||
OnEdgesChange,
|
||||
OnEdgesDelete,
|
||||
OnInit,
|
||||
OnMoveEnd,
|
||||
OnNodesChange,
|
||||
OnNodesDelete,
|
||||
@ -149,11 +147,6 @@ export const Flow = () => {
|
||||
dispatch(contextMenusClosed());
|
||||
}, [dispatch]);
|
||||
|
||||
const onInit: OnInit = useCallback((flow) => {
|
||||
$flow.set(flow);
|
||||
flow.fitView();
|
||||
}, []);
|
||||
|
||||
useHotkeys(['Ctrl+c', 'Meta+c'], (e) => {
|
||||
e.preventDefault();
|
||||
dispatch(selectionCopied());
|
||||
@ -177,7 +170,6 @@ export const Flow = () => {
|
||||
edgeTypes={edgeTypes}
|
||||
nodes={nodes}
|
||||
edges={edges}
|
||||
onInit={onInit}
|
||||
onNodesChange={onNodesChange}
|
||||
onEdgesChange={onEdgesChange}
|
||||
onEdgesDelete={onEdgesDelete}
|
||||
|
@ -12,7 +12,6 @@ import {
|
||||
Tooltip,
|
||||
useDisclosure,
|
||||
} from '@chakra-ui/react';
|
||||
import { compare } from 'compare-versions';
|
||||
import { useNodeData } from 'features/nodes/hooks/useNodeData';
|
||||
import { useNodeLabel } from 'features/nodes/hooks/useNodeLabel';
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
@ -21,7 +20,6 @@ import { isInvocationNodeData } from 'features/nodes/types/types';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { FaInfoCircle } from 'react-icons/fa';
|
||||
import NotesTextarea from './NotesTextarea';
|
||||
import { useDoNodeVersionsMatch } from 'features/nodes/hooks/useDoNodeVersionsMatch';
|
||||
|
||||
interface Props {
|
||||
nodeId: string;
|
||||
@ -31,7 +29,6 @@ const InvocationNodeNotes = ({ nodeId }: Props) => {
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
const label = useNodeLabel(nodeId);
|
||||
const title = useNodeTemplateTitle(nodeId);
|
||||
const doVersionsMatch = useDoNodeVersionsMatch(nodeId);
|
||||
|
||||
return (
|
||||
<>
|
||||
@ -53,11 +50,7 @@ const InvocationNodeNotes = ({ nodeId }: Props) => {
|
||||
>
|
||||
<Icon
|
||||
as={FaInfoCircle}
|
||||
sx={{
|
||||
boxSize: 4,
|
||||
w: 8,
|
||||
color: doVersionsMatch ? 'base.400' : 'error.400',
|
||||
}}
|
||||
sx={{ boxSize: 4, w: 8, color: 'base.400' }}
|
||||
/>
|
||||
</Flex>
|
||||
</Tooltip>
|
||||
@ -99,59 +92,16 @@ const TooltipContent = memo(({ nodeId }: { nodeId: string }) => {
|
||||
return 'Unknown Node';
|
||||
}, [data, nodeTemplate]);
|
||||
|
||||
const versionComponent = useMemo(() => {
|
||||
if (!isInvocationNodeData(data) || !nodeTemplate) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (!data.version) {
|
||||
return (
|
||||
<Text as="span" sx={{ color: 'error.500' }}>
|
||||
Version unknown
|
||||
</Text>
|
||||
);
|
||||
}
|
||||
|
||||
if (!nodeTemplate.version) {
|
||||
return (
|
||||
<Text as="span" sx={{ color: 'error.500' }}>
|
||||
Version {data.version} (unknown template)
|
||||
</Text>
|
||||
);
|
||||
}
|
||||
|
||||
if (compare(data.version, nodeTemplate.version, '<')) {
|
||||
return (
|
||||
<Text as="span" sx={{ color: 'error.500' }}>
|
||||
Version {data.version} (update node)
|
||||
</Text>
|
||||
);
|
||||
}
|
||||
|
||||
if (compare(data.version, nodeTemplate.version, '>')) {
|
||||
return (
|
||||
<Text as="span" sx={{ color: 'error.500' }}>
|
||||
Version {data.version} (update app)
|
||||
</Text>
|
||||
);
|
||||
}
|
||||
|
||||
return <Text as="span">Version {data.version}</Text>;
|
||||
}, [data, nodeTemplate]);
|
||||
|
||||
if (!isInvocationNodeData(data)) {
|
||||
return <Text sx={{ fontWeight: 600 }}>Unknown Node</Text>;
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex sx={{ flexDir: 'column' }}>
|
||||
<Text as="span" sx={{ fontWeight: 600 }}>
|
||||
{title}
|
||||
</Text>
|
||||
<Text sx={{ fontWeight: 600 }}>{title}</Text>
|
||||
<Text sx={{ opacity: 0.7, fontStyle: 'oblique 5deg' }}>
|
||||
{nodeTemplate?.description}
|
||||
</Text>
|
||||
{versionComponent}
|
||||
{data?.notes && <Text>{data.notes}</Text>}
|
||||
</Flex>
|
||||
);
|
||||
|
@ -1,11 +1,8 @@
|
||||
import { Tooltip } from '@chakra-ui/react';
|
||||
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
|
||||
import {
|
||||
COLLECTION_TYPES,
|
||||
FIELDS,
|
||||
HANDLE_TOOLTIP_OPEN_DELAY,
|
||||
MODEL_TYPES,
|
||||
POLYMORPHIC_TYPES,
|
||||
} from 'features/nodes/types/constants';
|
||||
import {
|
||||
InputFieldTemplate,
|
||||
@ -21,7 +18,6 @@ export const handleBaseStyles: CSSProperties = {
|
||||
borderWidth: 0,
|
||||
zIndex: 1,
|
||||
};
|
||||
``;
|
||||
|
||||
export const inputHandleStyles: CSSProperties = {
|
||||
left: '-1rem',
|
||||
@ -48,25 +44,15 @@ const FieldHandle = (props: FieldHandleProps) => {
|
||||
connectionError,
|
||||
} = props;
|
||||
const { name, type } = fieldTemplate;
|
||||
const { color: typeColor, title } = FIELDS[type];
|
||||
const { color, title } = FIELDS[type];
|
||||
|
||||
const styles: CSSProperties = useMemo(() => {
|
||||
const isCollectionType = COLLECTION_TYPES.includes(type);
|
||||
const isPolymorphicType = POLYMORPHIC_TYPES.includes(type);
|
||||
const isModelType = MODEL_TYPES.includes(type);
|
||||
const color = colorTokenToCssVar(typeColor);
|
||||
const s: CSSProperties = {
|
||||
backgroundColor:
|
||||
isCollectionType || isPolymorphicType
|
||||
? 'var(--invokeai-colors-base-900)'
|
||||
: color,
|
||||
backgroundColor: colorTokenToCssVar(color),
|
||||
position: 'absolute',
|
||||
width: '1rem',
|
||||
height: '1rem',
|
||||
borderWidth: isCollectionType || isPolymorphicType ? 4 : 0,
|
||||
borderStyle: 'solid',
|
||||
borderColor: color,
|
||||
borderRadius: isModelType ? 4 : '100%',
|
||||
borderWidth: 0,
|
||||
zIndex: 1,
|
||||
};
|
||||
|
||||
@ -92,12 +78,11 @@ const FieldHandle = (props: FieldHandleProps) => {
|
||||
|
||||
return s;
|
||||
}, [
|
||||
color,
|
||||
connectionError,
|
||||
handleType,
|
||||
isConnectionInProgress,
|
||||
isConnectionStartField,
|
||||
type,
|
||||
typeColor,
|
||||
]);
|
||||
|
||||
const tooltip = useMemo(() => {
|
||||
|
@ -75,7 +75,6 @@ const InputField = ({ nodeId, fieldName }: Props) => {
|
||||
sx={{
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
h: 'full',
|
||||
mb: 0,
|
||||
px: 1,
|
||||
gap: 2,
|
||||
|
@ -3,10 +3,18 @@ import { useFieldData } from 'features/nodes/hooks/useFieldData';
|
||||
import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate';
|
||||
import { memo } from 'react';
|
||||
import BooleanInputField from './inputs/BooleanInputField';
|
||||
import ClipInputField from './inputs/ClipInputField';
|
||||
import CollectionInputField from './inputs/CollectionInputField';
|
||||
import CollectionItemInputField from './inputs/CollectionItemInputField';
|
||||
import ColorInputField from './inputs/ColorInputField';
|
||||
import ConditioningInputField from './inputs/ConditioningInputField';
|
||||
import ControlInputField from './inputs/ControlInputField';
|
||||
import ControlNetModelInputField from './inputs/ControlNetModelInputField';
|
||||
import DenoiseMaskInputField from './inputs/DenoiseMaskInputField';
|
||||
import EnumInputField from './inputs/EnumInputField';
|
||||
import ImageCollectionInputField from './inputs/ImageCollectionInputField';
|
||||
import ImageInputField from './inputs/ImageInputField';
|
||||
import LatentsInputField from './inputs/LatentsInputField';
|
||||
import LoRAModelInputField from './inputs/LoRAModelInputField';
|
||||
import MainModelInputField from './inputs/MainModelInputField';
|
||||
import NumberInputField from './inputs/NumberInputField';
|
||||
@ -14,6 +22,8 @@ import RefinerModelInputField from './inputs/RefinerModelInputField';
|
||||
import SDXLMainModelInputField from './inputs/SDXLMainModelInputField';
|
||||
import SchedulerInputField from './inputs/SchedulerInputField';
|
||||
import StringInputField from './inputs/StringInputField';
|
||||
import UnetInputField from './inputs/UnetInputField';
|
||||
import VaeInputField from './inputs/VaeInputField';
|
||||
import VaeModelInputField from './inputs/VaeModelInputField';
|
||||
|
||||
type InputFieldProps = {
|
||||
@ -21,6 +31,7 @@ type InputFieldProps = {
|
||||
fieldName: string;
|
||||
};
|
||||
|
||||
// build an individual input element based on the schema
|
||||
const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
const field = useFieldData(nodeId, fieldName);
|
||||
const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input');
|
||||
@ -82,6 +93,88 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'LatentsField' &&
|
||||
fieldTemplate?.type === 'LatentsField'
|
||||
) {
|
||||
return (
|
||||
<LatentsInputField
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'DenoiseMaskField' &&
|
||||
fieldTemplate?.type === 'DenoiseMaskField'
|
||||
) {
|
||||
return (
|
||||
<DenoiseMaskInputField
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'ConditioningField' &&
|
||||
fieldTemplate?.type === 'ConditioningField'
|
||||
) {
|
||||
return (
|
||||
<ConditioningInputField
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (field?.type === 'UNetField' && fieldTemplate?.type === 'UNetField') {
|
||||
return (
|
||||
<UnetInputField
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (field?.type === 'ClipField' && fieldTemplate?.type === 'ClipField') {
|
||||
return (
|
||||
<ClipInputField
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (field?.type === 'VaeField' && fieldTemplate?.type === 'VaeField') {
|
||||
return (
|
||||
<VaeInputField
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'ControlField' &&
|
||||
fieldTemplate?.type === 'ControlField'
|
||||
) {
|
||||
return (
|
||||
<ControlInputField
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'MainModelField' &&
|
||||
fieldTemplate?.type === 'MainModelField'
|
||||
@ -147,6 +240,29 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
);
|
||||
}
|
||||
|
||||
if (field?.type === 'Collection' && fieldTemplate?.type === 'Collection') {
|
||||
return (
|
||||
<CollectionInputField
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'CollectionItem' &&
|
||||
fieldTemplate?.type === 'CollectionItem'
|
||||
) {
|
||||
return (
|
||||
<CollectionItemInputField
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (field?.type === 'ColorField' && fieldTemplate?.type === 'ColorField') {
|
||||
return (
|
||||
<ColorInputField
|
||||
@ -157,6 +273,19 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'ImageCollection' &&
|
||||
fieldTemplate?.type === 'ImageCollection'
|
||||
) {
|
||||
return (
|
||||
<ImageCollectionInputField
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'SDXLMainModelField' &&
|
||||
fieldTemplate?.type === 'SDXLMainModelField'
|
||||
@ -180,11 +309,6 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
);
|
||||
}
|
||||
|
||||
if (field && fieldTemplate) {
|
||||
// Fallback for when there is no component for the type
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Box p={1}>
|
||||
<Text
|
||||
|
@ -1,17 +1,12 @@
|
||||
import {
|
||||
ControlInputFieldTemplate,
|
||||
ControlInputFieldValue,
|
||||
ControlPolymorphicInputFieldTemplate,
|
||||
ControlPolymorphicInputFieldValue,
|
||||
FieldComponentProps,
|
||||
} from 'features/nodes/types/types';
|
||||
import { memo } from 'react';
|
||||
|
||||
const ControlInputFieldComponent = (
|
||||
_props: FieldComponentProps<
|
||||
ControlInputFieldValue | ControlPolymorphicInputFieldValue,
|
||||
ControlInputFieldTemplate | ControlPolymorphicInputFieldTemplate
|
||||
>
|
||||
_props: FieldComponentProps<ControlInputFieldValue, ControlInputFieldTemplate>
|
||||
) => {
|
||||
return null;
|
||||
};
|
||||
|
@ -9,9 +9,9 @@ import {
|
||||
} from 'features/dnd/types';
|
||||
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
FieldComponentProps,
|
||||
ImageInputFieldTemplate,
|
||||
ImageInputFieldValue,
|
||||
FieldComponentProps,
|
||||
} from 'features/nodes/types/types';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { FaUndo } from 'react-icons/fa';
|
||||
|
@ -2,16 +2,11 @@ import {
|
||||
LatentsInputFieldTemplate,
|
||||
LatentsInputFieldValue,
|
||||
FieldComponentProps,
|
||||
LatentsPolymorphicInputFieldValue,
|
||||
LatentsPolymorphicInputFieldTemplate,
|
||||
} from 'features/nodes/types/types';
|
||||
import { memo } from 'react';
|
||||
|
||||
const LatentsInputFieldComponent = (
|
||||
_props: FieldComponentProps<
|
||||
LatentsInputFieldValue | LatentsPolymorphicInputFieldValue,
|
||||
LatentsInputFieldTemplate | LatentsPolymorphicInputFieldTemplate
|
||||
>
|
||||
_props: FieldComponentProps<LatentsInputFieldValue, LatentsInputFieldTemplate>
|
||||
) => {
|
||||
return null;
|
||||
};
|
||||
|
@ -9,11 +9,11 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { numberStringRegex } from 'common/components/IAINumberInput';
|
||||
import { fieldNumberValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
FieldComponentProps,
|
||||
FloatInputFieldTemplate,
|
||||
FloatInputFieldValue,
|
||||
IntegerInputFieldTemplate,
|
||||
IntegerInputFieldValue,
|
||||
FieldComponentProps,
|
||||
} from 'features/nodes/types/types';
|
||||
import { memo, useEffect, useMemo, useState } from 'react';
|
||||
|
||||
|
@ -9,20 +9,13 @@ import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import NodeSelectionOverlay from 'common/components/NodeSelectionOverlay';
|
||||
import { useMouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
|
||||
import { nodeExclusivelySelected } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
DRAG_HANDLE_CLASSNAME,
|
||||
NODE_WIDTH,
|
||||
} from 'features/nodes/types/constants';
|
||||
import { NodeStatus } from 'features/nodes/types/types';
|
||||
import { contextMenusClosed } from 'features/ui/store/uiSlice';
|
||||
import {
|
||||
MouseEvent,
|
||||
PropsWithChildren,
|
||||
memo,
|
||||
useCallback,
|
||||
useMemo,
|
||||
} from 'react';
|
||||
import { PropsWithChildren, memo, useCallback, useMemo } from 'react';
|
||||
|
||||
type NodeWrapperProps = PropsWithChildren & {
|
||||
nodeId: string;
|
||||
@ -64,15 +57,9 @@ const NodeWrapper = (props: NodeWrapperProps) => {
|
||||
|
||||
const opacity = useAppSelector((state) => state.nodes.nodeOpacity);
|
||||
|
||||
const handleClick = useCallback(
|
||||
(e: MouseEvent<HTMLDivElement>) => {
|
||||
if (!e.ctrlKey && !e.altKey && !e.metaKey && !e.shiftKey) {
|
||||
dispatch(nodeExclusivelySelected(nodeId));
|
||||
}
|
||||
dispatch(contextMenusClosed());
|
||||
},
|
||||
[dispatch, nodeId]
|
||||
);
|
||||
const handleClick = useCallback(() => {
|
||||
dispatch(contextMenusClosed());
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<Box
|
||||
|
@ -138,14 +138,13 @@ export const useBuildNodeData = () => {
|
||||
data: {
|
||||
id: nodeId,
|
||||
type,
|
||||
version: template.version,
|
||||
label: '',
|
||||
notes: '',
|
||||
isOpen: true,
|
||||
embedWorkflow: false,
|
||||
isIntermediate: true,
|
||||
inputs,
|
||||
outputs,
|
||||
isOpen: true,
|
||||
label: '',
|
||||
notes: '',
|
||||
embedWorkflow: false,
|
||||
isIntermediate: true,
|
||||
},
|
||||
};
|
||||
|
||||
|
@ -1,33 +0,0 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { compareVersions } from 'compare-versions';
|
||||
import { useMemo } from 'react';
|
||||
import { isInvocationNode } from '../types/types';
|
||||
|
||||
export const useDoNodeVersionsMatch = (nodeId: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
stateSelector,
|
||||
({ nodes }) => {
|
||||
const node = nodes.nodes.find((node) => node.id === nodeId);
|
||||
if (!isInvocationNode(node)) {
|
||||
return false;
|
||||
}
|
||||
const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? ''];
|
||||
if (!nodeTemplate?.version || !node.data?.version) {
|
||||
return false;
|
||||
}
|
||||
return compareVersions(nodeTemplate.version, node.data.version) === 0;
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
[nodeId]
|
||||
);
|
||||
|
||||
const nodeTemplate = useAppSelector(selector);
|
||||
|
||||
return nodeTemplate;
|
||||
};
|
@ -15,7 +15,7 @@ export const useDoesInputHaveValue = (nodeId: string, fieldName: string) => {
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
return node?.data.inputs[fieldName]?.value !== undefined;
|
||||
return Boolean(node?.data.inputs[fieldName]?.value);
|
||||
},
|
||||
defaultSelectorOptions
|
||||
),
|
||||
|
@ -3,19 +3,9 @@ import graphlib from '@dagrejs/graphlib';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useCallback } from 'react';
|
||||
import { Connection, Edge, Node, useReactFlow } from 'reactflow';
|
||||
import {
|
||||
COLLECTION_MAP,
|
||||
COLLECTION_TYPES,
|
||||
POLYMORPHIC_TO_SINGLE_MAP,
|
||||
POLYMORPHIC_TYPES,
|
||||
} from '../types/constants';
|
||||
import { COLLECTION_TYPES } from '../types/constants';
|
||||
import { InvocationNodeData } from '../types/types';
|
||||
|
||||
/**
|
||||
* NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts`
|
||||
* TODO: Figure out how to do this without duplicating all the logic
|
||||
*/
|
||||
|
||||
export const useIsValidConnection = () => {
|
||||
const flow = useReactFlow();
|
||||
const shouldValidateGraph = useAppSelector(
|
||||
@ -52,19 +42,6 @@ export const useIsValidConnection = () => {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (
|
||||
edges
|
||||
.filter((edge) => {
|
||||
return edge.target === target && edge.targetHandle === targetHandle;
|
||||
})
|
||||
.find((edge) => {
|
||||
edge.source === source && edge.sourceHandle === sourceHandle;
|
||||
})
|
||||
) {
|
||||
// We already have a connection from this source to this target
|
||||
return false;
|
||||
}
|
||||
|
||||
// Connection is invalid if target already has a connection
|
||||
if (
|
||||
edges.find((edge) => {
|
||||
@ -76,62 +53,21 @@ export const useIsValidConnection = () => {
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Connection types must be the same for a connection, with exceptions:
|
||||
* - CollectionItem can connect to any non-Collection
|
||||
* - Non-Collections can connect to CollectionItem
|
||||
* - Anything (non-Collections, Collections, Polymorphics) can connect to Polymorphics of the same base type
|
||||
* - Generic Collection can connect to any other Collection or Polymorphic
|
||||
* - Any Collection can connect to a Generic Collection
|
||||
*/
|
||||
|
||||
if (sourceType !== targetType) {
|
||||
const isCollectionItemToNonCollection =
|
||||
sourceType === 'CollectionItem' &&
|
||||
!COLLECTION_TYPES.includes(targetType);
|
||||
|
||||
const isNonCollectionToCollectionItem =
|
||||
targetType === 'CollectionItem' &&
|
||||
!COLLECTION_TYPES.includes(sourceType) &&
|
||||
!POLYMORPHIC_TYPES.includes(sourceType);
|
||||
|
||||
const isAnythingToPolymorphicOfSameBaseType =
|
||||
POLYMORPHIC_TYPES.includes(targetType) &&
|
||||
(() => {
|
||||
if (!POLYMORPHIC_TYPES.includes(targetType)) {
|
||||
return false;
|
||||
}
|
||||
const baseType =
|
||||
POLYMORPHIC_TO_SINGLE_MAP[
|
||||
targetType as keyof typeof POLYMORPHIC_TO_SINGLE_MAP
|
||||
];
|
||||
|
||||
const collectionType =
|
||||
COLLECTION_MAP[baseType as keyof typeof COLLECTION_MAP];
|
||||
|
||||
return sourceType === baseType || sourceType === collectionType;
|
||||
})();
|
||||
|
||||
const isGenericCollectionToAnyCollectionOrPolymorphic =
|
||||
sourceType === 'Collection' &&
|
||||
(COLLECTION_TYPES.includes(targetType) ||
|
||||
POLYMORPHIC_TYPES.includes(targetType));
|
||||
|
||||
const isCollectionToGenericCollection =
|
||||
targetType === 'Collection' && COLLECTION_TYPES.includes(sourceType);
|
||||
|
||||
const isIntToFloat = sourceType === 'integer' && targetType === 'float';
|
||||
|
||||
return (
|
||||
isCollectionItemToNonCollection ||
|
||||
isNonCollectionToCollectionItem ||
|
||||
isAnythingToPolymorphicOfSameBaseType ||
|
||||
isGenericCollectionToAnyCollectionOrPolymorphic ||
|
||||
isCollectionToGenericCollection ||
|
||||
isIntToFloat
|
||||
);
|
||||
// Connection types must be the same for a connection
|
||||
if (
|
||||
sourceType !== targetType &&
|
||||
sourceType !== 'CollectionItem' &&
|
||||
targetType !== 'CollectionItem'
|
||||
) {
|
||||
if (
|
||||
!(
|
||||
COLLECTION_TYPES.includes(targetType) &&
|
||||
COLLECTION_TYPES.includes(sourceType)
|
||||
)
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Graphs much be acyclic (no loops!)
|
||||
return getIsGraphAcyclic(source, target, nodes, edges);
|
||||
},
|
||||
|
@ -2,13 +2,13 @@ import { ListItem, Text, UnorderedList } from '@chakra-ui/react';
|
||||
import { useLogger } from 'app/logging/useLogger';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { zWorkflow } from 'features/nodes/types/types';
|
||||
import { workflowLoaded } from 'features/nodes/store/nodesSlice';
|
||||
import { zValidatedWorkflow } from 'features/nodes/types/types';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { ZodError } from 'zod';
|
||||
import { fromZodError, fromZodIssue } from 'zod-validation-error';
|
||||
import { workflowLoadRequested } from '../store/actions';
|
||||
|
||||
export const useLoadWorkflowFromFile = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
@ -24,7 +24,7 @@ export const useLoadWorkflowFromFile = () => {
|
||||
|
||||
try {
|
||||
const parsedJSON = JSON.parse(String(rawJSON));
|
||||
const result = zWorkflow.safeParse(parsedJSON);
|
||||
const result = zValidatedWorkflow.safeParse(parsedJSON);
|
||||
|
||||
if (!result.success) {
|
||||
const { message } = fromZodError(result.error, {
|
||||
@ -45,8 +45,32 @@ export const useLoadWorkflowFromFile = () => {
|
||||
reader.abort();
|
||||
return;
|
||||
}
|
||||
dispatch(workflowLoaded(result.data.workflow));
|
||||
|
||||
dispatch(workflowLoadRequested(result.data));
|
||||
if (!result.data.warnings.length) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: 'Workflow Loaded',
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
reader.abort();
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: 'Workflow Loaded with Warnings',
|
||||
status: 'warning',
|
||||
})
|
||||
)
|
||||
);
|
||||
result.data.warnings.forEach(({ message, ...rest }) => {
|
||||
logger.warn(rest, message);
|
||||
});
|
||||
|
||||
reader.abort();
|
||||
} catch {
|
||||
|
@ -1,6 +1,5 @@
|
||||
import { createAction, isAnyOf } from '@reduxjs/toolkit';
|
||||
import { Graph } from 'services/api/types';
|
||||
import { Workflow } from '../types/types';
|
||||
|
||||
export const textToImageGraphBuilt = createAction<Graph>(
|
||||
'nodes/textToImageGraphBuilt'
|
||||
@ -17,7 +16,3 @@ export const isAnyGraphBuilt = isAnyOf(
|
||||
canvasGraphBuilt,
|
||||
nodesGraphBuilt
|
||||
);
|
||||
|
||||
export const workflowLoadRequested = createAction<Workflow>(
|
||||
'nodes/workflowLoadRequested'
|
||||
);
|
||||
|
@ -443,17 +443,6 @@ const nodesSlice = createSlice({
|
||||
}
|
||||
node.data.notes = notes;
|
||||
},
|
||||
nodeExclusivelySelected: (state, action: PayloadAction<string>) => {
|
||||
const nodeId = action.payload;
|
||||
state.nodes = applyNodeChanges(
|
||||
state.nodes.map((n) => ({
|
||||
id: n.id,
|
||||
type: 'select',
|
||||
selected: n.id === nodeId ? true : false,
|
||||
})),
|
||||
state.nodes
|
||||
);
|
||||
},
|
||||
selectedNodesChanged: (state, action: PayloadAction<string[]>) => {
|
||||
state.selectedNodes = action.payload;
|
||||
},
|
||||
@ -903,7 +892,6 @@ export const {
|
||||
nodeEmbedWorkflowChanged,
|
||||
nodeIsIntermediateChanged,
|
||||
mouseOverNodeChanged,
|
||||
nodeExclusivelySelected,
|
||||
} = nodesSlice.actions;
|
||||
|
||||
export default nodesSlice.reducer;
|
||||
|
@ -1,4 +0,0 @@
|
||||
import { atom } from 'nanostores';
|
||||
import { ReactFlowInstance } from 'reactflow';
|
||||
|
||||
export const $flow = atom<ReactFlowInstance | null>(null);
|
@ -1,20 +1,10 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { getIsGraphAcyclic } from 'features/nodes/hooks/useIsValidConnection';
|
||||
import {
|
||||
COLLECTION_MAP,
|
||||
COLLECTION_TYPES,
|
||||
POLYMORPHIC_TO_SINGLE_MAP,
|
||||
POLYMORPHIC_TYPES,
|
||||
} from 'features/nodes/types/constants';
|
||||
import { COLLECTION_TYPES } from 'features/nodes/types/constants';
|
||||
import { FieldType } from 'features/nodes/types/types';
|
||||
import { HandleType } from 'reactflow';
|
||||
|
||||
/**
|
||||
* NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts`
|
||||
* TODO: Figure out how to do this without duplicating all the logic
|
||||
*/
|
||||
|
||||
export const makeConnectionErrorSelector = (
|
||||
nodeId: string,
|
||||
fieldName: string,
|
||||
@ -29,6 +19,11 @@ export const makeConnectionErrorSelector = (
|
||||
const { currentConnectionFieldType, connectionStartParams, nodes, edges } =
|
||||
state.nodes;
|
||||
|
||||
if (!state.nodes.shouldValidateGraph) {
|
||||
// manual override!
|
||||
return null;
|
||||
}
|
||||
|
||||
if (!connectionStartParams || !currentConnectionFieldType) {
|
||||
return 'No connection in progress';
|
||||
}
|
||||
@ -43,9 +38,9 @@ export const makeConnectionErrorSelector = (
|
||||
return 'No connection data';
|
||||
}
|
||||
|
||||
const targetType =
|
||||
const targetFieldType =
|
||||
handleType === 'target' ? fieldType : currentConnectionFieldType;
|
||||
const sourceType =
|
||||
const sourceFieldType =
|
||||
handleType === 'source' ? fieldType : currentConnectionFieldType;
|
||||
|
||||
if (nodeId === connectionNodeId) {
|
||||
@ -60,73 +55,30 @@ export const makeConnectionErrorSelector = (
|
||||
}
|
||||
|
||||
if (
|
||||
fieldType !== currentConnectionFieldType &&
|
||||
fieldType !== 'CollectionItem' &&
|
||||
currentConnectionFieldType !== 'CollectionItem'
|
||||
) {
|
||||
if (
|
||||
!(
|
||||
COLLECTION_TYPES.includes(targetFieldType) &&
|
||||
COLLECTION_TYPES.includes(sourceFieldType)
|
||||
)
|
||||
) {
|
||||
// except for collection items, field types must match
|
||||
return 'Field types must match';
|
||||
}
|
||||
}
|
||||
|
||||
if (
|
||||
handleType === 'target' &&
|
||||
edges.find((edge) => {
|
||||
return edge.target === nodeId && edge.targetHandle === fieldName;
|
||||
}) &&
|
||||
// except CollectionItem inputs can have multiples
|
||||
targetType !== 'CollectionItem'
|
||||
targetFieldType !== 'CollectionItem'
|
||||
) {
|
||||
return 'Input may only have one connection';
|
||||
}
|
||||
|
||||
/**
|
||||
* Connection types must be the same for a connection, with exceptions:
|
||||
* - CollectionItem can connect to any non-Collection
|
||||
* - Non-Collections can connect to CollectionItem
|
||||
* - Anything (non-Collections, Collections, Polymorphics) can connect to Polymorphics of the same base type
|
||||
* - Generic Collection can connect to any other Collection or Polymorphic
|
||||
* - Any Collection can connect to a Generic Collection
|
||||
*/
|
||||
|
||||
if (sourceType !== targetType) {
|
||||
const isCollectionItemToNonCollection =
|
||||
sourceType === 'CollectionItem' &&
|
||||
!COLLECTION_TYPES.includes(targetType);
|
||||
|
||||
const isNonCollectionToCollectionItem =
|
||||
targetType === 'CollectionItem' &&
|
||||
!COLLECTION_TYPES.includes(sourceType) &&
|
||||
!POLYMORPHIC_TYPES.includes(sourceType);
|
||||
|
||||
const isAnythingToPolymorphicOfSameBaseType =
|
||||
POLYMORPHIC_TYPES.includes(targetType) &&
|
||||
(() => {
|
||||
if (!POLYMORPHIC_TYPES.includes(targetType)) {
|
||||
return false;
|
||||
}
|
||||
const baseType =
|
||||
POLYMORPHIC_TO_SINGLE_MAP[
|
||||
targetType as keyof typeof POLYMORPHIC_TO_SINGLE_MAP
|
||||
];
|
||||
|
||||
const collectionType =
|
||||
COLLECTION_MAP[baseType as keyof typeof COLLECTION_MAP];
|
||||
|
||||
return sourceType === baseType || sourceType === collectionType;
|
||||
})();
|
||||
|
||||
const isGenericCollectionToAnyCollectionOrPolymorphic =
|
||||
sourceType === 'Collection' &&
|
||||
(COLLECTION_TYPES.includes(targetType) ||
|
||||
POLYMORPHIC_TYPES.includes(targetType));
|
||||
|
||||
const isCollectionToGenericCollection =
|
||||
targetType === 'Collection' && COLLECTION_TYPES.includes(sourceType);
|
||||
|
||||
const isIntToFloat = sourceType === 'integer' && targetType === 'float';
|
||||
|
||||
if (
|
||||
!(
|
||||
isCollectionItemToNonCollection ||
|
||||
isNonCollectionToCollectionItem ||
|
||||
isAnythingToPolymorphicOfSameBaseType ||
|
||||
isGenericCollectionToAnyCollectionOrPolymorphic ||
|
||||
isCollectionToGenericCollection ||
|
||||
isIntToFloat
|
||||
)
|
||||
) {
|
||||
return 'Field types must match';
|
||||
}
|
||||
return 'Inputs may only have one connection';
|
||||
}
|
||||
|
||||
const isGraphAcyclic = getIsGraphAcyclic(
|
||||
|
@ -17,297 +17,176 @@ export const KIND_MAP = {
|
||||
export const COLLECTION_TYPES: FieldType[] = [
|
||||
'Collection',
|
||||
'IntegerCollection',
|
||||
'BooleanCollection',
|
||||
'FloatCollection',
|
||||
'StringCollection',
|
||||
'BooleanCollection',
|
||||
'ImageCollection',
|
||||
'LatentsCollection',
|
||||
'ConditioningCollection',
|
||||
'ControlCollection',
|
||||
'ColorCollection',
|
||||
];
|
||||
|
||||
export const POLYMORPHIC_TYPES = [
|
||||
'IntegerPolymorphic',
|
||||
'BooleanPolymorphic',
|
||||
'FloatPolymorphic',
|
||||
'StringPolymorphic',
|
||||
'ImagePolymorphic',
|
||||
'LatentsPolymorphic',
|
||||
'ConditioningPolymorphic',
|
||||
'ControlPolymorphic',
|
||||
'ColorPolymorphic',
|
||||
];
|
||||
|
||||
export const MODEL_TYPES = [
|
||||
'ControlNetModelField',
|
||||
'LoRAModelField',
|
||||
'MainModelField',
|
||||
'ONNXModelField',
|
||||
'SDXLMainModelField',
|
||||
'SDXLRefinerModelField',
|
||||
'VaeModelField',
|
||||
'UNetField',
|
||||
'VaeField',
|
||||
'ClipField',
|
||||
];
|
||||
|
||||
export const COLLECTION_MAP = {
|
||||
integer: 'IntegerCollection',
|
||||
boolean: 'BooleanCollection',
|
||||
number: 'FloatCollection',
|
||||
float: 'FloatCollection',
|
||||
string: 'StringCollection',
|
||||
ImageField: 'ImageCollection',
|
||||
LatentsField: 'LatentsCollection',
|
||||
ConditioningField: 'ConditioningCollection',
|
||||
ControlField: 'ControlCollection',
|
||||
ColorField: 'ColorCollection',
|
||||
};
|
||||
export const isCollectionItemType = (
|
||||
itemType: string | undefined
|
||||
): itemType is keyof typeof COLLECTION_MAP =>
|
||||
Boolean(itemType && itemType in COLLECTION_MAP);
|
||||
|
||||
export const SINGLE_TO_POLYMORPHIC_MAP = {
|
||||
integer: 'IntegerPolymorphic',
|
||||
boolean: 'BooleanPolymorphic',
|
||||
number: 'FloatPolymorphic',
|
||||
float: 'FloatPolymorphic',
|
||||
string: 'StringPolymorphic',
|
||||
ImageField: 'ImagePolymorphic',
|
||||
LatentsField: 'LatentsPolymorphic',
|
||||
ConditioningField: 'ConditioningPolymorphic',
|
||||
ControlField: 'ControlPolymorphic',
|
||||
ColorField: 'ColorPolymorphic',
|
||||
};
|
||||
|
||||
export const POLYMORPHIC_TO_SINGLE_MAP = {
|
||||
IntegerPolymorphic: 'integer',
|
||||
BooleanPolymorphic: 'boolean',
|
||||
FloatPolymorphic: 'float',
|
||||
StringPolymorphic: 'string',
|
||||
ImagePolymorphic: 'ImageField',
|
||||
LatentsPolymorphic: 'LatentsField',
|
||||
ConditioningPolymorphic: 'ConditioningField',
|
||||
ControlPolymorphic: 'ControlField',
|
||||
ColorPolymorphic: 'ColorField',
|
||||
};
|
||||
|
||||
export const isPolymorphicItemType = (
|
||||
itemType: string | undefined
|
||||
): itemType is keyof typeof SINGLE_TO_POLYMORPHIC_MAP =>
|
||||
Boolean(itemType && itemType in SINGLE_TO_POLYMORPHIC_MAP);
|
||||
|
||||
export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
||||
integer: {
|
||||
title: 'Integer',
|
||||
description: 'Integers are whole numbers, without a decimal point.',
|
||||
color: 'red.500',
|
||||
},
|
||||
float: {
|
||||
title: 'Float',
|
||||
description: 'Floats are numbers with a decimal point.',
|
||||
color: 'orange.500',
|
||||
},
|
||||
string: {
|
||||
title: 'String',
|
||||
description: 'Strings are text.',
|
||||
color: 'yellow.500',
|
||||
},
|
||||
boolean: {
|
||||
title: 'Boolean',
|
||||
color: 'green.500',
|
||||
description: 'Booleans are true or false.',
|
||||
title: 'Boolean',
|
||||
},
|
||||
BooleanCollection: {
|
||||
color: 'green.500',
|
||||
description: 'A collection of booleans.',
|
||||
title: 'Boolean Collection',
|
||||
enum: {
|
||||
title: 'Enum',
|
||||
description: 'Enums are values that may be one of a number of options.',
|
||||
color: 'blue.500',
|
||||
},
|
||||
BooleanPolymorphic: {
|
||||
color: 'green.500',
|
||||
description: 'A collection of booleans.',
|
||||
title: 'Boolean Polymorphic',
|
||||
},
|
||||
ClipField: {
|
||||
color: 'green.500',
|
||||
description: 'Tokenizer and text_encoder submodels.',
|
||||
title: 'Clip',
|
||||
},
|
||||
Collection: {
|
||||
array: {
|
||||
title: 'Array',
|
||||
description: 'Enums are values that may be one of a number of options.',
|
||||
color: 'base.500',
|
||||
description: 'TODO',
|
||||
title: 'Collection',
|
||||
},
|
||||
CollectionItem: {
|
||||
ImageField: {
|
||||
title: 'Image',
|
||||
description: 'Images may be passed between nodes.',
|
||||
color: 'purple.500',
|
||||
},
|
||||
DenoiseMaskField: {
|
||||
title: 'Denoise Mask',
|
||||
description: 'Denoise Mask may be passed between nodes',
|
||||
color: 'base.500',
|
||||
description: 'TODO',
|
||||
title: 'Collection Item',
|
||||
},
|
||||
ColorCollection: {
|
||||
color: 'pink.300',
|
||||
description: 'A collection of colors.',
|
||||
title: 'Color Collection',
|
||||
LatentsField: {
|
||||
title: 'Latents',
|
||||
description: 'Latents may be passed between nodes.',
|
||||
color: 'pink.500',
|
||||
},
|
||||
ColorField: {
|
||||
color: 'pink.300',
|
||||
description: 'A RGBA color.',
|
||||
title: 'Color',
|
||||
},
|
||||
ColorPolymorphic: {
|
||||
color: 'pink.300',
|
||||
description: 'A collection of colors.',
|
||||
title: 'Color Polymorphic',
|
||||
},
|
||||
ConditioningCollection: {
|
||||
color: 'cyan.500',
|
||||
description: 'Conditioning may be passed between nodes.',
|
||||
title: 'Conditioning Collection',
|
||||
LatentsCollection: {
|
||||
title: 'Latents Collection',
|
||||
description: 'Latents may be passed between nodes.',
|
||||
color: 'pink.500',
|
||||
},
|
||||
ConditioningField: {
|
||||
color: 'cyan.500',
|
||||
description: 'Conditioning may be passed between nodes.',
|
||||
title: 'Conditioning',
|
||||
},
|
||||
ConditioningPolymorphic: {
|
||||
color: 'cyan.500',
|
||||
description: 'Conditioning may be passed between nodes.',
|
||||
title: 'Conditioning Polymorphic',
|
||||
},
|
||||
ControlCollection: {
|
||||
color: 'teal.500',
|
||||
description: 'Control info passed between nodes.',
|
||||
title: 'Control Collection',
|
||||
},
|
||||
ControlField: {
|
||||
color: 'teal.500',
|
||||
description: 'Control info passed between nodes.',
|
||||
title: 'Control',
|
||||
},
|
||||
ControlNetModelField: {
|
||||
color: 'teal.500',
|
||||
description: 'TODO',
|
||||
title: 'ControlNet',
|
||||
},
|
||||
ControlPolymorphic: {
|
||||
color: 'teal.500',
|
||||
description: 'Control info passed between nodes.',
|
||||
title: 'Control Polymorphic',
|
||||
},
|
||||
DenoiseMaskField: {
|
||||
color: 'blue.300',
|
||||
description: 'Denoise Mask may be passed between nodes',
|
||||
title: 'Denoise Mask',
|
||||
},
|
||||
enum: {
|
||||
color: 'blue.500',
|
||||
description: 'Enums are values that may be one of a number of options.',
|
||||
title: 'Enum',
|
||||
},
|
||||
float: {
|
||||
color: 'orange.500',
|
||||
description: 'Floats are numbers with a decimal point.',
|
||||
title: 'Float',
|
||||
},
|
||||
FloatCollection: {
|
||||
color: 'orange.500',
|
||||
description: 'A collection of floats.',
|
||||
title: 'Float Collection',
|
||||
},
|
||||
FloatPolymorphic: {
|
||||
color: 'orange.500',
|
||||
description: 'A collection of floats.',
|
||||
title: 'Float Polymorphic',
|
||||
ConditioningCollection: {
|
||||
color: 'cyan.500',
|
||||
title: 'Conditioning Collection',
|
||||
description: 'Conditioning may be passed between nodes.',
|
||||
},
|
||||
ImageCollection: {
|
||||
color: 'purple.500',
|
||||
description: 'A collection of images.',
|
||||
title: 'Image Collection',
|
||||
},
|
||||
ImageField: {
|
||||
color: 'purple.500',
|
||||
description: 'Images may be passed between nodes.',
|
||||
title: 'Image',
|
||||
},
|
||||
ImagePolymorphic: {
|
||||
color: 'purple.500',
|
||||
description: 'A collection of images.',
|
||||
title: 'Image Polymorphic',
|
||||
},
|
||||
integer: {
|
||||
color: 'red.500',
|
||||
description: 'Integers are whole numbers, without a decimal point.',
|
||||
title: 'Integer',
|
||||
},
|
||||
IntegerCollection: {
|
||||
color: 'red.500',
|
||||
description: 'A collection of integers.',
|
||||
title: 'Integer Collection',
|
||||
},
|
||||
IntegerPolymorphic: {
|
||||
color: 'red.500',
|
||||
description: 'A collection of integers.',
|
||||
title: 'Integer Polymorphic',
|
||||
},
|
||||
LatentsCollection: {
|
||||
color: 'pink.500',
|
||||
description: 'Latents may be passed between nodes.',
|
||||
title: 'Latents Collection',
|
||||
},
|
||||
LatentsField: {
|
||||
color: 'pink.500',
|
||||
description: 'Latents may be passed between nodes.',
|
||||
title: 'Latents',
|
||||
},
|
||||
LatentsPolymorphic: {
|
||||
color: 'pink.500',
|
||||
description: 'Latents may be passed between nodes.',
|
||||
title: 'Latents Polymorphic',
|
||||
},
|
||||
LoRAModelField: {
|
||||
color: 'teal.500',
|
||||
description: 'TODO',
|
||||
title: 'LoRA',
|
||||
},
|
||||
MainModelField: {
|
||||
color: 'teal.500',
|
||||
description: 'TODO',
|
||||
title: 'Model',
|
||||
},
|
||||
ONNXModelField: {
|
||||
color: 'teal.500',
|
||||
description: 'ONNX model field.',
|
||||
title: 'ONNX Model',
|
||||
},
|
||||
Scheduler: {
|
||||
color: 'base.500',
|
||||
description: 'TODO',
|
||||
title: 'Scheduler',
|
||||
},
|
||||
SDXLMainModelField: {
|
||||
color: 'teal.500',
|
||||
description: 'SDXL model field.',
|
||||
title: 'SDXL Model',
|
||||
},
|
||||
SDXLRefinerModelField: {
|
||||
color: 'teal.500',
|
||||
description: 'TODO',
|
||||
title: 'Refiner Model',
|
||||
},
|
||||
string: {
|
||||
color: 'yellow.500',
|
||||
description: 'Strings are text.',
|
||||
title: 'String',
|
||||
},
|
||||
StringCollection: {
|
||||
color: 'yellow.500',
|
||||
description: 'A collection of strings.',
|
||||
title: 'String Collection',
|
||||
},
|
||||
StringPolymorphic: {
|
||||
color: 'yellow.500',
|
||||
description: 'A collection of strings.',
|
||||
title: 'String Polymorphic',
|
||||
color: 'base.300',
|
||||
},
|
||||
UNetField: {
|
||||
color: 'red.500',
|
||||
description: 'UNet submodel.',
|
||||
title: 'UNet',
|
||||
description: 'UNet submodel.',
|
||||
},
|
||||
ClipField: {
|
||||
color: 'green.500',
|
||||
title: 'Clip',
|
||||
description: 'Tokenizer and text_encoder submodels.',
|
||||
},
|
||||
VaeField: {
|
||||
color: 'blue.500',
|
||||
description: 'Vae submodel.',
|
||||
title: 'Vae',
|
||||
description: 'Vae submodel.',
|
||||
},
|
||||
ControlField: {
|
||||
color: 'cyan.500',
|
||||
title: 'Control',
|
||||
description: 'Control info passed between nodes.',
|
||||
},
|
||||
MainModelField: {
|
||||
color: 'teal.500',
|
||||
title: 'Model',
|
||||
description: 'TODO',
|
||||
},
|
||||
SDXLRefinerModelField: {
|
||||
color: 'teal.500',
|
||||
title: 'Refiner Model',
|
||||
description: 'TODO',
|
||||
},
|
||||
VaeModelField: {
|
||||
color: 'teal.500',
|
||||
description: 'TODO',
|
||||
title: 'VAE',
|
||||
description: 'TODO',
|
||||
},
|
||||
LoRAModelField: {
|
||||
color: 'teal.500',
|
||||
title: 'LoRA',
|
||||
description: 'TODO',
|
||||
},
|
||||
ControlNetModelField: {
|
||||
color: 'teal.500',
|
||||
title: 'ControlNet',
|
||||
description: 'TODO',
|
||||
},
|
||||
Scheduler: {
|
||||
color: 'base.500',
|
||||
title: 'Scheduler',
|
||||
description: 'TODO',
|
||||
},
|
||||
Collection: {
|
||||
color: 'base.500',
|
||||
title: 'Collection',
|
||||
description: 'TODO',
|
||||
},
|
||||
CollectionItem: {
|
||||
color: 'base.500',
|
||||
title: 'Collection Item',
|
||||
description: 'TODO',
|
||||
},
|
||||
ColorField: {
|
||||
title: 'Color',
|
||||
description: 'A RGBA color.',
|
||||
color: 'base.500',
|
||||
},
|
||||
BooleanCollection: {
|
||||
title: 'Boolean Collection',
|
||||
description: 'A collection of booleans.',
|
||||
color: 'green.500',
|
||||
},
|
||||
IntegerCollection: {
|
||||
title: 'Integer Collection',
|
||||
description: 'A collection of integers.',
|
||||
color: 'red.500',
|
||||
},
|
||||
FloatCollection: {
|
||||
color: 'orange.500',
|
||||
title: 'Float Collection',
|
||||
description: 'A collection of floats.',
|
||||
},
|
||||
ColorCollection: {
|
||||
color: 'base.500',
|
||||
title: 'Color Collection',
|
||||
description: 'A collection of colors.',
|
||||
},
|
||||
ONNXModelField: {
|
||||
color: 'base.500',
|
||||
title: 'ONNX Model',
|
||||
description: 'ONNX model field.',
|
||||
},
|
||||
SDXLMainModelField: {
|
||||
color: 'base.500',
|
||||
title: 'SDXL Model',
|
||||
description: 'SDXL model field.',
|
||||
},
|
||||
StringCollection: {
|
||||
color: 'yellow.500',
|
||||
title: 'String Collection',
|
||||
description: 'A collection of strings.',
|
||||
},
|
||||
};
|
||||
|
@ -1,9 +1,8 @@
|
||||
import { store } from 'app/store/store';
|
||||
import {
|
||||
SchedulerParam,
|
||||
zBaseModel,
|
||||
zMainModel,
|
||||
zMainOrOnnxModel,
|
||||
zOnnxModel,
|
||||
zSDXLRefinerModel,
|
||||
zScheduler,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
@ -11,14 +10,14 @@ import { keyBy } from 'lodash-es';
|
||||
import { OpenAPIV3 } from 'openapi-types';
|
||||
import { RgbaColor } from 'react-colorful';
|
||||
import { Node } from 'reactflow';
|
||||
import { Graph, _InputField, _OutputField } from 'services/api/types';
|
||||
import { JsonObject } from 'type-fest';
|
||||
import { Graph, ImageDTO, _InputField, _OutputField } from 'services/api/types';
|
||||
import {
|
||||
AnyInvocationType,
|
||||
AnyResult,
|
||||
ProgressImage,
|
||||
} from 'services/events/types';
|
||||
import { O } from 'ts-toolbelt';
|
||||
import { JsonObject } from 'type-fest';
|
||||
import { z } from 'zod';
|
||||
|
||||
export type NonNullableGraph = O.Required<Graph, 'nodes' | 'edges'>;
|
||||
@ -52,10 +51,6 @@ export type InvocationTemplate = {
|
||||
* The type of this node's output
|
||||
*/
|
||||
outputType: string; // TODO: generate a union of output types
|
||||
/**
|
||||
* The invocation's version.
|
||||
*/
|
||||
version?: string;
|
||||
};
|
||||
|
||||
export type FieldUIConfig = {
|
||||
@ -66,48 +61,50 @@ export type FieldUIConfig = {
|
||||
|
||||
// TODO: Get this from the OpenAPI schema? may be tricky...
|
||||
export const zFieldType = z.enum([
|
||||
'boolean',
|
||||
'BooleanCollection',
|
||||
'BooleanPolymorphic',
|
||||
'ClipField',
|
||||
'Collection',
|
||||
'CollectionItem',
|
||||
'ColorCollection',
|
||||
'ColorField',
|
||||
'ColorPolymorphic',
|
||||
'ConditioningCollection',
|
||||
'ConditioningField',
|
||||
'ConditioningPolymorphic',
|
||||
'ControlCollection',
|
||||
'ControlField',
|
||||
'ControlNetModelField',
|
||||
'ControlPolymorphic',
|
||||
'DenoiseMaskField',
|
||||
'enum',
|
||||
'float',
|
||||
'FloatCollection',
|
||||
'FloatPolymorphic',
|
||||
'ImageCollection',
|
||||
'ImageField',
|
||||
'ImagePolymorphic',
|
||||
// region Primitives
|
||||
'integer',
|
||||
'IntegerCollection',
|
||||
'IntegerPolymorphic',
|
||||
'LatentsCollection',
|
||||
'float',
|
||||
'boolean',
|
||||
'string',
|
||||
'array',
|
||||
'ImageField',
|
||||
'DenoiseMaskField',
|
||||
'LatentsField',
|
||||
'LatentsPolymorphic',
|
||||
'LoRAModelField',
|
||||
'ConditioningField',
|
||||
'ControlField',
|
||||
'ColorField',
|
||||
'ImageCollection',
|
||||
'ConditioningCollection',
|
||||
'ColorCollection',
|
||||
'LatentsCollection',
|
||||
'IntegerCollection',
|
||||
'FloatCollection',
|
||||
'StringCollection',
|
||||
'BooleanCollection',
|
||||
// endregion
|
||||
|
||||
// region Models
|
||||
'MainModelField',
|
||||
'ONNXModelField',
|
||||
'Scheduler',
|
||||
'SDXLMainModelField',
|
||||
'SDXLRefinerModelField',
|
||||
'string',
|
||||
'StringCollection',
|
||||
'StringPolymorphic',
|
||||
'ONNXModelField',
|
||||
'VaeModelField',
|
||||
'LoRAModelField',
|
||||
'ControlNetModelField',
|
||||
'UNetField',
|
||||
'VaeField',
|
||||
'VaeModelField',
|
||||
'ClipField',
|
||||
// endregion
|
||||
|
||||
// region Iterate/Collect
|
||||
'Collection',
|
||||
'CollectionItem',
|
||||
// endregion
|
||||
|
||||
// region Misc
|
||||
'enum',
|
||||
'Scheduler',
|
||||
// endregion
|
||||
]);
|
||||
|
||||
export type FieldType = z.infer<typeof zFieldType>;
|
||||
@ -124,6 +121,38 @@ export const isFieldType = (value: unknown): value is FieldType =>
|
||||
zFieldType.safeParse(value).success ||
|
||||
zReservedFieldType.safeParse(value).success;
|
||||
|
||||
/**
|
||||
* An input field template is generated on each page load from the OpenAPI schema.
|
||||
*
|
||||
* The template provides the field type and other field metadata (e.g. title, description,
|
||||
* maximum length, pattern to match, etc).
|
||||
*/
|
||||
export type InputFieldTemplate =
|
||||
| IntegerInputFieldTemplate
|
||||
| FloatInputFieldTemplate
|
||||
| StringInputFieldTemplate
|
||||
| BooleanInputFieldTemplate
|
||||
| ImageInputFieldTemplate
|
||||
| DenoiseMaskInputFieldTemplate
|
||||
| LatentsInputFieldTemplate
|
||||
| ConditioningInputFieldTemplate
|
||||
| UNetInputFieldTemplate
|
||||
| ClipInputFieldTemplate
|
||||
| VaeInputFieldTemplate
|
||||
| ControlInputFieldTemplate
|
||||
| EnumInputFieldTemplate
|
||||
| MainModelInputFieldTemplate
|
||||
| SDXLMainModelInputFieldTemplate
|
||||
| SDXLRefinerModelInputFieldTemplate
|
||||
| VaeModelInputFieldTemplate
|
||||
| LoRAModelInputFieldTemplate
|
||||
| ControlNetModelInputFieldTemplate
|
||||
| CollectionInputFieldTemplate
|
||||
| CollectionItemInputFieldTemplate
|
||||
| ColorInputFieldTemplate
|
||||
| ImageCollectionInputFieldTemplate
|
||||
| SchedulerInputFieldTemplate;
|
||||
|
||||
/**
|
||||
* Indicates the kind of input(s) this field may have.
|
||||
*/
|
||||
@ -202,88 +231,24 @@ export const zIntegerInputFieldValue = zInputFieldValueBase.extend({
|
||||
});
|
||||
export type IntegerInputFieldValue = z.infer<typeof zIntegerInputFieldValue>;
|
||||
|
||||
export const zIntegerCollectionInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('IntegerCollection'),
|
||||
value: z.array(z.number().int()).optional(),
|
||||
});
|
||||
export type IntegerCollectionInputFieldValue = z.infer<
|
||||
typeof zIntegerCollectionInputFieldValue
|
||||
>;
|
||||
|
||||
export const zIntegerPolymorphicInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('IntegerPolymorphic'),
|
||||
value: z.union([z.number().int(), z.array(z.number().int())]).optional(),
|
||||
});
|
||||
export type IntegerPolymorphicInputFieldValue = z.infer<
|
||||
typeof zIntegerPolymorphicInputFieldValue
|
||||
>;
|
||||
|
||||
export const zFloatInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('float'),
|
||||
value: z.number().optional(),
|
||||
});
|
||||
export type FloatInputFieldValue = z.infer<typeof zFloatInputFieldValue>;
|
||||
|
||||
export const zFloatCollectionInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('FloatCollection'),
|
||||
value: z.array(z.number()).optional(),
|
||||
});
|
||||
export type FloatCollectionInputFieldValue = z.infer<
|
||||
typeof zFloatCollectionInputFieldValue
|
||||
>;
|
||||
|
||||
export const zFloatPolymorphicInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('FloatPolymorphic'),
|
||||
value: z.union([z.number(), z.array(z.number())]).optional(),
|
||||
});
|
||||
export type FloatPolymorphicInputFieldValue = z.infer<
|
||||
typeof zFloatPolymorphicInputFieldValue
|
||||
>;
|
||||
|
||||
export const zStringInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('string'),
|
||||
value: z.string().optional(),
|
||||
});
|
||||
export type StringInputFieldValue = z.infer<typeof zStringInputFieldValue>;
|
||||
|
||||
export const zStringCollectionInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('StringCollection'),
|
||||
value: z.array(z.string()).optional(),
|
||||
});
|
||||
export type StringCollectionInputFieldValue = z.infer<
|
||||
typeof zStringCollectionInputFieldValue
|
||||
>;
|
||||
|
||||
export const zStringPolymorphicInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('StringPolymorphic'),
|
||||
value: z.union([z.string(), z.array(z.string())]).optional(),
|
||||
});
|
||||
export type StringPolymorphicInputFieldValue = z.infer<
|
||||
typeof zStringPolymorphicInputFieldValue
|
||||
>;
|
||||
|
||||
export const zBooleanInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('boolean'),
|
||||
value: z.boolean().optional(),
|
||||
});
|
||||
export type BooleanInputFieldValue = z.infer<typeof zBooleanInputFieldValue>;
|
||||
|
||||
export const zBooleanCollectionInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('BooleanCollection'),
|
||||
value: z.array(z.boolean()).optional(),
|
||||
});
|
||||
export type BooleanCollectionInputFieldValue = z.infer<
|
||||
typeof zBooleanCollectionInputFieldValue
|
||||
>;
|
||||
|
||||
export const zBooleanPolymorphicInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('BooleanPolymorphic'),
|
||||
value: z.union([z.boolean(), z.array(z.boolean())]).optional(),
|
||||
});
|
||||
export type BooleanPolymorphicInputFieldValue = z.infer<
|
||||
typeof zBooleanPolymorphicInputFieldValue
|
||||
>;
|
||||
|
||||
export const zEnumInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('enum'),
|
||||
value: z.union([z.string(), z.number()]).optional(),
|
||||
@ -296,22 +261,6 @@ export const zLatentsInputFieldValue = zInputFieldValueBase.extend({
|
||||
});
|
||||
export type LatentsInputFieldValue = z.infer<typeof zLatentsInputFieldValue>;
|
||||
|
||||
export const zLatentsCollectionInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('LatentsCollection'),
|
||||
value: z.array(zLatentsField).optional(),
|
||||
});
|
||||
export type LatentsCollectionInputFieldValue = z.infer<
|
||||
typeof zLatentsCollectionInputFieldValue
|
||||
>;
|
||||
|
||||
export const zLatentsPolymorphicInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('LatentsPolymorphic'),
|
||||
value: z.union([zLatentsField, z.array(zLatentsField)]).optional(),
|
||||
});
|
||||
export type LatentsPolymorphicInputFieldValue = z.infer<
|
||||
typeof zLatentsPolymorphicInputFieldValue
|
||||
>;
|
||||
|
||||
export const zDenoiseMaskInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('DenoiseMaskField'),
|
||||
value: zDenoiseMaskField.optional(),
|
||||
@ -328,26 +277,6 @@ export type ConditioningInputFieldValue = z.infer<
|
||||
typeof zConditioningInputFieldValue
|
||||
>;
|
||||
|
||||
export const zConditioningCollectionInputFieldValue =
|
||||
zInputFieldValueBase.extend({
|
||||
type: z.literal('ConditioningCollection'),
|
||||
value: z.array(zConditioningField).optional(),
|
||||
});
|
||||
export type ConditioningCollectionInputFieldValue = z.infer<
|
||||
typeof zConditioningCollectionInputFieldValue
|
||||
>;
|
||||
|
||||
export const zConditioningPolymorphicInputFieldValue =
|
||||
zInputFieldValueBase.extend({
|
||||
type: z.literal('ConditioningPolymorphic'),
|
||||
value: z
|
||||
.union([zConditioningField, z.array(zConditioningField)])
|
||||
.optional(),
|
||||
});
|
||||
export type ConditioningPolymorphicInputFieldValue = z.infer<
|
||||
typeof zConditioningPolymorphicInputFieldValue
|
||||
>;
|
||||
|
||||
export const zControlNetModel = zModelIdentifier;
|
||||
export type ControlNetModel = z.infer<typeof zControlNetModel>;
|
||||
|
||||
@ -372,22 +301,6 @@ export const zControlInputFieldValue = zInputFieldValueBase.extend({
|
||||
});
|
||||
export type ControlInputFieldValue = z.infer<typeof zControlInputFieldValue>;
|
||||
|
||||
export const zControlPolymorphicInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('ControlPolymorphic'),
|
||||
value: z.union([zControlField, z.array(zControlField)]).optional(),
|
||||
});
|
||||
export type ControlPolymorphicInputFieldValue = z.infer<
|
||||
typeof zControlPolymorphicInputFieldValue
|
||||
>;
|
||||
|
||||
export const zControlCollectionInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('ControlCollection'),
|
||||
value: z.array(zControlField).optional(),
|
||||
});
|
||||
export type ControlCollectionInputFieldValue = z.infer<
|
||||
typeof zControlCollectionInputFieldValue
|
||||
>;
|
||||
|
||||
export const zModelType = z.enum([
|
||||
'onnx',
|
||||
'main',
|
||||
@ -467,14 +380,6 @@ export const zImageInputFieldValue = zInputFieldValueBase.extend({
|
||||
});
|
||||
export type ImageInputFieldValue = z.infer<typeof zImageInputFieldValue>;
|
||||
|
||||
export const zImagePolymorphicInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('ImagePolymorphic'),
|
||||
value: z.union([zImageField, z.array(zImageField)]).optional(),
|
||||
});
|
||||
export type ImagePolymorphicInputFieldValue = z.infer<
|
||||
typeof zImagePolymorphicInputFieldValue
|
||||
>;
|
||||
|
||||
export const zImageCollectionInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('ImageCollection'),
|
||||
value: z.array(zImageField).optional(),
|
||||
@ -567,22 +472,6 @@ export const zColorInputFieldValue = zInputFieldValueBase.extend({
|
||||
});
|
||||
export type ColorInputFieldValue = z.infer<typeof zColorInputFieldValue>;
|
||||
|
||||
export const zColorCollectionInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('ColorCollection'),
|
||||
value: z.array(zColorField).optional(),
|
||||
});
|
||||
export type ColorCollectionInputFieldValue = z.infer<
|
||||
typeof zColorCollectionInputFieldValue
|
||||
>;
|
||||
|
||||
export const zColorPolymorphicInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('ColorPolymorphic'),
|
||||
value: z.union([zColorField, z.array(zColorField)]).optional(),
|
||||
});
|
||||
export type ColorPolymorphicInputFieldValue = z.infer<
|
||||
typeof zColorPolymorphicInputFieldValue
|
||||
>;
|
||||
|
||||
export const zSchedulerInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('Scheduler'),
|
||||
value: zScheduler.optional(),
|
||||
@ -592,47 +481,30 @@ export type SchedulerInputFieldValue = z.infer<
|
||||
>;
|
||||
|
||||
export const zInputFieldValue = z.discriminatedUnion('type', [
|
||||
zBooleanCollectionInputFieldValue,
|
||||
zIntegerInputFieldValue,
|
||||
zFloatInputFieldValue,
|
||||
zStringInputFieldValue,
|
||||
zBooleanInputFieldValue,
|
||||
zBooleanPolymorphicInputFieldValue,
|
||||
zImageInputFieldValue,
|
||||
zLatentsInputFieldValue,
|
||||
zDenoiseMaskInputFieldValue,
|
||||
zConditioningInputFieldValue,
|
||||
zUNetInputFieldValue,
|
||||
zClipInputFieldValue,
|
||||
zVaeInputFieldValue,
|
||||
zControlInputFieldValue,
|
||||
zEnumInputFieldValue,
|
||||
zMainModelInputFieldValue,
|
||||
zSDXLMainModelInputFieldValue,
|
||||
zSDXLRefinerModelInputFieldValue,
|
||||
zVaeModelInputFieldValue,
|
||||
zLoRAModelInputFieldValue,
|
||||
zControlNetModelInputFieldValue,
|
||||
zCollectionInputFieldValue,
|
||||
zCollectionItemInputFieldValue,
|
||||
zColorInputFieldValue,
|
||||
zColorCollectionInputFieldValue,
|
||||
zColorPolymorphicInputFieldValue,
|
||||
zConditioningInputFieldValue,
|
||||
zConditioningCollectionInputFieldValue,
|
||||
zConditioningPolymorphicInputFieldValue,
|
||||
zControlInputFieldValue,
|
||||
zControlNetModelInputFieldValue,
|
||||
zControlCollectionInputFieldValue,
|
||||
zControlPolymorphicInputFieldValue,
|
||||
zDenoiseMaskInputFieldValue,
|
||||
zEnumInputFieldValue,
|
||||
zFloatCollectionInputFieldValue,
|
||||
zFloatInputFieldValue,
|
||||
zFloatPolymorphicInputFieldValue,
|
||||
zImageCollectionInputFieldValue,
|
||||
zImagePolymorphicInputFieldValue,
|
||||
zImageInputFieldValue,
|
||||
zIntegerCollectionInputFieldValue,
|
||||
zIntegerPolymorphicInputFieldValue,
|
||||
zIntegerInputFieldValue,
|
||||
zLatentsInputFieldValue,
|
||||
zLatentsCollectionInputFieldValue,
|
||||
zLatentsPolymorphicInputFieldValue,
|
||||
zLoRAModelInputFieldValue,
|
||||
zMainModelInputFieldValue,
|
||||
zSchedulerInputFieldValue,
|
||||
zSDXLMainModelInputFieldValue,
|
||||
zSDXLRefinerModelInputFieldValue,
|
||||
zStringCollectionInputFieldValue,
|
||||
zStringPolymorphicInputFieldValue,
|
||||
zStringInputFieldValue,
|
||||
zUNetInputFieldValue,
|
||||
zVaeInputFieldValue,
|
||||
zVaeModelInputFieldValue,
|
||||
]);
|
||||
|
||||
export type InputFieldValue = z.infer<typeof zInputFieldValue>;
|
||||
@ -641,6 +513,7 @@ export type InputFieldTemplateBase = {
|
||||
name: string;
|
||||
title: string;
|
||||
description: string;
|
||||
type: FieldType;
|
||||
required: boolean;
|
||||
fieldKind: 'input';
|
||||
} & _InputField;
|
||||
@ -655,19 +528,6 @@ export type IntegerInputFieldTemplate = InputFieldTemplateBase & {
|
||||
exclusiveMinimum?: boolean;
|
||||
};
|
||||
|
||||
export type IntegerCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'IntegerCollection';
|
||||
default: number[];
|
||||
item_default?: number;
|
||||
};
|
||||
|
||||
export type IntegerPolymorphicInputFieldTemplate = Omit<
|
||||
IntegerInputFieldTemplate,
|
||||
'type'
|
||||
> & {
|
||||
type: 'IntegerPolymorphic';
|
||||
};
|
||||
|
||||
export type FloatInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'float';
|
||||
default: number;
|
||||
@ -678,19 +538,6 @@ export type FloatInputFieldTemplate = InputFieldTemplateBase & {
|
||||
exclusiveMinimum?: boolean;
|
||||
};
|
||||
|
||||
export type FloatCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'FloatCollection';
|
||||
default: number[];
|
||||
item_default?: number;
|
||||
};
|
||||
|
||||
export type FloatPolymorphicInputFieldTemplate = Omit<
|
||||
FloatInputFieldTemplate,
|
||||
'type'
|
||||
> & {
|
||||
type: 'FloatPolymorphic';
|
||||
};
|
||||
|
||||
export type StringInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'string';
|
||||
default: string;
|
||||
@ -699,53 +546,19 @@ export type StringInputFieldTemplate = InputFieldTemplateBase & {
|
||||
pattern?: string;
|
||||
};
|
||||
|
||||
export type StringCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'StringCollection';
|
||||
default: string[];
|
||||
item_default?: string;
|
||||
};
|
||||
|
||||
export type StringPolymorphicInputFieldTemplate = Omit<
|
||||
StringInputFieldTemplate,
|
||||
'type'
|
||||
> & {
|
||||
type: 'StringPolymorphic';
|
||||
};
|
||||
|
||||
export type BooleanInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: boolean;
|
||||
type: 'boolean';
|
||||
};
|
||||
|
||||
export type BooleanCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'BooleanCollection';
|
||||
default: boolean[];
|
||||
item_default?: boolean;
|
||||
};
|
||||
|
||||
export type BooleanPolymorphicInputFieldTemplate = Omit<
|
||||
BooleanInputFieldTemplate,
|
||||
'type'
|
||||
> & {
|
||||
type: 'BooleanPolymorphic';
|
||||
};
|
||||
|
||||
export type ImageInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: ImageField;
|
||||
default: ImageDTO;
|
||||
type: 'ImageField';
|
||||
};
|
||||
|
||||
export type ImageCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: ImageField[];
|
||||
type: 'ImageCollection';
|
||||
item_default?: ImageField;
|
||||
};
|
||||
|
||||
export type ImagePolymorphicInputFieldTemplate = Omit<
|
||||
ImageInputFieldTemplate,
|
||||
'type'
|
||||
> & {
|
||||
type: 'ImagePolymorphic';
|
||||
};
|
||||
|
||||
export type DenoiseMaskInputFieldTemplate = InputFieldTemplateBase & {
|
||||
@ -754,40 +567,15 @@ export type DenoiseMaskInputFieldTemplate = InputFieldTemplateBase & {
|
||||
};
|
||||
|
||||
export type LatentsInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: LatentsField;
|
||||
default: string;
|
||||
type: 'LatentsField';
|
||||
};
|
||||
|
||||
export type LatentsCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: LatentsField[];
|
||||
type: 'LatentsCollection';
|
||||
item_default?: LatentsField;
|
||||
};
|
||||
|
||||
export type LatentsPolymorphicInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: LatentsField;
|
||||
type: 'LatentsPolymorphic';
|
||||
};
|
||||
|
||||
export type ConditioningInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: undefined;
|
||||
type: 'ConditioningField';
|
||||
};
|
||||
|
||||
export type ConditioningCollectionInputFieldTemplate =
|
||||
InputFieldTemplateBase & {
|
||||
default: ConditioningField[];
|
||||
type: 'ConditioningCollection';
|
||||
item_default?: ConditioningField;
|
||||
};
|
||||
|
||||
export type ConditioningPolymorphicInputFieldTemplate = Omit<
|
||||
ConditioningInputFieldTemplate,
|
||||
'type'
|
||||
> & {
|
||||
type: 'ConditioningPolymorphic';
|
||||
};
|
||||
|
||||
export type UNetInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: undefined;
|
||||
type: 'UNetField';
|
||||
@ -808,19 +596,6 @@ export type ControlInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'ControlField';
|
||||
};
|
||||
|
||||
export type ControlCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: undefined;
|
||||
type: 'ControlCollection';
|
||||
item_default?: ControlField;
|
||||
};
|
||||
|
||||
export type ControlPolymorphicInputFieldTemplate = Omit<
|
||||
ControlInputFieldTemplate,
|
||||
'type'
|
||||
> & {
|
||||
type: 'ControlPolymorphic';
|
||||
};
|
||||
|
||||
export type EnumInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: string | number;
|
||||
type: 'enum';
|
||||
@ -873,18 +648,6 @@ export type ColorInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'ColorField';
|
||||
};
|
||||
|
||||
export type ColorPolymorphicInputFieldTemplate = Omit<
|
||||
ColorInputFieldTemplate,
|
||||
'type'
|
||||
> & {
|
||||
type: 'ColorPolymorphic';
|
||||
};
|
||||
|
||||
export type ColorCollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: [];
|
||||
type: 'ColorCollection';
|
||||
};
|
||||
|
||||
export type SchedulerInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: SchedulerParam;
|
||||
type: 'Scheduler';
|
||||
@ -895,55 +658,6 @@ export type WorkflowInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'WorkflowField';
|
||||
};
|
||||
|
||||
/**
|
||||
* An input field template is generated on each page load from the OpenAPI schema.
|
||||
*
|
||||
* The template provides the field type and other field metadata (e.g. title, description,
|
||||
* maximum length, pattern to match, etc).
|
||||
*/
|
||||
export type InputFieldTemplate =
|
||||
| BooleanCollectionInputFieldTemplate
|
||||
| BooleanPolymorphicInputFieldTemplate
|
||||
| BooleanInputFieldTemplate
|
||||
| ClipInputFieldTemplate
|
||||
| CollectionInputFieldTemplate
|
||||
| CollectionItemInputFieldTemplate
|
||||
| ColorInputFieldTemplate
|
||||
| ColorCollectionInputFieldTemplate
|
||||
| ColorPolymorphicInputFieldTemplate
|
||||
| ConditioningInputFieldTemplate
|
||||
| ConditioningCollectionInputFieldTemplate
|
||||
| ConditioningPolymorphicInputFieldTemplate
|
||||
| ControlInputFieldTemplate
|
||||
| ControlCollectionInputFieldTemplate
|
||||
| ControlNetModelInputFieldTemplate
|
||||
| ControlPolymorphicInputFieldTemplate
|
||||
| DenoiseMaskInputFieldTemplate
|
||||
| EnumInputFieldTemplate
|
||||
| FloatCollectionInputFieldTemplate
|
||||
| FloatInputFieldTemplate
|
||||
| FloatPolymorphicInputFieldTemplate
|
||||
| ImageCollectionInputFieldTemplate
|
||||
| ImagePolymorphicInputFieldTemplate
|
||||
| ImageInputFieldTemplate
|
||||
| IntegerCollectionInputFieldTemplate
|
||||
| IntegerPolymorphicInputFieldTemplate
|
||||
| IntegerInputFieldTemplate
|
||||
| LatentsInputFieldTemplate
|
||||
| LatentsCollectionInputFieldTemplate
|
||||
| LatentsPolymorphicInputFieldTemplate
|
||||
| LoRAModelInputFieldTemplate
|
||||
| MainModelInputFieldTemplate
|
||||
| SchedulerInputFieldTemplate
|
||||
| SDXLMainModelInputFieldTemplate
|
||||
| SDXLRefinerModelInputFieldTemplate
|
||||
| StringCollectionInputFieldTemplate
|
||||
| StringPolymorphicInputFieldTemplate
|
||||
| StringInputFieldTemplate
|
||||
| UNetInputFieldTemplate
|
||||
| VaeInputFieldTemplate
|
||||
| VaeModelInputFieldTemplate;
|
||||
|
||||
export const isInputFieldValue = (
|
||||
field?: InputFieldValue | OutputFieldValue
|
||||
): field is InputFieldValue => Boolean(field && field.fieldKind === 'input');
|
||||
@ -966,7 +680,6 @@ export type InvocationSchemaExtra = {
|
||||
title: string;
|
||||
category?: string;
|
||||
tags?: string[];
|
||||
version?: string;
|
||||
properties: Omit<
|
||||
NonNullable<OpenAPIV3.SchemaObject['properties']> &
|
||||
(_InputField | _OutputField),
|
||||
@ -1017,22 +730,8 @@ export type InvocationSchemaObject = (
|
||||
) & { class: 'invocation' };
|
||||
|
||||
export const isSchemaObject = (
|
||||
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject | undefined
|
||||
): obj is OpenAPIV3.SchemaObject => Boolean(obj && !('$ref' in obj));
|
||||
|
||||
export const isArraySchemaObject = (
|
||||
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject | undefined
|
||||
): obj is OpenAPIV3.ArraySchemaObject =>
|
||||
Boolean(obj && !('$ref' in obj) && obj.type === 'array');
|
||||
|
||||
export const isNonArraySchemaObject = (
|
||||
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject | undefined
|
||||
): obj is OpenAPIV3.NonArraySchemaObject =>
|
||||
Boolean(obj && !('$ref' in obj) && obj.type !== 'array');
|
||||
|
||||
export const isRefObject = (
|
||||
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject | undefined
|
||||
): obj is OpenAPIV3.ReferenceObject => Boolean(obj && '$ref' in obj);
|
||||
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject
|
||||
): obj is OpenAPIV3.SchemaObject => !('$ref' in obj);
|
||||
|
||||
export const isInvocationSchemaObject = (
|
||||
obj:
|
||||
@ -1071,14 +770,12 @@ export const zCoreMetadata = z
|
||||
steps: z.number().int().nullish(),
|
||||
scheduler: z.string().nullish(),
|
||||
clip_skip: z.number().int().nullish(),
|
||||
model: z
|
||||
.union([zMainModel.deepPartial(), zOnnxModel.deepPartial()])
|
||||
.nullish(),
|
||||
controlnets: z.array(zControlField.deepPartial()).nullish(),
|
||||
model: zMainOrOnnxModel.nullish(),
|
||||
controlnets: z.array(zControlField).nullish(),
|
||||
loras: z
|
||||
.array(
|
||||
z.object({
|
||||
lora: zLoRAModelField.deepPartial(),
|
||||
lora: zLoRAModelField,
|
||||
weight: z.number(),
|
||||
})
|
||||
)
|
||||
@ -1088,41 +785,18 @@ export const zCoreMetadata = z
|
||||
init_image: z.string().nullish(),
|
||||
positive_style_prompt: z.string().nullish(),
|
||||
negative_style_prompt: z.string().nullish(),
|
||||
refiner_model: zSDXLRefinerModel.deepPartial().nullish(),
|
||||
refiner_model: zSDXLRefinerModel.nullish(),
|
||||
refiner_cfg_scale: z.number().nullish(),
|
||||
refiner_steps: z.number().int().nullish(),
|
||||
refiner_scheduler: z.string().nullish(),
|
||||
refiner_positive_aesthetic_score: z.number().nullish(),
|
||||
refiner_negative_aesthetic_score: z.number().nullish(),
|
||||
refiner_positive_aesthetic_store: z.number().nullish(),
|
||||
refiner_negative_aesthetic_store: z.number().nullish(),
|
||||
refiner_start: z.number().nullish(),
|
||||
})
|
||||
.passthrough();
|
||||
.catchall(z.record(z.any()));
|
||||
|
||||
export type CoreMetadata = z.infer<typeof zCoreMetadata>;
|
||||
|
||||
export const zSemVer = z.string().refine((val) => {
|
||||
const [major, minor, patch] = val.split('.');
|
||||
return (
|
||||
major !== undefined &&
|
||||
Number.isInteger(Number(major)) &&
|
||||
minor !== undefined &&
|
||||
Number.isInteger(Number(minor)) &&
|
||||
patch !== undefined &&
|
||||
Number.isInteger(Number(patch))
|
||||
);
|
||||
});
|
||||
|
||||
export const zParsedSemver = zSemVer.transform((val) => {
|
||||
const [major, minor, patch] = val.split('.');
|
||||
return {
|
||||
major: Number(major),
|
||||
minor: Number(minor),
|
||||
patch: Number(patch),
|
||||
};
|
||||
});
|
||||
|
||||
export type SemVer = z.infer<typeof zSemVer>;
|
||||
|
||||
export const zInvocationNodeData = z.object({
|
||||
id: z.string().trim().min(1),
|
||||
// no easy way to build this dynamically, and we don't want to anyways, because this will be used
|
||||
@ -1135,7 +809,6 @@ export const zInvocationNodeData = z.object({
|
||||
notes: z.string(),
|
||||
embedWorkflow: z.boolean(),
|
||||
isIntermediate: z.boolean(),
|
||||
version: zSemVer.optional(),
|
||||
});
|
||||
|
||||
// Massage this to get better type safety while developing
|
||||
@ -1224,6 +897,20 @@ export const zFieldIdentifier = z.object({
|
||||
|
||||
export type FieldIdentifier = z.infer<typeof zFieldIdentifier>;
|
||||
|
||||
export const zSemVer = z.string().refine((val) => {
|
||||
const [major, minor, patch] = val.split('.');
|
||||
return (
|
||||
major !== undefined &&
|
||||
minor !== undefined &&
|
||||
patch !== undefined &&
|
||||
Number.isInteger(Number(major)) &&
|
||||
Number.isInteger(Number(minor)) &&
|
||||
Number.isInteger(Number(patch))
|
||||
);
|
||||
});
|
||||
|
||||
export type SemVer = z.infer<typeof zSemVer>;
|
||||
|
||||
export type WorkflowWarning = {
|
||||
message: string;
|
||||
issues: string[];
|
||||
@ -1249,10 +936,22 @@ export const zWorkflow = z.object({
|
||||
});
|
||||
|
||||
export const zValidatedWorkflow = zWorkflow.transform((workflow) => {
|
||||
const nodeTemplates = store.getState().nodes.nodeTemplates;
|
||||
const { nodes, edges } = workflow;
|
||||
const warnings: WorkflowWarning[] = [];
|
||||
const invocationNodes = nodes.filter(isWorkflowInvocationNode);
|
||||
const keyedNodes = keyBy(invocationNodes, 'id');
|
||||
invocationNodes.forEach((node, i) => {
|
||||
const nodeTemplate = nodeTemplates[node.data.type];
|
||||
if (!nodeTemplate) {
|
||||
warnings.push({
|
||||
message: `Node "${node.data.label || node.data.id}" skipped`,
|
||||
issues: [`Unable to find template for type "${node.data.type}"`],
|
||||
data: node,
|
||||
});
|
||||
delete nodes[i];
|
||||
}
|
||||
});
|
||||
edges.forEach((edge, i) => {
|
||||
const sourceNode = keyedNodes[edge.source];
|
||||
const targetNode = keyedNodes[edge.target];
|
||||
|
@ -1,14 +1,5 @@
|
||||
import { isBoolean, isInteger, isNumber, isString } from 'lodash-es';
|
||||
import { OpenAPIV3 } from 'openapi-types';
|
||||
import {
|
||||
COLLECTION_MAP,
|
||||
POLYMORPHIC_TYPES,
|
||||
SINGLE_TO_POLYMORPHIC_MAP,
|
||||
isCollectionItemType,
|
||||
isPolymorphicItemType,
|
||||
} from '../types/constants';
|
||||
import {
|
||||
BooleanCollectionInputFieldTemplate,
|
||||
BooleanInputFieldTemplate,
|
||||
ClipInputFieldTemplate,
|
||||
CollectionInputFieldTemplate,
|
||||
@ -20,13 +11,10 @@ import {
|
||||
DenoiseMaskInputFieldTemplate,
|
||||
EnumInputFieldTemplate,
|
||||
FieldType,
|
||||
FloatCollectionInputFieldTemplate,
|
||||
FloatPolymorphicInputFieldTemplate,
|
||||
FloatInputFieldTemplate,
|
||||
ImageCollectionInputFieldTemplate,
|
||||
ImageInputFieldTemplate,
|
||||
InputFieldTemplateBase,
|
||||
IntegerCollectionInputFieldTemplate,
|
||||
IntegerInputFieldTemplate,
|
||||
InvocationFieldSchema,
|
||||
InvocationSchemaObject,
|
||||
@ -36,32 +24,11 @@ import {
|
||||
SDXLMainModelInputFieldTemplate,
|
||||
SDXLRefinerModelInputFieldTemplate,
|
||||
SchedulerInputFieldTemplate,
|
||||
StringCollectionInputFieldTemplate,
|
||||
StringInputFieldTemplate,
|
||||
UNetInputFieldTemplate,
|
||||
VaeInputFieldTemplate,
|
||||
VaeModelInputFieldTemplate,
|
||||
isArraySchemaObject,
|
||||
isNonArraySchemaObject,
|
||||
isRefObject,
|
||||
isSchemaObject,
|
||||
ControlPolymorphicInputFieldTemplate,
|
||||
ColorPolymorphicInputFieldTemplate,
|
||||
ColorCollectionInputFieldTemplate,
|
||||
IntegerPolymorphicInputFieldTemplate,
|
||||
StringPolymorphicInputFieldTemplate,
|
||||
BooleanPolymorphicInputFieldTemplate,
|
||||
ImagePolymorphicInputFieldTemplate,
|
||||
LatentsPolymorphicInputFieldTemplate,
|
||||
LatentsCollectionInputFieldTemplate,
|
||||
ConditioningPolymorphicInputFieldTemplate,
|
||||
ConditioningCollectionInputFieldTemplate,
|
||||
ControlCollectionInputFieldTemplate,
|
||||
ImageField,
|
||||
LatentsField,
|
||||
ConditioningField,
|
||||
} from '../types/types';
|
||||
import { ControlField } from 'services/api/types';
|
||||
|
||||
export type BaseFieldProperties = 'name' | 'title' | 'description';
|
||||
|
||||
@ -78,8 +45,15 @@ export type BuildInputFieldArg = {
|
||||
* @example
|
||||
* refObjectToFieldType({ "$ref": "#/components/schemas/ImageField" }) --> 'ImageField'
|
||||
*/
|
||||
export const refObjectToSchemaName = (refObject: OpenAPIV3.ReferenceObject) =>
|
||||
refObject.$ref.split('/').slice(-1)[0];
|
||||
export const refObjectToFieldType = (
|
||||
refObject: OpenAPIV3.ReferenceObject
|
||||
): FieldType => {
|
||||
const name = refObject.$ref.split('/').slice(-1)[0];
|
||||
if (!name) {
|
||||
throw `Unknown field type: ${name}`;
|
||||
}
|
||||
return name as FieldType;
|
||||
};
|
||||
|
||||
const buildIntegerInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
@ -114,57 +88,6 @@ const buildIntegerInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildIntegerPolymorphicInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): IntegerPolymorphicInputFieldTemplate => {
|
||||
const template: IntegerPolymorphicInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'IntegerPolymorphic',
|
||||
default: schemaObject.default ?? 0,
|
||||
};
|
||||
|
||||
if (schemaObject.multipleOf !== undefined) {
|
||||
template.multipleOf = schemaObject.multipleOf;
|
||||
}
|
||||
|
||||
if (schemaObject.maximum !== undefined) {
|
||||
template.maximum = schemaObject.maximum;
|
||||
}
|
||||
|
||||
if (schemaObject.exclusiveMaximum !== undefined) {
|
||||
template.exclusiveMaximum = schemaObject.exclusiveMaximum;
|
||||
}
|
||||
|
||||
if (schemaObject.minimum !== undefined) {
|
||||
template.minimum = schemaObject.minimum;
|
||||
}
|
||||
|
||||
if (schemaObject.exclusiveMinimum !== undefined) {
|
||||
template.exclusiveMinimum = schemaObject.exclusiveMinimum;
|
||||
}
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildIntegerCollectionInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): IntegerCollectionInputFieldTemplate => {
|
||||
const item_default =
|
||||
isNumber(schemaObject.item_default) && isInteger(schemaObject.item_default)
|
||||
? schemaObject.item_default
|
||||
: 0;
|
||||
const template: IntegerCollectionInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'IntegerCollection',
|
||||
default: schemaObject.default ?? [],
|
||||
item_default,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildFloatInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@ -198,54 +121,6 @@ const buildFloatInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildFloatPolymorphicInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): FloatPolymorphicInputFieldTemplate => {
|
||||
const template: FloatPolymorphicInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'FloatPolymorphic',
|
||||
default: schemaObject.default ?? 0,
|
||||
};
|
||||
if (schemaObject.multipleOf !== undefined) {
|
||||
template.multipleOf = schemaObject.multipleOf;
|
||||
}
|
||||
|
||||
if (schemaObject.maximum !== undefined) {
|
||||
template.maximum = schemaObject.maximum;
|
||||
}
|
||||
|
||||
if (schemaObject.exclusiveMaximum !== undefined) {
|
||||
template.exclusiveMaximum = schemaObject.exclusiveMaximum;
|
||||
}
|
||||
|
||||
if (schemaObject.minimum !== undefined) {
|
||||
template.minimum = schemaObject.minimum;
|
||||
}
|
||||
|
||||
if (schemaObject.exclusiveMinimum !== undefined) {
|
||||
template.exclusiveMinimum = schemaObject.exclusiveMinimum;
|
||||
}
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildFloatCollectionInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): FloatCollectionInputFieldTemplate => {
|
||||
const item_default = isNumber(schemaObject.item_default)
|
||||
? schemaObject.item_default
|
||||
: 0;
|
||||
const template: FloatCollectionInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'FloatCollection',
|
||||
default: schemaObject.default ?? [],
|
||||
item_default,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildStringInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@ -271,48 +146,6 @@ const buildStringInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildStringPolymorphicInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): StringPolymorphicInputFieldTemplate => {
|
||||
const template: StringPolymorphicInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'StringPolymorphic',
|
||||
default: schemaObject.default ?? '',
|
||||
};
|
||||
|
||||
if (schemaObject.minLength !== undefined) {
|
||||
template.minLength = schemaObject.minLength;
|
||||
}
|
||||
|
||||
if (schemaObject.maxLength !== undefined) {
|
||||
template.maxLength = schemaObject.maxLength;
|
||||
}
|
||||
|
||||
if (schemaObject.pattern !== undefined) {
|
||||
template.pattern = schemaObject.pattern;
|
||||
}
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildStringCollectionInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): StringCollectionInputFieldTemplate => {
|
||||
const item_default = isString(schemaObject.item_default)
|
||||
? schemaObject.item_default
|
||||
: '';
|
||||
const template: StringCollectionInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'StringCollection',
|
||||
default: schemaObject.default ?? [],
|
||||
item_default,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildBooleanInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@ -326,37 +159,6 @@ const buildBooleanInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildBooleanPolymorphicInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): BooleanPolymorphicInputFieldTemplate => {
|
||||
const template: BooleanPolymorphicInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'BooleanPolymorphic',
|
||||
default: schemaObject.default ?? false,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildBooleanCollectionInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): BooleanCollectionInputFieldTemplate => {
|
||||
const item_default =
|
||||
schemaObject.item_default && isBoolean(schemaObject.item_default)
|
||||
? schemaObject.item_default
|
||||
: false;
|
||||
const template: BooleanCollectionInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'BooleanCollection',
|
||||
default: schemaObject.default ?? [],
|
||||
item_default,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildMainModelInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@ -448,19 +250,6 @@ const buildImageInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildImagePolymorphicInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): ImagePolymorphicInputFieldTemplate => {
|
||||
const template: ImagePolymorphicInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'ImagePolymorphic',
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildImageCollectionInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@ -468,8 +257,7 @@ const buildImageCollectionInputFieldTemplate = ({
|
||||
const template: ImageCollectionInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'ImageCollection',
|
||||
default: schemaObject.default ?? [],
|
||||
item_default: (schemaObject.item_default as ImageField) ?? undefined,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
@ -501,33 +289,6 @@ const buildLatentsInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildLatentsPolymorphicInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): LatentsPolymorphicInputFieldTemplate => {
|
||||
const template: LatentsPolymorphicInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'LatentsPolymorphic',
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildLatentsCollectionInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): LatentsCollectionInputFieldTemplate => {
|
||||
const template: LatentsCollectionInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'LatentsCollection',
|
||||
default: schemaObject.default ?? [],
|
||||
item_default: (schemaObject.item_default as LatentsField) ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildConditioningInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@ -541,33 +302,6 @@ const buildConditioningInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildConditioningPolymorphicInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): ConditioningPolymorphicInputFieldTemplate => {
|
||||
const template: ConditioningPolymorphicInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'ConditioningPolymorphic',
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildConditioningCollectionInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): ConditioningCollectionInputFieldTemplate => {
|
||||
const template: ConditioningCollectionInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'ConditioningCollection',
|
||||
default: schemaObject.default ?? [],
|
||||
item_default: (schemaObject.item_default as ConditioningField) ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildUNetInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@ -621,33 +355,6 @@ const buildControlInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildControlPolymorphicInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): ControlPolymorphicInputFieldTemplate => {
|
||||
const template: ControlPolymorphicInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'ControlPolymorphic',
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildControlCollectionInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): ControlCollectionInputFieldTemplate => {
|
||||
const template: ControlCollectionInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'ControlCollection',
|
||||
default: schemaObject.default ?? [],
|
||||
item_default: (schemaObject.item_default as ControlField) ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildEnumInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@ -701,32 +408,6 @@ const buildColorInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildColorPolymorphicInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): ColorPolymorphicInputFieldTemplate => {
|
||||
const template: ColorPolymorphicInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'ColorPolymorphic',
|
||||
default: schemaObject.default ?? { r: 127, g: 127, b: 127, a: 255 },
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildColorCollectionInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): ColorCollectionInputFieldTemplate => {
|
||||
const template: ColorCollectionInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'ColorCollection',
|
||||
default: schemaObject.default ?? [],
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildSchedulerInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@ -740,138 +421,45 @@ const buildSchedulerInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
export const getFieldType = (
|
||||
schemaObject: InvocationFieldSchema
|
||||
): string | undefined => {
|
||||
if (schemaObject?.ui_type) {
|
||||
return schemaObject.ui_type;
|
||||
export const getFieldType = (schemaObject: InvocationFieldSchema): string => {
|
||||
let fieldType = '';
|
||||
|
||||
const { ui_type } = schemaObject;
|
||||
if (ui_type) {
|
||||
fieldType = ui_type;
|
||||
} else if (!schemaObject.type) {
|
||||
// console.log('refObject', schemaObject);
|
||||
// if schemaObject has no type, then it should have one of allOf, anyOf, oneOf
|
||||
|
||||
if (schemaObject.allOf) {
|
||||
const allOf = schemaObject.allOf;
|
||||
if (allOf && allOf[0] && isRefObject(allOf[0])) {
|
||||
return refObjectToSchemaName(allOf[0]);
|
||||
}
|
||||
fieldType = refObjectToFieldType(
|
||||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||
schemaObject.allOf![0] as OpenAPIV3.ReferenceObject
|
||||
);
|
||||
} else if (schemaObject.anyOf) {
|
||||
const anyOf = schemaObject.anyOf;
|
||||
/**
|
||||
* Handle Polymorphic inputs, eg string | string[]. In OpenAPI, this is:
|
||||
* - an `anyOf` with two items
|
||||
* - one is an `ArraySchemaObject` with a single `SchemaObject or ReferenceObject` of type T in its `items`
|
||||
* - the other is a `SchemaObject` or `ReferenceObject` of type T
|
||||
*
|
||||
* Any other cases we ignore.
|
||||
*/
|
||||
|
||||
let firstType: string | undefined;
|
||||
let secondType: string | undefined;
|
||||
|
||||
if (isArraySchemaObject(anyOf[0])) {
|
||||
// first is array, second is not
|
||||
const first = anyOf[0].items;
|
||||
const second = anyOf[1];
|
||||
if (isRefObject(first) && isRefObject(second)) {
|
||||
firstType = refObjectToSchemaName(first);
|
||||
secondType = refObjectToSchemaName(second);
|
||||
} else if (
|
||||
isNonArraySchemaObject(first) &&
|
||||
isNonArraySchemaObject(second)
|
||||
) {
|
||||
firstType = first.type;
|
||||
secondType = second.type;
|
||||
}
|
||||
} else if (isArraySchemaObject(anyOf[1])) {
|
||||
// first is not array, second is
|
||||
const first = anyOf[0];
|
||||
const second = anyOf[1].items;
|
||||
if (isRefObject(first) && isRefObject(second)) {
|
||||
firstType = refObjectToSchemaName(first);
|
||||
secondType = refObjectToSchemaName(second);
|
||||
} else if (
|
||||
isNonArraySchemaObject(first) &&
|
||||
isNonArraySchemaObject(second)
|
||||
) {
|
||||
firstType = first.type;
|
||||
secondType = second.type;
|
||||
}
|
||||
}
|
||||
if (firstType === secondType && isPolymorphicItemType(firstType)) {
|
||||
return SINGLE_TO_POLYMORPHIC_MAP[firstType];
|
||||
}
|
||||
fieldType = refObjectToFieldType(
|
||||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||
schemaObject.anyOf![0] as OpenAPIV3.ReferenceObject
|
||||
);
|
||||
} else if (schemaObject.oneOf) {
|
||||
fieldType = refObjectToFieldType(
|
||||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||
schemaObject.oneOf![0] as OpenAPIV3.ReferenceObject
|
||||
);
|
||||
}
|
||||
} else if (schemaObject.enum) {
|
||||
return 'enum';
|
||||
fieldType = 'enum';
|
||||
} else if (schemaObject.type) {
|
||||
if (schemaObject.type === 'number') {
|
||||
// floats are "number" in OpenAPI, while ints are "integer" - we need to distinguish them
|
||||
return 'float';
|
||||
} else if (schemaObject.type === 'array') {
|
||||
const itemType = isSchemaObject(schemaObject.items)
|
||||
? schemaObject.items.type
|
||||
: refObjectToSchemaName(schemaObject.items);
|
||||
|
||||
if (isCollectionItemType(itemType)) {
|
||||
return COLLECTION_MAP[itemType];
|
||||
}
|
||||
|
||||
return;
|
||||
// floats are "number" in OpenAPI, while ints are "integer"
|
||||
fieldType = 'float';
|
||||
} else {
|
||||
return schemaObject.type;
|
||||
fieldType = schemaObject.type;
|
||||
}
|
||||
}
|
||||
return;
|
||||
};
|
||||
|
||||
const TEMPLATE_BUILDER_MAP = {
|
||||
boolean: buildBooleanInputFieldTemplate,
|
||||
BooleanCollection: buildBooleanCollectionInputFieldTemplate,
|
||||
BooleanPolymorphic: buildBooleanPolymorphicInputFieldTemplate,
|
||||
ClipField: buildClipInputFieldTemplate,
|
||||
Collection: buildCollectionInputFieldTemplate,
|
||||
CollectionItem: buildCollectionItemInputFieldTemplate,
|
||||
ColorCollection: buildColorCollectionInputFieldTemplate,
|
||||
ColorField: buildColorInputFieldTemplate,
|
||||
ColorPolymorphic: buildColorPolymorphicInputFieldTemplate,
|
||||
ConditioningCollection: buildConditioningCollectionInputFieldTemplate,
|
||||
ConditioningField: buildConditioningInputFieldTemplate,
|
||||
ConditioningPolymorphic: buildConditioningPolymorphicInputFieldTemplate,
|
||||
ControlCollection: buildControlCollectionInputFieldTemplate,
|
||||
ControlField: buildControlInputFieldTemplate,
|
||||
ControlNetModelField: buildControlNetModelInputFieldTemplate,
|
||||
ControlPolymorphic: buildControlPolymorphicInputFieldTemplate,
|
||||
DenoiseMaskField: buildDenoiseMaskInputFieldTemplate,
|
||||
enum: buildEnumInputFieldTemplate,
|
||||
float: buildFloatInputFieldTemplate,
|
||||
FloatCollection: buildFloatCollectionInputFieldTemplate,
|
||||
FloatPolymorphic: buildFloatPolymorphicInputFieldTemplate,
|
||||
ImageCollection: buildImageCollectionInputFieldTemplate,
|
||||
ImageField: buildImageInputFieldTemplate,
|
||||
ImagePolymorphic: buildImagePolymorphicInputFieldTemplate,
|
||||
integer: buildIntegerInputFieldTemplate,
|
||||
IntegerCollection: buildIntegerCollectionInputFieldTemplate,
|
||||
IntegerPolymorphic: buildIntegerPolymorphicInputFieldTemplate,
|
||||
LatentsCollection: buildLatentsCollectionInputFieldTemplate,
|
||||
LatentsField: buildLatentsInputFieldTemplate,
|
||||
LatentsPolymorphic: buildLatentsPolymorphicInputFieldTemplate,
|
||||
LoRAModelField: buildLoRAModelInputFieldTemplate,
|
||||
MainModelField: buildMainModelInputFieldTemplate,
|
||||
Scheduler: buildSchedulerInputFieldTemplate,
|
||||
SDXLMainModelField: buildSDXLMainModelInputFieldTemplate,
|
||||
SDXLRefinerModelField: buildRefinerModelInputFieldTemplate,
|
||||
string: buildStringInputFieldTemplate,
|
||||
StringCollection: buildStringCollectionInputFieldTemplate,
|
||||
StringPolymorphic: buildStringPolymorphicInputFieldTemplate,
|
||||
UNetField: buildUNetInputFieldTemplate,
|
||||
VaeField: buildVaeInputFieldTemplate,
|
||||
VaeModelField: buildVaeModelInputFieldTemplate,
|
||||
return fieldType;
|
||||
};
|
||||
|
||||
const isTemplatedFieldType = (
|
||||
fieldType: string | undefined
|
||||
): fieldType is keyof typeof TEMPLATE_BUILDER_MAP =>
|
||||
Boolean(fieldType && fieldType in TEMPLATE_BUILDER_MAP);
|
||||
|
||||
/**
|
||||
* Builds an input field from an invocation schema property.
|
||||
* @param fieldSchema The schema object
|
||||
@ -886,8 +474,7 @@ export const buildInputFieldTemplate = (
|
||||
const { input, ui_hidden, ui_component, ui_type, ui_order } = fieldSchema;
|
||||
|
||||
const extra = {
|
||||
// TODO: Can we support polymorphic inputs in the UI?
|
||||
input: POLYMORPHIC_TYPES.includes(fieldType) ? 'connection' : input,
|
||||
input,
|
||||
ui_hidden,
|
||||
ui_component,
|
||||
ui_type,
|
||||
@ -903,12 +490,146 @@ export const buildInputFieldTemplate = (
|
||||
...extra,
|
||||
};
|
||||
|
||||
if (!isTemplatedFieldType(fieldType)) {
|
||||
return;
|
||||
if (fieldType === 'ImageField') {
|
||||
return buildImageInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
|
||||
return TEMPLATE_BUILDER_MAP[fieldType]({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
if (fieldType === 'ImageCollection') {
|
||||
return buildImageCollectionInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'DenoiseMaskField') {
|
||||
return buildDenoiseMaskInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'LatentsField') {
|
||||
return buildLatentsInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'ConditioningField') {
|
||||
return buildConditioningInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'UNetField') {
|
||||
return buildUNetInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'ClipField') {
|
||||
return buildClipInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'VaeField') {
|
||||
return buildVaeInputFieldTemplate({ schemaObject: fieldSchema, baseField });
|
||||
}
|
||||
if (fieldType === 'ControlField') {
|
||||
return buildControlInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'MainModelField') {
|
||||
return buildMainModelInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'SDXLRefinerModelField') {
|
||||
return buildRefinerModelInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'SDXLMainModelField') {
|
||||
return buildSDXLMainModelInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'VaeModelField') {
|
||||
return buildVaeModelInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'LoRAModelField') {
|
||||
return buildLoRAModelInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'ControlNetModelField') {
|
||||
return buildControlNetModelInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'enum') {
|
||||
return buildEnumInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'integer') {
|
||||
return buildIntegerInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'float') {
|
||||
return buildFloatInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'string') {
|
||||
return buildStringInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'boolean') {
|
||||
return buildBooleanInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'Collection') {
|
||||
return buildCollectionInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'CollectionItem') {
|
||||
return buildCollectionItemInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'ColorField') {
|
||||
return buildColorInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
if (fieldType === 'Scheduler') {
|
||||
return buildSchedulerInputFieldTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
});
|
||||
}
|
||||
return;
|
||||
};
|
||||
|
@ -1,79 +1,104 @@
|
||||
import { InputFieldTemplate, InputFieldValue } from '../types/types';
|
||||
|
||||
const FIELD_VALUE_FALLBACK_MAP = {
|
||||
'enum.number': 0,
|
||||
'enum.string': '',
|
||||
boolean: false,
|
||||
BooleanCollection: [],
|
||||
BooleanPolymorphic: false,
|
||||
ClipField: undefined,
|
||||
Collection: [],
|
||||
CollectionItem: undefined,
|
||||
ColorCollection: [],
|
||||
ColorField: undefined,
|
||||
ColorPolymorphic: undefined,
|
||||
ConditioningCollection: [],
|
||||
ConditioningField: undefined,
|
||||
ConditioningPolymorphic: undefined,
|
||||
ControlCollection: [],
|
||||
ControlField: undefined,
|
||||
ControlNetModelField: undefined,
|
||||
ControlPolymorphic: undefined,
|
||||
DenoiseMaskField: undefined,
|
||||
float: 0,
|
||||
FloatCollection: [],
|
||||
FloatPolymorphic: 0,
|
||||
ImageCollection: [],
|
||||
ImageField: undefined,
|
||||
ImagePolymorphic: undefined,
|
||||
integer: 0,
|
||||
IntegerCollection: [],
|
||||
IntegerPolymorphic: 0,
|
||||
LatentsCollection: [],
|
||||
LatentsField: undefined,
|
||||
LatentsPolymorphic: undefined,
|
||||
LoRAModelField: undefined,
|
||||
MainModelField: undefined,
|
||||
ONNXModelField: undefined,
|
||||
Scheduler: 'euler',
|
||||
SDXLMainModelField: undefined,
|
||||
SDXLRefinerModelField: undefined,
|
||||
string: '',
|
||||
StringCollection: [],
|
||||
StringPolymorphic: '',
|
||||
UNetField: undefined,
|
||||
VaeField: undefined,
|
||||
VaeModelField: undefined,
|
||||
};
|
||||
|
||||
export const buildInputFieldValue = (
|
||||
id: string,
|
||||
template: InputFieldTemplate
|
||||
): InputFieldValue => {
|
||||
// TODO: this should be `fieldValue: InputFieldValue`, but that introduces a TS issue I couldn't
|
||||
// resolve - for some reason, it doesn't like `template.type`, which is the discriminant for both
|
||||
// `InputFieldTemplate` union. It is (type-structurally) equal to the discriminant for the
|
||||
// `InputFieldValue` union, but TS doesn't seem to like it...
|
||||
const fieldValue = {
|
||||
const fieldValue: InputFieldValue = {
|
||||
id,
|
||||
name: template.name,
|
||||
type: template.type,
|
||||
label: '',
|
||||
fieldKind: 'input',
|
||||
} as InputFieldValue;
|
||||
};
|
||||
|
||||
if (template.type === 'string') {
|
||||
fieldValue.value = template.default ?? '';
|
||||
}
|
||||
|
||||
if (template.type === 'integer') {
|
||||
fieldValue.value = template.default ?? 0;
|
||||
}
|
||||
|
||||
if (template.type === 'float') {
|
||||
fieldValue.value = template.default ?? 0;
|
||||
}
|
||||
|
||||
if (template.type === 'boolean') {
|
||||
fieldValue.value = template.default ?? false;
|
||||
}
|
||||
|
||||
if (template.type === 'enum') {
|
||||
if (template.enumType === 'number') {
|
||||
fieldValue.value =
|
||||
template.default ?? FIELD_VALUE_FALLBACK_MAP['enum.number'];
|
||||
fieldValue.value = template.default ?? 0;
|
||||
}
|
||||
if (template.enumType === 'string') {
|
||||
fieldValue.value =
|
||||
template.default ?? FIELD_VALUE_FALLBACK_MAP['enum.string'];
|
||||
fieldValue.value = template.default ?? '';
|
||||
}
|
||||
} else {
|
||||
fieldValue.value =
|
||||
template.default ?? FIELD_VALUE_FALLBACK_MAP[template.type];
|
||||
}
|
||||
|
||||
if (template.type === 'Collection') {
|
||||
fieldValue.value = template.default ?? 1;
|
||||
}
|
||||
|
||||
if (template.type === 'ImageField') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'ImageCollection') {
|
||||
fieldValue.value = [];
|
||||
}
|
||||
|
||||
if (template.type === 'DenoiseMaskField') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'LatentsField') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'ConditioningField') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'UNetField') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'ClipField') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'VaeField') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'ControlField') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'MainModelField') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'SDXLRefinerModelField') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'VaeModelField') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'LoRAModelField') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'ControlNetModelField') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'Scheduler') {
|
||||
fieldValue.value = 'euler';
|
||||
}
|
||||
|
||||
return fieldValue;
|
||||
|
@ -1,6 +1,4 @@
|
||||
import * as png from '@stevebel/png';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import {
|
||||
ImageMetadataAndWorkflow,
|
||||
zCoreMetadata,
|
||||
@ -20,11 +18,6 @@ export const getMetadataAndWorkflowFromImageBlob = async (
|
||||
const metadataResult = zCoreMetadata.safeParse(JSON.parse(rawMetadata));
|
||||
if (metadataResult.success) {
|
||||
data.metadata = metadataResult.data;
|
||||
} else {
|
||||
logger('system').error(
|
||||
{ error: parseify(metadataResult.error) },
|
||||
'Problem reading metadata from image'
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@ -33,11 +26,6 @@ export const getMetadataAndWorkflowFromImageBlob = async (
|
||||
const workflowResult = zWorkflow.safeParse(JSON.parse(rawWorkflow));
|
||||
if (workflowResult.success) {
|
||||
data.workflow = workflowResult.data;
|
||||
} else {
|
||||
logger('system').error(
|
||||
{ error: parseify(workflowResult.error) },
|
||||
'Problem reading workflow from image'
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -10,8 +10,7 @@ import {
|
||||
CANVAS_OUTPUT,
|
||||
INPAINT_IMAGE_RESIZE_UP,
|
||||
LATENTS_TO_IMAGE,
|
||||
MASK_COMBINE,
|
||||
MASK_RESIZE_UP,
|
||||
MASK_BLUR,
|
||||
METADATA_ACCUMULATOR,
|
||||
SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH,
|
||||
SDXL_CANVAS_INPAINT_GRAPH,
|
||||
@ -47,8 +46,6 @@ export const addSDXLRefinerToGraph = (
|
||||
const { seamlessXAxis, seamlessYAxis, vaePrecision } = state.generation;
|
||||
const { boundingBoxScaleMethod } = state.canvas;
|
||||
|
||||
const fp32 = vaePrecision === 'fp32';
|
||||
|
||||
const isUsingScaledDimensions = ['auto', 'manual'].includes(
|
||||
boundingBoxScaleMethod
|
||||
);
|
||||
@ -63,9 +60,9 @@ export const addSDXLRefinerToGraph = (
|
||||
|
||||
if (metadataAccumulator) {
|
||||
metadataAccumulator.refiner_model = refinerModel;
|
||||
metadataAccumulator.refiner_positive_aesthetic_score =
|
||||
metadataAccumulator.refiner_positive_aesthetic_store =
|
||||
refinerPositiveAestheticScore;
|
||||
metadataAccumulator.refiner_negative_aesthetic_score =
|
||||
metadataAccumulator.refiner_negative_aesthetic_store =
|
||||
refinerNegativeAestheticScore;
|
||||
metadataAccumulator.refiner_cfg_scale = refinerCFGScale;
|
||||
metadataAccumulator.refiner_scheduler = refinerScheduler;
|
||||
@ -234,7 +231,7 @@ export const addSDXLRefinerToGraph = (
|
||||
type: 'create_denoise_mask',
|
||||
id: SDXL_REFINER_INPAINT_CREATE_MASK,
|
||||
is_intermediate: true,
|
||||
fp32,
|
||||
fp32: vaePrecision === 'fp32' ? true : false,
|
||||
};
|
||||
|
||||
if (isUsingScaledDimensions) {
|
||||
@ -260,7 +257,7 @@ export const addSDXLRefinerToGraph = (
|
||||
graph.edges.push(
|
||||
{
|
||||
source: {
|
||||
node_id: isUsingScaledDimensions ? MASK_RESIZE_UP : MASK_COMBINE,
|
||||
node_id: MASK_BLUR,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
|
@ -2,7 +2,6 @@ import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { MetadataAccumulatorInvocation } from 'services/api/types';
|
||||
import {
|
||||
CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||
CANVAS_IMAGE_TO_IMAGE_GRAPH,
|
||||
CANVAS_INPAINT_GRAPH,
|
||||
CANVAS_OUTPAINT_GRAPH,
|
||||
@ -32,7 +31,7 @@ export const addVAEToGraph = (
|
||||
graph: NonNullableGraph,
|
||||
modelLoaderNodeId: string = MAIN_MODEL_LOADER
|
||||
): void => {
|
||||
const { vae, canvasCoherenceMode } = state.generation;
|
||||
const { vae } = state.generation;
|
||||
const { boundingBoxScaleMethod } = state.canvas;
|
||||
const { shouldUseSDXLRefiner } = state.sdxl;
|
||||
|
||||
@ -147,20 +146,6 @@ export const addVAEToGraph = (
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
// Handle Coherence Mode
|
||||
if (canvasCoherenceMode !== 'unmasked') {
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: isAutoVae ? modelLoaderNodeId : VAE_LOADER,
|
||||
field: isAutoVae && isOnnxModel ? 'vae_decoder' : 'vae',
|
||||
},
|
||||
destination: {
|
||||
node_id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||
field: 'vae',
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (shouldUseSDXLRefiner) {
|
||||
|
@ -59,8 +59,6 @@ export const buildCanvasImageToImageGraph = (
|
||||
shouldAutoSave,
|
||||
} = state.canvas;
|
||||
|
||||
const fp32 = vaePrecision === 'fp32';
|
||||
|
||||
const isUsingScaledDimensions = ['auto', 'manual'].includes(
|
||||
boundingBoxScaleMethod
|
||||
);
|
||||
@ -247,7 +245,7 @@ export const buildCanvasImageToImageGraph = (
|
||||
id: LATENTS_TO_IMAGE,
|
||||
type: 'l2i',
|
||||
is_intermediate: true,
|
||||
fp32,
|
||||
fp32: vaePrecision === 'fp32' ? true : false,
|
||||
};
|
||||
graph.nodes[CANVAS_OUTPUT] = {
|
||||
id: CANVAS_OUTPUT,
|
||||
@ -294,7 +292,7 @@ export const buildCanvasImageToImageGraph = (
|
||||
type: 'l2i',
|
||||
id: CANVAS_OUTPUT,
|
||||
is_intermediate: !shouldAutoSave,
|
||||
fp32,
|
||||
fp32: vaePrecision === 'fp32' ? true : false,
|
||||
};
|
||||
|
||||
(graph.nodes[IMAGE_TO_LATENTS] as ImageToLatentsInvocation).image =
|
||||
|
@ -6,7 +6,6 @@ import {
|
||||
ImageBlurInvocation,
|
||||
ImageDTO,
|
||||
ImageToLatentsInvocation,
|
||||
MaskEdgeInvocation,
|
||||
NoiseInvocation,
|
||||
RandomIntInvocation,
|
||||
RangeOfSizeInvocation,
|
||||
@ -19,8 +18,6 @@ import { addVAEToGraph } from './addVAEToGraph';
|
||||
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
||||
import {
|
||||
CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||
CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||
CANVAS_COHERENCE_MASK_EDGE,
|
||||
CANVAS_COHERENCE_NOISE,
|
||||
CANVAS_COHERENCE_NOISE_INCREMENT,
|
||||
CANVAS_INPAINT_GRAPH,
|
||||
@ -70,7 +67,6 @@ export const buildCanvasInpaintGraph = (
|
||||
shouldUseCpuNoise,
|
||||
maskBlur,
|
||||
maskBlurMethod,
|
||||
canvasCoherenceMode,
|
||||
canvasCoherenceSteps,
|
||||
canvasCoherenceStrength,
|
||||
clipSkip,
|
||||
@ -93,12 +89,6 @@ export const buildCanvasInpaintGraph = (
|
||||
shouldAutoSave,
|
||||
} = state.canvas;
|
||||
|
||||
const fp32 = vaePrecision === 'fp32';
|
||||
|
||||
const isUsingScaledDimensions = ['auto', 'manual'].includes(
|
||||
boundingBoxScaleMethod
|
||||
);
|
||||
|
||||
let modelLoaderNodeId = MAIN_MODEL_LOADER;
|
||||
|
||||
const use_cpu = shouldUseNoiseSettings
|
||||
@ -143,7 +133,13 @@ export const buildCanvasInpaintGraph = (
|
||||
type: 'i2l',
|
||||
id: INPAINT_IMAGE,
|
||||
is_intermediate: true,
|
||||
fp32,
|
||||
fp32: vaePrecision === 'fp32' ? true : false,
|
||||
},
|
||||
[INPAINT_CREATE_MASK]: {
|
||||
type: 'create_denoise_mask',
|
||||
id: INPAINT_CREATE_MASK,
|
||||
is_intermediate: true,
|
||||
fp32: vaePrecision === 'fp32' ? true : false,
|
||||
},
|
||||
[NOISE]: {
|
||||
type: 'noise',
|
||||
@ -151,12 +147,6 @@ export const buildCanvasInpaintGraph = (
|
||||
use_cpu,
|
||||
is_intermediate: true,
|
||||
},
|
||||
[INPAINT_CREATE_MASK]: {
|
||||
type: 'create_denoise_mask',
|
||||
id: INPAINT_CREATE_MASK,
|
||||
is_intermediate: true,
|
||||
fp32,
|
||||
},
|
||||
[DENOISE_LATENTS]: {
|
||||
type: 'denoise_latents',
|
||||
id: DENOISE_LATENTS,
|
||||
@ -181,7 +171,7 @@ export const buildCanvasInpaintGraph = (
|
||||
},
|
||||
[CANVAS_COHERENCE_DENOISE_LATENTS]: {
|
||||
type: 'denoise_latents',
|
||||
id: CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||
id: DENOISE_LATENTS,
|
||||
is_intermediate: true,
|
||||
steps: canvasCoherenceSteps,
|
||||
cfg_scale: cfg_scale,
|
||||
@ -193,7 +183,7 @@ export const buildCanvasInpaintGraph = (
|
||||
type: 'l2i',
|
||||
id: LATENTS_TO_IMAGE,
|
||||
is_intermediate: true,
|
||||
fp32,
|
||||
fp32: vaePrecision === 'fp32' ? true : false,
|
||||
},
|
||||
[CANVAS_OUTPUT]: {
|
||||
type: 'color_correct',
|
||||
@ -428,7 +418,7 @@ export const buildCanvasInpaintGraph = (
|
||||
};
|
||||
|
||||
// Handle Scale Before Processing
|
||||
if (isUsingScaledDimensions) {
|
||||
if (['auto', 'manual'].includes(boundingBoxScaleMethod)) {
|
||||
const scaledWidth: number = scaledBoundingBoxDimensions.width;
|
||||
const scaledHeight: number = scaledBoundingBoxDimensions.height;
|
||||
|
||||
@ -591,116 +581,6 @@ export const buildCanvasInpaintGraph = (
|
||||
);
|
||||
}
|
||||
|
||||
// Handle Coherence Mode
|
||||
if (canvasCoherenceMode !== 'unmasked') {
|
||||
// Create Mask If Coherence Mode Is Not Full
|
||||
graph.nodes[CANVAS_COHERENCE_INPAINT_CREATE_MASK] = {
|
||||
type: 'create_denoise_mask',
|
||||
id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||
is_intermediate: true,
|
||||
fp32,
|
||||
};
|
||||
|
||||
// Handle Image Input For Mask Creation
|
||||
if (isUsingScaledDimensions) {
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: INPAINT_IMAGE_RESIZE_UP,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||
field: 'image',
|
||||
},
|
||||
});
|
||||
} else {
|
||||
graph.nodes[CANVAS_COHERENCE_INPAINT_CREATE_MASK] = {
|
||||
...(graph.nodes[
|
||||
CANVAS_COHERENCE_INPAINT_CREATE_MASK
|
||||
] as CreateDenoiseMaskInvocation),
|
||||
image: canvasInitImage,
|
||||
};
|
||||
}
|
||||
|
||||
// Create Mask If Coherence Mode Is Mask
|
||||
if (canvasCoherenceMode === 'mask') {
|
||||
if (isUsingScaledDimensions) {
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: MASK_RESIZE_UP,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||
field: 'mask',
|
||||
},
|
||||
});
|
||||
} else {
|
||||
graph.nodes[CANVAS_COHERENCE_INPAINT_CREATE_MASK] = {
|
||||
...(graph.nodes[
|
||||
CANVAS_COHERENCE_INPAINT_CREATE_MASK
|
||||
] as CreateDenoiseMaskInvocation),
|
||||
mask: canvasMaskImage,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// Create Mask Edge If Coherence Mode Is Edge
|
||||
if (canvasCoherenceMode === 'edge') {
|
||||
graph.nodes[CANVAS_COHERENCE_MASK_EDGE] = {
|
||||
type: 'mask_edge',
|
||||
id: CANVAS_COHERENCE_MASK_EDGE,
|
||||
is_intermediate: true,
|
||||
edge_blur: maskBlur,
|
||||
edge_size: maskBlur * 2,
|
||||
low_threshold: 100,
|
||||
high_threshold: 200,
|
||||
};
|
||||
|
||||
// Handle Scaled Dimensions For Mask Edge
|
||||
if (isUsingScaledDimensions) {
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: MASK_RESIZE_UP,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: CANVAS_COHERENCE_MASK_EDGE,
|
||||
field: 'image',
|
||||
},
|
||||
});
|
||||
} else {
|
||||
graph.nodes[CANVAS_COHERENCE_MASK_EDGE] = {
|
||||
...(graph.nodes[CANVAS_COHERENCE_MASK_EDGE] as MaskEdgeInvocation),
|
||||
image: canvasMaskImage,
|
||||
};
|
||||
}
|
||||
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: CANVAS_COHERENCE_MASK_EDGE,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||
field: 'mask',
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
// Plug Denoise Mask To Coherence Denoise Latents
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||
field: 'denoise_mask',
|
||||
},
|
||||
destination: {
|
||||
node_id: CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||
field: 'denoise_mask',
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
// Handle Seed
|
||||
if (shouldRandomizeSeed) {
|
||||
// Random int node to generate the starting seed
|
||||
|
@ -2,6 +2,7 @@ import { logger } from 'app/logging/logger';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import {
|
||||
ImageBlurInvocation,
|
||||
ImageDTO,
|
||||
ImageToLatentsInvocation,
|
||||
InfillPatchMatchInvocation,
|
||||
@ -18,8 +19,6 @@ import { addVAEToGraph } from './addVAEToGraph';
|
||||
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
||||
import {
|
||||
CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||
CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||
CANVAS_COHERENCE_MASK_EDGE,
|
||||
CANVAS_COHERENCE_NOISE,
|
||||
CANVAS_COHERENCE_NOISE_INCREMENT,
|
||||
CANVAS_OUTPAINT_GRAPH,
|
||||
@ -35,6 +34,7 @@ import {
|
||||
ITERATE,
|
||||
LATENTS_TO_IMAGE,
|
||||
MAIN_MODEL_LOADER,
|
||||
MASK_BLUR,
|
||||
MASK_COMBINE,
|
||||
MASK_FROM_ALPHA,
|
||||
MASK_RESIZE_DOWN,
|
||||
@ -71,11 +71,10 @@ export const buildCanvasOutpaintGraph = (
|
||||
shouldUseNoiseSettings,
|
||||
shouldUseCpuNoise,
|
||||
maskBlur,
|
||||
canvasCoherenceMode,
|
||||
maskBlurMethod,
|
||||
canvasCoherenceSteps,
|
||||
canvasCoherenceStrength,
|
||||
infillTileSize,
|
||||
infillPatchmatchDownscaleSize,
|
||||
tileSize,
|
||||
infillMethod,
|
||||
clipSkip,
|
||||
seamlessXAxis,
|
||||
@ -97,12 +96,6 @@ export const buildCanvasOutpaintGraph = (
|
||||
shouldAutoSave,
|
||||
} = state.canvas;
|
||||
|
||||
const fp32 = vaePrecision === 'fp32';
|
||||
|
||||
const isUsingScaledDimensions = ['auto', 'manual'].includes(
|
||||
boundingBoxScaleMethod
|
||||
);
|
||||
|
||||
let modelLoaderNodeId = MAIN_MODEL_LOADER;
|
||||
|
||||
const use_cpu = shouldUseNoiseSettings
|
||||
@ -148,11 +141,18 @@ export const buildCanvasOutpaintGraph = (
|
||||
is_intermediate: true,
|
||||
mask2: canvasMaskImage,
|
||||
},
|
||||
[MASK_BLUR]: {
|
||||
type: 'img_blur',
|
||||
id: MASK_BLUR,
|
||||
is_intermediate: true,
|
||||
radius: maskBlur,
|
||||
blur_type: maskBlurMethod,
|
||||
},
|
||||
[INPAINT_IMAGE]: {
|
||||
type: 'i2l',
|
||||
id: INPAINT_IMAGE,
|
||||
is_intermediate: true,
|
||||
fp32,
|
||||
fp32: vaePrecision === 'fp32' ? true : false,
|
||||
},
|
||||
[NOISE]: {
|
||||
type: 'noise',
|
||||
@ -164,7 +164,7 @@ export const buildCanvasOutpaintGraph = (
|
||||
type: 'create_denoise_mask',
|
||||
id: INPAINT_CREATE_MASK,
|
||||
is_intermediate: true,
|
||||
fp32,
|
||||
fp32: vaePrecision === 'fp32' ? true : false,
|
||||
},
|
||||
[DENOISE_LATENTS]: {
|
||||
type: 'denoise_latents',
|
||||
@ -202,7 +202,7 @@ export const buildCanvasOutpaintGraph = (
|
||||
type: 'l2i',
|
||||
id: LATENTS_TO_IMAGE,
|
||||
is_intermediate: true,
|
||||
fp32,
|
||||
fp32: vaePrecision === 'fp32' ? true : false,
|
||||
},
|
||||
[CANVAS_OUTPUT]: {
|
||||
type: 'color_correct',
|
||||
@ -333,7 +333,7 @@ export const buildCanvasOutpaintGraph = (
|
||||
// Create Inpaint Mask
|
||||
{
|
||||
source: {
|
||||
node_id: isUsingScaledDimensions ? MASK_RESIZE_UP : MASK_COMBINE,
|
||||
node_id: MASK_BLUR,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
@ -443,16 +443,6 @@ export const buildCanvasOutpaintGraph = (
|
||||
field: 'latents',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: INPAINT_INFILL,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: INPAINT_CREATE_MASK,
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
// Decode the result from Inpaint
|
||||
{
|
||||
source: {
|
||||
@ -473,7 +463,6 @@ export const buildCanvasOutpaintGraph = (
|
||||
type: 'infill_patchmatch',
|
||||
id: INPAINT_INFILL,
|
||||
is_intermediate: true,
|
||||
downscale: infillPatchmatchDownscaleSize,
|
||||
};
|
||||
}
|
||||
|
||||
@ -485,25 +474,17 @@ export const buildCanvasOutpaintGraph = (
|
||||
};
|
||||
}
|
||||
|
||||
if (infillMethod === 'cv2') {
|
||||
graph.nodes[INPAINT_INFILL] = {
|
||||
type: 'infill_cv2',
|
||||
id: INPAINT_INFILL,
|
||||
is_intermediate: true,
|
||||
};
|
||||
}
|
||||
|
||||
if (infillMethod === 'tile') {
|
||||
graph.nodes[INPAINT_INFILL] = {
|
||||
type: 'infill_tile',
|
||||
id: INPAINT_INFILL,
|
||||
is_intermediate: true,
|
||||
tile_size: infillTileSize,
|
||||
tile_size: tileSize,
|
||||
};
|
||||
}
|
||||
|
||||
// Handle Scale Before Processing
|
||||
if (isUsingScaledDimensions) {
|
||||
if (['auto', 'manual'].includes(boundingBoxScaleMethod)) {
|
||||
const scaledWidth: number = scaledBoundingBoxDimensions.width;
|
||||
const scaledHeight: number = scaledBoundingBoxDimensions.height;
|
||||
|
||||
@ -565,6 +546,16 @@ export const buildCanvasOutpaintGraph = (
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: INPAINT_INFILL,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: INPAINT_CREATE_MASK,
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
// Take combined mask and resize and then blur
|
||||
{
|
||||
source: {
|
||||
@ -576,7 +567,16 @@ export const buildCanvasOutpaintGraph = (
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_RESIZE_UP,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: MASK_BLUR,
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
// Resize Results Down
|
||||
{
|
||||
source: {
|
||||
@ -658,8 +658,32 @@ export const buildCanvasOutpaintGraph = (
|
||||
...(graph.nodes[INPAINT_IMAGE] as ImageToLatentsInvocation),
|
||||
image: canvasInitImage,
|
||||
};
|
||||
graph.nodes[MASK_BLUR] = {
|
||||
...(graph.nodes[MASK_BLUR] as ImageBlurInvocation),
|
||||
};
|
||||
|
||||
graph.edges.push(
|
||||
// Take combined mask and plug it to blur
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_COMBINE,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: MASK_BLUR,
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: INPAINT_INFILL,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: INPAINT_CREATE_MASK,
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
// Color Correct The Inpainted Result
|
||||
{
|
||||
source: {
|
||||
@ -683,7 +707,7 @@ export const buildCanvasOutpaintGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_COMBINE,
|
||||
node_id: MASK_BLUR,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
@ -694,115 +718,6 @@ export const buildCanvasOutpaintGraph = (
|
||||
);
|
||||
}
|
||||
|
||||
// Handle Coherence Mode
|
||||
if (canvasCoherenceMode !== 'unmasked') {
|
||||
graph.nodes[CANVAS_COHERENCE_INPAINT_CREATE_MASK] = {
|
||||
type: 'create_denoise_mask',
|
||||
id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||
is_intermediate: true,
|
||||
fp32,
|
||||
};
|
||||
|
||||
// Handle Image Input For Mask Creation
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: INPAINT_INFILL,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||
field: 'image',
|
||||
},
|
||||
});
|
||||
|
||||
// Create Mask If Coherence Mode Is Mask
|
||||
if (canvasCoherenceMode === 'mask') {
|
||||
if (isUsingScaledDimensions) {
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: MASK_RESIZE_UP,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||
field: 'mask',
|
||||
},
|
||||
});
|
||||
} else {
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: MASK_COMBINE,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||
field: 'mask',
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (canvasCoherenceMode === 'edge') {
|
||||
graph.nodes[CANVAS_COHERENCE_MASK_EDGE] = {
|
||||
type: 'mask_edge',
|
||||
id: CANVAS_COHERENCE_MASK_EDGE,
|
||||
is_intermediate: true,
|
||||
edge_blur: maskBlur,
|
||||
edge_size: maskBlur * 2,
|
||||
low_threshold: 100,
|
||||
high_threshold: 200,
|
||||
};
|
||||
|
||||
// Handle Scaled Dimensions For Mask Edge
|
||||
if (isUsingScaledDimensions) {
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: MASK_RESIZE_UP,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: CANVAS_COHERENCE_MASK_EDGE,
|
||||
field: 'image',
|
||||
},
|
||||
});
|
||||
} else {
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: MASK_COMBINE,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: CANVAS_COHERENCE_MASK_EDGE,
|
||||
field: 'image',
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: CANVAS_COHERENCE_MASK_EDGE,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||
field: 'mask',
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
// Plug Denoise Mask To Coherence Denoise Latents
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||
field: 'denoise_mask',
|
||||
},
|
||||
destination: {
|
||||
node_id: CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||
field: 'denoise_mask',
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
// Handle Seed
|
||||
if (shouldRandomizeSeed) {
|
||||
// Random int node to generate the starting seed
|
||||
|
@ -67,8 +67,6 @@ export const buildCanvasSDXLImageToImageGraph = (
|
||||
shouldAutoSave,
|
||||
} = state.canvas;
|
||||
|
||||
const fp32 = vaePrecision === 'fp32';
|
||||
|
||||
const isUsingScaledDimensions = ['auto', 'manual'].includes(
|
||||
boundingBoxScaleMethod
|
||||
);
|
||||
@ -135,7 +133,7 @@ export const buildCanvasSDXLImageToImageGraph = (
|
||||
type: 'i2l',
|
||||
id: IMAGE_TO_LATENTS,
|
||||
is_intermediate: true,
|
||||
fp32,
|
||||
fp32: vaePrecision === 'fp32' ? true : false,
|
||||
},
|
||||
[SDXL_DENOISE_LATENTS]: {
|
||||
type: 'denoise_latents',
|
||||
@ -260,7 +258,7 @@ export const buildCanvasSDXLImageToImageGraph = (
|
||||
id: LATENTS_TO_IMAGE,
|
||||
type: 'l2i',
|
||||
is_intermediate: true,
|
||||
fp32,
|
||||
fp32: vaePrecision === 'fp32' ? true : false,
|
||||
};
|
||||
graph.nodes[CANVAS_OUTPUT] = {
|
||||
id: CANVAS_OUTPUT,
|
||||
@ -307,7 +305,7 @@ export const buildCanvasSDXLImageToImageGraph = (
|
||||
type: 'l2i',
|
||||
id: CANVAS_OUTPUT,
|
||||
is_intermediate: !shouldAutoSave,
|
||||
fp32,
|
||||
fp32: vaePrecision === 'fp32' ? true : false,
|
||||
};
|
||||
|
||||
(graph.nodes[IMAGE_TO_LATENTS] as ImageToLatentsInvocation).image =
|
||||
|
@ -6,7 +6,6 @@ import {
|
||||
ImageBlurInvocation,
|
||||
ImageDTO,
|
||||
ImageToLatentsInvocation,
|
||||
MaskEdgeInvocation,
|
||||
NoiseInvocation,
|
||||
RandomIntInvocation,
|
||||
RangeOfSizeInvocation,
|
||||
@ -20,8 +19,6 @@ import { addVAEToGraph } from './addVAEToGraph';
|
||||
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
||||
import {
|
||||
CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||
CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||
CANVAS_COHERENCE_MASK_EDGE,
|
||||
CANVAS_COHERENCE_NOISE,
|
||||
CANVAS_COHERENCE_NOISE_INCREMENT,
|
||||
CANVAS_OUTPUT,
|
||||
@ -71,7 +68,6 @@ export const buildCanvasSDXLInpaintGraph = (
|
||||
shouldUseCpuNoise,
|
||||
maskBlur,
|
||||
maskBlurMethod,
|
||||
canvasCoherenceMode,
|
||||
canvasCoherenceSteps,
|
||||
canvasCoherenceStrength,
|
||||
seamlessXAxis,
|
||||
@ -100,12 +96,6 @@ export const buildCanvasSDXLInpaintGraph = (
|
||||
shouldAutoSave,
|
||||
} = state.canvas;
|
||||
|
||||
const fp32 = vaePrecision === 'fp32';
|
||||
|
||||
const isUsingScaledDimensions = ['auto', 'manual'].includes(
|
||||
boundingBoxScaleMethod
|
||||
);
|
||||
|
||||
let modelLoaderNodeId = SDXL_MODEL_LOADER;
|
||||
|
||||
const use_cpu = shouldUseNoiseSettings
|
||||
@ -147,7 +137,7 @@ export const buildCanvasSDXLInpaintGraph = (
|
||||
type: 'i2l',
|
||||
id: INPAINT_IMAGE,
|
||||
is_intermediate: true,
|
||||
fp32,
|
||||
fp32: vaePrecision === 'fp32' ? true : false,
|
||||
},
|
||||
[NOISE]: {
|
||||
type: 'noise',
|
||||
@ -159,7 +149,7 @@ export const buildCanvasSDXLInpaintGraph = (
|
||||
type: 'create_denoise_mask',
|
||||
id: INPAINT_CREATE_MASK,
|
||||
is_intermediate: true,
|
||||
fp32,
|
||||
fp32: vaePrecision === 'fp32' ? true : false,
|
||||
},
|
||||
[SDXL_DENOISE_LATENTS]: {
|
||||
type: 'denoise_latents',
|
||||
@ -187,7 +177,7 @@ export const buildCanvasSDXLInpaintGraph = (
|
||||
},
|
||||
[CANVAS_COHERENCE_DENOISE_LATENTS]: {
|
||||
type: 'denoise_latents',
|
||||
id: CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||
id: SDXL_DENOISE_LATENTS,
|
||||
is_intermediate: true,
|
||||
steps: canvasCoherenceSteps,
|
||||
cfg_scale: cfg_scale,
|
||||
@ -199,7 +189,7 @@ export const buildCanvasSDXLInpaintGraph = (
|
||||
type: 'l2i',
|
||||
id: LATENTS_TO_IMAGE,
|
||||
is_intermediate: true,
|
||||
fp32,
|
||||
fp32: vaePrecision === 'fp32' ? true : false,
|
||||
},
|
||||
[CANVAS_OUTPUT]: {
|
||||
type: 'color_correct',
|
||||
@ -443,7 +433,7 @@ export const buildCanvasSDXLInpaintGraph = (
|
||||
};
|
||||
|
||||
// Handle Scale Before Processing
|
||||
if (isUsingScaledDimensions) {
|
||||
if (['auto', 'manual'].includes(boundingBoxScaleMethod)) {
|
||||
const scaledWidth: number = scaledBoundingBoxDimensions.width;
|
||||
const scaledHeight: number = scaledBoundingBoxDimensions.height;
|
||||
|
||||
@ -606,116 +596,6 @@ export const buildCanvasSDXLInpaintGraph = (
|
||||
);
|
||||
}
|
||||
|
||||
// Handle Coherence Mode
|
||||
if (canvasCoherenceMode !== 'unmasked') {
|
||||
// Create Mask If Coherence Mode Is Not Full
|
||||
graph.nodes[CANVAS_COHERENCE_INPAINT_CREATE_MASK] = {
|
||||
type: 'create_denoise_mask',
|
||||
id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||
is_intermediate: true,
|
||||
fp32,
|
||||
};
|
||||
|
||||
// Handle Image Input For Mask Creation
|
||||
if (isUsingScaledDimensions) {
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: INPAINT_IMAGE_RESIZE_UP,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||
field: 'image',
|
||||
},
|
||||
});
|
||||
} else {
|
||||
graph.nodes[CANVAS_COHERENCE_INPAINT_CREATE_MASK] = {
|
||||
...(graph.nodes[
|
||||
CANVAS_COHERENCE_INPAINT_CREATE_MASK
|
||||
] as CreateDenoiseMaskInvocation),
|
||||
image: canvasInitImage,
|
||||
};
|
||||
}
|
||||
|
||||
// Create Mask If Coherence Mode Is Mask
|
||||
if (canvasCoherenceMode === 'mask') {
|
||||
if (isUsingScaledDimensions) {
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: MASK_RESIZE_UP,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||
field: 'mask',
|
||||
},
|
||||
});
|
||||
} else {
|
||||
graph.nodes[CANVAS_COHERENCE_INPAINT_CREATE_MASK] = {
|
||||
...(graph.nodes[
|
||||
CANVAS_COHERENCE_INPAINT_CREATE_MASK
|
||||
] as CreateDenoiseMaskInvocation),
|
||||
mask: canvasMaskImage,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// Create Mask Edge If Coherence Mode Is Edge
|
||||
if (canvasCoherenceMode === 'edge') {
|
||||
graph.nodes[CANVAS_COHERENCE_MASK_EDGE] = {
|
||||
type: 'mask_edge',
|
||||
id: CANVAS_COHERENCE_MASK_EDGE,
|
||||
is_intermediate: true,
|
||||
edge_blur: maskBlur,
|
||||
edge_size: maskBlur * 2,
|
||||
low_threshold: 100,
|
||||
high_threshold: 200,
|
||||
};
|
||||
|
||||
// Handle Scaled Dimensions For Mask Edge
|
||||
if (isUsingScaledDimensions) {
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: MASK_RESIZE_UP,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: CANVAS_COHERENCE_MASK_EDGE,
|
||||
field: 'image',
|
||||
},
|
||||
});
|
||||
} else {
|
||||
graph.nodes[CANVAS_COHERENCE_MASK_EDGE] = {
|
||||
...(graph.nodes[CANVAS_COHERENCE_MASK_EDGE] as MaskEdgeInvocation),
|
||||
image: canvasMaskImage,
|
||||
};
|
||||
}
|
||||
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: CANVAS_COHERENCE_MASK_EDGE,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||
field: 'mask',
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
// Plug Denoise Mask To Coherence Denoise Latents
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||
field: 'denoise_mask',
|
||||
},
|
||||
destination: {
|
||||
node_id: CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||
field: 'denoise_mask',
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
// Handle Seed
|
||||
if (shouldRandomizeSeed) {
|
||||
// Random int node to generate the starting seed
|
||||
|
@ -2,6 +2,7 @@ import { logger } from 'app/logging/logger';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import {
|
||||
ImageBlurInvocation,
|
||||
ImageDTO,
|
||||
ImageToLatentsInvocation,
|
||||
InfillPatchMatchInvocation,
|
||||
@ -19,8 +20,6 @@ import { addVAEToGraph } from './addVAEToGraph';
|
||||
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
|
||||
import {
|
||||
CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||
CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||
CANVAS_COHERENCE_MASK_EDGE,
|
||||
CANVAS_COHERENCE_NOISE,
|
||||
CANVAS_COHERENCE_NOISE_INCREMENT,
|
||||
CANVAS_OUTPUT,
|
||||
@ -32,6 +31,7 @@ import {
|
||||
INPAINT_INFILL_RESIZE_DOWN,
|
||||
ITERATE,
|
||||
LATENTS_TO_IMAGE,
|
||||
MASK_BLUR,
|
||||
MASK_COMBINE,
|
||||
MASK_FROM_ALPHA,
|
||||
MASK_RESIZE_DOWN,
|
||||
@ -72,11 +72,10 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
shouldUseNoiseSettings,
|
||||
shouldUseCpuNoise,
|
||||
maskBlur,
|
||||
canvasCoherenceMode,
|
||||
maskBlurMethod,
|
||||
canvasCoherenceSteps,
|
||||
canvasCoherenceStrength,
|
||||
infillTileSize,
|
||||
infillPatchmatchDownscaleSize,
|
||||
tileSize,
|
||||
infillMethod,
|
||||
seamlessXAxis,
|
||||
seamlessYAxis,
|
||||
@ -104,12 +103,6 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
shouldAutoSave,
|
||||
} = state.canvas;
|
||||
|
||||
const fp32 = vaePrecision === 'fp32';
|
||||
|
||||
const isUsingScaledDimensions = ['auto', 'manual'].includes(
|
||||
boundingBoxScaleMethod
|
||||
);
|
||||
|
||||
let modelLoaderNodeId = SDXL_MODEL_LOADER;
|
||||
|
||||
const use_cpu = shouldUseNoiseSettings
|
||||
@ -152,11 +145,18 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
is_intermediate: true,
|
||||
mask2: canvasMaskImage,
|
||||
},
|
||||
[MASK_BLUR]: {
|
||||
type: 'img_blur',
|
||||
id: MASK_BLUR,
|
||||
is_intermediate: true,
|
||||
radius: maskBlur,
|
||||
blur_type: maskBlurMethod,
|
||||
},
|
||||
[INPAINT_IMAGE]: {
|
||||
type: 'i2l',
|
||||
id: INPAINT_IMAGE,
|
||||
is_intermediate: true,
|
||||
fp32,
|
||||
fp32: vaePrecision === 'fp32' ? true : false,
|
||||
},
|
||||
[NOISE]: {
|
||||
type: 'noise',
|
||||
@ -168,7 +168,7 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
type: 'create_denoise_mask',
|
||||
id: INPAINT_CREATE_MASK,
|
||||
is_intermediate: true,
|
||||
fp32,
|
||||
fp32: vaePrecision === 'fp32' ? true : false,
|
||||
},
|
||||
[SDXL_DENOISE_LATENTS]: {
|
||||
type: 'denoise_latents',
|
||||
@ -208,7 +208,7 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
type: 'l2i',
|
||||
id: LATENTS_TO_IMAGE,
|
||||
is_intermediate: true,
|
||||
fp32,
|
||||
fp32: vaePrecision === 'fp32' ? true : false,
|
||||
},
|
||||
[CANVAS_OUTPUT]: {
|
||||
type: 'color_correct',
|
||||
@ -348,7 +348,7 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
// Create Inpaint Mask
|
||||
{
|
||||
source: {
|
||||
node_id: isUsingScaledDimensions ? MASK_RESIZE_UP : MASK_COMBINE,
|
||||
node_id: MASK_BLUR,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
@ -410,7 +410,7 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: modelLoaderNodeId,
|
||||
node_id: SDXL_MODEL_LOADER,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
@ -458,16 +458,6 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
field: 'latents',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: INPAINT_INFILL,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: INPAINT_CREATE_MASK,
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
// Decode inpainted latents to image
|
||||
{
|
||||
source: {
|
||||
@ -483,12 +473,12 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
};
|
||||
|
||||
// Add Infill Nodes
|
||||
|
||||
if (infillMethod === 'patchmatch') {
|
||||
graph.nodes[INPAINT_INFILL] = {
|
||||
type: 'infill_patchmatch',
|
||||
id: INPAINT_INFILL,
|
||||
is_intermediate: true,
|
||||
downscale: infillPatchmatchDownscaleSize,
|
||||
};
|
||||
}
|
||||
|
||||
@ -500,25 +490,17 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
};
|
||||
}
|
||||
|
||||
if (infillMethod === 'cv2') {
|
||||
graph.nodes[INPAINT_INFILL] = {
|
||||
type: 'infill_cv2',
|
||||
id: INPAINT_INFILL,
|
||||
is_intermediate: true,
|
||||
};
|
||||
}
|
||||
|
||||
if (infillMethod === 'tile') {
|
||||
graph.nodes[INPAINT_INFILL] = {
|
||||
type: 'infill_tile',
|
||||
id: INPAINT_INFILL,
|
||||
is_intermediate: true,
|
||||
tile_size: infillTileSize,
|
||||
tile_size: tileSize,
|
||||
};
|
||||
}
|
||||
|
||||
// Handle Scale Before Processing
|
||||
if (isUsingScaledDimensions) {
|
||||
if (['auto', 'manual'].includes(boundingBoxScaleMethod)) {
|
||||
const scaledWidth: number = scaledBoundingBoxDimensions.width;
|
||||
const scaledHeight: number = scaledBoundingBoxDimensions.height;
|
||||
|
||||
@ -580,7 +562,16 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
source: {
|
||||
node_id: INPAINT_INFILL,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: INPAINT_CREATE_MASK,
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
// Take combined mask and resize and then blur
|
||||
{
|
||||
source: {
|
||||
@ -592,7 +583,16 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_RESIZE_UP,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: MASK_BLUR,
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
// Resize Results Down
|
||||
{
|
||||
source: {
|
||||
@ -674,8 +674,32 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
...(graph.nodes[INPAINT_IMAGE] as ImageToLatentsInvocation),
|
||||
image: canvasInitImage,
|
||||
};
|
||||
graph.nodes[MASK_BLUR] = {
|
||||
...(graph.nodes[MASK_BLUR] as ImageBlurInvocation),
|
||||
};
|
||||
|
||||
graph.edges.push(
|
||||
// Take combined mask and plug it to blur
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_COMBINE,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: MASK_BLUR,
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: INPAINT_INFILL,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: INPAINT_CREATE_MASK,
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
// Color Correct The Inpainted Result
|
||||
{
|
||||
source: {
|
||||
@ -699,7 +723,7 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: MASK_COMBINE,
|
||||
node_id: MASK_BLUR,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
@ -710,116 +734,7 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
);
|
||||
}
|
||||
|
||||
// Handle Coherence Mode
|
||||
if (canvasCoherenceMode !== 'unmasked') {
|
||||
graph.nodes[CANVAS_COHERENCE_INPAINT_CREATE_MASK] = {
|
||||
type: 'create_denoise_mask',
|
||||
id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||
is_intermediate: true,
|
||||
fp32,
|
||||
};
|
||||
|
||||
// Handle Image Input For Mask Creation
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: INPAINT_INFILL,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||
field: 'image',
|
||||
},
|
||||
});
|
||||
|
||||
// Create Mask If Coherence Mode Is Mask
|
||||
if (canvasCoherenceMode === 'mask') {
|
||||
if (isUsingScaledDimensions) {
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: MASK_RESIZE_UP,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||
field: 'mask',
|
||||
},
|
||||
});
|
||||
} else {
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: MASK_COMBINE,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||
field: 'mask',
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (canvasCoherenceMode === 'edge') {
|
||||
graph.nodes[CANVAS_COHERENCE_MASK_EDGE] = {
|
||||
type: 'mask_edge',
|
||||
id: CANVAS_COHERENCE_MASK_EDGE,
|
||||
is_intermediate: true,
|
||||
edge_blur: maskBlur,
|
||||
edge_size: maskBlur * 2,
|
||||
low_threshold: 100,
|
||||
high_threshold: 200,
|
||||
};
|
||||
|
||||
// Handle Scaled Dimensions For Mask Edge
|
||||
if (isUsingScaledDimensions) {
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: MASK_RESIZE_UP,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: CANVAS_COHERENCE_MASK_EDGE,
|
||||
field: 'image',
|
||||
},
|
||||
});
|
||||
} else {
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: MASK_COMBINE,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: CANVAS_COHERENCE_MASK_EDGE,
|
||||
field: 'image',
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: CANVAS_COHERENCE_MASK_EDGE,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||
field: 'mask',
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
// Plug Denoise Mask To Coherence Denoise Latents
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
|
||||
field: 'denoise_mask',
|
||||
},
|
||||
destination: {
|
||||
node_id: CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||
field: 'denoise_mask',
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
// Handle Seed
|
||||
// Handle seed
|
||||
if (shouldRandomizeSeed) {
|
||||
// Random int node to generate the starting seed
|
||||
const randomIntNode: RandomIntInvocation = {
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user