Compare commits

...

81 Commits

Author SHA1 Message Date
2b1762d8da feat: add 'project' workflow category 2023-12-05 18:39:53 +11:00
b056a9c181 chore: bump ruff 2023-12-05 14:56:19 +11:00
f56c47b550 fix(db): remove extraneous lock 2023-12-05 14:49:00 +11:00
f8e35aec7b Merge remote-tracking branch 'origin/main' into feat/workflow-saving 2023-12-05 13:31:47 +11:00
0463541d99 dont set socketURL until socket is initialized (#5229)
* dont set socketURL until socket is initialized

* cleanup

* feat(ui): simplify `socketUrl` memo

no need to mutate the string; just return early if using baseUrl

---------

Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
2023-12-04 21:01:49 +00:00
10cf10c16c fix(db): fix bug with releasing without lock in db.clean() 2023-12-05 06:48:53 +11:00
e45704833e if response for bulk download, dont close toast 2023-12-05 06:02:01 +11:00
0fdcc0af65 feat(nodes): add index and total to iterate output 2023-12-04 14:11:32 +11:00
4fc2ed7195 Added full-version endpoint (#5206)
* Added get_app_deps endpoint

* Use importlib.version & added deps
2023-12-04 02:57:39 +00:00
d0464a5793 Tiny grammar fix 2023-12-03 08:13:40 -08:00
7d4a78e470 fix(ui): fix circular dependency 2023-12-03 20:27:40 +11:00
37c87affd0 fix(ui): fix save/save-as workflow naming 2023-12-03 20:21:47 +11:00
3863bd9da3 fix(ui): fix workflow loading
- Different handling for loading from library vs external
- Fix bug where only nodes and edges loaded
2023-12-03 20:13:44 +11:00
4b2e3aa54d fix(ui): remove commented out property 2023-12-03 19:56:15 +11:00
d699efa5bc feat(ui): use custom hook in current image buttons
Already in use elsewhere, forgot to use it here.
2023-12-03 19:54:23 +11:00
b9a1374b8f feat(backend): workflow_records.get_many arg "filter_text" -> "query" 2023-12-03 19:22:16 +11:00
411ea75861 fix(backend): revert unrelated service organisational changes 2023-12-03 19:03:29 +11:00
375c9a1c20 fix: tidy up unused files, unrelated changes 2023-12-03 19:00:01 +11:00
907340b1e1 docs: update default workflows README 2023-12-03 18:54:42 +11:00
0f32d260b7 feat(ui): split out workflow redux state
The `nodes` slice is a rather complicated slice. Removing `workflow` makes it a bit more reasonable.

Also helps to flatten state out a bit.
2023-12-02 23:06:34 +11:00
92bc04dc87 feat(ui): tweak reset workflow editor translations 2023-12-02 21:49:50 +11:00
929b1f4a41 fix(db): fix mis-ordered db cleanup step
It was happening before pruning queue items - should happen afterwards, else you have to restart the app again to free disk space made available by the pruning.
2023-12-02 21:40:01 +11:00
6d7b4b8e8a feat(ui): clean up workflow library hooks 2023-12-02 21:01:18 +11:00
4a14ee0e01 fix(tests): fix tests 2023-12-02 20:00:07 +11:00
f268ea4e39 fix(workflow_records): typo 2023-12-02 19:54:15 +11:00
78face3481 feat(ui): refine workflow list UI 2023-12-02 19:52:17 +11:00
5a0e8261bf feat(workflows): update default workflows
- Update TextToImage_SD15
- Add TextToImage_SDXL
- Add README
2023-12-02 19:51:55 +11:00
0447fa2dcb Merge remote-tracking branch 'origin/main' into feat/workflow-saving 2023-12-02 13:55:19 +11:00
fb9b471150 feat(backend): move logic to clear latents to method 2023-12-01 17:44:07 -08:00
3f0e0af177 feat(backend): only log pruned queue items / db freed space if > 0 2023-12-01 17:44:07 -08:00
0228aba06f feat(backend): display freed space when cleaning DB 2023-12-01 17:44:07 -08:00
1fd6666682 feat(backend): clear latents files on startup
Adds logic to `DiskLatentsStorage.start()` to empty the latents folder on startup.

Adds start and stop methods to `ForwardCacheLatentsStorage`. This is required for `DiskLatentsStorage.start()` to be called, due to how this particular service breaks the direct DI pattern, wrapping the underlying storage with a cache.
2023-12-01 17:44:07 -08:00
4fd163698c feat: simplify default workflows
- Rename "system" -> "default"
- Simplify syncing logic
- Update UI to match
2023-12-02 11:41:05 +11:00
cff6600ded translationBot(ui): update translation (Italian)
Currently translated at 94.4% (1248 of 1321 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2023-12-02 07:45:14 +11:00
04ddcf53f3 Set minimum numpy version to ensure that np.testing.assert_array_equal() supports the 'strict' argument. 2023-12-01 07:30:47 -08:00
0539a64569 Add support for SDXL textual inversion/embeddings (#5213)
## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [X] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [X] Yes
- [ ] No, because:

      
## Have you updated all relevant documentation?
- [X] Yes
- [ ] No


## Description

This adds support for at least some of the SDXL embeddings currently
available on Civitai. The embeddings I have tested include:

- https://civitai.com/models/154898/marblingtixl?modelVersionId=173668
- https://civitai.com/models/148131?modelVersionId=167640
-
https://civitai.com/models/123485/hannah-ferguson-or-sdxl-or-comfyui-only-or-embedding?modelVersionId=134674
(said to be "comfyui only")
-
https://civitai.com/models/185938/kendall-jenner-sdxl-embedding?modelVersionId=208785

I am _not entirely sure_ that I have implemented support in the most
elegant way. The issue is that these embeddings have two weight tensors,
`clip_g` and `clip_l`, which correspond to `text_encoder` and
`text_encoder_2` in the main model. When the patcher calls the
ModelPatcher's `apply_ti()` method, I simply check the dimensions of the
incoming text encoder and choose the weights that match the dimensions
of the encoder.

While writing this, I also ran into a possible issue with the Compel
library's `get_pooled_embeddings()` call. It pads the input token list
to the model's max token length and then calls the TI manager to add the
additional tokens from the embedding. However, this ends up making the
input token list longer than the max length, and CLIPTextEncoder crashes
with a tensor size mismatch. I worked around this behavior by making the
TI manager's `expand_textual_inversion_token_ids_if_necessary()` method
remove the excess pads at the end of the token list.

Also note that I have made similar changes to `apply_ti()` in the
ONNXModelPatcher, but haven't tested them yet.

## Related Tickets & Documents

<!--
For pull requests that relate or close an issue, please include them
below. 

For example having the text: "closes #1234" would connect the current
pull
request to issue 1234.  And when we merge the pull request, Github will
automatically close the issue.
-->

- Related Issue #
- Closes #4401 

## QA Instructions, Screenshots, Recordings

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

## Added/updated tests?

- [ ] Yes
- [X] No : We need to create tests for model patching...

## [optional] Are there any post deployment tasks we need to perform?
2023-12-01 09:17:01 -05:00
224438a108 fix: merge conflicts 2023-12-01 23:05:52 +11:00
81d2d5abae Merge remote-tracking branch 'origin/main' into feat/workflow-saving 2023-12-01 23:03:01 +11:00
734e871e8f feat(backend): sync system workflows to db 2023-12-01 22:45:17 +11:00
b0350e9bc8 feat: workflow library - system graphs - wip 2023-12-01 19:36:38 +11:00
5a3f1f2b22 fix ruff github format errors 2023-12-01 01:59:26 -05:00
f95ce1870c fix ruff format check 2023-12-01 01:46:12 -05:00
0719a46372 add support for SDXL textual inversion/embeddings 2023-12-01 01:28:28 -05:00
0a25efd054 feat: workflow library WIP
- Save to library
- Duplicate
- Filter/sort
- UI/queries
2023-12-01 15:24:22 +11:00
a8ef4e5be8 fix(ui): fix types and storage prefix 2023-12-01 09:11:48 +11:00
e6fe2540b8 dynamically create indexedDB store using unique store key if available 2023-12-01 09:11:48 +11:00
aadcde3edd feat(ui): use IndexedDB for persistence
IndexedDB has a much larger storage limit than LocalStorage, and is widely supported.

Implemented as a custom storage driver for `redux-remember` via `idb-keyval`. `idb-keyval` is a simple wrapper for IndexedDB that allows it to be used easily as a key-value store.

The logic to clear persisted storage has been updated throughout the app.
2023-12-01 09:11:48 +11:00
984e609c61 (minor) Tweak field ordering and field names for tiling nodes. 2023-11-30 07:53:27 -08:00
57e70aaf50 Change input field ordering of CropLatentsCoreInvocation to match ImageCropInvocation. 2023-11-30 07:53:27 -08:00
bfdef120d1 Re-organize merge_tiles_with_linear_blending(...) to merge rows horizontally first and then vertically. This change achieves slightly more natural blending on the corners where 4 tiles overlap. 2023-11-30 07:53:27 -08:00
32da359ba5 Infer a tight-fitting output image size from the passed tiles in MergeTilesToImageInvocation. 2023-11-30 07:53:27 -08:00
b19ed36b43 Add width and height fields to TileToPropertiesInvocation output to avoid having to calculate them with math nodes. 2023-11-30 07:53:27 -08:00
e5a212b5c8 Update tiling nodes to use width-before-height field ordering convention. 2023-11-30 07:53:27 -08:00
9b863fb9bc Rename CropLatentsInvocation -> CropLatentsCoreInvocation to prevent conflict with custom node. And other minor tidying. 2023-11-30 07:53:27 -08:00
7cab51745b Improve documentation of CropLatentsInvocation. 2023-11-30 07:53:27 -08:00
18c6ff427e Use LATENT_SCALE_FACTOR = 8 constant in CropLatentsInvocation. 2023-11-30 07:53:27 -08:00
843f2d71d6 Copy CropLatentsInvocation from 74647fa9c1/images_to_grids.py (L1117C1-L1167C80). 2023-11-30 07:53:27 -08:00
67540c9ee0 (minor) Add 'Invocation' suffix to all tiling node classes. 2023-11-30 07:53:27 -08:00
7f816c9243 Tidy up tiles invocations, add documentation. 2023-11-30 07:53:27 -08:00
76b888de17 Add unit tests for merge_tiles_with_linear_blending(...). 2023-11-30 07:53:27 -08:00
65a16be299 Add unit tests for calc_tiles_with_overlap(...) and fix a bug in its implementation. 2023-11-30 07:53:27 -08:00
1c8ff0ae66 Add unit tests for tile paste(...) util function. 2023-11-30 07:53:27 -08:00
29eade4880 Add nodes for tile splitting and merging. The main motivation for these nodes is for use in tiled upscaling workflows. 2023-11-30 07:53:27 -08:00
86fd1d5b22 translationBot(ui): update translation files
Updated by "Cleanup translation files" hook in Weblate.

Co-authored-by: Hosted Weblate <hosted@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/
Translation: InvokeAI/Web UI
2023-12-01 00:40:48 +11:00
909b78a1cb fix(ui): fix missing images not handled
- Reset init image, control adapter images, and node image fields when their selected image fails to load
- Only do this if the app is connected via socket (this indicates that the image is "really" gone, and there isn't just a transient network issue)

It's possible for image parameters/nodes/states to have reference a deleted image. For example, a resize image node might have an image set on it, and the workflow saved. The workflow contains a hard reference to that image.

The image is deleted and the workflow loaded again later. The deleted image is still in that workflow, but the app doesn't detect that. The result is that the workflow/graph appears to be valid, but will fail on invoke.

This creates a really confusing user experience, where when somebody shares a workflow with an image baked into it, and another person opens it, everything *looks* ok, but the workflow fails with a mysterious error about a missing image.

The problem affects node images, control adapter images and the img2img init image. Resetting the image when it fails to load *and* socket is connected resolves this in a simple way.

The problem also affects canvas images, but we have handle that by displaying an error fallback image, so no change is made there.
2023-12-01 00:35:06 +11:00
2f81f9fb22 fix(ui): add missing star image translation key 2023-12-01 00:33:04 +11:00
a6d4e4ed57 fix(ui): fix enum parsing for optional enums
Closes #5121

- Parse `anyOf` for enums (present when they are optional)
- Consolidate `FieldTypeParseError` and `UnsupportedFieldTypeError` into `FieldParseError` (there was no difference in handling and it simplifies things a bit)
2023-11-30 05:01:29 -08:00
46905175a9 wip 2023-11-30 19:19:58 +11:00
11085783ef feat(ui): workflow library pagination button styles 2023-11-30 16:13:41 +11:00
3d57c14bb3 feat(ui): add workflow loading, deleting to workflow library UI 2023-11-29 23:39:10 +11:00
18f3190857 chore(ui): typegen 2023-11-29 23:38:39 +11:00
fcc056fe6a feat(backend): add WorkflowRecordListItemDTO
This is the id, name, description, created at and updated at workflow columns/attrs. Used to display lists of workflowsl
2023-11-29 23:37:11 +11:00
c1bfc1f47b fix(nodes): fix get_workflow from queue item dict func 2023-11-29 15:36:28 +11:00
14bf87e5e7 feat(nodes): restore WithWorkflow as no-op class
This class is deprecated and no longer needed. Set its workflow attr value to None (meaning it is now a no-op), and issue a warning when an invocation subclasses it.
2023-11-29 13:32:24 +11:00
715ce8538b feat(nodes): do not overwrite custom node module names
Use a different, simpler method to detect if a node is custom.
2023-11-29 13:30:52 +11:00
1987bc9cc5 feat: revert SQLiteMigrator class
Will pursue this in a separate PR.
2023-11-29 13:29:19 +11:00
0b079df4ae feat(ui): updated workflow handling (WIP)
Clientside updates for the backend workflow changes.

Includes roughed-out workflow library UI.
2023-11-29 12:42:10 +11:00
a514c9e28b feat(backend): update workflows handling
Update workflows handling for Workflow Library.

**Updated Workflow Storage**

"Embedded Workflows" are workflows associated with images, and are now only stored in the image files. "Library Workflows" are not associated with images, and are stored only in DB.

This works out nicely. We have always saved workflows to files, but recently began saving them to the DB in addition to in image files. When that happened, we stopped reading workflows from files, so all the workflows that only existed in images were inaccessible. With this change, access to those workflows is restored, and no workflows are lost.

**Updated Workflow Handling in Nodes**

Prior to this change, workflows were embedded in images by passing the whole workflow JSON to a special workflow field on a node. In the node's `invoke()` function, the node was able to access this workflow and save it with the image. This (inaccurately) models workflows as a property of an image and is rather awkward technically.

A workflow is now a property of a batch/session queue item. It is available in the InvocationContext and therefore available to all nodes during `invoke()`.

**Database Migrations**

Added a `SQLiteMigrator` class to handle database migrations. Migrations were needed to accomodate the DB-related changes in this PR. See the code for details.

The `images`, `workflows` and `session_queue` tables required migrations for this PR, and are using the new migrator. Other tables/services are still creating tables themselves. A followup PR will adapt them to use the migrator.

**Other/Support Changes**

- Add a `has_workflow` column to `images` table to indicate that the image has an embedded workflow.
- Add handling for retrieving the workflow from an image in python. The image file must be fetched, the workflow extracted, and then sent to client, avoiding needing the browser to parse the image file. With the `has_workflow` column, the UI knows if there is a workflow to be fetched, and only fetches when the user requests to load the workflow.
- Add route to get the workflow from an image
- Add CRUD service/routes for the library workflows
- `workflow_images` table and services removed (no longer needed now that embedded workflows are not in the DB)
2023-11-29 12:42:10 +11:00
8cf2806489 fix(workflow_records): fix SQLite workflow insertion to ignore duplicates 2023-11-29 12:42:10 +11:00
eb446471cc fix(ui): exclude public/en.json from prettier config 2023-11-29 12:42:10 +11:00
7392d07331 chore: bump pydantic to 2.5.2
This release fixes pydantic/pydantic#8175 and allows us to use `JsonValue`
2023-11-29 12:42:10 +11:00
152 changed files with 6732 additions and 1336 deletions

View File

@ -120,7 +120,7 @@ Generate an image with a given prompt, record the seed of the image, and then
use the `prompt2prompt` syntax to substitute words in the original prompt for
words in a new prompt. This works for `img2img` as well.
For example, consider the prompt `a cat.swap(dog) playing with a ball in the forest`. Normally, because of the word words interact with each other when doing a stable diffusion image generation, these two prompts would generate different compositions:
For example, consider the prompt `a cat.swap(dog) playing with a ball in the forest`. Normally, because the words interact with each other when doing a stable diffusion image generation, these two prompts would generate different compositions:
- `a cat playing with a ball in the forest`
- `a dog playing with a ball in the forest`

View File

@ -2,7 +2,6 @@
from logging import Logger
from invokeai.app.services.workflow_image_records.workflow_image_records_sqlite import SqliteWorkflowImageRecordsStorage
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.version.invokeai_version import __version__
@ -30,7 +29,7 @@ from ..services.session_processor.session_processor_default import DefaultSessio
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
from ..services.shared.default_graphs import create_system_graphs
from ..services.shared.graph import GraphExecutionState, LibraryGraph
from ..services.shared.sqlite import SqliteDatabase
from ..services.shared.sqlite.sqlite_database import SqliteDatabase
from ..services.urls.urls_default import LocalUrlService
from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
from .events import FastAPIEventService
@ -94,7 +93,6 @@ class ApiDependencies:
session_processor = DefaultSessionProcessor()
session_queue = SqliteSessionQueue(db=db)
urls = LocalUrlService()
workflow_image_records = SqliteWorkflowImageRecordsStorage(db=db)
workflow_records = SqliteWorkflowRecordsStorage(db=db)
services = InvocationServices(
@ -121,14 +119,12 @@ class ApiDependencies:
session_processor=session_processor,
session_queue=session_queue,
urls=urls,
workflow_image_records=workflow_image_records,
workflow_records=workflow_records,
)
create_system_graphs(services.graph_library)
ApiDependencies.invoker = Invoker(services)
db.clean()
@staticmethod

View File

@ -1,7 +1,11 @@
import typing
from enum import Enum
from importlib.metadata import PackageNotFoundError, version
from pathlib import Path
from platform import python_version
from typing import Optional
import torch
from fastapi import Body
from fastapi.routing import APIRouter
from pydantic import BaseModel, Field
@ -40,6 +44,24 @@ class AppVersion(BaseModel):
version: str = Field(description="App version")
class AppDependencyVersions(BaseModel):
"""App depencency Versions Response"""
accelerate: str = Field(description="accelerate version")
compel: str = Field(description="compel version")
cuda: Optional[str] = Field(description="CUDA version")
diffusers: str = Field(description="diffusers version")
numpy: str = Field(description="Numpy version")
opencv: str = Field(description="OpenCV version")
onnx: str = Field(description="ONNX version")
pillow: str = Field(description="Pillow (PIL) version")
python: str = Field(description="Python version")
torch: str = Field(description="PyTorch version")
torchvision: str = Field(description="PyTorch Vision version")
transformers: str = Field(description="transformers version")
xformers: Optional[str] = Field(description="xformers version")
class AppConfig(BaseModel):
"""App Config Response"""
@ -54,6 +76,29 @@ async def get_version() -> AppVersion:
return AppVersion(version=__version__)
@app_router.get("/app_deps", operation_id="get_app_deps", status_code=200, response_model=AppDependencyVersions)
async def get_app_deps() -> AppDependencyVersions:
try:
xformers = version("xformers")
except PackageNotFoundError:
xformers = None
return AppDependencyVersions(
accelerate=version("accelerate"),
compel=version("compel"),
cuda=torch.version.cuda,
diffusers=version("diffusers"),
numpy=version("numpy"),
opencv=version("opencv-python"),
onnx=version("onnx"),
pillow=version("pillow"),
python=python_version(),
torch=torch.version.__version__,
torchvision=version("torchvision"),
transformers=version("transformers"),
xformers=xformers,
)
@app_router.get("/config", operation_id="get_config", status_code=200, response_model=AppConfig)
async def get_config() -> AppConfig:
infill_methods = ["tile", "lama", "cv2"]

View File

@ -8,10 +8,11 @@ from fastapi.routing import APIRouter
from PIL import Image
from pydantic import BaseModel, Field, ValidationError
from invokeai.app.invocations.baseinvocation import MetadataField, MetadataFieldValidator, WorkflowFieldValidator
from invokeai.app.invocations.baseinvocation import MetadataField, MetadataFieldValidator
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID, WorkflowWithoutIDValidator
from ..dependencies import ApiDependencies
@ -73,7 +74,7 @@ async def upload_image(
workflow_raw = pil_image.info.get("invokeai_workflow", None)
if workflow_raw is not None:
try:
workflow = WorkflowFieldValidator.validate_json(workflow_raw)
workflow = WorkflowWithoutIDValidator.validate_json(workflow_raw)
except ValidationError:
ApiDependencies.invoker.services.logger.warn("Failed to parse metadata for uploaded image")
pass
@ -184,6 +185,18 @@ async def get_image_metadata(
raise HTTPException(status_code=404)
@images_router.get(
"/i/{image_name}/workflow", operation_id="get_image_workflow", response_model=Optional[WorkflowWithoutID]
)
async def get_image_workflow(
image_name: str = Path(description="The name of image whose workflow to get"),
) -> Optional[WorkflowWithoutID]:
try:
return ApiDependencies.invoker.services.images.get_workflow(image_name)
except Exception:
raise HTTPException(status_code=404)
@images_router.api_route(
"/i/{image_name}/full",
methods=["GET", "HEAD"],

View File

@ -141,7 +141,7 @@ async def del_model_record(
status_code=201,
)
async def add_model_record(
config: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")]
config: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")],
) -> AnyModelConfig:
"""
Add a model using the configuration information appropriate for its type.

View File

@ -1,7 +1,19 @@
from fastapi import APIRouter, Path
from typing import Optional
from fastapi import APIRouter, Body, HTTPException, Path, Query
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.invocations.baseinvocation import WorkflowField
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.services.workflow_records.workflow_records_common import (
Workflow,
WorkflowCategory,
WorkflowNotFoundError,
WorkflowRecordDTO,
WorkflowRecordListItemDTO,
WorkflowRecordOrderBy,
WorkflowWithoutID,
)
workflows_router = APIRouter(prefix="/v1/workflows", tags=["workflows"])
@ -10,11 +22,76 @@ workflows_router = APIRouter(prefix="/v1/workflows", tags=["workflows"])
"/i/{workflow_id}",
operation_id="get_workflow",
responses={
200: {"model": WorkflowField},
200: {"model": WorkflowRecordDTO},
},
)
async def get_workflow(
workflow_id: str = Path(description="The workflow to get"),
) -> WorkflowField:
) -> WorkflowRecordDTO:
"""Gets a workflow"""
return ApiDependencies.invoker.services.workflow_records.get(workflow_id)
try:
return ApiDependencies.invoker.services.workflow_records.get(workflow_id)
except WorkflowNotFoundError:
raise HTTPException(status_code=404, detail="Workflow not found")
@workflows_router.patch(
"/i/{workflow_id}",
operation_id="update_workflow",
responses={
200: {"model": WorkflowRecordDTO},
},
)
async def update_workflow(
workflow: Workflow = Body(description="The updated workflow", embed=True),
) -> WorkflowRecordDTO:
"""Updates a workflow"""
return ApiDependencies.invoker.services.workflow_records.update(workflow=workflow)
@workflows_router.delete(
"/i/{workflow_id}",
operation_id="delete_workflow",
)
async def delete_workflow(
workflow_id: str = Path(description="The workflow to delete"),
) -> None:
"""Deletes a workflow"""
ApiDependencies.invoker.services.workflow_records.delete(workflow_id)
@workflows_router.post(
"/",
operation_id="create_workflow",
responses={
200: {"model": WorkflowRecordDTO},
},
)
async def create_workflow(
workflow: WorkflowWithoutID = Body(description="The workflow to create", embed=True),
) -> WorkflowRecordDTO:
"""Creates a workflow"""
return ApiDependencies.invoker.services.workflow_records.create(workflow=workflow)
@workflows_router.get(
"/",
operation_id="list_workflows",
responses={
200: {"model": PaginatedResults[WorkflowRecordListItemDTO]},
},
)
async def list_workflows(
page: int = Query(default=0, description="The page to get"),
per_page: int = Query(default=10, description="The number of workflows per page"),
order_by: WorkflowRecordOrderBy = Query(
default=WorkflowRecordOrderBy.Name, description="The attribute to order by"
),
direction: SQLiteDirection = Query(default=SQLiteDirection.Ascending, description="The direction to order by"),
category: WorkflowCategory = Query(default=WorkflowCategory.User, description="The category of workflow to get"),
query: Optional[str] = Query(default=None, description="The text to query by (matches name and description)"),
) -> PaginatedResults[WorkflowRecordListItemDTO]:
"""Gets a page of workflows"""
return ApiDependencies.invoker.services.workflow_records.get_many(
page=page, per_page=per_page, order_by=order_by, direction=direction, query=query, category=category
)

View File

@ -16,6 +16,7 @@ from pydantic.fields import FieldInfo, _Unset
from pydantic_core import PydanticUndefined
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
from invokeai.app.shared.fields import FieldDescriptions
from invokeai.app.util.metaenum import MetaEnum
from invokeai.app.util.misc import uuid_string
@ -452,6 +453,7 @@ class InvocationContext:
queue_id: str
queue_item_id: int
queue_batch_id: str
workflow: Optional[WorkflowWithoutID]
def __init__(
self,
@ -460,12 +462,14 @@ class InvocationContext:
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
workflow: Optional[WorkflowWithoutID],
):
self.services = services
self.graph_execution_state_id = graph_execution_state_id
self.queue_id = queue_id
self.queue_item_id = queue_item_id
self.queue_batch_id = queue_batch_id
self.workflow = workflow
class BaseInvocationOutput(BaseModel):
@ -807,9 +811,9 @@ def invocation(
cls.UIConfig.category = category
# Grab the node pack's name from the module name, if it's a custom node
module_name = cls.__module__.split(".")[0]
if module_name.endswith(CUSTOM_NODE_PACK_SUFFIX):
cls.UIConfig.node_pack = module_name.split(CUSTOM_NODE_PACK_SUFFIX)[0]
is_custom_node = cls.__module__.rsplit(".", 1)[0] == "invokeai.app.invocations"
if is_custom_node:
cls.UIConfig.node_pack = cls.__module__.split(".")[0]
else:
cls.UIConfig.node_pack = None
@ -903,24 +907,6 @@ def invocation_output(
return wrapper
class WorkflowField(RootModel):
"""
Pydantic model for workflows with custom root of type dict[str, Any].
Workflows are stored without a strict schema.
"""
root: dict[str, Any] = Field(description="The workflow")
WorkflowFieldValidator = TypeAdapter(WorkflowField)
class WithWorkflow(BaseModel):
workflow: Optional[WorkflowField] = Field(
default=None, description=FieldDescriptions.workflow, json_schema_extra={"field_kind": FieldKind.NodeAttribute}
)
class MetadataField(RootModel):
"""
Pydantic model for metadata with custom root of type dict[str, Any].
@ -943,3 +929,13 @@ class WithMetadata(BaseModel):
orig_required=False,
).model_dump(exclude_none=True),
)
class WithWorkflow:
workflow = None
def __init_subclass__(cls) -> None:
logger.warn(
f"{cls.__module__.split('.')[0]}.{cls.__name__}: WithWorkflow is deprecated. Use `context.workflow` to access the workflow."
)
super().__init_subclass__()

View File

@ -39,7 +39,6 @@ from .baseinvocation import (
InvocationContext,
OutputField,
WithMetadata,
WithWorkflow,
invocation,
invocation_output,
)
@ -129,7 +128,7 @@ class ControlNetInvocation(BaseInvocation):
# This invocation exists for other invocations to subclass it - do not register with @invocation!
class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithWorkflow):
class ImageProcessorInvocation(BaseInvocation, WithMetadata):
"""Base class for invocations that preprocess images for ControlNet"""
image: ImageField = InputField(description="The image to process")
@ -153,7 +152,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithWorkflow):
node_id=self.id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
"""Builds an ImageOutput and its ImageField"""
@ -173,7 +172,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithWorkflow):
title="Canny Processor",
tags=["controlnet", "canny"],
category="controlnet",
version="1.1.0",
version="1.2.0",
)
class CannyImageProcessorInvocation(ImageProcessorInvocation):
"""Canny edge detection for ControlNet"""
@ -196,7 +195,7 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation):
title="HED (softedge) Processor",
tags=["controlnet", "hed", "softedge"],
category="controlnet",
version="1.1.0",
version="1.2.0",
)
class HedImageProcessorInvocation(ImageProcessorInvocation):
"""Applies HED edge detection to image"""
@ -225,7 +224,7 @@ class HedImageProcessorInvocation(ImageProcessorInvocation):
title="Lineart Processor",
tags=["controlnet", "lineart"],
category="controlnet",
version="1.1.0",
version="1.2.0",
)
class LineartImageProcessorInvocation(ImageProcessorInvocation):
"""Applies line art processing to image"""
@ -247,7 +246,7 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation):
title="Lineart Anime Processor",
tags=["controlnet", "lineart", "anime"],
category="controlnet",
version="1.1.0",
version="1.2.0",
)
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
"""Applies line art anime processing to image"""
@ -270,7 +269,7 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
title="Openpose Processor",
tags=["controlnet", "openpose", "pose"],
category="controlnet",
version="1.1.0",
version="1.2.0",
)
class OpenposeImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Openpose processing to image"""
@ -295,7 +294,7 @@ class OpenposeImageProcessorInvocation(ImageProcessorInvocation):
title="Midas Depth Processor",
tags=["controlnet", "midas"],
category="controlnet",
version="1.1.0",
version="1.2.0",
)
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Midas depth processing to image"""
@ -322,7 +321,7 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
title="Normal BAE Processor",
tags=["controlnet"],
category="controlnet",
version="1.1.0",
version="1.2.0",
)
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
"""Applies NormalBae processing to image"""
@ -339,7 +338,7 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
@invocation(
"mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet", version="1.1.0"
"mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet", version="1.2.0"
)
class MlsdImageProcessorInvocation(ImageProcessorInvocation):
"""Applies MLSD processing to image"""
@ -362,7 +361,7 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation):
@invocation(
"pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet", version="1.1.0"
"pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet", version="1.2.0"
)
class PidiImageProcessorInvocation(ImageProcessorInvocation):
"""Applies PIDI processing to image"""
@ -389,7 +388,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation):
title="Content Shuffle Processor",
tags=["controlnet", "contentshuffle"],
category="controlnet",
version="1.1.0",
version="1.2.0",
)
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
"""Applies content shuffle processing to image"""
@ -419,7 +418,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
title="Zoe (Depth) Processor",
tags=["controlnet", "zoe", "depth"],
category="controlnet",
version="1.1.0",
version="1.2.0",
)
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Zoe depth processing to image"""
@ -435,7 +434,7 @@ class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
title="Mediapipe Face Processor",
tags=["controlnet", "mediapipe", "face"],
category="controlnet",
version="1.1.0",
version="1.2.0",
)
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
"""Applies mediapipe face processing to image"""
@ -458,7 +457,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
title="Leres (Depth) Processor",
tags=["controlnet", "leres", "depth"],
category="controlnet",
version="1.1.0",
version="1.2.0",
)
class LeresImageProcessorInvocation(ImageProcessorInvocation):
"""Applies leres processing to image"""
@ -487,7 +486,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
title="Tile Resample Processor",
tags=["controlnet", "tile"],
category="controlnet",
version="1.1.0",
version="1.2.0",
)
class TileResamplerProcessorInvocation(ImageProcessorInvocation):
"""Tile resampler processor"""
@ -527,7 +526,7 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation):
title="Segment Anything Processor",
tags=["controlnet", "segmentanything"],
category="controlnet",
version="1.1.0",
version="1.2.0",
)
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
"""Applies segment anything processing to image"""
@ -569,7 +568,7 @@ class SamDetectorReproducibleColors(SamDetector):
title="Color Map Processor",
tags=["controlnet"],
category="controlnet",
version="1.1.0",
version="1.2.0",
)
class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
"""Generates a color map from the provided image"""

View File

@ -6,7 +6,6 @@ import sys
from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path
from invokeai.app.invocations.baseinvocation import CUSTOM_NODE_PACK_SUFFIX
from invokeai.backend.util.logging import InvokeAILogger
logger = InvokeAILogger.get_logger()
@ -34,7 +33,7 @@ for d in Path(__file__).parent.iterdir():
continue
# load the module, appending adding a suffix to identify it as a custom node pack
spec = spec_from_file_location(f"{module_name}{CUSTOM_NODE_PACK_SUFFIX}", init.absolute())
spec = spec_from_file_location(module_name, init.absolute())
if spec is None or spec.loader is None:
logger.warn(f"Could not load {init}")

View File

@ -8,11 +8,11 @@ from PIL import Image, ImageOps
from invokeai.app.invocations.primitives import ImageField, ImageOutput
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, WithWorkflow, invocation
from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, invocation
@invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.1.0")
class CvInpaintInvocation(BaseInvocation, WithMetadata, WithWorkflow):
@invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.2.0")
class CvInpaintInvocation(BaseInvocation, WithMetadata):
"""Simple inpaint using opencv."""
image: ImageField = InputField(description="The image to inpaint")
@ -41,7 +41,7 @@ class CvInpaintInvocation(BaseInvocation, WithMetadata, WithWorkflow):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(

View File

@ -17,7 +17,6 @@ from invokeai.app.invocations.baseinvocation import (
InvocationContext,
OutputField,
WithMetadata,
WithWorkflow,
invocation,
invocation_output,
)
@ -438,8 +437,8 @@ def get_faces_list(
return all_faces
@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.1.0")
class FaceOffInvocation(BaseInvocation, WithWorkflow, WithMetadata):
@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.2.0")
class FaceOffInvocation(BaseInvocation, WithMetadata):
"""Bound, extract, and mask a face from an image using MediaPipe detection"""
image: ImageField = InputField(description="Image for face detection")
@ -508,7 +507,7 @@ class FaceOffInvocation(BaseInvocation, WithWorkflow, WithMetadata):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
workflow=context.workflow,
)
mask_dto = context.services.images.create(
@ -532,8 +531,8 @@ class FaceOffInvocation(BaseInvocation, WithWorkflow, WithMetadata):
return output
@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.1.0")
class FaceMaskInvocation(BaseInvocation, WithWorkflow, WithMetadata):
@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.2.0")
class FaceMaskInvocation(BaseInvocation, WithMetadata):
"""Face mask creation using mediapipe face detection"""
image: ImageField = InputField(description="Image to face detect")
@ -627,7 +626,7 @@ class FaceMaskInvocation(BaseInvocation, WithWorkflow, WithMetadata):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
workflow=context.workflow,
)
mask_dto = context.services.images.create(
@ -650,9 +649,9 @@ class FaceMaskInvocation(BaseInvocation, WithWorkflow, WithMetadata):
@invocation(
"face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.1.0"
"face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.2.0"
)
class FaceIdentifierInvocation(BaseInvocation, WithWorkflow, WithMetadata):
class FaceIdentifierInvocation(BaseInvocation, WithMetadata):
"""Outputs an image with detected face IDs printed on each face. For use with other FaceTools."""
image: ImageField = InputField(description="Image to face detect")
@ -716,7 +715,7 @@ class FaceIdentifierInvocation(BaseInvocation, WithWorkflow, WithMetadata):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(

View File

@ -13,7 +13,7 @@ from invokeai.app.shared.fields import FieldDescriptions
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
from invokeai.backend.image_util.safety_checker import SafetyChecker
from .baseinvocation import BaseInvocation, Input, InputField, InvocationContext, WithMetadata, WithWorkflow, invocation
from .baseinvocation import BaseInvocation, Input, InputField, InvocationContext, WithMetadata, invocation
@invocation("show_image", title="Show Image", tags=["image"], category="image", version="1.0.0")
@ -36,8 +36,14 @@ class ShowImageInvocation(BaseInvocation):
)
@invocation("blank_image", title="Blank Image", tags=["image"], category="image", version="1.1.0")
class BlankImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
@invocation(
"blank_image",
title="Blank Image",
tags=["image"],
category="image",
version="1.2.0",
)
class BlankImageInvocation(BaseInvocation, WithMetadata):
"""Creates a blank image and forwards it to the pipeline"""
width: int = InputField(default=512, description="The width of the image")
@ -56,7 +62,7 @@ class BlankImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -66,8 +72,14 @@ class BlankImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
)
@invocation("img_crop", title="Crop Image", tags=["image", "crop"], category="image", version="1.1.0")
class ImageCropInvocation(BaseInvocation, WithWorkflow, WithMetadata):
@invocation(
"img_crop",
title="Crop Image",
tags=["image", "crop"],
category="image",
version="1.2.0",
)
class ImageCropInvocation(BaseInvocation, WithMetadata):
"""Crops an image to a specified box. The box can be outside of the image."""
image: ImageField = InputField(description="The image to crop")
@ -90,7 +102,7 @@ class ImageCropInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -101,11 +113,11 @@ class ImageCropInvocation(BaseInvocation, WithWorkflow, WithMetadata):
@invocation(
invocation_type="img_pad_crop",
title="Center Pad or Crop Image",
"img_paste",
title="Paste Image",
tags=["image", "paste"],
category="image",
tags=["image", "pad", "crop"],
version="1.0.0",
version="1.2.0",
)
class CenterPadCropInvocation(BaseInvocation):
"""Pad or crop an image's sides from the center by specified pixels. Positive values are outside of the image."""
@ -155,8 +167,14 @@ class CenterPadCropInvocation(BaseInvocation):
)
@invocation("img_paste", title="Paste Image", tags=["image", "paste"], category="image", version="1.1.0")
class ImagePasteInvocation(BaseInvocation, WithWorkflow, WithMetadata):
@invocation(
invocation_type="img_pad_crop",
title="Center Pad or Crop Image",
category="image",
tags=["image", "pad", "crop"],
version="1.0.0",
)
class ImagePasteInvocation(BaseInvocation, WithMetadata):
"""Pastes an image into another image."""
base_image: ImageField = InputField(description="The base image")
@ -199,7 +217,7 @@ class ImagePasteInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -209,8 +227,14 @@ class ImagePasteInvocation(BaseInvocation, WithWorkflow, WithMetadata):
)
@invocation("tomask", title="Mask from Alpha", tags=["image", "mask"], category="image", version="1.1.0")
class MaskFromAlphaInvocation(BaseInvocation, WithWorkflow, WithMetadata):
@invocation(
"tomask",
title="Mask from Alpha",
tags=["image", "mask"],
category="image",
version="1.2.0",
)
class MaskFromAlphaInvocation(BaseInvocation, WithMetadata):
"""Extracts the alpha channel of an image as a mask."""
image: ImageField = InputField(description="The image to create the mask from")
@ -231,7 +255,7 @@ class MaskFromAlphaInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -241,8 +265,14 @@ class MaskFromAlphaInvocation(BaseInvocation, WithWorkflow, WithMetadata):
)
@invocation("img_mul", title="Multiply Images", tags=["image", "multiply"], category="image", version="1.1.0")
class ImageMultiplyInvocation(BaseInvocation, WithWorkflow, WithMetadata):
@invocation(
"img_mul",
title="Multiply Images",
tags=["image", "multiply"],
category="image",
version="1.2.0",
)
class ImageMultiplyInvocation(BaseInvocation, WithMetadata):
"""Multiplies two images together using `PIL.ImageChops.multiply()`."""
image1: ImageField = InputField(description="The first image to multiply")
@ -262,7 +292,7 @@ class ImageMultiplyInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -275,8 +305,14 @@ class ImageMultiplyInvocation(BaseInvocation, WithWorkflow, WithMetadata):
IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
@invocation("img_chan", title="Extract Image Channel", tags=["image", "channel"], category="image", version="1.1.0")
class ImageChannelInvocation(BaseInvocation, WithWorkflow, WithMetadata):
@invocation(
"img_chan",
title="Extract Image Channel",
tags=["image", "channel"],
category="image",
version="1.2.0",
)
class ImageChannelInvocation(BaseInvocation, WithMetadata):
"""Gets a channel from an image."""
image: ImageField = InputField(description="The image to get the channel from")
@ -295,7 +331,7 @@ class ImageChannelInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -308,8 +344,14 @@ class ImageChannelInvocation(BaseInvocation, WithWorkflow, WithMetadata):
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
@invocation("img_conv", title="Convert Image Mode", tags=["image", "convert"], category="image", version="1.1.0")
class ImageConvertInvocation(BaseInvocation, WithWorkflow, WithMetadata):
@invocation(
"img_conv",
title="Convert Image Mode",
tags=["image", "convert"],
category="image",
version="1.2.0",
)
class ImageConvertInvocation(BaseInvocation, WithMetadata):
"""Converts an image to a different mode."""
image: ImageField = InputField(description="The image to convert")
@ -328,7 +370,7 @@ class ImageConvertInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -338,8 +380,14 @@ class ImageConvertInvocation(BaseInvocation, WithWorkflow, WithMetadata):
)
@invocation("img_blur", title="Blur Image", tags=["image", "blur"], category="image", version="1.1.0")
class ImageBlurInvocation(BaseInvocation, WithWorkflow, WithMetadata):
@invocation(
"img_blur",
title="Blur Image",
tags=["image", "blur"],
category="image",
version="1.2.0",
)
class ImageBlurInvocation(BaseInvocation, WithMetadata):
"""Blurs an image"""
image: ImageField = InputField(description="The image to blur")
@ -363,7 +411,7 @@ class ImageBlurInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -393,8 +441,14 @@ PIL_RESAMPLING_MAP = {
}
@invocation("img_resize", title="Resize Image", tags=["image", "resize"], category="image", version="1.1.0")
class ImageResizeInvocation(BaseInvocation, WithMetadata, WithWorkflow):
@invocation(
"img_resize",
title="Resize Image",
tags=["image", "resize"],
category="image",
version="1.2.0",
)
class ImageResizeInvocation(BaseInvocation, WithMetadata):
"""Resizes an image to specific dimensions"""
image: ImageField = InputField(description="The image to resize")
@ -420,7 +474,7 @@ class ImageResizeInvocation(BaseInvocation, WithMetadata, WithWorkflow):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -430,8 +484,14 @@ class ImageResizeInvocation(BaseInvocation, WithMetadata, WithWorkflow):
)
@invocation("img_scale", title="Scale Image", tags=["image", "scale"], category="image", version="1.1.0")
class ImageScaleInvocation(BaseInvocation, WithMetadata, WithWorkflow):
@invocation(
"img_scale",
title="Scale Image",
tags=["image", "scale"],
category="image",
version="1.2.0",
)
class ImageScaleInvocation(BaseInvocation, WithMetadata):
"""Scales an image by a factor"""
image: ImageField = InputField(description="The image to scale")
@ -462,7 +522,7 @@ class ImageScaleInvocation(BaseInvocation, WithMetadata, WithWorkflow):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -472,8 +532,14 @@ class ImageScaleInvocation(BaseInvocation, WithMetadata, WithWorkflow):
)
@invocation("img_lerp", title="Lerp Image", tags=["image", "lerp"], category="image", version="1.1.0")
class ImageLerpInvocation(BaseInvocation, WithWorkflow, WithMetadata):
@invocation(
"img_lerp",
title="Lerp Image",
tags=["image", "lerp"],
category="image",
version="1.2.0",
)
class ImageLerpInvocation(BaseInvocation, WithMetadata):
"""Linear interpolation of all pixels of an image"""
image: ImageField = InputField(description="The image to lerp")
@ -496,7 +562,7 @@ class ImageLerpInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -506,8 +572,14 @@ class ImageLerpInvocation(BaseInvocation, WithWorkflow, WithMetadata):
)
@invocation("img_ilerp", title="Inverse Lerp Image", tags=["image", "ilerp"], category="image", version="1.1.0")
class ImageInverseLerpInvocation(BaseInvocation, WithWorkflow, WithMetadata):
@invocation(
"img_ilerp",
title="Inverse Lerp Image",
tags=["image", "ilerp"],
category="image",
version="1.2.0",
)
class ImageInverseLerpInvocation(BaseInvocation, WithMetadata):
"""Inverse linear interpolation of all pixels of an image"""
image: ImageField = InputField(description="The image to lerp")
@ -530,7 +602,7 @@ class ImageInverseLerpInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -540,8 +612,14 @@ class ImageInverseLerpInvocation(BaseInvocation, WithWorkflow, WithMetadata):
)
@invocation("img_nsfw", title="Blur NSFW Image", tags=["image", "nsfw"], category="image", version="1.1.0")
class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata, WithWorkflow):
@invocation(
"img_nsfw",
title="Blur NSFW Image",
tags=["image", "nsfw"],
category="image",
version="1.2.0",
)
class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata):
"""Add blur to NSFW-flagged images"""
image: ImageField = InputField(description="The image to check")
@ -566,7 +644,7 @@ class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata, WithWorkflow):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -587,9 +665,9 @@ class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata, WithWorkflow):
title="Add Invisible Watermark",
tags=["image", "watermark"],
category="image",
version="1.1.0",
version="1.2.0",
)
class ImageWatermarkInvocation(BaseInvocation, WithMetadata, WithWorkflow):
class ImageWatermarkInvocation(BaseInvocation, WithMetadata):
"""Add an invisible watermark to an image"""
image: ImageField = InputField(description="The image to check")
@ -606,7 +684,7 @@ class ImageWatermarkInvocation(BaseInvocation, WithMetadata, WithWorkflow):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -616,8 +694,14 @@ class ImageWatermarkInvocation(BaseInvocation, WithMetadata, WithWorkflow):
)
@invocation("mask_edge", title="Mask Edge", tags=["image", "mask", "inpaint"], category="image", version="1.1.0")
class MaskEdgeInvocation(BaseInvocation, WithWorkflow, WithMetadata):
@invocation(
"mask_edge",
title="Mask Edge",
tags=["image", "mask", "inpaint"],
category="image",
version="1.2.0",
)
class MaskEdgeInvocation(BaseInvocation, WithMetadata):
"""Applies an edge mask to an image"""
image: ImageField = InputField(description="The image to apply the mask to")
@ -652,7 +736,7 @@ class MaskEdgeInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -667,9 +751,9 @@ class MaskEdgeInvocation(BaseInvocation, WithWorkflow, WithMetadata):
title="Combine Masks",
tags=["image", "mask", "multiply"],
category="image",
version="1.1.0",
version="1.2.0",
)
class MaskCombineInvocation(BaseInvocation, WithWorkflow, WithMetadata):
class MaskCombineInvocation(BaseInvocation, WithMetadata):
"""Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`."""
mask1: ImageField = InputField(description="The first mask to combine")
@ -689,7 +773,7 @@ class MaskCombineInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -699,8 +783,14 @@ class MaskCombineInvocation(BaseInvocation, WithWorkflow, WithMetadata):
)
@invocation("color_correct", title="Color Correct", tags=["image", "color"], category="image", version="1.1.0")
class ColorCorrectInvocation(BaseInvocation, WithWorkflow, WithMetadata):
@invocation(
"color_correct",
title="Color Correct",
tags=["image", "color"],
category="image",
version="1.2.0",
)
class ColorCorrectInvocation(BaseInvocation, WithMetadata):
"""
Shifts the colors of a target image to match the reference image, optionally
using a mask to only color-correct certain regions of the target image.
@ -800,7 +890,7 @@ class ColorCorrectInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -810,8 +900,14 @@ class ColorCorrectInvocation(BaseInvocation, WithWorkflow, WithMetadata):
)
@invocation("img_hue_adjust", title="Adjust Image Hue", tags=["image", "hue"], category="image", version="1.1.0")
class ImageHueAdjustmentInvocation(BaseInvocation, WithWorkflow, WithMetadata):
@invocation(
"img_hue_adjust",
title="Adjust Image Hue",
tags=["image", "hue"],
category="image",
version="1.2.0",
)
class ImageHueAdjustmentInvocation(BaseInvocation, WithMetadata):
"""Adjusts the Hue of an image."""
image: ImageField = InputField(description="The image to adjust")
@ -840,7 +936,7 @@ class ImageHueAdjustmentInvocation(BaseInvocation, WithWorkflow, WithMetadata):
is_intermediate=self.is_intermediate,
session_id=context.graph_execution_state_id,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -913,9 +1009,9 @@ CHANNEL_FORMATS = {
"value",
],
category="image",
version="1.1.0",
version="1.2.0",
)
class ImageChannelOffsetInvocation(BaseInvocation, WithWorkflow, WithMetadata):
class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata):
"""Add or subtract a value from a specific color channel of an image."""
image: ImageField = InputField(description="The image to adjust")
@ -950,7 +1046,7 @@ class ImageChannelOffsetInvocation(BaseInvocation, WithWorkflow, WithMetadata):
is_intermediate=self.is_intermediate,
session_id=context.graph_execution_state_id,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -984,9 +1080,9 @@ class ImageChannelOffsetInvocation(BaseInvocation, WithWorkflow, WithMetadata):
"value",
],
category="image",
version="1.1.0",
version="1.2.0",
)
class ImageChannelMultiplyInvocation(BaseInvocation, WithWorkflow, WithMetadata):
class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata):
"""Scale a specific color channel of an image."""
image: ImageField = InputField(description="The image to adjust")
@ -1025,7 +1121,7 @@ class ImageChannelMultiplyInvocation(BaseInvocation, WithWorkflow, WithMetadata)
node_id=self.id,
is_intermediate=self.is_intermediate,
session_id=context.graph_execution_state_id,
workflow=self.workflow,
workflow=context.workflow,
metadata=self.metadata,
)
@ -1043,10 +1139,10 @@ class ImageChannelMultiplyInvocation(BaseInvocation, WithWorkflow, WithMetadata)
title="Save Image",
tags=["primitives", "image"],
category="primitives",
version="1.1.0",
version="1.2.0",
use_cache=False,
)
class SaveImageInvocation(BaseInvocation, WithWorkflow, WithMetadata):
class SaveImageInvocation(BaseInvocation, WithMetadata):
"""Saves an image. Unlike an image primitive, this invocation stores a copy of the image."""
image: ImageField = InputField(description=FieldDescriptions.image)
@ -1064,7 +1160,7 @@ class SaveImageInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -1082,7 +1178,7 @@ class SaveImageInvocation(BaseInvocation, WithWorkflow, WithMetadata):
version="1.0.1",
use_cache=False,
)
class LinearUIOutputInvocation(BaseInvocation, WithWorkflow, WithMetadata):
class LinearUIOutputInvocation(BaseInvocation, WithMetadata):
"""Handles Linear UI Image Outputting tasks."""
image: ImageField = InputField(description=FieldDescriptions.image)

View File

@ -13,7 +13,7 @@ from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint
from invokeai.backend.image_util.lama import LaMA
from invokeai.backend.image_util.patchmatch import PatchMatch
from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, WithWorkflow, invocation
from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, invocation
from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES
@ -118,8 +118,8 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int]
return si
@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.1.0")
class InfillColorInvocation(BaseInvocation, WithWorkflow, WithMetadata):
@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0")
class InfillColorInvocation(BaseInvocation, WithMetadata):
"""Infills transparent areas of an image with a solid color"""
image: ImageField = InputField(description="The image to infill")
@ -144,7 +144,7 @@ class InfillColorInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -154,8 +154,8 @@ class InfillColorInvocation(BaseInvocation, WithWorkflow, WithMetadata):
)
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.1.1")
class InfillTileInvocation(BaseInvocation, WithWorkflow, WithMetadata):
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1")
class InfillTileInvocation(BaseInvocation, WithMetadata):
"""Infills transparent areas of an image with tiles of the image"""
image: ImageField = InputField(description="The image to infill")
@ -181,7 +181,7 @@ class InfillTileInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -192,9 +192,9 @@ class InfillTileInvocation(BaseInvocation, WithWorkflow, WithMetadata):
@invocation(
"infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.1.0"
"infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0"
)
class InfillPatchMatchInvocation(BaseInvocation, WithWorkflow, WithMetadata):
class InfillPatchMatchInvocation(BaseInvocation, WithMetadata):
"""Infills transparent areas of an image using the PatchMatch algorithm"""
image: ImageField = InputField(description="The image to infill")
@ -235,7 +235,7 @@ class InfillPatchMatchInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -245,8 +245,8 @@ class InfillPatchMatchInvocation(BaseInvocation, WithWorkflow, WithMetadata):
)
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.1.0")
class LaMaInfillInvocation(BaseInvocation, WithWorkflow, WithMetadata):
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0")
class LaMaInfillInvocation(BaseInvocation, WithMetadata):
"""Infills transparent areas of an image using the LaMa model"""
image: ImageField = InputField(description="The image to infill")
@ -264,7 +264,7 @@ class LaMaInfillInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -274,8 +274,8 @@ class LaMaInfillInvocation(BaseInvocation, WithWorkflow, WithMetadata):
)
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.1.0")
class CV2InfillInvocation(BaseInvocation, WithWorkflow, WithMetadata):
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0")
class CV2InfillInvocation(BaseInvocation, WithMetadata):
"""Infills transparent areas of an image using OpenCV Inpainting"""
image: ImageField = InputField(description="The image to infill")
@ -293,7 +293,7 @@ class CV2InfillInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(

View File

@ -64,7 +64,6 @@ from .baseinvocation import (
OutputField,
UIType,
WithMetadata,
WithWorkflow,
invocation,
invocation_output,
)
@ -79,6 +78,12 @@ DEFAULT_PRECISION = choose_precision(choose_torch_device())
SAMPLER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())]
# HACK: Many nodes are currently hard-coded to use a fixed latent scale factor of 8. This is fragile, and will need to
# be addressed if future models use a different latent scale factor. Also, note that there may be places where the scale
# factor is hard-coded to a literal '8' rather than using this constant.
# The ratio of image:latent dimensions is LATENT_SCALE_FACTOR:1, or 8:1.
LATENT_SCALE_FACTOR = 8
@invocation_output("scheduler_output")
class SchedulerOutput(BaseInvocationOutput):
@ -394,9 +399,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
exit_stack: ExitStack,
do_classifier_free_guidance: bool = True,
) -> List[ControlNetData]:
# assuming fixed dimensional scaling of 8:1 for image:latents
control_height_resize = latents_shape[2] * 8
control_width_resize = latents_shape[3] * 8
# Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR.
control_height_resize = latents_shape[2] * LATENT_SCALE_FACTOR
control_width_resize = latents_shape[3] * LATENT_SCALE_FACTOR
if control_input is None:
control_list = None
elif isinstance(control_input, list) and len(control_input) == 0:
@ -796,9 +801,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
title="Latents to Image",
tags=["latents", "image", "vae", "l2i"],
category="latents",
version="1.1.0",
version="1.2.0",
)
class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
class LatentsToImageInvocation(BaseInvocation, WithMetadata):
"""Generates an image from latents."""
latents: LatentsField = InputField(
@ -880,7 +885,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(
@ -909,12 +914,12 @@ class ResizeLatentsInvocation(BaseInvocation):
)
width: int = InputField(
ge=64,
multiple_of=8,
multiple_of=LATENT_SCALE_FACTOR,
description=FieldDescriptions.width,
)
height: int = InputField(
ge=64,
multiple_of=8,
multiple_of=LATENT_SCALE_FACTOR,
description=FieldDescriptions.width,
)
mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode)
@ -928,7 +933,7 @@ class ResizeLatentsInvocation(BaseInvocation):
resized_latents = torch.nn.functional.interpolate(
latents.to(device),
size=(self.height // 8, self.width // 8),
size=(self.height // LATENT_SCALE_FACTOR, self.width // LATENT_SCALE_FACTOR),
mode=self.mode,
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
)
@ -1166,3 +1171,60 @@ class BlendLatentsInvocation(BaseInvocation):
# context.services.latents.set(name, resized_latents)
context.services.latents.save(name, blended_latents)
return build_latents_output(latents_name=name, latents=blended_latents)
# The Crop Latents node was copied from @skunkworxdark's implementation here:
# https://github.com/skunkworxdark/XYGrid_nodes/blob/74647fa9c1fa57d317a94bd43ca689af7f0aae5e/images_to_grids.py#L1117C1-L1167C80
@invocation(
"crop_latents",
title="Crop Latents",
tags=["latents", "crop"],
category="latents",
version="1.0.0",
)
# TODO(ryand): Named `CropLatentsCoreInvocation` to prevent a conflict with custom node `CropLatentsInvocation`.
# Currently, if the class names conflict then 'GET /openapi.json' fails.
class CropLatentsCoreInvocation(BaseInvocation):
"""Crops a latent-space tensor to a box specified in image-space. The box dimensions and coordinates must be
divisible by the latent scale factor of 8.
"""
latents: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
x: int = InputField(
ge=0,
multiple_of=LATENT_SCALE_FACTOR,
description="The left x coordinate (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
)
y: int = InputField(
ge=0,
multiple_of=LATENT_SCALE_FACTOR,
description="The top y coordinate (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
)
width: int = InputField(
ge=1,
multiple_of=LATENT_SCALE_FACTOR,
description="The width (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
)
height: int = InputField(
ge=1,
multiple_of=LATENT_SCALE_FACTOR,
description="The height (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
)
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name)
x1 = self.x // LATENT_SCALE_FACTOR
y1 = self.y // LATENT_SCALE_FACTOR
x2 = x1 + (self.width // LATENT_SCALE_FACTOR)
y2 = y1 + (self.height // LATENT_SCALE_FACTOR)
cropped_latents = latents[..., y1:y2, x1:x2]
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.save(name, cropped_latents)
return build_latents_output(latents_name=name, latents=cropped_latents)

View File

@ -31,7 +31,6 @@ from .baseinvocation import (
UIComponent,
UIType,
WithMetadata,
WithWorkflow,
invocation,
invocation_output,
)
@ -326,9 +325,9 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
title="ONNX Latents to Image",
tags=["latents", "image", "vae", "onnx"],
category="image",
version="1.1.0",
version="1.2.0",
)
class ONNXLatentsToImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
class ONNXLatentsToImageInvocation(BaseInvocation, WithMetadata):
"""Generates an image from latents."""
latents: LatentsField = InputField(
@ -378,7 +377,7 @@ class ONNXLatentsToImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(

View File

@ -0,0 +1,180 @@
import numpy as np
from PIL import Image
from pydantic import BaseModel
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InputField,
InvocationContext,
OutputField,
WithMetadata,
invocation,
invocation_output,
)
from invokeai.app.invocations.primitives import ImageField, ImageOutput
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
from invokeai.backend.tiles.tiles import calc_tiles_with_overlap, merge_tiles_with_linear_blending
from invokeai.backend.tiles.utils import Tile
class TileWithImage(BaseModel):
tile: Tile
image: ImageField
@invocation_output("calculate_image_tiles_output")
class CalculateImageTilesOutput(BaseInvocationOutput):
tiles: list[Tile] = OutputField(description="The tiles coordinates that cover a particular image shape.")
@invocation("calculate_image_tiles", title="Calculate Image Tiles", tags=["tiles"], category="tiles", version="1.0.0")
class CalculateImageTilesInvocation(BaseInvocation):
"""Calculate the coordinates and overlaps of tiles that cover a target image shape."""
image_width: int = InputField(ge=1, default=1024, description="The image width, in pixels, to calculate tiles for.")
image_height: int = InputField(
ge=1, default=1024, description="The image height, in pixels, to calculate tiles for."
)
tile_width: int = InputField(ge=1, default=576, description="The tile width, in pixels.")
tile_height: int = InputField(ge=1, default=576, description="The tile height, in pixels.")
overlap: int = InputField(
ge=0,
default=128,
description="The target overlap, in pixels, between adjacent tiles. Adjacent tiles will overlap by at least this amount",
)
def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput:
tiles = calc_tiles_with_overlap(
image_height=self.image_height,
image_width=self.image_width,
tile_height=self.tile_height,
tile_width=self.tile_width,
overlap=self.overlap,
)
return CalculateImageTilesOutput(tiles=tiles)
@invocation_output("tile_to_properties_output")
class TileToPropertiesOutput(BaseInvocationOutput):
coords_left: int = OutputField(description="Left coordinate of the tile relative to its parent image.")
coords_right: int = OutputField(description="Right coordinate of the tile relative to its parent image.")
coords_top: int = OutputField(description="Top coordinate of the tile relative to its parent image.")
coords_bottom: int = OutputField(description="Bottom coordinate of the tile relative to its parent image.")
# HACK: The width and height fields are 'meta' fields that can easily be calculated from the other fields on this
# object. Including redundant fields that can cheaply/easily be re-calculated goes against conventional API design
# principles. These fields are included, because 1) they are often useful in tiled workflows, and 2) they are
# difficult to calculate in a workflow (even though it's just a couple of subtraction nodes the graph gets
# surprisingly complicated).
width: int = OutputField(description="The width of the tile. Equal to coords_right - coords_left.")
height: int = OutputField(description="The height of the tile. Equal to coords_bottom - coords_top.")
overlap_top: int = OutputField(description="Overlap between this tile and its top neighbor.")
overlap_bottom: int = OutputField(description="Overlap between this tile and its bottom neighbor.")
overlap_left: int = OutputField(description="Overlap between this tile and its left neighbor.")
overlap_right: int = OutputField(description="Overlap between this tile and its right neighbor.")
@invocation("tile_to_properties", title="Tile to Properties", tags=["tiles"], category="tiles", version="1.0.0")
class TileToPropertiesInvocation(BaseInvocation):
"""Split a Tile into its individual properties."""
tile: Tile = InputField(description="The tile to split into properties.")
def invoke(self, context: InvocationContext) -> TileToPropertiesOutput:
return TileToPropertiesOutput(
coords_left=self.tile.coords.left,
coords_right=self.tile.coords.right,
coords_top=self.tile.coords.top,
coords_bottom=self.tile.coords.bottom,
width=self.tile.coords.right - self.tile.coords.left,
height=self.tile.coords.bottom - self.tile.coords.top,
overlap_top=self.tile.overlap.top,
overlap_bottom=self.tile.overlap.bottom,
overlap_left=self.tile.overlap.left,
overlap_right=self.tile.overlap.right,
)
@invocation_output("pair_tile_image_output")
class PairTileImageOutput(BaseInvocationOutput):
tile_with_image: TileWithImage = OutputField(description="A tile description with its corresponding image.")
@invocation("pair_tile_image", title="Pair Tile with Image", tags=["tiles"], category="tiles", version="1.0.0")
class PairTileImageInvocation(BaseInvocation):
"""Pair an image with its tile properties."""
# TODO(ryand): The only reason that PairTileImage is needed is because the iterate/collect nodes don't preserve
# order. Can this be fixed?
image: ImageField = InputField(description="The tile image.")
tile: Tile = InputField(description="The tile properties.")
def invoke(self, context: InvocationContext) -> PairTileImageOutput:
return PairTileImageOutput(
tile_with_image=TileWithImage(
tile=self.tile,
image=self.image,
)
)
@invocation("merge_tiles_to_image", title="Merge Tiles to Image", tags=["tiles"], category="tiles", version="1.1.0")
class MergeTilesToImageInvocation(BaseInvocation, WithMetadata):
"""Merge multiple tile images into a single image."""
# Inputs
tiles_with_images: list[TileWithImage] = InputField(description="A list of tile images with tile properties.")
blend_amount: int = InputField(
ge=0,
description="The amount to blend adjacent tiles in pixels. Must be <= the amount of overlap between adjacent tiles.",
)
def invoke(self, context: InvocationContext) -> ImageOutput:
images = [twi.image for twi in self.tiles_with_images]
tiles = [twi.tile for twi in self.tiles_with_images]
# Infer the output image dimensions from the max/min tile limits.
height = 0
width = 0
for tile in tiles:
height = max(height, tile.coords.bottom)
width = max(width, tile.coords.right)
# Get all tile images for processing.
# TODO(ryand): It pains me that we spend time PNG decoding each tile from disk when they almost certainly
# existed in memory at an earlier point in the graph.
tile_np_images: list[np.ndarray] = []
for image in images:
pil_image = context.services.images.get_pil_image(image.image_name)
pil_image = pil_image.convert("RGB")
tile_np_images.append(np.array(pil_image))
# Prepare the output image buffer.
# Check the first tile to determine how many image channels are expected in the output.
channels = tile_np_images[0].shape[-1]
dtype = tile_np_images[0].dtype
np_image = np.zeros(shape=(height, width, channels), dtype=dtype)
merge_tiles_with_linear_blending(
dst_image=np_image, tiles=tiles, tile_images=tile_np_images, blend_amount=self.blend_amount
)
pil_image = Image.fromarray(np_image)
image_dto = context.services.images.create(
image=pil_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=context.workflow,
)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)

View File

@ -14,7 +14,7 @@ from invokeai.app.services.image_records.image_records_common import ImageCatego
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
from invokeai.backend.util.devices import choose_torch_device
from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, WithWorkflow, invocation
from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, invocation
# TODO: Populate this from disk?
# TODO: Use model manager to load?
@ -29,8 +29,8 @@ if choose_torch_device() == torch.device("mps"):
from torch import mps
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.2.0")
class ESRGANInvocation(BaseInvocation, WithWorkflow, WithMetadata):
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.0")
class ESRGANInvocation(BaseInvocation, WithMetadata):
"""Upscales an image using RealESRGAN."""
image: ImageField = InputField(description="The input image")
@ -118,7 +118,7 @@ class ESRGANInvocation(BaseInvocation, WithWorkflow, WithMetadata):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=self.workflow,
workflow=context.workflow,
)
return ImageOutput(

View File

@ -4,7 +4,7 @@ from typing import Optional, cast
from invokeai.app.services.image_records.image_records_common import ImageRecord, deserialize_image_record
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite import SqliteDatabase
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from .board_image_records_base import BoardImageRecordStorageBase

View File

@ -3,7 +3,7 @@ import threading
from typing import Union, cast
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite import SqliteDatabase
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.util.misc import uuid_string
from .board_records_base import BoardRecordStorageBase

View File

@ -4,7 +4,8 @@ from typing import Optional
from PIL.Image import Image as PILImageType
from invokeai.app.invocations.baseinvocation import MetadataField, WorkflowField
from invokeai.app.invocations.baseinvocation import MetadataField
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
class ImageFileStorageBase(ABC):
@ -33,7 +34,7 @@ class ImageFileStorageBase(ABC):
image: PILImageType,
image_name: str,
metadata: Optional[MetadataField] = None,
workflow: Optional[WorkflowField] = None,
workflow: Optional[WorkflowWithoutID] = None,
thumbnail_size: int = 256,
) -> None:
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
@ -43,3 +44,8 @@ class ImageFileStorageBase(ABC):
def delete(self, image_name: str) -> None:
"""Deletes an image and its thumbnail (if one exists)."""
pass
@abstractmethod
def get_workflow(self, image_name: str) -> Optional[WorkflowWithoutID]:
"""Gets the workflow of an image."""
pass

View File

@ -7,8 +7,9 @@ from PIL import Image, PngImagePlugin
from PIL.Image import Image as PILImageType
from send2trash import send2trash
from invokeai.app.invocations.baseinvocation import MetadataField, WorkflowField
from invokeai.app.invocations.baseinvocation import MetadataField
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
from .image_files_base import ImageFileStorageBase
@ -56,7 +57,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
image: PILImageType,
image_name: str,
metadata: Optional[MetadataField] = None,
workflow: Optional[WorkflowField] = None,
workflow: Optional[WorkflowWithoutID] = None,
thumbnail_size: int = 256,
) -> None:
try:
@ -64,12 +65,19 @@ class DiskImageFileStorage(ImageFileStorageBase):
image_path = self.get_path(image_name)
pnginfo = PngImagePlugin.PngInfo()
info_dict = {}
if metadata is not None:
pnginfo.add_text("invokeai_metadata", metadata.model_dump_json())
metadata_json = metadata.model_dump_json()
info_dict["invokeai_metadata"] = metadata_json
pnginfo.add_text("invokeai_metadata", metadata_json)
if workflow is not None:
pnginfo.add_text("invokeai_workflow", workflow.model_dump_json())
workflow_json = workflow.model_dump_json()
info_dict["invokeai_workflow"] = workflow_json
pnginfo.add_text("invokeai_workflow", workflow_json)
# When saving the image, the image object's info field is not populated. We need to set it
image.info = info_dict
image.save(
image_path,
"PNG",
@ -121,6 +129,13 @@ class DiskImageFileStorage(ImageFileStorageBase):
path = path if isinstance(path, Path) else Path(path)
return path.exists()
def get_workflow(self, image_name: str) -> WorkflowWithoutID | None:
image = self.get(image_name)
workflow = image.info.get("invokeai_workflow", None)
if workflow is not None:
return WorkflowWithoutID.model_validate_json(workflow)
return None
def __validate_storage_folders(self) -> None:
"""Checks if the required output folders exist and create them if they don't"""
folders: list[Path] = [self.__output_folder, self.__thumbnails_folder]

