merged multi-gpu support into new session_processor architecture

This commit is contained in:
Lincoln Stein 2024-06-02 14:10:08 -04:00
commit e26360f85b
449 changed files with 20703 additions and 13428 deletions

View File

@ -18,6 +18,7 @@ help:
@echo "frontend-typegen Generate types for the frontend from the OpenAPI schema" @echo "frontend-typegen Generate types for the frontend from the OpenAPI schema"
@echo "installer-zip Build the installer .zip file for the current version" @echo "installer-zip Build the installer .zip file for the current version"
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)" @echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
@echo "openapi Generate the OpenAPI schema for the app, outputting to stdout"
# Runs ruff, fixing any safely-fixable errors and formatting # Runs ruff, fixing any safely-fixable errors and formatting
ruff: ruff:
@ -70,3 +71,6 @@ installer-zip:
tag-release: tag-release:
cd installer && ./tag_release.sh cd installer && ./tag_release.sh
# Generate the OpenAPI Schema for the app
openapi:
python scripts/generate_openapi_schema.py

View File

@ -64,7 +64,7 @@ GPU_DRIVER=nvidia
Any environment variables supported by InvokeAI can be set here - please see the [Configuration docs](https://invoke-ai.github.io/InvokeAI/features/CONFIGURATION/) for further detail. Any environment variables supported by InvokeAI can be set here - please see the [Configuration docs](https://invoke-ai.github.io/InvokeAI/features/CONFIGURATION/) for further detail.
## Even Moar Customizing! ## Even More Customizing!
See the `docker-compose.yml` file. The `command` instruction can be uncommented and used to run arbitrary startup commands. Some examples below. See the `docker-compose.yml` file. The `command` instruction can be uncommented and used to run arbitrary startup commands. Some examples below.

View File

@ -117,13 +117,13 @@ Stateless fields do not store their value in the node, so their field instances
"Custom" fields will always be treated as stateless fields. "Custom" fields will always be treated as stateless fields.
##### Collection and Scalar Fields ##### Single and Collection Fields
Field types have a name and two flags which may identify it as a **collection** or **collection or scalar** field. Field types have a name and cardinality property which may identify it as a **SINGLE**, **COLLECTION** or **SINGLE_OR_COLLECTION** field.
If a field is annotated in python as a list, its field type is parsed and flagged as a **collection** type (e.g. `list[int]`). - If a field is annotated in python as a singular value or class, its field type is parsed as a **SINGLE** type (e.g. `int`, `ImageField`, `str`).
- If a field is annotated in python as a list, its field type is parsed as a **COLLECTION** type (e.g. `list[int]`).
If it is annotated as a union of a type and list, the type will be flagged as a **collection or scalar** type (e.g. `Union[int, list[int]]`). Fields may not be unions of different types (e.g. `Union[int, list[str]]` and `Union[int, str]` are not allowed). - If it is annotated as a union of a type and list, the type will be parsed as a **SINGLE_OR_COLLECTION** type (e.g. `Union[int, list[int]]`). Fields may not be unions of different types (e.g. `Union[int, list[str]]` and `Union[int, str]` are not allowed).
## Implementation ## Implementation
@ -173,8 +173,7 @@ Field types are represented as structured objects:
```ts ```ts
type FieldType = { type FieldType = {
name: string; name: string;
isCollection: boolean; cardinality: 'SINGLE' | 'COLLECTION' | 'SINGLE_OR_COLLECTION';
isCollectionOrScalar: boolean;
}; };
``` ```
@ -186,7 +185,7 @@ There are 4 general cases for field type parsing.
When a field is annotated as a primitive values (e.g. `int`, `str`, `float`), the field type parsing is fairly straightforward. The field is represented by a simple OpenAPI **schema object**, which has a `type` property. When a field is annotated as a primitive values (e.g. `int`, `str`, `float`), the field type parsing is fairly straightforward. The field is represented by a simple OpenAPI **schema object**, which has a `type` property.
We create a field type name from this `type` string (e.g. `string` -> `StringField`). We create a field type name from this `type` string (e.g. `string` -> `StringField`). The cardinality is `"SINGLE"`.
##### Complex Types ##### Complex Types
@ -200,13 +199,13 @@ We need to **dereference** the schema to pull these out. Dereferencing may requi
When a field is annotated as a list of a single type, the schema object has an `items` property. They may be a schema object or reference object and must be parsed to determine the item type. When a field is annotated as a list of a single type, the schema object has an `items` property. They may be a schema object or reference object and must be parsed to determine the item type.
We use the item type for field type name, adding `isCollection: true` to the field type. We use the item type for field type name. The cardinality is `"COLLECTION"`.
##### Collection or Scalar Types ##### Single or Collection Types
When a field is annotated as a union of a type and list of that type, the schema object has an `anyOf` property, which holds a list of valid types for the union. When a field is annotated as a union of a type and list of that type, the schema object has an `anyOf` property, which holds a list of valid types for the union.
After verifying that the union has two members (a type and list of the same type), we use the type for field type name, adding `isCollectionOrScalar: true` to the field type. After verifying that the union has two members (a type and list of the same type), we use the type for field type name, with cardinality `"SINGLE_OR_COLLECTION"`.
##### Optional Fields ##### Optional Fields

View File

@ -165,7 +165,7 @@ Additionally, each section can be expanded with the "Show Advanced" button in o
There are several ways to install IP-Adapter models with an existing InvokeAI installation: There are several ways to install IP-Adapter models with an existing InvokeAI installation:
1. Through the command line interface launched from the invoke.sh / invoke.bat scripts, option [4] to download models. 1. Through the command line interface launched from the invoke.sh / invoke.bat scripts, option [4] to download models.
2. Through the Model Manager UI with models from the *Tools* section of [www.models.invoke.ai](https://www.models.invoke.ai). To do this, copy the repo ID from the desired model page, and paste it in the Add Model field of the model manager. **Note** Both the IP-Adapter and the Image Encoder must be installed for IP-Adapter to work. For example, the [SD 1.5 IP-Adapter](https://models.invoke.ai/InvokeAI/ip_adapter_plus_sd15) and [SD1.5 Image Encoder](https://models.invoke.ai/InvokeAI/ip_adapter_sd_image_encoder) must be installed to use IP-Adapter with SD1.5 based models. 2. Through the Model Manager UI with models from the *Tools* section of [models.invoke.ai](https://models.invoke.ai). To do this, copy the repo ID from the desired model page, and paste it in the Add Model field of the model manager. **Note** Both the IP-Adapter and the Image Encoder must be installed for IP-Adapter to work. For example, the [SD 1.5 IP-Adapter](https://models.invoke.ai/InvokeAI/ip_adapter_plus_sd15) and [SD1.5 Image Encoder](https://models.invoke.ai/InvokeAI/ip_adapter_sd_image_encoder) must be installed to use IP-Adapter with SD1.5 based models.
3. **Advanced -- Not recommended ** Manually downloading the IP-Adapter and Image Encoder files - Image Encoder folders shouid be placed in the `models\any\clip_vision` folders. IP Adapter Model folders should be placed in the relevant `ip-adapter` folder of relevant base model folder of Invoke root directory. For example, for the SDXL IP-Adapter, files should be added to the `model/sdxl/ip_adapter/` folder. 3. **Advanced -- Not recommended ** Manually downloading the IP-Adapter and Image Encoder files - Image Encoder folders shouid be placed in the `models\any\clip_vision` folders. IP Adapter Model folders should be placed in the relevant `ip-adapter` folder of relevant base model folder of Invoke root directory. For example, for the SDXL IP-Adapter, files should be added to the `model/sdxl/ip_adapter/` folder.
#### Using IP-Adapter #### Using IP-Adapter

View File

@ -154,6 +154,18 @@ This is caused by an invalid setting in the `invokeai.yaml` configuration file.
Check the [configuration docs] for more detail about the settings and how to specify them. Check the [configuration docs] for more detail about the settings and how to specify them.
## `ModuleNotFoundError: No module named 'controlnet_aux'`
`controlnet_aux` is a dependency of Invoke and appears to have been packaged or distributed strangely. Sometimes, it doesn't install correctly. This is outside our control.
If you encounter this error, the solution is to remove the package from the `pip` cache and re-run the Invoke installer so a fresh, working version of `controlnet_aux` can be downloaded and installed:
- Run the Invoke launcher
- Choose the developer console option
- Run this command: `pip cache remove controlnet_aux`
- Close the terminal window
- Download and run the [installer](https://github.com/invoke-ai/InvokeAI/releases/latest), selecting your current install location
## Out of Memory Issues ## Out of Memory Issues
The models are large, VRAM is expensive, and you may find yourself The models are large, VRAM is expensive, and you may find yourself

View File

@ -20,7 +20,7 @@ When you generate an image using text-to-image, multiple steps occur in latent s
4. The VAE decodes the final latent image from latent space into image space. 4. The VAE decodes the final latent image from latent space into image space.
Image-to-image is a similar process, with only step 1 being different: Image-to-image is a similar process, with only step 1 being different:
1. The input image is encoded from image space into latent space by the VAE. Noise is then added to the input latent image. Denoising Strength dictates how may noise steps are added, and the amount of noise added at each step. A Denoising Strength of 0 means there are 0 steps and no noise added, resulting in an unchanged image, while a Denoising Strength of 1 results in the image being completely replaced with noise and a full set of denoising steps are performance. The process is then the same as steps 2-4 in the text-to-image process. 1. The input image is encoded from image space into latent space by the VAE. Noise is then added to the input latent image. Denoising Strength dictates how many noise steps are added, and the amount of noise added at each step. A Denoising Strength of 0 means there are 0 steps and no noise added, resulting in an unchanged image, while a Denoising Strength of 1 results in the image being completely replaced with noise and a full set of denoising steps are performance. The process is then the same as steps 2-4 in the text-to-image process.
Furthermore, a model provides the CLIP prompt tokenizer, the VAE, and a U-Net (where noise prediction occurs given a prompt and initial noise tensor). Furthermore, a model provides the CLIP prompt tokenizer, the VAE, and a U-Net (where noise prediction occurs given a prompt and initial noise tensor).

View File

@ -98,7 +98,7 @@ Updating is exactly the same as installing - download the latest installer, choo
If you have installation issues, please review the [FAQ]. You can also [create an issue] or ask for help on [discord]. If you have installation issues, please review the [FAQ]. You can also [create an issue] or ask for help on [discord].
[installation requirements]: INSTALLATION.md#installation-requirements [installation requirements]: INSTALL_REQUIREMENTS.md
[FAQ]: ../help/FAQ.md [FAQ]: ../help/FAQ.md
[install some models]: 050_INSTALLING_MODELS.md [install some models]: 050_INSTALLING_MODELS.md
[configuration docs]: ../features/CONFIGURATION.md [configuration docs]: ../features/CONFIGURATION.md

View File

@ -10,7 +10,7 @@ InvokeAI is distributed as a python package on PyPI, installable with `pip`. The
### Requirements ### Requirements
Before you start, go through the [installation requirements]. Before you start, go through the [installation requirements](./INSTALL_REQUIREMENTS.md).
### Installation Walkthrough ### Installation Walkthrough
@ -79,7 +79,7 @@ Before you start, go through the [installation requirements].
1. Install the InvokeAI Package. The base command is `pip install InvokeAI --use-pep517`, but you may need to change this depending on your system and the desired features. 1. Install the InvokeAI Package. The base command is `pip install InvokeAI --use-pep517`, but you may need to change this depending on your system and the desired features.
- You may need to provide an [extra index URL]. Select your platform configuration using [this tool on the PyTorch website]. Copy the `--extra-index-url` string from this and append it to your install command. - You may need to provide an [extra index URL](https://pip.pypa.io/en/stable/cli/pip_install/#cmdoption-extra-index-url). Select your platform configuration using [this tool on the PyTorch website](https://pytorch.org/get-started/locally/). Copy the `--extra-index-url` string from this and append it to your install command.
!!! example "Install with an extra index URL" !!! example "Install with an extra index URL"
@ -116,4 +116,4 @@ Before you start, go through the [installation requirements].
!!! warning !!! warning
If the virtual environment is _not_ inside the root directory, then you _must_ specify the path to the root directory with `--root_dir \path\to\invokeai` or the `INVOKEAI_ROOT` environment variable. If the virtual environment is _not_ inside the root directory, then you _must_ specify the path to the root directory with `--root \path\to\invokeai` or the `INVOKEAI_ROOT` environment variable.

View File

@ -37,13 +37,13 @@ Invoke runs best with a dedicated GPU, but will fall back to running on CPU, alb
=== "Nvidia" === "Nvidia"
``` ```
Any GPU with at least 8GB VRAM. Linux only. Any GPU with at least 8GB VRAM.
``` ```
=== "AMD" === "AMD"
``` ```
Any GPU with at least 16GB VRAM. Any GPU with at least 16GB VRAM. Linux only.
``` ```
=== "Mac" === "Mac"

View File

@ -10,8 +10,7 @@ set INVOKEAI_ROOT=.
echo Desired action: echo Desired action:
echo 1. Generate images with the browser-based interface echo 1. Generate images with the browser-based interface
echo 2. Open the developer console echo 2. Open the developer console
echo 3. Run the InvokeAI image database maintenance script echo 3. Command-line help
echo 4. Command-line help
echo Q - Quit echo Q - Quit
echo. echo.
echo To update, download and run the installer from https://github.com/invoke-ai/InvokeAI/releases/latest. echo To update, download and run the installer from https://github.com/invoke-ai/InvokeAI/releases/latest.
@ -34,9 +33,6 @@ IF /I "%choice%" == "1" (
echo *** Type `exit` to quit this shell and deactivate the Python virtual environment *** echo *** Type `exit` to quit this shell and deactivate the Python virtual environment ***
call cmd /k call cmd /k
) ELSE IF /I "%choice%" == "3" ( ) ELSE IF /I "%choice%" == "3" (
echo Running the db maintenance script...
python .venv\Scripts\invokeai-db-maintenance.exe
) ELSE IF /I "%choice%" == "4" (
echo Displaying command line help... echo Displaying command line help...
python .venv\Scripts\invokeai-web.exe --help %* python .venv\Scripts\invokeai-web.exe --help %*
pause pause

View File

@ -47,11 +47,6 @@ do_choice() {
bash --init-file "$file_name" bash --init-file "$file_name"
;; ;;
3) 3)
clear
printf "Running the db maintenance script\n"
invokeai-db-maintenance --root ${INVOKEAI_ROOT}
;;
4)
clear clear
printf "Command-line help\n" printf "Command-line help\n"
invokeai-web --help invokeai-web --help
@ -71,8 +66,7 @@ do_line_input() {
printf "What would you like to do?\n" printf "What would you like to do?\n"
printf "1: Generate images using the browser-based interface\n" printf "1: Generate images using the browser-based interface\n"
printf "2: Open the developer console\n" printf "2: Open the developer console\n"
printf "3: Run the InvokeAI image database maintenance script\n" printf "3: Command-line help\n"
printf "4: Command-line help\n"
printf "Q: Quit\n\n" printf "Q: Quit\n\n"
printf "To update, download and run the installer from https://github.com/invoke-ai/InvokeAI/releases/latest.\n\n" printf "To update, download and run the installer from https://github.com/invoke-ai/InvokeAI/releases/latest.\n\n"
read -p "Please enter 1-4, Q: [1] " yn read -p "Please enter 1-4, Q: [1] " yn

View File

@ -19,21 +19,22 @@ from ..services.boards.boards_default import BoardService
from ..services.bulk_download.bulk_download_default import BulkDownloadService from ..services.bulk_download.bulk_download_default import BulkDownloadService
from ..services.config import InvokeAIAppConfig from ..services.config import InvokeAIAppConfig
from ..services.download import DownloadQueueService from ..services.download import DownloadQueueService
from ..services.events.events_fastapievents import FastAPIEventService
from ..services.image_files.image_files_disk import DiskImageFileStorage from ..services.image_files.image_files_disk import DiskImageFileStorage
from ..services.image_records.image_records_sqlite import SqliteImageRecordStorage from ..services.image_records.image_records_sqlite import SqliteImageRecordStorage
from ..services.images.images_default import ImageService from ..services.images.images_default import ImageService
from ..services.invocation_cache.invocation_cache_memory import MemoryInvocationCache from ..services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
from ..services.invocation_services import InvocationServices from ..services.invocation_services import InvocationServices
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
from ..services.invoker import Invoker from ..services.invoker import Invoker
from ..services.model_images.model_images_default import ModelImageFileStorageDisk from ..services.model_images.model_images_default import ModelImageFileStorageDisk
from ..services.model_manager.model_manager_default import ModelManagerService from ..services.model_manager.model_manager_default import ModelManagerService
from ..services.model_records import ModelRecordServiceSQL from ..services.model_records import ModelRecordServiceSQL
from ..services.names.names_default import SimpleNameService from ..services.names.names_default import SimpleNameService
from ..services.session_processor.session_processor_default import DefaultSessionProcessor from ..services.session_processor.session_processor_default import DefaultSessionProcessor, DefaultSessionRunner
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
from ..services.urls.urls_default import LocalUrlService from ..services.urls.urls_default import LocalUrlService
from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
from .events import FastAPIEventService
# TODO: is there a better way to achieve this? # TODO: is there a better way to achieve this?
@ -101,11 +102,9 @@ class ApiDependencies:
download_queue=download_queue_service, download_queue=download_queue_service,
events=events, events=events,
) )
# horrible hack - remove
invokeai.backend.util.devices.RAM_CACHE = model_manager.load.ram_cache
names = SimpleNameService() names = SimpleNameService()
session_processor = DefaultSessionProcessor() performance_statistics = InvocationStatsService()
session_processor = DefaultSessionProcessor(session_runner=DefaultSessionRunner())
session_queue = SqliteSessionQueue(db=db) session_queue = SqliteSessionQueue(db=db)
urls = LocalUrlService() urls = LocalUrlService()
workflow_records = SqliteWorkflowRecordsStorage(db=db) workflow_records = SqliteWorkflowRecordsStorage(db=db)
@ -127,6 +126,7 @@ class ApiDependencies:
model_manager=model_manager, model_manager=model_manager,
download_queue=download_queue_service, download_queue=download_queue_service,
names=names, names=names,
performance_statistics=performance_statistics,
session_processor=session_processor, session_processor=session_processor,
session_queue=session_queue, session_queue=session_queue,
urls=urls, urls=urls,

View File

@ -1,52 +0,0 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import asyncio
import threading
from queue import Empty, Queue
from typing import Any
from fastapi_events.dispatcher import dispatch
from ..services.events.events_base import EventServiceBase
class FastAPIEventService(EventServiceBase):
event_handler_id: int
__queue: Queue
__stop_event: threading.Event
def __init__(self, event_handler_id: int) -> None:
self.event_handler_id = event_handler_id
self.__queue = Queue()
self.__stop_event = threading.Event()
asyncio.create_task(self.__dispatch_from_queue(stop_event=self.__stop_event))
super().__init__()
def stop(self, *args, **kwargs):
self.__stop_event.set()
self.__queue.put(None)
def dispatch(self, event_name: str, payload: Any) -> None:
self.__queue.put({"event_name": event_name, "payload": payload})
async def __dispatch_from_queue(self, stop_event: threading.Event):
"""Get events on from the queue and dispatch them, from the correct thread"""
while not stop_event.is_set():
try:
event = self.__queue.get(block=False)
if not event: # Probably stopping
continue
dispatch(
event.get("event_name"),
payload=event.get("payload"),
middleware_id=self.event_handler_id,
)
except Empty:
await asyncio.sleep(0.1)
pass
except asyncio.CancelledError as e:
raise e # Raise a proper error

View File

@ -13,7 +13,6 @@ from pydantic import BaseModel, Field
from invokeai.app.invocations.upscale import ESRGAN_MODELS from invokeai.app.invocations.upscale import ESRGAN_MODELS
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
from invokeai.backend.image_util.infill_methods.patchmatch import PatchMatch from invokeai.backend.image_util.infill_methods.patchmatch import PatchMatch
from invokeai.backend.image_util.safety_checker import SafetyChecker
from invokeai.backend.util.logging import logging from invokeai.backend.util.logging import logging
from invokeai.version import __version__ from invokeai.version import __version__
@ -109,9 +108,7 @@ async def get_config() -> AppConfig:
upscaling_models.append(str(Path(model).stem)) upscaling_models.append(str(Path(model).stem))
upscaler = Upscaler(upscaling_method="esrgan", upscaling_models=upscaling_models) upscaler = Upscaler(upscaling_method="esrgan", upscaling_models=upscaling_models)
nsfw_methods = [] nsfw_methods = ["nsfw_checker"]
if SafetyChecker.safety_checker_available():
nsfw_methods.append("nsfw_checker")
watermarking_methods = ["invisible_watermark"] watermarking_methods = ["invisible_watermark"]

View File

@ -6,13 +6,12 @@ from fastapi import BackgroundTasks, Body, HTTPException, Path, Query, Request,
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from PIL import Image from PIL import Image
from pydantic import BaseModel, Field, ValidationError from pydantic import BaseModel, Field, JsonValue
from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID, WorkflowWithoutIDValidator
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
@ -42,13 +41,17 @@ async def upload_image(
board_id: Optional[str] = Query(default=None, description="The board to add this image to, if any"), board_id: Optional[str] = Query(default=None, description="The board to add this image to, if any"),
session_id: Optional[str] = Query(default=None, description="The session ID associated with this upload, if any"), session_id: Optional[str] = Query(default=None, description="The session ID associated with this upload, if any"),
crop_visible: Optional[bool] = Query(default=False, description="Whether to crop the image"), crop_visible: Optional[bool] = Query(default=False, description="Whether to crop the image"),
metadata: Optional[JsonValue] = Body(
default=None, description="The metadata to associate with the image", embed=True
),
) -> ImageDTO: ) -> ImageDTO:
"""Uploads an image""" """Uploads an image"""
if not file.content_type or not file.content_type.startswith("image"): if not file.content_type or not file.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image") raise HTTPException(status_code=415, detail="Not an image")
metadata = None _metadata = None
workflow = None _workflow = None
_graph = None
contents = await file.read() contents = await file.read()
try: try:
@ -62,21 +65,27 @@ async def upload_image(
# TODO: retain non-invokeai metadata on upload? # TODO: retain non-invokeai metadata on upload?
# attempt to parse metadata from image # attempt to parse metadata from image
metadata_raw = pil_image.info.get("invokeai_metadata", None) metadata_raw = metadata if isinstance(metadata, str) else pil_image.info.get("invokeai_metadata", None)
if metadata_raw: if isinstance(metadata_raw, str):
try: _metadata = metadata_raw
metadata = MetadataFieldValidator.validate_json(metadata_raw) else:
except ValidationError: ApiDependencies.invoker.services.logger.debug("Failed to parse metadata for uploaded image")
ApiDependencies.invoker.services.logger.warn("Failed to parse metadata for uploaded image")
pass pass
# attempt to parse workflow from image # attempt to parse workflow from image
workflow_raw = pil_image.info.get("invokeai_workflow", None) workflow_raw = pil_image.info.get("invokeai_workflow", None)
if workflow_raw is not None: if isinstance(workflow_raw, str):
try: _workflow = workflow_raw
workflow = WorkflowWithoutIDValidator.validate_json(workflow_raw) else:
except ValidationError: ApiDependencies.invoker.services.logger.debug("Failed to parse workflow for uploaded image")
ApiDependencies.invoker.services.logger.warn("Failed to parse metadata for uploaded image") pass
# attempt to extract graph from image
graph_raw = pil_image.info.get("invokeai_graph", None)
if isinstance(graph_raw, str):
_graph = graph_raw
else:
ApiDependencies.invoker.services.logger.debug("Failed to parse graph for uploaded image")
pass pass
try: try:
@ -86,8 +95,9 @@ async def upload_image(
image_category=image_category, image_category=image_category,
session_id=session_id, session_id=session_id,
board_id=board_id, board_id=board_id,
metadata=metadata, metadata=_metadata,
workflow=workflow, workflow=_workflow,
graph=_graph,
is_intermediate=is_intermediate, is_intermediate=is_intermediate,
) )
@ -185,14 +195,21 @@ async def get_image_metadata(
raise HTTPException(status_code=404) raise HTTPException(status_code=404)
class WorkflowAndGraphResponse(BaseModel):
workflow: Optional[str] = Field(description="The workflow used to generate the image, as stringified JSON")
graph: Optional[str] = Field(description="The graph used to generate the image, as stringified JSON")
@images_router.get( @images_router.get(
"/i/{image_name}/workflow", operation_id="get_image_workflow", response_model=Optional[WorkflowWithoutID] "/i/{image_name}/workflow", operation_id="get_image_workflow", response_model=WorkflowAndGraphResponse
) )
async def get_image_workflow( async def get_image_workflow(
image_name: str = Path(description="The name of image whose workflow to get"), image_name: str = Path(description="The name of image whose workflow to get"),
) -> Optional[WorkflowWithoutID]: ) -> WorkflowAndGraphResponse:
try: try:
return ApiDependencies.invoker.services.images.get_workflow(image_name) workflow = ApiDependencies.invoker.services.images.get_workflow(image_name)
graph = ApiDependencies.invoker.services.images.get_graph(image_name)
return WorkflowAndGraphResponse(workflow=workflow, graph=graph)
except Exception: except Exception:
raise HTTPException(status_code=404) raise HTTPException(status_code=404)

View File

@ -6,7 +6,7 @@ import pathlib
import shutil import shutil
import traceback import traceback
from copy import deepcopy from copy import deepcopy
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, Type
from fastapi import Body, Path, Query, Response, UploadFile from fastapi import Body, Path, Query, Response, UploadFile
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
@ -16,7 +16,8 @@ from pydantic import AnyHttpUrl, BaseModel, ConfigDict, Field
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from typing_extensions import Annotated from typing_extensions import Annotated
from invokeai.app.services.model_install import ModelInstallJob from invokeai.app.services.model_images.model_images_common import ModelImageFileNotFoundException
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
from invokeai.app.services.model_records import ( from invokeai.app.services.model_records import (
DuplicateModelException, DuplicateModelException,
InvalidModelException, InvalidModelException,
@ -52,6 +53,13 @@ class ModelsList(BaseModel):
model_config = ConfigDict(use_enum_values=True) model_config = ConfigDict(use_enum_values=True)
def add_cover_image_to_model_config(config: AnyModelConfig, dependencies: Type[ApiDependencies]) -> AnyModelConfig:
"""Add a cover image URL to a model configuration."""
cover_image = dependencies.invoker.services.model_images.get_url(config.key)
config.cover_image = cover_image
return config
############################################################################## ##############################################################################
# These are example inputs and outputs that are used in places where Swagger # These are example inputs and outputs that are used in places where Swagger
# is unable to generate a correct example. # is unable to generate a correct example.
@ -118,8 +126,7 @@ async def list_model_records(
record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format) record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format)
) )
for model in found_models: for model in found_models:
cover_image = ApiDependencies.invoker.services.model_images.get_url(model.key) model = add_cover_image_to_model_config(model, ApiDependencies)
model.cover_image = cover_image
return ModelsList(models=found_models) return ModelsList(models=found_models)
@ -160,12 +167,9 @@ async def get_model_record(
key: str = Path(description="Key of the model record to fetch."), key: str = Path(description="Key of the model record to fetch."),
) -> AnyModelConfig: ) -> AnyModelConfig:
"""Get a model record""" """Get a model record"""
record_store = ApiDependencies.invoker.services.model_manager.store
try: try:
config: AnyModelConfig = record_store.get_model(key) config = ApiDependencies.invoker.services.model_manager.store.get_model(key)
cover_image = ApiDependencies.invoker.services.model_images.get_url(key) return add_cover_image_to_model_config(config, ApiDependencies)
config.cover_image = cover_image
return config
except UnknownModelException as e: except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
@ -294,14 +298,15 @@ async def update_model_record(
installer = ApiDependencies.invoker.services.model_manager.install installer = ApiDependencies.invoker.services.model_manager.install
try: try:
record_store.update_model(key, changes=changes) record_store.update_model(key, changes=changes)
model_response: AnyModelConfig = installer.sync_model_path(key) config = installer.sync_model_path(key)
config = add_cover_image_to_model_config(config, ApiDependencies)
logger.info(f"Updated model: {key}") logger.info(f"Updated model: {key}")
except UnknownModelException as e: except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
except ValueError as e: except ValueError as e:
logger.error(str(e)) logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e)) raise HTTPException(status_code=409, detail=str(e))
return model_response return config
@model_manager_router.get( @model_manager_router.get(
@ -648,6 +653,14 @@ async def convert_model(
logger.error(str(e)) logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e)) raise HTTPException(status_code=409, detail=str(e))
# Update the model image if the model had one
try:
model_image = ApiDependencies.invoker.services.model_images.get(key)
ApiDependencies.invoker.services.model_images.save(model_image, new_key)
ApiDependencies.invoker.services.model_images.delete(key)
except ModelImageFileNotFoundException:
pass
# delete the original safetensors file # delete the original safetensors file
installer.delete(key) installer.delete(key)
@ -655,7 +668,8 @@ async def convert_model(
shutil.rmtree(cache_path) shutil.rmtree(cache_path)
# return the config record for the new diffusers directory # return the config record for the new diffusers directory
new_config: AnyModelConfig = store.get_model(new_key) new_config = store.get_model(new_key)
new_config = add_cover_image_to_model_config(new_config, ApiDependencies)
return new_config return new_config

View File

@ -203,6 +203,7 @@ async def get_batch_status(
responses={ responses={
200: {"model": SessionQueueItem}, 200: {"model": SessionQueueItem},
}, },
response_model_exclude_none=True,
) )
async def get_queue_item( async def get_queue_item(
queue_id: str = Path(description="The queue id to perform this operation on"), queue_id: str = Path(description="The queue id to perform this operation on"),

View File

@ -1,66 +1,125 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Any
from fastapi import FastAPI from fastapi import FastAPI
from fastapi_events.handlers.local import local_handler from pydantic import BaseModel
from fastapi_events.typing import Event
from socketio import ASGIApp, AsyncServer from socketio import ASGIApp, AsyncServer
from ..services.events.events_base import EventServiceBase from invokeai.app.services.events.events_common import (
BatchEnqueuedEvent,
BulkDownloadCompleteEvent,
BulkDownloadErrorEvent,
BulkDownloadEventBase,
BulkDownloadStartedEvent,
DownloadCancelledEvent,
DownloadCompleteEvent,
DownloadErrorEvent,
DownloadEventBase,
DownloadProgressEvent,
DownloadStartedEvent,
FastAPIEvent,
InvocationCompleteEvent,
InvocationDenoiseProgressEvent,
InvocationErrorEvent,
InvocationStartedEvent,
ModelEventBase,
ModelInstallCancelledEvent,
ModelInstallCompleteEvent,
ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent,
ModelInstallErrorEvent,
ModelInstallStartedEvent,
ModelLoadCompleteEvent,
ModelLoadStartedEvent,
QueueClearedEvent,
QueueEventBase,
QueueItemStatusChangedEvent,
register_events,
)
class QueueSubscriptionEvent(BaseModel):
"""Event data for subscribing to the socket.io queue room.
This is a pydantic model to ensure the data is in the correct format."""
queue_id: str
class BulkDownloadSubscriptionEvent(BaseModel):
"""Event data for subscribing to the socket.io bulk downloads room.
This is a pydantic model to ensure the data is in the correct format."""
bulk_download_id: str
QUEUE_EVENTS = {
InvocationStartedEvent,
InvocationDenoiseProgressEvent,
InvocationCompleteEvent,
InvocationErrorEvent,
QueueItemStatusChangedEvent,
BatchEnqueuedEvent,
QueueClearedEvent,
}
MODEL_EVENTS = {
DownloadCancelledEvent,
DownloadCompleteEvent,
DownloadErrorEvent,
DownloadProgressEvent,
DownloadStartedEvent,
ModelLoadStartedEvent,
ModelLoadCompleteEvent,
ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent,
ModelInstallStartedEvent,
ModelInstallCompleteEvent,
ModelInstallCancelledEvent,
ModelInstallErrorEvent,
}
BULK_DOWNLOAD_EVENTS = {BulkDownloadStartedEvent, BulkDownloadCompleteEvent, BulkDownloadErrorEvent}
class SocketIO: class SocketIO:
__sio: AsyncServer _sub_queue = "subscribe_queue"
__app: ASGIApp _unsub_queue = "unsubscribe_queue"
__sub_queue: str = "subscribe_queue" _sub_bulk_download = "subscribe_bulk_download"
__unsub_queue: str = "unsubscribe_queue" _unsub_bulk_download = "unsubscribe_bulk_download"
__sub_bulk_download: str = "subscribe_bulk_download"
__unsub_bulk_download: str = "unsubscribe_bulk_download"
def __init__(self, app: FastAPI): def __init__(self, app: FastAPI):
self.__sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*") self._sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*")
self.__app = ASGIApp(socketio_server=self.__sio, socketio_path="/ws/socket.io") self._app = ASGIApp(socketio_server=self._sio, socketio_path="/ws/socket.io")
app.mount("/ws", self.__app) app.mount("/ws", self._app)
self.__sio.on(self.__sub_queue, handler=self._handle_sub_queue) self._sio.on(self._sub_queue, handler=self._handle_sub_queue)
self.__sio.on(self.__unsub_queue, handler=self._handle_unsub_queue) self._sio.on(self._unsub_queue, handler=self._handle_unsub_queue)
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event) self._sio.on(self._sub_bulk_download, handler=self._handle_sub_bulk_download)
local_handler.register(event_name=EventServiceBase.model_event, _func=self._handle_model_event) self._sio.on(self._unsub_bulk_download, handler=self._handle_unsub_bulk_download)
self.__sio.on(self.__sub_bulk_download, handler=self._handle_sub_bulk_download) register_events(QUEUE_EVENTS, self._handle_queue_event)
self.__sio.on(self.__unsub_bulk_download, handler=self._handle_unsub_bulk_download) register_events(MODEL_EVENTS, self._handle_model_event)
local_handler.register(event_name=EventServiceBase.bulk_download_event, _func=self._handle_bulk_download_event) register_events(BULK_DOWNLOAD_EVENTS, self._handle_bulk_image_download_event)
async def _handle_queue_event(self, event: Event): async def _handle_sub_queue(self, sid: str, data: Any) -> None:
await self.__sio.emit( await self._sio.enter_room(sid, QueueSubscriptionEvent(**data).queue_id)
event=event[1]["event"],
data=event[1]["data"],
room=event[1]["data"]["queue_id"],
)
async def _handle_sub_queue(self, sid, data, *args, **kwargs) -> None: async def _handle_unsub_queue(self, sid: str, data: Any) -> None:
if "queue_id" in data: await self._sio.leave_room(sid, QueueSubscriptionEvent(**data).queue_id)
await self.__sio.enter_room(sid, data["queue_id"])
async def _handle_unsub_queue(self, sid, data, *args, **kwargs) -> None: async def _handle_sub_bulk_download(self, sid: str, data: Any) -> None:
if "queue_id" in data: await self._sio.enter_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
await self.__sio.leave_room(sid, data["queue_id"])
async def _handle_model_event(self, event: Event) -> None: async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None:
await self.__sio.emit(event=event[1]["event"], data=event[1]["data"]) await self._sio.leave_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
async def _handle_bulk_download_event(self, event: Event): async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]):
await self.__sio.emit( await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].queue_id)
event=event[1]["event"],
data=event[1]["data"],
room=event[1]["data"]["bulk_download_id"],
)
async def _handle_sub_bulk_download(self, sid, data, *args, **kwargs): async def _handle_model_event(self, event: FastAPIEvent[ModelEventBase | DownloadEventBase]) -> None:
if "bulk_download_id" in data: await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"))
await self.__sio.enter_room(sid, data["bulk_download_id"])
async def _handle_unsub_bulk_download(self, sid, data, *args, **kwargs): async def _handle_bulk_image_download_event(self, event: FastAPIEvent[BulkDownloadEventBase]) -> None:
if "bulk_download_id" in data: await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].bulk_download_id)
await self.__sio.leave_room(sid, data["bulk_download_id"])

View File

@ -3,9 +3,7 @@ import logging
import mimetypes import mimetypes
import socket import socket
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from inspect import signature
from pathlib import Path from pathlib import Path
from typing import Any
import torch import torch
import uvicorn import uvicorn
@ -13,11 +11,9 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware from fastapi.middleware.gzip import GZipMiddleware
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
from fastapi.openapi.utils import get_openapi
from fastapi.responses import HTMLResponse from fastapi.responses import HTMLResponse
from fastapi_events.handlers.local import local_handler from fastapi_events.handlers.local import local_handler
from fastapi_events.middleware import EventHandlerASGIMiddleware from fastapi_events.middleware import EventHandlerASGIMiddleware
from pydantic.json_schema import models_json_schema
from torch.backends.mps import is_available as is_mps_available from torch.backends.mps import is_available as is_mps_available
# for PyCharm: # for PyCharm:
@ -25,9 +21,8 @@ from torch.backends.mps import is_available as is_mps_available
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import) import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
import invokeai.frontend.web as web_dir import invokeai.frontend.web as web_dir
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.config.config_default import get_config from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.session_processor.session_processor_common import ProgressImage from invokeai.app.util.custom_openapi import get_openapi_func
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
from ..backend.util.logging import InvokeAILogger from ..backend.util.logging import InvokeAILogger
@ -44,11 +39,6 @@ from .api.routers import (
workflows, workflows,
) )
from .api.sockets import SocketIO from .api.sockets import SocketIO
from .invocations.baseinvocation import (
BaseInvocation,
UIConfigBase,
)
from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
app_config = get_config() app_config = get_config()
@ -118,85 +108,7 @@ app.include_router(app_info.app_router, prefix="/api")
app.include_router(session_queue.session_queue_router, prefix="/api") app.include_router(session_queue.session_queue_router, prefix="/api")
app.include_router(workflows.workflows_router, prefix="/api") app.include_router(workflows.workflows_router, prefix="/api")
app.openapi = get_openapi_func(app)
# Build a custom OpenAPI to include all outputs
# TODO: can outputs be included on metadata of invocation schemas somehow?
def custom_openapi() -> dict[str, Any]:
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
title=app.title,
description="An API for invoking AI image operations",
version="1.0.0",
routes=app.routes,
separate_input_output_schemas=False, # https://fastapi.tiangolo.com/how-to/separate-openapi-schemas/
)
# Add all outputs
all_invocations = BaseInvocation.get_invocations()
output_types = set()
output_type_titles = {}
for invoker in all_invocations:
output_type = signature(invoker.invoke).return_annotation
output_types.add(output_type)
output_schemas = models_json_schema(
models=[(o, "serialization") for o in output_types], ref_template="#/components/schemas/{model}"
)
for schema_key, output_schema in output_schemas[1]["$defs"].items():
# TODO: note that we assume the schema_key here is the TYPE.__name__
# This could break in some cases, figure out a better way to do it
output_type_titles[schema_key] = output_schema["title"]
openapi_schema["components"]["schemas"][schema_key] = output_schema
openapi_schema["components"]["schemas"][schema_key]["class"] = "output"
# Some models don't end up in the schemas as standalone definitions
additional_schemas = models_json_schema(
[
(UIConfigBase, "serialization"),
(InputFieldJSONSchemaExtra, "serialization"),
(OutputFieldJSONSchemaExtra, "serialization"),
(ModelIdentifierField, "serialization"),
(ProgressImage, "serialization"),
],
ref_template="#/components/schemas/{model}",
)
for schema_key, schema_json in additional_schemas[1]["$defs"].items():
openapi_schema["components"]["schemas"][schema_key] = schema_json
# Add a reference to the output type to additionalProperties of the invoker schema
for invoker in all_invocations:
invoker_name = invoker.__name__ # type: ignore [attr-defined] # this is a valid attribute
output_type = signature(obj=invoker.invoke).return_annotation
output_type_title = output_type_titles[output_type.__name__]
invoker_schema = openapi_schema["components"]["schemas"][f"{invoker_name}"]
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
invoker_schema["output"] = outputs_ref
invoker_schema["class"] = "invocation"
# This code no longer seems to be necessary?
# Leave it here just in case
#
# from invokeai.backend.model_manager import get_model_config_formats
# formats = get_model_config_formats()
# for model_config_name, enum_set in formats.items():
# if model_config_name in openapi_schema["components"]["schemas"]:
# # print(f"Config with name {name} already defined")
# continue
# openapi_schema["components"]["schemas"][model_config_name] = {
# "title": model_config_name,
# "description": "An enumeration.",
# "type": "string",
# "enum": [v.value for v in enum_set],
# }
app.openapi_schema = openapi_schema
return app.openapi_schema
app.openapi = custom_openapi # type: ignore [method-assign] # this is a valid assignment
@app.get("/docs", include_in_schema=False) @app.get("/docs", include_in_schema=False)

View File

@ -98,11 +98,13 @@ class BaseInvocationOutput(BaseModel):
_output_classes: ClassVar[set[BaseInvocationOutput]] = set() _output_classes: ClassVar[set[BaseInvocationOutput]] = set()
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None _typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
_typeadapter_needs_update: ClassVar[bool] = False
@classmethod @classmethod
def register_output(cls, output: BaseInvocationOutput) -> None: def register_output(cls, output: BaseInvocationOutput) -> None:
"""Registers an invocation output.""" """Registers an invocation output."""
cls._output_classes.add(output) cls._output_classes.add(output)
cls._typeadapter_needs_update = True
@classmethod @classmethod
def get_outputs(cls) -> Iterable[BaseInvocationOutput]: def get_outputs(cls) -> Iterable[BaseInvocationOutput]:
@ -112,11 +114,12 @@ class BaseInvocationOutput(BaseModel):
@classmethod @classmethod
def get_typeadapter(cls) -> TypeAdapter[Any]: def get_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a pydantc TypeAdapter for the union of all invocation output types.""" """Gets a pydantc TypeAdapter for the union of all invocation output types."""
if not cls._typeadapter: if not cls._typeadapter or cls._typeadapter_needs_update:
InvocationOutputsUnion = TypeAliasType( AnyInvocationOutput = TypeAliasType(
"InvocationOutputsUnion", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")] "AnyInvocationOutput", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
) )
cls._typeadapter = TypeAdapter(InvocationOutputsUnion) cls._typeadapter = TypeAdapter(AnyInvocationOutput)
cls._typeadapter_needs_update = False
return cls._typeadapter return cls._typeadapter
@classmethod @classmethod
@ -125,12 +128,13 @@ class BaseInvocationOutput(BaseModel):
return (i.get_type() for i in BaseInvocationOutput.get_outputs()) return (i.get_type() for i in BaseInvocationOutput.get_outputs())
@staticmethod @staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None: def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocationOutput]) -> None:
"""Adds various UI-facing attributes to the invocation output's OpenAPI schema.""" """Adds various UI-facing attributes to the invocation output's OpenAPI schema."""
# Because we use a pydantic Literal field with default value for the invocation type, # Because we use a pydantic Literal field with default value for the invocation type,
# it will be typed as optional in the OpenAPI schema. Make it required manually. # it will be typed as optional in the OpenAPI schema. Make it required manually.
if "required" not in schema or not isinstance(schema["required"], list): if "required" not in schema or not isinstance(schema["required"], list):
schema["required"] = [] schema["required"] = []
schema["class"] = "output"
schema["required"].extend(["type"]) schema["required"].extend(["type"])
@classmethod @classmethod
@ -167,6 +171,7 @@ class BaseInvocation(ABC, BaseModel):
_invocation_classes: ClassVar[set[BaseInvocation]] = set() _invocation_classes: ClassVar[set[BaseInvocation]] = set()
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None _typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
_typeadapter_needs_update: ClassVar[bool] = False
@classmethod @classmethod
def get_type(cls) -> str: def get_type(cls) -> str:
@ -177,15 +182,17 @@ class BaseInvocation(ABC, BaseModel):
def register_invocation(cls, invocation: BaseInvocation) -> None: def register_invocation(cls, invocation: BaseInvocation) -> None:
"""Registers an invocation.""" """Registers an invocation."""
cls._invocation_classes.add(invocation) cls._invocation_classes.add(invocation)
cls._typeadapter_needs_update = True
@classmethod @classmethod
def get_typeadapter(cls) -> TypeAdapter[Any]: def get_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a pydantc TypeAdapter for the union of all invocation types.""" """Gets a pydantc TypeAdapter for the union of all invocation types."""
if not cls._typeadapter: if not cls._typeadapter or cls._typeadapter_needs_update:
InvocationsUnion = TypeAliasType( AnyInvocation = TypeAliasType(
"InvocationsUnion", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")] "AnyInvocation", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")]
) )
cls._typeadapter = TypeAdapter(InvocationsUnion) cls._typeadapter = TypeAdapter(AnyInvocation)
cls._typeadapter_needs_update = False
return cls._typeadapter return cls._typeadapter
@classmethod @classmethod
@ -221,7 +228,7 @@ class BaseInvocation(ABC, BaseModel):
return signature(cls.invoke).return_annotation return signature(cls.invoke).return_annotation
@staticmethod @staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel], *args, **kwargs) -> None: def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None:
"""Adds various UI-facing attributes to the invocation's OpenAPI schema.""" """Adds various UI-facing attributes to the invocation's OpenAPI schema."""
uiconfig = cast(UIConfigBase | None, getattr(model_class, "UIConfig", None)) uiconfig = cast(UIConfigBase | None, getattr(model_class, "UIConfig", None))
if uiconfig is not None: if uiconfig is not None:
@ -237,6 +244,7 @@ class BaseInvocation(ABC, BaseModel):
schema["version"] = uiconfig.version schema["version"] = uiconfig.version
if "required" not in schema or not isinstance(schema["required"], list): if "required" not in schema or not isinstance(schema["required"], list):
schema["required"] = [] schema["required"] = []
schema["class"] = "invocation"
schema["required"].extend(["type", "id"]) schema["required"].extend(["type", "id"])
@abstractmethod @abstractmethod
@ -310,7 +318,7 @@ class BaseInvocation(ABC, BaseModel):
protected_namespaces=(), protected_namespaces=(),
validate_assignment=True, validate_assignment=True,
json_schema_extra=json_schema_extra, json_schema_extra=json_schema_extra,
json_schema_serialization_defaults_required=True, json_schema_serialization_defaults_required=False,
coerce_numbers_to_str=True, coerce_numbers_to_str=True,
) )

View File

@ -65,11 +65,7 @@ class CompelInvocation(BaseInvocation):
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput: def invoke(self, context: InvocationContext) -> ConditioningOutput:
tokenizer_info = context.models.load(self.clip.tokenizer) tokenizer_info = context.models.load(self.clip.tokenizer)
tokenizer_model = tokenizer_info.model
assert isinstance(tokenizer_model, CLIPTokenizer)
text_encoder_info = context.models.load(self.clip.text_encoder) text_encoder_info = context.models.load(self.clip.text_encoder)
text_encoder_model = text_encoder_info.model
assert isinstance(text_encoder_model, CLIPTextModel)
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.clip.loras: for lora in self.clip.loras:
@ -84,19 +80,21 @@ class CompelInvocation(BaseInvocation):
ti_list = generate_ti_list(self.prompt, text_encoder_info.config.base, context) ti_list = generate_ti_list(self.prompt, text_encoder_info.config.base, context)
with ( with (
ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as ( # apply all patches while the model is on the target device
tokenizer,
ti_manager,
),
text_encoder_info as text_encoder, text_encoder_info as text_encoder,
# Apply the LoRA after text_encoder has been moved to its target device for faster patching. tokenizer_info as tokenizer,
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()), ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_model, self.clip.skipped_layers), ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers),
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (
patched_tokenizer,
ti_manager,
),
): ):
assert isinstance(text_encoder, CLIPTextModel) assert isinstance(text_encoder, CLIPTextModel)
assert isinstance(tokenizer, CLIPTokenizer)
compel = Compel( compel = Compel(
tokenizer=tokenizer, tokenizer=patched_tokenizer,
text_encoder=text_encoder, text_encoder=text_encoder,
textual_inversion_manager=ti_manager, textual_inversion_manager=ti_manager,
dtype_for_device_getter=TorchDevice.choose_torch_dtype, dtype_for_device_getter=TorchDevice.choose_torch_dtype,
@ -106,7 +104,7 @@ class CompelInvocation(BaseInvocation):
conjunction = Compel.parse_prompt_string(self.prompt) conjunction = Compel.parse_prompt_string(self.prompt)
if context.config.get().log_tokenization: if context.config.get().log_tokenization:
log_tokenization_for_conjunction(conjunction, tokenizer) log_tokenization_for_conjunction(conjunction, patched_tokenizer)
c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction) c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction)
@ -136,11 +134,7 @@ class SDXLPromptInvocationBase:
zero_on_empty: bool, zero_on_empty: bool,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
tokenizer_info = context.models.load(clip_field.tokenizer) tokenizer_info = context.models.load(clip_field.tokenizer)
tokenizer_model = tokenizer_info.model
assert isinstance(tokenizer_model, CLIPTokenizer)
text_encoder_info = context.models.load(clip_field.text_encoder) text_encoder_info = context.models.load(clip_field.text_encoder)
text_encoder_model = text_encoder_info.model
assert isinstance(text_encoder_model, (CLIPTextModel, CLIPTextModelWithProjection))
# return zero on empty # return zero on empty
if prompt == "" and zero_on_empty: if prompt == "" and zero_on_empty:
@ -177,20 +171,23 @@ class SDXLPromptInvocationBase:
ti_list = generate_ti_list(prompt, text_encoder_info.config.base, context) ti_list = generate_ti_list(prompt, text_encoder_info.config.base, context)
with ( with (
ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as ( # apply all patches while the model is on the target device
tokenizer,
ti_manager,
),
text_encoder_info as text_encoder, text_encoder_info as text_encoder,
# Apply the LoRA after text_encoder has been moved to its target device for faster patching. tokenizer_info as tokenizer,
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix), ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers), ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers),
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (
patched_tokenizer,
ti_manager,
),
): ):
assert isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)) assert isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
assert isinstance(tokenizer, CLIPTokenizer)
text_encoder = cast(CLIPTextModel, text_encoder) text_encoder = cast(CLIPTextModel, text_encoder)
compel = Compel( compel = Compel(
tokenizer=tokenizer, tokenizer=patched_tokenizer,
text_encoder=text_encoder, text_encoder=text_encoder,
textual_inversion_manager=ti_manager, textual_inversion_manager=ti_manager,
dtype_for_device_getter=TorchDevice.choose_torch_dtype, dtype_for_device_getter=TorchDevice.choose_torch_dtype,
@ -203,7 +200,7 @@ class SDXLPromptInvocationBase:
if context.config.get().log_tokenization: if context.config.get().log_tokenization:
# TODO: better logging for and syntax # TODO: better logging for and syntax
log_tokenization_for_conjunction(conjunction, tokenizer) log_tokenization_for_conjunction(conjunction, patched_tokenizer)
# TODO: ask for optimizations? to not run text_encoder twice # TODO: ask for optimizations? to not run text_encoder twice
c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction) c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction)

View File

@ -24,7 +24,6 @@ from pydantic import BaseModel, Field, field_validator, model_validator
from invokeai.app.invocations.fields import ( from invokeai.app.invocations.fields import (
FieldDescriptions, FieldDescriptions,
ImageField, ImageField,
Input,
InputField, InputField,
OutputField, OutputField,
UIType, UIType,
@ -80,13 +79,13 @@ class ControlOutput(BaseInvocationOutput):
control: ControlField = OutputField(description=FieldDescriptions.control) control: ControlField = OutputField(description=FieldDescriptions.control)
@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.1.1") @invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.1.2")
class ControlNetInvocation(BaseInvocation): class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes""" """Collects ControlNet info to pass to other nodes"""
image: ImageField = InputField(description="The control image") image: ImageField = InputField(description="The control image")
control_model: ModelIdentifierField = InputField( control_model: ModelIdentifierField = InputField(
description=FieldDescriptions.controlnet_model, input=Input.Direct, ui_type=UIType.ControlNetModel description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
) )
control_weight: Union[float, List[float]] = InputField( control_weight: Union[float, List[float]] = InputField(
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet" default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"

View File

@ -1,6 +1,5 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from pathlib import Path
from typing import Literal, Optional from typing import Literal, Optional
import cv2 import cv2
@ -504,7 +503,7 @@ class ImageInverseLerpInvocation(BaseInvocation, WithMetadata, WithBoard):
title="Blur NSFW Image", title="Blur NSFW Image",
tags=["image", "nsfw"], tags=["image", "nsfw"],
category="image", category="image",
version="1.2.2", version="1.2.3",
) )
class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata, WithBoard): class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Add blur to NSFW-flagged images""" """Add blur to NSFW-flagged images"""
@ -516,23 +515,12 @@ class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata, WithBoard):
logger = context.logger logger = context.logger
logger.debug("Running NSFW checker") logger.debug("Running NSFW checker")
if SafetyChecker.has_nsfw_concept(image): image = SafetyChecker.blur_if_nsfw(image)
logger.info("A potentially NSFW image has been detected. Image will be blurred.")
blurry_image = image.filter(filter=ImageFilter.GaussianBlur(radius=32))
caution = self._get_caution_img()
blurry_image.paste(caution, (0, 0), caution)
image = blurry_image
image_dto = context.images.save(image=image) image_dto = context.images.save(image=image)
return ImageOutput.build(image_dto) return ImageOutput.build(image_dto)
def _get_caution_img(self) -> Image.Image:
import invokeai.app.assets.images as image_assets
caution = Image.open(Path(image_assets.__path__[0]) / "caution.png")
return caution.resize((caution.width // 2, caution.height // 2))
@invocation( @invocation(
"img_watermark", "img_watermark",

View File

@ -5,7 +5,7 @@ from pydantic import BaseModel, Field, field_validator, model_validator
from typing_extensions import Self from typing_extensions import Self
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, TensorField, UIType from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, TensorField, UIType
from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
@ -58,7 +58,7 @@ class IPAdapterOutput(BaseInvocationOutput):
CLIP_VISION_MODEL_MAP = {"ViT-H": "ip_adapter_sd_image_encoder", "ViT-G": "ip_adapter_sdxl_image_encoder"} CLIP_VISION_MODEL_MAP = {"ViT-H": "ip_adapter_sd_image_encoder", "ViT-G": "ip_adapter_sdxl_image_encoder"}
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.4.0") @invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.4.1")
class IPAdapterInvocation(BaseInvocation): class IPAdapterInvocation(BaseInvocation):
"""Collects IP-Adapter info to pass to other nodes.""" """Collects IP-Adapter info to pass to other nodes."""
@ -67,7 +67,6 @@ class IPAdapterInvocation(BaseInvocation):
ip_adapter_model: ModelIdentifierField = InputField( ip_adapter_model: ModelIdentifierField = InputField(
description="The IP-Adapter model.", description="The IP-Adapter model.",
title="IP-Adapter Model", title="IP-Adapter Model",
input=Input.Direct,
ui_order=-1, ui_order=-1,
ui_type=UIType.IPAdapterModel, ui_type=UIType.IPAdapterModel,
) )

View File

@ -586,13 +586,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: Scheduler, scheduler: Scheduler,
) -> StableDiffusionGeneratorPipeline: ) -> StableDiffusionGeneratorPipeline:
# TODO:
# configure_model_padding(
# unet,
# self.seamless,
# self.seamless_axes,
# )
class FakeVae: class FakeVae:
class FakeVaeConfig: class FakeVaeConfig:
def __init__(self) -> None: def __init__(self) -> None:
@ -937,9 +930,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
assert isinstance(unet_info.model, UNet2DConditionModel) assert isinstance(unet_info.model, UNet2DConditionModel)
with ( with (
ExitStack() as exit_stack, ExitStack() as exit_stack,
ModelPatcher.apply_freeu(unet_info.model, self.unet.freeu_config),
set_seamless(unet_info.model, self.unet.seamless_axes), # FIXME
unet_info as unet, unet_info as unet,
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
set_seamless(unet, self.unet.seamless_axes), # FIXME
# Apply the LoRA after unet has been moved to its target device for faster patching. # Apply the LoRA after unet has been moved to its target device for faster patching.
ModelPatcher.apply_lora_unet(unet, _lora_loader()), ModelPatcher.apply_lora_unet(unet, _lora_loader()),
): ):

View File

@ -11,6 +11,7 @@ from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType,
from .baseinvocation import ( from .baseinvocation import (
BaseInvocation, BaseInvocation,
BaseInvocationOutput, BaseInvocationOutput,
Classification,
invocation, invocation,
invocation_output, invocation_output,
) )
@ -93,19 +94,46 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput):
pass pass
@invocation_output("model_identifier_output")
class ModelIdentifierOutput(BaseInvocationOutput):
"""Model identifier output"""
model: ModelIdentifierField = OutputField(description="Model identifier", title="Model")
@invocation(
"model_identifier",
title="Model identifier",
tags=["model"],
category="model",
version="1.0.0",
classification=Classification.Prototype,
)
class ModelIdentifierInvocation(BaseInvocation):
"""Selects any model, outputting it its identifier. Be careful with this one! The identifier will be accepted as
input for any model, even if the model types don't match. If you connect this to a mismatched input, you'll get an
error."""
model: ModelIdentifierField = InputField(description="The model to select", title="Model")
def invoke(self, context: InvocationContext) -> ModelIdentifierOutput:
if not context.models.exists(self.model.key):
raise Exception(f"Unknown model {self.model.key}")
return ModelIdentifierOutput(model=self.model)
@invocation( @invocation(
"main_model_loader", "main_model_loader",
title="Main Model", title="Main Model",
tags=["model"], tags=["model"],
category="model", category="model",
version="1.0.2", version="1.0.3",
) )
class MainModelLoaderInvocation(BaseInvocation): class MainModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels.""" """Loads a main model, outputting its submodels."""
model: ModelIdentifierField = InputField( model: ModelIdentifierField = InputField(description=FieldDescriptions.main_model, ui_type=UIType.MainModel)
description=FieldDescriptions.main_model, input=Input.Direct, ui_type=UIType.MainModel
)
# TODO: precision? # TODO: precision?
def invoke(self, context: InvocationContext) -> ModelLoaderOutput: def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
@ -134,12 +162,12 @@ class LoRALoaderOutput(BaseInvocationOutput):
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP") clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.2") @invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.3")
class LoRALoaderInvocation(BaseInvocation): class LoRALoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder.""" """Apply selected lora to unet and text_encoder."""
lora: ModelIdentifierField = InputField( lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
) )
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight) weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
unet: Optional[UNetField] = InputField( unet: Optional[UNetField] = InputField(
@ -190,6 +218,75 @@ class LoRALoaderInvocation(BaseInvocation):
return output return output
@invocation_output("lora_selector_output")
class LoRASelectorOutput(BaseInvocationOutput):
"""Model loader output"""
lora: LoRAField = OutputField(description="LoRA model and weight", title="LoRA")
@invocation("lora_selector", title="LoRA Selector", tags=["model"], category="model", version="1.0.1")
class LoRASelectorInvocation(BaseInvocation):
"""Selects a LoRA model and weight."""
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
def invoke(self, context: InvocationContext) -> LoRASelectorOutput:
return LoRASelectorOutput(lora=LoRAField(lora=self.lora, weight=self.weight))
@invocation("lora_collection_loader", title="LoRA Collection Loader", tags=["model"], category="model", version="1.0.0")
class LoRACollectionLoader(BaseInvocation):
"""Applies a collection of LoRAs to the provided UNet and CLIP models."""
loras: LoRAField | list[LoRAField] = InputField(
description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs"
)
unet: Optional[UNetField] = InputField(
default=None,
description=FieldDescriptions.unet,
input=Input.Connection,
title="UNet",
)
clip: Optional[CLIPField] = InputField(
default=None,
description=FieldDescriptions.clip,
input=Input.Connection,
title="CLIP",
)
def invoke(self, context: InvocationContext) -> LoRALoaderOutput:
output = LoRALoaderOutput()
loras = self.loras if isinstance(self.loras, list) else [self.loras]
added_loras: list[str] = []
for lora in loras:
if lora.lora.key in added_loras:
continue
if not context.models.exists(lora.lora.key):
raise Exception(f"Unknown lora: {lora.lora.key}!")
assert lora.lora.base in (BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2)
added_loras.append(lora.lora.key)
if self.unet is not None:
if output.unet is None:
output.unet = self.unet.model_copy(deep=True)
output.unet.loras.append(lora)
if self.clip is not None:
if output.clip is None:
output.clip = self.clip.model_copy(deep=True)
output.clip.loras.append(lora)
return output
@invocation_output("sdxl_lora_loader_output") @invocation_output("sdxl_lora_loader_output")
class SDXLLoRALoaderOutput(BaseInvocationOutput): class SDXLLoRALoaderOutput(BaseInvocationOutput):
"""SDXL LoRA Loader Output""" """SDXL LoRA Loader Output"""
@ -204,13 +301,13 @@ class SDXLLoRALoaderOutput(BaseInvocationOutput):
title="SDXL LoRA", title="SDXL LoRA",
tags=["lora", "model"], tags=["lora", "model"],
category="model", category="model",
version="1.0.2", version="1.0.3",
) )
class SDXLLoRALoaderInvocation(BaseInvocation): class SDXLLoRALoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder.""" """Apply selected lora to unet and text_encoder."""
lora: ModelIdentifierField = InputField( lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
) )
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight) weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
unet: Optional[UNetField] = InputField( unet: Optional[UNetField] = InputField(
@ -279,12 +376,78 @@ class SDXLLoRALoaderInvocation(BaseInvocation):
return output return output
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.2") @invocation(
"sdxl_lora_collection_loader",
title="SDXL LoRA Collection Loader",
tags=["model"],
category="model",
version="1.0.0",
)
class SDXLLoRACollectionLoader(BaseInvocation):
"""Applies a collection of SDXL LoRAs to the provided UNet and CLIP models."""
loras: LoRAField | list[LoRAField] = InputField(
description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs"
)
unet: Optional[UNetField] = InputField(
default=None,
description=FieldDescriptions.unet,
input=Input.Connection,
title="UNet",
)
clip: Optional[CLIPField] = InputField(
default=None,
description=FieldDescriptions.clip,
input=Input.Connection,
title="CLIP",
)
clip2: Optional[CLIPField] = InputField(
default=None,
description=FieldDescriptions.clip,
input=Input.Connection,
title="CLIP 2",
)
def invoke(self, context: InvocationContext) -> SDXLLoRALoaderOutput:
output = SDXLLoRALoaderOutput()
loras = self.loras if isinstance(self.loras, list) else [self.loras]
added_loras: list[str] = []
for lora in loras:
if lora.lora.key in added_loras:
continue
if not context.models.exists(lora.lora.key):
raise Exception(f"Unknown lora: {lora.lora.key}!")
assert lora.lora.base is BaseModelType.StableDiffusionXL
added_loras.append(lora.lora.key)
if self.unet is not None:
if output.unet is None:
output.unet = self.unet.model_copy(deep=True)
output.unet.loras.append(lora)
if self.clip is not None:
if output.clip is None:
output.clip = self.clip.model_copy(deep=True)
output.clip.loras.append(lora)
if self.clip2 is not None:
if output.clip2 is None:
output.clip2 = self.clip2.model_copy(deep=True)
output.clip2.loras.append(lora)
return output
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.3")
class VAELoaderInvocation(BaseInvocation): class VAELoaderInvocation(BaseInvocation):
"""Loads a VAE model, outputting a VaeLoaderOutput""" """Loads a VAE model, outputting a VaeLoaderOutput"""
vae_model: ModelIdentifierField = InputField( vae_model: ModelIdentifierField = InputField(
description=FieldDescriptions.vae_model, input=Input.Direct, title="VAE", ui_type=UIType.VAEModel description=FieldDescriptions.vae_model, title="VAE", ui_type=UIType.VAEModel
) )
def invoke(self, context: InvocationContext) -> VAEOutput: def invoke(self, context: InvocationContext) -> VAEOutput:

View File

@ -1,4 +1,4 @@
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, UIType
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager import SubModelType from invokeai.backend.model_manager import SubModelType
@ -30,12 +30,12 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE") vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.2") @invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.3")
class SDXLModelLoaderInvocation(BaseInvocation): class SDXLModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl base model, outputting its submodels.""" """Loads an sdxl base model, outputting its submodels."""
model: ModelIdentifierField = InputField( model: ModelIdentifierField = InputField(
description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel description=FieldDescriptions.sdxl_main_model, ui_type=UIType.SDXLMainModel
) )
# TODO: precision? # TODO: precision?
@ -67,13 +67,13 @@ class SDXLModelLoaderInvocation(BaseInvocation):
title="SDXL Refiner Model", title="SDXL Refiner Model",
tags=["model", "sdxl", "refiner"], tags=["model", "sdxl", "refiner"],
category="model", category="model",
version="1.0.2", version="1.0.3",
) )
class SDXLRefinerModelLoaderInvocation(BaseInvocation): class SDXLRefinerModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl refiner model, outputting its submodels.""" """Loads an sdxl refiner model, outputting its submodels."""
model: ModelIdentifierField = InputField( model: ModelIdentifierField = InputField(
description=FieldDescriptions.sdxl_refiner_model, input=Input.Direct, ui_type=UIType.SDXLRefinerModel description=FieldDescriptions.sdxl_refiner_model, ui_type=UIType.SDXLRefinerModel
) )
# TODO: precision? # TODO: precision?

View File

@ -8,7 +8,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation, invocation,
invocation_output, invocation_output,
) )
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, UIType from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType
from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
@ -45,7 +45,7 @@ class T2IAdapterOutput(BaseInvocationOutput):
@invocation( @invocation(
"t2i_adapter", title="T2I-Adapter", tags=["t2i_adapter", "control"], category="t2i_adapter", version="1.0.2" "t2i_adapter", title="T2I-Adapter", tags=["t2i_adapter", "control"], category="t2i_adapter", version="1.0.3"
) )
class T2IAdapterInvocation(BaseInvocation): class T2IAdapterInvocation(BaseInvocation):
"""Collects T2I-Adapter info to pass to other nodes.""" """Collects T2I-Adapter info to pass to other nodes."""
@ -55,7 +55,6 @@ class T2IAdapterInvocation(BaseInvocation):
t2i_adapter_model: ModelIdentifierField = InputField( t2i_adapter_model: ModelIdentifierField = InputField(
description="The T2I-Adapter model.", description="The T2I-Adapter model.",
title="T2I-Adapter Model", title="T2I-Adapter Model",
input=Input.Direct,
ui_order=-1, ui_order=-1,
ui_type=UIType.T2IAdapterModel, ui_type=UIType.T2IAdapterModel,
) )

View File

@ -106,9 +106,7 @@ class BulkDownloadService(BulkDownloadBase):
if self._invoker: if self._invoker:
assert bulk_download_id is not None assert bulk_download_id is not None
self._invoker.services.events.emit_bulk_download_started( self._invoker.services.events.emit_bulk_download_started(
bulk_download_id=bulk_download_id, bulk_download_id, bulk_download_item_id, bulk_download_item_name
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
) )
def _signal_job_completed( def _signal_job_completed(
@ -118,10 +116,8 @@ class BulkDownloadService(BulkDownloadBase):
if self._invoker: if self._invoker:
assert bulk_download_id is not None assert bulk_download_id is not None
assert bulk_download_item_name is not None assert bulk_download_item_name is not None
self._invoker.services.events.emit_bulk_download_completed( self._invoker.services.events.emit_bulk_download_complete(
bulk_download_id=bulk_download_id, bulk_download_id, bulk_download_item_id, bulk_download_item_name
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
) )
def _signal_job_failed( def _signal_job_failed(
@ -131,11 +127,8 @@ class BulkDownloadService(BulkDownloadBase):
if self._invoker: if self._invoker:
assert bulk_download_id is not None assert bulk_download_id is not None
assert exception is not None assert exception is not None
self._invoker.services.events.emit_bulk_download_failed( self._invoker.services.events.emit_bulk_download_error(
bulk_download_id=bulk_download_id, bulk_download_id, bulk_download_item_id, bulk_download_item_name, str(exception)
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
error=str(exception),
) )
def stop(self, *args, **kwargs): def stop(self, *args, **kwargs):

View File

@ -8,14 +8,13 @@ import time
import traceback import traceback
from pathlib import Path from pathlib import Path
from queue import Empty, PriorityQueue from queue import Empty, PriorityQueue
from typing import Any, Dict, List, Optional, Set from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
import requests import requests
from pydantic.networks import AnyHttpUrl from pydantic.networks import AnyHttpUrl
from requests import HTTPError from requests import HTTPError
from tqdm import tqdm from tqdm import tqdm
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.util.misc import get_iso_timestamp from invokeai.app.util.misc import get_iso_timestamp
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
@ -30,6 +29,9 @@ from .download_base import (
UnknownJobIDException, UnknownJobIDException,
) )
if TYPE_CHECKING:
from invokeai.app.services.events.events_base import EventServiceBase
# Maximum number of bytes to download during each call to requests.iter_content() # Maximum number of bytes to download during each call to requests.iter_content()
DOWNLOAD_CHUNK_SIZE = 100000 DOWNLOAD_CHUNK_SIZE = 100000
@ -40,7 +42,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
def __init__( def __init__(
self, self,
max_parallel_dl: int = 5, max_parallel_dl: int = 5,
event_bus: Optional[EventServiceBase] = None, event_bus: Optional["EventServiceBase"] = None,
requests_session: Optional[requests.sessions.Session] = None, requests_session: Optional[requests.sessions.Session] = None,
): ):
""" """
@ -343,8 +345,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
f"An error occurred while processing the on_start callback: {traceback.format_exception(e)}" f"An error occurred while processing the on_start callback: {traceback.format_exception(e)}"
) )
if self._event_bus: if self._event_bus:
assert job.download_path self._event_bus.emit_download_started(job)
self._event_bus.emit_download_started(str(job.source), job.download_path.as_posix())
def _signal_job_progress(self, job: DownloadJob) -> None: def _signal_job_progress(self, job: DownloadJob) -> None:
if job.on_progress: if job.on_progress:
@ -355,13 +356,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
f"An error occurred while processing the on_progress callback: {traceback.format_exception(e)}" f"An error occurred while processing the on_progress callback: {traceback.format_exception(e)}"
) )
if self._event_bus: if self._event_bus:
assert job.download_path self._event_bus.emit_download_progress(job)
self._event_bus.emit_download_progress(
str(job.source),
download_path=job.download_path.as_posix(),
current_bytes=job.bytes,
total_bytes=job.total_bytes,
)
def _signal_job_complete(self, job: DownloadJob) -> None: def _signal_job_complete(self, job: DownloadJob) -> None:
job.status = DownloadJobStatus.COMPLETED job.status = DownloadJobStatus.COMPLETED
@ -373,10 +368,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
f"An error occurred while processing the on_complete callback: {traceback.format_exception(e)}" f"An error occurred while processing the on_complete callback: {traceback.format_exception(e)}"
) )
if self._event_bus: if self._event_bus:
assert job.download_path self._event_bus.emit_download_complete(job)
self._event_bus.emit_download_complete(
str(job.source), download_path=job.download_path.as_posix(), total_bytes=job.total_bytes
)
def _signal_job_cancelled(self, job: DownloadJob) -> None: def _signal_job_cancelled(self, job: DownloadJob) -> None:
if job.status not in [DownloadJobStatus.RUNNING, DownloadJobStatus.WAITING]: if job.status not in [DownloadJobStatus.RUNNING, DownloadJobStatus.WAITING]:
@ -390,7 +382,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
f"An error occurred while processing the on_cancelled callback: {traceback.format_exception(e)}" f"An error occurred while processing the on_cancelled callback: {traceback.format_exception(e)}"
) )
if self._event_bus: if self._event_bus:
self._event_bus.emit_download_cancelled(str(job.source)) self._event_bus.emit_download_cancelled(job)
def _signal_job_error(self, job: DownloadJob, excp: Optional[Exception] = None) -> None: def _signal_job_error(self, job: DownloadJob, excp: Optional[Exception] = None) -> None:
job.status = DownloadJobStatus.ERROR job.status = DownloadJobStatus.ERROR
@ -403,9 +395,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
f"An error occurred while processing the on_error callback: {traceback.format_exception(e)}" f"An error occurred while processing the on_error callback: {traceback.format_exception(e)}"
) )
if self._event_bus: if self._event_bus:
assert job.error_type self._event_bus.emit_download_error(job)
assert job.error
self._event_bus.emit_download_error(str(job.source), error_type=job.error_type, error=job.error)
def _cleanup_cancelled_job(self, job: DownloadJob) -> None: def _cleanup_cancelled_job(self, job: DownloadJob) -> None:
self._logger.debug(f"Cleaning up leftover files from cancelled download job {job.download_path}") self._logger.debug(f"Cleaning up leftover files from cancelled download job {job.download_path}")

View File

@ -1,486 +1,195 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Any, Dict, List, Optional, Union from typing import TYPE_CHECKING, Optional
from invokeai.app.services.session_processor.session_processor_common import ProgressImage from invokeai.app.services.events.events_common import (
from invokeai.app.services.session_queue.session_queue_common import ( BatchEnqueuedEvent,
BulkDownloadCompleteEvent,
BulkDownloadErrorEvent,
BulkDownloadStartedEvent,
DownloadCancelledEvent,
DownloadCompleteEvent,
DownloadErrorEvent,
DownloadProgressEvent,
DownloadStartedEvent,
EventBase,
InvocationCompleteEvent,
InvocationDenoiseProgressEvent,
InvocationErrorEvent,
InvocationStartedEvent,
ModelInstallCancelledEvent,
ModelInstallCompleteEvent,
ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent,
ModelInstallErrorEvent,
ModelInstallStartedEvent,
ModelLoadCompleteEvent,
ModelLoadStartedEvent,
QueueClearedEvent,
QueueItemStatusChangedEvent,
)
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
if TYPE_CHECKING:
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.download.download_base import DownloadJob
from invokeai.app.services.events.events_common import EventBase
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import (
BatchStatus, BatchStatus,
EnqueueBatchResult, EnqueueBatchResult,
SessionQueueItem, SessionQueueItem,
SessionQueueStatus, SessionQueueStatus,
) )
from invokeai.app.util.misc import get_timestamp from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
from invokeai.backend.model_manager import AnyModelConfig
from invokeai.backend.model_manager.config import SubModelType
class EventServiceBase: class EventServiceBase:
queue_event: str = "queue_event"
bulk_download_event: str = "bulk_download_event"
download_event: str = "download_event"
model_event: str = "model_event"
"""Basic event bus, to have an empty stand-in when not needed""" """Basic event bus, to have an empty stand-in when not needed"""
def dispatch(self, event_name: str, payload: Any) -> None: def dispatch(self, event: "EventBase") -> None:
pass pass
def _emit_bulk_download_event(self, event_name: str, payload: dict) -> None: # region: Invocation
"""Bulk download events are emitted to a room with queue_id as the room name"""
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.bulk_download_event,
payload={"event": event_name, "data": payload},
)
def __emit_queue_event(self, event_name: str, payload: dict) -> None: def emit_invocation_started(self, queue_item: "SessionQueueItem", invocation: "BaseInvocation") -> None:
"""Queue events are emitted to a room with queue_id as the room name""" """Emitted when an invocation is started"""
payload["timestamp"] = get_timestamp() self.dispatch(InvocationStartedEvent.build(queue_item, invocation))
self.dispatch(
event_name=EventServiceBase.queue_event,
payload={"event": event_name, "data": payload},
)
def __emit_download_event(self, event_name: str, payload: dict) -> None: def emit_invocation_denoise_progress(
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.download_event,
payload={"event": event_name, "data": payload},
)
def __emit_model_event(self, event_name: str, payload: dict) -> None:
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.model_event,
payload={"event": event_name, "data": payload},
)
# Define events here for every event in the system.
# This will make them easier to integrate until we find a schema generator.
def emit_generator_progress(
self, self,
queue_id: str, queue_item: "SessionQueueItem",
queue_item_id: int, invocation: "BaseInvocation",
queue_batch_id: str, intermediate_state: PipelineIntermediateState,
graph_execution_state_id: str, progress_image: "ProgressImage",
node_id: str,
source_node_id: str,
progress_image: Optional[ProgressImage],
step: int,
order: int,
total_steps: int,
) -> None: ) -> None:
"""Emitted when there is generation progress""" """Emitted at each step during denoising of an invocation."""
self.__emit_queue_event( self.dispatch(InvocationDenoiseProgressEvent.build(queue_item, invocation, intermediate_state, progress_image))
event_name="generator_progress",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node_id": node_id,
"source_node_id": source_node_id,
"progress_image": progress_image.model_dump(mode="json") if progress_image is not None else None,
"step": step,
"order": order,
"total_steps": total_steps,
},
)
def emit_invocation_complete( def emit_invocation_complete(
self, self, queue_item: "SessionQueueItem", invocation: "BaseInvocation", output: "BaseInvocationOutput"
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
result: dict,
node: dict,
source_node_id: str,
) -> None: ) -> None:
"""Emitted when an invocation has completed""" """Emitted when an invocation is complete"""
self.__emit_queue_event( self.dispatch(InvocationCompleteEvent.build(queue_item, invocation, output))
event_name="invocation_complete",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node": node,
"source_node_id": source_node_id,
"result": result,
},
)
def emit_invocation_error( def emit_invocation_error(
self, self,
queue_id: str, queue_item: "SessionQueueItem",
queue_item_id: int, invocation: "BaseInvocation",
queue_batch_id: str,
graph_execution_state_id: str,
node: dict,
source_node_id: str,
error_type: str, error_type: str,
error: str, error_message: str,
error_traceback: str,
) -> None: ) -> None:
"""Emitted when an invocation has completed""" """Emitted when an invocation encounters an error"""
self.__emit_queue_event( self.dispatch(InvocationErrorEvent.build(queue_item, invocation, error_type, error_message, error_traceback))
event_name="invocation_error",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node": node,
"source_node_id": source_node_id,
"error_type": error_type,
"error": error,
},
)
def emit_invocation_started( # endregion
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
node: dict,
source_node_id: str,
) -> None:
"""Emitted when an invocation has started"""
self.__emit_queue_event(
event_name="invocation_started",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node": node,
"source_node_id": source_node_id,
},
)
def emit_graph_execution_complete( # region Queue
self, queue_id: str, queue_item_id: int, queue_batch_id: str, graph_execution_state_id: str
) -> None:
"""Emitted when a session has completed all invocations"""
self.__emit_queue_event(
event_name="graph_execution_state_complete",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
},
)
def emit_model_load_started(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> None:
"""Emitted when a model is requested"""
self.__emit_queue_event(
event_name="model_load_started",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"model_config": model_config.model_dump(mode="json"),
"submodel_type": submodel_type,
},
)
def emit_model_load_completed(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> None:
"""Emitted when a model is correctly loaded (returns model info)"""
self.__emit_queue_event(
event_name="model_load_completed",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"model_config": model_config.model_dump(mode="json"),
"submodel_type": submodel_type,
},
)
def emit_session_canceled(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
) -> None:
"""Emitted when a session is canceled"""
self.__emit_queue_event(
event_name="session_canceled",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
},
)
def emit_queue_item_status_changed( def emit_queue_item_status_changed(
self, self, queue_item: "SessionQueueItem", batch_status: "BatchStatus", queue_status: "SessionQueueStatus"
session_queue_item: SessionQueueItem,
batch_status: BatchStatus,
queue_status: SessionQueueStatus,
) -> None: ) -> None:
"""Emitted when a queue item's status changes""" """Emitted when a queue item's status changes"""
self.__emit_queue_event( self.dispatch(QueueItemStatusChangedEvent.build(queue_item, batch_status, queue_status))
event_name="queue_item_status_changed",
payload={
"queue_id": queue_status.queue_id,
"queue_item": {
"queue_id": session_queue_item.queue_id,
"item_id": session_queue_item.item_id,
"status": session_queue_item.status,
"batch_id": session_queue_item.batch_id,
"session_id": session_queue_item.session_id,
"error": session_queue_item.error,
"created_at": str(session_queue_item.created_at) if session_queue_item.created_at else None,
"updated_at": str(session_queue_item.updated_at) if session_queue_item.updated_at else None,
"started_at": str(session_queue_item.started_at) if session_queue_item.started_at else None,
"completed_at": str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
},
"batch_status": batch_status.model_dump(mode="json"),
"queue_status": queue_status.model_dump(mode="json"),
},
)
def emit_batch_enqueued(self, enqueue_result: EnqueueBatchResult) -> None: def emit_batch_enqueued(self, enqueue_result: "EnqueueBatchResult") -> None:
"""Emitted when a batch is enqueued""" """Emitted when a batch is enqueued"""
self.__emit_queue_event( self.dispatch(BatchEnqueuedEvent.build(enqueue_result))
event_name="batch_enqueued",
payload={
"queue_id": enqueue_result.queue_id,
"batch_id": enqueue_result.batch.batch_id,
"enqueued": enqueue_result.enqueued,
},
)
def emit_queue_cleared(self, queue_id: str) -> None: def emit_queue_cleared(self, queue_id: str) -> None:
"""Emitted when the queue is cleared""" """Emitted when a queue is cleared"""
self.__emit_queue_event( self.dispatch(QueueClearedEvent.build(queue_id))
event_name="queue_cleared",
payload={"queue_id": queue_id},
)
def emit_download_started(self, source: str, download_path: str) -> None: # endregion
"""
Emit when a download job is started.
:param url: The downloaded url # region Download
"""
self.__emit_download_event(
event_name="download_started",
payload={"source": source, "download_path": download_path},
)
def emit_download_progress(self, source: str, download_path: str, current_bytes: int, total_bytes: int) -> None: def emit_download_started(self, job: "DownloadJob") -> None:
""" """Emitted when a download is started"""
Emit "download_progress" events at regular intervals during a download job. self.dispatch(DownloadStartedEvent.build(job))
:param source: The downloaded source def emit_download_progress(self, job: "DownloadJob") -> None:
:param download_path: The local downloaded file """Emitted at intervals during a download"""
:param current_bytes: Number of bytes downloaded so far self.dispatch(DownloadProgressEvent.build(job))
:param total_bytes: The size of the file being downloaded (if known)
"""
self.__emit_download_event(
event_name="download_progress",
payload={
"source": source,
"download_path": download_path,
"current_bytes": current_bytes,
"total_bytes": total_bytes,
},
)
def emit_download_complete(self, source: str, download_path: str, total_bytes: int) -> None: def emit_download_complete(self, job: "DownloadJob") -> None:
""" """Emitted when a download is completed"""
Emit a "download_complete" event at the end of a successful download. self.dispatch(DownloadCompleteEvent.build(job))
:param source: Source URL def emit_download_cancelled(self, job: "DownloadJob") -> None:
:param download_path: Path to the locally downloaded file """Emitted when a download is cancelled"""
:param total_bytes: The size of the downloaded file self.dispatch(DownloadCancelledEvent.build(job))
"""
self.__emit_download_event(
event_name="download_complete",
payload={
"source": source,
"download_path": download_path,
"total_bytes": total_bytes,
},
)
def emit_download_cancelled(self, source: str) -> None: def emit_download_error(self, job: "DownloadJob") -> None:
"""Emit a "download_cancelled" event in the event that the download was cancelled by user.""" """Emitted when a download encounters an error"""
self.__emit_download_event( self.dispatch(DownloadErrorEvent.build(job))
event_name="download_cancelled",
payload={
"source": source,
},
)
def emit_download_error(self, source: str, error_type: str, error: str) -> None: # endregion
"""
Emit a "download_error" event when an download job encounters an exception.
:param source: Source URL # region Model loading
:param error_type: The name of the exception that raised the error
:param error: The traceback from this error
"""
self.__emit_download_event(
event_name="download_error",
payload={
"source": source,
"error_type": error_type,
"error": error,
},
)
def emit_model_install_downloading( def emit_model_load_started(self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None) -> None:
self, """Emitted when a model load is started."""
source: str, self.dispatch(ModelLoadStartedEvent.build(config, submodel_type))
local_path: str,
bytes: int, def emit_model_load_complete(
total_bytes: int, self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None
parts: List[Dict[str, Union[str, int]]],
id: int,
) -> None: ) -> None:
""" """Emitted when a model load is complete."""
Emit at intervals while the install job is in progress (remote models only). self.dispatch(ModelLoadCompleteEvent.build(config, submodel_type))
:param source: Source of the model # endregion
:param local_path: Where model is downloading to
:param parts: Progress of downloading URLs that comprise the model, if any.
:param bytes: Number of bytes downloaded so far.
:param total_bytes: Total size of download, including all files.
This emits a Dict with keys "source", "local_path", "bytes" and "total_bytes".
"""
self.__emit_model_event(
event_name="model_install_downloading",
payload={
"source": source,
"local_path": local_path,
"bytes": bytes,
"total_bytes": total_bytes,
"parts": parts,
"id": id,
},
)
def emit_model_install_downloads_done(self, source: str) -> None: # region Model install
"""
Emit once when all parts are downloaded, but before the probing and registration start.
:param source: Source of the model; local path, repo_id or url def emit_model_install_download_progress(self, job: "ModelInstallJob") -> None:
""" """Emitted at intervals while the install job is in progress (remote models only)."""
self.__emit_model_event( self.dispatch(ModelInstallDownloadProgressEvent.build(job))
event_name="model_install_downloads_done",
payload={"source": source},
)
def emit_model_install_running(self, source: str) -> None: def emit_model_install_downloads_complete(self, job: "ModelInstallJob") -> None:
""" self.dispatch(ModelInstallDownloadsCompleteEvent.build(job))
Emit once when an install job becomes active.
:param source: Source of the model; local path, repo_id or url def emit_model_install_started(self, job: "ModelInstallJob") -> None:
""" """Emitted once when an install job is started (after any download)."""
self.__emit_model_event( self.dispatch(ModelInstallStartedEvent.build(job))
event_name="model_install_running",
payload={"source": source},
)
def emit_model_install_completed(self, source: str, key: str, id: int, total_bytes: Optional[int] = None) -> None: def emit_model_install_complete(self, job: "ModelInstallJob") -> None:
""" """Emitted when an install job is completed successfully."""
Emit when an install job is completed successfully. self.dispatch(ModelInstallCompleteEvent.build(job))
:param source: Source of the model; local path, repo_id or url def emit_model_install_cancelled(self, job: "ModelInstallJob") -> None:
:param key: Model config record key """Emitted when an install job is cancelled."""
:param total_bytes: Size of the model (may be None for installation of a local path) self.dispatch(ModelInstallCancelledEvent.build(job))
"""
self.__emit_model_event(
event_name="model_install_completed",
payload={"source": source, "total_bytes": total_bytes, "key": key, "id": id},
)
def emit_model_install_cancelled(self, source: str, id: int) -> None: def emit_model_install_error(self, job: "ModelInstallJob") -> None:
""" """Emitted when an install job encounters an exception."""
Emit when an install job is cancelled. self.dispatch(ModelInstallErrorEvent.build(job))
:param source: Source of the model; local path, repo_id or url # endregion
"""
self.__emit_model_event(
event_name="model_install_cancelled",
payload={"source": source, "id": id},
)
def emit_model_install_error(self, source: str, error_type: str, error: str, id: int) -> None: # region Bulk image download
"""
Emit when an install job encounters an exception.
:param source: Source of the model
:param error_type: The name of the exception
:param error: A text description of the exception
"""
self.__emit_model_event(
event_name="model_install_error",
payload={"source": source, "error_type": error_type, "error": error, "id": id},
)
def emit_bulk_download_started( def emit_bulk_download_started(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
) -> None: ) -> None:
"""Emitted when a bulk download starts""" """Emitted when a bulk image download is started"""
self._emit_bulk_download_event( self.dispatch(BulkDownloadStartedEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name))
event_name="bulk_download_started",
payload={
"bulk_download_id": bulk_download_id,
"bulk_download_item_id": bulk_download_item_id,
"bulk_download_item_name": bulk_download_item_name,
},
)
def emit_bulk_download_completed( def emit_bulk_download_complete(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
) -> None: ) -> None:
"""Emitted when a bulk download completes""" """Emitted when a bulk image download is complete"""
self._emit_bulk_download_event( self.dispatch(BulkDownloadCompleteEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name))
event_name="bulk_download_completed",
payload={
"bulk_download_id": bulk_download_id,
"bulk_download_item_id": bulk_download_item_id,
"bulk_download_item_name": bulk_download_item_name,
},
)
def emit_bulk_download_failed( def emit_bulk_download_error(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str
) -> None: ) -> None:
"""Emitted when a bulk download fails""" """Emitted when a bulk image download has an error"""
self._emit_bulk_download_event( self.dispatch(
event_name="bulk_download_failed", BulkDownloadErrorEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, error)
payload={
"bulk_download_id": bulk_download_id,
"bulk_download_item_id": bulk_download_item_id,
"bulk_download_item_name": bulk_download_item_name,
"error": error,
},
) )
# endregion

View File

@ -0,0 +1,592 @@
from math import floor
from typing import TYPE_CHECKING, Any, ClassVar, Coroutine, Generic, Optional, Protocol, TypeAlias, TypeVar
from fastapi_events.handlers.local import local_handler
from fastapi_events.registry.payload_schema import registry as payload_schema
from pydantic import BaseModel, ConfigDict, Field
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import (
QUEUE_ITEM_STATUS,
BatchStatus,
EnqueueBatchResult,
SessionQueueItem,
SessionQueueStatus,
)
from invokeai.app.services.shared.graph import AnyInvocation, AnyInvocationOutput
from invokeai.app.util.misc import get_timestamp
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
if TYPE_CHECKING:
from invokeai.app.services.download.download_base import DownloadJob
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
class EventBase(BaseModel):
"""Base class for all events. All events must inherit from this class.
Events must define a class attribute `__event_name__` to identify the event.
All other attributes should be defined as normal for a pydantic model.
A timestamp is automatically added to the event when it is created.
"""
__event_name__: ClassVar[str]
timestamp: int = Field(description="The timestamp of the event", default_factory=get_timestamp)
model_config = ConfigDict(json_schema_serialization_defaults_required=True)
@classmethod
def get_events(cls) -> set[type["EventBase"]]:
"""Get a set of all event models."""
event_subclasses: set[type["EventBase"]] = set()
for subclass in cls.__subclasses__():
# We only want to include subclasses that are event models, not intermediary classes
if hasattr(subclass, "__event_name__"):
event_subclasses.add(subclass)
event_subclasses.update(subclass.get_events())
return event_subclasses
TEvent = TypeVar("TEvent", bound=EventBase, contravariant=True)
FastAPIEvent: TypeAlias = tuple[str, TEvent]
"""
A tuple representing a `fastapi-events` event, with the event name and payload.
Provide a generic type to `TEvent` to specify the payload type.
"""
class FastAPIEventFunc(Protocol, Generic[TEvent]):
def __call__(self, event: FastAPIEvent[TEvent]) -> Optional[Coroutine[Any, Any, None]]: ...
def register_events(events: set[type[TEvent]] | type[TEvent], func: FastAPIEventFunc[TEvent]) -> None:
"""Register a function to handle specific events.
:param events: An event or set of events to handle
:param func: The function to handle the events
"""
events = events if isinstance(events, set) else {events}
for event in events:
assert hasattr(event, "__event_name__")
local_handler.register(event_name=event.__event_name__, _func=func) # pyright: ignore [reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
class QueueEventBase(EventBase):
"""Base class for queue events"""
queue_id: str = Field(description="The ID of the queue")
class QueueItemEventBase(QueueEventBase):
"""Base class for queue item events"""
item_id: int = Field(description="The ID of the queue item")
batch_id: str = Field(description="The ID of the queue batch")
class InvocationEventBase(QueueItemEventBase):
"""Base class for invocation events"""
session_id: str = Field(description="The ID of the session (aka graph execution state)")
queue_id: str = Field(description="The ID of the queue")
item_id: int = Field(description="The ID of the queue item")
batch_id: str = Field(description="The ID of the queue batch")
session_id: str = Field(description="The ID of the session (aka graph execution state)")
invocation: AnyInvocation = Field(description="The ID of the invocation")
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")
@payload_schema.register
class InvocationStartedEvent(InvocationEventBase):
"""Event model for invocation_started"""
__event_name__ = "invocation_started"
@classmethod
def build(cls, queue_item: SessionQueueItem, invocation: AnyInvocation) -> "InvocationStartedEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
)
@payload_schema.register
class InvocationDenoiseProgressEvent(InvocationEventBase):
"""Event model for invocation_denoise_progress"""
__event_name__ = "invocation_denoise_progress"
progress_image: ProgressImage = Field(description="The progress image sent at each step during processing")
step: int = Field(description="The current step of the invocation")
total_steps: int = Field(description="The total number of steps in the invocation")
order: int = Field(description="The order of the invocation in the session")
percentage: float = Field(description="The percentage of completion of the invocation")
@classmethod
def build(
cls,
queue_item: SessionQueueItem,
invocation: AnyInvocation,
intermediate_state: PipelineIntermediateState,
progress_image: ProgressImage,
) -> "InvocationDenoiseProgressEvent":
step = intermediate_state.step
total_steps = intermediate_state.total_steps
order = intermediate_state.order
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
progress_image=progress_image,
step=step,
total_steps=total_steps,
order=order,
percentage=cls.calc_percentage(step, total_steps, order),
)
@staticmethod
def calc_percentage(step: int, total_steps: int, scheduler_order: float) -> float:
"""Calculate the percentage of completion of denoising."""
if total_steps == 0:
return 0.0
if scheduler_order == 2:
return floor((step + 1 + 1) / 2) / floor((total_steps + 1) / 2)
# order == 1
return (step + 1 + 1) / (total_steps + 1)
@payload_schema.register
class InvocationCompleteEvent(InvocationEventBase):
"""Event model for invocation_complete"""
__event_name__ = "invocation_complete"
result: AnyInvocationOutput = Field(description="The result of the invocation")
@classmethod
def build(
cls, queue_item: SessionQueueItem, invocation: AnyInvocation, result: AnyInvocationOutput
) -> "InvocationCompleteEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
result=result,
)
@payload_schema.register
class InvocationErrorEvent(InvocationEventBase):
"""Event model for invocation_error"""
__event_name__ = "invocation_error"
error_type: str = Field(description="The error type")
error_message: str = Field(description="The error message")
error_traceback: str = Field(description="The error traceback")
user_id: Optional[str] = Field(default=None, description="The ID of the user who created the invocation")
project_id: Optional[str] = Field(default=None, description="The ID of the user who created the invocation")
@classmethod
def build(
cls,
queue_item: SessionQueueItem,
invocation: AnyInvocation,
error_type: str,
error_message: str,
error_traceback: str,
) -> "InvocationErrorEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
user_id=getattr(queue_item, "user_id", None),
project_id=getattr(queue_item, "project_id", None),
)
@payload_schema.register
class QueueItemStatusChangedEvent(QueueItemEventBase):
"""Event model for queue_item_status_changed"""
__event_name__ = "queue_item_status_changed"
status: QUEUE_ITEM_STATUS = Field(description="The new status of the queue item")
error_type: Optional[str] = Field(default=None, description="The error type, if any")
error_message: Optional[str] = Field(default=None, description="The error message, if any")
error_traceback: Optional[str] = Field(default=None, description="The error traceback, if any")
created_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was created")
updated_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was last updated")
started_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was started")
completed_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was completed")
batch_status: BatchStatus = Field(description="The status of the batch")
queue_status: SessionQueueStatus = Field(description="The status of the queue")
session_id: str = Field(description="The ID of the session (aka graph execution state)")
@classmethod
def build(
cls, queue_item: SessionQueueItem, batch_status: BatchStatus, queue_status: SessionQueueStatus
) -> "QueueItemStatusChangedEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
status=queue_item.status,
error_type=queue_item.error_type,
error_message=queue_item.error_message,
error_traceback=queue_item.error_traceback,
created_at=str(queue_item.created_at) if queue_item.created_at else None,
updated_at=str(queue_item.updated_at) if queue_item.updated_at else None,
started_at=str(queue_item.started_at) if queue_item.started_at else None,
completed_at=str(queue_item.completed_at) if queue_item.completed_at else None,
batch_status=batch_status,
queue_status=queue_status,
)
@payload_schema.register
class BatchEnqueuedEvent(QueueEventBase):
"""Event model for batch_enqueued"""
__event_name__ = "batch_enqueued"
batch_id: str = Field(description="The ID of the batch")
enqueued: int = Field(description="The number of invocations enqueued")
requested: int = Field(
description="The number of invocations initially requested to be enqueued (may be less than enqueued if queue was full)"
)
priority: int = Field(description="The priority of the batch")
@classmethod
def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent":
return cls(
queue_id=enqueue_result.queue_id,
batch_id=enqueue_result.batch.batch_id,
enqueued=enqueue_result.enqueued,
requested=enqueue_result.requested,
priority=enqueue_result.priority,
)
@payload_schema.register
class QueueClearedEvent(QueueEventBase):
"""Event model for queue_cleared"""
__event_name__ = "queue_cleared"
@classmethod
def build(cls, queue_id: str) -> "QueueClearedEvent":
return cls(queue_id=queue_id)
class DownloadEventBase(EventBase):
"""Base class for events associated with a download"""
source: str = Field(description="The source of the download")
@payload_schema.register
class DownloadStartedEvent(DownloadEventBase):
"""Event model for download_started"""
__event_name__ = "download_started"
download_path: str = Field(description="The local path where the download is saved")
@classmethod
def build(cls, job: "DownloadJob") -> "DownloadStartedEvent":
assert job.download_path
return cls(source=str(job.source), download_path=job.download_path.as_posix())
@payload_schema.register
class DownloadProgressEvent(DownloadEventBase):
"""Event model for download_progress"""
__event_name__ = "download_progress"
download_path: str = Field(description="The local path where the download is saved")
current_bytes: int = Field(description="The number of bytes downloaded so far")
total_bytes: int = Field(description="The total number of bytes to be downloaded")
@classmethod
def build(cls, job: "DownloadJob") -> "DownloadProgressEvent":
assert job.download_path
return cls(
source=str(job.source),
download_path=job.download_path.as_posix(),
current_bytes=job.bytes,
total_bytes=job.total_bytes,
)
@payload_schema.register
class DownloadCompleteEvent(DownloadEventBase):
"""Event model for download_complete"""
__event_name__ = "download_complete"
download_path: str = Field(description="The local path where the download is saved")
total_bytes: int = Field(description="The total number of bytes downloaded")
@classmethod
def build(cls, job: "DownloadJob") -> "DownloadCompleteEvent":
assert job.download_path
return cls(source=str(job.source), download_path=job.download_path.as_posix(), total_bytes=job.total_bytes)
@payload_schema.register
class DownloadCancelledEvent(DownloadEventBase):
"""Event model for download_cancelled"""
__event_name__ = "download_cancelled"
@classmethod
def build(cls, job: "DownloadJob") -> "DownloadCancelledEvent":
return cls(source=str(job.source))
@payload_schema.register
class DownloadErrorEvent(DownloadEventBase):
"""Event model for download_error"""
__event_name__ = "download_error"
error_type: str = Field(description="The type of error")
error: str = Field(description="The error message")
@classmethod
def build(cls, job: "DownloadJob") -> "DownloadErrorEvent":
assert job.error_type
assert job.error
return cls(source=str(job.source), error_type=job.error_type, error=job.error)
class ModelEventBase(EventBase):
"""Base class for events associated with a model"""
@payload_schema.register
class ModelLoadStartedEvent(ModelEventBase):
"""Event model for model_load_started"""
__event_name__ = "model_load_started"
config: AnyModelConfig = Field(description="The model's config")
submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any")
@classmethod
def build(cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> "ModelLoadStartedEvent":
return cls(config=config, submodel_type=submodel_type)
@payload_schema.register
class ModelLoadCompleteEvent(ModelEventBase):
"""Event model for model_load_complete"""
__event_name__ = "model_load_complete"
config: AnyModelConfig = Field(description="The model's config")
submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any")
@classmethod
def build(cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> "ModelLoadCompleteEvent":
return cls(config=config, submodel_type=submodel_type)
@payload_schema.register
class ModelInstallDownloadProgressEvent(ModelEventBase):
"""Event model for model_install_download_progress"""
__event_name__ = "model_install_download_progress"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
local_path: str = Field(description="Where model is downloading to")
bytes: int = Field(description="Number of bytes downloaded so far")
total_bytes: int = Field(description="Total size of download, including all files")
parts: list[dict[str, int | str]] = Field(
description="Progress of downloading URLs that comprise the model, if any"
)
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadProgressEvent":
parts: list[dict[str, str | int]] = [
{
"url": str(x.source),
"local_path": str(x.download_path),
"bytes": x.bytes,
"total_bytes": x.total_bytes,
}
for x in job.download_parts
]
return cls(
id=job.id,
source=str(job.source),
local_path=job.local_path.as_posix(),
parts=parts,
bytes=job.bytes,
total_bytes=job.total_bytes,
)
@payload_schema.register
class ModelInstallDownloadsCompleteEvent(ModelEventBase):
"""Emitted once when an install job becomes active."""
__event_name__ = "model_install_downloads_complete"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadsCompleteEvent":
return cls(id=job.id, source=str(job.source))
@payload_schema.register
class ModelInstallStartedEvent(ModelEventBase):
"""Event model for model_install_started"""
__event_name__ = "model_install_started"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallStartedEvent":
return cls(id=job.id, source=str(job.source))
@payload_schema.register
class ModelInstallCompleteEvent(ModelEventBase):
"""Event model for model_install_complete"""
__event_name__ = "model_install_complete"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
key: str = Field(description="Model config record key")
total_bytes: Optional[int] = Field(description="Size of the model (may be None for installation of a local path)")
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallCompleteEvent":
assert job.config_out is not None
return cls(id=job.id, source=str(job.source), key=(job.config_out.key), total_bytes=job.total_bytes)
@payload_schema.register
class ModelInstallCancelledEvent(ModelEventBase):
"""Event model for model_install_cancelled"""
__event_name__ = "model_install_cancelled"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallCancelledEvent":
return cls(id=job.id, source=str(job.source))
@payload_schema.register
class ModelInstallErrorEvent(ModelEventBase):
"""Event model for model_install_error"""
__event_name__ = "model_install_error"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
error_type: str = Field(description="The name of the exception")
error: str = Field(description="A text description of the exception")
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallErrorEvent":
assert job.error_type is not None
assert job.error is not None
return cls(id=job.id, source=str(job.source), error_type=job.error_type, error=job.error)
class BulkDownloadEventBase(EventBase):
"""Base class for events associated with a bulk image download"""
bulk_download_id: str = Field(description="The ID of the bulk image download")
bulk_download_item_id: str = Field(description="The ID of the bulk image download item")
bulk_download_item_name: str = Field(description="The name of the bulk image download item")
@payload_schema.register
class BulkDownloadStartedEvent(BulkDownloadEventBase):
"""Event model for bulk_download_started"""
__event_name__ = "bulk_download_started"
@classmethod
def build(
cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
) -> "BulkDownloadStartedEvent":
return cls(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
)
@payload_schema.register
class BulkDownloadCompleteEvent(BulkDownloadEventBase):
"""Event model for bulk_download_complete"""
__event_name__ = "bulk_download_complete"
@classmethod
def build(
cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
) -> "BulkDownloadCompleteEvent":
return cls(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
)
@payload_schema.register
class BulkDownloadErrorEvent(BulkDownloadEventBase):
"""Event model for bulk_download_error"""
__event_name__ = "bulk_download_error"
error: str = Field(description="The error message")
@classmethod
def build(
cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str
) -> "BulkDownloadErrorEvent":
return cls(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
error=error,
)

View File

@ -0,0 +1,47 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import asyncio
import threading
from queue import Empty, Queue
from fastapi_events.dispatcher import dispatch
from invokeai.app.services.events.events_common import (
EventBase,
)
from .events_base import EventServiceBase
class FastAPIEventService(EventServiceBase):
def __init__(self, event_handler_id: int) -> None:
self.event_handler_id = event_handler_id
self._queue = Queue[EventBase | None]()
self._stop_event = threading.Event()
asyncio.create_task(self._dispatch_from_queue(stop_event=self._stop_event))
super().__init__()
def stop(self, *args, **kwargs):
self._stop_event.set()
self._queue.put(None)
def dispatch(self, event: EventBase) -> None:
self._queue.put(event)
async def _dispatch_from_queue(self, stop_event: threading.Event):
"""Get events on from the queue and dispatch them, from the correct thread"""
while not stop_event.is_set():
try:
event = self._queue.get(block=False)
if not event: # Probably stopping
continue
# Leave the payloads as live pydantic models
dispatch(event, middleware_id=self.event_handler_id, payload_schema_dump=False)
except Empty:
await asyncio.sleep(0.1)
pass
except asyncio.CancelledError as e:
raise e # Raise a proper error

View File

@ -4,9 +4,6 @@ from typing import Optional
from PIL.Image import Image as PILImageType from PIL.Image import Image as PILImageType
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
class ImageFileStorageBase(ABC): class ImageFileStorageBase(ABC):
"""Low-level service responsible for storing and retrieving image files.""" """Low-level service responsible for storing and retrieving image files."""
@ -33,8 +30,9 @@ class ImageFileStorageBase(ABC):
self, self,
image: PILImageType, image: PILImageType,
image_name: str, image_name: str,
metadata: Optional[MetadataField] = None, metadata: Optional[str] = None,
workflow: Optional[WorkflowWithoutID] = None, workflow: Optional[str] = None,
graph: Optional[str] = None,
thumbnail_size: int = 256, thumbnail_size: int = 256,
) -> None: ) -> None:
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp.""" """Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
@ -46,6 +44,11 @@ class ImageFileStorageBase(ABC):
pass pass
@abstractmethod @abstractmethod
def get_workflow(self, image_name: str) -> Optional[WorkflowWithoutID]: def get_workflow(self, image_name: str) -> Optional[str]:
"""Gets the workflow of an image.""" """Gets the workflow of an image."""
pass pass
@abstractmethod
def get_graph(self, image_name: str) -> Optional[str]:
"""Gets the graph of an image."""
pass

View File

@ -7,9 +7,7 @@ from PIL import Image, PngImagePlugin
from PIL.Image import Image as PILImageType from PIL.Image import Image as PILImageType
from send2trash import send2trash from send2trash import send2trash
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
from .image_files_base import ImageFileStorageBase from .image_files_base import ImageFileStorageBase
@ -56,8 +54,9 @@ class DiskImageFileStorage(ImageFileStorageBase):
self, self,
image: PILImageType, image: PILImageType,
image_name: str, image_name: str,
metadata: Optional[MetadataField] = None, metadata: Optional[str] = None,
workflow: Optional[WorkflowWithoutID] = None, workflow: Optional[str] = None,
graph: Optional[str] = None,
thumbnail_size: int = 256, thumbnail_size: int = 256,
) -> None: ) -> None:
try: try:
@ -68,13 +67,14 @@ class DiskImageFileStorage(ImageFileStorageBase):
info_dict = {} info_dict = {}
if metadata is not None: if metadata is not None:
metadata_json = metadata.model_dump_json() info_dict["invokeai_metadata"] = metadata
info_dict["invokeai_metadata"] = metadata_json pnginfo.add_text("invokeai_metadata", metadata)
pnginfo.add_text("invokeai_metadata", metadata_json)
if workflow is not None: if workflow is not None:
workflow_json = workflow.model_dump_json() info_dict["invokeai_workflow"] = workflow
info_dict["invokeai_workflow"] = workflow_json pnginfo.add_text("invokeai_workflow", workflow)
pnginfo.add_text("invokeai_workflow", workflow_json) if graph is not None:
info_dict["invokeai_graph"] = graph
pnginfo.add_text("invokeai_graph", graph)
# When saving the image, the image object's info field is not populated. We need to set it # When saving the image, the image object's info field is not populated. We need to set it
image.info = info_dict image.info = info_dict
@ -129,11 +129,18 @@ class DiskImageFileStorage(ImageFileStorageBase):
path = path if isinstance(path, Path) else Path(path) path = path if isinstance(path, Path) else Path(path)
return path.exists() return path.exists()
def get_workflow(self, image_name: str) -> WorkflowWithoutID | None: def get_workflow(self, image_name: str) -> str | None:
image = self.get(image_name) image = self.get(image_name)
workflow = image.info.get("invokeai_workflow", None) workflow = image.info.get("invokeai_workflow", None)
if workflow is not None: if isinstance(workflow, str):
return WorkflowWithoutID.model_validate_json(workflow) return workflow
return None
def get_graph(self, image_name: str) -> str | None:
image = self.get(image_name)
graph = image.info.get("invokeai_graph", None)
if isinstance(graph, str):
return graph
return None return None
def __validate_storage_folders(self) -> None: def __validate_storage_folders(self) -> None:

View File

@ -80,7 +80,7 @@ class ImageRecordStorageBase(ABC):
starred: Optional[bool] = False, starred: Optional[bool] = False,
session_id: Optional[str] = None, session_id: Optional[str] = None,
node_id: Optional[str] = None, node_id: Optional[str] = None,
metadata: Optional[MetadataField] = None, metadata: Optional[str] = None,
) -> datetime: ) -> datetime:
"""Saves an image record.""" """Saves an image record."""
pass pass

View File

@ -328,10 +328,9 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
starred: Optional[bool] = False, starred: Optional[bool] = False,
session_id: Optional[str] = None, session_id: Optional[str] = None,
node_id: Optional[str] = None, node_id: Optional[str] = None,
metadata: Optional[MetadataField] = None, metadata: Optional[str] = None,
) -> datetime: ) -> datetime:
try: try:
metadata_json = metadata.model_dump_json() if metadata is not None else None
self._lock.acquire() self._lock.acquire()
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
@ -358,7 +357,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
height, height,
node_id, node_id,
session_id, session_id,
metadata_json, metadata,
is_intermediate, is_intermediate,
starred, starred,
has_workflow, has_workflow,

View File

@ -12,7 +12,6 @@ from invokeai.app.services.image_records.image_records_common import (
) )
from invokeai.app.services.images.images_common import ImageDTO from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
class ImageServiceABC(ABC): class ImageServiceABC(ABC):
@ -51,8 +50,9 @@ class ImageServiceABC(ABC):
session_id: Optional[str] = None, session_id: Optional[str] = None,
board_id: Optional[str] = None, board_id: Optional[str] = None,
is_intermediate: Optional[bool] = False, is_intermediate: Optional[bool] = False,
metadata: Optional[MetadataField] = None, metadata: Optional[str] = None,
workflow: Optional[WorkflowWithoutID] = None, workflow: Optional[str] = None,
graph: Optional[str] = None,
) -> ImageDTO: ) -> ImageDTO:
"""Creates an image, storing the file and its metadata.""" """Creates an image, storing the file and its metadata."""
pass pass
@ -87,7 +87,12 @@ class ImageServiceABC(ABC):
pass pass
@abstractmethod @abstractmethod
def get_workflow(self, image_name: str) -> Optional[WorkflowWithoutID]: def get_workflow(self, image_name: str) -> Optional[str]:
"""Gets an image's workflow."""
pass
@abstractmethod
def get_graph(self, image_name: str) -> Optional[str]:
"""Gets an image's workflow.""" """Gets an image's workflow."""
pass pass

View File

@ -5,7 +5,6 @@ from PIL.Image import Image as PILImageType
from invokeai.app.invocations.fields import MetadataField from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
from ..image_files.image_files_common import ( from ..image_files.image_files_common import (
ImageFileDeleteException, ImageFileDeleteException,
@ -42,8 +41,9 @@ class ImageService(ImageServiceABC):
session_id: Optional[str] = None, session_id: Optional[str] = None,
board_id: Optional[str] = None, board_id: Optional[str] = None,
is_intermediate: Optional[bool] = False, is_intermediate: Optional[bool] = False,
metadata: Optional[MetadataField] = None, metadata: Optional[str] = None,
workflow: Optional[WorkflowWithoutID] = None, workflow: Optional[str] = None,
graph: Optional[str] = None,
) -> ImageDTO: ) -> ImageDTO:
if image_origin not in ResourceOrigin: if image_origin not in ResourceOrigin:
raise InvalidOriginException raise InvalidOriginException
@ -64,7 +64,7 @@ class ImageService(ImageServiceABC):
image_category=image_category, image_category=image_category,
width=width, width=width,
height=height, height=height,
has_workflow=workflow is not None, has_workflow=workflow is not None or graph is not None,
# Meta fields # Meta fields
is_intermediate=is_intermediate, is_intermediate=is_intermediate,
# Nullable fields # Nullable fields
@ -75,7 +75,7 @@ class ImageService(ImageServiceABC):
if board_id is not None: if board_id is not None:
self.__invoker.services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name) self.__invoker.services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name)
self.__invoker.services.image_files.save( self.__invoker.services.image_files.save(
image_name=image_name, image=image, metadata=metadata, workflow=workflow image_name=image_name, image=image, metadata=metadata, workflow=workflow, graph=graph
) )
image_dto = self.get_dto(image_name) image_dto = self.get_dto(image_name)
@ -157,7 +157,7 @@ class ImageService(ImageServiceABC):
self.__invoker.services.logger.error("Problem getting image metadata") self.__invoker.services.logger.error("Problem getting image metadata")
raise e raise e
def get_workflow(self, image_name: str) -> Optional[WorkflowWithoutID]: def get_workflow(self, image_name: str) -> Optional[str]:
try: try:
return self.__invoker.services.image_files.get_workflow(image_name) return self.__invoker.services.image_files.get_workflow(image_name)
except ImageFileNotFoundException: except ImageFileNotFoundException:
@ -167,6 +167,16 @@ class ImageService(ImageServiceABC):
self.__invoker.services.logger.error("Problem getting image workflow") self.__invoker.services.logger.error("Problem getting image workflow")
raise raise
def get_graph(self, image_name: str) -> Optional[str]:
try:
return self.__invoker.services.image_files.get_graph(image_name)
except ImageFileNotFoundException:
self.__invoker.services.logger.error("Image file not found")
raise
except Exception:
self.__invoker.services.logger.error("Problem getting image graph")
raise
def get_path(self, image_name: str, thumbnail: bool = False) -> str: def get_path(self, image_name: str, thumbnail: bool = False) -> str:
try: try:
return str(self.__invoker.services.image_files.get_path(image_name, thumbnail)) return str(self.__invoker.services.image_files.get_path(image_name, thumbnail))

View File

@ -24,6 +24,7 @@ if TYPE_CHECKING:
from .image_records.image_records_base import ImageRecordStorageBase from .image_records.image_records_base import ImageRecordStorageBase
from .images.images_base import ImageServiceABC from .images.images_base import ImageServiceABC
from .invocation_cache.invocation_cache_base import InvocationCacheBase from .invocation_cache.invocation_cache_base import InvocationCacheBase
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
from .model_images.model_images_base import ModelImageFileStorageBase from .model_images.model_images_base import ModelImageFileStorageBase
from .model_manager.model_manager_base import ModelManagerServiceBase from .model_manager.model_manager_base import ModelManagerServiceBase
from .names.names_base import NameServiceBase from .names.names_base import NameServiceBase
@ -56,6 +57,7 @@ class InvocationServices:
session_processor: "SessionProcessorBase", session_processor: "SessionProcessorBase",
invocation_cache: "InvocationCacheBase", invocation_cache: "InvocationCacheBase",
names: "NameServiceBase", names: "NameServiceBase",
performance_statistics: "InvocationStatsServiceBase",
urls: "UrlServiceBase", urls: "UrlServiceBase",
workflow_records: "WorkflowRecordsStorageBase", workflow_records: "WorkflowRecordsStorageBase",
tensors: "ObjectSerializerBase[torch.Tensor]", tensors: "ObjectSerializerBase[torch.Tensor]",
@ -79,6 +81,7 @@ class InvocationServices:
self.session_processor = session_processor self.session_processor = session_processor
self.invocation_cache = invocation_cache self.invocation_cache = invocation_cache
self.names = names self.names = names
self.performance_statistics = performance_statistics
self.urls = urls self.urls = urls
self.workflow_records = workflow_records self.workflow_records = workflow_records
self.tensors = tensors self.tensors = tensors

View File

@ -1,11 +1,13 @@
"""Initialization file for model install service package.""" """Initialization file for model install service package."""
from .model_install_base import ( from .model_install_base import (
ModelInstallServiceBase,
)
from .model_install_common import (
HFModelSource, HFModelSource,
InstallStatus, InstallStatus,
LocalModelSource, LocalModelSource,
ModelInstallJob, ModelInstallJob,
ModelInstallServiceBase,
ModelSource, ModelSource,
UnknownInstallJobException, UnknownInstallJobException,
URLModelSource, URLModelSource,

View File

@ -1,244 +1,19 @@
# Copyright 2023 Lincoln D. Stein and the InvokeAI development team # Copyright 2023 Lincoln D. Stein and the InvokeAI development team
"""Baseclass definitions for the model installer.""" """Baseclass definitions for the model installer."""
import re
import traceback
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Set, Union from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Field, PrivateAttr, field_validator
from pydantic.networks import AnyHttpUrl from pydantic.networks import AnyHttpUrl
from typing_extensions import Annotated
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase from invokeai.app.services.download import DownloadQueueServiceBase
from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource
from invokeai.app.services.model_records import ModelRecordServiceBase from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant from invokeai.backend.model_manager.config import AnyModelConfig
from invokeai.backend.model_manager.config import ModelSourceType
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
class InstallStatus(str, Enum):
"""State of an install job running in the background."""
WAITING = "waiting" # waiting to be dequeued
DOWNLOADING = "downloading" # downloading of model files in process
DOWNLOADS_DONE = "downloads_done" # downloading done, waiting to run
RUNNING = "running" # being processed
COMPLETED = "completed" # finished running
ERROR = "error" # terminated with an error message
CANCELLED = "cancelled" # terminated with an error message
class ModelInstallPart(BaseModel):
url: AnyHttpUrl
path: Path
bytes: int = 0
total_bytes: int = 0
class UnknownInstallJobException(Exception):
"""Raised when the status of an unknown job is requested."""
class StringLikeSource(BaseModel):
"""
Base class for model sources, implements functions that lets the source be sorted and indexed.
These shenanigans let this stuff work:
source1 = LocalModelSource(path='C:/users/mort/foo.safetensors')
mydict = {source1: 'model 1'}
assert mydict['C:/users/mort/foo.safetensors'] == 'model 1'
assert mydict[LocalModelSource(path='C:/users/mort/foo.safetensors')] == 'model 1'
source2 = LocalModelSource(path=Path('C:/users/mort/foo.safetensors'))
assert source1 == source2
assert source1 == 'C:/users/mort/foo.safetensors'
"""
def __hash__(self) -> int:
"""Return hash of the path field, for indexing."""
return hash(str(self))
def __lt__(self, other: object) -> int:
"""Return comparison of the stringified version, for sorting."""
return str(self) < str(other)
def __eq__(self, other: object) -> bool:
"""Return equality on the stringified version."""
if isinstance(other, Path):
return str(self) == other.as_posix()
else:
return str(self) == str(other)
class LocalModelSource(StringLikeSource):
"""A local file or directory path."""
path: str | Path
inplace: Optional[bool] = False
type: Literal["local"] = "local"
# these methods allow the source to be used in a string-like way,
# for example as an index into a dict
def __str__(self) -> str:
"""Return string version of path when string rep needed."""
return Path(self.path).as_posix()
class HFModelSource(StringLikeSource):
"""
A HuggingFace repo_id with optional variant, sub-folder and access token.
Note that the variant option, if not provided to the constructor, will default to fp16, which is
what people (almost) always want.
"""
repo_id: str
variant: Optional[ModelRepoVariant] = ModelRepoVariant.FP16
subfolder: Optional[Path] = None
access_token: Optional[str] = None
type: Literal["hf"] = "hf"
@field_validator("repo_id")
@classmethod
def proper_repo_id(cls, v: str) -> str: # noqa D102
if not re.match(r"^([.\w-]+/[.\w-]+)$", v):
raise ValueError(f"{v}: invalid repo_id format")
return v
def __str__(self) -> str:
"""Return string version of repoid when string rep needed."""
base: str = self.repo_id
if self.variant:
base += f":{self.variant or ''}"
if self.subfolder:
base += f":{self.subfolder}"
return base
class URLModelSource(StringLikeSource):
"""A generic URL point to a checkpoint file."""
url: AnyHttpUrl
access_token: Optional[str] = None
type: Literal["url"] = "url"
def __str__(self) -> str:
"""Return string version of the url when string rep needed."""
return str(self.url)
ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Field(discriminator="type")]
MODEL_SOURCE_TO_TYPE_MAP = {
URLModelSource: ModelSourceType.Url,
HFModelSource: ModelSourceType.HFRepoID,
LocalModelSource: ModelSourceType.Path,
}
class ModelInstallJob(BaseModel):
"""Object that tracks the current status of an install request."""
id: int = Field(description="Unique ID for this job")
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
error_reason: Optional[str] = Field(default=None, description="Information about why the job failed")
config_in: Dict[str, Any] = Field(
default_factory=dict, description="Configuration information (e.g. 'description') to apply to model."
)
config_out: Optional[AnyModelConfig] = Field(
default=None, description="After successful installation, this will hold the configuration object."
)
inplace: bool = Field(
default=False, description="Leave model in its current location; otherwise install under models directory"
)
source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model")
local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source")
bytes: int = Field(
default=0, description="For a remote model, the number of bytes downloaded so far (may not be available)"
)
total_bytes: int = Field(default=0, description="Total size of the model to be installed")
source_metadata: Optional[AnyModelRepoMetadata] = Field(
default=None, description="Metadata provided by the model source"
)
download_parts: Set[DownloadJob] = Field(
default_factory=set, description="Download jobs contributing to this install"
)
error: Optional[str] = Field(
default=None, description="On an error condition, this field will contain the text of the exception"
)
error_traceback: Optional[str] = Field(
default=None, description="On an error condition, this field will contain the exception traceback"
)
# internal flags and transitory settings
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
_exception: Optional[Exception] = PrivateAttr(default=None)
def set_error(self, e: Exception) -> None:
"""Record the error and traceback from an exception."""
self._exception = e
self.error = str(e)
self.error_traceback = self._format_error(e)
self.status = InstallStatus.ERROR
self.error_reason = self._exception.__class__.__name__ if self._exception else None
def cancel(self) -> None:
"""Call to cancel the job."""
self.status = InstallStatus.CANCELLED
@property
def error_type(self) -> Optional[str]:
"""Class name of the exception that led to status==ERROR."""
return self._exception.__class__.__name__ if self._exception else None
def _format_error(self, exception: Exception) -> str:
"""Error traceback."""
return "".join(traceback.format_exception(exception))
@property
def cancelled(self) -> bool:
"""Set status to CANCELLED."""
return self.status == InstallStatus.CANCELLED
@property
def errored(self) -> bool:
"""Return true if job has errored."""
return self.status == InstallStatus.ERROR
@property
def waiting(self) -> bool:
"""Return true if job is waiting to run."""
return self.status == InstallStatus.WAITING
@property
def downloading(self) -> bool:
"""Return true if job is downloading."""
return self.status == InstallStatus.DOWNLOADING
@property
def downloads_done(self) -> bool:
"""Return true if job's downloads ae done."""
return self.status == InstallStatus.DOWNLOADS_DONE
@property
def running(self) -> bool:
"""Return true if job is running."""
return self.status == InstallStatus.RUNNING
@property
def complete(self) -> bool:
"""Return true if job completed without errors."""
return self.status == InstallStatus.COMPLETED
@property
def in_terminal_state(self) -> bool:
"""Return true if job is in a terminal state."""
return self.status in [InstallStatus.COMPLETED, InstallStatus.ERROR, InstallStatus.CANCELLED]
class ModelInstallServiceBase(ABC): class ModelInstallServiceBase(ABC):
@ -282,7 +57,7 @@ class ModelInstallServiceBase(ABC):
@property @property
@abstractmethod @abstractmethod
def event_bus(self) -> Optional[EventServiceBase]: def event_bus(self) -> Optional["EventServiceBase"]:
"""Return the event service base object associated with the installer.""" """Return the event service base object associated with the installer."""
@abstractmethod @abstractmethod

View File

@ -0,0 +1,233 @@
import re
import traceback
from enum import Enum
from pathlib import Path
from typing import Any, Dict, Literal, Optional, Set, Union
from pydantic import BaseModel, Field, PrivateAttr, field_validator
from pydantic.networks import AnyHttpUrl
from typing_extensions import Annotated
from invokeai.app.services.download import DownloadJob
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
from invokeai.backend.model_manager.config import ModelSourceType
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
class InstallStatus(str, Enum):
"""State of an install job running in the background."""
WAITING = "waiting" # waiting to be dequeued
DOWNLOADING = "downloading" # downloading of model files in process
DOWNLOADS_DONE = "downloads_done" # downloading done, waiting to run
RUNNING = "running" # being processed
COMPLETED = "completed" # finished running
ERROR = "error" # terminated with an error message
CANCELLED = "cancelled" # terminated with an error message
class ModelInstallPart(BaseModel):
url: AnyHttpUrl
path: Path
bytes: int = 0
total_bytes: int = 0
class UnknownInstallJobException(Exception):
"""Raised when the status of an unknown job is requested."""
class StringLikeSource(BaseModel):
"""
Base class for model sources, implements functions that lets the source be sorted and indexed.
These shenanigans let this stuff work:
source1 = LocalModelSource(path='C:/users/mort/foo.safetensors')
mydict = {source1: 'model 1'}
assert mydict['C:/users/mort/foo.safetensors'] == 'model 1'
assert mydict[LocalModelSource(path='C:/users/mort/foo.safetensors')] == 'model 1'
source2 = LocalModelSource(path=Path('C:/users/mort/foo.safetensors'))
assert source1 == source2
assert source1 == 'C:/users/mort/foo.safetensors'
"""
def __hash__(self) -> int:
"""Return hash of the path field, for indexing."""
return hash(str(self))
def __lt__(self, other: object) -> int:
"""Return comparison of the stringified version, for sorting."""
return str(self) < str(other)
def __eq__(self, other: object) -> bool:
"""Return equality on the stringified version."""
if isinstance(other, Path):
return str(self) == other.as_posix()
else:
return str(self) == str(other)
class LocalModelSource(StringLikeSource):
"""A local file or directory path."""
path: str | Path
inplace: Optional[bool] = False
type: Literal["local"] = "local"
# these methods allow the source to be used in a string-like way,
# for example as an index into a dict
def __str__(self) -> str:
"""Return string version of path when string rep needed."""
return Path(self.path).as_posix()
class HFModelSource(StringLikeSource):
"""
A HuggingFace repo_id with optional variant, sub-folder and access token.
Note that the variant option, if not provided to the constructor, will default to fp16, which is
what people (almost) always want.
"""
repo_id: str
variant: Optional[ModelRepoVariant] = ModelRepoVariant.FP16
subfolder: Optional[Path] = None
access_token: Optional[str] = None
type: Literal["hf"] = "hf"
@field_validator("repo_id")
@classmethod
def proper_repo_id(cls, v: str) -> str: # noqa D102
if not re.match(r"^([.\w-]+/[.\w-]+)$", v):
raise ValueError(f"{v}: invalid repo_id format")
return v
def __str__(self) -> str:
"""Return string version of repoid when string rep needed."""
base: str = self.repo_id
if self.variant:
base += f":{self.variant or ''}"
if self.subfolder:
base += f":{self.subfolder}"
return base
class URLModelSource(StringLikeSource):
"""A generic URL point to a checkpoint file."""
url: AnyHttpUrl
access_token: Optional[str] = None
type: Literal["url"] = "url"
def __str__(self) -> str:
"""Return string version of the url when string rep needed."""
return str(self.url)
ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Field(discriminator="type")]
MODEL_SOURCE_TO_TYPE_MAP = {
URLModelSource: ModelSourceType.Url,
HFModelSource: ModelSourceType.HFRepoID,
LocalModelSource: ModelSourceType.Path,
}
class ModelInstallJob(BaseModel):
"""Object that tracks the current status of an install request."""
id: int = Field(description="Unique ID for this job")
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
error_reason: Optional[str] = Field(default=None, description="Information about why the job failed")
config_in: Dict[str, Any] = Field(
default_factory=dict, description="Configuration information (e.g. 'description') to apply to model."
)
config_out: Optional[AnyModelConfig] = Field(
default=None, description="After successful installation, this will hold the configuration object."
)
inplace: bool = Field(
default=False, description="Leave model in its current location; otherwise install under models directory"
)
source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model")
local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source")
bytes: int = Field(
default=0, description="For a remote model, the number of bytes downloaded so far (may not be available)"
)
total_bytes: int = Field(default=0, description="Total size of the model to be installed")
source_metadata: Optional[AnyModelRepoMetadata] = Field(
default=None, description="Metadata provided by the model source"
)
download_parts: Set[DownloadJob] = Field(
default_factory=set, description="Download jobs contributing to this install"
)
error: Optional[str] = Field(
default=None, description="On an error condition, this field will contain the text of the exception"
)
error_traceback: Optional[str] = Field(
default=None, description="On an error condition, this field will contain the exception traceback"
)
# internal flags and transitory settings
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
_exception: Optional[Exception] = PrivateAttr(default=None)
def set_error(self, e: Exception) -> None:
"""Record the error and traceback from an exception."""
self._exception = e
self.error = str(e)
self.error_traceback = self._format_error(e)
self.status = InstallStatus.ERROR
self.error_reason = self._exception.__class__.__name__ if self._exception else None
def cancel(self) -> None:
"""Call to cancel the job."""
self.status = InstallStatus.CANCELLED
@property
def error_type(self) -> Optional[str]:
"""Class name of the exception that led to status==ERROR."""
return self._exception.__class__.__name__ if self._exception else None
def _format_error(self, exception: Exception) -> str:
"""Error traceback."""
return "".join(traceback.format_exception(exception))
@property
def cancelled(self) -> bool:
"""Set status to CANCELLED."""
return self.status == InstallStatus.CANCELLED
@property
def errored(self) -> bool:
"""Return true if job has errored."""
return self.status == InstallStatus.ERROR
@property
def waiting(self) -> bool:
"""Return true if job is waiting to run."""
return self.status == InstallStatus.WAITING
@property
def downloading(self) -> bool:
"""Return true if job is downloading."""
return self.status == InstallStatus.DOWNLOADING
@property
def downloads_done(self) -> bool:
"""Return true if job's downloads ae done."""
return self.status == InstallStatus.DOWNLOADS_DONE
@property
def running(self) -> bool:
"""Return true if job is running."""
return self.status == InstallStatus.RUNNING
@property
def complete(self) -> bool:
"""Return true if job completed without errors."""
return self.status == InstallStatus.COMPLETED
@property
def in_terminal_state(self) -> bool:
"""Return true if job is in a terminal state."""
return self.status in [InstallStatus.COMPLETED, InstallStatus.ERROR, InstallStatus.CANCELLED]

View File

@ -10,7 +10,7 @@ from pathlib import Path
from queue import Empty, Queue from queue import Empty, Queue
from shutil import copyfile, copytree, move, rmtree from shutil import copyfile, copytree, move, rmtree
from tempfile import mkdtemp from tempfile import mkdtemp
from typing import Any, Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import torch import torch
import yaml import yaml
@ -20,8 +20,8 @@ from requests import Session
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_install.model_install_base import ModelInstallServiceBase
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.backend.model_manager.config import ( from invokeai.backend.model_manager.config import (
@ -45,13 +45,12 @@ from invokeai.backend.util import InvokeAILogger
from invokeai.backend.util.catch_sigint import catch_sigint from invokeai.backend.util.catch_sigint import catch_sigint
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
from .model_install_base import ( from .model_install_common import (
MODEL_SOURCE_TO_TYPE_MAP, MODEL_SOURCE_TO_TYPE_MAP,
HFModelSource, HFModelSource,
InstallStatus, InstallStatus,
LocalModelSource, LocalModelSource,
ModelInstallJob, ModelInstallJob,
ModelInstallServiceBase,
ModelSource, ModelSource,
StringLikeSource, StringLikeSource,
URLModelSource, URLModelSource,
@ -59,6 +58,9 @@ from .model_install_base import (
TMPDIR_PREFIX = "tmpinstall_" TMPDIR_PREFIX = "tmpinstall_"
if TYPE_CHECKING:
from invokeai.app.services.events.events_base import EventServiceBase
class ModelInstallService(ModelInstallServiceBase): class ModelInstallService(ModelInstallServiceBase):
"""class for InvokeAI model installation.""" """class for InvokeAI model installation."""
@ -68,7 +70,7 @@ class ModelInstallService(ModelInstallServiceBase):
app_config: InvokeAIAppConfig, app_config: InvokeAIAppConfig,
record_store: ModelRecordServiceBase, record_store: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase, download_queue: DownloadQueueServiceBase,
event_bus: Optional[EventServiceBase] = None, event_bus: Optional["EventServiceBase"] = None,
session: Optional[Session] = None, session: Optional[Session] = None,
): ):
""" """
@ -104,7 +106,7 @@ class ModelInstallService(ModelInstallServiceBase):
return self._record_store return self._record_store
@property @property
def event_bus(self) -> Optional[EventServiceBase]: # noqa D102 def event_bus(self) -> Optional["EventServiceBase"]: # noqa D102
return self._event_bus return self._event_bus
# make the invoker optional here because we don't need it and it # make the invoker optional here because we don't need it and it
@ -855,35 +857,17 @@ class ModelInstallService(ModelInstallServiceBase):
job.status = InstallStatus.RUNNING job.status = InstallStatus.RUNNING
self._logger.info(f"Model install started: {job.source}") self._logger.info(f"Model install started: {job.source}")
if self._event_bus: if self._event_bus:
self._event_bus.emit_model_install_running(str(job.source)) self._event_bus.emit_model_install_started(job)
def _signal_job_downloading(self, job: ModelInstallJob) -> None: def _signal_job_downloading(self, job: ModelInstallJob) -> None:
if self._event_bus: if self._event_bus:
parts: List[Dict[str, str | int]] = [ self._event_bus.emit_model_install_download_progress(job)
{
"url": str(x.source),
"local_path": str(x.download_path),
"bytes": x.bytes,
"total_bytes": x.total_bytes,
}
for x in job.download_parts
]
assert job.bytes is not None
assert job.total_bytes is not None
self._event_bus.emit_model_install_downloading(
str(job.source),
local_path=job.local_path.as_posix(),
parts=parts,
bytes=job.bytes,
total_bytes=job.total_bytes,
id=job.id,
)
def _signal_job_downloads_done(self, job: ModelInstallJob) -> None: def _signal_job_downloads_done(self, job: ModelInstallJob) -> None:
job.status = InstallStatus.DOWNLOADS_DONE job.status = InstallStatus.DOWNLOADS_DONE
self._logger.info(f"Model download complete: {job.source}") self._logger.info(f"Model download complete: {job.source}")
if self._event_bus: if self._event_bus:
self._event_bus.emit_model_install_downloads_done(str(job.source)) self._event_bus.emit_model_install_downloads_complete(job)
def _signal_job_completed(self, job: ModelInstallJob) -> None: def _signal_job_completed(self, job: ModelInstallJob) -> None:
job.status = InstallStatus.COMPLETED job.status = InstallStatus.COMPLETED
@ -891,24 +875,19 @@ class ModelInstallService(ModelInstallServiceBase):
self._logger.info(f"Model install complete: {job.source}") self._logger.info(f"Model install complete: {job.source}")
self._logger.debug(f"{job.local_path} registered key {job.config_out.key}") self._logger.debug(f"{job.local_path} registered key {job.config_out.key}")
if self._event_bus: if self._event_bus:
assert job.local_path is not None self._event_bus.emit_model_install_complete(job)
assert job.config_out is not None
key = job.config_out.key
self._event_bus.emit_model_install_completed(str(job.source), key, id=job.id)
def _signal_job_errored(self, job: ModelInstallJob) -> None: def _signal_job_errored(self, job: ModelInstallJob) -> None:
self._logger.error(f"Model install error: {job.source}\n{job.error_type}: {job.error}") self._logger.error(f"Model install error: {job.source}\n{job.error_type}: {job.error}")
if self._event_bus: if self._event_bus:
error_type = job.error_type assert job.error_type is not None
error = job.error assert job.error is not None
assert error_type is not None self._event_bus.emit_model_install_error(job)
assert error is not None
self._event_bus.emit_model_install_error(str(job.source), error_type, error, id=job.id)
def _signal_job_cancelled(self, job: ModelInstallJob) -> None: def _signal_job_cancelled(self, job: ModelInstallJob) -> None:
self._logger.info(f"Model install canceled: {job.source}") self._logger.info(f"Model install canceled: {job.source}")
if self._event_bus: if self._event_bus:
self._event_bus.emit_model_install_cancelled(str(job.source), id=job.id) self._event_bus.emit_model_install_cancelled(job)
@staticmethod @staticmethod
def get_fetcher_from_url(url: str) -> ModelMetadataFetchBase: def get_fetcher_from_url(url: str) -> ModelMetadataFetchBase:

View File

@ -4,7 +4,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional from typing import Optional
from invokeai.app.services.shared.invocation_context import InvocationContextData
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
from invokeai.backend.model_manager.load import LoadedModel from invokeai.backend.model_manager.load import LoadedModel
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
@ -15,18 +14,12 @@ class ModelLoadServiceBase(ABC):
"""Wrapper around AnyModelLoader.""" """Wrapper around AnyModelLoader."""
@abstractmethod @abstractmethod
def load_model( def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
self,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
""" """
Given a model's configuration, load it and return the LoadedModel object. Given a model's configuration, load it and return the LoadedModel object.
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model()) :param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
:param submodel: For main (pipeline models), the submodel to fetch. :param submodel: For main (pipeline models), the submodel to fetch.
:param context_data: Invocation context data used for event reporting
""" """
@property @property

View File

@ -5,7 +5,6 @@ from typing import Optional, Type
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.invocation_context import InvocationContextData
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
from invokeai.backend.model_manager.load import ( from invokeai.backend.model_manager.load import (
LoadedModel, LoadedModel,
@ -57,25 +56,18 @@ class ModelLoadService(ModelLoadServiceBase):
"""Return the checkpoint convert cache used by this loader.""" """Return the checkpoint convert cache used by this loader."""
return self._convert_cache return self._convert_cache
def load_model( def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
self,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
""" """
Given a model's configuration, load it and return the LoadedModel object. Given a model's configuration, load it and return the LoadedModel object.
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model()) :param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
:param submodel: For main (pipeline models), the submodel to fetch. :param submodel: For main (pipeline models), the submodel to fetch.
:param context: Invocation context used for event reporting
""" """
if context_data:
self._emit_load_event( # We don't have an invoker during testing
context_data=context_data, # TODO(psyche): Mock this method on the invoker in the tests
model_config=model_config, if hasattr(self, "_invoker"):
submodel_type=submodel_type, self._invoker.services.events.emit_model_load_started(model_config, submodel_type)
)
implementation, model_config, submodel_type = self._registry.get_implementation(model_config, submodel_type) # type: ignore implementation, model_config, submodel_type = self._registry.get_implementation(model_config, submodel_type) # type: ignore
loaded_model: LoadedModel = implementation( loaded_model: LoadedModel = implementation(
@ -85,40 +77,7 @@ class ModelLoadService(ModelLoadServiceBase):
convert_cache=self._convert_cache, convert_cache=self._convert_cache,
).load_model(model_config, submodel_type) ).load_model(model_config, submodel_type)
if context_data: if hasattr(self, "_invoker"):
self._emit_load_event( self._invoker.services.events.emit_model_load_complete(model_config, submodel_type)
context_data=context_data,
model_config=model_config,
submodel_type=submodel_type,
loaded=True,
)
return loaded_model return loaded_model
def _emit_load_event(
self,
context_data: InvocationContextData,
model_config: AnyModelConfig,
loaded: Optional[bool] = False,
submodel_type: Optional[SubModelType] = None,
) -> None:
if not self._invoker:
return
if not loaded:
self._invoker.services.events.emit_model_load_started(
queue_id=context_data.queue_item.queue_id,
queue_item_id=context_data.queue_item.item_id,
queue_batch_id=context_data.queue_item.batch_id,
graph_execution_state_id=context_data.queue_item.session_id,
model_config=model_config,
submodel_type=submodel_type,
)
else:
self._invoker.services.events.emit_model_load_completed(
queue_id=context_data.queue_item.queue_id,
queue_item_id=context_data.queue_item.item_id,
queue_batch_id=context_data.queue_item.batch_id,
graph_execution_state_id=context_data.queue_item.session_id,
model_config=model_config,
submodel_type=submodel_type,
)

View File

@ -1,6 +1,49 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from threading import Event
from typing import Optional, Protocol
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
from invokeai.app.util.profiler import Profiler
class SessionRunnerBase(ABC):
"""
Base class for session runner.
"""
@abstractmethod
def start(self, services: InvocationServices, cancel_event: Event, profiler: Optional[Profiler] = None) -> None:
"""Starts the session runner.
Args:
services: The invocation services.
cancel_event: The cancel event.
profiler: The profiler to use for session profiling via cProfile. Omit to disable profiling. Basic session
stats will be still be recorded and logged when profiling is disabled.
"""
pass
@abstractmethod
def run(self, queue_item: SessionQueueItem) -> None:
"""Runs a session.
Args:
queue_item: The session to run.
"""
pass
@abstractmethod
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> None:
"""Run a single node in the graph.
Args:
invocation: The invocation to run.
queue_item: The session queue item.
"""
pass
class SessionProcessorBase(ABC): class SessionProcessorBase(ABC):
@ -26,3 +69,85 @@ class SessionProcessorBase(ABC):
def get_status(self) -> SessionProcessorStatus: def get_status(self) -> SessionProcessorStatus:
"""Gets the status of the session processor""" """Gets the status of the session processor"""
pass pass
class OnBeforeRunNode(Protocol):
def __call__(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> None:
"""Callback to run before executing a node.
Args:
invocation: The invocation that will be executed.
queue_item: The session queue item.
"""
...
class OnAfterRunNode(Protocol):
def __call__(self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput) -> None:
"""Callback to run before executing a node.
Args:
invocation: The invocation that was executed.
queue_item: The session queue item.
"""
...
class OnNodeError(Protocol):
def __call__(
self,
invocation: BaseInvocation,
queue_item: SessionQueueItem,
error_type: str,
error_message: str,
error_traceback: str,
) -> None:
"""Callback to run when a node has an error.
Args:
invocation: The invocation that errored.
queue_item: The session queue item.
error_type: The type of error, e.g. "ValueError".
error_message: The error message, e.g. "Invalid value".
error_traceback: The stringified error traceback.
"""
...
class OnBeforeRunSession(Protocol):
def __call__(self, queue_item: SessionQueueItem) -> None:
"""Callback to run before executing a session.
Args:
queue_item: The session queue item.
"""
...
class OnAfterRunSession(Protocol):
def __call__(self, queue_item: SessionQueueItem) -> None:
"""Callback to run after executing a session.
Args:
queue_item: The session queue item.
"""
...
class OnNonFatalProcessorError(Protocol):
def __call__(
self,
queue_item: Optional[SessionQueueItem],
error_type: str,
error_message: str,
error_traceback: str,
) -> None:
"""Callback to run when a non-fatal error occurs in the processor.
Args:
queue_item: The session queue item, if one was being executed when the error occurred.
error_type: The type of error, e.g. "ValueError".
error_message: The error message, e.g. "Invalid value".
error_traceback: The stringified error traceback.
"""
...

View File

@ -1,33 +1,334 @@
import traceback import traceback
from contextlib import suppress from contextlib import suppress
from queue import Queue from queue import Queue
from threading import BoundedSemaphore, Lock, Thread from threading import BoundedSemaphore, Thread, Lock
from threading import Event as ThreadEvent from threading import Event as ThreadEvent
from typing import Optional, Set from typing import Optional, Set
from fastapi_events.handlers.local import local_handler from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from fastapi_events.typing import Event as FastAPIEvent from invokeai.app.services.events.events_common import (
BatchEnqueuedEvent,
from invokeai.app.invocations.baseinvocation import BaseInvocation FastAPIEvent,
from invokeai.app.services.events.events_base import EventServiceBase QueueClearedEvent,
QueueItemStatusChangedEvent,
register_events,
)
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService from invokeai.app.services.session_processor.session_processor_base import (
OnAfterRunNode,
OnAfterRunSession,
OnBeforeRunNode,
OnBeforeRunSession,
OnNodeError,
OnNonFatalProcessorError,
)
from invokeai.app.services.session_processor.session_processor_common import CanceledException from invokeai.app.services.session_processor.session_processor_common import CanceledException
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem, SessionQueueItemNotFoundError
from invokeai.app.services.shared.graph import NodeInputError
from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context
from invokeai.app.util.profiler import Profiler from invokeai.app.util.profiler import Profiler
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
from ..invoker import Invoker from ..invoker import Invoker
from .session_processor_base import SessionProcessorBase from .session_processor_base import InvocationServices, SessionProcessorBase, SessionRunnerBase
from .session_processor_common import SessionProcessorStatus from .session_processor_common import SessionProcessorStatus
class DefaultSessionRunner(SessionRunnerBase):
"""Processes a single session's invocations."""
def __init__(
self,
on_before_run_session_callbacks: Optional[list[OnBeforeRunSession]] = None,
on_before_run_node_callbacks: Optional[list[OnBeforeRunNode]] = None,
on_after_run_node_callbacks: Optional[list[OnAfterRunNode]] = None,
on_node_error_callbacks: Optional[list[OnNodeError]] = None,
on_after_run_session_callbacks: Optional[list[OnAfterRunSession]] = None,
):
"""
Args:
on_before_run_session_callbacks: Callbacks to run before the session starts.
on_before_run_node_callbacks: Callbacks to run before each node starts.
on_after_run_node_callbacks: Callbacks to run after each node completes.
on_node_error_callbacks: Callbacks to run when a node errors.
on_after_run_session_callbacks: Callbacks to run after the session completes.
"""
self._on_before_run_session_callbacks = on_before_run_session_callbacks or []
self._on_before_run_node_callbacks = on_before_run_node_callbacks or []
self._on_after_run_node_callbacks = on_after_run_node_callbacks or []
self._on_node_error_callbacks = on_node_error_callbacks or []
self._on_after_run_session_callbacks = on_after_run_session_callbacks or []
self._process_lock = Lock()
def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None) -> None:
self._services = services
self._cancel_event = cancel_event
self._profiler = profiler
def _is_canceled(self) -> bool:
"""Check if the cancel event is set. This is also passed to the invocation context builder and called during
denoising to check if the session has been canceled."""
return self._cancel_event.is_set()
def run(self, queue_item: SessionQueueItem):
# Exceptions raised outside `run_node` are handled by the processor. There is no need to catch them here.
self._on_before_run_session(queue_item=queue_item)
# Loop over invocations until the session is complete or canceled
while True:
try:
with self._process_lock:
invocation = queue_item.session.next()
# Anything other than a `NodeInputError` is handled as a processor error
except NodeInputError as e:
error_type = e.__class__.__name__
error_message = str(e)
error_traceback = traceback.format_exc()
self._on_node_error(
invocation=e.node,
queue_item=queue_item,
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)
break
if invocation is None or self._is_canceled():
break
self.run_node(invocation, queue_item)
# The session is complete if all invocations have been run or there is an error on the session.
# At this time, the queue item may be canceled, but the object itself here won't be updated yet. We must
# use the cancel event to check if the session is canceled.
if (
queue_item.session.is_complete()
or self._is_canceled()
or queue_item.status in ["failed", "canceled", "completed"]
):
break
self._on_after_run_session(queue_item=queue_item)
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> None:
try:
# Any unhandled exception in this scope is an invocation error & will fail the graph
with self._services.performance_statistics.collect_stats(invocation, queue_item.session_id):
self._on_before_run_node(invocation, queue_item)
data = InvocationContextData(
invocation=invocation,
source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id],
queue_item=queue_item,
)
context = build_invocation_context(
data=data,
services=self._services,
is_canceled=self._is_canceled,
)
# Invoke the node
output = invocation.invoke_internal(context=context, services=self._services)
# Save output and history
queue_item.session.complete(invocation.id, output)
self._on_after_run_node(invocation, queue_item, output)
except KeyboardInterrupt:
# TODO(psyche): This is expected to be caught in the main thread. Do we need to catch this here?
pass
except CanceledException:
# A CanceledException is raised during the denoising step callback if the cancel event is set. We don't need
# to do any handling here, and no error should be set - just pass and the cancellation will be handled
# correctly in the next iteration of the session runner loop.
#
# See the comment in the processor's `_on_queue_item_status_changed()` method for more details on how we
# handle cancellation.
pass
except Exception as e:
error_type = e.__class__.__name__
error_message = str(e)
error_traceback = traceback.format_exc()
self._on_node_error(
invocation=invocation,
queue_item=queue_item,
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)
def _on_before_run_session(self, queue_item: SessionQueueItem) -> None:
"""Called before a session is run.
- Start the profiler if profiling is enabled.
- Run any callbacks registered for this event.
"""
self._services.logger.debug(
f"On before run session: queue item {queue_item.item_id}, session {queue_item.session_id}"
)
# If profiling is enabled, start the profiler
if self._profiler is not None:
self._profiler.start(profile_id=queue_item.session_id)
for callback in self._on_before_run_session_callbacks:
callback(queue_item=queue_item)
def _on_after_run_session(self, queue_item: SessionQueueItem) -> None:
"""Called after a session is run.
- Stop the profiler if profiling is enabled.
- Update the queue item's session object in the database.
- If not already canceled or failed, complete the queue item.
- Log and reset performance statistics.
- Run any callbacks registered for this event.
"""
self._services.logger.debug(
f"On after run session: queue item {queue_item.item_id}, session {queue_item.session_id}"
)
# If we are profiling, stop the profiler and dump the profile & stats
if self._profiler is not None:
profile_path = self._profiler.stop()
stats_path = profile_path.with_suffix(".json")
self._services.performance_statistics.dump_stats(
graph_execution_state_id=queue_item.session.id, output_path=stats_path
)
try:
# Update the queue item with the completed session. If the queue item has been removed from the queue,
# we'll get a SessionQueueItemNotFoundError and we can ignore it. This can happen if the queue is cleared
# while the session is running.
queue_item = self._services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session)
# The queue item may have been canceled or failed while the session was running. We should only complete it
# if it is not already canceled or failed.
if queue_item.status not in ["canceled", "failed"]:
queue_item = self._services.session_queue.complete_queue_item(queue_item.item_id)
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
# we don't care about that - suppress the error.
with suppress(GESStatsNotFoundError):
self._services.performance_statistics.log_stats(queue_item.session.id)
self._services.performance_statistics.reset_stats()
for callback in self._on_after_run_session_callbacks:
callback(queue_item=queue_item)
except SessionQueueItemNotFoundError:
pass
def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
"""Called before a node is run.
- Emits an invocation started event.
- Run any callbacks registered for this event.
"""
self._services.logger.debug(
f"On before run node: queue item {queue_item.item_id}, session {queue_item.session_id}, node {invocation.id} ({invocation.get_type()})"
)
# Send starting event
self._services.events.emit_invocation_started(queue_item=queue_item, invocation=invocation)
for callback in self._on_before_run_node_callbacks:
callback(invocation=invocation, queue_item=queue_item)
def _on_after_run_node(
self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput
):
"""Called after a node is run.
- Emits an invocation complete event.
- Run any callbacks registered for this event.
"""
self._services.logger.debug(
f"On after run node: queue item {queue_item.item_id}, session {queue_item.session_id}, node {invocation.id} ({invocation.get_type()})"
)
# Send complete event on successful runs
self._services.events.emit_invocation_complete(invocation=invocation, queue_item=queue_item, output=output)
for callback in self._on_after_run_node_callbacks:
callback(invocation=invocation, queue_item=queue_item, output=output)
def _on_node_error(
self,
invocation: BaseInvocation,
queue_item: SessionQueueItem,
error_type: str,
error_message: str,
error_traceback: str,
):
"""Called when a node errors. Node errors may occur when running or preparing the node..
- Set the node error on the session object.
- Log the error.
- Fail the queue item.
- Emits an invocation error event.
- Run any callbacks registered for this event.
"""
self._services.logger.debug(
f"On node error: queue item {queue_item.item_id}, session {queue_item.session_id}, node {invocation.id} ({invocation.get_type()})"
)
# Node errors do not get the full traceback. Only the queue item gets the full traceback.
node_error = f"{error_type}: {error_message}"
queue_item.session.set_node_error(invocation.id, node_error)
self._services.logger.error(
f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}): {error_message}"
)
self._services.logger.error(error_traceback)
# Fail the queue item
queue_item = self._services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session)
queue_item = self._services.session_queue.fail_queue_item(
queue_item.item_id, error_type, error_message, error_traceback
)
# Send error event
self._services.events.emit_invocation_error(
queue_item=queue_item,
invocation=invocation,
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)
for callback in self._on_node_error_callbacks:
callback(
invocation=invocation,
queue_item=queue_item,
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)
class DefaultSessionProcessor(SessionProcessorBase): class DefaultSessionProcessor(SessionProcessorBase):
def start(self, invoker: Invoker, polling_interval: int = 1) -> None: def __init__(
self,
session_runner: Optional[SessionRunnerBase] = None,
on_non_fatal_processor_error_callbacks: Optional[list[OnNonFatalProcessorError]] = None,
thread_limit: int = 1,
polling_interval: int = 1,
) -> None:
super().__init__()
self.session_runner = session_runner if session_runner else DefaultSessionRunner()
self._on_non_fatal_processor_error_callbacks = on_non_fatal_processor_error_callbacks or []
self._thread_limit = thread_limit
self._polling_interval = polling_interval
def start(self, invoker: Invoker) -> None:
self._invoker: Invoker = invoker self._invoker: Invoker = invoker
self._queue_items: Set[int] = set() self._active_queue_items: Set[SessionQueueItem] = set()
self._sessions_to_cancel: Set[int] = set()
self._invocation: Optional[BaseInvocation] = None self._invocation: Optional[BaseInvocation] = None
self._resume_event = ThreadEvent() self._resume_event = ThreadEvent()
@ -35,17 +336,11 @@ class DefaultSessionProcessor(SessionProcessorBase):
self._poll_now_event = ThreadEvent() self._poll_now_event = ThreadEvent()
self._cancel_event = ThreadEvent() self._cancel_event = ThreadEvent()
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event) register_events(QueueClearedEvent, self._on_queue_cleared)
register_events(BatchEnqueuedEvent, self._on_batch_enqueued)
register_events(QueueItemStatusChangedEvent, self._on_queue_item_status_changed)
self._thread_limit = 1
self._thread_semaphore = BoundedSemaphore(self._thread_limit) self._thread_semaphore = BoundedSemaphore(self._thread_limit)
self._polling_interval = polling_interval
self._worker_thread_count = self._invoker.services.configuration.max_threads or len(
TorchDevice.execution_devices()
)
self._session_worker_queue: Queue[SessionQueueItem] = Queue()
self._process_lock = Lock()
# If profiling is enabled, create a profiler. The same profiler will be used for all sessions. Internally, # If profiling is enabled, create a profiler. The same profiler will be used for all sessions. Internally,
# the profiler will create a new profile for each session. # the profiler will create a new profile for each session.
@ -59,7 +354,14 @@ class DefaultSessionProcessor(SessionProcessorBase):
else None else None
) )
# main session processor loop - single thread self._worker_thread_count = self._invoker.services.configuration.max_threads or len(
TorchDevice.execution_devices()
)
self._session_worker_queue: Queue[SessionQueueItem] = Queue()
self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event, profiler=self._profiler)
# Session processor - singlethreaded
self._thread = Thread( self._thread = Thread(
name="session_processor", name="session_processor",
target=self._process, target=self._process,
@ -82,31 +384,33 @@ class DefaultSessionProcessor(SessionProcessorBase):
) )
worker.start() worker.start()
def stop(self, *args, **kwargs) -> None: def stop(self, *args, **kwargs) -> None:
self._stop_event.set() self._stop_event.set()
def _poll_now(self) -> None: def _poll_now(self) -> None:
self._poll_now_event.set() self._poll_now_event.set()
async def _on_queue_event(self, event: FastAPIEvent) -> None: async def _on_queue_cleared(self, event: FastAPIEvent[QueueClearedEvent]) -> None:
event_name = event[1]["event"] if any(item.queue_id == event[1].queue_id for item in self._active_queue_items):
self._cancel_event.set()
self._poll_now()
if event_name == "session_canceled" and event[1]["data"]["queue_item_id"] in self._queue_items: async def _on_batch_enqueued(self, event: FastAPIEvent[BatchEnqueuedEvent]) -> None:
self._sessions_to_cancel.add(event[1]["data"]["queue_item_id"]) self._poll_now()
async def _on_queue_item_status_changed(self, event: FastAPIEvent[QueueItemStatusChangedEvent]) -> None:
if self._active_queue_items and event[1].status in ["completed", "failed", "canceled"]:
# When the queue item is canceled via HTTP, the queue item status is set to `"canceled"` and this event is
# emitted. We need to respond to this event and stop graph execution. This is done by setting the cancel
# event, which the session runner checks between invocations. If set, the session runner loop is broken.
#
# Long-running nodes that cannot be interrupted easily present a challenge. `denoise_latents` is one such
# node, but it gets a step callback, called on each step of denoising. This callback checks if the queue item
# is canceled, and if it is, raises a `CanceledException` to stop execution immediately.
if event[1].status == "canceled":
self._cancel_event.set() self._cancel_event.set()
self._poll_now() self._poll_now()
elif event_name == "queue_cleared" and event[1]["data"]["queue_id"] in self._queue_items:
self._sessions_to_cancel.add(event[1]["data"]["queue_item_id"])
self._cancel_event.set()
self._poll_now()
elif event_name == "batch_enqueued":
self._poll_now()
elif event_name == "queue_item_status_changed" and event[1]["data"]["queue_item"]["status"] in [
"completed",
"failed",
"canceled",
]:
self._poll_now()
def resume(self) -> SessionProcessorStatus: def resume(self) -> SessionProcessorStatus:
if not self._resume_event.is_set(): if not self._resume_event.is_set():
@ -121,7 +425,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
def get_status(self) -> SessionProcessorStatus: def get_status(self) -> SessionProcessorStatus:
return SessionProcessorStatus( return SessionProcessorStatus(
is_started=self._resume_event.is_set(), is_started=self._resume_event.is_set(),
is_processing=len(self._queue_items) > 0, is_processing=len(self._active_queue_items) > 0,
) )
def _process( def _process(
@ -130,9 +434,9 @@ class DefaultSessionProcessor(SessionProcessorBase):
poll_now_event: ThreadEvent, poll_now_event: ThreadEvent,
resume_event: ThreadEvent, resume_event: ThreadEvent,
cancel_event: ThreadEvent, cancel_event: ThreadEvent,
) -> None: ):
# Outermost processor try block; any unhandled exception is a fatal processor error
try: try:
# Any unhandled exception in this block is a fatal processor error and will stop the processor.
self._thread_semaphore.acquire() self._thread_semaphore.acquire()
stop_event.clear() stop_event.clear()
resume_event.set() resume_event.set()
@ -140,198 +444,94 @@ class DefaultSessionProcessor(SessionProcessorBase):
while not stop_event.is_set(): while not stop_event.is_set():
poll_now_event.clear() poll_now_event.clear()
try:
# Any unhandled exception in this block is a nonfatal processor error and will be handled.
# If we are paused, wait for resume event
resume_event.wait() resume_event.wait()
# Get the next session to process # Get the next session to process
session = self._invoker.services.session_queue.dequeue() queue_item = self._invoker.services.session_queue.dequeue()
if session is None: if queue_item is None:
# The queue was empty, wait for next polling interval or event to try again # The queue was empty, wait for next polling interval or event to try again
self._invoker.services.logger.debug("Waiting for next polling interval or event") self._invoker.services.logger.debug("Waiting for next polling interval or event")
poll_now_event.wait(self._polling_interval) poll_now_event.wait(self._polling_interval)
continue continue
self._queue_items.add(session.item_id) self._session_worker_queue.put(queue_item)
self._session_worker_queue.put(session) self._invoker.services.logger.debug(f"Scheduling queue item {queue_item.item_id} to run")
self._invoker.services.logger.debug(f"Executing queue item {session.item_id}")
cancel_event.clear() cancel_event.clear()
except Exception:
# Run the graph
# self.session_runner.run(queue_item=self._queue_item)
except Exception as e:
# Wait for next polling interval or event to try again
poll_now_event.wait(self._polling_interval)
continue
except Exception as e:
# Fatal error in processor, log and pass - we're done here # Fatal error in processor, log and pass - we're done here
self._invoker.services.logger.error(f"Fatal Error in session processor:\n{traceback.format_exc()}") error_type = e.__class__.__name__
error_message = str(e)
error_traceback = traceback.format_exc()
self._invoker.services.logger.error(f"Fatal Error in session processor {error_type}: {error_message}")
self._invoker.services.logger.error(error_traceback)
pass pass
finally: finally:
stop_event.clear() stop_event.clear()
poll_now_event.clear() poll_now_event.clear()
self._queue_items.clear()
self._thread_semaphore.release() self._thread_semaphore.release()
def _process_next_session(self) -> None: def _process_next_session(self) -> None:
profiler = (
Profiler(
logger=self._invoker.services.logger,
output_dir=self._invoker.services.configuration.profiles_path,
prefix=self._invoker.services.configuration.profile_prefix,
)
if self._invoker.services.configuration.profile_graphs
else None
)
stats_service = InvocationStatsService()
stats_service.start(self._invoker)
while True: while True:
# Outer try block. Any error here is a fatal processor error
try:
self._resume_event.wait() self._resume_event.wait()
session = self._session_worker_queue.get() queue_item = self._session_worker_queue.get()
if queue_item.status == "canceled":
if self._cancel_event.is_set():
if session.item_id in self._sessions_to_cancel:
continue continue
try:
if profiler is not None: self._active_queue_items.add(queue_item)
profiler.start(profile_id=session.session_id)
# reserve a GPU for this session - may block # reserve a GPU for this session - may block
with self._invoker.services.model_manager.load.ram_cache.reserve_execution_device(): with self._invoker.services.model_manager.load.ram_cache.reserve_execution_device():
# Prepare invocations and take the first # Run the session on the reserved GPU
with self._process_lock: self.session_runner.run(queue_item=queue_item)
invocation = session.session.next()
# Loop over invocations until the session is complete or canceled
while invocation is not None and not self._cancel_event.is_set():
self._process_next_invocation(session, invocation, stats_service)
# The session is complete if all invocations are complete or there was an error
if session.session.is_complete():
# Send complete event
self._invoker.services.events.emit_graph_execution_complete(
queue_batch_id=session.batch_id,
queue_item_id=session.item_id,
queue_id=session.queue_id,
graph_execution_state_id=session.session.id,
)
# Log stats
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
# we don't care about that - suppress the error.
with suppress(GESStatsNotFoundError):
stats_service.log_stats(session.session.id)
stats_service.reset_stats()
# If we are profiling, stop the profiler and dump the profile & stats
if self._profiler:
profile_path = self._profiler.stop()
stats_path = profile_path.with_suffix(".json")
stats_service.dump_stats(
graph_execution_state_id=session.session.id, output_path=stats_path
)
self._queue_items.remove(session.item_id)
invocation = None
else:
# Prepare the next invocation
with self._process_lock:
invocation = session.session.next()
except Exception:
# Non-fatal error in processor
self._invoker.services.logger.error(f"Non-fatal error in session processor:\n{traceback.format_exc()}")
# Cancel the queue item
if session is not None:
self._invoker.services.session_queue.cancel_queue_item(
session.item_id, error=traceback.format_exc()
)
finally:
self._session_worker_queue.task_done()
def _process_next_invocation(
self,
session: SessionQueueItem,
invocation: BaseInvocation,
stats_service: InvocationStatsService,
) -> None:
# get the source node id to provide to clients (the prepared node id is not as useful)
source_invocation_id = session.session.prepared_source_mapping[invocation.id]
self._invoker.services.logger.debug(f"Executing invocation {session.session.id}:{source_invocation_id}")
# Send starting event
self._invoker.services.events.emit_invocation_started(
queue_batch_id=session.batch_id,
queue_item_id=session.item_id,
queue_id=session.queue_id,
graph_execution_state_id=session.session_id,
node=invocation.model_dump(),
source_node_id=source_invocation_id,
)
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
try:
# Build invocation context (the node-facing API)
data = InvocationContextData(
invocation=invocation,
source_invocation_id=source_invocation_id,
queue_item=session,
)
context = build_invocation_context(
data=data,
services=self._invoker.services,
cancel_event=self._cancel_event,
)
# Invoke the node
# title = invocation.UIConfig.title
with stats_service.collect_stats(invocation, session.session.id):
outputs = invocation.invoke_internal(context=context, services=self._invoker.services)
# Save outputs and history
session.session.complete(invocation.id, outputs)
# Send complete event
self._invoker.services.events.emit_invocation_complete(
queue_batch_id=session.batch_id,
queue_item_id=session.item_id,
queue_id=session.queue_id,
graph_execution_state_id=session.session.id,
node=invocation.model_dump(),
source_node_id=source_invocation_id,
result=outputs.model_dump(),
)
except KeyboardInterrupt:
# TODO(MM2): Create an event for this
pass
except CanceledException:
# When the user cancels the graph, we first set the cancel event. The event is checked
# between invocations, in this loop. Some invocations are long-running, and we need to
# be able to cancel them mid-execution.
#
# For example, denoising is a long-running invocation with many steps. A step callback
# is executed after each step. This step callback checks if the canceled event is set,
# then raises a CanceledException to stop execution immediately.
#
# When we get a CanceledException, we don't need to do anything - just pass and let the
# loop go to its next iteration, and the cancel event will be handled correctly.
pass
except Exception as e: except Exception as e:
error = traceback.format_exc() continue
finally:
self._active_queue_items.remove(queue_item)
# Save error def _on_non_fatal_processor_error(
session.session.set_node_error(invocation.id, error) self,
self._invoker.services.logger.error( queue_item: Optional[SessionQueueItem],
f"Error while invoking session {session.session_id}, invocation {invocation.id} ({invocation.get_type()}):\n{e}" error_type: str,
) error_message: str,
self._invoker.services.logger.error(error) error_traceback: str,
) -> None:
"""Called when a non-fatal error occurs in the processor.
# Send error event - Log the error.
self._invoker.services.events.emit_invocation_error( - If a queue item is provided, update the queue item with the completed session & fail it.
queue_batch_id=session.session_id, - Run any callbacks registered for this event.
queue_item_id=session.item_id, """
queue_id=session.queue_id,
graph_execution_state_id=session.session.id, self._invoker.services.logger.error(f"Non-fatal error in session processor {error_type}: {error_message}")
node=invocation.model_dump(), self._invoker.services.logger.error(error_traceback)
source_node_id=source_invocation_id,
error_type=e.__class__.__name__, if queue_item is not None:
error=error, # Update the queue item with the completed session & fail it
queue_item = self._invoker.services.session_queue.set_queue_item_session(
queue_item.item_id, queue_item.session
)
queue_item = self._invoker.services.session_queue.fail_queue_item(
item_id=queue_item.item_id,
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)
for callback in self._on_non_fatal_processor_error_callbacks:
callback(
queue_item=queue_item,
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
) )

View File

@ -16,6 +16,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
SessionQueueItemDTO, SessionQueueItemDTO,
SessionQueueStatus, SessionQueueStatus,
) )
from invokeai.app.services.shared.graph import GraphExecutionState
from invokeai.app.services.shared.pagination import CursorPaginatedResults from invokeai.app.services.shared.pagination import CursorPaginatedResults
@ -73,10 +74,22 @@ class SessionQueueBase(ABC):
pass pass
@abstractmethod @abstractmethod
def cancel_queue_item(self, item_id: int, error: Optional[str] = None) -> SessionQueueItem: def complete_queue_item(self, item_id: int) -> SessionQueueItem:
"""Completes a session queue item"""
pass
@abstractmethod
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
"""Cancels a session queue item""" """Cancels a session queue item"""
pass pass
@abstractmethod
def fail_queue_item(
self, item_id: int, error_type: str, error_message: str, error_traceback: str
) -> SessionQueueItem:
"""Fails a session queue item"""
pass
@abstractmethod @abstractmethod
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult: def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
"""Cancels all queue items with matching batch IDs""" """Cancels all queue items with matching batch IDs"""
@ -103,3 +116,8 @@ class SessionQueueBase(ABC):
def get_queue_item(self, item_id: int) -> SessionQueueItem: def get_queue_item(self, item_id: int) -> SessionQueueItem:
"""Gets a session queue item by ID""" """Gets a session queue item by ID"""
pass pass
@abstractmethod
def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> SessionQueueItem:
"""Sets the session for a session queue item. Use this to update the session state."""
pass

View File

@ -3,7 +3,16 @@ import json
from itertools import chain, product from itertools import chain, product
from typing import Generator, Iterable, Literal, NamedTuple, Optional, TypeAlias, Union, cast from typing import Generator, Iterable, Literal, NamedTuple, Optional, TypeAlias, Union, cast
from pydantic import BaseModel, ConfigDict, Field, StrictStr, TypeAdapter, field_validator, model_validator from pydantic import (
AliasChoices,
BaseModel,
ConfigDict,
Field,
StrictStr,
TypeAdapter,
field_validator,
model_validator,
)
from pydantic_core import to_jsonable_python from pydantic_core import to_jsonable_python
from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.invocations.baseinvocation import BaseInvocation
@ -189,7 +198,13 @@ class SessionQueueItemWithoutGraph(BaseModel):
session_id: str = Field( session_id: str = Field(
description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed." description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed."
) )
error: Optional[str] = Field(default=None, description="The error message if this queue item errored") error_type: Optional[str] = Field(default=None, description="The error type if this queue item errored")
error_message: Optional[str] = Field(default=None, description="The error message if this queue item errored")
error_traceback: Optional[str] = Field(
default=None,
description="The error traceback if this queue item errored",
validation_alias=AliasChoices("error_traceback", "error"),
)
created_at: Union[datetime.datetime, str] = Field(description="When this queue item was created") created_at: Union[datetime.datetime, str] = Field(description="When this queue item was created")
updated_at: Union[datetime.datetime, str] = Field(description="When this queue item was updated") updated_at: Union[datetime.datetime, str] = Field(description="When this queue item was updated")
started_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was started") started_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was started")
@ -221,6 +236,8 @@ class SessionQueueItemWithoutGraph(BaseModel):
} }
) )
def __hash__(self) -> int:
return self.item_id
class SessionQueueItemDTO(SessionQueueItemWithoutGraph): class SessionQueueItemDTO(SessionQueueItemWithoutGraph):
pass pass

View File

@ -2,10 +2,6 @@ import sqlite3
import threading import threading
from typing import Optional, Union, cast from typing import Optional, Union, cast
from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event as FastAPIEvent
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
from invokeai.app.services.session_queue.session_queue_common import ( from invokeai.app.services.session_queue.session_queue_common import (
@ -27,6 +23,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
calc_session_count, calc_session_count,
prepare_values_to_insert, prepare_values_to_insert,
) )
from invokeai.app.services.shared.graph import GraphExecutionState
from invokeai.app.services.shared.pagination import CursorPaginatedResults from invokeai.app.services.shared.pagination import CursorPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
@ -41,7 +38,7 @@ class SqliteSessionQueue(SessionQueueBase):
self.__invoker = invoker self.__invoker = invoker
self._set_in_progress_to_canceled() self._set_in_progress_to_canceled()
prune_result = self.prune(DEFAULT_QUEUE_ID) prune_result = self.prune(DEFAULT_QUEUE_ID)
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_session_event)
if prune_result.deleted > 0: if prune_result.deleted > 0:
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items") self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
@ -51,52 +48,6 @@ class SqliteSessionQueue(SessionQueueBase):
self.__conn = db.conn self.__conn = db.conn
self.__cursor = self.__conn.cursor() self.__cursor = self.__conn.cursor()
def _match_event_name(self, event: FastAPIEvent, match_in: list[str]) -> bool:
return event[1]["event"] in match_in
async def _on_session_event(self, event: FastAPIEvent) -> FastAPIEvent:
event_name = event[1]["event"]
# This was a match statement, but match is not supported on python 3.9
if event_name == "graph_execution_state_complete":
await self._handle_complete_event(event)
elif event_name == "invocation_error":
await self._handle_error_event(event)
elif event_name == "session_canceled":
await self._handle_cancel_event(event)
return event
async def _handle_complete_event(self, event: FastAPIEvent) -> None:
try:
item_id = event[1]["data"]["queue_item_id"]
# When a queue item has an error, we get an error event, then a completed event.
# Mark the queue item completed only if it isn't already marked completed, e.g.
# by a previously-handled error event.
queue_item = self.get_queue_item(item_id)
if queue_item.status not in ["completed", "failed", "canceled"]:
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="completed")
except SessionQueueItemNotFoundError:
return
async def _handle_error_event(self, event: FastAPIEvent) -> None:
try:
item_id = event[1]["data"]["queue_item_id"]
error = event[1]["data"]["error"]
queue_item = self.get_queue_item(item_id)
# always set to failed if have an error, even if previously the item was marked completed or canceled
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="failed", error=error)
except SessionQueueItemNotFoundError:
return
async def _handle_cancel_event(self, event: FastAPIEvent) -> None:
try:
item_id = event[1]["data"]["queue_item_id"]
queue_item = self.get_queue_item(item_id)
if queue_item.status not in ["completed", "failed", "canceled"]:
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="canceled")
except SessionQueueItemNotFoundError:
return
def _set_in_progress_to_canceled(self) -> None: def _set_in_progress_to_canceled(self) -> None:
""" """
Sets all in_progress queue items to canceled. Run on app startup, not associated with any queue. Sets all in_progress queue items to canceled. Run on app startup, not associated with any queue.
@ -271,17 +222,22 @@ class SqliteSessionQueue(SessionQueueBase):
return SessionQueueItem.queue_item_from_dict(dict(result)) return SessionQueueItem.queue_item_from_dict(dict(result))
def _set_queue_item_status( def _set_queue_item_status(
self, item_id: int, status: QUEUE_ITEM_STATUS, error: Optional[str] = None self,
item_id: int,
status: QUEUE_ITEM_STATUS,
error_type: Optional[str] = None,
error_message: Optional[str] = None,
error_traceback: Optional[str] = None,
) -> SessionQueueItem: ) -> SessionQueueItem:
try: try:
self.__lock.acquire() self.__lock.acquire()
self.__cursor.execute( self.__cursor.execute(
"""--sql """--sql
UPDATE session_queue UPDATE session_queue
SET status = ?, error = ? SET status = ?, error_type = ?, error_message = ?, error_traceback = ?
WHERE item_id = ? WHERE item_id = ?
""", """,
(status, error, item_id), (status, error_type, error_message, error_traceback, item_id),
) )
self.__conn.commit() self.__conn.commit()
except Exception: except Exception:
@ -292,11 +248,7 @@ class SqliteSessionQueue(SessionQueueBase):
queue_item = self.get_queue_item(item_id) queue_item = self.get_queue_item(item_id)
batch_status = self.get_batch_status(queue_id=queue_item.queue_id, batch_id=queue_item.batch_id) batch_status = self.get_batch_status(queue_id=queue_item.queue_id, batch_id=queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_item.queue_id) queue_status = self.get_queue_status(queue_id=queue_item.queue_id)
self.__invoker.services.events.emit_queue_item_status_changed( self.__invoker.services.events.emit_queue_item_status_changed(queue_item, batch_status, queue_status)
session_queue_item=queue_item,
batch_status=batch_status,
queue_status=queue_status,
)
return queue_item return queue_item
def is_empty(self, queue_id: str) -> IsEmptyResult: def is_empty(self, queue_id: str) -> IsEmptyResult:
@ -338,26 +290,6 @@ class SqliteSessionQueue(SessionQueueBase):
self.__lock.release() self.__lock.release()
return IsFullResult(is_full=is_full) return IsFullResult(is_full=is_full)
def delete_queue_item(self, item_id: int) -> SessionQueueItem:
queue_item = self.get_queue_item(item_id=item_id)
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
DELETE FROM session_queue
WHERE
item_id = ?
""",
(item_id,),
)
self.__conn.commit()
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
return queue_item
def clear(self, queue_id: str) -> ClearResult: def clear(self, queue_id: str) -> ClearResult:
try: try:
self.__lock.acquire() self.__lock.acquire()
@ -424,16 +356,27 @@ class SqliteSessionQueue(SessionQueueBase):
self.__lock.release() self.__lock.release()
return PruneResult(deleted=count) return PruneResult(deleted=count)
def cancel_queue_item(self, item_id: int, error: Optional[str] = None) -> SessionQueueItem: def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
queue_item = self.get_queue_item(item_id) queue_item = self._set_queue_item_status(item_id=item_id, status="canceled")
if queue_item.status not in ["canceled", "failed", "completed"]: return queue_item
status = "failed" if error is not None else "canceled"
queue_item = self._set_queue_item_status(item_id=item_id, status=status, error=error) # type: ignore [arg-type] # mypy seems to not narrow the Literals here def complete_queue_item(self, item_id: int) -> SessionQueueItem:
self.__invoker.services.events.emit_session_canceled( queue_item = self._set_queue_item_status(item_id=item_id, status="completed")
queue_item_id=queue_item.item_id, return queue_item
queue_id=queue_item.queue_id,
queue_batch_id=queue_item.batch_id, def fail_queue_item(
graph_execution_state_id=queue_item.session_id, self,
item_id: int,
error_type: str,
error_message: str,
error_traceback: str,
) -> SessionQueueItem:
queue_item = self._set_queue_item_status(
item_id=item_id,
status="failed",
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
) )
return queue_item return queue_item
@ -470,18 +413,10 @@ class SqliteSessionQueue(SessionQueueBase):
) )
self.__conn.commit() self.__conn.commit()
if current_queue_item is not None and current_queue_item.batch_id in batch_ids: if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
self.__invoker.services.events.emit_session_canceled(
queue_item_id=current_queue_item.item_id,
queue_id=current_queue_item.queue_id,
queue_batch_id=current_queue_item.batch_id,
graph_execution_state_id=current_queue_item.session_id,
)
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id) batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_id) queue_status = self.get_queue_status(queue_id=queue_id)
self.__invoker.services.events.emit_queue_item_status_changed( self.__invoker.services.events.emit_queue_item_status_changed(
session_queue_item=current_queue_item, current_queue_item, batch_status, queue_status
batch_status=batch_status,
queue_status=queue_status,
) )
except Exception: except Exception:
self.__conn.rollback() self.__conn.rollback()
@ -521,18 +456,10 @@ class SqliteSessionQueue(SessionQueueBase):
) )
self.__conn.commit() self.__conn.commit()
if current_queue_item is not None and current_queue_item.queue_id == queue_id: if current_queue_item is not None and current_queue_item.queue_id == queue_id:
self.__invoker.services.events.emit_session_canceled(
queue_item_id=current_queue_item.item_id,
queue_id=current_queue_item.queue_id,
queue_batch_id=current_queue_item.batch_id,
graph_execution_state_id=current_queue_item.session_id,
)
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id) batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_id) queue_status = self.get_queue_status(queue_id=queue_id)
self.__invoker.services.events.emit_queue_item_status_changed( self.__invoker.services.events.emit_queue_item_status_changed(
session_queue_item=current_queue_item, current_queue_item, batch_status, queue_status
batch_status=batch_status,
queue_status=queue_status,
) )
except Exception: except Exception:
self.__conn.rollback() self.__conn.rollback()
@ -562,6 +489,29 @@ class SqliteSessionQueue(SessionQueueBase):
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}") raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
return SessionQueueItem.queue_item_from_dict(dict(result)) return SessionQueueItem.queue_item_from_dict(dict(result))
def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> SessionQueueItem:
try:
# Use exclude_none so we don't end up with a bunch of nulls in the graph - this can cause validation errors
# when the graph is loaded. Graph execution occurs purely in memory - the session saved here is not referenced
# during execution.
session_json = session.model_dump_json(warnings=False, exclude_none=True)
self.__lock.acquire()
self.__cursor.execute(
"""--sql
UPDATE session_queue
SET session = ?
WHERE item_id = ?
""",
(session_json, item_id),
)
self.__conn.commit()
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
return self.get_queue_item(item_id)
def list_queue_items( def list_queue_items(
self, self,
queue_id: str, queue_id: str,
@ -578,7 +528,9 @@ class SqliteSessionQueue(SessionQueueBase):
status, status,
priority, priority,
field_values, field_values,
error, error_type,
error_message,
error_traceback,
created_at, created_at,
updated_at, updated_at,
completed_at, completed_at,

View File

@ -2,17 +2,19 @@
import copy import copy
import itertools import itertools
from typing import Annotated, Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints from typing import Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
import networkx as nx import networkx as nx
from pydantic import ( from pydantic import (
BaseModel, BaseModel,
GetCoreSchemaHandler,
GetJsonSchemaHandler, GetJsonSchemaHandler,
ValidationError,
field_validator, field_validator,
) )
from pydantic.fields import Field from pydantic.fields import Field
from pydantic.json_schema import JsonSchemaValue from pydantic.json_schema import JsonSchemaValue
from pydantic_core import CoreSchema from pydantic_core import core_schema
# Importing * is bad karma but needed here for node detection # Importing * is bad karma but needed here for node detection
from invokeai.app.invocations import * # noqa: F401 F403 from invokeai.app.invocations import * # noqa: F401 F403
@ -190,6 +192,39 @@ class UnknownGraphValidationError(ValueError):
pass pass
class NodeInputError(ValueError):
"""Raised when a node fails preparation. This occurs when a node's inputs are being set from its incomers, but an
input fails validation.
Attributes:
node: The node that failed preparation. Note: only successfully set fields will be accurate. Review the error to
determine which field caused the failure.
"""
def __init__(self, node: BaseInvocation, e: ValidationError):
self.original_error = e
self.node = node
# When preparing a node, we set each input one-at-a-time. We may thus safely assume that the first error
# represents the first input that failed.
self.failed_input = loc_to_dot_sep(e.errors()[0]["loc"])
super().__init__(f"Node {node.id} has invalid incoming input for {self.failed_input}")
def loc_to_dot_sep(loc: tuple[Union[str, int], ...]) -> str:
"""Helper to pretty-print pydantic error locations as dot-separated strings.
Taken from https://docs.pydantic.dev/latest/errors/errors/#customize-error-messages
"""
path = ""
for i, x in enumerate(loc):
if isinstance(x, str):
if i > 0:
path += "."
path += x
else:
path += f"[{x}]"
return path
@invocation_output("iterate_output") @invocation_output("iterate_output")
class IterateInvocationOutput(BaseInvocationOutput): class IterateInvocationOutput(BaseInvocationOutput):
"""Used to connect iteration outputs. Will be expanded to a specific output.""" """Used to connect iteration outputs. Will be expanded to a specific output."""
@ -243,73 +278,58 @@ class CollectInvocation(BaseInvocation):
return CollectInvocationOutput(collection=copy.copy(self.collection)) return CollectInvocationOutput(collection=copy.copy(self.collection))
class AnyInvocation(BaseInvocation):
@classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
def validate_invocation(v: Any) -> "AnyInvocation":
return BaseInvocation.get_typeadapter().validate_python(v)
return core_schema.no_info_plain_validator_function(validate_invocation)
@classmethod
def __get_pydantic_json_schema__(
cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
# Nodes are too powerful, we have to make our own OpenAPI schema manually
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
oneOf: list[dict[str, str]] = []
names = [i.__name__ for i in BaseInvocation.get_invocations()]
for name in sorted(names):
oneOf.append({"$ref": f"#/components/schemas/{name}"})
return {"oneOf": oneOf}
class AnyInvocationOutput(BaseInvocationOutput):
@classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler):
def validate_invocation_output(v: Any) -> "AnyInvocationOutput":
return BaseInvocationOutput.get_typeadapter().validate_python(v)
return core_schema.no_info_plain_validator_function(validate_invocation_output)
@classmethod
def __get_pydantic_json_schema__(
cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
# Nodes are too powerful, we have to make our own OpenAPI schema manually
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
oneOf: list[dict[str, str]] = []
names = [i.__name__ for i in BaseInvocationOutput.get_outputs()]
for name in sorted(names):
oneOf.append({"$ref": f"#/components/schemas/{name}"})
return {"oneOf": oneOf}
class Graph(BaseModel): class Graph(BaseModel):
id: str = Field(description="The id of this graph", default_factory=uuid_string) id: str = Field(description="The id of this graph", default_factory=uuid_string)
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me # TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
nodes: dict[str, BaseInvocation] = Field(description="The nodes in this graph", default_factory=dict) nodes: dict[str, AnyInvocation] = Field(description="The nodes in this graph", default_factory=dict)
edges: list[Edge] = Field( edges: list[Edge] = Field(
description="The connections between nodes and their fields in this graph", description="The connections between nodes and their fields in this graph",
default_factory=list, default_factory=list,
) )
@field_validator("nodes", mode="plain")
@classmethod
def validate_nodes(cls, v: dict[str, Any]):
"""Validates the nodes in the graph by retrieving a union of all node types and validating each node."""
# Invocations register themselves as their python modules are executed. The union of all invocations is
# constructed at runtime. We use pydantic to validate `Graph.nodes` using that union.
#
# It's possible that when `graph.py` is executed, not all invocation-containing modules will have executed. If
# we construct the invocation union as `graph.py` is executed, we may miss some invocations. Those missing
# invocations will cause a graph to fail if they are used.
#
# We can get around this by validating the nodes in the graph using a "plain" validator, which overrides the
# pydantic validation entirely. This allows us to validate the nodes using the union of invocations at runtime.
#
# This same pattern is used in `GraphExecutionState`.
nodes: dict[str, BaseInvocation] = {}
typeadapter = BaseInvocation.get_typeadapter()
for node_id, node in v.items():
nodes[node_id] = typeadapter.validate_python(node)
return nodes
@classmethod
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
# We use a "plain" validator to validate the nodes in the graph. Pydantic is unable to create a JSON Schema for
# fields that use "plain" validators, so we have to hack around this. Also, we need to add all invocations to
# the generated schema as options for the `nodes` field.
#
# The workaround is to create a new BaseModel that has the same fields as `Graph` but without the validator and
# with the invocation union as the type for the `nodes` field. Pydantic then generates the JSON Schema as
# expected.
#
# You might be tempted to do something like this:
#
# ```py
# cloned_model = create_model(cls.__name__, __base__=cls, nodes=...)
# delattr(cloned_model, "validate_nodes")
# cloned_model.model_rebuild(force=True)
# json_schema = handler(cloned_model.__pydantic_core_schema__)
# ```
#
# Unfortunately, this does not work. Calling `handler` here results in infinite recursion as pydantic attempts
# to build the JSON Schema for the cloned model. Instead, we have to manually clone the model.
#
# This same pattern is used in `GraphExecutionState`.
class Graph(BaseModel):
id: Optional[str] = Field(default=None, description="The id of this graph")
nodes: dict[
str, Annotated[Union[tuple(BaseInvocation._invocation_classes)], Field(discriminator="type")]
] = Field(description="The nodes in this graph")
edges: list[Edge] = Field(description="The connections between nodes and their fields in this graph")
json_schema = handler(Graph.__pydantic_core_schema__)
json_schema = handler.resolve_ref_schema(json_schema)
return json_schema
def add_node(self, node: BaseInvocation) -> None: def add_node(self, node: BaseInvocation) -> None:
"""Adds a node to a graph """Adds a node to a graph
@ -740,7 +760,7 @@ class GraphExecutionState(BaseModel):
) )
# The results of executed nodes # The results of executed nodes
results: dict[str, BaseInvocationOutput] = Field(description="The results of node executions", default_factory=dict) results: dict[str, AnyInvocationOutput] = Field(description="The results of node executions", default_factory=dict)
# Errors raised when executing nodes # Errors raised when executing nodes
errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict) errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict)
@ -757,52 +777,12 @@ class GraphExecutionState(BaseModel):
default_factory=dict, default_factory=dict,
) )
@field_validator("results", mode="plain")
@classmethod
def validate_results(cls, v: dict[str, BaseInvocationOutput]):
"""Validates the results in the GES by retrieving a union of all output types and validating each result."""
# See the comment in `Graph.validate_nodes` for an explanation of this logic.
results: dict[str, BaseInvocationOutput] = {}
typeadapter = BaseInvocationOutput.get_typeadapter()
for result_id, result in v.items():
results[result_id] = typeadapter.validate_python(result)
return results
@field_validator("graph") @field_validator("graph")
def graph_is_valid(cls, v: Graph): def graph_is_valid(cls, v: Graph):
"""Validates that the graph is valid""" """Validates that the graph is valid"""
v.validate_self() v.validate_self()
return v return v
@classmethod
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
# See the comment in `Graph.__get_pydantic_json_schema__` for an explanation of this logic.
class GraphExecutionState(BaseModel):
"""Tracks the state of a graph execution"""
id: str = Field(description="The id of the execution state")
graph: Graph = Field(description="The graph being executed")
execution_graph: Graph = Field(description="The expanded graph of activated and executed nodes")
executed: set[str] = Field(description="The set of node ids that have been executed")
executed_history: list[str] = Field(
description="The list of node ids that have been executed, in order of execution"
)
results: dict[
str, Annotated[Union[tuple(BaseInvocationOutput._output_classes)], Field(discriminator="type")]
] = Field(description="The results of node executions")
errors: dict[str, str] = Field(description="Errors raised when executing nodes")
prepared_source_mapping: dict[str, str] = Field(
description="The map of prepared nodes to original graph nodes"
)
source_prepared_mapping: dict[str, set[str]] = Field(
description="The map of original graph nodes to prepared nodes"
)
json_schema = handler(GraphExecutionState.__pydantic_core_schema__)
json_schema = handler.resolve_ref_schema(json_schema)
return json_schema
def next(self) -> Optional[BaseInvocation]: def next(self) -> Optional[BaseInvocation]:
"""Gets the next node ready to execute.""" """Gets the next node ready to execute."""
@ -821,7 +801,10 @@ class GraphExecutionState(BaseModel):
# Get values from edges # Get values from edges
if next_node is not None: if next_node is not None:
try:
self._prepare_inputs(next_node) self._prepare_inputs(next_node)
except ValidationError as e:
raise NodeInputError(next_node, e)
# If next is still none, there's no next node, return None # If next is still none, there's no next node, return None
return next_node return next_node

View File

@ -1,7 +1,6 @@
import threading
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Optional, Union from typing import TYPE_CHECKING, Callable, Optional, Union
import torch import torch
from PIL.Image import Image from PIL.Image import Image
@ -190,9 +189,9 @@ class ImagesInterface(InvocationContextInterface):
# If `metadata` is provided directly, use that. Else, use the metadata provided by `WithMetadata`, falling back to None. # If `metadata` is provided directly, use that. Else, use the metadata provided by `WithMetadata`, falling back to None.
metadata_ = None metadata_ = None
if metadata: if metadata:
metadata_ = metadata metadata_ = metadata.model_dump_json()
elif isinstance(self._data.invocation, WithMetadata): elif isinstance(self._data.invocation, WithMetadata) and self._data.invocation.metadata:
metadata_ = self._data.invocation.metadata metadata_ = self._data.invocation.metadata.model_dump_json()
# If `board_id` is provided directly, use that. Else, use the board provided by `WithBoard`, falling back to None. # If `board_id` is provided directly, use that. Else, use the board provided by `WithBoard`, falling back to None.
board_id_ = None board_id_ = None
@ -201,6 +200,14 @@ class ImagesInterface(InvocationContextInterface):
elif isinstance(self._data.invocation, WithBoard) and self._data.invocation.board: elif isinstance(self._data.invocation, WithBoard) and self._data.invocation.board:
board_id_ = self._data.invocation.board.board_id board_id_ = self._data.invocation.board.board_id
workflow_ = None
if self._data.queue_item.workflow:
workflow_ = self._data.queue_item.workflow.model_dump_json()
graph_ = None
if self._data.queue_item.session.graph:
graph_ = self._data.queue_item.session.graph.model_dump_json()
return self._services.images.create( return self._services.images.create(
image=image, image=image,
is_intermediate=self._data.invocation.is_intermediate, is_intermediate=self._data.invocation.is_intermediate,
@ -208,7 +215,8 @@ class ImagesInterface(InvocationContextInterface):
board_id=board_id_, board_id=board_id_,
metadata=metadata_, metadata=metadata_,
image_origin=ResourceOrigin.INTERNAL, image_origin=ResourceOrigin.INTERNAL,
workflow=self._data.queue_item.workflow, workflow=workflow_,
graph=graph_,
session_id=self._data.queue_item.session_id, session_id=self._data.queue_item.session_id,
node_id=self._data.invocation.id, node_id=self._data.invocation.id,
) )
@ -354,11 +362,11 @@ class ModelsInterface(InvocationContextInterface):
if isinstance(identifier, str): if isinstance(identifier, str):
model = self._services.model_manager.store.get_model(identifier) model = self._services.model_manager.store.get_model(identifier)
return self._services.model_manager.load.load_model(model, submodel_type, self._data) return self._services.model_manager.load.load_model(model, submodel_type)
else: else:
_submodel_type = submodel_type or identifier.submodel_type _submodel_type = submodel_type or identifier.submodel_type
model = self._services.model_manager.store.get_model(identifier.key) model = self._services.model_manager.store.get_model(identifier.key)
return self._services.model_manager.load.load_model(model, _submodel_type, self._data) return self._services.model_manager.load.load_model(model, _submodel_type)
def load_by_attrs( def load_by_attrs(
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
@ -383,7 +391,7 @@ class ModelsInterface(InvocationContextInterface):
if len(configs) > 1: if len(configs) > 1:
raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}") raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}")
return self._services.model_manager.load.load_model(configs[0], submodel_type, self._data) return self._services.model_manager.load.load_model(configs[0], submodel_type)
def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig: def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig:
"""Gets a model's config. """Gets a model's config.
@ -450,10 +458,10 @@ class ConfigInterface(InvocationContextInterface):
class UtilInterface(InvocationContextInterface): class UtilInterface(InvocationContextInterface):
def __init__( def __init__(
self, services: InvocationServices, data: InvocationContextData, cancel_event: threading.Event self, services: InvocationServices, data: InvocationContextData, is_canceled: Callable[[], bool]
) -> None: ) -> None:
super().__init__(services, data) super().__init__(services, data)
self._cancel_event = cancel_event self._is_canceled = is_canceled
def is_canceled(self) -> bool: def is_canceled(self) -> bool:
"""Checks if the current session has been canceled. """Checks if the current session has been canceled.
@ -461,7 +469,7 @@ class UtilInterface(InvocationContextInterface):
Returns: Returns:
True if the current session has been canceled, False if not. True if the current session has been canceled, False if not.
""" """
return self._cancel_event.is_set() return self._is_canceled()
def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_model: BaseModelType) -> None: def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_model: BaseModelType) -> None:
""" """
@ -558,7 +566,7 @@ class InvocationContext:
def build_invocation_context( def build_invocation_context(
services: InvocationServices, services: InvocationServices,
data: InvocationContextData, data: InvocationContextData,
cancel_event: threading.Event, is_canceled: Callable[[], bool],
) -> InvocationContext: ) -> InvocationContext:
"""Builds the invocation context for a specific invocation execution. """Builds the invocation context for a specific invocation execution.
@ -575,7 +583,7 @@ def build_invocation_context(
tensors = TensorsInterface(services=services, data=data) tensors = TensorsInterface(services=services, data=data)
models = ModelsInterface(services=services, data=data) models = ModelsInterface(services=services, data=data)
config = ConfigInterface(services=services, data=data) config = ConfigInterface(services=services, data=data)
util = UtilInterface(services=services, data=data, cancel_event=cancel_event) util = UtilInterface(services=services, data=data, is_canceled=is_canceled)
conditioning = ConditioningInterface(services=services, data=data) conditioning = ConditioningInterface(services=services, data=data)
boards = BoardsInterface(services=services, data=data) boards = BoardsInterface(services=services, data=data)

View File

@ -12,6 +12,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_6 import
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_7 import build_migration_7 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_7 import build_migration_7
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_8 import build_migration_8 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_8 import build_migration_8
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_9 import build_migration_9 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_9 import build_migration_9
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_10 import build_migration_10
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@ -41,6 +42,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_7()) migrator.register_migration(build_migration_7())
migrator.register_migration(build_migration_8(app_config=config)) migrator.register_migration(build_migration_8(app_config=config))
migrator.register_migration(build_migration_9()) migrator.register_migration(build_migration_9())
migrator.register_migration(build_migration_10())
migrator.run_migrations() migrator.run_migrations()
return db return db

View File

@ -0,0 +1,35 @@
import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration10Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._update_error_cols(cursor)
def _update_error_cols(self, cursor: sqlite3.Cursor) -> None:
"""
- Adds `error_type` and `error_message` columns to the session queue table.
- Renames the `error` column to `error_traceback`.
"""
cursor.execute("ALTER TABLE session_queue ADD COLUMN error_type TEXT;")
cursor.execute("ALTER TABLE session_queue ADD COLUMN error_message TEXT;")
cursor.execute("ALTER TABLE session_queue RENAME COLUMN error TO error_traceback;")
def build_migration_10() -> Migration:
"""
Build the migration from database version 9 to 10.
This migration does the following:
- Adds `error_type` and `error_message` columns to the session queue table.
- Renames the `error` column to `error_traceback`.
"""
migration_10 = Migration(
from_version=9,
to_version=10,
callback=Migration10Callback(),
)
return migration_10

View File

@ -0,0 +1,116 @@
from typing import Any, Callable, Optional
from fastapi import FastAPI
from fastapi.openapi.utils import get_openapi
from pydantic.json_schema import models_json_schema
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, UIConfigBase
from invokeai.app.invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.events.events_common import EventBase
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
def move_defs_to_top_level(openapi_schema: dict[str, Any], component_schema: dict[str, Any]) -> None:
"""Moves a component schema's $defs to the top level of the openapi schema. Useful when generating a schema
for a single model that needs to be added back to the top level of the schema. Mutates openapi_schema and
component_schema."""
defs = component_schema.pop("$defs", {})
for schema_key, json_schema in defs.items():
if schema_key in openapi_schema["components"]["schemas"]:
continue
openapi_schema["components"]["schemas"][schema_key] = json_schema
def get_openapi_func(
app: FastAPI, post_transform: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None
) -> Callable[[], dict[str, Any]]:
"""Gets the OpenAPI schema generator function.
Args:
app (FastAPI): The FastAPI app to generate the schema for.
post_transform (Optional[Callable[[dict[str, Any]], dict[str, Any]]], optional): A function to apply to the
generated schema before returning it. Defaults to None.
Returns:
Callable[[], dict[str, Any]]: The OpenAPI schema generator function. When first called, the generated schema is
cached in `app.openapi_schema`. On subsequent calls, the cached schema is returned. This caching behaviour
matches FastAPI's default schema generation caching.
"""
def openapi() -> dict[str, Any]:
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
title=app.title,
description="An API for invoking AI image operations",
version="1.0.0",
routes=app.routes,
separate_input_output_schemas=False, # https://fastapi.tiangolo.com/how-to/separate-openapi-schemas/
)
# We'll create a map of invocation type to output schema to make some types simpler on the client.
invocation_output_map_properties: dict[str, Any] = {}
invocation_output_map_required: list[str] = []
# We need to manually add all outputs to the schema - pydantic doesn't add them because they aren't used directly.
for output in BaseInvocationOutput.get_outputs():
json_schema = output.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}")
move_defs_to_top_level(openapi_schema, json_schema)
openapi_schema["components"]["schemas"][output.__name__] = json_schema
# Technically, invocations are added to the schema by pydantic, but we still need to manually set their output
# property, so we'll just do it all manually.
for invocation in BaseInvocation.get_invocations():
json_schema = invocation.model_json_schema(
mode="serialization", ref_template="#/components/schemas/{model}"
)
move_defs_to_top_level(openapi_schema, json_schema)
output_title = invocation.get_output_annotation().__name__
outputs_ref = {"$ref": f"#/components/schemas/{output_title}"}
json_schema["output"] = outputs_ref
openapi_schema["components"]["schemas"][invocation.__name__] = json_schema
# Add this invocation and its output to the output map
invocation_type = invocation.get_type()
invocation_output_map_properties[invocation_type] = json_schema["output"]
invocation_output_map_required.append(invocation_type)
# Add the output map to the schema
openapi_schema["components"]["schemas"]["InvocationOutputMap"] = {
"type": "object",
"properties": invocation_output_map_properties,
"required": invocation_output_map_required,
}
# Some models don't end up in the schemas as standalone definitions because they aren't used directly in the API.
# We need to add them manually here. WARNING: Pydantic can choke if you call `model.model_json_schema()` to get
# a schema. This has something to do with schema refs - not totally clear. For whatever reason, using
# `models_json_schema` seems to work fine.
additional_models = [
*EventBase.get_events(),
UIConfigBase,
InputFieldJSONSchemaExtra,
OutputFieldJSONSchemaExtra,
ModelIdentifierField,
ProgressImage,
]
additional_schemas = models_json_schema(
[(m, "serialization") for m in additional_models],
ref_template="#/components/schemas/{model}",
)
# additional_schemas[1] is a dict of $defs that we need to add to the top level of the schema
move_defs_to_top_level(openapi_schema, additional_schemas[1])
if post_transform is not None:
openapi_schema = post_transform(openapi_schema)
openapi_schema["components"]["schemas"] = dict(sorted(openapi_schema["components"]["schemas"].items()))
app.openapi_schema = openapi_schema
return app.openapi_schema
return openapi

View File

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Callable from typing import TYPE_CHECKING, Callable, Optional
import torch import torch
from PIL import Image from PIL import Image
@ -13,8 +13,36 @@ if TYPE_CHECKING:
from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.shared.invocation_context import InvocationContextData from invokeai.app.services.shared.invocation_context import InvocationContextData
# fast latents preview matrix for sdxl
# generated by @StAlKeR7779
SDXL_LATENT_RGB_FACTORS = [
# R G B
[0.3816, 0.4930, 0.5320],
[-0.3753, 0.1631, 0.1739],
[0.1770, 0.3588, -0.2048],
[-0.4350, -0.2644, -0.4289],
]
SDXL_SMOOTH_MATRIX = [
[0.0358, 0.0964, 0.0358],
[0.0964, 0.4711, 0.0964],
[0.0358, 0.0964, 0.0358],
]
def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=None): # origingally adapted from code by @erucipe and @keturn here:
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
# these updated numbers for v1.5 are from @torridgristle
SD1_5_LATENT_RGB_FACTORS = [
# R G B
[0.3444, 0.1385, 0.0670], # L1
[0.1247, 0.4027, 0.1494], # L2
[-0.3192, 0.2513, 0.2103], # L3
[-0.1307, -0.1874, -0.7445], # L4
]
def sample_to_lowres_estimated_image(
samples: torch.Tensor, latent_rgb_factors: torch.Tensor, smooth_matrix: Optional[torch.Tensor] = None
):
latent_image = samples[0].permute(1, 2, 0) @ latent_rgb_factors latent_image = samples[0].permute(1, 2, 0) @ latent_rgb_factors
if smooth_matrix is not None: if smooth_matrix is not None:
@ -47,64 +75,12 @@ def stable_diffusion_step_callback(
else: else:
sample = intermediate_state.latents sample = intermediate_state.latents
# TODO: This does not seem to be needed any more?
# # txt2img provides a Tensor in the step_callback
# # img2img provides a PipelineIntermediateState
# if isinstance(sample, PipelineIntermediateState):
# # this was an img2img
# print('img2img')
# latents = sample.latents
# step = sample.step
# else:
# print('txt2img')
# latents = sample
# step = intermediate_state.step
# TODO: only output a preview image when requested
if base_model in [BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner]: if base_model in [BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner]:
# fast latents preview matrix for sdxl sdxl_latent_rgb_factors = torch.tensor(SDXL_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
# generated by @StAlKeR7779 sdxl_smooth_matrix = torch.tensor(SDXL_SMOOTH_MATRIX, dtype=sample.dtype, device=sample.device)
sdxl_latent_rgb_factors = torch.tensor(
[
# R G B
[0.3816, 0.4930, 0.5320],
[-0.3753, 0.1631, 0.1739],
[0.1770, 0.3588, -0.2048],
[-0.4350, -0.2644, -0.4289],
],
dtype=sample.dtype,
device=sample.device,
)
sdxl_smooth_matrix = torch.tensor(
[
[0.0358, 0.0964, 0.0358],
[0.0964, 0.4711, 0.0964],
[0.0358, 0.0964, 0.0358],
],
dtype=sample.dtype,
device=sample.device,
)
image = sample_to_lowres_estimated_image(sample, sdxl_latent_rgb_factors, sdxl_smooth_matrix) image = sample_to_lowres_estimated_image(sample, sdxl_latent_rgb_factors, sdxl_smooth_matrix)
else: else:
# origingally adapted from code by @erucipe and @keturn here: v1_5_latent_rgb_factors = torch.tensor(SD1_5_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
# these updated numbers for v1.5 are from @torridgristle
v1_5_latent_rgb_factors = torch.tensor(
[
# R G B
[0.3444, 0.1385, 0.0670], # L1
[0.1247, 0.4027, 0.1494], # L2
[-0.3192, 0.2513, 0.2103], # L3
[-0.1307, -0.1874, -0.7445], # L4
],
dtype=sample.dtype,
device=sample.device,
)
image = sample_to_lowres_estimated_image(sample, v1_5_latent_rgb_factors) image = sample_to_lowres_estimated_image(sample, v1_5_latent_rgb_factors)
(width, height) = image.size (width, height) = image.size
@ -113,15 +89,9 @@ def stable_diffusion_step_callback(
dataURL = image_to_dataURL(image, image_format="JPEG") dataURL = image_to_dataURL(image, image_format="JPEG")
events.emit_generator_progress( events.emit_invocation_denoise_progress(
queue_id=context_data.queue_item.queue_id, context_data.queue_item,
queue_item_id=context_data.queue_item.item_id, context_data.invocation,
queue_batch_id=context_data.queue_item.batch_id, intermediate_state,
graph_execution_state_id=context_data.queue_item.session_id, ProgressImage(dataURL=dataURL, width=width, height=height),
node_id=context_data.invocation.id,
source_node_id=context_data.source_invocation_id,
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
step=intermediate_state.step,
order=intermediate_state.order,
total_steps=intermediate_state.total_steps,
) )

View File

@ -4,5 +4,4 @@ Initialization file for invokeai.backend.image_util methods.
from .infill_methods.patchmatch import PatchMatch # noqa: F401 from .infill_methods.patchmatch import PatchMatch # noqa: F401
from .pngwriter import PngWriter, PromptFormatter, retrieve_metadata, write_metadata # noqa: F401 from .pngwriter import PngWriter, PromptFormatter, retrieve_metadata, write_metadata # noqa: F401
from .seamless import configure_model_padding # noqa: F401
from .util import InitImageResizer, make_grid # noqa: F401 from .util import InitImageResizer, make_grid # noqa: F401

View File

@ -8,7 +8,7 @@ from pathlib import Path
import numpy as np import numpy as np
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from PIL import Image from PIL import Image, ImageFilter
from transformers import AutoFeatureExtractor from transformers import AutoFeatureExtractor
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
@ -16,6 +16,7 @@ from invokeai.app.services.config.config_default import get_config
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.silence_warnings import SilenceWarnings from invokeai.backend.util.silence_warnings import SilenceWarnings
repo_id = "CompVis/stable-diffusion-safety-checker"
CHECKER_PATH = "core/convert/stable-diffusion-safety-checker" CHECKER_PATH = "core/convert/stable-diffusion-safety-checker"
@ -24,30 +25,30 @@ class SafetyChecker:
Wrapper around SafetyChecker model. Wrapper around SafetyChecker model.
""" """
safety_checker = None
feature_extractor = None feature_extractor = None
tried_load: bool = False safety_checker = None
@classmethod @classmethod
def _load_safety_checker(cls): def _load_safety_checker(cls):
if cls.tried_load: if cls.safety_checker is not None and cls.feature_extractor is not None:
return return
try: try:
cls.safety_checker = StableDiffusionSafetyChecker.from_pretrained(get_config().models_path / CHECKER_PATH) model_path = get_config().models_path / CHECKER_PATH
cls.feature_extractor = AutoFeatureExtractor.from_pretrained(get_config().models_path / CHECKER_PATH) if model_path.exists():
cls.feature_extractor = AutoFeatureExtractor.from_pretrained(model_path)
cls.safety_checker = StableDiffusionSafetyChecker.from_pretrained(model_path)
else:
model_path.mkdir(parents=True, exist_ok=True)
cls.feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
cls.feature_extractor.save_pretrained(model_path, safe_serialization=True)
cls.safety_checker = StableDiffusionSafetyChecker.from_pretrained(repo_id)
cls.safety_checker.save_pretrained(model_path, safe_serialization=True)
except Exception as e: except Exception as e:
logger.warning(f"Could not load NSFW checker: {str(e)}") logger.warning(f"Could not load NSFW checker: {str(e)}")
cls.tried_load = True
@classmethod
def safety_checker_available(cls) -> bool:
return Path(get_config().models_path, CHECKER_PATH).exists()
@classmethod @classmethod
def has_nsfw_concept(cls, image: Image.Image) -> bool: def has_nsfw_concept(cls, image: Image.Image) -> bool:
if not cls.safety_checker_available() and cls.tried_load:
return False
cls._load_safety_checker() cls._load_safety_checker()
if cls.safety_checker is None or cls.feature_extractor is None: if cls.safety_checker is None or cls.feature_extractor is None:
return False return False
@ -60,3 +61,24 @@ class SafetyChecker:
with SilenceWarnings(): with SilenceWarnings():
checked_image, has_nsfw_concept = cls.safety_checker(images=x_image, clip_input=features.pixel_values) checked_image, has_nsfw_concept = cls.safety_checker(images=x_image, clip_input=features.pixel_values)
return has_nsfw_concept[0] return has_nsfw_concept[0]
@classmethod
def blur_if_nsfw(cls, image: Image.Image) -> Image.Image:
if cls.has_nsfw_concept(image):
logger.warning("A potentially NSFW image has been detected. Image will be blurred.")
blurry_image = image.filter(filter=ImageFilter.GaussianBlur(radius=32))
caution = cls._get_caution_img()
# Center the caution image on the blurred image
x = (blurry_image.width - caution.width) // 2
y = (blurry_image.height - caution.height) // 2
blurry_image.paste(caution, (x, y), caution)
image = blurry_image
return image
@classmethod
def _get_caution_img(cls) -> Image.Image:
import invokeai.app.assets.images as image_assets
caution = Image.open(Path(image_assets.__path__[0]) / "caution.png")
return caution.resize((caution.width // 2, caution.height // 2))

View File

@ -1,52 +0,0 @@
import torch.nn as nn
def _conv_forward_asymmetric(self, input, weight, bias):
"""
Patch for Conv2d._conv_forward that supports asymmetric padding
"""
working = nn.functional.pad(input, self.asymmetric_padding["x"], mode=self.asymmetric_padding_mode["x"])
working = nn.functional.pad(working, self.asymmetric_padding["y"], mode=self.asymmetric_padding_mode["y"])
return nn.functional.conv2d(
working,
weight,
bias,
self.stride,
nn.modules.utils._pair(0),
self.dilation,
self.groups,
)
def configure_model_padding(model, seamless, seamless_axes):
"""
Modifies the 2D convolution layers to use a circular padding mode based on
the `seamless` and `seamless_axes` options.
"""
# TODO: get an explicit interface for this in diffusers: https://github.com/huggingface/diffusers/issues/556
for m in model.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
if seamless:
m.asymmetric_padding_mode = {}
m.asymmetric_padding = {}
m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant"
m.asymmetric_padding["x"] = (
m._reversed_padding_repeated_twice[0],
m._reversed_padding_repeated_twice[1],
0,
0,
)
m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant"
m.asymmetric_padding["y"] = (
0,
0,
m._reversed_padding_repeated_twice[2],
m._reversed_padding_repeated_twice[3],
)
m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d)
else:
m._conv_forward = nn.Conv2d._conv_forward.__get__(m, nn.Conv2d)
if hasattr(m, "asymmetric_padding_mode"):
del m.asymmetric_padding_mode
if hasattr(m, "asymmetric_padding"):
del m.asymmetric_padding

View File

@ -43,9 +43,26 @@ T = TypeVar("T")
@dataclass @dataclass
class CacheRecord(Generic[T]): class CacheRecord(Generic[T]):
"""Elements of the cache.""" """
Elements of the cache:
key: Unique key for each model, same as used in the models database.
model: Model in memory.
state_dict: A read-only copy of the model's state dict in RAM. It will be
used as a template for creating a copy in the VRAM.
size: Size of the model
loaded: True if the model's state dict is currently in VRAM
Before a model is executed, the state_dict template is copied into VRAM,
and then injected into the model. When the model is finished, the VRAM
copy of the state dict is deleted, and the RAM version is reinjected
into the model.
"""
key: str key: str
model: T
device: torch.device
state_dict: Optional[Dict[str, torch.Tensor]]
size: int size: int
model: T model: T
loaded: bool = False loaded: bool = False

View File

@ -233,7 +233,8 @@ class ModelCache(ModelCacheBase[AnyModel]):
if key in self._cached_models: if key in self._cached_models:
return return
self.make_room(size) self.make_room(size)
cache_record = CacheRecord(key, model=model, size=size) state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None
cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size)
self._cached_models[key] = cache_record self._cached_models[key] = cache_record
self._cache_stack.append(key) self._cache_stack.append(key)
@ -329,17 +330,37 @@ class ModelCache(ModelCacheBase[AnyModel]):
if not (hasattr(cache_entry.model, "device") and hasattr(cache_entry.model, "to")): if not (hasattr(cache_entry.model, "device") and hasattr(cache_entry.model, "to")):
return return
source_device = cache_entry.model.device source_device = cache_entry.device
# Note: We compare device types only so that 'cuda' == 'cuda:0'. # Note: We compare device types only so that 'cuda' == 'cuda:0'.
# This would need to be revised to support multi-GPU. # This would need to be revised to support multi-GPU.
if torch.device(source_device).type == torch.device(target_device).type: if torch.device(source_device).type == torch.device(target_device).type:
return return
# This roundabout method for moving the model around is done to avoid
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
# When moving to VRAM, we copy (not move) each element of the state dict from
# RAM to a new state dict in VRAM, and then inject it into the model.
# This operation is slightly faster than running `to()` on the whole model.
#
# When the model needs to be removed from VRAM we simply delete the copy
# of the state dict in VRAM, and reinject the state dict that is cached
# in RAM into the model. So this operation is very fast.
start_model_to_time = time.time() start_model_to_time = time.time()
snapshot_before = self._capture_memory_snapshot() snapshot_before = self._capture_memory_snapshot()
try: try:
if cache_entry.state_dict is not None:
assert hasattr(cache_entry.model, "load_state_dict")
if target_device == self.storage_device:
cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True)
else:
new_dict: Dict[str, torch.Tensor] = {}
for k, v in cache_entry.state_dict.items():
new_dict[k] = v.to(torch.device(target_device), copy=True)
cache_entry.model.load_state_dict(new_dict, assign=True)
cache_entry.model.to(target_device) cache_entry.model.to(target_device)
cache_entry.device = target_device
except Exception as e: # blow away cache entry except Exception as e: # blow away cache entry
self._delete_cache_entry(cache_entry) self._delete_cache_entry(cache_entry)
raise e raise e
@ -419,43 +440,12 @@ class ModelCache(ModelCacheBase[AnyModel]):
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack): while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
model_key = self._cache_stack[pos] model_key = self._cache_stack[pos]
cache_entry = self._cached_models[model_key] cache_entry = self._cached_models[model_key]
refs = sys.getrefcount(cache_entry.model)
# HACK: This is a workaround for a memory-management issue that we haven't tracked down yet. We are directly
# going against the advice in the Python docs by using `gc.get_referrers(...)` in this way:
# https://docs.python.org/3/library/gc.html#gc.get_referrers
# manualy clear local variable references of just finished function calls
# for some reason python don't want to collect it even by gc.collect() immidiately
if refs > 2:
while True:
cleared = False
for referrer in gc.get_referrers(cache_entry.model):
if type(referrer).__name__ == "frame":
# RuntimeError: cannot clear an executing frame
with suppress(RuntimeError):
referrer.clear()
cleared = True
# break
# repeat if referrers changes(due to frame clear), else exit loop
if cleared:
gc.collect()
else:
break
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
self.logger.debug( self.logger.debug(
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}," f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}"
f" refs: {refs}"
) )
# Expected refs: if not cache_entry.locked:
# 1 from cache_entry
# 1 from getrefcount function
# 1 from onnx runtime object
if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2):
self.logger.debug( self.logger.debug(
f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)" f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
) )

View File

@ -80,5 +80,5 @@ class ModelLocker(ModelLockerBase):
self._cache_entry.unlock() self._cache_entry.unlock()
if not self._cache.lazy_offloading: if not self._cache.lazy_offloading:
self._cache.offload_unlocked_models(self._cache_entry.size) self._cache.offload_unlocked_models(0)
self._cache.print_cuda_stats() self._cache.print_cuda_stats()

View File

@ -1,89 +1,51 @@
from __future__ import annotations
from contextlib import contextmanager from contextlib import contextmanager
from typing import Callable, List, Union from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn as nn import torch.nn as nn
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
from diffusers.models.lora import LoRACompatibleConv
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
def _conv_forward_asymmetric(self, input, weight, bias):
"""
Patch for Conv2d._conv_forward that supports asymmetric padding
"""
working = nn.functional.pad(input, self.asymmetric_padding["x"], mode=self.asymmetric_padding_mode["x"])
working = nn.functional.pad(working, self.asymmetric_padding["y"], mode=self.asymmetric_padding_mode["y"])
return nn.functional.conv2d(
working,
weight,
bias,
self.stride,
nn.modules.utils._pair(0),
self.dilation,
self.groups,
)
@contextmanager @contextmanager
def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL, AutoencoderTiny], seamless_axes: List[str]): def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL, AutoencoderTiny], seamless_axes: List[str]):
if not seamless_axes: if not seamless_axes:
yield yield
return return
# Callable: (input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor # override conv_forward
to_restore: list[tuple[nn.Conv2d | nn.ConvTranspose2d, Callable]] = [] # https://github.com/huggingface/diffusers/issues/556#issuecomment-1993287019
def _conv_forward_asymmetric(self, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):
self.paddingX = (self._reversed_padding_repeated_twice[0], self._reversed_padding_repeated_twice[1], 0, 0)
self.paddingY = (0, 0, self._reversed_padding_repeated_twice[2], self._reversed_padding_repeated_twice[3])
working = torch.nn.functional.pad(input, self.paddingX, mode=x_mode)
working = torch.nn.functional.pad(working, self.paddingY, mode=y_mode)
return torch.nn.functional.conv2d(
working, weight, bias, self.stride, torch.nn.modules.utils._pair(0), self.dilation, self.groups
)
original_layers: List[Tuple[nn.Conv2d, Callable]] = []
try: try:
# Hard coded to skip down block layers, allowing for seamless tiling at the expense of prompt adherence x_mode = "circular" if "x" in seamless_axes else "constant"
skipped_layers = 1 y_mode = "circular" if "y" in seamless_axes else "constant"
for m_name, m in model.named_modules():
if not isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
continue
if isinstance(model, UNet2DConditionModel) and m_name.startswith("down_blocks.") and ".resnets." in m_name: conv_layers: List[torch.nn.Conv2d] = []
# down_blocks.1.resnets.1.conv1
_, block_num, _, resnet_num, submodule_name = m_name.split(".")
block_num = int(block_num)
resnet_num = int(resnet_num)
if block_num >= len(model.down_blocks) - skipped_layers: for module in model.modules():
continue if isinstance(module, torch.nn.Conv2d):
conv_layers.append(module)
# Skip the second resnet (could be configurable) for layer in conv_layers:
if resnet_num > 0: if isinstance(layer, LoRACompatibleConv) and layer.lora_layer is None:
continue layer.lora_layer = lambda *x: 0
original_layers.append((layer, layer._conv_forward))
# Skip Conv2d layers (could be configurable) layer._conv_forward = _conv_forward_asymmetric.__get__(layer, torch.nn.Conv2d)
if submodule_name == "conv2":
continue
m.asymmetric_padding_mode = {}
m.asymmetric_padding = {}
m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant"
m.asymmetric_padding["x"] = (
m._reversed_padding_repeated_twice[0],
m._reversed_padding_repeated_twice[1],
0,
0,
)
m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant"
m.asymmetric_padding["y"] = (
0,
0,
m._reversed_padding_repeated_twice[2],
m._reversed_padding_repeated_twice[3],
)
to_restore.append((m, m._conv_forward))
m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d)
yield yield
finally: finally:
for module, orig_conv_forward in to_restore: for layer, orig_conv_forward in original_layers:
module._conv_forward = orig_conv_forward layer._conv_forward = orig_conv_forward
if hasattr(module, "asymmetric_padding_mode"):
del module.asymmetric_padding_mode
if hasattr(module, "asymmetric_padding"):
del module.asymmetric_padding

View File

@ -1,7 +1,7 @@
"""Textual Inversion wrapper class.""" """Textual Inversion wrapper class."""
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Union from typing import Optional, Union
import torch import torch
from compel.embeddings_provider import BaseTextualInversionManager from compel.embeddings_provider import BaseTextualInversionManager
@ -66,35 +66,52 @@ class TextualInversionModelRaw(RawModel):
return result return result
# no type hints for BaseTextualInversionManager? class TextualInversionManager(BaseTextualInversionManager):
class TextualInversionManager(BaseTextualInversionManager): # type: ignore """TextualInversionManager implements the BaseTextualInversionManager ABC from the compel library."""
pad_tokens: Dict[int, List[int]]
tokenizer: CLIPTokenizer
def __init__(self, tokenizer: CLIPTokenizer): def __init__(self, tokenizer: CLIPTokenizer):
self.pad_tokens = {} self.pad_tokens: dict[int, list[int]] = {}
self.tokenizer = tokenizer self.tokenizer = tokenizer
def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]: def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]:
"""Given a list of tokens ids, expand any TI tokens to their corresponding pad tokens.
For example, suppose we have a `<ti_dog>` TI with 4 vectors that was added to the tokenizer with the following
mapping of tokens to token_ids:
```
<ti_dog>: 49408
<ti_dog-!pad-1>: 49409
<ti_dog-!pad-2>: 49410
<ti_dog-!pad-3>: 49411
```
`self.pad_tokens` would be set to `{49408: [49408, 49409, 49410, 49411]}`.
This function is responsible for expanding `49408` in the token_ids list to `[49408, 49409, 49410, 49411]`.
"""
# Short circuit if there are no pad tokens to save a little time.
if len(self.pad_tokens) == 0: if len(self.pad_tokens) == 0:
return token_ids return token_ids
# This function assumes that compel has not included the BOS and EOS tokens in the token_ids list. We verify
# this assumption here.
if token_ids[0] == self.tokenizer.bos_token_id: if token_ids[0] == self.tokenizer.bos_token_id:
raise ValueError("token_ids must not start with bos_token_id") raise ValueError("token_ids must not start with bos_token_id")
if token_ids[-1] == self.tokenizer.eos_token_id: if token_ids[-1] == self.tokenizer.eos_token_id:
raise ValueError("token_ids must not end with eos_token_id") raise ValueError("token_ids must not end with eos_token_id")
new_token_ids = [] # Expand any TI tokens to their corresponding pad tokens.
new_token_ids: list[int] = []
for token_id in token_ids: for token_id in token_ids:
new_token_ids.append(token_id) new_token_ids.append(token_id)
if token_id in self.pad_tokens: if token_id in self.pad_tokens:
new_token_ids.extend(self.pad_tokens[token_id]) new_token_ids.extend(self.pad_tokens[token_id])
# Do not exceed the max model input size # Do not exceed the max model input size. The -2 here is compensating for
# The -2 here is compensating for compensate compel.embeddings_provider.get_token_ids(), # compel.embeddings_provider.get_token_ids(), which first removes and then adds back the start and end tokens.
# which first removes and then adds back the start and end tokens. max_length = self.tokenizer.model_max_length - 2
max_length = list(self.tokenizer.max_model_input_sizes.values())[0] - 2
if len(new_token_ids) > max_length: if len(new_token_ids) > max_length:
# HACK: If TI token expansion causes us to exceed the max text encoder input length, we silently discard
# tokens. Token expansion should happen in a way that is compatible with compel's default handling of long
# prompts.
new_token_ids = new_token_ids[0:max_length] new_token_ids = new_token_ids[0:max_length]
return new_token_ids return new_token_ids

View File

@ -10,6 +10,8 @@ module.exports = {
'path/no-relative-imports': ['error', { maxDepth: 0 }], 'path/no-relative-imports': ['error', { maxDepth: 0 }],
// https://github.com/edvardchen/eslint-plugin-i18next/blob/HEAD/docs/rules/no-literal-string.md // https://github.com/edvardchen/eslint-plugin-i18next/blob/HEAD/docs/rules/no-literal-string.md
'i18next/no-literal-string': 'error', 'i18next/no-literal-string': 'error',
// https://eslint.org/docs/latest/rules/no-console
'no-console': 'error',
}, },
overrides: [ overrides: [
/** /**

View File

@ -44,3 +44,4 @@ yalc.lock
# vitest # vitest
tsconfig.vitest-temp.json tsconfig.vitest-temp.json
coverage/

View File

@ -35,6 +35,7 @@
"storybook": "storybook dev -p 6006", "storybook": "storybook dev -p 6006",
"build-storybook": "storybook build", "build-storybook": "storybook build",
"test": "vitest", "test": "vitest",
"test:ui": "vitest --coverage --ui",
"test:no-watch": "vitest --no-watch" "test:no-watch": "vitest --no-watch"
}, },
"madge": { "madge": {
@ -52,48 +53,48 @@
}, },
"dependencies": { "dependencies": {
"@chakra-ui/react-use-size": "^2.1.0", "@chakra-ui/react-use-size": "^2.1.0",
"@dagrejs/dagre": "^1.1.1", "@dagrejs/dagre": "^1.1.2",
"@dagrejs/graphlib": "^2.2.1", "@dagrejs/graphlib": "^2.2.2",
"@dnd-kit/core": "^6.1.0", "@dnd-kit/core": "^6.1.0",
"@dnd-kit/sortable": "^8.0.0", "@dnd-kit/sortable": "^8.0.0",
"@dnd-kit/utilities": "^3.2.2", "@dnd-kit/utilities": "^3.2.2",
"@fontsource-variable/inter": "^5.0.17", "@fontsource-variable/inter": "^5.0.18",
"@invoke-ai/ui-library": "^0.0.25", "@invoke-ai/ui-library": "^0.0.25",
"@nanostores/react": "^0.7.2", "@nanostores/react": "^0.7.2",
"@reduxjs/toolkit": "2.2.2", "@reduxjs/toolkit": "2.2.3",
"@roarr/browser-log-writer": "^1.3.0", "@roarr/browser-log-writer": "^1.3.0",
"chakra-react-select": "^4.7.6", "chakra-react-select": "^4.7.6",
"compare-versions": "^6.1.0", "compare-versions": "^6.1.0",
"dateformat": "^5.0.3", "dateformat": "^5.0.3",
"framer-motion": "^11.0.22", "fracturedjsonjs": "^4.0.1",
"i18next": "^23.10.1", "framer-motion": "^11.1.8",
"i18next-http-backend": "^2.5.0", "i18next": "^23.11.3",
"i18next-http-backend": "^2.5.1",
"idb-keyval": "^6.2.1", "idb-keyval": "^6.2.1",
"jsondiffpatch": "^0.6.0", "jsondiffpatch": "^0.6.0",
"konva": "^9.3.6", "konva": "^9.3.6",
"lodash-es": "^4.17.21", "lodash-es": "^4.17.21",
"nanostores": "^0.10.0", "nanostores": "^0.10.3",
"new-github-issue-url": "^1.0.0", "new-github-issue-url": "^1.0.0",
"overlayscrollbars": "^2.6.1", "overlayscrollbars": "^2.7.3",
"overlayscrollbars-react": "^0.5.5", "overlayscrollbars-react": "^0.5.6",
"query-string": "^9.0.0", "query-string": "^9.0.0",
"react": "^18.2.0", "react": "^18.3.1",
"react-colorful": "^5.6.1", "react-colorful": "^5.6.1",
"react-dom": "^18.2.0", "react-dom": "^18.3.1",
"react-dropzone": "^14.2.3", "react-dropzone": "^14.2.3",
"react-error-boundary": "^4.0.13", "react-error-boundary": "^4.0.13",
"react-hook-form": "^7.51.2", "react-hook-form": "^7.51.4",
"react-hotkeys-hook": "4.5.0", "react-hotkeys-hook": "4.5.0",
"react-i18next": "^14.1.0", "react-i18next": "^14.1.1",
"react-icons": "^5.0.1", "react-icons": "^5.2.0",
"react-konva": "^18.2.10", "react-konva": "^18.2.10",
"react-redux": "9.1.0", "react-redux": "9.1.2",
"react-resizable-panels": "^2.0.16", "react-resizable-panels": "^2.0.19",
"react-rnd": "^10.4.10",
"react-select": "5.8.0", "react-select": "5.8.0",
"react-use": "^17.5.0", "react-use": "^17.5.0",
"react-virtuoso": "^4.7.5", "react-virtuoso": "^4.7.10",
"reactflow": "^11.10.4", "reactflow": "^11.11.3",
"redux-dynamic-middlewares": "^2.2.0", "redux-dynamic-middlewares": "^2.2.0",
"redux-remember": "^5.1.0", "redux-remember": "^5.1.0",
"redux-undo": "^1.1.0", "redux-undo": "^1.1.0",
@ -105,8 +106,8 @@
"use-device-pixel-ratio": "^1.1.2", "use-device-pixel-ratio": "^1.1.2",
"use-image": "^1.1.1", "use-image": "^1.1.1",
"uuid": "^9.0.1", "uuid": "^9.0.1",
"zod": "^3.22.4", "zod": "^3.23.6",
"zod-validation-error": "^3.0.3" "zod-validation-error": "^3.2.0"
}, },
"peerDependencies": { "peerDependencies": {
"@chakra-ui/react": "^2.8.2", "@chakra-ui/react": "^2.8.2",
@ -117,40 +118,42 @@
"devDependencies": { "devDependencies": {
"@invoke-ai/eslint-config-react": "^0.0.14", "@invoke-ai/eslint-config-react": "^0.0.14",
"@invoke-ai/prettier-config-react": "^0.0.7", "@invoke-ai/prettier-config-react": "^0.0.7",
"@storybook/addon-essentials": "^8.0.4", "@storybook/addon-essentials": "^8.0.10",
"@storybook/addon-interactions": "^8.0.4", "@storybook/addon-interactions": "^8.0.10",
"@storybook/addon-links": "^8.0.4", "@storybook/addon-links": "^8.0.10",
"@storybook/addon-storysource": "^8.0.4", "@storybook/addon-storysource": "^8.0.10",
"@storybook/manager-api": "^8.0.4", "@storybook/manager-api": "^8.0.10",
"@storybook/react": "^8.0.4", "@storybook/react": "^8.0.10",
"@storybook/react-vite": "^8.0.4", "@storybook/react-vite": "^8.0.10",
"@storybook/theming": "^8.0.4", "@storybook/theming": "^8.0.10",
"@types/dateformat": "^5.0.2", "@types/dateformat": "^5.0.2",
"@types/lodash-es": "^4.17.12", "@types/lodash-es": "^4.17.12",
"@types/node": "^20.11.30", "@types/node": "^20.12.10",
"@types/react": "^18.2.73", "@types/react": "^18.3.1",
"@types/react-dom": "^18.2.22", "@types/react-dom": "^18.3.0",
"@types/uuid": "^9.0.8", "@types/uuid": "^9.0.8",
"@vitejs/plugin-react-swc": "^3.6.0", "@vitejs/plugin-react-swc": "^3.6.0",
"@vitest/coverage-v8": "^1.5.0",
"@vitest/ui": "^1.5.0",
"concurrently": "^8.2.2", "concurrently": "^8.2.2",
"dpdm": "^3.14.0", "dpdm": "^3.14.0",
"eslint": "^8.57.0", "eslint": "^8.57.0",
"eslint-plugin-i18next": "^6.0.3", "eslint-plugin-i18next": "^6.0.3",
"eslint-plugin-path": "^1.3.0", "eslint-plugin-path": "^1.3.0",
"knip": "^5.6.1", "knip": "^5.12.3",
"openapi-types": "^12.1.3", "openapi-types": "^12.1.3",
"openapi-typescript": "^6.7.5", "openapi-typescript": "^6.7.5",
"prettier": "^3.2.5", "prettier": "^3.2.5",
"rollup-plugin-visualizer": "^5.12.0", "rollup-plugin-visualizer": "^5.12.0",
"storybook": "^8.0.4", "storybook": "^8.0.10",
"ts-toolbelt": "^9.6.0", "ts-toolbelt": "^9.6.0",
"tsafe": "^1.6.6", "tsafe": "^1.6.6",
"typescript": "^5.4.3", "typescript": "^5.4.5",
"vite": "^5.2.6", "vite": "^5.2.11",
"vite-plugin-css-injected-by-js": "^3.5.0", "vite-plugin-css-injected-by-js": "^3.5.1",
"vite-plugin-dts": "^3.8.0", "vite-plugin-dts": "^3.9.1",
"vite-plugin-eslint": "^1.8.1", "vite-plugin-eslint": "^1.8.1",
"vite-tsconfig-paths": "^4.3.2", "vite-tsconfig-paths": "^4.3.2",
"vitest": "^1.4.0" "vitest": "^1.6.0"
} }
} }

File diff suppressed because it is too large Load Diff

View File

@ -76,7 +76,9 @@
"aboutHeading": "Nutzen Sie Ihre kreative Energie", "aboutHeading": "Nutzen Sie Ihre kreative Energie",
"toResolve": "Lösen", "toResolve": "Lösen",
"add": "Hinzufügen", "add": "Hinzufügen",
"loglevel": "Protokoll Stufe" "loglevel": "Protokoll Stufe",
"selected": "Ausgewählt",
"beta": "Beta"
}, },
"gallery": { "gallery": {
"galleryImageSize": "Bildgröße", "galleryImageSize": "Bildgröße",
@ -86,7 +88,7 @@
"noImagesInGallery": "Keine Bilder in der Galerie", "noImagesInGallery": "Keine Bilder in der Galerie",
"loading": "Lade", "loading": "Lade",
"deleteImage_one": "Lösche Bild", "deleteImage_one": "Lösche Bild",
"deleteImage_other": "", "deleteImage_other": "Lösche {{count}} Bilder",
"copy": "Kopieren", "copy": "Kopieren",
"download": "Runterladen", "download": "Runterladen",
"setCurrentImage": "Setze aktuelle Bild", "setCurrentImage": "Setze aktuelle Bild",
@ -397,7 +399,14 @@
"cancel": "Stornieren", "cancel": "Stornieren",
"defaultSettingsSaved": "Standardeinstellungen gespeichert", "defaultSettingsSaved": "Standardeinstellungen gespeichert",
"addModels": "Model hinzufügen", "addModels": "Model hinzufügen",
"deleteModelImage": "Lösche Model Bild" "deleteModelImage": "Lösche Model Bild",
"hfTokenInvalidErrorMessage": "Falscher oder fehlender HuggingFace Schlüssel.",
"huggingFaceRepoID": "HuggingFace Repo ID",
"hfToken": "HuggingFace Schlüssel",
"hfTokenInvalid": "Falscher oder fehlender HF Schlüssel",
"huggingFacePlaceholder": "besitzer/model-name",
"hfTokenSaved": "HF Schlüssel gespeichert",
"hfTokenUnableToVerify": "Konnte den HF Schlüssel nicht validieren"
}, },
"parameters": { "parameters": {
"images": "Bilder", "images": "Bilder",
@ -686,7 +695,11 @@
"hands": "Hände", "hands": "Hände",
"dwOpenpose": "DW Openpose", "dwOpenpose": "DW Openpose",
"dwOpenposeDescription": "Posenschätzung mit DW Openpose", "dwOpenposeDescription": "Posenschätzung mit DW Openpose",
"selectCLIPVisionModel": "Wähle ein CLIP Vision Model aus" "selectCLIPVisionModel": "Wähle ein CLIP Vision Model aus",
"ipAdapterMethod": "Methode",
"composition": "Nur Komposition",
"full": "Voll",
"style": "Nur Style"
}, },
"queue": { "queue": {
"status": "Status", "status": "Status",
@ -717,7 +730,6 @@
"resume": "Wieder aufnehmen", "resume": "Wieder aufnehmen",
"item": "Auftrag", "item": "Auftrag",
"notReady": "Warteschlange noch nicht bereit", "notReady": "Warteschlange noch nicht bereit",
"queueCountPrediction": "{{promptsCount}} Prompts × {{iterations}} Iterationen -> {{count}} Generationen",
"clearQueueAlertDialog": "\"Die Warteschlange leeren\" stoppt den aktuellen Prozess und leert die Warteschlange komplett.", "clearQueueAlertDialog": "\"Die Warteschlange leeren\" stoppt den aktuellen Prozess und leert die Warteschlange komplett.",
"completedIn": "Fertig in", "completedIn": "Fertig in",
"cancelBatchSucceeded": "Stapel abgebrochen", "cancelBatchSucceeded": "Stapel abgebrochen",

View File

@ -2,6 +2,7 @@
"accessibility": { "accessibility": {
"about": "About", "about": "About",
"createIssue": "Create Issue", "createIssue": "Create Issue",
"submitSupportTicket": "Submit Support Ticket",
"invokeProgressBar": "Invoke progress bar", "invokeProgressBar": "Invoke progress bar",
"menu": "Menu", "menu": "Menu",
"mode": "Mode", "mode": "Mode",
@ -142,9 +143,15 @@
"blue": "Blue", "blue": "Blue",
"alpha": "Alpha", "alpha": "Alpha",
"selected": "Selected", "selected": "Selected",
"viewer": "Viewer",
"tab": "Tab", "tab": "Tab",
"close": "Close" "viewing": "Viewing",
"viewingDesc": "Review images in a large gallery view",
"editing": "Editing",
"editingDesc": "Edit on the Control Layers canvas",
"comparing": "Comparing",
"comparingDesc": "Comparing two images",
"enabled": "Enabled",
"disabled": "Disabled"
}, },
"controlnet": { "controlnet": {
"controlAdapter_one": "Control Adapter", "controlAdapter_one": "Control Adapter",
@ -259,7 +266,6 @@
"queue": "Queue", "queue": "Queue",
"queueFront": "Add to Front of Queue", "queueFront": "Add to Front of Queue",
"queueBack": "Add to Queue", "queueBack": "Add to Queue",
"queueCountPrediction": "{{promptsCount}} prompts \u00d7 {{iterations}} iterations -> {{count}} generations",
"queueEmpty": "Queue Empty", "queueEmpty": "Queue Empty",
"enqueueing": "Queueing Batch", "enqueueing": "Queueing Batch",
"resume": "Resume", "resume": "Resume",
@ -312,7 +318,13 @@
"batchFailedToQueue": "Failed to Queue Batch", "batchFailedToQueue": "Failed to Queue Batch",
"graphQueued": "Graph queued", "graphQueued": "Graph queued",
"graphFailedToQueue": "Failed to queue graph", "graphFailedToQueue": "Failed to queue graph",
"openQueue": "Open Queue" "openQueue": "Open Queue",
"prompts_one": "Prompt",
"prompts_other": "Prompts",
"iterations_one": "Iteration",
"iterations_other": "Iterations",
"generations_one": "Generation",
"generations_other": "Generations"
}, },
"invocationCache": { "invocationCache": {
"invocationCache": "Invocation Cache", "invocationCache": "Invocation Cache",
@ -366,9 +378,22 @@
"bulkDownloadFailed": "Download Failed", "bulkDownloadFailed": "Download Failed",
"problemDeletingImages": "Problem Deleting Images", "problemDeletingImages": "Problem Deleting Images",
"problemDeletingImagesDesc": "One or more images could not be deleted", "problemDeletingImagesDesc": "One or more images could not be deleted",
"switchTo": "Switch to {{ tab }} (Z)", "viewerImage": "Viewer Image",
"openFloatingViewer": "Open Floating Viewer", "compareImage": "Compare Image",
"closeFloatingViewer": "Close Floating Viewer" "openInViewer": "Open in Viewer",
"selectForCompare": "Select for Compare",
"selectAnImageToCompare": "Select an Image to Compare",
"slider": "Slider",
"sideBySide": "Side-by-Side",
"hover": "Hover",
"swapImages": "Swap Images",
"compareOptions": "Comparison Options",
"stretchToFit": "Stretch to Fit",
"exitCompare": "Exit Compare",
"compareHelp1": "Hold <Kbd>Alt</Kbd> while clicking a gallery image or using the arrow keys to change the compare image.",
"compareHelp2": "Press <Kbd>M</Kbd> to cycle through comparison modes.",
"compareHelp3": "Press <Kbd>C</Kbd> to swap the compared images.",
"compareHelp4": "Press <Kbd>Z</Kbd> or <Kbd>Esc</Kbd> to exit."
}, },
"hotkeys": { "hotkeys": {
"searchHotkeys": "Search Hotkeys", "searchHotkeys": "Search Hotkeys",
@ -770,10 +795,15 @@
"cannotConnectOutputToOutput": "Cannot connect output to output", "cannotConnectOutputToOutput": "Cannot connect output to output",
"cannotConnectToSelf": "Cannot connect to self", "cannotConnectToSelf": "Cannot connect to self",
"cannotDuplicateConnection": "Cannot create duplicate connections", "cannotDuplicateConnection": "Cannot create duplicate connections",
"cannotMixAndMatchCollectionItemTypes": "Cannot mix and match collection item types",
"missingNode": "Missing invocation node",
"missingInvocationTemplate": "Missing invocation template",
"missingFieldTemplate": "Missing field template",
"nodePack": "Node pack", "nodePack": "Node pack",
"collection": "Collection", "collection": "Collection",
"collectionFieldType": "{{name}} Collection", "singleFieldType": "{{name}} (Single)",
"collectionOrScalarFieldType": "{{name}} Collection|Scalar", "collectionFieldType": "{{name}} (Collection)",
"collectionOrScalarFieldType": "{{name}} (Single or Collection)",
"colorCodeEdges": "Color-Code Edges", "colorCodeEdges": "Color-Code Edges",
"colorCodeEdgesHelp": "Color-code edges according to their connected fields", "colorCodeEdgesHelp": "Color-code edges according to their connected fields",
"connectionWouldCreateCycle": "Connection would create a cycle", "connectionWouldCreateCycle": "Connection would create a cycle",
@ -875,6 +905,7 @@
"versionUnknown": " Version Unknown", "versionUnknown": " Version Unknown",
"workflow": "Workflow", "workflow": "Workflow",
"graph": "Graph", "graph": "Graph",
"noGraph": "No Graph",
"workflowAuthor": "Author", "workflowAuthor": "Author",
"workflowContact": "Contact", "workflowContact": "Contact",
"workflowDescription": "Short Description", "workflowDescription": "Short Description",
@ -887,7 +918,10 @@
"zoomInNodes": "Zoom In", "zoomInNodes": "Zoom In",
"zoomOutNodes": "Zoom Out", "zoomOutNodes": "Zoom Out",
"betaDesc": "This invocation is in beta. Until it is stable, it may have breaking changes during app updates. We plan to support this invocation long-term.", "betaDesc": "This invocation is in beta. Until it is stable, it may have breaking changes during app updates. We plan to support this invocation long-term.",
"prototypeDesc": "This invocation is a prototype. It may have breaking changes during app updates and may be removed at any time." "prototypeDesc": "This invocation is a prototype. It may have breaking changes during app updates and may be removed at any time.",
"imageAccessError": "Unable to find image {{image_name}}, resetting to default",
"boardAccessError": "Unable to find board {{board_id}}, resetting to default",
"modelAccessError": "Unable to find model {{key}}, resetting to default"
}, },
"parameters": { "parameters": {
"aspect": "Aspect", "aspect": "Aspect",
@ -935,17 +969,30 @@
"noModelSelected": "No model selected", "noModelSelected": "No model selected",
"noPrompts": "No prompts generated", "noPrompts": "No prompts generated",
"noNodesInGraph": "No nodes in graph", "noNodesInGraph": "No nodes in graph",
"systemDisconnected": "System disconnected" "systemDisconnected": "System disconnected",
"layer": {
"initialImageNoImageSelected": "no initial image selected",
"controlAdapterNoModelSelected": "no Control Adapter model selected",
"controlAdapterIncompatibleBaseModel": "incompatible Control Adapter base model",
"controlAdapterNoImageSelected": "no Control Adapter image selected",
"controlAdapterImageNotProcessed": "Control Adapter image not processed",
"t2iAdapterIncompatibleDimensions": "T2I Adapter requires image dimension to be multiples of {{multiple}}",
"ipAdapterNoModelSelected": "no IP adapter selected",
"ipAdapterIncompatibleBaseModel": "incompatible IP Adapter base model",
"ipAdapterNoImageSelected": "no IP Adapter image selected",
"rgNoPromptsOrIPAdapters": "no text prompts or IP Adapters",
"rgNoRegion": "no region selected"
}
}, },
"maskBlur": "Mask Blur", "maskBlur": "Mask Blur",
"negativePromptPlaceholder": "Negative Prompt", "negativePromptPlaceholder": "Negative Prompt",
"globalNegativePromptPlaceholder": "Global Negative Prompt",
"noiseThreshold": "Noise Threshold", "noiseThreshold": "Noise Threshold",
"patchmatchDownScaleSize": "Downscale", "patchmatchDownScaleSize": "Downscale",
"perlinNoise": "Perlin Noise", "perlinNoise": "Perlin Noise",
"positivePromptPlaceholder": "Positive Prompt", "positivePromptPlaceholder": "Positive Prompt",
"globalPositivePromptPlaceholder": "Global Positive Prompt",
"iterations": "Iterations", "iterations": "Iterations",
"iterationsWithCount_one": "{{count}} Iteration",
"iterationsWithCount_other": "{{count}} Iterations",
"scale": "Scale", "scale": "Scale",
"scaleBeforeProcessing": "Scale Before Processing", "scaleBeforeProcessing": "Scale Before Processing",
"scaledHeight": "Scaled H", "scaledHeight": "Scaled H",
@ -1047,8 +1094,9 @@
}, },
"toast": { "toast": {
"addedToBoard": "Added to board", "addedToBoard": "Added to board",
"baseModelChangedCleared_one": "Base model changed, cleared or disabled {{count}} incompatible submodel", "baseModelChanged": "Base Model Changed",
"baseModelChangedCleared_other": "Base model changed, cleared or disabled {{count}} incompatible submodels", "baseModelChangedCleared_one": "Cleared or disabled {{count}} incompatible submodel",
"baseModelChangedCleared_other": "Cleared or disabled {{count}} incompatible submodels",
"canceled": "Processing Canceled", "canceled": "Processing Canceled",
"canvasCopiedClipboard": "Canvas Copied to Clipboard", "canvasCopiedClipboard": "Canvas Copied to Clipboard",
"canvasDownloaded": "Canvas Downloaded", "canvasDownloaded": "Canvas Downloaded",
@ -1069,10 +1117,17 @@
"metadataLoadFailed": "Failed to load metadata", "metadataLoadFailed": "Failed to load metadata",
"modelAddedSimple": "Model Added to Queue", "modelAddedSimple": "Model Added to Queue",
"modelImportCanceled": "Model Import Canceled", "modelImportCanceled": "Model Import Canceled",
"outOfMemoryError": "Out of Memory Error",
"outOfMemoryErrorDesc": "Your current generation settings exceed system capacity. Please adjust your settings and try again.",
"parameters": "Parameters", "parameters": "Parameters",
"parameterNotSet": "{{parameter}} not set", "parameterSet": "Parameter Recalled",
"parameterSet": "{{parameter}} set", "parameterSetDesc": "Recalled {{parameter}}",
"parametersNotSet": "Parameters Not Set", "parameterNotSet": "Parameter Not Recalled",
"parameterNotSetDesc": "Unable to recall {{parameter}}",
"parameterNotSetDescWithMessage": "Unable to recall {{parameter}}: {{message}}",
"parametersSet": "Parameters Recalled",
"parametersNotSet": "Parameters Not Recalled",
"errorCopied": "Error Copied",
"problemCopyingCanvas": "Problem Copying Canvas", "problemCopyingCanvas": "Problem Copying Canvas",
"problemCopyingCanvasDesc": "Unable to export base layer", "problemCopyingCanvasDesc": "Unable to export base layer",
"problemCopyingImage": "Unable to Copy Image", "problemCopyingImage": "Unable to Copy Image",
@ -1092,11 +1147,13 @@
"sentToImageToImage": "Sent To Image To Image", "sentToImageToImage": "Sent To Image To Image",
"sentToUnifiedCanvas": "Sent to Unified Canvas", "sentToUnifiedCanvas": "Sent to Unified Canvas",
"serverError": "Server Error", "serverError": "Server Error",
"sessionRef": "Session: {{sessionId}}",
"setAsCanvasInitialImage": "Set as canvas initial image", "setAsCanvasInitialImage": "Set as canvas initial image",
"setCanvasInitialImage": "Set canvas initial image", "setCanvasInitialImage": "Set canvas initial image",
"setControlImage": "Set as control image", "setControlImage": "Set as control image",
"setInitialImage": "Set as initial image", "setInitialImage": "Set as initial image",
"setNodeField": "Set as node field", "setNodeField": "Set as node field",
"somethingWentWrong": "Something Went Wrong",
"uploadFailed": "Upload failed", "uploadFailed": "Upload failed",
"uploadFailedInvalidUploadDesc": "Must be single PNG or JPEG image", "uploadFailedInvalidUploadDesc": "Must be single PNG or JPEG image",
"uploadInitialImage": "Upload Initial Image", "uploadInitialImage": "Upload Initial Image",
@ -1536,7 +1593,6 @@
"controlLayers": "Control Layers", "controlLayers": "Control Layers",
"globalMaskOpacity": "Global Mask Opacity", "globalMaskOpacity": "Global Mask Opacity",
"autoNegative": "Auto Negative", "autoNegative": "Auto Negative",
"toggleVisibility": "Toggle Layer Visibility",
"deletePrompt": "Delete Prompt", "deletePrompt": "Delete Prompt",
"resetRegion": "Reset Region", "resetRegion": "Reset Region",
"debugLayers": "Debug Layers", "debugLayers": "Debug Layers",
@ -1547,8 +1603,6 @@
"addIPAdapter": "Add $t(common.ipAdapter)", "addIPAdapter": "Add $t(common.ipAdapter)",
"regionalGuidance": "Regional Guidance", "regionalGuidance": "Regional Guidance",
"regionalGuidanceLayer": "$t(controlLayers.regionalGuidance) $t(unifiedCanvas.layer)", "regionalGuidanceLayer": "$t(controlLayers.regionalGuidance) $t(unifiedCanvas.layer)",
"controlNetLayer": "$t(common.controlNet) $t(unifiedCanvas.layer)",
"ipAdapterLayer": "$t(common.ipAdapter) $t(unifiedCanvas.layer)",
"opacity": "Opacity", "opacity": "Opacity",
"globalControlAdapter": "Global $t(controlnet.controlAdapter_one)", "globalControlAdapter": "Global $t(controlnet.controlAdapter_one)",
"globalControlAdapterLayer": "Global $t(controlnet.controlAdapter_one) $t(unifiedCanvas.layer)", "globalControlAdapterLayer": "Global $t(controlnet.controlAdapter_one) $t(unifiedCanvas.layer)",
@ -1559,7 +1613,9 @@
"opacityFilter": "Opacity Filter", "opacityFilter": "Opacity Filter",
"clearProcessor": "Clear Processor", "clearProcessor": "Clear Processor",
"resetProcessor": "Reset Processor to Defaults", "resetProcessor": "Reset Processor to Defaults",
"noLayersAdded": "No Layers Added" "noLayersAdded": "No Layers Added",
"layers_one": "Layer",
"layers_other": "Layers"
}, },
"ui": { "ui": {
"tabs": { "tabs": {

View File

@ -25,7 +25,24 @@
"areYouSure": "¿Estas seguro?", "areYouSure": "¿Estas seguro?",
"batch": "Administrador de lotes", "batch": "Administrador de lotes",
"modelManager": "Administrador de modelos", "modelManager": "Administrador de modelos",
"communityLabel": "Comunidad" "communityLabel": "Comunidad",
"direction": "Dirección",
"ai": "Ia",
"add": "Añadir",
"auto": "Automático",
"copyError": "Error $t(gallery.copy)",
"details": "Detalles",
"or": "o",
"checkpoint": "Punto de control",
"controlNet": "ControlNet",
"aboutHeading": "Sea dueño de su poder creativo",
"advanced": "Avanzado",
"data": "Fecha",
"delete": "Borrar",
"copy": "Copiar",
"beta": "Beta",
"on": "En",
"aboutDesc": "¿Utilizas Invoke para trabajar? Mira aquí:"
}, },
"gallery": { "gallery": {
"galleryImageSize": "Tamaño de la imagen", "galleryImageSize": "Tamaño de la imagen",
@ -365,7 +382,7 @@
"canvasMerged": "Lienzo consolidado", "canvasMerged": "Lienzo consolidado",
"sentToImageToImage": "Enviar hacia Imagen a Imagen", "sentToImageToImage": "Enviar hacia Imagen a Imagen",
"sentToUnifiedCanvas": "Enviar hacia Lienzo Consolidado", "sentToUnifiedCanvas": "Enviar hacia Lienzo Consolidado",
"parametersNotSet": "Parámetros no establecidos", "parametersNotSet": "Parámetros no recuperados",
"metadataLoadFailed": "Error al cargar metadatos", "metadataLoadFailed": "Error al cargar metadatos",
"serverError": "Error en el servidor", "serverError": "Error en el servidor",
"canceled": "Procesando la cancelación", "canceled": "Procesando la cancelación",
@ -373,7 +390,8 @@
"uploadFailedInvalidUploadDesc": "Debe ser una sola imagen PNG o JPEG", "uploadFailedInvalidUploadDesc": "Debe ser una sola imagen PNG o JPEG",
"parameterSet": "Conjunto de parámetros", "parameterSet": "Conjunto de parámetros",
"parameterNotSet": "Parámetro no configurado", "parameterNotSet": "Parámetro no configurado",
"problemCopyingImage": "No se puede copiar la imagen" "problemCopyingImage": "No se puede copiar la imagen",
"errorCopied": "Error al copiar"
}, },
"tooltip": { "tooltip": {
"feature": { "feature": {
@ -443,7 +461,13 @@
"previousImage": "Imagen anterior", "previousImage": "Imagen anterior",
"nextImage": "Siguiente imagen", "nextImage": "Siguiente imagen",
"showOptionsPanel": "Mostrar el panel lateral", "showOptionsPanel": "Mostrar el panel lateral",
"menu": "Menú" "menu": "Menú",
"showGalleryPanel": "Mostrar panel de galería",
"loadMore": "Cargar más",
"about": "Acerca de",
"createIssue": "Crear un problema",
"resetUI": "Interfaz de usuario $t(accessibility.reset)",
"mode": "Modo"
}, },
"nodes": { "nodes": {
"zoomInNodes": "Acercar", "zoomInNodes": "Acercar",
@ -456,5 +480,68 @@
"reloadNodeTemplates": "Recargar las plantillas de nodos", "reloadNodeTemplates": "Recargar las plantillas de nodos",
"loadWorkflow": "Cargar el flujo de trabajo", "loadWorkflow": "Cargar el flujo de trabajo",
"downloadWorkflow": "Descargar el flujo de trabajo en un archivo JSON" "downloadWorkflow": "Descargar el flujo de trabajo en un archivo JSON"
},
"boards": {
"autoAddBoard": "Agregar panel automáticamente",
"changeBoard": "Cambiar el panel",
"clearSearch": "Borrar la búsqueda",
"deleteBoard": "Borrar el panel",
"selectBoard": "Seleccionar un panel",
"uncategorized": "Sin categoría",
"cancel": "Cancelar",
"addBoard": "Agregar un panel",
"movingImagesToBoard_one": "Moviendo {{count}} imagen al panel:",
"movingImagesToBoard_many": "Moviendo {{count}} imágenes al panel:",
"movingImagesToBoard_other": "Moviendo {{count}} imágenes al panel:",
"bottomMessage": "Al eliminar este panel y las imágenes que contiene, se restablecerán las funciones que los estén utilizando actualmente.",
"deleteBoardAndImages": "Borrar el panel y las imágenes",
"loading": "Cargando...",
"deletedBoardsCannotbeRestored": "Los paneles eliminados no se pueden restaurar",
"move": "Mover",
"menuItemAutoAdd": "Agregar automáticamente a este panel",
"searchBoard": "Buscando paneles…",
"topMessage": "Este panel contiene imágenes utilizadas en las siguientes funciones:",
"downloadBoard": "Descargar panel",
"deleteBoardOnly": "Borrar solo el panel",
"myBoard": "Mi panel",
"noMatching": "No hay paneles que coincidan"
},
"accordions": {
"compositing": {
"title": "Composición",
"infillTab": "Relleno"
},
"generation": {
"title": "Generación"
},
"image": {
"title": "Imagen"
},
"control": {
"title": "Control"
},
"advanced": {
"options": "$t(accordions.advanced.title) opciones",
"title": "Avanzado"
}
},
"ui": {
"tabs": {
"generationTab": "$t(ui.tabs.generation) $t(common.tab)",
"canvas": "Lienzo",
"generation": "Generación",
"queue": "Cola",
"queueTab": "$t(ui.tabs.queue) $t(common.tab)",
"workflows": "Flujos de trabajo",
"models": "Modelos",
"modelsTab": "$t(ui.tabs.models) $t(common.tab)",
"canvasTab": "$t(ui.tabs.canvas) $t(common.tab)",
"workflowsTab": "$t(ui.tabs.workflows) $t(common.tab)"
}
},
"controlLayers": {
"layers_one": "Capa",
"layers_many": "Capas",
"layers_other": "Capas"
} }
} }

View File

@ -5,7 +5,7 @@
"reportBugLabel": "Segnala un errore", "reportBugLabel": "Segnala un errore",
"settingsLabel": "Impostazioni", "settingsLabel": "Impostazioni",
"img2img": "Immagine a Immagine", "img2img": "Immagine a Immagine",
"unifiedCanvas": "Tela unificata", "unifiedCanvas": "Tela",
"nodes": "Flussi di lavoro", "nodes": "Flussi di lavoro",
"upload": "Caricamento", "upload": "Caricamento",
"load": "Carica", "load": "Carica",
@ -74,7 +74,18 @@
"file": "File", "file": "File",
"toResolve": "Da risolvere", "toResolve": "Da risolvere",
"add": "Aggiungi", "add": "Aggiungi",
"loglevel": "Livello di log" "loglevel": "Livello di log",
"beta": "Beta",
"positivePrompt": "Prompt positivo",
"negativePrompt": "Prompt negativo",
"selected": "Selezionato",
"goTo": "Vai a",
"editor": "Editor",
"tab": "Scheda",
"viewing": "Visualizza",
"viewingDesc": "Rivedi le immagini in un'ampia vista della galleria",
"editing": "Modifica",
"editingDesc": "Modifica nell'area Livelli di controllo"
}, },
"gallery": { "gallery": {
"galleryImageSize": "Dimensione dell'immagine", "galleryImageSize": "Dimensione dell'immagine",
@ -180,8 +191,8 @@
"desc": "Mostra le informazioni sui metadati dell'immagine corrente" "desc": "Mostra le informazioni sui metadati dell'immagine corrente"
}, },
"sendToImageToImage": { "sendToImageToImage": {
"title": "Invia a Immagine a Immagine", "title": "Invia a Generazione da immagine",
"desc": "Invia l'immagine corrente a da Immagine a Immagine" "desc": "Invia l'immagine corrente a Generazione da immagine"
}, },
"deleteImage": { "deleteImage": {
"title": "Elimina immagine", "title": "Elimina immagine",
@ -334,6 +345,10 @@
"remixImage": { "remixImage": {
"desc": "Utilizza tutti i parametri tranne il seme dell'immagine corrente", "desc": "Utilizza tutti i parametri tranne il seme dell'immagine corrente",
"title": "Remixa l'immagine" "title": "Remixa l'immagine"
},
"toggleViewer": {
"title": "Attiva/disattiva il visualizzatore di immagini",
"desc": "Passa dal Visualizzatore immagini all'area di lavoro per la scheda corrente."
} }
}, },
"modelManager": { "modelManager": {
@ -471,8 +486,8 @@
"scaledHeight": "Altezza ridimensionata", "scaledHeight": "Altezza ridimensionata",
"infillMethod": "Metodo di riempimento", "infillMethod": "Metodo di riempimento",
"tileSize": "Dimensione piastrella", "tileSize": "Dimensione piastrella",
"sendToImg2Img": "Invia a Immagine a Immagine", "sendToImg2Img": "Invia a Generazione da immagine",
"sendToUnifiedCanvas": "Invia a Tela Unificata", "sendToUnifiedCanvas": "Invia alla Tela",
"downloadImage": "Scarica l'immagine", "downloadImage": "Scarica l'immagine",
"usePrompt": "Usa Prompt", "usePrompt": "Usa Prompt",
"useSeed": "Usa Seme", "useSeed": "Usa Seme",
@ -508,13 +523,24 @@
"incompatibleBaseModelForControlAdapter": "Il modello dell'adattatore di controllo #{{number}} non è compatibile con il modello principale.", "incompatibleBaseModelForControlAdapter": "Il modello dell'adattatore di controllo #{{number}} non è compatibile con il modello principale.",
"missingNodeTemplate": "Modello di nodo mancante", "missingNodeTemplate": "Modello di nodo mancante",
"missingInputForField": "{{nodeLabel}} -> {{fieldLabel}} ingresso mancante", "missingInputForField": "{{nodeLabel}} -> {{fieldLabel}} ingresso mancante",
"missingFieldTemplate": "Modello di campo mancante" "missingFieldTemplate": "Modello di campo mancante",
"imageNotProcessedForControlAdapter": "L'immagine dell'adattatore di controllo #{{number}} non è stata elaborata",
"layer": {
"initialImageNoImageSelected": "Nessuna immagine iniziale selezionata",
"t2iAdapterIncompatibleDimensions": "L'adattatore T2I richiede che la dimensione dell'immagine sia un multiplo di {{multiple}}",
"controlAdapterNoModelSelected": "Nessun modello di Adattatore di Controllo selezionato",
"controlAdapterIncompatibleBaseModel": "Il modello base dell'adattatore di controllo non è compatibile",
"controlAdapterNoImageSelected": "Nessuna immagine dell'adattatore di controllo selezionata",
"controlAdapterImageNotProcessed": "Immagine dell'adattatore di controllo non elaborata",
"ipAdapterNoModelSelected": "Nessun adattatore IP selezionato",
"ipAdapterIncompatibleBaseModel": "Il modello base dell'adattatore IP non è compatibile",
"ipAdapterNoImageSelected": "Nessuna immagine dell'adattatore IP selezionata",
"rgNoPromptsOrIPAdapters": "Nessun prompt o adattatore IP",
"rgNoRegion": "Nessuna regione selezionata"
}
}, },
"useCpuNoise": "Usa la CPU per generare rumore", "useCpuNoise": "Usa la CPU per generare rumore",
"iterations": "Iterazioni", "iterations": "Iterazioni",
"iterationsWithCount_one": "{{count}} Iterazione",
"iterationsWithCount_many": "{{count}} Iterazioni",
"iterationsWithCount_other": "{{count}} Iterazioni",
"isAllowedToUpscale": { "isAllowedToUpscale": {
"useX2Model": "L'immagine è troppo grande per l'ampliamento con il modello x4, utilizza il modello x2", "useX2Model": "L'immagine è troppo grande per l'ampliamento con il modello x4, utilizza il modello x2",
"tooLarge": "L'immagine è troppo grande per l'ampliamento, seleziona un'immagine più piccola" "tooLarge": "L'immagine è troppo grande per l'ampliamento, seleziona un'immagine più piccola"
@ -534,7 +560,10 @@
"infillMosaicMinColor": "Colore minimo", "infillMosaicMinColor": "Colore minimo",
"infillMosaicMaxColor": "Colore massimo", "infillMosaicMaxColor": "Colore massimo",
"infillMosaicTileHeight": "Altezza piastrella", "infillMosaicTileHeight": "Altezza piastrella",
"infillColorValue": "Colore di riempimento" "infillColorValue": "Colore di riempimento",
"globalSettings": "Impostazioni globali",
"globalPositivePromptPlaceholder": "Prompt positivo globale",
"globalNegativePromptPlaceholder": "Prompt negativo globale"
}, },
"settings": { "settings": {
"models": "Modelli", "models": "Modelli",
@ -559,7 +588,7 @@
"intermediatesCleared_one": "Cancellata {{count}} immagine intermedia", "intermediatesCleared_one": "Cancellata {{count}} immagine intermedia",
"intermediatesCleared_many": "Cancellate {{count}} immagini intermedie", "intermediatesCleared_many": "Cancellate {{count}} immagini intermedie",
"intermediatesCleared_other": "Cancellate {{count}} immagini intermedie", "intermediatesCleared_other": "Cancellate {{count}} immagini intermedie",
"clearIntermediatesDesc1": "La cancellazione delle immagini intermedie ripristinerà lo stato di Tela Unificata e ControlNet.", "clearIntermediatesDesc1": "La cancellazione delle immagini intermedie ripristinerà lo stato della Tela e degli Adattatori di Controllo.",
"intermediatesClearedFailed": "Problema con la cancellazione delle immagini intermedie", "intermediatesClearedFailed": "Problema con la cancellazione delle immagini intermedie",
"clearIntermediatesWithCount_one": "Cancella {{count}} immagine intermedia", "clearIntermediatesWithCount_one": "Cancella {{count}} immagine intermedia",
"clearIntermediatesWithCount_many": "Cancella {{count}} immagini intermedie", "clearIntermediatesWithCount_many": "Cancella {{count}} immagini intermedie",
@ -575,8 +604,8 @@
"imageCopied": "Immagine copiata", "imageCopied": "Immagine copiata",
"imageNotLoadedDesc": "Impossibile trovare l'immagine", "imageNotLoadedDesc": "Impossibile trovare l'immagine",
"canvasMerged": "Tela unita", "canvasMerged": "Tela unita",
"sentToImageToImage": "Inviato a Immagine a Immagine", "sentToImageToImage": "Inviato a Generazione da immagine",
"sentToUnifiedCanvas": "Inviato a Tela Unificata", "sentToUnifiedCanvas": "Inviato alla Tela",
"parametersNotSet": "Parametri non impostati", "parametersNotSet": "Parametri non impostati",
"metadataLoadFailed": "Impossibile caricare i metadati", "metadataLoadFailed": "Impossibile caricare i metadati",
"serverError": "Errore del Server", "serverError": "Errore del Server",
@ -795,7 +824,7 @@
"float": "In virgola mobile", "float": "In virgola mobile",
"currentImageDescription": "Visualizza l'immagine corrente nell'editor dei nodi", "currentImageDescription": "Visualizza l'immagine corrente nell'editor dei nodi",
"fieldTypesMustMatch": "I tipi di campo devono corrispondere", "fieldTypesMustMatch": "I tipi di campo devono corrispondere",
"edge": "Bordo", "edge": "Collegamento",
"currentImage": "Immagine corrente", "currentImage": "Immagine corrente",
"integer": "Numero Intero", "integer": "Numero Intero",
"inputMayOnlyHaveOneConnection": "L'ingresso può avere solo una connessione", "inputMayOnlyHaveOneConnection": "L'ingresso può avere solo una connessione",
@ -808,8 +837,8 @@
"unableToUpdateNodes_other": "Impossibile aggiornare {{count}} nodi", "unableToUpdateNodes_other": "Impossibile aggiornare {{count}} nodi",
"addLinearView": "Aggiungi alla vista Lineare", "addLinearView": "Aggiungi alla vista Lineare",
"unknownErrorValidatingWorkflow": "Errore sconosciuto durante la convalida del flusso di lavoro", "unknownErrorValidatingWorkflow": "Errore sconosciuto durante la convalida del flusso di lavoro",
"collectionFieldType": "{{name}} Raccolta", "collectionFieldType": "{{name}} (Raccolta)",
"collectionOrScalarFieldType": "{{name}} Raccolta|Scalare", "collectionOrScalarFieldType": "{{name}} (Singola o Raccolta)",
"nodeVersion": "Versione Nodo", "nodeVersion": "Versione Nodo",
"inputFieldTypeParseError": "Impossibile analizzare il tipo di campo di input {{node}}.{{field}} ({{message}})", "inputFieldTypeParseError": "Impossibile analizzare il tipo di campo di input {{node}}.{{field}} ({{message}})",
"unsupportedArrayItemType": "Tipo di elemento dell'array non supportato \"{{type}}\"", "unsupportedArrayItemType": "Tipo di elemento dell'array non supportato \"{{type}}\"",
@ -845,7 +874,15 @@
"resetToDefaultValue": "Ripristina il valore predefinito", "resetToDefaultValue": "Ripristina il valore predefinito",
"noFieldsViewMode": "Questo flusso di lavoro non ha campi selezionati da visualizzare. Visualizza il flusso di lavoro completo per configurare i valori.", "noFieldsViewMode": "Questo flusso di lavoro non ha campi selezionati da visualizzare. Visualizza il flusso di lavoro completo per configurare i valori.",
"edit": "Modifica", "edit": "Modifica",
"graph": "Grafico" "graph": "Grafico",
"showEdgeLabelsHelp": "Mostra etichette sui collegamenti, che indicano i nodi collegati",
"showEdgeLabels": "Mostra le etichette del collegamento",
"cannotMixAndMatchCollectionItemTypes": "Impossibile combinare e abbinare i tipi di elementi della raccolta",
"noGraph": "Nessun grafico",
"missingNode": "Nodo di invocazione mancante",
"missingInvocationTemplate": "Modello di invocazione mancante",
"missingFieldTemplate": "Modello di campo mancante",
"singleFieldType": "{{name}} (Singola)"
}, },
"boards": { "boards": {
"autoAddBoard": "Aggiungi automaticamente bacheca", "autoAddBoard": "Aggiungi automaticamente bacheca",
@ -922,7 +959,7 @@
"colorMapTileSize": "Dimensione piastrella", "colorMapTileSize": "Dimensione piastrella",
"mediapipeFaceDescription": "Rilevamento dei volti tramite Mediapipe", "mediapipeFaceDescription": "Rilevamento dei volti tramite Mediapipe",
"hedDescription": "Rilevamento dei bordi nidificati olisticamente", "hedDescription": "Rilevamento dei bordi nidificati olisticamente",
"setControlImageDimensions": "Imposta le dimensioni dell'immagine di controllo su L/A", "setControlImageDimensions": "Copia le dimensioni in L/A (ottimizza per il modello)",
"maxFaces": "Numero massimo di volti", "maxFaces": "Numero massimo di volti",
"addT2IAdapter": "Aggiungi $t(common.t2iAdapter)", "addT2IAdapter": "Aggiungi $t(common.t2iAdapter)",
"addControlNet": "Aggiungi $t(common.controlNet)", "addControlNet": "Aggiungi $t(common.controlNet)",
@ -951,12 +988,17 @@
"mediapipeFace": "Mediapipe Volto", "mediapipeFace": "Mediapipe Volto",
"ip_adapter": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.ipAdapter))", "ip_adapter": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.ipAdapter))",
"t2i_adapter": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.t2iAdapter))", "t2i_adapter": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.t2iAdapter))",
"selectCLIPVisionModel": "Seleziona un modello CLIP Vision" "selectCLIPVisionModel": "Seleziona un modello CLIP Vision",
"ipAdapterMethod": "Metodo",
"full": "Completo",
"composition": "Solo la composizione",
"style": "Solo lo stile",
"beginEndStepPercentShort": "Inizio/Fine %",
"setControlImageDimensionsForce": "Copia le dimensioni in L/A (ignora il modello)"
}, },
"queue": { "queue": {
"queueFront": "Aggiungi all'inizio della coda", "queueFront": "Aggiungi all'inizio della coda",
"queueBack": "Aggiungi alla coda", "queueBack": "Aggiungi alla coda",
"queueCountPrediction": "{{promptsCount}} prompt × {{iterations}} iterazioni -> {{count}} generazioni",
"queue": "Coda", "queue": "Coda",
"status": "Stato", "status": "Stato",
"pruneSucceeded": "Rimossi {{item_count}} elementi completati dalla coda", "pruneSucceeded": "Rimossi {{item_count}} elementi completati dalla coda",
@ -993,7 +1035,7 @@
"cancelBatchSucceeded": "Lotto annullato", "cancelBatchSucceeded": "Lotto annullato",
"clearTooltip": "Annulla e cancella tutti gli elementi", "clearTooltip": "Annulla e cancella tutti gli elementi",
"current": "Attuale", "current": "Attuale",
"pauseTooltip": "Sospende l'elaborazione", "pauseTooltip": "Sospendi l'elaborazione",
"failed": "Falliti", "failed": "Falliti",
"cancelItem": "Annulla l'elemento", "cancelItem": "Annulla l'elemento",
"next": "Prossimo", "next": "Prossimo",
@ -1011,7 +1053,16 @@
"graphFailedToQueue": "Impossibile mettere in coda il grafico", "graphFailedToQueue": "Impossibile mettere in coda il grafico",
"batchFieldValues": "Valori Campi Lotto", "batchFieldValues": "Valori Campi Lotto",
"time": "Tempo", "time": "Tempo",
"openQueue": "Apri coda" "openQueue": "Apri coda",
"iterations_one": "Iterazione",
"iterations_many": "Iterazioni",
"iterations_other": "Iterazioni",
"prompts_one": "Prompt",
"prompts_many": "Prompt",
"prompts_other": "Prompt",
"generations_one": "Generazione",
"generations_many": "Generazioni",
"generations_other": "Generazioni"
}, },
"models": { "models": {
"noMatchingModels": "Nessun modello corrispondente", "noMatchingModels": "Nessun modello corrispondente",
@ -1394,6 +1445,12 @@
"paragraphs": [ "paragraphs": [
"La dimensione del bordo del passaggio di coerenza." "La dimensione del bordo del passaggio di coerenza."
] ]
},
"ipAdapterMethod": {
"heading": "Metodo",
"paragraphs": [
"Metodo con cui applicare l'adattatore IP corrente."
]
} }
}, },
"sdxl": { "sdxl": {
@ -1522,5 +1579,55 @@
"compatibleEmbeddings": "Incorporamenti compatibili", "compatibleEmbeddings": "Incorporamenti compatibili",
"addPromptTrigger": "Aggiungi Trigger nel prompt", "addPromptTrigger": "Aggiungi Trigger nel prompt",
"noMatchingTriggers": "Nessun Trigger corrispondente" "noMatchingTriggers": "Nessun Trigger corrispondente"
},
"controlLayers": {
"opacityFilter": "Filtro opacità",
"deleteAll": "Cancella tutto",
"addLayer": "Aggiungi Livello",
"moveToFront": "Sposta in primo piano",
"moveToBack": "Sposta in fondo",
"moveForward": "Sposta avanti",
"moveBackward": "Sposta indietro",
"brushSize": "Dimensioni del pennello",
"globalMaskOpacity": "Opacità globale della maschera",
"autoNegative": "Auto Negativo",
"deletePrompt": "Cancella il prompt",
"debugLayers": "Debug dei Livelli",
"rectangle": "Rettangolo",
"maskPreviewColor": "Colore anteprima maschera",
"addPositivePrompt": "Aggiungi $t(common.positivePrompt)",
"addNegativePrompt": "Aggiungi $t(common.negativePrompt)",
"addIPAdapter": "Aggiungi $t(common.ipAdapter)",
"regionalGuidance": "Guida regionale",
"regionalGuidanceLayer": "$t(unifiedCanvas.layer) $t(controlLayers.regionalGuidance)",
"opacity": "Opacità",
"globalControlAdapter": "$t(controlnet.controlAdapter_one) Globale",
"globalControlAdapterLayer": "$t(controlnet.controlAdapter_one) - $t(unifiedCanvas.layer) Globale",
"globalIPAdapter": "$t(common.ipAdapter) Globale",
"globalIPAdapterLayer": "$t(common.ipAdapter) - $t(unifiedCanvas.layer) Globale",
"globalInitialImage": "Immagine iniziale",
"globalInitialImageLayer": "$t(controlLayers.globalInitialImage) - $t(unifiedCanvas.layer) Globale",
"clearProcessor": "Cancella processore",
"resetProcessor": "Ripristina il processore alle impostazioni predefinite",
"noLayersAdded": "Nessun livello aggiunto",
"resetRegion": "Reimposta la regione",
"controlLayers": "Livelli di controllo",
"layers_one": "Livello",
"layers_many": "Livelli",
"layers_other": "Livelli"
},
"ui": {
"tabs": {
"generation": "Generazione",
"generationTab": "$t(ui.tabs.generation) $t(common.tab)",
"canvas": "Tela",
"canvasTab": "$t(ui.tabs.canvas) $t(common.tab)",
"workflows": "Flussi di lavoro",
"workflowsTab": "$t(ui.tabs.workflows) $t(common.tab)",
"models": "Modelli",
"modelsTab": "$t(ui.tabs.models) $t(common.tab)",
"queue": "Coda",
"queueTab": "$t(ui.tabs.queue) $t(common.tab)"
}
} }
} }

View File

@ -570,7 +570,6 @@
"pauseSucceeded": "処理が一時停止されました", "pauseSucceeded": "処理が一時停止されました",
"queueFront": "キューの先頭へ追加", "queueFront": "キューの先頭へ追加",
"queueBack": "キューに追加", "queueBack": "キューに追加",
"queueCountPrediction": "{{promptsCount}} プロンプト × {{iterations}} イテレーション -> {{count}} 枚生成",
"pause": "一時停止", "pause": "一時停止",
"queue": "キュー", "queue": "キュー",
"pauseTooltip": "処理を一時停止", "pauseTooltip": "処理を一時停止",

View File

@ -505,7 +505,6 @@
"completed": "완성된", "completed": "완성된",
"queueBack": "Queue에 추가", "queueBack": "Queue에 추가",
"cancelFailed": "항목 취소 중 발생한 문제", "cancelFailed": "항목 취소 중 발생한 문제",
"queueCountPrediction": "Queue에 {{predicted}} 추가",
"batchQueued": "Batch Queued", "batchQueued": "Batch Queued",
"pauseFailed": "프로세서 중지 중 발생한 문제", "pauseFailed": "프로세서 중지 중 발생한 문제",
"clearFailed": "Queue 제거 중 발생한 문제", "clearFailed": "Queue 제거 중 발생한 문제",

View File

@ -6,7 +6,7 @@
"settingsLabel": "Instellingen", "settingsLabel": "Instellingen",
"img2img": "Afbeelding naar afbeelding", "img2img": "Afbeelding naar afbeelding",
"unifiedCanvas": "Centraal canvas", "unifiedCanvas": "Centraal canvas",
"nodes": "Werkstroom-editor", "nodes": "Werkstromen",
"upload": "Upload", "upload": "Upload",
"load": "Laad", "load": "Laad",
"statusDisconnected": "Niet verbonden", "statusDisconnected": "Niet verbonden",
@ -34,7 +34,60 @@
"controlNet": "ControlNet", "controlNet": "ControlNet",
"imageFailedToLoad": "Kan afbeelding niet laden", "imageFailedToLoad": "Kan afbeelding niet laden",
"learnMore": "Meer informatie", "learnMore": "Meer informatie",
"advanced": "Uitgebreid" "advanced": "Uitgebreid",
"file": "Bestand",
"installed": "Geïnstalleerd",
"notInstalled": "Niet $t(common.installed)",
"simple": "Eenvoudig",
"somethingWentWrong": "Er ging iets mis",
"add": "Voeg toe",
"checkpoint": "Checkpoint",
"details": "Details",
"outputs": "Uitvoeren",
"save": "Bewaar",
"nextPage": "Volgende pagina",
"blue": "Blauw",
"alpha": "Alfa",
"red": "Rood",
"editor": "Editor",
"folder": "Map",
"format": "structuur",
"goTo": "Ga naar",
"template": "Sjabloon",
"input": "Invoer",
"loglevel": "Logboekniveau",
"safetensors": "Safetensors",
"saveAs": "Bewaar als",
"created": "Gemaakt",
"green": "Groen",
"tab": "Tab",
"positivePrompt": "Positieve prompt",
"negativePrompt": "Negatieve prompt",
"selected": "Geselecteerd",
"orderBy": "Sorteer op",
"prevPage": "Vorige pagina",
"beta": "Bèta",
"copyError": "$t(gallery.copy) Fout",
"toResolve": "Op te lossen",
"aboutDesc": "Gebruik je Invoke voor het werk? Kijk dan naar:",
"aboutHeading": "Creatieve macht voor jou",
"copy": "Kopieer",
"data": "Gegevens",
"or": "of",
"updated": "Bijgewerkt",
"outpaint": "outpainten",
"viewing": "Bekijken",
"viewingDesc": "Beoordeel afbeelding in een grote galerijweergave",
"editing": "Bewerken",
"editingDesc": "Bewerk op het canvas Stuurlagen",
"ai": "ai",
"inpaint": "inpainten",
"unknown": "Onbekend",
"delete": "Verwijder",
"direction": "Richting",
"error": "Fout",
"localSystem": "Lokaal systeem",
"unknownError": "Onbekende fout"
}, },
"gallery": { "gallery": {
"galleryImageSize": "Afbeeldingsgrootte", "galleryImageSize": "Afbeeldingsgrootte",
@ -310,10 +363,41 @@
"modelSyncFailed": "Synchronisatie modellen mislukt", "modelSyncFailed": "Synchronisatie modellen mislukt",
"modelDeleteFailed": "Model kon niet verwijderd worden", "modelDeleteFailed": "Model kon niet verwijderd worden",
"convertingModelBegin": "Model aan het converteren. Even geduld.", "convertingModelBegin": "Model aan het converteren. Even geduld.",
"predictionType": "Soort voorspelling (voor Stable Diffusion 2.x-modellen en incidentele Stable Diffusion 1.x-modellen)", "predictionType": "Soort voorspelling",
"advanced": "Uitgebreid", "advanced": "Uitgebreid",
"modelType": "Soort model", "modelType": "Soort model",
"vaePrecision": "Nauwkeurigheid VAE" "vaePrecision": "Nauwkeurigheid VAE",
"loraTriggerPhrases": "LoRA-triggerzinnen",
"urlOrLocalPathHelper": "URL's zouden moeten wijzen naar een los bestand. Lokale paden kunnen wijzen naar een los bestand of map voor een individueel Diffusers-model.",
"modelName": "Modelnaam",
"path": "Pad",
"triggerPhrases": "Triggerzinnen",
"typePhraseHere": "Typ zin hier in",
"useDefaultSettings": "Gebruik standaardinstellingen",
"modelImageDeleteFailed": "Fout bij verwijderen modelafbeelding",
"modelImageUpdated": "Modelafbeelding bijgewerkt",
"modelImageUpdateFailed": "Fout bij bijwerken modelafbeelding",
"noMatchingModels": "Geen overeenkomende modellen",
"scanPlaceholder": "Pad naar een lokale map",
"noModelsInstalled": "Geen modellen geïnstalleerd",
"noModelsInstalledDesc1": "Installeer modellen met de",
"noModelSelected": "Geen model geselecteerd",
"starterModels": "Beginnermodellen",
"textualInversions": "Tekstuele omkeringen",
"upcastAttention": "Upcast-aandacht",
"uploadImage": "Upload afbeelding",
"mainModelTriggerPhrases": "Triggerzinnen hoofdmodel",
"urlOrLocalPath": "URL of lokaal pad",
"scanFolderHelper": "De map zal recursief worden ingelezen voor modellen. Dit kan enige tijd in beslag nemen voor erg grote mappen.",
"simpleModelPlaceholder": "URL of pad naar een lokaal pad of Diffusers-map",
"modelSettings": "Modelinstellingen",
"pathToConfig": "Pad naar configuratie",
"prune": "Snoei",
"pruneTooltip": "Snoei voltooide importeringen uit wachtrij",
"repoVariant": "Repovariant",
"scanFolder": "Lees map in",
"scanResults": "Resultaten inlezen",
"source": "Bron"
}, },
"parameters": { "parameters": {
"images": "Afbeeldingen", "images": "Afbeeldingen",
@ -353,13 +437,13 @@
"copyImage": "Kopieer afbeelding", "copyImage": "Kopieer afbeelding",
"denoisingStrength": "Sterkte ontruisen", "denoisingStrength": "Sterkte ontruisen",
"scheduler": "Planner", "scheduler": "Planner",
"seamlessXAxis": "X-as", "seamlessXAxis": "Naadloze tegels in x-as",
"seamlessYAxis": "Y-as", "seamlessYAxis": "Naadloze tegels in y-as",
"clipSkip": "Overslaan CLIP", "clipSkip": "Overslaan CLIP",
"negativePromptPlaceholder": "Negatieve prompt", "negativePromptPlaceholder": "Negatieve prompt",
"controlNetControlMode": "Aansturingsmodus", "controlNetControlMode": "Aansturingsmodus",
"positivePromptPlaceholder": "Positieve prompt", "positivePromptPlaceholder": "Positieve prompt",
"maskBlur": "Vervaag", "maskBlur": "Vervaging van masker",
"invoke": { "invoke": {
"noNodesInGraph": "Geen knooppunten in graaf", "noNodesInGraph": "Geen knooppunten in graaf",
"noModelSelected": "Geen model ingesteld", "noModelSelected": "Geen model ingesteld",
@ -369,11 +453,25 @@
"missingInputForField": "{{nodeLabel}} -> {{fieldLabel}} invoer ontbreekt", "missingInputForField": "{{nodeLabel}} -> {{fieldLabel}} invoer ontbreekt",
"noControlImageForControlAdapter": "Controle-adapter #{{number}} heeft geen controle-afbeelding", "noControlImageForControlAdapter": "Controle-adapter #{{number}} heeft geen controle-afbeelding",
"noModelForControlAdapter": "Control-adapter #{{number}} heeft geen model ingesteld staan.", "noModelForControlAdapter": "Control-adapter #{{number}} heeft geen model ingesteld staan.",
"incompatibleBaseModelForControlAdapter": "Model van controle-adapter #{{number}} is ongeldig in combinatie met het hoofdmodel.", "incompatibleBaseModelForControlAdapter": "Model van controle-adapter #{{number}} is niet compatibel met het hoofdmodel.",
"systemDisconnected": "Systeem is niet verbonden", "systemDisconnected": "Systeem is niet verbonden",
"missingNodeTemplate": "Knooppuntsjabloon ontbreekt", "missingNodeTemplate": "Knooppuntsjabloon ontbreekt",
"missingFieldTemplate": "Veldsjabloon ontbreekt", "missingFieldTemplate": "Veldsjabloon ontbreekt",
"addingImagesTo": "Bezig met toevoegen van afbeeldingen aan" "addingImagesTo": "Bezig met toevoegen van afbeeldingen aan",
"layer": {
"initialImageNoImageSelected": "geen initiële afbeelding geselecteerd",
"controlAdapterNoModelSelected": "geen controle-adaptermodel geselecteerd",
"controlAdapterIncompatibleBaseModel": "niet-compatibele basismodel voor controle-adapter",
"controlAdapterNoImageSelected": "geen afbeelding voor controle-adapter geselecteerd",
"controlAdapterImageNotProcessed": "Afbeelding voor controle-adapter niet verwerkt",
"ipAdapterIncompatibleBaseModel": "niet-compatibele basismodel voor IP-adapter",
"ipAdapterNoImageSelected": "geen afbeelding voor IP-adapter geselecteerd",
"rgNoRegion": "geen gebied geselecteerd",
"rgNoPromptsOrIPAdapters": "geen tekstprompts of IP-adapters",
"t2iAdapterIncompatibleDimensions": "T2I-adapter vereist een afbeelding met afmetingen met een veelvoud van 64",
"ipAdapterNoModelSelected": "geen IP-adapter geselecteerd"
},
"imageNotProcessedForControlAdapter": "De afbeelding van controle-adapter #{{number}} is niet verwerkt"
}, },
"isAllowedToUpscale": { "isAllowedToUpscale": {
"useX2Model": "Afbeelding is te groot om te vergroten met het x4-model. Gebruik hiervoor het x2-model", "useX2Model": "Afbeelding is te groot om te vergroten met het x4-model. Gebruik hiervoor het x2-model",
@ -383,9 +481,26 @@
"useCpuNoise": "Gebruik CPU-ruis", "useCpuNoise": "Gebruik CPU-ruis",
"imageActions": "Afbeeldingshandeling", "imageActions": "Afbeeldingshandeling",
"iterations": "Iteraties", "iterations": "Iteraties",
"iterationsWithCount_one": "{{count}} iteratie", "coherenceMode": "Modus",
"iterationsWithCount_other": "{{count}} iteraties", "infillColorValue": "Vulkleur",
"coherenceMode": "Modus" "remixImage": "Meng afbeelding opnieuw",
"setToOptimalSize": "Optimaliseer grootte voor het model",
"setToOptimalSizeTooSmall": "$t(parameters.setToOptimalSize) (is mogelijk te klein)",
"aspect": "Beeldverhouding",
"infillMosaicTileWidth": "Breedte tegel",
"setToOptimalSizeTooLarge": "$t(parameters.setToOptimalSize) (is mogelijk te groot)",
"lockAspectRatio": "Zet beeldverhouding vast",
"infillMosaicTileHeight": "Hoogte tegel",
"globalNegativePromptPlaceholder": "Globale negatieve prompt",
"globalPositivePromptPlaceholder": "Globale positieve prompt",
"useSize": "Gebruik grootte",
"swapDimensions": "Wissel afmetingen om",
"globalSettings": "Globale instellingen",
"coherenceEdgeSize": "Randgrootte",
"coherenceMinDenoise": "Min. ontruising",
"infillMosaicMinColor": "Min. kleur",
"infillMosaicMaxColor": "Max. kleur",
"cfgRescaleMultiplier": "Vermenigvuldiger voor CFG-herschaling"
}, },
"settings": { "settings": {
"models": "Modellen", "models": "Modellen",
@ -412,7 +527,12 @@
"intermediatesCleared_one": "{{count}} tussentijdse afbeelding gewist", "intermediatesCleared_one": "{{count}} tussentijdse afbeelding gewist",
"intermediatesCleared_other": "{{count}} tussentijdse afbeeldingen gewist", "intermediatesCleared_other": "{{count}} tussentijdse afbeeldingen gewist",
"clearIntermediatesDesc1": "Als je tussentijdse afbeeldingen wist, dan wordt de staat hersteld van je canvas en van ControlNet.", "clearIntermediatesDesc1": "Als je tussentijdse afbeeldingen wist, dan wordt de staat hersteld van je canvas en van ControlNet.",
"intermediatesClearedFailed": "Fout bij wissen van tussentijdse afbeeldingen" "intermediatesClearedFailed": "Fout bij wissen van tussentijdse afbeeldingen",
"clearIntermediatesDisabled": "Wachtrij moet leeg zijn om tussentijdse afbeeldingen te kunnen leegmaken",
"enableInformationalPopovers": "Schakel informatieve hulpballonnen in",
"enableInvisibleWatermark": "Schakel onzichtbaar watermerk in",
"enableNSFWChecker": "Schakel NSFW-controle in",
"reloadingIn": "Opnieuw laden na"
}, },
"toast": { "toast": {
"uploadFailed": "Upload mislukt", "uploadFailed": "Upload mislukt",
@ -427,8 +547,8 @@
"connected": "Verbonden met server", "connected": "Verbonden met server",
"canceled": "Verwerking geannuleerd", "canceled": "Verwerking geannuleerd",
"uploadFailedInvalidUploadDesc": "Moet een enkele PNG- of JPEG-afbeelding zijn", "uploadFailedInvalidUploadDesc": "Moet een enkele PNG- of JPEG-afbeelding zijn",
"parameterNotSet": "Parameter niet ingesteld", "parameterNotSet": "{{parameter}} niet ingesteld",
"parameterSet": "Instellen parameters", "parameterSet": "{{parameter}} ingesteld",
"problemCopyingImage": "Kan Afbeelding Niet Kopiëren", "problemCopyingImage": "Kan Afbeelding Niet Kopiëren",
"baseModelChangedCleared_one": "Basismodel is gewijzigd: {{count}} niet-compatibel submodel weggehaald of uitgeschakeld", "baseModelChangedCleared_one": "Basismodel is gewijzigd: {{count}} niet-compatibel submodel weggehaald of uitgeschakeld",
"baseModelChangedCleared_other": "Basismodel is gewijzigd: {{count}} niet-compatibele submodellen weggehaald of uitgeschakeld", "baseModelChangedCleared_other": "Basismodel is gewijzigd: {{count}} niet-compatibele submodellen weggehaald of uitgeschakeld",
@ -445,11 +565,11 @@
"maskSavedAssets": "Masker bewaard in Assets", "maskSavedAssets": "Masker bewaard in Assets",
"problemDownloadingCanvas": "Fout bij downloaden van canvas", "problemDownloadingCanvas": "Fout bij downloaden van canvas",
"problemMergingCanvas": "Fout bij samenvoegen canvas", "problemMergingCanvas": "Fout bij samenvoegen canvas",
"setCanvasInitialImage": "Ingesteld als initiële canvasafbeelding", "setCanvasInitialImage": "Initiële canvasafbeelding ingesteld",
"imageUploaded": "Afbeelding geüpload", "imageUploaded": "Afbeelding geüpload",
"addedToBoard": "Toegevoegd aan bord", "addedToBoard": "Toegevoegd aan bord",
"workflowLoaded": "Werkstroom geladen", "workflowLoaded": "Werkstroom geladen",
"modelAddedSimple": "Model toegevoegd", "modelAddedSimple": "Model toegevoegd aan wachtrij",
"problemImportingMaskDesc": "Kan masker niet exporteren", "problemImportingMaskDesc": "Kan masker niet exporteren",
"problemCopyingCanvas": "Fout bij kopiëren canvas", "problemCopyingCanvas": "Fout bij kopiëren canvas",
"problemSavingCanvas": "Fout bij bewaren canvas", "problemSavingCanvas": "Fout bij bewaren canvas",
@ -461,7 +581,18 @@
"maskSentControlnetAssets": "Masker gestuurd naar ControlNet en Assets", "maskSentControlnetAssets": "Masker gestuurd naar ControlNet en Assets",
"canvasSavedGallery": "Canvas bewaard in galerij", "canvasSavedGallery": "Canvas bewaard in galerij",
"imageUploadFailed": "Fout bij uploaden afbeelding", "imageUploadFailed": "Fout bij uploaden afbeelding",
"problemImportingMask": "Fout bij importeren masker" "problemImportingMask": "Fout bij importeren masker",
"workflowDeleted": "Werkstroom verwijderd",
"invalidUpload": "Ongeldige upload",
"uploadInitialImage": "Initiële afbeelding uploaden",
"setAsCanvasInitialImage": "Ingesteld als initiële afbeelding voor canvas",
"problemRetrievingWorkflow": "Fout bij ophalen van werkstroom",
"parameters": "Parameters",
"modelImportCanceled": "Importeren model geannuleerd",
"problemDeletingWorkflow": "Fout bij verwijderen van werkstroom",
"prunedQueue": "Wachtrij gesnoeid",
"problemDownloadingImage": "Fout bij downloaden afbeelding",
"resetInitialImage": "Initiële afbeelding hersteld"
}, },
"tooltip": { "tooltip": {
"feature": { "feature": {
@ -535,7 +666,11 @@
"showOptionsPanel": "Toon zijscherm", "showOptionsPanel": "Toon zijscherm",
"menu": "Menu", "menu": "Menu",
"showGalleryPanel": "Toon deelscherm Galerij", "showGalleryPanel": "Toon deelscherm Galerij",
"loadMore": "Laad meer" "loadMore": "Laad meer",
"about": "Over",
"mode": "Modus",
"resetUI": "$t(accessibility.reset) UI",
"createIssue": "Maak probleem aan"
}, },
"nodes": { "nodes": {
"zoomOutNodes": "Uitzoomen", "zoomOutNodes": "Uitzoomen",
@ -549,7 +684,7 @@
"loadWorkflow": "Laad werkstroom", "loadWorkflow": "Laad werkstroom",
"downloadWorkflow": "Download JSON van werkstroom", "downloadWorkflow": "Download JSON van werkstroom",
"scheduler": "Planner", "scheduler": "Planner",
"missingTemplate": "Ontbrekende sjabloon", "missingTemplate": "Ongeldig knooppunt: knooppunt {{node}} van het soort {{type}} heeft een ontbrekend sjabloon (niet geïnstalleerd?)",
"workflowDescription": "Korte beschrijving", "workflowDescription": "Korte beschrijving",
"versionUnknown": " Versie onbekend", "versionUnknown": " Versie onbekend",
"noNodeSelected": "Geen knooppunt gekozen", "noNodeSelected": "Geen knooppunt gekozen",
@ -565,7 +700,7 @@
"integer": "Geheel getal", "integer": "Geheel getal",
"nodeTemplate": "Sjabloon knooppunt", "nodeTemplate": "Sjabloon knooppunt",
"nodeOpacity": "Dekking knooppunt", "nodeOpacity": "Dekking knooppunt",
"unableToLoadWorkflow": "Kan werkstroom niet valideren", "unableToLoadWorkflow": "Fout bij laden werkstroom",
"snapToGrid": "Lijn uit op raster", "snapToGrid": "Lijn uit op raster",
"noFieldsLinearview": "Geen velden toegevoegd aan lineaire weergave", "noFieldsLinearview": "Geen velden toegevoegd aan lineaire weergave",
"nodeSearch": "Zoek naar knooppunten", "nodeSearch": "Zoek naar knooppunten",
@ -616,11 +751,56 @@
"unknownField": "Onbekend veld", "unknownField": "Onbekend veld",
"colorCodeEdges": "Kleurgecodeerde randen", "colorCodeEdges": "Kleurgecodeerde randen",
"unknownNode": "Onbekend knooppunt", "unknownNode": "Onbekend knooppunt",
"mismatchedVersion": "Heeft niet-overeenkomende versie", "mismatchedVersion": "Ongeldig knooppunt: knooppunt {{node}} van het soort {{type}} heeft een niet-overeenkomende versie (probeer het bij te werken?)",
"addNodeToolTip": "Voeg knooppunt toe (Shift+A, spatie)", "addNodeToolTip": "Voeg knooppunt toe (Shift+A, spatie)",
"loadingNodes": "Bezig met laden van knooppunten...", "loadingNodes": "Bezig met laden van knooppunten...",
"snapToGridHelp": "Lijn knooppunten uit op raster bij verplaatsing", "snapToGridHelp": "Lijn knooppunten uit op raster bij verplaatsing",
"workflowSettings": "Instellingen werkstroomeditor" "workflowSettings": "Instellingen werkstroomeditor",
"addLinearView": "Voeg toe aan lineaire weergave",
"nodePack": "Knooppuntpakket",
"unknownInput": "Onbekende invoer: {{name}}",
"sourceNodeFieldDoesNotExist": "Ongeldige rand: bron-/uitvoerveld {{node}}.{{field}} bestaat niet",
"collectionFieldType": "Verzameling {{name}}",
"deletedInvalidEdge": "Ongeldige hoek {{source}} -> {{target}} verwijderd",
"graph": "Grafiek",
"targetNodeDoesNotExist": "Ongeldige rand: doel-/invoerknooppunt {{node}} bestaat niet",
"resetToDefaultValue": "Herstel naar standaardwaarden",
"editMode": "Bewerk in Werkstroom-editor",
"showEdgeLabels": "Toon randlabels",
"showEdgeLabelsHelp": "Toon labels aan randen, waarmee de verbonden knooppunten mee worden aangegeven",
"clearWorkflowDesc2": "Je huidige werkstroom heeft niet-bewaarde wijzigingen.",
"unableToParseFieldType": "fout bij bepalen soort veld",
"sourceNodeDoesNotExist": "Ongeldige rand: bron-/uitvoerknooppunt {{node}} bestaat niet",
"unsupportedArrayItemType": "niet-ondersteunde soort van het array-onderdeel \"{{type}}\"",
"targetNodeFieldDoesNotExist": "Ongeldige rand: doel-/invoerveld {{node}}.{{field}} bestaat niet",
"reorderLinearView": "Herorden lineaire weergave",
"newWorkflowDesc": "Een nieuwe werkstroom aanmaken?",
"collectionOrScalarFieldType": "Verzameling|scalair {{name}}",
"newWorkflow": "Nieuwe werkstroom",
"unknownErrorValidatingWorkflow": "Onbekende fout bij valideren werkstroom",
"unsupportedAnyOfLength": "te veel union-leden ({{count}})",
"unknownOutput": "Onbekende uitvoer: {{name}}",
"viewMode": "Gebruik in lineaire weergave",
"unableToExtractSchemaNameFromRef": "fout bij het extraheren van de schemanaam via de ref",
"unsupportedMismatchedUnion": "niet-overeenkomende soort CollectionOrScalar met basissoorten {{firstType}} en {{secondType}}",
"unknownNodeType": "Onbekend soort knooppunt",
"edit": "Bewerk",
"updateAllNodes": "Werk knooppunten bij",
"allNodesUpdated": "Alle knooppunten bijgewerkt",
"nodeVersion": "Knooppuntversie",
"newWorkflowDesc2": "Je huidige werkstroom heeft niet-bewaarde wijzigingen.",
"clearWorkflow": "Maak werkstroom leeg",
"clearWorkflowDesc": "Deze werkstroom leegmaken en met een nieuwe beginnen?",
"inputFieldTypeParseError": "Fout bij bepalen van het soort invoerveld {{node}}.{{field}} ({{message}})",
"outputFieldTypeParseError": "Fout bij het bepalen van het soort uitvoerveld {{node}}.{{field}} ({{message}})",
"unableToExtractEnumOptions": "fout bij extraheren enumeratie-opties",
"unknownFieldType": "Soort $t(nodes.unknownField): {{type}}",
"unableToGetWorkflowVersion": "Fout bij ophalen schemaversie van werkstroom",
"betaDesc": "Deze uitvoering is in bèta. Totdat deze stabiel is kunnen er wijzigingen voorkomen gedurende app-updates die zaken kapotmaken. We zijn van plan om deze uitvoering op lange termijn te gaan ondersteunen.",
"prototypeDesc": "Deze uitvoering is een prototype. Er kunnen wijzigingen voorkomen gedurende app-updates die zaken kapotmaken. Deze kunnen op een willekeurig moment verwijderd worden.",
"noFieldsViewMode": "Deze werkstroom heeft geen geselecteerde velden om te tonen. Bekijk de volledige werkstroom om de waarden te configureren.",
"unableToUpdateNodes_one": "Fout bij bijwerken van {{count}} knooppunt",
"unableToUpdateNodes_other": "Fout bij bijwerken van {{count}} knooppunten"
}, },
"controlnet": { "controlnet": {
"amult": "a_mult", "amult": "a_mult",
@ -693,9 +873,28 @@
"canny": "Canny", "canny": "Canny",
"depthZoeDescription": "Genereer diepteblad via Zoe", "depthZoeDescription": "Genereer diepteblad via Zoe",
"hedDescription": "Herkenning van holistisch-geneste randen", "hedDescription": "Herkenning van holistisch-geneste randen",
"setControlImageDimensions": "Stel afmetingen controle-afbeelding in op B/H", "setControlImageDimensions": "Kopieer grootte naar B/H (optimaliseer voor model)",
"scribble": "Krabbel", "scribble": "Krabbel",
"maxFaces": "Max. gezichten" "maxFaces": "Max. gezichten",
"dwOpenpose": "DW Openpose",
"depthAnything": "Depth Anything",
"base": "Basis",
"hands": "Handen",
"selectCLIPVisionModel": "Selecteer een CLIP Vision-model",
"modelSize": "Modelgrootte",
"small": "Klein",
"large": "Groot",
"resizeSimple": "Wijzig grootte (eenvoudig)",
"beginEndStepPercentShort": "Begin-/eind-%",
"depthAnythingDescription": "Genereren dieptekaart d.m.v. de techniek Depth Anything",
"face": "Gezicht",
"body": "Lichaam",
"dwOpenposeDescription": "Schatting menselijke pose d.m.v. DW Openpose",
"ipAdapterMethod": "Methode",
"full": "Volledig",
"style": "Alleen stijl",
"composition": "Alleen samenstelling",
"setControlImageDimensionsForce": "Kopieer grootte naar B/H (negeer model)"
}, },
"dynamicPrompts": { "dynamicPrompts": {
"seedBehaviour": { "seedBehaviour": {
@ -708,7 +907,10 @@
"maxPrompts": "Max. prompts", "maxPrompts": "Max. prompts",
"promptsWithCount_one": "{{count}} prompt", "promptsWithCount_one": "{{count}} prompt",
"promptsWithCount_other": "{{count}} prompts", "promptsWithCount_other": "{{count}} prompts",
"dynamicPrompts": "Dynamische prompts" "dynamicPrompts": "Dynamische prompts",
"showDynamicPrompts": "Toon dynamische prompts",
"loading": "Genereren van dynamische prompts...",
"promptsPreview": "Voorvertoning prompts"
}, },
"popovers": { "popovers": {
"noiseUseCPU": { "noiseUseCPU": {
@ -721,7 +923,7 @@
}, },
"paramScheduler": { "paramScheduler": {
"paragraphs": [ "paragraphs": [
"De planner bepaalt hoe ruis per iteratie wordt toegevoegd aan een afbeelding of hoe een monster wordt bijgewerkt op basis van de uitvoer van een model." "De planner gebruikt gedurende het genereringsproces."
], ],
"heading": "Planner" "heading": "Planner"
}, },
@ -808,8 +1010,8 @@
}, },
"clipSkip": { "clipSkip": {
"paragraphs": [ "paragraphs": [
"Kies hoeveel CLIP-modellagen je wilt overslaan.", "Aantal over te slaan CLIP-modellagen.",
"Bepaalde modellen werken beter met bepaalde Overslaan CLIP-instellingen." "Bepaalde modellen zijn beter geschikt met bepaalde Overslaan CLIP-instellingen."
], ],
"heading": "Overslaan CLIP" "heading": "Overslaan CLIP"
}, },
@ -940,7 +1142,6 @@
"completed": "Voltooid", "completed": "Voltooid",
"queueBack": "Voeg toe aan wachtrij", "queueBack": "Voeg toe aan wachtrij",
"cancelFailed": "Fout bij annuleren onderdeel", "cancelFailed": "Fout bij annuleren onderdeel",
"queueCountPrediction": "Voeg {{predicted}} toe aan wachtrij",
"batchQueued": "Reeks in wachtrij geplaatst", "batchQueued": "Reeks in wachtrij geplaatst",
"pauseFailed": "Fout bij onderbreken verwerker", "pauseFailed": "Fout bij onderbreken verwerker",
"clearFailed": "Fout bij wissen van wachtrij", "clearFailed": "Fout bij wissen van wachtrij",
@ -994,17 +1195,26 @@
"denoisingStrength": "Sterkte ontruising", "denoisingStrength": "Sterkte ontruising",
"refinermodel": "Verfijningsmodel", "refinermodel": "Verfijningsmodel",
"posAestheticScore": "Positieve esthetische score", "posAestheticScore": "Positieve esthetische score",
"concatPromptStyle": "Plak prompt- en stijltekst aan elkaar", "concatPromptStyle": "Koppelen van prompt en stijl",
"loading": "Bezig met laden...", "loading": "Bezig met laden...",
"steps": "Stappen", "steps": "Stappen",
"posStylePrompt": "Positieve-stijlprompt" "posStylePrompt": "Positieve-stijlprompt",
"freePromptStyle": "Handmatige stijlprompt",
"refinerSteps": "Aantal stappen verfijner"
}, },
"models": { "models": {
"noMatchingModels": "Geen overeenkomend modellen", "noMatchingModels": "Geen overeenkomend modellen",
"loading": "bezig met laden", "loading": "bezig met laden",
"noMatchingLoRAs": "Geen overeenkomende LoRA's", "noMatchingLoRAs": "Geen overeenkomende LoRA's",
"noModelsAvailable": "Geen modellen beschikbaar", "noModelsAvailable": "Geen modellen beschikbaar",
"selectModel": "Kies een model" "selectModel": "Kies een model",
"noLoRAsInstalled": "Geen LoRA's geïnstalleerd",
"noRefinerModelsInstalled": "Geen SDXL-verfijningsmodellen geïnstalleerd",
"defaultVAE": "Standaard-VAE",
"lora": "LoRA",
"esrganModel": "ESRGAN-model",
"addLora": "Voeg LoRA toe",
"concepts": "Concepten"
}, },
"boards": { "boards": {
"autoAddBoard": "Voeg automatisch bord toe", "autoAddBoard": "Voeg automatisch bord toe",
@ -1022,7 +1232,13 @@
"downloadBoard": "Download bord", "downloadBoard": "Download bord",
"changeBoard": "Wijzig bord", "changeBoard": "Wijzig bord",
"loading": "Bezig met laden...", "loading": "Bezig met laden...",
"clearSearch": "Maak zoekopdracht leeg" "clearSearch": "Maak zoekopdracht leeg",
"deleteBoard": "Verwijder bord",
"deleteBoardAndImages": "Verwijder bord en afbeeldingen",
"deleteBoardOnly": "Verwijder alleen bord",
"deletedBoardsCannotbeRestored": "Verwijderde borden kunnen niet worden hersteld",
"movingImagesToBoard_one": "Verplaatsen van {{count}} afbeelding naar bord:",
"movingImagesToBoard_other": "Verplaatsen van {{count}} afbeeldingen naar bord:"
}, },
"invocationCache": { "invocationCache": {
"disable": "Schakel uit", "disable": "Schakel uit",
@ -1039,5 +1255,39 @@
"clear": "Wis", "clear": "Wis",
"maxCacheSize": "Max. grootte cache", "maxCacheSize": "Max. grootte cache",
"cacheSize": "Grootte cache" "cacheSize": "Grootte cache"
},
"accordions": {
"generation": {
"title": "Genereren"
},
"image": {
"title": "Afbeelding"
},
"advanced": {
"title": "Geavanceerd",
"options": "$t(accordions.advanced.title) Opties"
},
"control": {
"title": "Besturing"
},
"compositing": {
"title": "Samenstellen",
"coherenceTab": "Coherentiefase",
"infillTab": "Invullen"
}
},
"hrf": {
"upscaleMethod": "Opschaalmethode",
"metadata": {
"strength": "Sterkte oplossing voor hoge resolutie",
"method": "Methode oplossing voor hoge resolutie",
"enabled": "Oplossing voor hoge resolutie ingeschakeld"
},
"hrf": "Oplossing voor hoge resolutie",
"enableHrf": "Schakel oplossing in voor hoge resolutie"
},
"prompt": {
"addPromptTrigger": "Voeg prompttrigger toe",
"compatibleEmbeddings": "Compatibele embeddings"
} }
} }

View File

@ -76,7 +76,18 @@
"localSystem": "Локальная система", "localSystem": "Локальная система",
"aboutDesc": "Используя Invoke для работы? Проверьте это:", "aboutDesc": "Используя Invoke для работы? Проверьте это:",
"add": "Добавить", "add": "Добавить",
"loglevel": "Уровень логов" "loglevel": "Уровень логов",
"beta": "Бета",
"selected": "Выбрано",
"positivePrompt": "Позитивный запрос",
"negativePrompt": "Негативный запрос",
"editor": "Редактор",
"goTo": "Перейти к",
"tab": "Вкладка",
"viewing": "Просмотр",
"editing": "Редактирование",
"viewingDesc": "Просмотр изображений в режиме большой галереи",
"editingDesc": "Редактировать на холсте слоёв управления"
}, },
"gallery": { "gallery": {
"galleryImageSize": "Размер изображений", "galleryImageSize": "Размер изображений",
@ -87,8 +98,8 @@
"deleteImagePermanent": "Удаленные изображения невозможно восстановить.", "deleteImagePermanent": "Удаленные изображения невозможно восстановить.",
"deleteImageBin": "Удаленные изображения будут отправлены в корзину вашей операционной системы.", "deleteImageBin": "Удаленные изображения будут отправлены в корзину вашей операционной системы.",
"deleteImage_one": "Удалить изображение", "deleteImage_one": "Удалить изображение",
"deleteImage_few": "", "deleteImage_few": "Удалить {{count}} изображения",
"deleteImage_many": "", "deleteImage_many": "Удалить {{count}} изображений",
"assets": "Ресурсы", "assets": "Ресурсы",
"autoAssignBoardOnClick": "Авто-назначение доски по клику", "autoAssignBoardOnClick": "Авто-назначение доски по клику",
"deleteSelection": "Удалить выделенное", "deleteSelection": "Удалить выделенное",
@ -336,6 +347,10 @@
"remixImage": { "remixImage": {
"desc": "Используйте все параметры, кроме сида из текущего изображения", "desc": "Используйте все параметры, кроме сида из текущего изображения",
"title": "Ремикс изображения" "title": "Ремикс изображения"
},
"toggleViewer": {
"title": "Переключить просмотр изображений",
"desc": "Переключение между средством просмотра изображений и рабочей областью для текущей вкладки."
} }
}, },
"modelManager": { "modelManager": {
@ -512,7 +527,8 @@
"missingNodeTemplate": "Отсутствует шаблон узла", "missingNodeTemplate": "Отсутствует шаблон узла",
"missingFieldTemplate": "Отсутствует шаблон поля", "missingFieldTemplate": "Отсутствует шаблон поля",
"addingImagesTo": "Добавление изображений в", "addingImagesTo": "Добавление изображений в",
"invoke": "Создать" "invoke": "Создать",
"imageNotProcessedForControlAdapter": "Изображение адаптера контроля №{{number}} не обрабатывается"
}, },
"isAllowedToUpscale": { "isAllowedToUpscale": {
"useX2Model": "Изображение слишком велико для увеличения с помощью модели x4. Используйте модель x2", "useX2Model": "Изображение слишком велико для увеличения с помощью модели x4. Используйте модель x2",
@ -523,9 +539,6 @@
"useCpuNoise": "Использовать шум CPU", "useCpuNoise": "Использовать шум CPU",
"imageActions": "Действия с изображениями", "imageActions": "Действия с изображениями",
"iterations": "Кол-во", "iterations": "Кол-во",
"iterationsWithCount_one": "{{count}} Интеграция",
"iterationsWithCount_few": "{{count}} Итерации",
"iterationsWithCount_many": "{{count}} Итераций",
"useSize": "Использовать размер", "useSize": "Использовать размер",
"coherenceMode": "Режим", "coherenceMode": "Режим",
"aspect": "Соотношение", "aspect": "Соотношение",
@ -541,7 +554,10 @@
"infillMosaicTileHeight": "Высота плиток", "infillMosaicTileHeight": "Высота плиток",
"infillMosaicMinColor": "Мин цвет", "infillMosaicMinColor": "Мин цвет",
"infillMosaicMaxColor": "Макс цвет", "infillMosaicMaxColor": "Макс цвет",
"infillColorValue": "Цвет заливки" "infillColorValue": "Цвет заливки",
"globalSettings": "Глобальные настройки",
"globalNegativePromptPlaceholder": "Глобальный негативный запрос",
"globalPositivePromptPlaceholder": "Глобальный запрос"
}, },
"settings": { "settings": {
"models": "Модели", "models": "Модели",
@ -706,7 +722,9 @@
"coherenceModeBoxBlur": "коробчатое размытие", "coherenceModeBoxBlur": "коробчатое размытие",
"discardCurrent": "Отбросить текущее", "discardCurrent": "Отбросить текущее",
"invertBrushSizeScrollDirection": "Инвертировать прокрутку для размера кисти", "invertBrushSizeScrollDirection": "Инвертировать прокрутку для размера кисти",
"initialFitImageSize": "Подогнать размер изображения при перебросе" "initialFitImageSize": "Подогнать размер изображения при перебросе",
"hideBoundingBox": "Скрыть ограничительную рамку",
"showBoundingBox": "Показать ограничительную рамку"
}, },
"accessibility": { "accessibility": {
"uploadImage": "Загрузить изображение", "uploadImage": "Загрузить изображение",
@ -849,7 +867,10 @@
"editMode": "Открыть в редакторе узлов", "editMode": "Открыть в редакторе узлов",
"resetToDefaultValue": "Сбросить к стандартному значкнию", "resetToDefaultValue": "Сбросить к стандартному значкнию",
"edit": "Редактировать", "edit": "Редактировать",
"noFieldsViewMode": "В этом рабочем процессе нет выбранных полей для отображения. Просмотрите полный рабочий процесс для настройки значений." "noFieldsViewMode": "В этом рабочем процессе нет выбранных полей для отображения. Просмотрите полный рабочий процесс для настройки значений.",
"graph": "График",
"showEdgeLabels": "Показать метки на ребрах",
"showEdgeLabelsHelp": "Показать метки на ребрах, указывающие на соединенные узлы"
}, },
"controlnet": { "controlnet": {
"amult": "a_mult", "amult": "a_mult",
@ -917,8 +938,8 @@
"lineartAnime": "Контурный рисунок в стиле аниме", "lineartAnime": "Контурный рисунок в стиле аниме",
"mediapipeFaceDescription": "Обнаружение лиц с помощью Mediapipe", "mediapipeFaceDescription": "Обнаружение лиц с помощью Mediapipe",
"hedDescription": "Целостное обнаружение границ", "hedDescription": "Целостное обнаружение границ",
"setControlImageDimensions": "Установите размеры контрольного изображения на Ш/В", "setControlImageDimensions": "Скопируйте размер в Ш/В (оптимизируйте для модели)",
"scribble": "каракули", "scribble": "Штрихи",
"maxFaces": "Макс Лица", "maxFaces": "Макс Лица",
"mlsdDescription": "Минималистичный детектор отрезков линии", "mlsdDescription": "Минималистичный детектор отрезков линии",
"resizeSimple": "Изменить размер (простой)", "resizeSimple": "Изменить размер (простой)",
@ -933,7 +954,18 @@
"small": "Маленький", "small": "Маленький",
"body": "Тело", "body": "Тело",
"hands": "Руки", "hands": "Руки",
"selectCLIPVisionModel": "Выбрать модель CLIP Vision" "selectCLIPVisionModel": "Выбрать модель CLIP Vision",
"ipAdapterMethod": "Метод",
"full": "Всё",
"mlsd": "M-LSD",
"h": "H",
"style": "Только стиль",
"dwOpenpose": "DW Openpose",
"pidi": "PIDI",
"composition": "Только композиция",
"hed": "HED",
"beginEndStepPercentShort": "Начало/конец %",
"setControlImageDimensionsForce": "Скопируйте размер в Ш/В (игнорируйте модель)"
}, },
"boards": { "boards": {
"autoAddBoard": "Авто добавление Доски", "autoAddBoard": "Авто добавление Доски",
@ -1312,6 +1344,12 @@
"paragraphs": [ "paragraphs": [
"Плавно укладывайте изображение вдоль вертикальной оси." "Плавно укладывайте изображение вдоль вертикальной оси."
] ]
},
"ipAdapterMethod": {
"heading": "Метод",
"paragraphs": [
"Метод, с помощью которого применяется текущий IP-адаптер."
]
} }
}, },
"metadata": { "metadata": {
@ -1359,7 +1397,6 @@
"completed": "Выполнено", "completed": "Выполнено",
"queueBack": "Добавить в очередь", "queueBack": "Добавить в очередь",
"cancelFailed": "Проблема с отменой элемента", "cancelFailed": "Проблема с отменой элемента",
"queueCountPrediction": "{{promptsCount}} запросов × {{iterations}} изображений -> {{count}} генераций",
"batchQueued": "Пакетная очередь", "batchQueued": "Пакетная очередь",
"pauseFailed": "Проблема с приостановкой рендеринга", "pauseFailed": "Проблема с приостановкой рендеринга",
"clearFailed": "Проблема с очисткой очереди", "clearFailed": "Проблема с очисткой очереди",
@ -1475,7 +1512,11 @@
"projectWorkflows": "Рабочие процессы проекта", "projectWorkflows": "Рабочие процессы проекта",
"defaultWorkflows": "Стандартные рабочие процессы", "defaultWorkflows": "Стандартные рабочие процессы",
"name": "Имя", "name": "Имя",
"noRecentWorkflows": "Нет последних рабочих процессов" "noRecentWorkflows": "Нет последних рабочих процессов",
"loadWorkflow": "Рабочий процесс $t(common.load)",
"convertGraph": "Конвертировать график",
"loadFromGraph": "Загрузка рабочего процесса из графика",
"autoLayout": "Автоматическое расположение"
}, },
"hrf": { "hrf": {
"enableHrf": "Включить исправление высокого разрешения", "enableHrf": "Включить исправление высокого разрешения",
@ -1528,5 +1569,55 @@
"addPromptTrigger": "Добавить триггер запроса", "addPromptTrigger": "Добавить триггер запроса",
"compatibleEmbeddings": "Совместимые встраивания", "compatibleEmbeddings": "Совместимые встраивания",
"noMatchingTriggers": "Нет соответствующих триггеров" "noMatchingTriggers": "Нет соответствующих триггеров"
},
"controlLayers": {
"moveToBack": "На задний план",
"moveForward": "Переместить вперёд",
"moveBackward": "Переместить назад",
"brushSize": "Размер кисти",
"controlLayers": "Слои управления",
"globalMaskOpacity": "Глобальная непрозрачность маски",
"autoNegative": "Авто негатив",
"deletePrompt": "Удалить запрос",
"resetRegion": "Сбросить регион",
"debugLayers": "Слои отладки",
"rectangle": "Прямоугольник",
"maskPreviewColor": "Цвет предпросмотра маски",
"addNegativePrompt": "Добавить $t(common.negativePrompt)",
"regionalGuidance": "Региональная точность",
"opacity": "Непрозрачность",
"globalControlAdapter": "Глобальный $t(controlnet.controlAdapter_one)",
"globalControlAdapterLayer": "Глобальный $t(controlnet.controlAdapter_one) $t(unifiedCanvas.layer)",
"globalIPAdapter": "Глобальный $t(common.ipAdapter)",
"globalIPAdapterLayer": "Глобальный $t(common.ipAdapter) $t(unifiedCanvas.layer)",
"opacityFilter": "Фильтр непрозрачности",
"deleteAll": "Удалить всё",
"addLayer": "Добавить слой",
"moveToFront": "На передний план",
"addPositivePrompt": "Добавить $t(common.positivePrompt)",
"addIPAdapter": "Добавить $t(common.ipAdapter)",
"regionalGuidanceLayer": "$t(controlLayers.regionalGuidance) $t(unifiedCanvas.layer)",
"resetProcessor": "Сброс процессора по умолчанию",
"clearProcessor": "Чистый процессор",
"globalInitialImage": "Глобальное исходное изображение",
"globalInitialImageLayer": "$t(controlLayers.globalInitialImage) $t(unifiedCanvas.layer)",
"noLayersAdded": "Без слоев",
"layers_one": "Слой",
"layers_few": "Слоя",
"layers_many": "Слоев"
},
"ui": {
"tabs": {
"generation": "Генерация",
"canvas": "Холст",
"workflowsTab": "$t(ui.tabs.workflows) $t(common.tab)",
"models": "Модели",
"generationTab": "$t(ui.tabs.generation) $t(common.tab)",
"workflows": "Рабочие процессы",
"canvasTab": "$t(ui.tabs.canvas) $t(common.tab)",
"queueTab": "$t(ui.tabs.queue) $t(common.tab)",
"modelsTab": "$t(ui.tabs.models) $t(common.tab)",
"queue": "Очередь"
}
} }
} }

View File

@ -66,7 +66,7 @@
"saveAs": "保存为", "saveAs": "保存为",
"ai": "ai", "ai": "ai",
"or": "或", "or": "或",
"aboutDesc": "使用 Invoke 工作?看:", "aboutDesc": "使用 Invoke 工作?来看看:",
"add": "添加", "add": "添加",
"loglevel": "日志级别", "loglevel": "日志级别",
"copy": "复制", "copy": "复制",
@ -445,7 +445,6 @@
"useX2Model": "图像太大,无法使用 x4 模型,使用 x2 模型作为替代", "useX2Model": "图像太大,无法使用 x4 模型,使用 x2 模型作为替代",
"tooLarge": "图像太大无法进行放大,请选择更小的图像" "tooLarge": "图像太大无法进行放大,请选择更小的图像"
}, },
"iterationsWithCount_other": "{{count}} 次迭代生成",
"cfgRescaleMultiplier": "CFG 重缩放倍数", "cfgRescaleMultiplier": "CFG 重缩放倍数",
"useSize": "使用尺寸", "useSize": "使用尺寸",
"setToOptimalSize": "优化模型大小", "setToOptimalSize": "优化模型大小",
@ -853,7 +852,6 @@
"pruneSucceeded": "从队列修剪 {{item_count}} 个已完成的项目", "pruneSucceeded": "从队列修剪 {{item_count}} 个已完成的项目",
"notReady": "无法排队", "notReady": "无法排队",
"batchFailedToQueue": "批次加入队列失败", "batchFailedToQueue": "批次加入队列失败",
"queueCountPrediction": "{{promptsCount}} 提示词 × {{iterations}} 迭代次数 -> {{count}} 次生成",
"batchQueued": "加入队列的批次", "batchQueued": "加入队列的批次",
"front": "前", "front": "前",
"pruneTooltip": "修剪 {{item_count}} 个已完成的项目", "pruneTooltip": "修剪 {{item_count}} 个已完成的项目",

View File

@ -1,3 +1,4 @@
/* eslint-disable no-console */
import fs from 'node:fs'; import fs from 'node:fs';
import openapiTS from 'openapi-typescript'; import openapiTS from 'openapi-typescript';

View File

@ -12,7 +12,6 @@ import { useGlobalHotkeys } from 'common/hooks/useGlobalHotkeys';
import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardModal'; import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardModal';
import DeleteImageModal from 'features/deleteImageModal/components/DeleteImageModal'; import DeleteImageModal from 'features/deleteImageModal/components/DeleteImageModal';
import { DynamicPromptsModal } from 'features/dynamicPrompts/components/DynamicPromptsPreviewModal'; import { DynamicPromptsModal } from 'features/dynamicPrompts/components/DynamicPromptsPreviewModal';
import { FloatingImageViewer } from 'features/gallery/components/ImageViewer/FloatingImageViewer';
import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterModelsToast'; import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterModelsToast';
import { configChanged } from 'features/system/store/configSlice'; import { configChanged } from 'features/system/store/configSlice';
import { languageSelector } from 'features/system/store/systemSelectors'; import { languageSelector } from 'features/system/store/systemSelectors';
@ -22,10 +21,10 @@ import i18n from 'i18n';
import { size } from 'lodash-es'; import { size } from 'lodash-es';
import { memo, useCallback, useEffect } from 'react'; import { memo, useCallback, useEffect } from 'react';
import { ErrorBoundary } from 'react-error-boundary'; import { ErrorBoundary } from 'react-error-boundary';
import { useGetOpenAPISchemaQuery } from 'services/api/endpoints/appInfo';
import AppErrorBoundaryFallback from './AppErrorBoundaryFallback'; import AppErrorBoundaryFallback from './AppErrorBoundaryFallback';
import PreselectedImage from './PreselectedImage'; import PreselectedImage from './PreselectedImage';
import Toaster from './Toaster';
const DEFAULT_CONFIG = {}; const DEFAULT_CONFIG = {};
@ -47,6 +46,7 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage }: Props) => {
useSocketIO(); useSocketIO();
useGlobalModifiersInit(); useGlobalModifiersInit();
useGlobalHotkeys(); useGlobalHotkeys();
useGetOpenAPISchemaQuery();
const { dropzone, isHandlingUpload, setIsHandlingUpload } = useFullscreenDropzone(); const { dropzone, isHandlingUpload, setIsHandlingUpload } = useFullscreenDropzone();
@ -95,9 +95,7 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage }: Props) => {
<DeleteImageModal /> <DeleteImageModal />
<ChangeBoardModal /> <ChangeBoardModal />
<DynamicPromptsModal /> <DynamicPromptsModal />
<Toaster />
<PreselectedImage selectedImage={selectedImage} /> <PreselectedImage selectedImage={selectedImage} />
<FloatingImageViewer />
</ErrorBoundary> </ErrorBoundary>
); );
}; };

View File

@ -1,5 +1,8 @@
import { Button, Flex, Heading, Link, Text, useToast } from '@invoke-ai/ui-library'; import { Button, Flex, Heading, Image, Link, Text } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { toast } from 'features/toast/toast';
import newGithubIssueUrl from 'new-github-issue-url'; import newGithubIssueUrl from 'new-github-issue-url';
import InvokeLogoYellow from 'public/assets/images/invoke-symbol-ylw-lrg.svg';
import { memo, useCallback, useMemo } from 'react'; import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { PiArrowCounterClockwiseBold, PiArrowSquareOutBold, PiCopyBold } from 'react-icons/pi'; import { PiArrowCounterClockwiseBold, PiArrowSquareOutBold, PiCopyBold } from 'react-icons/pi';
@ -11,31 +14,39 @@ type Props = {
}; };
const AppErrorBoundaryFallback = ({ error, resetErrorBoundary }: Props) => { const AppErrorBoundaryFallback = ({ error, resetErrorBoundary }: Props) => {
const toast = useToast();
const { t } = useTranslation(); const { t } = useTranslation();
const isLocal = useAppSelector((s) => s.config.isLocal);
const handleCopy = useCallback(() => { const handleCopy = useCallback(() => {
const text = JSON.stringify(serializeError(error), null, 2); const text = JSON.stringify(serializeError(error), null, 2);
navigator.clipboard.writeText(`\`\`\`\n${text}\n\`\`\``); navigator.clipboard.writeText(`\`\`\`\n${text}\n\`\`\``);
toast({ toast({
title: 'Error Copied', id: 'ERROR_COPIED',
title: t('toast.errorCopied'),
}); });
}, [error, toast]); }, [error, t]);
const url = useMemo( const url = useMemo(() => {
() => if (isLocal) {
newGithubIssueUrl({ return newGithubIssueUrl({
user: 'invoke-ai', user: 'invoke-ai',
repo: 'InvokeAI', repo: 'InvokeAI',
template: 'BUG_REPORT.yml', template: 'BUG_REPORT.yml',
title: `[bug]: ${error.name}: ${error.message}`, title: `[bug]: ${error.name}: ${error.message}`,
}), });
[error.message, error.name] } else {
); return 'https://support.invoke.ai/support/tickets/new';
}
}, [error.message, error.name, isLocal]);
return ( return (
<Flex layerStyle="body" w="100vw" h="100vh" alignItems="center" justifyContent="center" p={4}> <Flex layerStyle="body" w="100vw" h="100vh" alignItems="center" justifyContent="center" p={4}>
<Flex layerStyle="first" flexDir="column" borderRadius="base" justifyContent="center" gap={8} p={16}> <Flex layerStyle="first" flexDir="column" borderRadius="base" justifyContent="center" gap={8} p={16}>
<Heading>{t('common.somethingWentWrong')}</Heading> <Flex alignItems="center" gap="2">
<Image src={InvokeLogoYellow} alt="invoke-logo" w="24px" h="24px" minW="24px" minH="24px" userSelect="none" />
<Heading fontSize="2xl">{t('common.somethingWentWrong')}</Heading>
</Flex>
<Flex <Flex
layerStyle="second" layerStyle="second"
px={8} px={8}
@ -57,7 +68,9 @@ const AppErrorBoundaryFallback = ({ error, resetErrorBoundary }: Props) => {
{t('common.copyError')} {t('common.copyError')}
</Button> </Button>
<Link href={url} isExternal> <Link href={url} isExternal>
<Button leftIcon={<PiArrowSquareOutBold />}>{t('accessibility.createIssue')}</Button> <Button leftIcon={<PiArrowSquareOutBold />}>
{isLocal ? t('accessibility.createIssue') : t('accessibility.submitSupportTicket')}
</Button>
</Link> </Link>
</Flex> </Flex>
</Flex> </Flex>

View File

@ -19,6 +19,13 @@ function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) {
return extendTheme({ return extendTheme({
..._theme, ..._theme,
direction, direction,
shadows: {
..._theme.shadows,
selectedForCompare:
'0px 0px 0px 1px var(--invoke-colors-base-900), 0px 0px 0px 4px var(--invoke-colors-green-400)',
hoverSelectedForCompare:
'0px 0px 0px 1px var(--invoke-colors-base-900), 0px 0px 0px 4px var(--invoke-colors-green-300)',
},
}); });
}, [direction]); }, [direction]);

View File

@ -1,44 +0,0 @@
import { useToast } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { addToast, clearToastQueue } from 'features/system/store/systemSlice';
import type { MakeToastArg } from 'features/system/util/makeToast';
import { makeToast } from 'features/system/util/makeToast';
import { memo, useCallback, useEffect } from 'react';
/**
* Logical component. Watches the toast queue and makes toasts when the queue is not empty.
* @returns null
*/
const Toaster = () => {
const dispatch = useAppDispatch();
const toastQueue = useAppSelector((s) => s.system.toastQueue);
const toast = useToast();
useEffect(() => {
toastQueue.forEach((t) => {
toast(t);
});
toastQueue.length > 0 && dispatch(clearToastQueue());
}, [dispatch, toast, toastQueue]);
return null;
};
/**
* Returns a function that can be used to make a toast.
* @example
* const toaster = useAppToaster();
* toaster('Hello world!');
* toaster({ title: 'Hello world!', status: 'success' });
* @returns A function that can be used to make a toast.
* @see makeToast
* @see MakeToastArg
* @see UseToastOptions
*/
export const useAppToaster = () => {
const dispatch = useAppDispatch();
const toaster = useCallback((arg: MakeToastArg) => dispatch(addToast(makeToast(arg))), [dispatch]);
return toaster;
};
export default memo(Toaster);

View File

@ -6,8 +6,8 @@ import { useAppDispatch } from 'app/store/storeHooks';
import type { MapStore } from 'nanostores'; import type { MapStore } from 'nanostores';
import { atom, map } from 'nanostores'; import { atom, map } from 'nanostores';
import { useEffect, useMemo } from 'react'; import { useEffect, useMemo } from 'react';
import { setEventListeners } from 'services/events/setEventListeners';
import type { ClientToServerEvents, ServerToClientEvents } from 'services/events/types'; import type { ClientToServerEvents, ServerToClientEvents } from 'services/events/types';
import { setEventListeners } from 'services/events/util/setEventListeners';
import type { ManagerOptions, Socket, SocketOptions } from 'socket.io-client'; import type { ManagerOptions, Socket, SocketOptions } from 'socket.io-client';
import { io } from 'socket.io-client'; import { io } from 'socket.io-client';
@ -67,6 +67,8 @@ export const useSocketIO = () => {
if ($isDebugging.get() || import.meta.env.MODE === 'development') { if ($isDebugging.get() || import.meta.env.MODE === 'development') {
window.$socketOptions = $socketOptions; window.$socketOptions = $socketOptions;
// This is only enabled manually for debugging, console is allowed.
/* eslint-disable-next-line no-console */
console.log('Socket initialized', socket); console.log('Socket initialized', socket);
} }
@ -75,6 +77,8 @@ export const useSocketIO = () => {
return () => { return () => {
if ($isDebugging.get() || import.meta.env.MODE === 'development') { if ($isDebugging.get() || import.meta.env.MODE === 'development') {
window.$socketOptions = undefined; window.$socketOptions = undefined;
// This is only enabled manually for debugging, console is allowed.
/* eslint-disable-next-line no-console */
console.log('Socket teardown', socket); console.log('Socket teardown', socket);
} }
socket.disconnect(); socket.disconnect();

View File

@ -1,3 +1,6 @@
/* eslint-disable no-console */
// This is only enabled manually for debugging, console is allowed.
import type { Middleware, MiddlewareAPI } from '@reduxjs/toolkit'; import type { Middleware, MiddlewareAPI } from '@reduxjs/toolkit';
import { diff } from 'jsondiffpatch'; import { diff } from 'jsondiffpatch';

View File

@ -1,7 +1,6 @@
import type { UnknownAction } from '@reduxjs/toolkit'; import type { UnknownAction } from '@reduxjs/toolkit';
import { deepClone } from 'common/util/deepClone'; import { deepClone } from 'common/util/deepClone';
import { isAnyGraphBuilt } from 'features/nodes/store/actions'; import { isAnyGraphBuilt } from 'features/nodes/store/actions';
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
import { appInfoApi } from 'services/api/endpoints/appInfo'; import { appInfoApi } from 'services/api/endpoints/appInfo';
import type { Graph } from 'services/api/types'; import type { Graph } from 'services/api/types';
import { socketGeneratorProgress } from 'services/events/actions'; import { socketGeneratorProgress } from 'services/events/actions';
@ -25,13 +24,6 @@ export const actionSanitizer = <A extends UnknownAction>(action: A): A => {
}; };
} }
if (nodeTemplatesBuilt.match(action)) {
return {
...action,
payload: '<Node templates omitted>',
};
}
if (socketGeneratorProgress.match(action)) { if (socketGeneratorProgress.match(action)) {
const sanitized = deepClone(action); const sanitized = deepClone(action);
if (sanitized.payload.data.progress_image) { if (sanitized.payload.data.progress_image) {

View File

@ -35,28 +35,22 @@ import { addImageUploadedFulfilledListener } from 'app/store/middleware/listener
import { addModelSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelSelected'; import { addModelSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelSelected';
import { addModelsLoadedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelsLoaded'; import { addModelsLoadedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelsLoaded';
import { addDynamicPromptsListener } from 'app/store/middleware/listenerMiddleware/listeners/promptChanged'; import { addDynamicPromptsListener } from 'app/store/middleware/listenerMiddleware/listeners/promptChanged';
import { addSetDefaultSettingsListener } from 'app/store/middleware/listenerMiddleware/listeners/setDefaultSettings';
import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected'; import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected';
import { addSocketDisconnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected'; import { addSocketDisconnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected';
import { addGeneratorProgressEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress'; import { addGeneratorProgressEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress';
import { addGraphExecutionStateCompleteEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketGraphExecutionStateComplete';
import { addInvocationCompleteEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete'; import { addInvocationCompleteEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete';
import { addInvocationErrorEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError'; import { addInvocationErrorEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError';
import { addInvocationRetrievalErrorEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationRetrievalError';
import { addInvocationStartedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted'; import { addInvocationStartedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted';
import { addModelInstallEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall'; import { addModelInstallEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall';
import { addModelLoadEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelLoad'; import { addModelLoadEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelLoad';
import { addSocketQueueItemStatusChangedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged'; import { addSocketQueueItemStatusChangedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged';
import { addSessionRetrievalErrorEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketSessionRetrievalError';
import { addSocketSubscribedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketSubscribed';
import { addSocketUnsubscribedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketUnsubscribed';
import { addStagingAreaImageSavedListener } from 'app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved'; import { addStagingAreaImageSavedListener } from 'app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved';
import { addUpdateAllNodesRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested'; import { addUpdateAllNodesRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested';
import { addUpscaleRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/upscaleRequested'; import { addUpscaleRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/upscaleRequested';
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested'; import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
import type { AppDispatch, RootState } from 'app/store/store'; import type { AppDispatch, RootState } from 'app/store/store';
import { addSetDefaultSettingsListener } from './listeners/setDefaultSettings';
export const listenerMiddleware = createListenerMiddleware(); export const listenerMiddleware = createListenerMiddleware();
export type AppStartListening = TypedStartListening<RootState, AppDispatch>; export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
@ -104,18 +98,13 @@ addCommitStagingAreaImageListener(startAppListening);
// Socket.IO // Socket.IO
addGeneratorProgressEventListener(startAppListening); addGeneratorProgressEventListener(startAppListening);
addGraphExecutionStateCompleteEventListener(startAppListening);
addInvocationCompleteEventListener(startAppListening); addInvocationCompleteEventListener(startAppListening);
addInvocationErrorEventListener(startAppListening); addInvocationErrorEventListener(startAppListening);
addInvocationStartedEventListener(startAppListening); addInvocationStartedEventListener(startAppListening);
addSocketConnectedEventListener(startAppListening); addSocketConnectedEventListener(startAppListening);
addSocketDisconnectedEventListener(startAppListening); addSocketDisconnectedEventListener(startAppListening);
addSocketSubscribedEventListener(startAppListening);
addSocketUnsubscribedEventListener(startAppListening);
addModelLoadEventListener(startAppListening); addModelLoadEventListener(startAppListening);
addModelInstallEventListener(startAppListening); addModelInstallEventListener(startAppListening);
addSessionRetrievalErrorEventListener(startAppListening);
addInvocationRetrievalErrorEventListener(startAppListening);
addSocketQueueItemStatusChangedEventListener(startAppListening); addSocketQueueItemStatusChangedEventListener(startAppListening);
addBulkDownloadListeners(startAppListening); addBulkDownloadListeners(startAppListening);

View File

@ -8,7 +8,7 @@ import {
resetCanvas, resetCanvas,
setInitialCanvasImage, setInitialCanvasImage,
} from 'features/canvas/store/canvasSlice'; } from 'features/canvas/store/canvasSlice';
import { addToast } from 'features/system/store/systemSlice'; import { toast } from 'features/toast/toast';
import { t } from 'i18next'; import { t } from 'i18next';
import { queueApi } from 'services/api/endpoints/queue'; import { queueApi } from 'services/api/endpoints/queue';
@ -30,22 +30,20 @@ export const addCommitStagingAreaImageListener = (startAppListening: AppStartLis
req.reset(); req.reset();
if (canceled > 0) { if (canceled > 0) {
log.debug(`Canceled ${canceled} canvas batches`); log.debug(`Canceled ${canceled} canvas batches`);
dispatch( toast({
addToast({ id: 'CANCEL_BATCH_SUCCEEDED',
title: t('queue.cancelBatchSucceeded'), title: t('queue.cancelBatchSucceeded'),
status: 'success', status: 'success',
}) });
);
} }
dispatch(canvasBatchIdsReset()); dispatch(canvasBatchIdsReset());
} catch { } catch {
log.error('Failed to cancel canvas batches'); log.error('Failed to cancel canvas batches');
dispatch( toast({
addToast({ id: 'CANCEL_BATCH_FAILED',
title: t('queue.cancelBatchFailed'), title: t('queue.cancelBatchFailed'),
status: 'error', status: 'error',
}) });
);
} }
}, },
}); });

View File

@ -1,8 +1,8 @@
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { parseify } from 'common/util/serialize'; import { parseify } from 'common/util/serialize';
import { toast } from 'common/util/toast';
import { zPydanticValidationError } from 'features/system/store/zodSchemas'; import { zPydanticValidationError } from 'features/system/store/zodSchemas';
import { toast } from 'features/toast/toast';
import { t } from 'i18next'; import { t } from 'i18next';
import { truncate, upperFirst } from 'lodash-es'; import { truncate, upperFirst } from 'lodash-es';
import { queueApi } from 'services/api/endpoints/queue'; import { queueApi } from 'services/api/endpoints/queue';
@ -16,18 +16,15 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) =
const arg = action.meta.arg.originalArgs; const arg = action.meta.arg.originalArgs;
logger('queue').debug({ enqueueResult: parseify(response) }, 'Batch enqueued'); logger('queue').debug({ enqueueResult: parseify(response) }, 'Batch enqueued');
if (!toast.isActive('batch-queued')) {
toast({ toast({
id: 'batch-queued', id: 'QUEUE_BATCH_SUCCEEDED',
title: t('queue.batchQueued'), title: t('queue.batchQueued'),
status: 'success',
description: t('queue.batchQueuedDesc', { description: t('queue.batchQueuedDesc', {
count: response.enqueued, count: response.enqueued,
direction: arg.prepend ? t('queue.front') : t('queue.back'), direction: arg.prepend ? t('queue.front') : t('queue.back'),
}), }),
duration: 1000,
status: 'success',
}); });
}
}, },
}); });
@ -40,9 +37,10 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) =
if (!response) { if (!response) {
toast({ toast({
id: 'QUEUE_BATCH_FAILED',
title: t('queue.batchFailedToQueue'), title: t('queue.batchFailedToQueue'),
status: 'error', status: 'error',
description: 'Unknown Error', description: t('common.unknownError'),
}); });
logger('queue').error({ batchConfig: parseify(arg), error: parseify(response) }, t('queue.batchFailedToQueue')); logger('queue').error({ batchConfig: parseify(arg), error: parseify(response) }, t('queue.batchFailedToQueue'));
return; return;
@ -52,7 +50,7 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) =
if (result.success) { if (result.success) {
result.data.data.detail.map((e) => { result.data.data.detail.map((e) => {
toast({ toast({
id: 'batch-failed-to-queue', id: 'QUEUE_BATCH_FAILED',
title: truncate(upperFirst(e.msg), { length: 128 }), title: truncate(upperFirst(e.msg), { length: 128 }),
status: 'error', status: 'error',
description: truncate( description: truncate(
@ -64,9 +62,10 @@ export const addBatchEnqueuedListener = (startAppListening: AppStartListening) =
}); });
} else if (response.status !== 403) { } else if (response.status !== 403) {
toast({ toast({
id: 'QUEUE_BATCH_FAILED',
title: t('queue.batchFailedToQueue'), title: t('queue.batchFailedToQueue'),
description: t('common.unknownError'),
status: 'error', status: 'error',
description: t('common.unknownError'),
}); });
} }
logger('queue').error({ batchConfig: parseify(arg), error: parseify(response) }, t('queue.batchFailedToQueue')); logger('queue').error({ batchConfig: parseify(arg), error: parseify(response) }, t('queue.batchFailedToQueue'));

View File

@ -21,7 +21,7 @@ export const addDeleteBoardAndImagesFulfilledListener = (startAppListening: AppS
const { canvas, nodes, controlAdapters, controlLayers } = getState(); const { canvas, nodes, controlAdapters, controlLayers } = getState();
deleted_images.forEach((image_name) => { deleted_images.forEach((image_name) => {
const imageUsage = getImageUsage(canvas, nodes, controlAdapters, controlLayers.present, image_name); const imageUsage = getImageUsage(canvas, nodes.present, controlAdapters, controlLayers.present, image_name);
if (imageUsage.isCanvasImage && !wasCanvasReset) { if (imageUsage.isCanvasImage && !wasCanvasReset) {
dispatch(resetCanvas()); dispatch(resetCanvas());

View File

@ -1,13 +1,12 @@
import type { UseToastOptions } from '@invoke-ai/ui-library';
import { ExternalLink } from '@invoke-ai/ui-library'; import { ExternalLink } from '@invoke-ai/ui-library';
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { toast } from 'common/util/toast'; import { toast } from 'features/toast/toast';
import { t } from 'i18next'; import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images'; import { imagesApi } from 'services/api/endpoints/images';
import { import {
socketBulkDownloadCompleted, socketBulkDownloadComplete,
socketBulkDownloadFailed, socketBulkDownloadError,
socketBulkDownloadStarted, socketBulkDownloadStarted,
} from 'services/events/actions'; } from 'services/events/actions';
@ -28,7 +27,6 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) =
// Show the response message if it exists, otherwise show the default message // Show the response message if it exists, otherwise show the default message
description: action.payload.response || t('gallery.bulkDownloadRequestedDesc'), description: action.payload.response || t('gallery.bulkDownloadRequestedDesc'),
duration: null, duration: null,
isClosable: true,
}); });
}, },
}); });
@ -40,9 +38,9 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) =
// There isn't any toast to update if we get this event. // There isn't any toast to update if we get this event.
toast({ toast({
id: 'BULK_DOWNLOAD_REQUEST_FAILED',
title: t('gallery.bulkDownloadRequestFailed'), title: t('gallery.bulkDownloadRequestFailed'),
status: 'success', status: 'error',
isClosable: true,
}); });
}, },
}); });
@ -56,7 +54,7 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) =
}); });
startAppListening({ startAppListening({
actionCreator: socketBulkDownloadCompleted, actionCreator: socketBulkDownloadComplete,
effect: async (action) => { effect: async (action) => {
log.debug(action.payload.data, 'Bulk download preparation completed'); log.debug(action.payload.data, 'Bulk download preparation completed');
@ -65,7 +63,7 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) =
// TODO(psyche): This URL may break in in some environments (e.g. Nvidia workbench) but we need to test it first // TODO(psyche): This URL may break in in some environments (e.g. Nvidia workbench) but we need to test it first
const url = `/api/v1/images/download/${bulk_download_item_name}`; const url = `/api/v1/images/download/${bulk_download_item_name}`;
const toastOptions: UseToastOptions = { toast({
id: bulk_download_item_name, id: bulk_download_item_name,
title: t('gallery.bulkDownloadReady', 'Download ready'), title: t('gallery.bulkDownloadReady', 'Download ready'),
status: 'success', status: 'success',
@ -77,38 +75,24 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) =
/> />
), ),
duration: null, duration: null,
isClosable: true, });
};
if (toast.isActive(bulk_download_item_name)) {
toast.update(bulk_download_item_name, toastOptions);
} else {
toast(toastOptions);
}
}, },
}); });
startAppListening({ startAppListening({
actionCreator: socketBulkDownloadFailed, actionCreator: socketBulkDownloadError,
effect: async (action) => { effect: async (action) => {
log.debug(action.payload.data, 'Bulk download preparation failed'); log.debug(action.payload.data, 'Bulk download preparation failed');
const { bulk_download_item_name } = action.payload.data; const { bulk_download_item_name } = action.payload.data;
const toastOptions: UseToastOptions = { toast({
id: bulk_download_item_name, id: bulk_download_item_name,
title: t('gallery.bulkDownloadFailed'), title: t('gallery.bulkDownloadFailed'),
status: 'error', status: 'error',
description: action.payload.data.error, description: action.payload.data.error,
duration: null, duration: null,
isClosable: true, });
};
if (toast.isActive(bulk_download_item_name)) {
toast.update(bulk_download_item_name, toastOptions);
} else {
toast(toastOptions);
}
}, },
}); });
}; };

View File

@ -2,14 +2,14 @@ import { $logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { canvasCopiedToClipboard } from 'features/canvas/store/actions'; import { canvasCopiedToClipboard } from 'features/canvas/store/actions';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob'; import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { addToast } from 'features/system/store/systemSlice';
import { copyBlobToClipboard } from 'features/system/util/copyBlobToClipboard'; import { copyBlobToClipboard } from 'features/system/util/copyBlobToClipboard';
import { toast } from 'features/toast/toast';
import { t } from 'i18next'; import { t } from 'i18next';
export const addCanvasCopiedToClipboardListener = (startAppListening: AppStartListening) => { export const addCanvasCopiedToClipboardListener = (startAppListening: AppStartListening) => {
startAppListening({ startAppListening({
actionCreator: canvasCopiedToClipboard, actionCreator: canvasCopiedToClipboard,
effect: async (action, { dispatch, getState }) => { effect: async (action, { getState }) => {
const moduleLog = $logger.get().child({ namespace: 'canvasCopiedToClipboardListener' }); const moduleLog = $logger.get().child({ namespace: 'canvasCopiedToClipboardListener' });
const state = getState(); const state = getState();
@ -19,22 +19,20 @@ export const addCanvasCopiedToClipboardListener = (startAppListening: AppStartLi
copyBlobToClipboard(blob); copyBlobToClipboard(blob);
} catch (err) { } catch (err) {
moduleLog.error(String(err)); moduleLog.error(String(err));
dispatch( toast({
addToast({ id: 'CANVAS_COPY_FAILED',
title: t('toast.problemCopyingCanvas'), title: t('toast.problemCopyingCanvas'),
description: t('toast.problemCopyingCanvasDesc'), description: t('toast.problemCopyingCanvasDesc'),
status: 'error', status: 'error',
}) });
);
return; return;
} }
dispatch( toast({
addToast({ id: 'CANVAS_COPY_SUCCEEDED',
title: t('toast.canvasCopiedClipboard'), title: t('toast.canvasCopiedClipboard'),
status: 'success', status: 'success',
}) });
);
}, },
}); });
}; };

View File

@ -3,13 +3,13 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
import { canvasDownloadedAsImage } from 'features/canvas/store/actions'; import { canvasDownloadedAsImage } from 'features/canvas/store/actions';
import { downloadBlob } from 'features/canvas/util/downloadBlob'; import { downloadBlob } from 'features/canvas/util/downloadBlob';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob'; import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { addToast } from 'features/system/store/systemSlice'; import { toast } from 'features/toast/toast';
import { t } from 'i18next'; import { t } from 'i18next';
export const addCanvasDownloadedAsImageListener = (startAppListening: AppStartListening) => { export const addCanvasDownloadedAsImageListener = (startAppListening: AppStartListening) => {
startAppListening({ startAppListening({
actionCreator: canvasDownloadedAsImage, actionCreator: canvasDownloadedAsImage,
effect: async (action, { dispatch, getState }) => { effect: async (action, { getState }) => {
const moduleLog = $logger.get().child({ namespace: 'canvasSavedToGalleryListener' }); const moduleLog = $logger.get().child({ namespace: 'canvasSavedToGalleryListener' });
const state = getState(); const state = getState();
@ -18,18 +18,17 @@ export const addCanvasDownloadedAsImageListener = (startAppListening: AppStartLi
blob = await getBaseLayerBlob(state); blob = await getBaseLayerBlob(state);
} catch (err) { } catch (err) {
moduleLog.error(String(err)); moduleLog.error(String(err));
dispatch( toast({
addToast({ id: 'CANVAS_DOWNLOAD_FAILED',
title: t('toast.problemDownloadingCanvas'), title: t('toast.problemDownloadingCanvas'),
description: t('toast.problemDownloadingCanvasDesc'), description: t('toast.problemDownloadingCanvasDesc'),
status: 'error', status: 'error',
}) });
);
return; return;
} }
downloadBlob(blob, 'canvas.png'); downloadBlob(blob, 'canvas.png');
dispatch(addToast({ title: t('toast.canvasDownloaded'), status: 'success' })); toast({ id: 'CANVAS_DOWNLOAD_SUCCEEDED', title: t('toast.canvasDownloaded'), status: 'success' });
}, },
}); });
}; };

View File

@ -3,7 +3,7 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
import { canvasImageToControlAdapter } from 'features/canvas/store/actions'; import { canvasImageToControlAdapter } from 'features/canvas/store/actions';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob'; import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { controlAdapterImageChanged } from 'features/controlAdapters/store/controlAdaptersSlice'; import { controlAdapterImageChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
import { addToast } from 'features/system/store/systemSlice'; import { toast } from 'features/toast/toast';
import { t } from 'i18next'; import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images'; import { imagesApi } from 'services/api/endpoints/images';
@ -20,13 +20,12 @@ export const addCanvasImageToControlNetListener = (startAppListening: AppStartLi
blob = await getBaseLayerBlob(state, true); blob = await getBaseLayerBlob(state, true);
} catch (err) { } catch (err) {
log.error(String(err)); log.error(String(err));
dispatch( toast({
addToast({ id: 'PROBLEM_SAVING_CANVAS',
title: t('toast.problemSavingCanvas'), title: t('toast.problemSavingCanvas'),
description: t('toast.problemSavingCanvasDesc'), description: t('toast.problemSavingCanvasDesc'),
status: 'error', status: 'error',
}) });
);
return; return;
} }
@ -43,7 +42,7 @@ export const addCanvasImageToControlNetListener = (startAppListening: AppStartLi
crop_visible: false, crop_visible: false,
postUploadAction: { postUploadAction: {
type: 'TOAST', type: 'TOAST',
toastOptions: { title: t('toast.canvasSentControlnetAssets') }, title: t('toast.canvasSentControlnetAssets'),
}, },
}) })
).unwrap(); ).unwrap();

View File

@ -2,7 +2,7 @@ import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { canvasMaskSavedToGallery } from 'features/canvas/store/actions'; import { canvasMaskSavedToGallery } from 'features/canvas/store/actions';
import { getCanvasData } from 'features/canvas/util/getCanvasData'; import { getCanvasData } from 'features/canvas/util/getCanvasData';
import { addToast } from 'features/system/store/systemSlice'; import { toast } from 'features/toast/toast';
import { t } from 'i18next'; import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images'; import { imagesApi } from 'services/api/endpoints/images';
@ -29,13 +29,12 @@ export const addCanvasMaskSavedToGalleryListener = (startAppListening: AppStartL
if (!maskBlob) { if (!maskBlob) {
log.error('Problem getting mask layer blob'); log.error('Problem getting mask layer blob');
dispatch( toast({
addToast({ id: 'PROBLEM_SAVING_MASK',
title: t('toast.problemSavingMask'), title: t('toast.problemSavingMask'),
description: t('toast.problemSavingMaskDesc'), description: t('toast.problemSavingMaskDesc'),
status: 'error', status: 'error',
}) });
);
return; return;
} }
@ -52,7 +51,7 @@ export const addCanvasMaskSavedToGalleryListener = (startAppListening: AppStartL
crop_visible: true, crop_visible: true,
postUploadAction: { postUploadAction: {
type: 'TOAST', type: 'TOAST',
toastOptions: { title: t('toast.maskSavedAssets') }, title: t('toast.maskSavedAssets'),
}, },
}) })
); );

View File

@ -3,7 +3,7 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
import { canvasMaskToControlAdapter } from 'features/canvas/store/actions'; import { canvasMaskToControlAdapter } from 'features/canvas/store/actions';
import { getCanvasData } from 'features/canvas/util/getCanvasData'; import { getCanvasData } from 'features/canvas/util/getCanvasData';
import { controlAdapterImageChanged } from 'features/controlAdapters/store/controlAdaptersSlice'; import { controlAdapterImageChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
import { addToast } from 'features/system/store/systemSlice'; import { toast } from 'features/toast/toast';
import { t } from 'i18next'; import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images'; import { imagesApi } from 'services/api/endpoints/images';
@ -30,13 +30,12 @@ export const addCanvasMaskToControlNetListener = (startAppListening: AppStartLis
if (!maskBlob) { if (!maskBlob) {
log.error('Problem getting mask layer blob'); log.error('Problem getting mask layer blob');
dispatch( toast({
addToast({ id: 'PROBLEM_IMPORTING_MASK',
title: t('toast.problemImportingMask'), title: t('toast.problemImportingMask'),
description: t('toast.problemImportingMaskDesc'), description: t('toast.problemImportingMaskDesc'),
status: 'error', status: 'error',
}) });
);
return; return;
} }
@ -53,7 +52,7 @@ export const addCanvasMaskToControlNetListener = (startAppListening: AppStartLis
crop_visible: false, crop_visible: false,
postUploadAction: { postUploadAction: {
type: 'TOAST', type: 'TOAST',
toastOptions: { title: t('toast.maskSentControlnetAssets') }, title: t('toast.maskSentControlnetAssets'),
}, },
}) })
).unwrap(); ).unwrap();

View File

@ -4,7 +4,7 @@ import { canvasMerged } from 'features/canvas/store/actions';
import { $canvasBaseLayer } from 'features/canvas/store/canvasNanostore'; import { $canvasBaseLayer } from 'features/canvas/store/canvasNanostore';
import { setMergedCanvas } from 'features/canvas/store/canvasSlice'; import { setMergedCanvas } from 'features/canvas/store/canvasSlice';
import { getFullBaseLayerBlob } from 'features/canvas/util/getFullBaseLayerBlob'; import { getFullBaseLayerBlob } from 'features/canvas/util/getFullBaseLayerBlob';
import { addToast } from 'features/system/store/systemSlice'; import { toast } from 'features/toast/toast';
import { t } from 'i18next'; import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images'; import { imagesApi } from 'services/api/endpoints/images';
@ -17,13 +17,12 @@ export const addCanvasMergedListener = (startAppListening: AppStartListening) =>
if (!blob) { if (!blob) {
moduleLog.error('Problem getting base layer blob'); moduleLog.error('Problem getting base layer blob');
dispatch( toast({
addToast({ id: 'PROBLEM_MERGING_CANVAS',
title: t('toast.problemMergingCanvas'), title: t('toast.problemMergingCanvas'),
description: t('toast.problemMergingCanvasDesc'), description: t('toast.problemMergingCanvasDesc'),
status: 'error', status: 'error',
}) });
);
return; return;
} }
@ -31,13 +30,12 @@ export const addCanvasMergedListener = (startAppListening: AppStartListening) =>
if (!canvasBaseLayer) { if (!canvasBaseLayer) {
moduleLog.error('Problem getting canvas base layer'); moduleLog.error('Problem getting canvas base layer');
dispatch( toast({
addToast({ id: 'PROBLEM_MERGING_CANVAS',
title: t('toast.problemMergingCanvas'), title: t('toast.problemMergingCanvas'),
description: t('toast.problemMergingCanvasDesc'), description: t('toast.problemMergingCanvasDesc'),
status: 'error', status: 'error',
}) });
);
return; return;
} }
@ -54,7 +52,7 @@ export const addCanvasMergedListener = (startAppListening: AppStartListening) =>
is_intermediate: true, is_intermediate: true,
postUploadAction: { postUploadAction: {
type: 'TOAST', type: 'TOAST',
toastOptions: { title: t('toast.canvasMerged') }, title: t('toast.canvasMerged'),
}, },
}) })
).unwrap(); ).unwrap();

View File

@ -1,8 +1,9 @@
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { parseify } from 'common/util/serialize';
import { canvasSavedToGallery } from 'features/canvas/store/actions'; import { canvasSavedToGallery } from 'features/canvas/store/actions';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob'; import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { addToast } from 'features/system/store/systemSlice'; import { toast } from 'features/toast/toast';
import { t } from 'i18next'; import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images'; import { imagesApi } from 'services/api/endpoints/images';
@ -18,13 +19,12 @@ export const addCanvasSavedToGalleryListener = (startAppListening: AppStartListe
blob = await getBaseLayerBlob(state); blob = await getBaseLayerBlob(state);
} catch (err) { } catch (err) {
log.error(String(err)); log.error(String(err));
dispatch( toast({
addToast({ id: 'CANVAS_SAVE_FAILED',
title: t('toast.problemSavingCanvas'), title: t('toast.problemSavingCanvas'),
description: t('toast.problemSavingCanvasDesc'), description: t('toast.problemSavingCanvasDesc'),
status: 'error', status: 'error',
}) });
);
return; return;
} }
@ -41,7 +41,10 @@ export const addCanvasSavedToGalleryListener = (startAppListening: AppStartListe
crop_visible: true, crop_visible: true,
postUploadAction: { postUploadAction: {
type: 'TOAST', type: 'TOAST',
toastOptions: { title: t('toast.canvasSavedGallery') }, title: t('toast.canvasSavedGallery'),
},
metadata: {
_canvas_objects: parseify(state.canvas.layerState.objects),
}, },
}) })
); );

View File

@ -1,60 +1,56 @@
import { isAnyOf } from '@reduxjs/toolkit'; import { isAnyOf } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppDispatch } from 'app/store/store';
import { parseify } from 'common/util/serialize'; import { parseify } from 'common/util/serialize';
import { import {
caLayerImageChanged, caLayerImageChanged,
caLayerIsProcessingImageChanged,
caLayerModelChanged, caLayerModelChanged,
caLayerProcessedImageChanged, caLayerProcessedImageChanged,
caLayerProcessorConfigChanged, caLayerProcessorConfigChanged,
caLayerProcessorPendingBatchIdChanged,
caLayerRecalled,
isControlAdapterLayer, isControlAdapterLayer,
} from 'features/controlLayers/store/controlLayersSlice'; } from 'features/controlLayers/store/controlLayersSlice';
import { CA_PROCESSOR_DATA } from 'features/controlLayers/util/controlAdapters'; import { CA_PROCESSOR_DATA } from 'features/controlLayers/util/controlAdapters';
import { isImageOutput } from 'features/nodes/types/common'; import { toast } from 'features/toast/toast';
import { addToast } from 'features/system/store/systemSlice';
import { t } from 'i18next'; import { t } from 'i18next';
import { isEqual } from 'lodash-es'; import { isEqual } from 'lodash-es';
import { imagesApi } from 'services/api/endpoints/images'; import { getImageDTO } from 'services/api/endpoints/images';
import { queueApi } from 'services/api/endpoints/queue'; import { queueApi } from 'services/api/endpoints/queue';
import type { BatchConfig, ImageDTO } from 'services/api/types'; import type { BatchConfig } from 'services/api/types';
import { socketInvocationComplete } from 'services/events/actions'; import { socketInvocationComplete } from 'services/events/actions';
import { assert } from 'tsafe';
const matcher = isAnyOf(caLayerImageChanged, caLayerProcessorConfigChanged, caLayerModelChanged); const matcher = isAnyOf(caLayerImageChanged, caLayerProcessorConfigChanged, caLayerModelChanged, caLayerRecalled);
const DEBOUNCE_MS = 300; const DEBOUNCE_MS = 300;
const log = logger('session'); const log = logger('session');
/**
* Simple helper to cancel a batch and reset the pending batch ID
*/
const cancelProcessorBatch = async (dispatch: AppDispatch, layerId: string, batchId: string) => {
const req = dispatch(queueApi.endpoints.cancelByBatchIds.initiate({ batch_ids: [batchId] }));
log.trace({ batchId }, 'Cancelling existing preprocessor batch');
try {
await req.unwrap();
} catch {
// no-op
} finally {
req.reset();
// Always reset the pending batch ID - the cancel req could fail if the batch doesn't exist
dispatch(caLayerProcessorPendingBatchIdChanged({ layerId, batchId: null }));
}
};
export const addControlAdapterPreprocessor = (startAppListening: AppStartListening) => { export const addControlAdapterPreprocessor = (startAppListening: AppStartListening) => {
startAppListening({ startAppListening({
matcher, matcher,
effect: async (action, { dispatch, getState, getOriginalState, cancelActiveListeners, delay, take }) => { effect: async (action, { dispatch, getState, getOriginalState, cancelActiveListeners, delay, take, signal }) => {
const { layerId } = action.payload; const layerId = caLayerRecalled.match(action) ? action.payload.id : action.payload.layerId;
const precheckLayerOriginal = getOriginalState() const state = getState();
.controlLayers.present.layers.filter(isControlAdapterLayer) const originalState = getOriginalState();
.find((l) => l.id === layerId);
const precheckLayer = getState()
.controlLayers.present.layers.filter(isControlAdapterLayer)
.find((l) => l.id === layerId);
// Conditions to bail
const layerDoesNotExist = !precheckLayer;
const layerHasNoImage = !precheckLayer?.controlAdapter.image;
const layerHasNoProcessorConfig = !precheckLayer?.controlAdapter.processorConfig;
const layerIsAlreadyProcessingImage = precheckLayer?.controlAdapter.isProcessingImage;
const areImageAndProcessorUnchanged =
isEqual(precheckLayer?.controlAdapter.image, precheckLayerOriginal?.controlAdapter.image) &&
isEqual(precheckLayer?.controlAdapter.processorConfig, precheckLayerOriginal?.controlAdapter.processorConfig);
if (
layerDoesNotExist ||
layerHasNoImage ||
layerHasNoProcessorConfig ||
areImageAndProcessorUnchanged ||
layerIsAlreadyProcessingImage
) {
return;
}
// Cancel any in-progress instances of this listener // Cancel any in-progress instances of this listener
cancelActiveListeners(); cancelActiveListeners();
@ -62,27 +58,55 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni
// Delay before starting actual work // Delay before starting actual work
await delay(DEBOUNCE_MS); await delay(DEBOUNCE_MS);
dispatch(caLayerIsProcessingImageChanged({ layerId, isProcessingImage: true }));
// Double-check that we are still eligible for processing
const state = getState();
const layer = state.controlLayers.present.layers.filter(isControlAdapterLayer).find((l) => l.id === layerId); const layer = state.controlLayers.present.layers.filter(isControlAdapterLayer).find((l) => l.id === layerId);
const image = layer?.controlAdapter.image;
const config = layer?.controlAdapter.processorConfig;
// If we have no image or there is no processor config, bail if (!layer) {
if (!layer || !image || !config) {
return; return;
} }
// @ts-expect-error: TS isn't able to narrow the typing of buildNode and `config` will error... // We should only process if the processor settings or image have changed
const processorNode = CA_PROCESSOR_DATA[config.type].buildNode(image, config); const originalLayer = originalState.controlLayers.present.layers
.filter(isControlAdapterLayer)
.find((l) => l.id === layerId);
const originalImage = originalLayer?.controlAdapter.image;
const originalConfig = originalLayer?.controlAdapter.processorConfig;
const image = layer.controlAdapter.image;
const config = layer.controlAdapter.processorConfig;
if (isEqual(config, originalConfig) && isEqual(image, originalImage)) {
// Neither config nor image have changed, we can bail
return;
}
if (!image || !config) {
// - If we have no image, we have nothing to process
// - If we have no processor config, we have nothing to process
// Clear the processed image and bail
dispatch(caLayerProcessedImageChanged({ layerId, imageDTO: null }));
return;
}
// At this point, the user has stopped fiddling with the processor settings and there is a processor selected.
// If there is a pending processor batch, cancel it.
if (layer.controlAdapter.processorPendingBatchId) {
cancelProcessorBatch(dispatch, layerId, layer.controlAdapter.processorPendingBatchId);
}
// TODO(psyche): I can't get TS to be happy, it thinkgs `config` is `never` but it should be inferred from the generic... I'll just cast it for now
const processorNode = CA_PROCESSOR_DATA[config.type].buildNode(image, config as never);
const enqueueBatchArg: BatchConfig = { const enqueueBatchArg: BatchConfig = {
prepend: true, prepend: true,
batch: { batch: {
graph: { graph: {
nodes: { nodes: {
[processorNode.id]: { ...processorNode, is_intermediate: true }, [processorNode.id]: {
...processorNode,
// Control images are always intermediate - do not save to gallery
is_intermediate: true,
},
}, },
edges: [], edges: [],
}, },
@ -90,50 +114,55 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni
}, },
}; };
try { // Kick off the processor batch
const req = dispatch( const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(enqueueBatchArg, { queueApi.endpoints.enqueueBatch.initiate(enqueueBatchArg, {
fixedCacheKey: 'enqueueBatch', fixedCacheKey: 'enqueueBatch',
}) })
); );
try {
const enqueueResult = await req.unwrap(); const enqueueResult = await req.unwrap();
req.reset(); // TODO(psyche): Update the pydantic models, pretty sure we will _always_ have a batch_id here, but the model says it's optional
assert(enqueueResult.batch.batch_id, 'Batch ID not returned from queue');
dispatch(caLayerProcessorPendingBatchIdChanged({ layerId, batchId: enqueueResult.batch.batch_id }));
log.debug({ enqueueResult: parseify(enqueueResult) }, t('queue.graphQueued')); log.debug({ enqueueResult: parseify(enqueueResult) }, t('queue.graphQueued'));
// Wait for the processor node to complete
const [invocationCompleteAction] = await take( const [invocationCompleteAction] = await take(
(action): action is ReturnType<typeof socketInvocationComplete> => (action): action is ReturnType<typeof socketInvocationComplete> =>
socketInvocationComplete.match(action) && socketInvocationComplete.match(action) &&
action.payload.data.queue_batch_id === enqueueResult.batch.batch_id && action.payload.data.batch_id === enqueueResult.batch.batch_id &&
action.payload.data.source_node_id === processorNode.id action.payload.data.invocation_source_id === processorNode.id
); );
// We still have to check the output type // We still have to check the output type
if (isImageOutput(invocationCompleteAction.payload.data.result)) { assert(
invocationCompleteAction.payload.data.result.type === 'image_output',
`Processor did not return an image output, got: ${invocationCompleteAction.payload.data.result}`
);
const { image_name } = invocationCompleteAction.payload.data.result.image; const { image_name } = invocationCompleteAction.payload.data.result.image;
// Wait for the ImageDTO to be received const imageDTO = await getImageDTO(image_name);
const [{ payload }] = await take( assert(imageDTO, "Failed to fetch processor output's image DTO");
(action) =>
imagesApi.endpoints.getImageDTO.matchFulfilled(action) && action.payload.image_name === image_name
);
const imageDTO = payload as ImageDTO;
// Whew! We made it. Update the layer with the processed image
log.debug({ layerId, imageDTO }, 'ControlNet image processed'); log.debug({ layerId, imageDTO }, 'ControlNet image processed');
dispatch(caLayerProcessedImageChanged({ layerId, imageDTO }));
// Update the processed image in the store dispatch(caLayerProcessorPendingBatchIdChanged({ layerId, batchId: null }));
dispatch(
caLayerProcessedImageChanged({
layerId,
imageDTO,
})
);
dispatch(caLayerIsProcessingImageChanged({ layerId, isProcessingImage: false }));
}
} catch (error) { } catch (error) {
console.log(error); if (signal.aborted) {
// The listener was canceled - we need to cancel the pending processor batch, if there is one (could have changed by now).
const pendingBatchId = getState()
.controlLayers.present.layers.filter(isControlAdapterLayer)
.find((l) => l.id === layerId)?.controlAdapter.processorPendingBatchId;
if (pendingBatchId) {
cancelProcessorBatch(dispatch, layerId, pendingBatchId);
}
log.trace('Control Adapter preprocessor cancelled');
} else {
// Some other error condition...
log.error({ enqueueBatchArg: parseify(enqueueBatchArg) }, t('queue.graphFailedToQueue')); log.error({ enqueueBatchArg: parseify(enqueueBatchArg) }, t('queue.graphFailedToQueue'));
dispatch(caLayerIsProcessingImageChanged({ layerId, isProcessingImage: false }));
if (error instanceof Object) { if (error instanceof Object) {
if ('data' in error && 'status' in error) { if ('data' in error && 'status' in error) {
@ -144,12 +173,14 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni
} }
} }
dispatch( toast({
addToast({ id: 'GRAPH_QUEUE_FAILED',
title: t('queue.graphFailedToQueue'), title: t('queue.graphFailedToQueue'),
status: 'error', status: 'error',
}) });
); }
} finally {
req.reset();
} }
}, },
}); });

View File

@ -9,8 +9,7 @@ import {
selectControlAdapterById, selectControlAdapterById,
} from 'features/controlAdapters/store/controlAdaptersSlice'; } from 'features/controlAdapters/store/controlAdaptersSlice';
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types'; import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
import { isImageOutput } from 'features/nodes/types/common'; import { toast } from 'features/toast/toast';
import { addToast } from 'features/system/store/systemSlice';
import { t } from 'i18next'; import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images'; import { imagesApi } from 'services/api/endpoints/images';
import { queueApi } from 'services/api/endpoints/queue'; import { queueApi } from 'services/api/endpoints/queue';
@ -69,12 +68,12 @@ export const addControlNetImageProcessedListener = (startAppListening: AppStartL
const [invocationCompleteAction] = await take( const [invocationCompleteAction] = await take(
(action): action is ReturnType<typeof socketInvocationComplete> => (action): action is ReturnType<typeof socketInvocationComplete> =>
socketInvocationComplete.match(action) && socketInvocationComplete.match(action) &&
action.payload.data.queue_batch_id === enqueueResult.batch.batch_id && action.payload.data.batch_id === enqueueResult.batch.batch_id &&
action.payload.data.source_node_id === nodeId action.payload.data.invocation_source_id === nodeId
); );
// We still have to check the output type // We still have to check the output type
if (isImageOutput(invocationCompleteAction.payload.data.result)) { if (invocationCompleteAction.payload.data.result.type === 'image_output') {
const { image_name } = invocationCompleteAction.payload.data.result.image; const { image_name } = invocationCompleteAction.payload.data.result.image;
// Wait for the ImageDTO to be received // Wait for the ImageDTO to be received
@ -108,12 +107,11 @@ export const addControlNetImageProcessedListener = (startAppListening: AppStartL
} }
} }
dispatch( toast({
addToast({ id: 'GRAPH_QUEUE_FAILED',
title: t('queue.graphFailedToQueue'), title: t('queue.graphFailedToQueue'),
status: 'error', status: 'error',
}) });
);
} }
}, },
}); });

Some files were not shown because too many files have changed in this diff Show More