View File

@ -75,6 +75,7 @@ class ImageRecordStorageBase(ABC):
image_category: ImageCategory,
width: int,
height: int,
has_workflow: bool,
is_intermediate: Optional[bool] = False,
starred: Optional[bool] = False,
session_id: Optional[str] = None,

View File

@ -100,6 +100,7 @@ IMAGE_DTO_COLS = ", ".join(
"height",
"session_id",
"node_id",
"has_workflow",
"is_intermediate",
"created_at",
"updated_at",
@ -145,6 +146,7 @@ class ImageRecord(BaseModelExcludeNull):
"""The node ID that generated this image, if it is a generated image."""
starred: bool = Field(description="Whether this image is starred.")
"""Whether this image is starred."""
has_workflow: bool = Field(description="Whether this image has a workflow.")
class ImageRecordChanges(BaseModelExcludeNull, extra="allow"):
@ -188,6 +190,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
deleted_at = image_dict.get("deleted_at", get_iso_timestamp())
is_intermediate = image_dict.get("is_intermediate", False)
starred = image_dict.get("starred", False)
has_workflow = image_dict.get("has_workflow", False)
return ImageRecord(
image_name=image_name,
@ -202,4 +205,5 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
deleted_at=deleted_at,
is_intermediate=is_intermediate,
starred=starred,
has_workflow=has_workflow,
)

View File

@ -5,7 +5,7 @@ from typing import Optional, Union, cast
from invokeai.app.invocations.baseinvocation import MetadataField, MetadataFieldValidator
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite import SqliteDatabase
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from .image_records_base import ImageRecordStorageBase
from .image_records_common import (
@ -117,6 +117,16 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
"""
)
self._cursor.execute("PRAGMA table_info(images)")
columns = [column[1] for column in self._cursor.fetchall()]
if "has_workflow" not in columns:
self._cursor.execute(
"""--sql
ALTER TABLE images
ADD COLUMN has_workflow BOOLEAN DEFAULT FALSE;
"""
)
def get(self, image_name: str) -> ImageRecord:
try:
self._lock.acquire()
@ -408,6 +418,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
image_category: ImageCategory,
width: int,
height: int,
has_workflow: bool,
is_intermediate: Optional[bool] = False,
starred: Optional[bool] = False,
session_id: Optional[str] = None,
@ -429,9 +440,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
session_id,
metadata,
is_intermediate,
starred
starred,
has_workflow
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
""",
(
image_name,
@ -444,6 +456,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
metadata_json,
is_intermediate,
starred,
has_workflow,
),
)
self._conn.commit()

View File

@ -3,7 +3,7 @@ from typing import Callable, Optional
from PIL.Image import Image as PILImageType
from invokeai.app.invocations.baseinvocation import MetadataField, WorkflowField
from invokeai.app.invocations.baseinvocation import MetadataField
from invokeai.app.services.image_records.image_records_common import (
ImageCategory,
ImageRecord,
@ -12,6 +12,7 @@ from invokeai.app.services.image_records.image_records_common import (
)
from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
class ImageServiceABC(ABC):
@ -51,7 +52,7 @@ class ImageServiceABC(ABC):
board_id: Optional[str] = None,
is_intermediate: Optional[bool] = False,
metadata: Optional[MetadataField] = None,
workflow: Optional[WorkflowField] = None,
workflow: Optional[WorkflowWithoutID] = None,
) -> ImageDTO:
"""Creates an image, storing the file and its metadata."""
pass
@ -85,6 +86,11 @@ class ImageServiceABC(ABC):
"""Gets an image's metadata."""
pass
@abstractmethod
def get_workflow(self, image_name: str) -> Optional[WorkflowWithoutID]:
"""Gets an image's workflow."""
pass
@abstractmethod
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
"""Gets an image's path."""

View File

@ -24,11 +24,6 @@ class ImageDTO(ImageRecord, ImageUrlsDTO):
default=None, description="The id of the board the image belongs to, if one exists."
)
"""The id of the board the image belongs to, if one exists."""
workflow_id: Optional[str] = Field(
default=None,
description="The workflow that generated this image.",
)
"""The workflow that generated this image."""
def image_record_to_dto(
@ -36,7 +31,6 @@ def image_record_to_dto(
image_url: str,
thumbnail_url: str,
board_id: Optional[str],
workflow_id: Optional[str],
) -> ImageDTO:
"""Converts an image record to an image DTO."""
return ImageDTO(
@ -44,5 +38,4 @@ def image_record_to_dto(
image_url=image_url,
thumbnail_url=thumbnail_url,
board_id=board_id,
workflow_id=workflow_id,
)

View File

@ -2,9 +2,10 @@ from typing import Optional
from PIL.Image import Image as PILImageType
from invokeai.app.invocations.baseinvocation import MetadataField, WorkflowField
from invokeai.app.invocations.baseinvocation import MetadataField
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
from ..image_files.image_files_common import (
ImageFileDeleteException,
@ -42,7 +43,7 @@ class ImageService(ImageServiceABC):
board_id: Optional[str] = None,
is_intermediate: Optional[bool] = False,
metadata: Optional[MetadataField] = None,
workflow: Optional[WorkflowField] = None,
workflow: Optional[WorkflowWithoutID] = None,
) -> ImageDTO:
if image_origin not in ResourceOrigin:
raise InvalidOriginException
@ -55,12 +56,6 @@ class ImageService(ImageServiceABC):
(width, height) = image.size
try:
if workflow is not None:
created_workflow = self.__invoker.services.workflow_records.create(workflow)
workflow_id = created_workflow.model_dump()["id"]
else:
workflow_id = None
# TODO: Consider using a transaction here to ensure consistency between storage and database
self.__invoker.services.image_records.save(
# Non-nullable fields
@ -69,6 +64,7 @@ class ImageService(ImageServiceABC):
image_category=image_category,
width=width,
height=height,
has_workflow=workflow is not None,
# Meta fields
is_intermediate=is_intermediate,
# Nullable fields
@ -78,8 +74,6 @@ class ImageService(ImageServiceABC):
)
if board_id is not None:
self.__invoker.services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name)
if workflow_id is not None:
self.__invoker.services.workflow_image_records.create(workflow_id=workflow_id, image_name=image_name)
self.__invoker.services.image_files.save(
image_name=image_name, image=image, metadata=metadata, workflow=workflow
)
@ -143,7 +137,6 @@ class ImageService(ImageServiceABC):
image_url=self.__invoker.services.urls.get_image_url(image_name),
thumbnail_url=self.__invoker.services.urls.get_image_url(image_name, True),
board_id=self.__invoker.services.board_image_records.get_board_for_image(image_name),
workflow_id=self.__invoker.services.workflow_image_records.get_workflow_for_image(image_name),
)
return image_dto
@ -164,18 +157,15 @@ class ImageService(ImageServiceABC):
self.__invoker.services.logger.error("Problem getting image DTO")
raise e
def get_workflow(self, image_name: str) -> Optional[WorkflowField]:
def get_workflow(self, image_name: str) -> Optional[WorkflowWithoutID]:
try:
workflow_id = self.__invoker.services.workflow_image_records.get_workflow_for_image(image_name)
if workflow_id is None:
return None
return self.__invoker.services.workflow_records.get(workflow_id)
except ImageRecordNotFoundException:
self.__invoker.services.logger.error("Image record not found")
return self.__invoker.services.image_files.get_workflow(image_name)
except ImageFileNotFoundException:
self.__invoker.services.logger.error("Image file not found")
raise
except Exception:
self.__invoker.services.logger.error("Problem getting image workflow")
raise
except Exception as e:
self.__invoker.services.logger.error("Problem getting image DTO")
raise e
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
try:
@ -223,7 +213,6 @@ class ImageService(ImageServiceABC):
image_url=self.__invoker.services.urls.get_image_url(r.image_name),
thumbnail_url=self.__invoker.services.urls.get_image_url(r.image_name, True),
board_id=self.__invoker.services.board_image_records.get_board_for_image(r.image_name),
workflow_id=self.__invoker.services.workflow_image_records.get_workflow_for_image(r.image_name),
)
for r in results.items
]

View File

@ -108,6 +108,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
queue_batch_id=queue_item.session_queue_batch_id,
workflow=queue_item.workflow,
)
)
@ -178,6 +179,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
session_queue_item_id=queue_item.session_queue_item_id,
session_queue_id=queue_item.session_queue_id,
graph_execution_state=graph_execution_state,
workflow=queue_item.workflow,
invoke_all=True,
)
except Exception as e:

View File

@ -1,9 +1,12 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import time
from typing import Optional
from pydantic import BaseModel, Field
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
class InvocationQueueItem(BaseModel):
graph_execution_state_id: str = Field(description="The ID of the graph execution state")
@ -15,5 +18,6 @@ class InvocationQueueItem(BaseModel):
session_queue_batch_id: str = Field(
description="The ID of the session batch from which this invocation queue item came"
)
workflow: Optional[WorkflowWithoutID] = Field(description="The workflow associated with this queue item")
invoke_all: bool = Field(default=False)
timestamp: float = Field(default_factory=time.time)

View File

@ -28,7 +28,6 @@ if TYPE_CHECKING:
from .session_queue.session_queue_base import SessionQueueBase
from .shared.graph import GraphExecutionState, LibraryGraph
from .urls.urls_base import UrlServiceBase
from .workflow_image_records.workflow_image_records_base import WorkflowImageRecordsStorageBase
from .workflow_records.workflow_records_base import WorkflowRecordsStorageBase
@ -59,7 +58,6 @@ class InvocationServices:
invocation_cache: "InvocationCacheBase"
names: "NameServiceBase"
urls: "UrlServiceBase"
workflow_image_records: "WorkflowImageRecordsStorageBase"
workflow_records: "WorkflowRecordsStorageBase"
def __init__(
@ -87,7 +85,6 @@ class InvocationServices:
invocation_cache: "InvocationCacheBase",
names: "NameServiceBase",
urls: "UrlServiceBase",
workflow_image_records: "WorkflowImageRecordsStorageBase",
workflow_records: "WorkflowRecordsStorageBase",
):
self.board_images = board_images
@ -113,5 +110,4 @@ class InvocationServices:
self.invocation_cache = invocation_cache
self.names = names
self.urls = urls
self.workflow_image_records = workflow_image_records
self.workflow_records = workflow_records

View File

@ -2,6 +2,8 @@
from typing import Optional
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
from .invocation_queue.invocation_queue_common import InvocationQueueItem
from .invocation_services import InvocationServices
from .shared.graph import Graph, GraphExecutionState
@ -22,6 +24,7 @@ class Invoker:
session_queue_item_id: int,
session_queue_batch_id: str,
graph_execution_state: GraphExecutionState,
workflow: Optional[WorkflowWithoutID] = None,
invoke_all: bool = False,
) -> Optional[str]:
"""Determines the next node to invoke and enqueues it, preparing if needed.
@ -43,6 +46,7 @@ class Invoker:
session_queue_batch_id=session_queue_batch_id,
graph_execution_state_id=graph_execution_state.id,
invocation_id=invocation.id,
workflow=workflow,
invoke_all=invoke_all,
)
)

View File

@ -5,7 +5,7 @@ from typing import Generic, Optional, TypeVar, get_args
from pydantic import BaseModel, TypeAdapter
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.app.services.shared.sqlite import SqliteDatabase
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from .item_storage_base import ItemStorageABC

View File

@ -5,6 +5,8 @@ from typing import Union
import torch
from invokeai.app.services.invoker import Invoker
from .latents_storage_base import LatentsStorageBase
@ -17,6 +19,10 @@ class DiskLatentsStorage(LatentsStorageBase):
self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder)
self.__output_folder.mkdir(parents=True, exist_ok=True)
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
self._delete_all_latents()
def get(self, name: str) -> torch.Tensor:
latent_path = self.get_path(name)
return torch.load(latent_path)
@ -32,3 +38,21 @@ class DiskLatentsStorage(LatentsStorageBase):
def get_path(self, name: str) -> Path:
return self.__output_folder / name
def _delete_all_latents(self) -> None:
"""
Deletes all latents from disk.
Must be called after we have access to `self._invoker` (e.g. in `start()`).
"""
deleted_latents_count = 0
freed_space = 0
for latents_file in Path(self.__output_folder).glob("*"):
if latents_file.is_file():
freed_space += latents_file.stat().st_size
deleted_latents_count += 1
latents_file.unlink()
if deleted_latents_count > 0:
freed_space_in_mb = round(freed_space / 1024 / 1024, 2)
self._invoker.services.logger.info(
f"Deleted {deleted_latents_count} latents files (freed {freed_space_in_mb}MB)"
)

View File

@ -5,6 +5,8 @@ from typing import Dict, Optional
import torch
from invokeai.app.services.invoker import Invoker
from .latents_storage_base import LatentsStorageBase
@ -23,6 +25,18 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
self.__cache_ids = Queue()
self.__max_cache_size = max_cache_size
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
start_op = getattr(self.__underlying_storage, "start", None)
if callable(start_op):
start_op(invoker)
def stop(self, invoker: Invoker) -> None:
self._invoker = invoker
stop_op = getattr(self.__underlying_storage, "stop", None)
if callable(stop_op):
stop_op(invoker)
def get(self, name: str) -> torch.Tensor:
cache_item = self.__get_cache(name)
if cache_item is not None:

View File

@ -52,7 +52,7 @@ from invokeai.backend.model_manager.config import (
ModelType,
)
from ..shared.sqlite import SqliteDatabase
from ..shared.sqlite.sqlite_database import SqliteDatabase
from .model_records_base import (
CONFIG_FILE_VERSION,
DuplicateModelException,

View File

@ -114,6 +114,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
session_queue_id=queue_item.queue_id,
session_queue_item_id=queue_item.item_id,
graph_execution_state=queue_item.session,
workflow=queue_item.workflow,
invoke_all=True,
)
queue_item = None

View File

@ -8,6 +8,10 @@ from pydantic_core import to_jsonable_python
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, NodeNotFoundError
from invokeai.app.services.workflow_records.workflow_records_common import (
WorkflowWithoutID,
WorkflowWithoutIDValidator,
)
from invokeai.app.util.misc import uuid_string
# region Errors
@ -66,6 +70,9 @@ class Batch(BaseModel):
batch_id: str = Field(default_factory=uuid_string, description="The ID of the batch")
data: Optional[BatchDataCollection] = Field(default=None, description="The batch data collection.")
graph: Graph = Field(description="The graph to initialize the session with")
workflow: Optional[WorkflowWithoutID] = Field(
default=None, description="The workflow to initialize the session with"
)
runs: int = Field(
default=1, ge=1, description="Int stating how many times to iterate through all possible batch indices"
)
@ -164,6 +171,14 @@ def get_session(queue_item_dict: dict) -> GraphExecutionState:
return session
def get_workflow(queue_item_dict: dict) -> Optional[WorkflowWithoutID]:
workflow_raw = queue_item_dict.get("workflow", None)
if workflow_raw is not None:
workflow = WorkflowWithoutIDValidator.validate_json(workflow_raw, strict=False)
return workflow
return None
class SessionQueueItemWithoutGraph(BaseModel):
"""Session queue item without the full graph. Used for serialization."""
@ -213,12 +228,16 @@ class SessionQueueItemDTO(SessionQueueItemWithoutGraph):
class SessionQueueItem(SessionQueueItemWithoutGraph):
session: GraphExecutionState = Field(description="The fully-populated session to be executed")
workflow: Optional[WorkflowWithoutID] = Field(
default=None, description="The workflow associated with this queue item"
)
@classmethod
def queue_item_from_dict(cls, queue_item_dict: dict) -> "SessionQueueItem":
# must parse these manually
queue_item_dict["field_values"] = get_field_values(queue_item_dict)
queue_item_dict["session"] = get_session(queue_item_dict)
queue_item_dict["workflow"] = get_workflow(queue_item_dict)
return SessionQueueItem(**queue_item_dict)
model_config = ConfigDict(
@ -334,7 +353,7 @@ def populate_graph(graph: Graph, node_field_values: Iterable[NodeFieldValue]) ->
def create_session_nfv_tuples(
batch: Batch, maximum: int
) -> Generator[tuple[GraphExecutionState, list[NodeFieldValue]], None, None]:
) -> Generator[tuple[GraphExecutionState, list[NodeFieldValue], Optional[WorkflowWithoutID]], None, None]:
"""
Create all graph permutations from the given batch data and graph. Yields tuples
of the form (graph, batch_data_items) where batch_data_items is the list of BatchDataItems
@ -365,7 +384,7 @@ def create_session_nfv_tuples(
return
flat_node_field_values = list(chain.from_iterable(d))
graph = populate_graph(batch.graph, flat_node_field_values)
yield (GraphExecutionState(graph=graph), flat_node_field_values)
yield (GraphExecutionState(graph=graph), flat_node_field_values, batch.workflow)
count += 1
@ -391,12 +410,14 @@ def calc_session_count(batch: Batch) -> int:
class SessionQueueValueToInsert(NamedTuple):
"""A tuple of values to insert into the session_queue table"""
# Careful with the ordering of this - it must match the insert statement
queue_id: str # queue_id
session: str # session json
session_id: str # session_id
batch_id: str # batch_id
field_values: Optional[str] # field_values json
priority: int # priority
workflow: Optional[str] # workflow json
ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert]
@ -404,7 +425,7 @@ ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert]
def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new_queue_items: int) -> ValuesToInsert:
values_to_insert: ValuesToInsert = []
for session, field_values in create_session_nfv_tuples(batch, max_new_queue_items):
for session, field_values, workflow in create_session_nfv_tuples(batch, max_new_queue_items):
# sessions must have unique id
session.id = uuid_string()
values_to_insert.append(
@ -416,6 +437,7 @@ def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new
# must use pydantic_encoder bc field_values is a list of models
json.dumps(field_values, default=to_jsonable_python) if field_values else None, # field_values (json)
priority, # priority
json.dumps(workflow, default=to_jsonable_python) if workflow else None, # workflow (json)
)
)
return values_to_insert

View File

@ -28,7 +28,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
prepare_values_to_insert,
)
from invokeai.app.services.shared.pagination import CursorPaginatedResults
from invokeai.app.services.shared.sqlite import SqliteDatabase
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
class SqliteSessionQueue(SessionQueueBase):
@ -42,7 +42,8 @@ class SqliteSessionQueue(SessionQueueBase):
self._set_in_progress_to_canceled()
prune_result = self.prune(DEFAULT_QUEUE_ID)
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_session_event)
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
if prune_result.deleted > 0:
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
@ -198,6 +199,15 @@ class SqliteSessionQueue(SessionQueueBase):
"""
)
self.__cursor.execute("PRAGMA table_info(session_queue)")
columns = [column[1] for column in self.__cursor.fetchall()]
if "workflow" not in columns:
self.__cursor.execute(
"""--sql
ALTER TABLE session_queue ADD COLUMN workflow TEXT;
"""
)
self.__conn.commit()
except Exception:
self.__conn.rollback()
@ -280,8 +290,8 @@ class SqliteSessionQueue(SessionQueueBase):
self.__cursor.executemany(
"""--sql
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority)
VALUES (?, ?, ?, ?, ?, ?)
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow)
VALUES (?, ?, ?, ?, ?, ?, ?)
""",
values_to_insert,
)

View File

@ -207,10 +207,12 @@ class IterateInvocationOutput(BaseInvocationOutput):
item: Any = OutputField(
description="The item being iterated over", title="Collection Item", ui_type=UIType._CollectionItem
)
index: int = OutputField(description="The index of the item", title="Index")
total: int = OutputField(description="The total number of items", title="Total")
# TODO: Fill this out and move to invocations
@invocation("iterate", version="1.0.0")
@invocation("iterate", version="1.1.0")
class IterateInvocation(BaseInvocation):
"""Iterates over a list of items"""
@ -221,7 +223,7 @@ class IterateInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> IterateInvocationOutput:
"""Produces the outputs as values"""
return IterateInvocationOutput(item=self.collection[self.index])
return IterateInvocationOutput(item=self.collection[self.index], index=self.index, total=len(self.collection))
@invocation_output("collect_output")

View File

@ -1,48 +0,0 @@
import sqlite3
import threading
from logging import Logger
from invokeai.app.services.config import InvokeAIAppConfig
sqlite_memory = ":memory:"
class SqliteDatabase:
conn: sqlite3.Connection
lock: threading.RLock
_logger: Logger
_config: InvokeAIAppConfig
def __init__(self, config: InvokeAIAppConfig, logger: Logger):
self._logger = logger
self._config = config
if self._config.use_memory_db:
location = sqlite_memory
logger.info("Using in-memory database")
else:
db_path = self._config.db_path
db_path.parent.mkdir(parents=True, exist_ok=True)
location = str(db_path)
self._logger.info(f"Using database at {location}")
self.conn = sqlite3.connect(location, check_same_thread=False)
self.lock = threading.RLock()
self.conn.row_factory = sqlite3.Row
if self._config.log_sql:
self.conn.set_trace_callback(self._logger.debug)
self.conn.execute("PRAGMA foreign_keys = ON;")
def clean(self) -> None:
try:
self.lock.acquire()
self.conn.execute("VACUUM;")
self.conn.commit()
self._logger.info("Cleaned database")
except Exception as e:
self._logger.error(f"Error cleaning database: {e}")
raise e
finally:
self.lock.release()

View File

@ -0,0 +1,10 @@
from enum import Enum
from invokeai.app.util.metaenum import MetaEnum
sqlite_memory = ":memory:"
class SQLiteDirection(str, Enum, metaclass=MetaEnum):
Ascending = "ASC"
Descending = "DESC"

View File

@ -0,0 +1,47 @@
import sqlite3
import threading
from logging import Logger
from pathlib import Path
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory
class SqliteDatabase:
def __init__(self, config: InvokeAIAppConfig, logger: Logger):
self._logger = logger
self._config = config
if self._config.use_memory_db:
self.db_path = sqlite_memory
logger.info("Using in-memory database")
else:
db_path = self._config.db_path
db_path.parent.mkdir(parents=True, exist_ok=True)
self.db_path = str(db_path)
self._logger.info(f"Using database at {self.db_path}")
self.conn = sqlite3.connect(self.db_path, check_same_thread=False)
self.lock = threading.RLock()
self.conn.row_factory = sqlite3.Row
if self._config.log_sql:
self.conn.set_trace_callback(self._logger.debug)
self.conn.execute("PRAGMA foreign_keys = ON;")
def clean(self) -> None:
with self.lock:
try:
if self.db_path == sqlite_memory:
return
initial_db_size = Path(self.db_path).stat().st_size
self.conn.execute("VACUUM;")
self.conn.commit()
final_db_size = Path(self.db_path).stat().st_size
freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2)
if freed_space_in_mb > 0:
self._logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
except Exception as e:
self._logger.error(f"Error cleaning database: {e}")
raise

View File

@ -1,23 +0,0 @@
from abc import ABC, abstractmethod
from typing import Optional
class WorkflowImageRecordsStorageBase(ABC):
"""Abstract base class for the one-to-many workflow-image relationship record storage."""
@abstractmethod
def create(
self,
workflow_id: str,
image_name: str,
) -> None:
"""Creates a workflow-image record."""
pass
@abstractmethod
def get_workflow_for_image(
self,
image_name: str,
) -> Optional[str]:
"""Gets an image's workflow id, if it has one."""
pass

View File

@ -1,122 +0,0 @@
import sqlite3
import threading
from typing import Optional, cast
from invokeai.app.services.shared.sqlite import SqliteDatabase
from invokeai.app.services.workflow_image_records.workflow_image_records_base import WorkflowImageRecordsStorageBase
class SqliteWorkflowImageRecordsStorage(WorkflowImageRecordsStorageBase):
"""SQLite implementation of WorkflowImageRecordsStorageBase."""
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.RLock
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._lock = db.lock
self._conn = db.conn
self._cursor = self._conn.cursor()
try:
self._lock.acquire()
self._create_tables()
self._conn.commit()
finally:
self._lock.release()
def _create_tables(self) -> None:
# Create the `workflow_images` junction table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS workflow_images (
workflow_id TEXT NOT NULL,
image_name TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME,
-- enforce one-to-many relationship between workflows and images using PK
-- (we can extend this to many-to-many later)
PRIMARY KEY (image_name),
FOREIGN KEY (workflow_id) REFERENCES workflows (workflow_id) ON DELETE CASCADE,
FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE
);
"""
)
# Add index for workflow id
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_workflow_images_workflow_id ON workflow_images (workflow_id);
"""
)
# Add index for workflow id, sorted by created_at
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_workflow_images_workflow_id_created_at ON workflow_images (workflow_id, created_at);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_workflow_images_updated_at
AFTER UPDATE
ON workflow_images FOR EACH ROW
BEGIN
UPDATE workflow_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE workflow_id = old.workflow_id AND image_name = old.image_name;
END;
"""
)
def create(
self,
workflow_id: str,
image_name: str,
) -> None:
"""Creates a workflow-image record."""
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT INTO workflow_images (workflow_id, image_name)
VALUES (?, ?);
""",
(workflow_id, image_name),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
def get_workflow_for_image(
self,
image_name: str,
) -> Optional[str]:
"""Gets an image's workflow id, if it has one."""
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT workflow_id
FROM workflow_images
WHERE image_name = ?;
""",
(image_name,),
)
result = self._cursor.fetchone()
if result is None:
return None
return cast(str, result[0])
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()

View File

@ -0,0 +1,17 @@
# Default Workflows
Workflows placed in this directory will be synced to the `workflow_library` as
_default workflows_ on app startup.
- Default workflows are not editable by users. If they are loaded and saved,
they will save as a copy of the default workflow.
- Default workflows must have the `meta.category` property set to `"default"`.
An exception will be raised during sync if this is not set correctly.
- Default workflows appear on the "Default Workflows" tab of the Workflow
Library.
After adding or updating default workflows, you **must** start the app up and
load them to ensure:
- The workflow loads without warning or errors
- The workflow runs successfully

View File

@ -0,0 +1,798 @@
{
"name": "Text to Image - SD1.5",
"author": "InvokeAI",
"description": "Sample text to image workflow for Stable Diffusion 1.5/2",
"version": "1.1.0",
"contact": "invoke@invoke.ai",
"tags": "text2image, SD1.5, SD2, default",
"notes": "",
"exposedFields": [
{
"nodeId": "c8d55139-f380-4695-b7f2-8b3d1e1e3db8",
"fieldName": "model"
},
{
"nodeId": "7d8bf987-284f-413a-b2fd-d825445a5d6c",
"fieldName": "prompt"
},
{
"nodeId": "93dc02a4-d05b-48ed-b99c-c9b616af3402",
"fieldName": "prompt"
},
{
"nodeId": "55705012-79b9-4aac-9f26-c0b10309785b",
"fieldName": "width"
},
{
"nodeId": "55705012-79b9-4aac-9f26-c0b10309785b",
"fieldName": "height"
}
],
"meta": {
"category": "default",
"version": "2.0.0"
},
"nodes": [
{
"id": "93dc02a4-d05b-48ed-b99c-c9b616af3402",
"type": "invocation",
"data": {
"id": "93dc02a4-d05b-48ed-b99c-c9b616af3402",
"type": "compel",
"label": "Negative Compel Prompt",
"isOpen": true,
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.0.0",
"nodePack": "invokeai",
"inputs": {
"prompt": {
"id": "7739aff6-26cb-4016-8897-5a1fb2305e4e",
"name": "prompt",
"fieldKind": "input",
"label": "Negative Prompt",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "StringField"
},
"value": ""
},
"clip": {
"id": "48d23dce-a6ae-472a-9f8c-22a714ea5ce0",
"name": "clip",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ClipField"
}
}
},
"outputs": {
"conditioning": {
"id": "37cf3a9d-f6b7-4b64-8ff6-2558c5ecc447",
"name": "conditioning",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ConditioningField"
}
}
}
},
"width": 320,
"height": 259,
"position": {
"x": 1000,
"y": 350
}
},
{
"id": "55705012-79b9-4aac-9f26-c0b10309785b",
"type": "invocation",
"data": {
"id": "55705012-79b9-4aac-9f26-c0b10309785b",
"type": "noise",
"label": "",
"isOpen": true,
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.0.1",
"nodePack": "invokeai",
"inputs": {
"seed": {
"id": "6431737c-918a-425d-a3b4-5d57e2f35d4d",
"name": "seed",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
},
"value": 0
},
"width": {
"id": "38fc5b66-fe6e-47c8-bba9-daf58e454ed7",
"name": "width",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
},
"value": 512
},
"height": {
"id": "16298330-e2bf-4872-a514-d6923df53cbb",
"name": "height",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
},
"value": 512
},
"use_cpu": {
"id": "c7c436d3-7a7a-4e76-91e4-c6deb271623c",
"name": "use_cpu",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "BooleanField"
},
"value": true
}
},
"outputs": {
"noise": {
"id": "50f650dc-0184-4e23-a927-0497a96fe954",
"name": "noise",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "LatentsField"
}
},
"width": {
"id": "bb8a452b-133d-42d1-ae4a-3843d7e4109a",
"name": "width",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
},
"height": {
"id": "35cfaa12-3b8b-4b7a-a884-327ff3abddd9",
"name": "height",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
}
}
},
"width": 320,
"height": 388,
"position": {
"x": 600,
"y": 325
}
},
{
"id": "c8d55139-f380-4695-b7f2-8b3d1e1e3db8",
"type": "invocation",
"data": {
"id": "c8d55139-f380-4695-b7f2-8b3d1e1e3db8",
"type": "main_model_loader",
"label": "",
"isOpen": true,
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.0.0",
"nodePack": "invokeai",
"inputs": {
"model": {
"id": "993eabd2-40fd-44fe-bce7-5d0c7075ddab",
"name": "model",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "MainModelField"
},
"value": {
"model_name": "stable-diffusion-v1-5",
"base_model": "sd-1",
"model_type": "main"
}
}
},
"outputs": {
"unet": {
"id": "5c18c9db-328d-46d0-8cb9-143391c410be",
"name": "unet",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "UNetField"
}
},
"clip": {
"id": "6effcac0-ec2f-4bf5-a49e-a2c29cf921f4",
"name": "clip",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ClipField"
}
},
"vae": {
"id": "57683ba3-f5f5-4f58-b9a2-4b83dacad4a1",
"name": "vae",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "VaeField"
}
}
}
},
"width": 320,
"height": 226,
"position": {
"x": 600,
"y": 25
}
},
{
"id": "7d8bf987-284f-413a-b2fd-d825445a5d6c",
"type": "invocation",
"data": {
"id": "7d8bf987-284f-413a-b2fd-d825445a5d6c",
"type": "compel",
"label": "Positive Compel Prompt",
"isOpen": true,
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.0.0",
"nodePack": "invokeai",
"inputs": {
"prompt": {
"id": "7739aff6-26cb-4016-8897-5a1fb2305e4e",
"name": "prompt",
"fieldKind": "input",
"label": "Positive Prompt",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "StringField"
},
"value": "Super cute tiger cub, national geographic award-winning photograph"
},
"clip": {
"id": "48d23dce-a6ae-472a-9f8c-22a714ea5ce0",
"name": "clip",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ClipField"
}
}
},
"outputs": {
"conditioning": {
"id": "37cf3a9d-f6b7-4b64-8ff6-2558c5ecc447",
"name": "conditioning",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ConditioningField"
}
}
}
},
"width": 320,
"height": 259,
"position": {
"x": 1000,
"y": 25
}
},
{
"id": "ea94bc37-d995-4a83-aa99-4af42479f2f2",
"type": "invocation",
"data": {
"id": "ea94bc37-d995-4a83-aa99-4af42479f2f2",
"type": "rand_int",
"label": "Random Seed",
"isOpen": false,
"notes": "",
"isIntermediate": true,
"useCache": false,
"version": "1.0.0",
"nodePack": "invokeai",
"inputs": {
"low": {
"id": "3ec65a37-60ba-4b6c-a0b2-553dd7a84b84",
"name": "low",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
},
"value": 0
},
"high": {
"id": "085f853a-1a5f-494d-8bec-e4ba29a3f2d1",
"name": "high",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
},
"value": 2147483647
}
},
"outputs": {
"value": {
"id": "812ade4d-7699-4261-b9fc-a6c9d2ab55ee",
"name": "value",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
}
}
},
"width": 320,
"height": 32,
"position": {
"x": 600,
"y": 275
}
},
{
"id": "eea2702a-19fb-45b5-9d75-56b4211ec03c",
"type": "invocation",
"data": {
"id": "eea2702a-19fb-45b5-9d75-56b4211ec03c",
"type": "denoise_latents",
"label": "",
"isOpen": true,
"notes": "",
"isIntermediate": true,
"useCache": true,
"version": "1.5.0",
"nodePack": "invokeai",
"inputs": {
"positive_conditioning": {
"id": "90b7f4f8-ada7-4028-8100-d2e54f192052",
"name": "positive_conditioning",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ConditioningField"
}
},
"negative_conditioning": {
"id": "9393779e-796c-4f64-b740-902a1177bf53",
"name": "negative_conditioning",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ConditioningField"
}
},
"noise": {
"id": "8e17f1e5-4f98-40b1-b7f4-86aeeb4554c1",
"name": "noise",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "LatentsField"
}
},
"steps": {
"id": "9b63302d-6bd2-42c9-ac13-9b1afb51af88",
"name": "steps",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
},
"value": 50
},
"cfg_scale": {
"id": "87dd04d3-870e-49e1-98bf-af003a810109",
"name": "cfg_scale",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": true,
"name": "FloatField"
},
"value": 7.5
},
"denoising_start": {
"id": "f369d80f-4931-4740-9bcd-9f0620719fab",
"name": "denoising_start",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "FloatField"
},
"value": 0
},
"denoising_end": {
"id": "747d10e5-6f02-445c-994c-0604d814de8c",
"name": "denoising_end",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "FloatField"
},
"value": 1
},
"scheduler": {
"id": "1de84a4e-3a24-4ec8-862b-16ce49633b9b",
"name": "scheduler",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "SchedulerField"
},
"value": "unipc"
},
"unet": {
"id": "ffa6fef4-3ce2-4bdb-9296-9a834849489b",
"name": "unet",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "UNetField"
}
},
"control": {
"id": "077b64cb-34be-4fcc-83f2-e399807a02bd",
"name": "control",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": true,
"name": "ControlField"
}
},
"ip_adapter": {
"id": "1d6948f7-3a65-4a65-a20c-768b287251aa",
"name": "ip_adapter",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": true,
"name": "IPAdapterField"
}
},
"t2i_adapter": {
"id": "75e67b09-952f-4083-aaf4-6b804d690412",
"name": "t2i_adapter",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": true,
"name": "T2IAdapterField"
}
},
"cfg_rescale_multiplier": {
"id": "9101f0a6-5fe0-4826-b7b3-47e5d506826c",
"name": "cfg_rescale_multiplier",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "FloatField"
},
"value": 0
},
"latents": {
"id": "334d4ba3-5a99-4195-82c5-86fb3f4f7d43",
"name": "latents",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "LatentsField"
}
},
"denoise_mask": {
"id": "0d3dbdbf-b014-4e95-8b18-ff2ff9cb0bfa",
"name": "denoise_mask",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "DenoiseMaskField"
}
}
},
"outputs": {
"latents": {
"id": "70fa5bbc-0c38-41bb-861a-74d6d78d2f38",
"name": "latents",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "LatentsField"
}
},
"width": {
"id": "98ee0e6c-82aa-4e8f-8be5-dc5f00ee47f0",
"name": "width",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
},
"height": {
"id": "e8cb184a-5e1a-47c8-9695-4b8979564f5d",
"name": "height",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
}
}
},
"width": 320,
"height": 703,
"position": {
"x": 1400,
"y": 25
}
},
{
"id": "58c957f5-0d01-41fc-a803-b2bbf0413d4f",
"type": "invocation",
"data": {
"id": "58c957f5-0d01-41fc-a803-b2bbf0413d4f",
"type": "l2i",
"label": "",
"isOpen": true,
"notes": "",
"isIntermediate": false,
"useCache": true,
"version": "1.2.0",
"nodePack": "invokeai",
"inputs": {
"metadata": {
"id": "ab375f12-0042-4410-9182-29e30db82c85",
"name": "metadata",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "MetadataField"
}
},
"latents": {
"id": "3a7e7efd-bff5-47d7-9d48-615127afee78",
"name": "latents",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "LatentsField"
}
},
"vae": {
"id": "a1f5f7a1-0795-4d58-b036-7820c0b0ef2b",
"name": "vae",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "VaeField"
}
},
"tiled": {
"id": "da52059a-0cee-4668-942f-519aa794d739",
"name": "tiled",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "BooleanField"
},
"value": false
},
"fp32": {
"id": "c4841df3-b24e-4140-be3b-ccd454c2522c",
"name": "fp32",
"fieldKind": "input",
"label": "",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "BooleanField"
},
"value": true
}
},
"outputs": {
"image": {
"id": "72d667d0-cf85-459d-abf2-28bd8b823fe7",
"name": "image",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "ImageField"
}
},
"width": {
"id": "c8c907d8-1066-49d1-b9a6-83bdcd53addc",
"name": "width",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
},
"height": {
"id": "230f359c-b4ea-436c-b372-332d7dcdca85",
"name": "height",
"fieldKind": "output",
"type": {
"isCollection": false,
"isCollectionOrScalar": false,
"name": "IntegerField"
}
}
}
},
"width": 320,
"height": 266,
"position": {
"x": 1800,
"y": 25
}
}
],
"edges": [
{
"id": "reactflow__edge-ea94bc37-d995-4a83-aa99-4af42479f2f2value-55705012-79b9-4aac-9f26-c0b10309785bseed",
"source": "ea94bc37-d995-4a83-aa99-4af42479f2f2",
"target": "55705012-79b9-4aac-9f26-c0b10309785b",
"type": "default",
"sourceHandle": "value",
"targetHandle": "seed"
},
{
"id": "reactflow__edge-c8d55139-f380-4695-b7f2-8b3d1e1e3db8clip-7d8bf987-284f-413a-b2fd-d825445a5d6cclip",
"source": "c8d55139-f380-4695-b7f2-8b3d1e1e3db8",
"target": "7d8bf987-284f-413a-b2fd-d825445a5d6c",
"type": "default",
"sourceHandle": "clip",
"targetHandle": "clip"
},
{
"id": "reactflow__edge-c8d55139-f380-4695-b7f2-8b3d1e1e3db8clip-93dc02a4-d05b-48ed-b99c-c9b616af3402clip",
"source": "c8d55139-f380-4695-b7f2-8b3d1e1e3db8",
"target": "93dc02a4-d05b-48ed-b99c-c9b616af3402",
"type": "default",
"sourceHandle": "clip",
"targetHandle": "clip"
},
{
"id": "reactflow__edge-55705012-79b9-4aac-9f26-c0b10309785bnoise-eea2702a-19fb-45b5-9d75-56b4211ec03cnoise",
"source": "55705012-79b9-4aac-9f26-c0b10309785b",
"target": "eea2702a-19fb-45b5-9d75-56b4211ec03c",
"type": "default",
"sourceHandle": "noise",
"targetHandle": "noise"
},
{
"id": "reactflow__edge-7d8bf987-284f-413a-b2fd-d825445a5d6cconditioning-eea2702a-19fb-45b5-9d75-56b4211ec03cpositive_conditioning",
"source": "7d8bf987-284f-413a-b2fd-d825445a5d6c",
"target": "eea2702a-19fb-45b5-9d75-56b4211ec03c",
"type": "default",
"sourceHandle": "conditioning",
"targetHandle": "positive_conditioning"
},
{
"id": "reactflow__edge-93dc02a4-d05b-48ed-b99c-c9b616af3402conditioning-eea2702a-19fb-45b5-9d75-56b4211ec03cnegative_conditioning",
"source": "93dc02a4-d05b-48ed-b99c-c9b616af3402",
"target": "eea2702a-19fb-45b5-9d75-56b4211ec03c",
"type": "default",
"sourceHandle": "conditioning",
"targetHandle": "negative_conditioning"
},
{
"id": "reactflow__edge-c8d55139-f380-4695-b7f2-8b3d1e1e3db8unet-eea2702a-19fb-45b5-9d75-56b4211ec03cunet",
"source": "c8d55139-f380-4695-b7f2-8b3d1e1e3db8",
"target": "eea2702a-19fb-45b5-9d75-56b4211ec03c",
"type": "default",
"sourceHandle": "unet",
"targetHandle": "unet"
},
{
"id": "reactflow__edge-eea2702a-19fb-45b5-9d75-56b4211ec03clatents-58c957f5-0d01-41fc-a803-b2bbf0413d4flatents",
"source": "eea2702a-19fb-45b5-9d75-56b4211ec03c",
"target": "58c957f5-0d01-41fc-a803-b2bbf0413d4f",
"type": "default",
"sourceHandle": "latents",
"targetHandle": "latents"
},
{
"id": "reactflow__edge-c8d55139-f380-4695-b7f2-8b3d1e1e3db8vae-58c957f5-0d01-41fc-a803-b2bbf0413d4fvae",
"source": "c8d55139-f380-4695-b7f2-8b3d1e1e3db8",
"target": "58c957f5-0d01-41fc-a803-b2bbf0413d4f",
"type": "default",
"sourceHandle": "vae",
"targetHandle": "vae"
}
]
}

File diff suppressed because it is too large Load Diff

View File

@ -1,17 +1,50 @@
from abc import ABC, abstractmethod
from typing import Optional
from invokeai.app.invocations.baseinvocation import WorkflowField
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.services.workflow_records.workflow_records_common import (
Workflow,
WorkflowCategory,
WorkflowRecordDTO,
WorkflowRecordListItemDTO,
WorkflowRecordOrderBy,
WorkflowWithoutID,
)
class WorkflowRecordsStorageBase(ABC):
"""Base class for workflow storage services."""
@abstractmethod
def get(self, workflow_id: str) -> WorkflowField:
def get(self, workflow_id: str) -> WorkflowRecordDTO:
"""Get workflow by id."""
pass
@abstractmethod
def create(self, workflow: WorkflowField) -> WorkflowField:
def create(self, workflow: WorkflowWithoutID) -> WorkflowRecordDTO:
"""Creates a workflow."""
pass
@abstractmethod
def update(self, workflow: Workflow) -> WorkflowRecordDTO:
"""Updates a workflow."""
pass
@abstractmethod
def delete(self, workflow_id: str) -> None:
"""Deletes a workflow."""
pass
@abstractmethod
def get_many(
self,
page: int,
per_page: int,
order_by: WorkflowRecordOrderBy,
direction: SQLiteDirection,
category: WorkflowCategory,
query: Optional[str],
) -> PaginatedResults[WorkflowRecordListItemDTO]:
"""Gets many workflows."""
pass

View File

@ -1,2 +1,104 @@
import datetime
from enum import Enum
from typing import Any, Union
import semver
from pydantic import BaseModel, Field, JsonValue, TypeAdapter, field_validator
from invokeai.app.util.metaenum import MetaEnum
from invokeai.app.util.misc import uuid_string
__workflow_meta_version__ = semver.Version.parse("1.0.0")
class ExposedField(BaseModel):
nodeId: str
fieldName: str
class WorkflowNotFoundError(Exception):
"""Raised when a workflow is not found"""
class WorkflowRecordOrderBy(str, Enum, metaclass=MetaEnum):
"""The order by options for workflow records"""
CreatedAt = "created_at"
UpdatedAt = "updated_at"
OpenedAt = "opened_at"
Name = "name"
class WorkflowCategory(str, Enum, metaclass=MetaEnum):
User = "user"
Default = "default"
Project = "project"
class WorkflowMeta(BaseModel):
version: str = Field(description="The version of the workflow schema.")
category: WorkflowCategory = Field(description="The category of the workflow (user or default).")
@field_validator("version")
def validate_version(cls, version: str):
try:
semver.Version.parse(version)
return version
except Exception:
raise ValueError(f"Invalid workflow meta version: {version}")
def to_semver(self) -> semver.Version:
return semver.Version.parse(self.version)
class WorkflowWithoutID(BaseModel):
name: str = Field(description="The name of the workflow.")
author: str = Field(description="The author of the workflow.")
description: str = Field(description="The description of the workflow.")
version: str = Field(description="The version of the workflow.")
contact: str = Field(description="The contact of the workflow.")
tags: str = Field(description="The tags of the workflow.")
notes: str = Field(description="The notes of the workflow.")
exposedFields: list[ExposedField] = Field(description="The exposed fields of the workflow.")
meta: WorkflowMeta = Field(description="The meta of the workflow.")
# TODO: nodes and edges are very loosely typed
nodes: list[dict[str, JsonValue]] = Field(description="The nodes of the workflow.")
edges: list[dict[str, JsonValue]] = Field(description="The edges of the workflow.")
WorkflowWithoutIDValidator = TypeAdapter(WorkflowWithoutID)
class Workflow(WorkflowWithoutID):
id: str = Field(default_factory=uuid_string, description="The id of the workflow.")
WorkflowValidator = TypeAdapter(Workflow)
class WorkflowRecordDTOBase(BaseModel):
workflow_id: str = Field(description="The id of the workflow.")
name: str = Field(description="The name of the workflow.")
created_at: Union[datetime.datetime, str] = Field(description="The created timestamp of the workflow.")
updated_at: Union[datetime.datetime, str] = Field(description="The updated timestamp of the workflow.")
opened_at: Union[datetime.datetime, str] = Field(description="The opened timestamp of the workflow.")
class WorkflowRecordDTO(WorkflowRecordDTOBase):
workflow: Workflow = Field(description="The workflow.")
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "WorkflowRecordDTO":
data["workflow"] = WorkflowValidator.validate_json(data.get("workflow", ""))
return WorkflowRecordDTOValidator.validate_python(data)
WorkflowRecordDTOValidator = TypeAdapter(WorkflowRecordDTO)
class WorkflowRecordListItemDTO(WorkflowRecordDTOBase):
description: str = Field(description="The description of the workflow.")
category: WorkflowCategory = Field(description="The description of the workflow.")
WorkflowRecordListItemDTOValidator = TypeAdapter(WorkflowRecordListItemDTO)

View File

@ -1,20 +1,25 @@
import sqlite3
import threading
from pathlib import Path
from typing import Optional
from invokeai.app.invocations.baseinvocation import WorkflowField, WorkflowFieldValidator
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.sqlite import SqliteDatabase
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.workflow_records.workflow_records_base import WorkflowRecordsStorageBase
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowNotFoundError
from invokeai.app.util.misc import uuid_string
from invokeai.app.services.workflow_records.workflow_records_common import (
Workflow,
WorkflowCategory,
WorkflowNotFoundError,
WorkflowRecordDTO,
WorkflowRecordListItemDTO,
WorkflowRecordListItemDTOValidator,
WorkflowRecordOrderBy,
WorkflowValidator,
WorkflowWithoutID,
)
class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
_invoker: Invoker
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.RLock
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._lock = db.lock
@ -24,14 +29,25 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
self._sync_default_workflows()
def get(self, workflow_id: str) -> WorkflowField:
def get(self, workflow_id: str) -> WorkflowRecordDTO:
"""Gets a workflow by ID. Updates the opened_at column."""
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT workflow
FROM workflows
UPDATE workflow_library
SET opened_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE workflow_id = ?;
""",
(workflow_id,),
)
self._conn.commit()
self._cursor.execute(
"""--sql
SELECT workflow_id, workflow, name, created_at, updated_at, opened_at
FROM workflow_library
WHERE workflow_id = ?;
""",
(workflow_id,),
@ -39,25 +55,28 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
row = self._cursor.fetchone()
if row is None:
raise WorkflowNotFoundError(f"Workflow with id {workflow_id} not found")
return WorkflowFieldValidator.validate_json(row[0])
return WorkflowRecordDTO.from_dict(dict(row))
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
def create(self, workflow: WorkflowField) -> WorkflowField:
def create(self, workflow: WorkflowWithoutID) -> WorkflowRecordDTO:
try:
# workflows do not have ids until they are saved
workflow_id = uuid_string()
workflow.root["id"] = workflow_id
# Only user workflows may be created by this method
assert workflow.meta.category is WorkflowCategory.User
workflow_with_id = WorkflowValidator.validate_python(workflow.model_dump())
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT INTO workflows(workflow)
VALUES (?);
INSERT OR IGNORE INTO workflow_library (
workflow_id,
workflow
)
VALUES (?, ?);
""",
(workflow.model_dump_json(),),
(workflow_with_id.id, workflow_with_id.model_dump_json()),
)
self._conn.commit()
except Exception:
@ -65,35 +84,231 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
raise
finally:
self._lock.release()
return self.get(workflow_id)
return self.get(workflow_with_id.id)
def update(self, workflow: Workflow) -> WorkflowRecordDTO:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
UPDATE workflow_library
SET workflow = ?
WHERE workflow_id = ? AND category = 'user';
""",
(workflow.model_dump_json(), workflow.id),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
return self.get(workflow.id)
def delete(self, workflow_id: str) -> None:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE from workflow_library
WHERE workflow_id = ? AND category = 'user';
""",
(workflow_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
return None
def get_many(
self,
page: int,
per_page: int,
order_by: WorkflowRecordOrderBy,
direction: SQLiteDirection,
category: WorkflowCategory,
query: Optional[str] = None,
) -> PaginatedResults[WorkflowRecordListItemDTO]:
try:
self._lock.acquire()
# sanitize!
assert order_by in WorkflowRecordOrderBy
assert direction in SQLiteDirection
assert category in WorkflowCategory
count_query = "SELECT COUNT(*) FROM workflow_library WHERE category = ?"
main_query = """
SELECT
workflow_id,
category,
name,
description,
created_at,
updated_at,
opened_at
FROM workflow_library
WHERE category = ?
"""
main_params: list[int | str] = [category.value]
count_params: list[int | str] = [category.value]
stripped_query = query.strip() if query else None
if stripped_query:
wildcard_query = "%" + stripped_query + "%"
main_query += " AND name LIKE ? OR description LIKE ? "
count_query += " AND name LIKE ? OR description LIKE ?;"
main_params.extend([wildcard_query, wildcard_query])
count_params.extend([wildcard_query, wildcard_query])
main_query += f" ORDER BY {order_by.value} {direction.value} LIMIT ? OFFSET ?;"
main_params.extend([per_page, page * per_page])
self._cursor.execute(main_query, main_params)
rows = self._cursor.fetchall()
workflows = [WorkflowRecordListItemDTOValidator.validate_python(dict(row)) for row in rows]
self._cursor.execute(count_query, count_params)
total = self._cursor.fetchone()[0]
pages = int(total / per_page) + 1
return PaginatedResults(
items=workflows,
page=page,
per_page=per_page,
pages=pages,
total=total,
)
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
def _sync_default_workflows(self) -> None:
"""Syncs default workflows to the database. Internal use only."""
"""
An enhancement might be to only update workflows that have changed. This would require stable
default workflow IDs, and properly incrementing the workflow version.
It's much simpler to just replace them all with whichever workflows are in the directory.
The downside is that the `updated_at` and `opened_at` timestamps for default workflows are
meaningless, as they are overwritten every time the server starts.
"""
try:
self._lock.acquire()
workflows: list[Workflow] = []
workflows_dir = Path(__file__).parent / Path("default_workflows")
workflow_paths = workflows_dir.glob("*.json")
for path in workflow_paths:
bytes_ = path.read_bytes()
workflow = WorkflowValidator.validate_json(bytes_)
workflows.append(workflow)
# Only default workflows may be managed by this method
assert all(w.meta.category is WorkflowCategory.Default for w in workflows)
self._cursor.execute(
"""--sql
DELETE FROM workflow_library
WHERE category = 'default';
"""
)
for w in workflows:
self._cursor.execute(
"""--sql
INSERT OR REPLACE INTO workflow_library (
workflow_id,
workflow
)
VALUES (?, ?);
""",
(w.id, w.model_dump_json()),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
def _create_tables(self) -> None:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS workflows (
CREATE TABLE IF NOT EXISTS workflow_library (
workflow_id TEXT NOT NULL PRIMARY KEY,
workflow TEXT NOT NULL,
workflow_id TEXT GENERATED ALWAYS AS (json_extract(workflow, '$.id')) VIRTUAL NOT NULL UNIQUE, -- gets implicit index
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) -- updated via trigger
-- updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- updated manually when retrieving workflow
opened_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Generated columns, needed for indexing and searching
category TEXT GENERATED ALWAYS as (json_extract(workflow, '$.meta.category')) VIRTUAL NOT NULL,
name TEXT GENERATED ALWAYS as (json_extract(workflow, '$.name')) VIRTUAL NOT NULL,
description TEXT GENERATED ALWAYS as (json_extract(workflow, '$.description')) VIRTUAL NOT NULL
);
"""
)
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_workflows_updated_at
CREATE TRIGGER IF NOT EXISTS tg_workflow_library_updated_at
AFTER UPDATE
ON workflows FOR EACH ROW
ON workflow_library FOR EACH ROW
BEGIN
UPDATE workflows
UPDATE workflow_library
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE workflow_id = old.workflow_id;
END;
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_workflow_library_created_at ON workflow_library(created_at);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_workflow_library_updated_at ON workflow_library(updated_at);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_workflow_library_opened_at ON workflow_library(opened_at);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_workflow_library_category ON workflow_library(category);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_workflow_library_name ON workflow_library(name);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_workflow_library_description ON workflow_library(description);
"""
)
# We do not need the original `workflows` table or `workflow_images` junction table.
self._cursor.execute(
"""--sql
DROP TABLE IF EXISTS workflow_images;
"""
)
self._cursor.execute(
"""--sql
DROP TABLE IF EXISTS workflows;
"""
)
self._conn.commit()
except Exception:
self._conn.rollback()

View File

@ -192,20 +192,33 @@ class ModelPatcher:
trigger += f"-!pad-{i}"
return f"<{trigger}>"
def _get_ti_embedding(model_embeddings, ti):
# for SDXL models, select the embedding that matches the text encoder's dimensions
if ti.embedding_2 is not None:
return (
ti.embedding_2
if ti.embedding_2.shape[1] == model_embeddings.weight.data[0].shape[0]
else ti.embedding
)
else:
return ti.embedding
# modify tokenizer
new_tokens_added = 0
for ti_name, ti in ti_list:
for i in range(ti.embedding.shape[0]):
ti_embedding = _get_ti_embedding(text_encoder.get_input_embeddings(), ti)
for i in range(ti_embedding.shape[0]):
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
# modify text_encoder
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added, pad_to_multiple_of)
model_embeddings = text_encoder.get_input_embeddings()
for ti_name, ti in ti_list:
for ti_name, _ in ti_list:
ti_tokens = []
for i in range(ti.embedding.shape[0]):
embedding = ti.embedding[i]
for i in range(ti_embedding.shape[0]):
embedding = ti_embedding[i]
trigger = _get_trigger(ti_name, i)
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
@ -273,6 +286,7 @@ class ModelPatcher:
class TextualInversionModel:
embedding: torch.Tensor # [n, 768]|[n, 1280]
embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models
@classmethod
def from_checkpoint(
@ -296,8 +310,8 @@ class TextualInversionModel:
if "string_to_param" in state_dict:
if len(state_dict["string_to_param"]) > 1:
print(
f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first'
" token will be used."
f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first',
" token will be used.",
)
result.embedding = next(iter(state_dict["string_to_param"].values()))
@ -306,6 +320,11 @@ class TextualInversionModel:
elif "emb_params" in state_dict:
result.embedding = state_dict["emb_params"]
# v5(sdxl safetensors file)
elif "clip_g" in state_dict and "clip_l" in state_dict:
result.embedding = state_dict["clip_g"]
result.embedding_2 = state_dict["clip_l"]
# v4(diffusers bin files)
else:
result.embedding = next(iter(state_dict.values()))
@ -342,6 +361,13 @@ class TextualInversionManager(BaseTextualInversionManager):
if token_id in self.pad_tokens:
new_token_ids.extend(self.pad_tokens[token_id])
# Do not exceed the max model input size
# The -2 here is compensating for compensate compel.embeddings_provider.get_token_ids(),
# which first removes and then adds back the start and end tokens.
max_length = list(self.tokenizer.max_model_input_sizes.values())[0] - 2
if len(new_token_ids) > max_length:
new_token_ids = new_token_ids[0:max_length]
return new_token_ids
@ -490,24 +516,31 @@ class ONNXModelPatcher:
trigger += f"-!pad-{i}"
return f"<{trigger}>"
# modify text_encoder
orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"]
# modify tokenizer
new_tokens_added = 0
for ti_name, ti in ti_list:
for i in range(ti.embedding.shape[0]):
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
if ti.embedding_2 is not None:
ti_embedding = (
ti.embedding_2 if ti.embedding_2.shape[1] == orig_embeddings.shape[0] else ti.embedding
)
else:
ti_embedding = ti.embedding
# modify text_encoder
orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"]
for i in range(ti_embedding.shape[0]):
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
embeddings = np.concatenate(
(np.copy(orig_embeddings), np.zeros((new_tokens_added, orig_embeddings.shape[1]))),
axis=0,
)
for ti_name, ti in ti_list:
for ti_name, _ in ti_list:
ti_tokens = []
for i in range(ti.embedding.shape[0]):
embedding = ti.embedding[i].detach().numpy()
for i in range(ti_embedding.shape[0]):
embedding = ti_embedding[i].detach().numpy()
trigger = _get_trigger(ti_name, i)
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)

View File

@ -373,12 +373,16 @@ class TextualInversionCheckpointProbe(CheckpointProbeBase):
token_dim = list(checkpoint["string_to_param"].values())[0].shape[-1]
elif "emb_params" in checkpoint:
token_dim = checkpoint["emb_params"].shape[-1]
elif "clip_g" in checkpoint:
token_dim = checkpoint["clip_g"].shape[-1]
else:
token_dim = list(checkpoint.values())[0].shape[0]
if token_dim == 768:
return BaseModelType.StableDiffusion1
elif token_dim == 1024:
return BaseModelType.StableDiffusion2
elif token_dim == 1280:
return BaseModelType.StableDiffusionXL
else:
return None

View File

@ -11,7 +11,7 @@ from invokeai.app.services.model_records import (
DuplicateModelException,
ModelRecordServiceSQL,
)
from invokeai.app.services.shared.sqlite import SqliteDatabase
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,

View File

View File

@ -0,0 +1,201 @@
import math
from typing import Union
import numpy as np
from invokeai.backend.tiles.utils import TBLR, Tile, paste
def calc_tiles_with_overlap(
image_height: int, image_width: int, tile_height: int, tile_width: int, overlap: int = 0
) -> list[Tile]:
"""Calculate the tile coordinates for a given image shape under a simple tiling scheme with overlaps.
Args:
image_height (int): The image height in px.
image_width (int): The image width in px.
tile_height (int): The tile height in px. All tiles will have this height.
tile_width (int): The tile width in px. All tiles will have this width.
overlap (int, optional): The target overlap between adjacent tiles. If the tiles do not evenly cover the image
shape, then the last row/column of tiles will overlap more than this. Defaults to 0.
Returns:
list[Tile]: A list of tiles that cover the image shape. Ordered from left-to-right, top-to-bottom.
"""
assert image_height >= tile_height
assert image_width >= tile_width
assert overlap < tile_height
assert overlap < tile_width
non_overlap_per_tile_height = tile_height - overlap
non_overlap_per_tile_width = tile_width - overlap
num_tiles_y = math.ceil((image_height - overlap) / non_overlap_per_tile_height)
num_tiles_x = math.ceil((image_width - overlap) / non_overlap_per_tile_width)
# tiles[y * num_tiles_x + x] is the tile for the y'th row, x'th column.
tiles: list[Tile] = []
# Calculate tile coordinates. (Ignore overlap values for now.)
for tile_idx_y in range(num_tiles_y):
for tile_idx_x in range(num_tiles_x):
tile = Tile(
coords=TBLR(
top=tile_idx_y * non_overlap_per_tile_height,
bottom=tile_idx_y * non_overlap_per_tile_height + tile_height,
left=tile_idx_x * non_overlap_per_tile_width,
right=tile_idx_x * non_overlap_per_tile_width + tile_width,
),
overlap=TBLR(top=0, bottom=0, left=0, right=0),
)
if tile.coords.bottom > image_height:
# If this tile would go off the bottom of the image, shift it so that it is aligned with the bottom
# of the image.
tile.coords.bottom = image_height
tile.coords.top = image_height - tile_height
if tile.coords.right > image_width:
# If this tile would go off the right edge of the image, shift it so that it is aligned with the
# right edge of the image.
tile.coords.right = image_width
tile.coords.left = image_width - tile_width
tiles.append(tile)
def get_tile_or_none(idx_y: int, idx_x: int) -> Union[Tile, None]:
if idx_y < 0 or idx_y > num_tiles_y or idx_x < 0 or idx_x > num_tiles_x:
return None
return tiles[idx_y * num_tiles_x + idx_x]
# Iterate over tiles again and calculate overlaps.
for tile_idx_y in range(num_tiles_y):
for tile_idx_x in range(num_tiles_x):
cur_tile = get_tile_or_none(tile_idx_y, tile_idx_x)
top_neighbor_tile = get_tile_or_none(tile_idx_y - 1, tile_idx_x)
left_neighbor_tile = get_tile_or_none(tile_idx_y, tile_idx_x - 1)
assert cur_tile is not None
# Update cur_tile top-overlap and corresponding top-neighbor bottom-overlap.
if top_neighbor_tile is not None:
cur_tile.overlap.top = max(0, top_neighbor_tile.coords.bottom - cur_tile.coords.top)
top_neighbor_tile.overlap.bottom = cur_tile.overlap.top
# Update cur_tile left-overlap and corresponding left-neighbor right-overlap.
if left_neighbor_tile is not None:
cur_tile.overlap.left = max(0, left_neighbor_tile.coords.right - cur_tile.coords.left)
left_neighbor_tile.overlap.right = cur_tile.overlap.left
return tiles
def merge_tiles_with_linear_blending(
dst_image: np.ndarray, tiles: list[Tile], tile_images: list[np.ndarray], blend_amount: int
):
"""Merge a set of image tiles into `dst_image` with linear blending between the tiles.
We expect every tile edge to either:
1) have an overlap of 0, because it is aligned with the image edge, or
2) have an overlap >= blend_amount.
If neither of these conditions are satisfied, we raise an exception.
The linear blending is centered at the halfway point of the overlap between adjacent tiles.
Args:
dst_image (np.ndarray): The destination image. Shape: (H, W, C).
tiles (list[Tile]): The list of tiles describing the locations of the respective `tile_images`.
tile_images (list[np.ndarray]): The tile images to merge into `dst_image`.
blend_amount (int): The amount of blending (in px) between adjacent overlapping tiles.
"""
# Sort tiles and images first by left x coordinate, then by top y coordinate. During tile processing, we want to
# iterate over tiles left-to-right, top-to-bottom.
tiles_and_images = list(zip(tiles, tile_images, strict=True))
tiles_and_images = sorted(tiles_and_images, key=lambda x: x[0].coords.left)
tiles_and_images = sorted(tiles_and_images, key=lambda x: x[0].coords.top)
# Organize tiles into rows.
tile_and_image_rows: list[list[tuple[Tile, np.ndarray]]] = []
cur_tile_and_image_row: list[tuple[Tile, np.ndarray]] = []
first_tile_in_cur_row, _ = tiles_and_images[0]
for tile_and_image in tiles_and_images:
tile, _ = tile_and_image
if not (
tile.coords.top == first_tile_in_cur_row.coords.top
and tile.coords.bottom == first_tile_in_cur_row.coords.bottom
):
# Store the previous row, and start a new one.
tile_and_image_rows.append(cur_tile_and_image_row)
cur_tile_and_image_row = []
first_tile_in_cur_row, _ = tile_and_image
cur_tile_and_image_row.append(tile_and_image)
tile_and_image_rows.append(cur_tile_and_image_row)
# Prepare 1D linear gradients for blending.
gradient_left_x = np.linspace(start=0.0, stop=1.0, num=blend_amount)
gradient_top_y = np.linspace(start=0.0, stop=1.0, num=blend_amount)
# Convert shape: (blend_amount, ) -> (blend_amount, 1). The extra dimension enables the gradient to be applied
# to a 2D image via broadcasting. Note that no additional dimension is needed on gradient_left_x for
# broadcasting to work correctly.
gradient_top_y = np.expand_dims(gradient_top_y, axis=1)
for tile_and_image_row in tile_and_image_rows:
first_tile_in_row, _ = tile_and_image_row[0]
row_height = first_tile_in_row.coords.bottom - first_tile_in_row.coords.top
row_image = np.zeros((row_height, dst_image.shape[1], dst_image.shape[2]), dtype=dst_image.dtype)
# Blend the tiles in the row horizontally.
for tile, tile_image in tile_and_image_row:
# We expect the tiles to be ordered left-to-right. For each tile, we construct a mask that applies linear
# blending to the left of the current tile. The inverse linear blending is automatically applied to the
# right of the tiles that have already been pasted by the paste(...) operation.
tile_height, tile_width, _ = tile_image.shape
mask = np.ones(shape=(tile_height, tile_width), dtype=np.float64)
# Left blending:
if tile.overlap.left > 0:
assert tile.overlap.left >= blend_amount
# Center the blending gradient in the middle of the overlap.
blend_start_left = tile.overlap.left // 2 - blend_amount // 2
# The region left of the blending region is masked completely.
mask[:, :blend_start_left] = 0.0
# Apply the blend gradient to the mask.
mask[:, blend_start_left : blend_start_left + blend_amount] = gradient_left_x
# For visual debugging:
# tile_image[:, blend_start_left : blend_start_left + blend_amount] = 0
paste(
dst_image=row_image,
src_image=tile_image,
box=TBLR(
top=0, bottom=tile.coords.bottom - tile.coords.top, left=tile.coords.left, right=tile.coords.right
),
mask=mask,
)
# Blend the row into the dst_image vertically.
# We construct a mask that applies linear blending to the top of the current row. The inverse linear blending is
# automatically applied to the bottom of the tiles that have already been pasted by the paste(...) operation.
mask = np.ones(shape=(row_image.shape[0], row_image.shape[1]), dtype=np.float64)
# Top blending:
# (See comments under 'Left blending' for an explanation of the logic.)
# We assume that the entire row has the same vertical overlaps as the first_tile_in_row.
if first_tile_in_row.overlap.top > 0:
assert first_tile_in_row.overlap.top >= blend_amount
blend_start_top = first_tile_in_row.overlap.top // 2 - blend_amount // 2
mask[:blend_start_top, :] = 0.0
mask[blend_start_top : blend_start_top + blend_amount, :] = gradient_top_y
# For visual debugging:
# row_image[blend_start_top : blend_start_top + blend_amount, :] = 0
paste(
dst_image=dst_image,
src_image=row_image,
box=TBLR(
top=first_tile_in_row.coords.top,
bottom=first_tile_in_row.coords.bottom,
left=0,
right=row_image.shape[1],
),
mask=mask,
)

View File

@ -0,0 +1,47 @@
from typing import Optional
import numpy as np
from pydantic import BaseModel, Field
class TBLR(BaseModel):
top: int
bottom: int
left: int
right: int
def __eq__(self, other):
return (
self.top == other.top
and self.bottom == other.bottom
and self.left == other.left
and self.right == other.right
)
class Tile(BaseModel):
coords: TBLR = Field(description="The coordinates of this tile relative to its parent image.")
overlap: TBLR = Field(description="The amount of overlap with adjacent tiles on each side of this tile.")
def __eq__(self, other):
return self.coords == other.coords and self.overlap == other.overlap
def paste(dst_image: np.ndarray, src_image: np.ndarray, box: TBLR, mask: Optional[np.ndarray] = None):
"""Paste a source image into a destination image.
Args:
dst_image (torch.Tensor): The destination image to paste into. Shape: (H, W, C).
src_image (torch.Tensor): The source image to paste. Shape: (H, W, C). H and W must be compatible with 'box'.
box (TBLR): Box defining the region in the 'dst_image' where 'src_image' will be pasted.
mask (Optional[torch.Tensor]): A mask that defines the blending between 'src_image' and 'dst_image'.
Range: [0.0, 1.0], Shape: (H, W). The output is calculate per-pixel according to
`src * mask + dst * (1 - mask)`.
"""
if mask is None:
dst_image[box.top : box.bottom, box.left : box.right] = src_image
else:
mask = np.expand_dims(mask, -1)
dst_image_box = dst_image[box.top : box.bottom, box.left : box.right]
dst_image[box.top : box.bottom, box.left : box.right] = src_image * mask + dst_image_box * (1.0 - mask)

View File

@ -1,5 +1,6 @@
dist/
public/locales/*.json
!public/locales/en.json
.husky/
node_modules/
patches/

View File

@ -75,6 +75,7 @@
"framer-motion": "^10.16.4",
"i18next": "^23.6.0",
"i18next-http-backend": "^2.3.1",
"idb-keyval": "^6.2.1",
"konva": "^9.2.3",
"lodash-es": "^4.17.21",
"nanostores": "^0.9.4",

View File

@ -803,8 +803,7 @@
"canny": "Canny",
"hedDescription": "Ganzheitlich verschachtelte Kantenerkennung",
"scribble": "Scribble",
"maxFaces": "Maximal Anzahl Gesichter",
"unstarImage": "Markierung aufheben"
"maxFaces": "Maximal Anzahl Gesichter"
},
"queue": {
"status": "Status",

View File

@ -67,7 +67,9 @@
"controlNet": "ControlNet",
"controlAdapter": "Control Adapter",
"data": "Data",
"delete": "Delete",
"details": "Details",
"direction": "Direction",
"ipAdapter": "IP Adapter",
"t2iAdapter": "T2I Adapter",
"darkMode": "Dark Mode",
@ -115,6 +117,7 @@
"nodesDesc": "A node based system for the generation of images is under development currently. Stay tuned for updates about this amazing feature.",
"notInstalled": "Not $t(common.installed)",
"openInNewTab": "Open in New Tab",
"orderBy": "Order By",
"outpaint": "outpaint",
"outputs": "Outputs",
"postProcessDesc1": "Invoke AI offers a wide variety of post processing features. Image Upscaling and Face Restoration are already available in the WebUI. You can access them from the Advanced Options menu of the Text To Image and Image To Image tabs. You can also process images directly, using the image action buttons above the current image display or in the viewer.",
@ -125,6 +128,8 @@
"random": "Random",
"reportBugLabel": "Report Bug",
"safetensors": "Safetensors",
"save": "Save",
"saveAs": "Save As",
"settingsLabel": "Settings",
"simple": "Simple",
"somethingWentWrong": "Something went wrong",
@ -161,7 +166,11 @@
"txt2img": "Text To Image",
"unifiedCanvas": "Unified Canvas",
"unknown": "Unknown",
"upload": "Upload"
"upload": "Upload",
"updated": "Updated",
"created": "Created",
"prevPage": "Previous Page",
"nextPage": "Next Page"
},
"controlnet": {
"controlAdapter_one": "Control Adapter",
@ -243,7 +252,6 @@
"setControlImageDimensions": "Set Control Image Dimensions To W/H",
"showAdvanced": "Show Advanced",
"toggleControlNet": "Toggle this ControlNet",
"unstarImage": "Unstar Image",
"w": "W",
"weight": "Weight",
"enableIPAdapter": "Enable IP Adapter",
@ -378,6 +386,8 @@
"showGenerations": "Show Generations",
"showUploads": "Show Uploads",
"singleColumnLayout": "Single Column Layout",
"starImage": "Star Image",
"unstarImage": "Unstar Image",
"unableToLoad": "Unable to load Gallery",
"uploads": "Uploads",
"deleteSelection": "Delete Selection",
@ -936,9 +946,9 @@
"problemSettingTitle": "Problem Setting Title",
"reloadNodeTemplates": "Reload Node Templates",
"removeLinearView": "Remove from Linear View",
"resetWorkflow": "Reset Workflow",
"resetWorkflowDesc": "Are you sure you want to reset this workflow?",
"resetWorkflowDesc2": "Resetting the workflow will clear all nodes, edges and workflow details.",
"resetWorkflow": "Reset Workflow Editor",
"resetWorkflowDesc": "Are you sure you want to reset the Workflow Editor?",
"resetWorkflowDesc2": "Resetting the Workflow Editor will clear all nodes, edges and workflow details. Saved workflows will not be affected.",
"scheduler": "Scheduler",
"schedulerDescription": "TODO",
"sDXLMainModelField": "SDXL Model",
@ -978,6 +988,7 @@
"unsupportedAnyOfLength": "too many union members ({{count}})",
"unsupportedMismatchedUnion": "mismatched CollectionOrScalar type with base types {{firstType}} and {{secondType}}",
"unableToParseFieldType": "unable to parse field type",
"unableToExtractEnumOptions": "unable to extract enum options",
"uNetField": "UNet",
"uNetFieldDescription": "UNet submodel.",
"unhandledInputProperty": "Unhandled input property",
@ -1264,7 +1275,6 @@
"modelAddedSimple": "Model Added",
"modelAddFailed": "Model Add Failed",
"nodesBrokenConnections": "Cannot load. Some connections are broken.",
"nodesCleared": "Nodes Cleared",
"nodesCorruptedGraph": "Cannot load. Graph seems to be corrupted.",
"nodesLoaded": "Nodes Loaded",
"nodesLoadedFailed": "Failed To Load Nodes",
@ -1313,7 +1323,10 @@
"uploadFailedInvalidUploadDesc": "Must be single PNG or JPEG image",
"uploadFailedUnableToLoadDesc": "Unable to load file",
"upscalingFailed": "Upscaling Failed",
"workflowLoaded": "Workflow Loaded"
"workflowLoaded": "Workflow Loaded",
"problemRetrievingWorkflow": "Problem Retrieving Workflow",
"workflowDeleted": "Workflow Deleted",
"problemDeletingWorkflow": "Problem Deleting Workflow"
},
"tooltip": {
"feature": {
@ -1608,5 +1621,32 @@
"showIntermediates": "Show Intermediates",
"snapToGrid": "Snap to Grid",
"undo": "Undo"
},
"workflows": {
"workflows": "Workflows",
"workflowLibrary": "Workflow Library",
"userWorkflows": "My Workflows",
"defaultWorkflows": "Defaults",
"projectWorkflows": "Project",
"openWorkflow": "Open Workflow",
"uploadWorkflow": "Upload Workflow",
"deleteWorkflow": "Delete Workflow",
"unnamedWorkflow": "Unnamed Workflow",
"downloadWorkflow": "Download Workflow",
"saveWorkflow": "Save Workflow",
"saveWorkflowAs": "Save Workflow As",
"problemSavingWorkflow": "Problem Saving Workflow",
"workflowSaved": "Workflow Saved",
"noWorkflows": "No Workflows",
"problemLoading": "Problem Loading Workflows",
"loading": "Loading Workflows",
"noDescription": "No description",
"searchWorkflows": "Search Workflows",
"clearWorkflowSearchFilter": "Clear Workflow Search Filter",
"workflowName": "Workflow Name",
"workflowEditorReset": "Workflow Editor Reset"
},
"app": {
"storeNotInitialized": "Store is not initialized"
}
}

View File

@ -91,7 +91,19 @@
"controlNet": "ControlNet",
"auto": "Automatico",
"simple": "Semplice",
"details": "Dettagli"
"details": "Dettagli",
"format": "formato",
"unknown": "Sconosciuto",
"folder": "Cartella",
"error": "Errore",
"installed": "Installato",
"template": "Schema",
"outputs": "Uscite",
"data": "Dati",
"somethingWentWrong": "Qualcosa è andato storto",
"copyError": "$t(gallery.copy) Errore",
"input": "Ingresso",
"notInstalled": "Non $t(common.installed)"
},
"gallery": {
"generations": "Generazioni",
@ -122,7 +134,14 @@
"preparingDownload": "Preparazione del download",
"preparingDownloadFailed": "Problema durante la preparazione del download",
"downloadSelection": "Scarica gli elementi selezionati",
"noImageSelected": "Nessuna immagine selezionata"
"noImageSelected": "Nessuna immagine selezionata",
"deleteSelection": "Elimina la selezione",
"image": "immagine",
"drop": "Rilascia",
"unstarImage": "Rimuovi preferenza immagine",
"dropOrUpload": "$t(gallery.drop) o carica",
"starImage": "Immagine preferita",
"dropToUpload": "$t(gallery.drop) per aggiornare"
},
"hotkeys": {
"keyboardShortcuts": "Tasti rapidi",
@ -477,7 +496,8 @@
"modelType": "Tipo di modello",
"customConfigFileLocation": "Posizione del file di configurazione personalizzato",
"vaePrecision": "Precisione VAE",
"noModelSelected": "Nessun modello selezionato"
"noModelSelected": "Nessun modello selezionato",
"conversionNotSupported": "Conversione non supportata"
},
"parameters": {
"images": "Immagini",
@ -838,7 +858,8 @@
"menu": "Menu",
"showGalleryPanel": "Mostra il pannello Galleria",
"loadMore": "Carica altro",
"mode": "Modalità"
"mode": "Modalità",
"resetUI": "$t(accessibility.reset) l'Interfaccia Utente"
},
"ui": {
"hideProgressImages": "Nascondi avanzamento immagini",
@ -1040,7 +1061,15 @@
"updateAllNodes": "Aggiorna tutti i nodi",
"unableToUpdateNodes_one": "Impossibile aggiornare {{count}} nodo",
"unableToUpdateNodes_many": "Impossibile aggiornare {{count}} nodi",
"unableToUpdateNodes_other": "Impossibile aggiornare {{count}} nodi"
"unableToUpdateNodes_other": "Impossibile aggiornare {{count}} nodi",
"addLinearView": "Aggiungi alla vista Lineare",
"outputFieldInInput": "Campo di uscita in ingresso",
"unableToMigrateWorkflow": "Impossibile migrare il flusso di lavoro",
"unableToUpdateNode": "Impossibile aggiornare nodo",
"unknownErrorValidatingWorkflow": "Errore sconosciuto durante la convalida del flusso di lavoro",
"collectionFieldType": "{{name}} Raccolta",
"collectionOrScalarFieldType": "{{name}} Raccolta|Scalare",
"nodeVersion": "Versione Nodo"
},
"boards": {
"autoAddBoard": "Aggiungi automaticamente bacheca",
@ -1062,7 +1091,10 @@
"deleteBoardOnly": "Elimina solo la Bacheca",
"deleteBoard": "Elimina Bacheca",
"deleteBoardAndImages": "Elimina Bacheca e Immagini",
"deletedBoardsCannotbeRestored": "Le bacheche eliminate non possono essere ripristinate"
"deletedBoardsCannotbeRestored": "Le bacheche eliminate non possono essere ripristinate",
"movingImagesToBoard_one": "Spostare {{count}} immagine nella bacheca:",
"movingImagesToBoard_many": "Spostare {{count}} immagini nella bacheca:",
"movingImagesToBoard_other": "Spostare {{count}} immagini nella bacheca:"
},
"controlnet": {
"contentShuffleDescription": "Rimescola il contenuto di un'immagine",
@ -1136,7 +1168,8 @@
"megaControl": "Mega ControlNet",
"minConfidence": "Confidenza minima",
"scribble": "Scribble",
"amult": "Angolo di illuminazione"
"amult": "Angolo di illuminazione",
"coarse": "Approssimativo"
},
"queue": {
"queueFront": "Aggiungi all'inizio della coda",
@ -1204,7 +1237,8 @@
"embedding": {
"noMatchingEmbedding": "Nessun Incorporamento corrispondente",
"addEmbedding": "Aggiungi Incorporamento",
"incompatibleModel": "Modello base incompatibile:"
"incompatibleModel": "Modello base incompatibile:",
"noEmbeddingsLoaded": "Nessun incorporamento caricato"
},
"models": {
"noMatchingModels": "Nessun modello corrispondente",
@ -1217,7 +1251,8 @@
"noRefinerModelsInstalled": "Nessun modello SDXL Refiner installato",
"noLoRAsInstalled": "Nessun LoRA installato",
"esrganModel": "Modello ESRGAN",
"addLora": "Aggiungi LoRA"
"addLora": "Aggiungi LoRA",
"noLoRAsLoaded": "Nessuna LoRA caricata"
},
"invocationCache": {
"disable": "Disabilita",
@ -1233,7 +1268,8 @@
"enable": "Abilita",
"clear": "Svuota",
"maxCacheSize": "Dimensione max cache",
"cacheSize": "Dimensione cache"
"cacheSize": "Dimensione cache",
"useCache": "Usa Cache"
},
"dynamicPrompts": {
"seedBehaviour": {

View File

@ -1137,8 +1137,7 @@
"openPose": "Openpose",
"controlAdapter_other": "Control Adapters",
"lineartAnime": "Lineart Anime",
"canny": "Canny",
"unstarImage": "取消收藏图像"
"canny": "Canny"
},
"queue": {
"status": "状态",

View File

@ -21,6 +21,7 @@ import GlobalHotkeys from './GlobalHotkeys';
import PreselectedImage from './PreselectedImage';
import Toaster from './Toaster';
import { useSocketIO } from 'app/hooks/useSocketIO';
import { useClearStorage } from 'common/hooks/useClearStorage';
const DEFAULT_CONFIG = {};
@ -36,15 +37,16 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage }: Props) => {
const language = useAppSelector(languageSelector);
const logger = useLogger('system');
const dispatch = useAppDispatch();
const clearStorage = useClearStorage();
// singleton!
useSocketIO();
const handleReset = useCallback(() => {
localStorage.clear();
clearStorage();
location.reload();
return false;
}, []);
}, [clearStorage]);
useEffect(() => {
i18n.changeLanguage(language);

View File

@ -7,21 +7,23 @@ import { $headerComponent } from 'app/store/nanostores/headerComponent';
import { $isDebugging } from 'app/store/nanostores/isDebugging';
import { $projectId } from 'app/store/nanostores/projectId';
import { $queueId, DEFAULT_QUEUE_ID } from 'app/store/nanostores/queueId';
import { store } from 'app/store/store';
import { $store } from 'app/store/nanostores/store';
import { createStore } from 'app/store/store';
import { PartialAppConfig } from 'app/types/invokeai';
import Loading from 'common/components/Loading/Loading';
import AppDndContext from 'features/dnd/components/AppDndContext';
import 'i18n';
import React, {
PropsWithChildren,
ReactNode,
lazy,
memo,
useEffect,
useMemo,
} from 'react';
import { Provider } from 'react-redux';
import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares';
import { ManagerOptions, SocketOptions } from 'socket.io-client';
import Loading from 'common/components/Loading/Loading';
import AppDndContext from 'features/dnd/components/AppDndContext';
import 'i18n';
const App = lazy(() => import('./App'));
const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider'));
@ -137,6 +139,14 @@ const InvokeAIUI = ({
};
}, [isDebugging]);
const store = useMemo(() => {
return createStore(projectId);
}, [projectId]);
useEffect(() => {
$store.set(store);
}, [store]);
return (
<React.StrictMode>
<Provider store={store}>

View File

@ -9,9 +9,9 @@ import { TOAST_OPTIONS, theme as invokeAITheme } from 'theme/theme';
import '@fontsource-variable/inter';
import { MantineProvider } from '@mantine/core';
import { useMantineTheme } from 'mantine-theme/theme';
import 'overlayscrollbars/overlayscrollbars.css';
import 'theme/css/overlayscrollbars.css';
import { useMantineTheme } from 'mantine-theme/theme';
type ThemeLocaleProviderProps = {
children: ReactNode;

View File

@ -3,8 +3,8 @@ import { $authToken } from 'app/store/nanostores/authToken';
import { $baseUrl } from 'app/store/nanostores/baseUrl';
import { $isDebugging } from 'app/store/nanostores/isDebugging';
import { useAppDispatch } from 'app/store/storeHooks';
import { MapStore, WritableAtom, atom, map } from 'nanostores';
import { useEffect } from 'react';
import { MapStore, atom, map } from 'nanostores';
import { useEffect, useMemo } from 'react';
import {
ClientToServerEvents,
ServerToClientEvents,
@ -16,57 +16,10 @@ import { ManagerOptions, Socket, SocketOptions, io } from 'socket.io-client';
declare global {
interface Window {
$socketOptions?: MapStore<Partial<ManagerOptions & SocketOptions>>;
$socketUrl?: WritableAtom<string>;
}
}
const makeSocketOptions = (): Partial<ManagerOptions & SocketOptions> => {
const socketOptions: Parameters<typeof io>[0] = {
timeout: 60000,
path: '/ws/socket.io',
autoConnect: false, // achtung! removing this breaks the dynamic middleware
forceNew: true,
};
// if building in package mode, replace socket url with open api base url minus the http protocol
if (['nodes', 'package'].includes(import.meta.env.MODE)) {
const authToken = $authToken.get();
if (authToken) {
// TODO: handle providing jwt to socket.io
socketOptions.auth = { token: authToken };
}
socketOptions.transports = ['websocket', 'polling'];
}
return socketOptions;
};
const makeSocketUrl = (): string => {
const wsProtocol = window.location.protocol === 'https:' ? 'wss' : 'ws';
let socketUrl = `${wsProtocol}://${window.location.host}`;
if (['nodes', 'package'].includes(import.meta.env.MODE)) {
const baseUrl = $baseUrl.get();
if (baseUrl) {
//eslint-disable-next-line
socketUrl = baseUrl.replace(/^https?\:\/\//i, '');
}
}
return socketUrl;
};
const makeSocket = (): Socket<ServerToClientEvents, ClientToServerEvents> => {
const socketOptions = makeSocketOptions();
const socketUrl = $socketUrl.get();
const socket: Socket<ServerToClientEvents, ClientToServerEvents> = io(
socketUrl,
{ ...socketOptions, ...$socketOptions.get() }
);
return socket;
};
export const $socketOptions = map<Partial<ManagerOptions & SocketOptions>>({});
export const $socketUrl = atom<string>(makeSocketUrl());
export const $isSocketInitialized = atom<boolean>(false);
/**
@ -74,23 +27,50 @@ export const $isSocketInitialized = atom<boolean>(false);
*/
export const useSocketIO = () => {
const dispatch = useAppDispatch();
const socketOptions = useStore($socketOptions);
const socketUrl = useStore($socketUrl);
const baseUrl = useStore($baseUrl);
const authToken = useStore($authToken);
const addlSocketOptions = useStore($socketOptions);
const socketUrl = useMemo(() => {
const wsProtocol = window.location.protocol === 'https:' ? 'wss' : 'ws';
if (baseUrl) {
return baseUrl.replace(/^https?:\/\//i, '');
}
return `${wsProtocol}://${window.location.host}`;
}, [baseUrl]);
const socketOptions = useMemo(() => {
const options: Parameters<typeof io>[0] = {
timeout: 60000,
path: '/ws/socket.io',
autoConnect: false, // achtung! removing this breaks the dynamic middleware
forceNew: true,
};
if (authToken) {
options.auth = { token: authToken };
options.transports = ['websocket', 'polling'];
}
return { ...options, ...addlSocketOptions };
}, [authToken, addlSocketOptions]);
useEffect(() => {
if ($isSocketInitialized.get()) {
// Singleton!
return;
}
const socket = makeSocket();
const socket: Socket<ServerToClientEvents, ClientToServerEvents> = io(
socketUrl,
socketOptions
);
setEventListeners({ dispatch, socket });
socket.connect();
if ($isDebugging.get()) {
window.$socketOptions = $socketOptions;
window.$socketUrl = $socketUrl;
console.log('Socket initialized', socket);
}
@ -99,11 +79,10 @@ export const useSocketIO = () => {
return () => {
if ($isDebugging.get()) {
window.$socketOptions = undefined;
window.$socketUrl = undefined;
console.log('Socket teardown', socket);
}
socket.disconnect();
$isSocketInitialized.set(false);
};
}, [dispatch, socketOptions, socketUrl, baseUrl, authToken]);
}, [dispatch, socketOptions, socketUrl]);
};

View File

@ -1,8 +1 @@
export const LOCALSTORAGE_KEYS = [
'chakra-ui-color-mode',
'i18nextLng',
'ROARR_FILTER',
'ROARR_LOG',
];
export const LOCALSTORAGE_PREFIX = '@@invokeai-';
export const STORAGE_PREFIX = '@@invokeai-';

View File

@ -3,6 +3,7 @@ import { buildNodesGraph } from 'features/nodes/util/graph/buildNodesGraph';
import { queueApi } from 'services/api/endpoints/queue';
import { BatchConfig } from 'services/api/types';
import { startAppListening } from '..';
import { buildWorkflow } from 'features/nodes/util/workflow/buildWorkflow';
export const addEnqueueRequestedNodes = () => {
startAppListening({
@ -10,10 +11,18 @@ export const addEnqueueRequestedNodes = () => {
enqueueRequested.match(action) && action.payload.tabName === 'nodes',
effect: async (action, { getState, dispatch }) => {
const state = getState();
const { nodes, edges } = state.nodes;
const workflow = state.workflow;
const graph = buildNodesGraph(state.nodes);
const builtWorkflow = buildWorkflow({
nodes,
edges,
workflow,
});
const batchConfig: BatchConfig = {
batch: {
graph,
workflow: builtWorkflow,
runs: state.generation.iterations,
},
prepend: action.payload.prepend,

View File

@ -11,13 +11,11 @@ import {
TypesafeDroppableData,
} from 'features/dnd/types';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import {
fieldImageValueChanged,
workflowExposedFieldAdded,
} from 'features/nodes/store/nodesSlice';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { imagesApi } from 'services/api/endpoints/images';
import { startAppListening } from '../';
import { workflowExposedFieldAdded } from 'features/nodes/store/workflowSlice';
export const dndDropped = createAction<{
overData: TypesafeDroppableData;

View File

@ -1,7 +1,7 @@
import { logger } from 'app/logging/logger';
import { parseify } from 'common/util/serialize';
import { workflowLoadRequested } from 'features/nodes/store/actions';
import { workflowLoaded } from 'features/nodes/store/nodesSlice';
import { workflowLoaded } from 'features/nodes/store/actions';
import { $flow } from 'features/nodes/store/reactFlowInstance';
import {
WorkflowMigrationError,
@ -21,7 +21,7 @@ export const addWorkflowLoadRequestedListener = () => {
actionCreator: workflowLoadRequested,
effect: (action, { dispatch, getState }) => {
const log = logger('nodes');
const workflow = action.payload;
const { workflow, asCopy } = action.payload;
const nodeTemplates = getState().nodes.nodeTemplates;
try {
@ -29,6 +29,12 @@ export const addWorkflowLoadRequestedListener = () => {
workflow,
nodeTemplates
);
if (asCopy) {
// If we're loading a copy, we need to remove the ID so that the backend will create a new workflow
delete validatedWorkflow.id;
}
dispatch(workflowLoaded(validatedWorkflow));
if (!warnings.length) {
dispatch(
@ -99,7 +105,6 @@ export const addWorkflowLoadRequestedListener = () => {
);
} else {
// Some other error occurred
console.log(e);
log.error(
{ error: parseify(e) },
t('nodes.unknownErrorValidatingWorkflow')

View File

@ -1,5 +1,6 @@
import { Store } from '@reduxjs/toolkit';
import { createStore } from 'app/store/store';
import { atom } from 'nanostores';
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export const $store = atom<Store<any> | undefined>();
export const $store = atom<
Readonly<ReturnType<typeof createStore>> | undefined
>();

View File

@ -14,6 +14,7 @@ import galleryReducer from 'features/gallery/store/gallerySlice';
import loraReducer from 'features/lora/store/loraSlice';
import modelmanagerReducer from 'features/modelManager/store/modelManagerSlice';
import nodesReducer from 'features/nodes/store/nodesSlice';
import workflowReducer from 'features/nodes/store/workflowSlice';
import generationReducer from 'features/parameters/store/generationSlice';
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
import queueReducer from 'features/queue/store/queueSlice';
@ -22,17 +23,17 @@ import configReducer from 'features/system/store/configSlice';
import systemReducer from 'features/system/store/systemSlice';
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
import uiReducer from 'features/ui/store/uiSlice';
import { createStore as createIDBKeyValStore, get, set } from 'idb-keyval';
import dynamicMiddlewares from 'redux-dynamic-middlewares';
import { rememberEnhancer, rememberReducer } from 'redux-remember';
import { Driver, rememberEnhancer, rememberReducer } from 'redux-remember';
import { api } from 'services/api';
import { LOCALSTORAGE_PREFIX } from './constants';
import { STORAGE_PREFIX } from './constants';
import { serialize } from './enhancers/reduxRemember/serialize';
import { unserialize } from './enhancers/reduxRemember/unserialize';
import { actionSanitizer } from './middleware/devtools/actionSanitizer';
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
import { listenerMiddleware } from './middleware/listenerMiddleware';
import { $store } from './nanostores/store';
const allReducers = {
canvas: canvasReducer,
@ -52,6 +53,7 @@ const allReducers = {
modelmanager: modelmanagerReducer,
sdxl: sdxlReducer,
queue: queueReducer,
workflow: workflowReducer,
[api.reducerPath]: api.reducer,
};
@ -65,6 +67,7 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
'generation',
'sdxl',
'nodes',
'workflow',
'postprocessing',
'system',
'ui',
@ -74,57 +77,70 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
'modelmanager',
];
export const store = configureStore({
reducer: rememberedRootReducer,
enhancers: (existingEnhancers) => {
return existingEnhancers
.concat(
rememberEnhancer(window.localStorage, rememberedKeys, {
persistDebounce: 300,
serialize,
unserialize,
prefix: LOCALSTORAGE_PREFIX,
})
)
.concat(autoBatchEnhancer());
},
middleware: (getDefaultMiddleware) =>
getDefaultMiddleware({
serializableCheck: false,
immutableCheck: false,
})
.concat(api.middleware)
.concat(dynamicMiddlewares)
.prepend(listenerMiddleware.middleware),
devTools: {
actionSanitizer,
stateSanitizer,
trace: true,
predicate: (state, action) => {
// TODO: hook up to the log level param in system slice
// manually type state, cannot type the arg
// const typedState = state as ReturnType<typeof rootReducer>;
// Create a custom idb-keyval store (just needed to customize the name)
export const idbKeyValStore = createIDBKeyValStore('invoke', 'invoke-store');
// TODO: doing this breaks the rtk query devtools, commenting out for now
// if (action.type.startsWith('api/')) {
// // don't log api actions, with manual cache updates they are extremely noisy
// return false;
// }
// Create redux-remember driver, wrapping idb-keyval
const idbKeyValDriver: Driver = {
getItem: (key) => get(key, idbKeyValStore),
setItem: (key, value) => set(key, value, idbKeyValStore),
};
if (actionsDenylist.includes(action.type)) {
// don't log other noisy actions
return false;
}
return true;
export const createStore = (uniqueStoreKey?: string) =>
configureStore({
reducer: rememberedRootReducer,
enhancers: (existingEnhancers) => {
return existingEnhancers
.concat(
rememberEnhancer(idbKeyValDriver, rememberedKeys, {
persistDebounce: 300,
serialize,
unserialize,
prefix: uniqueStoreKey
? `${STORAGE_PREFIX}${uniqueStoreKey}-`
: STORAGE_PREFIX,
})
)
.concat(autoBatchEnhancer());
},
},
});
middleware: (getDefaultMiddleware) =>
getDefaultMiddleware({
serializableCheck: false,
immutableCheck: false,
})
.concat(api.middleware)
.concat(dynamicMiddlewares)
.prepend(listenerMiddleware.middleware),
devTools: {
actionSanitizer,
stateSanitizer,
trace: true,
predicate: (state, action) => {
// TODO: hook up to the log level param in system slice
// manually type state, cannot type the arg
// const typedState = state as ReturnType<typeof rootReducer>;
export type AppGetState = typeof store.getState;
export type RootState = ReturnType<typeof store.getState>;
// TODO: doing this breaks the rtk query devtools, commenting out for now
// if (action.type.startsWith('api/')) {
// // don't log api actions, with manual cache updates they are extremely noisy
// return false;
// }
if (actionsDenylist.includes(action.type)) {
// don't log other noisy actions
return false;
}
return true;
},
},
});
export type AppGetState = ReturnType<
ReturnType<typeof createStore>['getState']
>;
export type RootState = ReturnType<ReturnType<typeof createStore>['getState']>;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export type AppThunkDispatch = ThunkDispatch<RootState, any, AnyAction>;
export type AppDispatch = typeof store.dispatch;
export type AppDispatch = ReturnType<typeof createStore>['dispatch'];
export const stateSelector = (state: RootState) => state;
$store.set(store);

View File

@ -1,4 +1,10 @@
import { FormControl, FormLabel, Tooltip, forwardRef } from '@chakra-ui/react';
import {
FormControl,
FormControlProps,
FormLabel,
Tooltip,
forwardRef,
} from '@chakra-ui/react';
import { Select, SelectProps } from '@mantine/core';
import { useMantineSelectStyles } from 'mantine-theme/hooks/useMantineSelectStyles';
import { RefObject, memo } from 'react';
@ -13,10 +19,19 @@ export type IAISelectProps = Omit<SelectProps, 'label'> & {
tooltip?: string | null;
inputRef?: RefObject<HTMLInputElement>;
label?: string;
formControlProps?: FormControlProps;
};
const IAIMantineSelect = forwardRef((props: IAISelectProps, ref) => {
const { tooltip, inputRef, label, disabled, required, ...rest } = props;
const {
tooltip,
formControlProps,
inputRef,
label,
disabled,
required,
...rest
} = props;
const styles = useMantineSelectStyles();
@ -28,6 +43,7 @@ const IAIMantineSelect = forwardRef((props: IAISelectProps, ref) => {
isDisabled={disabled}
position="static"
data-testid={`select-${label || props.placeholder}`}
{...formControlProps}
>
<FormLabel>{label}</FormLabel>
<Select disabled={disabled} ref={inputRef} styles={styles} {...rest} />

View File

@ -0,0 +1 @@
export const Nbsp = () => <>{'\u00A0'}</>;

View File

@ -0,0 +1,12 @@
import { idbKeyValStore } from 'app/store/store';
import { clear } from 'idb-keyval';
import { useCallback } from 'react';
export const useClearStorage = () => {
const clearStorage = useCallback(() => {
clear(idbKeyValStore);
localStorage.clear();
}, []);
return clearStorage;
};

View File

@ -5,14 +5,19 @@ import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIDndImage from 'common/components/IAIDndImage';
import IAIDndImageIcon from 'common/components/IAIDndImageIcon';
import { setBoundingBoxDimensions } from 'features/canvas/store/canvasSlice';
import { useControlAdapterControlImage } from 'features/controlAdapters/hooks/useControlAdapterControlImage';
import { useControlAdapterProcessedControlImage } from 'features/controlAdapters/hooks/useControlAdapterProcessedControlImage';
import { useControlAdapterProcessorType } from 'features/controlAdapters/hooks/useControlAdapterProcessorType';
import { controlAdapterImageChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
import {
TypesafeDraggableData,
TypesafeDroppableData,
} from 'features/dnd/types';
import { setHeight, setWidth } from 'features/parameters/store/generationSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { memo, useCallback, useMemo, useState } from 'react';
import { memo, useCallback, useEffect, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { FaRulerVertical, FaSave, FaUndo } from 'react-icons/fa';
import {
@ -22,11 +27,6 @@ import {
useRemoveImageFromBoardMutation,
} from 'services/api/endpoints/images';
import { PostUploadAction } from 'services/api/types';
import IAIDndImageIcon from 'common/components/IAIDndImageIcon';
import { controlAdapterImageChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
import { useControlAdapterControlImage } from 'features/controlAdapters/hooks/useControlAdapterControlImage';
import { useControlAdapterProcessedControlImage } from 'features/controlAdapters/hooks/useControlAdapterProcessedControlImage';
import { useControlAdapterProcessorType } from 'features/controlAdapters/hooks/useControlAdapterProcessorType';
type Props = {
id: string;
@ -35,13 +35,15 @@ type Props = {
const selector = createSelector(
stateSelector,
({ controlAdapters, gallery }) => {
({ controlAdapters, gallery, system }) => {
const { pendingControlImages } = controlAdapters;
const { autoAddBoardId } = gallery;
const { isConnected } = system;
return {
pendingControlImages,
autoAddBoardId,
isConnected,
};
},
defaultSelectorOptions
@ -55,18 +57,19 @@ const ControlAdapterImagePreview = ({ isSmall, id }: Props) => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { pendingControlImages, autoAddBoardId } = useAppSelector(selector);
const { pendingControlImages, autoAddBoardId, isConnected } =
useAppSelector(selector);
const activeTabName = useAppSelector(activeTabNameSelector);
const [isMouseOverImage, setIsMouseOverImage] = useState(false);
const { currentData: controlImage } = useGetImageDTOQuery(
controlImageName ?? skipToken
);
const { currentData: controlImage, isError: isErrorControlImage } =
useGetImageDTOQuery(controlImageName ?? skipToken);
const { currentData: processedControlImage } = useGetImageDTOQuery(
processedControlImageName ?? skipToken
);
const {
currentData: processedControlImage,
isError: isErrorProcessedControlImage,
} = useGetImageDTOQuery(processedControlImageName ?? skipToken);
const [changeIsIntermediate] = useChangeImageIsIntermediateMutation();
const [addToBoard] = useAddImageToBoardMutation();
@ -158,6 +161,17 @@ const ControlAdapterImagePreview = ({ isSmall, id }: Props) => {
!pendingControlImages.includes(id) &&
processorType !== 'none';
useEffect(() => {
if (isConnected && (isErrorControlImage || isErrorProcessedControlImage)) {
handleResetControlImage();
}
}, [
handleResetControlImage,
isConnected,
isErrorControlImage,
isErrorProcessedControlImage,
]);
return (
<Flex
onMouseEnter={handleMouseEnter}

View File

@ -73,7 +73,13 @@ const BoardContextMenu = ({
addToast({
title: t('gallery.preparingDownload'),
status: 'success',
...(response.response ? { description: response.response } : {}),
...(response.response
? {
description: response.response,
duration: null,
isClosable: true,
}
: {}),
})
);
} catch {

View File

@ -16,7 +16,8 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import { DeleteImageButton } from 'features/deleteImageModal/components/DeleteImageButton';
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
import { workflowLoadRequested } from 'features/nodes/store/actions';
import SingleSelectionMenuItems from 'features/gallery/components/ImageContextMenu/SingleSelectionMenuItems';
import { sentImageToImg2Img } from 'features/gallery/store/actions';
import ParamUpscalePopover from 'features/parameters/components/Parameters/Upscale/ParamUpscaleSettings';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { initialImageSelected } from 'features/parameters/store/actions';
@ -27,6 +28,7 @@ import {
setShouldShowImageDetails,
setShouldShowProgressInViewer,
} from 'features/ui/store/uiSlice';
import { useGetAndLoadEmbeddedWorkflow } from 'features/workflowLibrary/hooks/useGetAndLoadEmbeddedWorkflow';
import { memo, useCallback } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
@ -41,10 +43,7 @@ import {
import { FaCircleNodes, FaEllipsis } from 'react-icons/fa6';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { useDebouncedMetadata } from 'services/api/hooks/useDebouncedMetadata';
import { useDebouncedWorkflow } from 'services/api/hooks/useDebouncedWorkflow';
import { menuListMotionProps } from 'theme/components/menu';
import { sentImageToImg2Img } from 'features/gallery/store/actions';
import SingleSelectionMenuItems from 'features/gallery/components/ImageContextMenu/SingleSelectionMenuItems';
const currentImageButtonsSelector = createSelector(
[stateSelector, activeTabNameSelector],
@ -111,18 +110,17 @@ const CurrentImageButtons = () => {
lastSelectedImage?.image_name
);
const { workflow, isLoading: isLoadingWorkflow } = useDebouncedWorkflow(
lastSelectedImage?.workflow_id
);
const { getAndLoadEmbeddedWorkflow, getAndLoadEmbeddedWorkflowResult } =
useGetAndLoadEmbeddedWorkflow({});
const handleLoadWorkflow = useCallback(() => {
if (!workflow) {
if (!lastSelectedImage || !lastSelectedImage.has_workflow) {
return;
}
dispatch(workflowLoadRequested(workflow));
}, [dispatch, workflow]);
getAndLoadEmbeddedWorkflow(lastSelectedImage.image_name);
}, [getAndLoadEmbeddedWorkflow, lastSelectedImage]);
useHotkeys('w', handleLoadWorkflow, [workflow]);
useHotkeys('w', handleLoadWorkflow, [lastSelectedImage]);
const handleClickUseAllParameters = useCallback(() => {
recallAllParameters(metadata);
@ -255,12 +253,12 @@ const CurrentImageButtons = () => {
<ButtonGroup isAttached={true} isDisabled={shouldDisableToolbarButtons}>
<IAIIconButton
isLoading={isLoadingWorkflow}
icon={<FaCircleNodes />}
tooltip={`${t('nodes.loadWorkflow')} (W)`}
aria-label={`${t('nodes.loadWorkflow')} (W)`}
isDisabled={!workflow}
isDisabled={!imageDTO?.has_workflow}
onClick={handleLoadWorkflow}
isLoading={getAndLoadEmbeddedWorkflowResult.isLoading}
/>
<IAIIconButton
isLoading={isLoadingMetadata}

View File

@ -59,7 +59,13 @@ const MultipleSelectionMenuItems = () => {
addToast({
title: t('gallery.preparingDownload'),
status: 'success',
...(response.response ? { description: response.response } : {}),
...(response.response
? {
description: response.response,
duration: null,
isClosable: true,
}
: {}),
})
);
} catch {

View File

@ -3,18 +3,22 @@ import { useStore } from '@nanostores/react';
import { useAppToaster } from 'app/components/Toaster';
import { $customStarUI } from 'app/store/nanostores/customStarUI';
import { useAppDispatch } from 'app/store/storeHooks';
import { useCopyImageToClipboard } from 'common/hooks/useCopyImageToClipboard';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import {
imagesToChangeSelected,
isModalOpenChanged,
} from 'features/changeBoardModal/store/slice';
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
import { workflowLoadRequested } from 'features/nodes/store/actions';
import {
sentImageToCanvas,
sentImageToImg2Img,
} from 'features/gallery/store/actions';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { initialImageSelected } from 'features/parameters/store/actions';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { useCopyImageToClipboard } from 'common/hooks/useCopyImageToClipboard';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { useGetAndLoadEmbeddedWorkflow } from 'features/workflowLibrary/hooks/useGetAndLoadEmbeddedWorkflow';
import { memo, useCallback } from 'react';
import { flushSync } from 'react-dom';
import { useTranslation } from 'react-i18next';
@ -36,12 +40,7 @@ import {
useUnstarImagesMutation,
} from 'services/api/endpoints/images';
import { useDebouncedMetadata } from 'services/api/hooks/useDebouncedMetadata';
import { useDebouncedWorkflow } from 'services/api/hooks/useDebouncedWorkflow';
import { ImageDTO } from 'services/api/types';
import {
sentImageToCanvas,
sentImageToImg2Img,
} from 'features/gallery/store/actions';
type SingleSelectionMenuItemsProps = {
imageDTO: ImageDTO;
@ -61,9 +60,13 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
const { metadata, isLoading: isLoadingMetadata } = useDebouncedMetadata(
imageDTO?.image_name
);
const { workflow, isLoading: isLoadingWorkflow } = useDebouncedWorkflow(
imageDTO?.workflow_id
);
const { getAndLoadEmbeddedWorkflow, getAndLoadEmbeddedWorkflowResult } =
useGetAndLoadEmbeddedWorkflow({});
const handleLoadWorkflow = useCallback(() => {
getAndLoadEmbeddedWorkflow(imageDTO.image_name);
}, [getAndLoadEmbeddedWorkflow, imageDTO.image_name]);
const [starImages] = useStarImagesMutation();
const [unstarImages] = useUnstarImagesMutation();
@ -101,13 +104,6 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
recallSeed(metadata?.seed);
}, [metadata?.seed, recallSeed]);
const handleLoadWorkflow = useCallback(() => {
if (!workflow) {
return;
}
dispatch(workflowLoadRequested(workflow));
}, [dispatch, workflow]);
const handleSendToImageToImage = useCallback(() => {
dispatch(sentImageToImg2Img());
dispatch(initialImageSelected(imageDTO));
@ -179,9 +175,15 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
{t('parameters.downloadImage')}
</MenuItem>
<MenuItem
icon={isLoadingWorkflow ? <SpinnerIcon /> : <FaCircleNodes />}
icon={
getAndLoadEmbeddedWorkflowResult.isLoading ? (
<SpinnerIcon />
) : (
<FaCircleNodes />
)
}
onClickCapture={handleLoadWorkflow}
isDisabled={isLoadingWorkflow || !workflow}
isDisabled={!imageDTO.has_workflow}
>
{t('nodes.loadWorkflow')}
</MenuItem>
@ -234,14 +236,14 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
icon={customStarUi ? customStarUi.off.icon : <MdStar />}
onClickCapture={handleUnstarImage}
>
{customStarUi ? customStarUi.off.text : t('controlnet.unstarImage')}
{customStarUi ? customStarUi.off.text : t('gallery.unstarImage')}
</MenuItem>
) : (
<MenuItem
icon={customStarUi ? customStarUi.on.icon : <MdStarBorder />}
onClickCapture={handleStarImage}
>
{customStarUi ? customStarUi.on.text : `Star Image`}
{customStarUi ? customStarUi.on.text : t('gallery.starImage')}
</MenuItem>
)}
<MenuItem

View File

@ -14,10 +14,10 @@ import ScrollableContent from 'features/nodes/components/sidePanel/ScrollableCon
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { useDebouncedMetadata } from 'services/api/hooks/useDebouncedMetadata';
import { useDebouncedWorkflow } from 'services/api/hooks/useDebouncedWorkflow';
import { ImageDTO } from 'services/api/types';
import DataViewer from './DataViewer';
import ImageMetadataActions from './ImageMetadataActions';
import ImageMetadataWorkflowTabContent from './ImageMetadataWorkflowTabContent';
type ImageMetadataViewerProps = {
image: ImageDTO;
@ -32,7 +32,6 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
const { t } = useTranslation();
const { metadata } = useDebouncedMetadata(image.image_name);
const { workflow } = useDebouncedWorkflow(image.workflow_id);
return (
<Flex
@ -67,9 +66,9 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
>
<TabList>
<Tab>{t('metadata.recallParameters')}</Tab>
<Tab>{t('metadata.metadata')}</Tab>
<Tab isDisabled={!metadata}>{t('metadata.metadata')}</Tab>
<Tab>{t('metadata.imageDetails')}</Tab>
<Tab>{t('metadata.workflow')}</Tab>
<Tab isDisabled={!image.has_workflow}>{t('metadata.workflow')}</Tab>
</TabList>
<TabPanels>
@ -97,11 +96,7 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
)}
</TabPanel>
<TabPanel>
{workflow ? (
<DataViewer data={workflow} label={t('metadata.workflow')} />
) : (
<IAINoContentFallback label={t('nodes.noWorkflow')} />
)}
<ImageMetadataWorkflowTabContent image={image} />
</TabPanel>
</TabPanels>
</Tabs>

View File

@ -0,0 +1,23 @@
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetImageWorkflowQuery } from 'services/api/endpoints/images';
import { ImageDTO } from 'services/api/types';
import DataViewer from './DataViewer';
type Props = {
image: ImageDTO;
};
const ImageMetadataWorkflowTabContent = ({ image }: Props) => {
const { t } = useTranslation();
const { currentData: workflow } = useGetImageWorkflowQuery(image.image_name);
if (!workflow) {
return <IAINoContentFallback label={t('nodes.noWorkflow')} />;
}
return <DataViewer data={workflow} label={t('metadata.workflow')} />;
};
export default memo(ImageMetadataWorkflowTabContent);

View File

@ -1,45 +0,0 @@
import { Checkbox, Flex, FormControl, FormLabel } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { useEmbedWorkflow } from 'features/nodes/hooks/useEmbedWorkflow';
import { useWithWorkflow } from 'features/nodes/hooks/useWithWorkflow';
import { nodeEmbedWorkflowChanged } from 'features/nodes/store/nodesSlice';
import { ChangeEvent, memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const EmbedWorkflowCheckbox = ({ nodeId }: { nodeId: string }) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const withWorkflow = useWithWorkflow(nodeId);
const embedWorkflow = useEmbedWorkflow(nodeId);
const handleChange = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(
nodeEmbedWorkflowChanged({
nodeId,
embedWorkflow: e.target.checked,
})
);
},
[dispatch, nodeId]
);
if (!withWorkflow) {
return null;
}
return (
<FormControl as={Flex} sx={{ alignItems: 'center', gap: 2, w: 'auto' }}>
<FormLabel sx={{ fontSize: 'xs', mb: '1px' }}>
{t('metadata.workflow')}
</FormLabel>
<Checkbox
className="nopan"
size="sm"
onChange={handleChange}
isChecked={embedWorkflow}
/>
</FormControl>
);
};
export default memo(EmbedWorkflowCheckbox);

View File

@ -1,9 +1,8 @@
import { Flex } from '@chakra-ui/react';
import { useHasImageOutput } from 'features/nodes/hooks/useHasImageOutput';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import { memo } from 'react';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import EmbedWorkflowCheckbox from './EmbedWorkflowCheckbox';
import { memo } from 'react';
import SaveToGalleryCheckbox from './SaveToGalleryCheckbox';
import UseCacheCheckbox from './UseCacheCheckbox';
@ -28,7 +27,6 @@ const InvocationNodeFooter = ({ nodeId }: Props) => {
}}
>
{isCacheEnabled && <UseCacheCheckbox nodeId={nodeId} />}
{hasImageOutput && <EmbedWorkflowCheckbox nodeId={nodeId} />}
{hasImageOutput && <SaveToGalleryCheckbox nodeId={nodeId} />}
</Flex>
);

View File

@ -13,7 +13,7 @@ import { useFieldTemplateTitle } from 'features/nodes/hooks/useFieldTemplateTitl
import {
workflowExposedFieldAdded,
workflowExposedFieldRemoved,
} from 'features/nodes/store/nodesSlice';
} from 'features/nodes/store/workflowSlice';
import { MouseEvent, ReactNode, memo, useCallback, useMemo } from 'react';
import { FaMinus, FaPlus } from 'react-icons/fa';
import { menuListMotionProps } from 'theme/components/menu';
@ -41,9 +41,9 @@ const FieldContextMenu = ({ nodeId, fieldName, kind, children }: Props) => {
() =>
createSelector(
stateSelector,
({ nodes }) => {
({ workflow }) => {
const isExposed = Boolean(
nodes.workflow.exposedFields.find(
workflow.exposedFields.find(
(f) => f.nodeId === nodeId && f.fieldName === fieldName
)
);

View File

@ -10,7 +10,7 @@ import { useAppDispatch } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import NodeSelectionOverlay from 'common/components/NodeSelectionOverlay';
import { useMouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
import { workflowExposedFieldRemoved } from 'features/nodes/store/nodesSlice';
import { workflowExposedFieldRemoved } from 'features/nodes/store/workflowSlice';
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
import { memo, useCallback } from 'react';
import { FaInfoCircle, FaTrash } from 'react-icons/fa';

View File

@ -1,6 +1,6 @@
import { Flex, Text } from '@chakra-ui/react';
import { skipToken } from '@reduxjs/toolkit/dist/query';
import { useAppDispatch } from 'app/store/storeHooks';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIDndImage from 'common/components/IAIDndImage';
import IAIDndImageIcon from 'common/components/IAIDndImageIcon';
import {
@ -13,7 +13,7 @@ import {
ImageFieldInputTemplate,
} from 'features/nodes/types/field';
import { FieldComponentProps } from './types';
import { memo, useCallback, useMemo } from 'react';
import { memo, useCallback, useEffect, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { FaUndo } from 'react-icons/fa';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
@ -24,8 +24,8 @@ const ImageFieldInputComponent = (
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const { currentData: imageDTO } = useGetImageDTOQuery(
const isConnected = useAppSelector((state) => state.system.isConnected);
const { currentData: imageDTO, isError } = useGetImageDTOQuery(
field.value?.image_name ?? skipToken
);
@ -67,6 +67,12 @@ const ImageFieldInputComponent = (
[nodeId, field.name]
);
useEffect(() => {
if (isConnected && isError) {
handleReset();
}
}, [handleReset, isConnected, isError]);
return (
<Flex
className="nodrag"

View File

@ -1,8 +1,10 @@
import { Flex } from '@chakra-ui/layout';
import { memo } from 'react';
import LoadWorkflowButton from './LoadWorkflowButton';
import ResetWorkflowButton from './ResetWorkflowButton';
import DownloadWorkflowButton from './DownloadWorkflowButton';
import DownloadWorkflowButton from 'features/workflowLibrary/components/DownloadWorkflowButton';
import UploadWorkflowButton from 'features/workflowLibrary/components/LoadWorkflowFromFileButton';
import ResetWorkflowEditorButton from 'features/workflowLibrary/components/ResetWorkflowButton';
import SaveWorkflowButton from 'features/workflowLibrary/components/SaveWorkflowButton';
import SaveWorkflowAsButton from 'features/workflowLibrary/components/SaveWorkflowAsButton';
const TopCenterPanel = () => {
return (
@ -16,8 +18,10 @@ const TopCenterPanel = () => {
}}
>
<DownloadWorkflowButton />
<LoadWorkflowButton />
<ResetWorkflowButton />
<UploadWorkflowButton />
<SaveWorkflowButton />
<SaveWorkflowAsButton />
<ResetWorkflowEditorButton />
</Flex>
);
};

View File

@ -1,10 +1,12 @@
import { Flex } from '@chakra-ui/layout';
import { Flex } from '@chakra-ui/react';
import WorkflowLibraryButton from 'features/workflowLibrary/components/WorkflowLibraryButton';
import { memo } from 'react';
import WorkflowEditorSettings from './WorkflowEditorSettings';
const TopRightPanel = () => {
return (
<Flex sx={{ gap: 2, position: 'absolute', top: 2, insetInlineEnd: 2 }}>
<WorkflowLibraryButton />
<WorkflowEditorSettings />
</Flex>
);

View File

@ -11,17 +11,16 @@ import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { getNeedsUpdate } from 'features/nodes/util/node/nodeUpdate';
import NotesTextarea from 'features/nodes/components/flow/nodes/Invocation/NotesTextarea';
import ScrollableContent from 'features/nodes/components/sidePanel/ScrollableContent';
import {
InvocationNodeData,
InvocationNode,
InvocationTemplate,
isInvocationNode,
} from 'features/nodes/types/invocation';
import { getNeedsUpdate } from 'features/nodes/util/node/nodeUpdate';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { Node } from 'reactflow';
import NotesTextarea from 'features/nodes/components/flow/nodes/Invocation/NotesTextarea';
import ScrollableContent from 'features/nodes/components/sidePanel/ScrollableContent';
import EditableNodeTitle from './details/EditableNodeTitle';
const selector = createSelector(
@ -62,7 +61,7 @@ const InspectorDetailsTab = () => {
export default memo(InspectorDetailsTab);
type ContentProps = {
node: Node<InvocationNodeData>;
node: InvocationNode;
template: InvocationTemplate;
};

View File

@ -5,6 +5,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIInput from 'common/components/IAIInput';
import IAITextarea from 'common/components/IAITextarea';
import ScrollableContent from 'features/nodes/components/sidePanel/ScrollableContent';
import {
workflowAuthorChanged,
workflowContactChanged,
@ -13,16 +14,15 @@ import {
workflowNotesChanged,
workflowTagsChanged,
workflowVersionChanged,
} from 'features/nodes/store/nodesSlice';
} from 'features/nodes/store/workflowSlice';
import { ChangeEvent, memo, useCallback } from 'react';
import ScrollableContent from 'features/nodes/components/sidePanel/ScrollableContent';
import { useTranslation } from 'react-i18next';
const selector = createSelector(
stateSelector,
({ nodes }) => {
({ workflow }) => {
const { author, name, description, tags, version, contact, notes } =
nodes.workflow;
workflow;
return {
name,

View File

@ -11,9 +11,9 @@ import { useTranslation } from 'react-i18next';
const selector = createSelector(
stateSelector,
({ nodes }) => {
({ workflow }) => {
return {
fields: nodes.workflow.exposedFields,
fields: workflow.exposedFields,
};
},
defaultSelectorOptions

View File

@ -0,0 +1,17 @@
import { useWorkflow } from 'features/nodes/hooks/useWorkflow';
import { useCallback } from 'react';
export const useDownloadWorkflow = () => {
const workflow = useWorkflow();
const downloadWorkflow = useCallback(() => {
const blob = new Blob([JSON.stringify(workflow, null, 2)]);
const a = document.createElement('a');
a.href = URL.createObjectURL(blob);
a.download = `${workflow.name || 'My Workflow'}.json`;
document.body.appendChild(a);
a.click();
a.remove();
}, [workflow]);
return downloadWorkflow;
};

View File

@ -1,27 +0,0 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useMemo } from 'react';
import { isInvocationNode } from 'features/nodes/types/invocation';
export const useEmbedWorkflow = (nodeId: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return false;
}
return node.data.embedWorkflow;
},
defaultSelectorOptions
),
[nodeId]
);
const embedWorkflow = useAppSelector(selector);
return embedWorkflow;
};

View File

@ -1,31 +0,0 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useMemo } from 'react';
import { isInvocationNode } from 'features/nodes/types/invocation';
export const useWithWorkflow = (nodeId: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return false;
}
const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? ''];
if (!nodeTemplate) {
return false;
}
return nodeTemplate.withWorkflow;
},
defaultSelectorOptions
),
[nodeId]
);
const withWorkflow = useAppSelector(selector);
return withWorkflow;
};

View File

@ -5,12 +5,16 @@ import { useMemo } from 'react';
import { useDebounce } from 'use-debounce';
export const useWorkflow = () => {
const nodes = useAppSelector((state: RootState) => state.nodes);
const [debouncedNodes] = useDebounce(nodes, 300);
const workflow = useMemo(
() => buildWorkflow(debouncedNodes),
[debouncedNodes]
const nodes_ = useAppSelector((state: RootState) => state.nodes.nodes);
const edges_ = useAppSelector((state: RootState) => state.nodes.edges);
const workflow_ = useAppSelector((state: RootState) => state.workflow);
const [nodes] = useDebounce(nodes_, 300);
const [edges] = useDebounce(edges_, 300);
const [workflow] = useDebounce(workflow_, 300);
const builtWorkflow = useMemo(
() => buildWorkflow({ nodes, edges, workflow }),
[nodes, edges, workflow]
);
return workflow;
return builtWorkflow;
};

View File

@ -1,4 +1,5 @@
import { createAction, isAnyOf } from '@reduxjs/toolkit';
import { WorkflowV2 } from 'features/nodes/types/workflow';
import { Graph } from 'services/api/types';
export const textToImageGraphBuilt = createAction<Graph>(
@ -17,10 +18,15 @@ export const isAnyGraphBuilt = isAnyOf(
nodesGraphBuilt
);
export const workflowLoadRequested = createAction<unknown>(
'nodes/workflowLoadRequested'
);
export const workflowLoadRequested = createAction<{
workflow: unknown;
asCopy: boolean;
}>('nodes/workflowLoadRequested');
export const updateAllNodesRequested = createAction(
'nodes/updateAllNodesRequested'
);
export const workflowLoaded = createAction<WorkflowV2>(
'workflow/workflowLoaded'
);

View File

@ -1,33 +1,5 @@
import { createSlice, PayloadAction } from '@reduxjs/toolkit';
import { cloneDeep, forEach, isEqual, uniqBy } from 'lodash-es';
import {
addEdge,
applyEdgeChanges,
applyNodeChanges,
Connection,
Edge,
EdgeChange,
EdgeRemoveChange,
getConnectedEdges,
getIncomers,
getOutgoers,
Node,
NodeChange,
OnConnectStartParams,
SelectionMode,
updateEdge,
Viewport,
XYPosition,
} from 'reactflow';
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
import {
appSocketGeneratorProgress,
appSocketInvocationComplete,
appSocketInvocationError,
appSocketInvocationStarted,
appSocketQueueItemStatusChanged,
} from 'services/events/actions';
import { v4 as uuidv4 } from 'uuid';
import { workflowLoaded } from 'features/nodes/store/actions';
import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
import {
BoardFieldValue,
@ -57,7 +29,35 @@ import {
NodeExecutionState,
zNodeStatus,
} from 'features/nodes/types/invocation';
import { WorkflowV2 } from 'features/nodes/types/workflow';
import { cloneDeep, forEach } from 'lodash-es';
import {
addEdge,
applyEdgeChanges,
applyNodeChanges,
Connection,
Edge,
EdgeChange,
EdgeRemoveChange,
getConnectedEdges,
getIncomers,
getOutgoers,
Node,
NodeChange,
OnConnectStartParams,
SelectionMode,
updateEdge,
Viewport,
XYPosition,
} from 'reactflow';
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
import {
appSocketGeneratorProgress,
appSocketInvocationComplete,
appSocketInvocationError,
appSocketInvocationStarted,
appSocketQueueItemStatusChanged,
} from 'services/events/actions';
import { v4 as uuidv4 } from 'uuid';
import { NodesState } from './types';
import { findConnectionToValidHandle } from './util/findConnectionToValidHandle';
import { findUnoccupiedPosition } from './util/findUnoccupiedPosition';
@ -70,20 +70,6 @@ const initialNodeExecutionState: Omit<NodeExecutionState, 'nodeId'> = {
outputs: [],
};
const INITIAL_WORKFLOW: WorkflowV2 = {
name: '',
author: '',
description: '',
version: '',
contact: '',
tags: '',
notes: '',
nodes: [],
edges: [],
exposedFields: [],
meta: { version: '2.0.0' },
};
export const initialNodesState: NodesState = {
nodes: [],
edges: [],
@ -103,7 +89,6 @@ export const initialNodesState: NodesState = {
nodeOpacity: 1,
selectedNodes: [],
selectedEdges: [],
workflow: INITIAL_WORKFLOW,
nodeExecutionStates: {},
viewport: { x: 0, y: 0, zoom: 1 },
mouseOverField: null,
@ -308,23 +293,6 @@ const nodesSlice = createSlice({
}
state.modifyingEdge = false;
},
workflowExposedFieldAdded: (
state,
action: PayloadAction<FieldIdentifier>
) => {
state.workflow.exposedFields = uniqBy(
state.workflow.exposedFields.concat(action.payload),
(field) => `${field.nodeId}-${field.fieldName}`
);
},
workflowExposedFieldRemoved: (
state,
action: PayloadAction<FieldIdentifier>
) => {
state.workflow.exposedFields = state.workflow.exposedFields.filter(
(field) => !isEqual(field, action.payload)
);
},
fieldLabelChanged: (
state,
action: PayloadAction<{
@ -344,20 +312,6 @@ const nodesSlice = createSlice({
}
field.label = label;
},
nodeEmbedWorkflowChanged: (
state,
action: PayloadAction<{ nodeId: string; embedWorkflow: boolean }>
) => {
const { nodeId, embedWorkflow } = action.payload;
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
const node = state.nodes?.[nodeIndex];
if (!isInvocationNode(node)) {
return;
}
node.data.embedWorkflow = embedWorkflow;
},
nodeUseCacheChanged: (
state,
action: PayloadAction<{ nodeId: string; useCache: boolean }>
@ -522,9 +476,6 @@ const nodesSlice = createSlice({
},
nodesDeleted: (state, action: PayloadAction<AnyNode[]>) => {
action.payload.forEach((node) => {
state.workflow.exposedFields = state.workflow.exposedFields.filter(
(f) => f.nodeId !== node.id
);
if (!isInvocationNode(node)) {
return;
}
@ -687,7 +638,6 @@ const nodesSlice = createSlice({
nodeEditorReset: (state) => {
state.nodes = [];
state.edges = [];
state.workflow = cloneDeep(INITIAL_WORKFLOW);
},
shouldValidateGraphChanged: (state, action: PayloadAction<boolean>) => {
state.shouldValidateGraph = action.payload;
@ -704,56 +654,6 @@ const nodesSlice = createSlice({
nodeOpacityChanged: (state, action: PayloadAction<number>) => {
state.nodeOpacity = action.payload;
},
workflowNameChanged: (state, action: PayloadAction<string>) => {
state.workflow.name = action.payload;
},
workflowDescriptionChanged: (state, action: PayloadAction<string>) => {
state.workflow.description = action.payload;
},
workflowTagsChanged: (state, action: PayloadAction<string>) => {
state.workflow.tags = action.payload;
},
workflowAuthorChanged: (state, action: PayloadAction<string>) => {
state.workflow.author = action.payload;
},
workflowNotesChanged: (state, action: PayloadAction<string>) => {
state.workflow.notes = action.payload;
},
workflowVersionChanged: (state, action: PayloadAction<string>) => {
state.workflow.version = action.payload;
},
workflowContactChanged: (state, action: PayloadAction<string>) => {
state.workflow.contact = action.payload;
},
workflowLoaded: (state, action: PayloadAction<WorkflowV2>) => {
const { nodes, edges, ...workflow } = action.payload;
state.workflow = workflow;
state.nodes = applyNodeChanges(
nodes.map((node) => ({
item: { ...node, ...SHARED_NODE_PROPERTIES },
type: 'add',
})),
[]
);
state.edges = applyEdgeChanges(
edges.map((edge) => ({ item: edge, type: 'add' })),
[]
);
state.nodeExecutionStates = nodes.reduce<
Record<string, NodeExecutionState>
>((acc, node) => {
acc[node.id] = {
nodeId: node.id,
...initialNodeExecutionState,
};
return acc;
}, {});
},
workflowReset: (state) => {
state.workflow = cloneDeep(INITIAL_WORKFLOW);
},
viewportChanged: (state, action: PayloadAction<Viewport>) => {
state.viewport = action.payload;
},
@ -899,6 +799,32 @@ const nodesSlice = createSlice({
builder.addCase(receivedOpenAPISchema.pending, (state) => {
state.isReady = false;
});
builder.addCase(workflowLoaded, (state, action) => {
const { nodes, edges } = action.payload;
state.nodes = applyNodeChanges(
nodes.map((node) => ({
item: { ...node, ...SHARED_NODE_PROPERTIES },
type: 'add',
})),
[]
);
state.edges = applyEdgeChanges(
edges.map((edge) => ({ item: edge, type: 'add' })),
[]
);
state.nodeExecutionStates = nodes.reduce<
Record<string, NodeExecutionState>
>((acc, node) => {
acc[node.id] = {
nodeId: node.id,
...initialNodeExecutionState,
};
return acc;
}, {});
});
builder.addCase(appSocketInvocationStarted, (state, action) => {
const { source_node_id } = action.payload.data;
const node = state.nodeExecutionStates[source_node_id];
@ -984,7 +910,6 @@ export const {
nodeAdded,
nodeReplaced,
nodeEditorReset,
nodeEmbedWorkflowChanged,
nodeExclusivelySelected,
nodeIsIntermediateChanged,
nodeIsOpenChanged,
@ -1008,16 +933,6 @@ export const {
shouldSnapToGridChanged,
shouldValidateGraphChanged,
viewportChanged,
workflowAuthorChanged,
workflowContactChanged,
workflowDescriptionChanged,
workflowExposedFieldAdded,
workflowExposedFieldRemoved,
workflowLoaded,
workflowNameChanged,
workflowNotesChanged,
workflowTagsChanged,
workflowVersionChanged,
edgeAdded,
} = nodesSlice.actions;

View File

@ -29,7 +29,6 @@ export type NodesState = {
shouldColorEdges: boolean;
selectedNodes: string[];
selectedEdges: string[];
workflow: Omit<WorkflowV2, 'nodes' | 'edges'>;
nodeExecutionStates: Record<string, NodeExecutionState>;
viewport: Viewport;
isReady: boolean;
@ -41,3 +40,5 @@ export type NodesState = {
addNewNodePosition: XYPosition | null;
selectionMode: SelectionMode;
};
export type WorkflowsState = Omit<WorkflowV2, 'nodes' | 'edges'>;

Some files were not shown because too many files have changed in this diff Show More